Use tf.app.run in gcs_smoke, so that the flags are explicitly parsed, instead of parsed when first accessed.

PiperOrigin-RevId: 173702828
This commit is contained in:
A. Unique TensorFlower
2017-10-27 12:16:44 -07:00
committed by TensorFlower Gardener
parent 3d39b32b9a
commit 9158f974a3

View File

@@ -35,6 +35,7 @@ flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
FLAGS = flags.FLAGS
def create_examples(num_examples, input_mean):
"""Create ExampleProto's containing data."""
ids = np.arange(num_examples).reshape([num_examples, 1])
@@ -49,6 +50,7 @@ def create_examples(num_examples, input_mean):
examples.append(ex)
return examples
def create_dir_test():
"""Verifies file_io directory handling methods."""
@@ -122,6 +124,7 @@ def create_dir_test():
print("Deleted directory recursively %s in %s milliseconds" % (
dir_name, elapsed_ms))
def create_object_test():
"""Verifies file_io's object manipulation methods ."""
starttime_ms = int(round(time.time() * 1000))
@@ -142,7 +145,8 @@ def create_object_test():
print("Creating file %s." % file_name)
file_io.write_string_to_file(file_name, "test file creation.")
elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
print("Created %d files in %s milliseconds" % (len(files_to_create), elapsed_ms))
print("Created %d files in %s milliseconds" % (
len(files_to_create), elapsed_ms))
# Listing files of pattern1.
list_files_pattern = "%s/test_file*.txt" % dir_name
@@ -185,7 +189,9 @@ def create_object_test():
file_io.delete_recursively(dir_name)
if __name__ == "__main__":
def main(argv):
del argv # Unused.
# Sanity check on the GCS bucket URL.
if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
@@ -210,7 +216,7 @@ if __name__ == "__main__":
# tf_record_iterator works.
record_iter = tf.python_io.tf_record_iterator(input_path)
read_count = 0
for r in record_iter:
for _ in record_iter:
read_count += 1
print("Read %d records using tf_record_iterator" % read_count)
@@ -222,7 +228,7 @@ if __name__ == "__main__":
# Verify that running the read op in a session works.
print("\n=== Testing TFRecordReader.read op in a session... ===")
with tf.Graph().as_default() as g:
with tf.Graph().as_default():
filename_queue = tf.train.string_input_producer([input_path], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
@@ -249,3 +255,7 @@ if __name__ == "__main__":
create_dir_test()
create_object_test()
if __name__ == "__main__":
tf.app.run(main)