mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Use tf.app.run in gcs_smoke, so that the flags are explicitly parsed, instead of parsed when first accessed.
PiperOrigin-RevId: 173702828
This commit is contained in:
committed by
TensorFlower Gardener
parent
3d39b32b9a
commit
9158f974a3
@@ -35,6 +35,7 @@ flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def create_examples(num_examples, input_mean):
|
||||
"""Create ExampleProto's containing data."""
|
||||
ids = np.arange(num_examples).reshape([num_examples, 1])
|
||||
@@ -49,6 +50,7 @@ def create_examples(num_examples, input_mean):
|
||||
examples.append(ex)
|
||||
return examples
|
||||
|
||||
|
||||
def create_dir_test():
|
||||
"""Verifies file_io directory handling methods."""
|
||||
|
||||
@@ -122,6 +124,7 @@ def create_dir_test():
|
||||
print("Deleted directory recursively %s in %s milliseconds" % (
|
||||
dir_name, elapsed_ms))
|
||||
|
||||
|
||||
def create_object_test():
|
||||
"""Verifies file_io's object manipulation methods ."""
|
||||
starttime_ms = int(round(time.time() * 1000))
|
||||
@@ -142,7 +145,8 @@ def create_object_test():
|
||||
print("Creating file %s." % file_name)
|
||||
file_io.write_string_to_file(file_name, "test file creation.")
|
||||
elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
|
||||
print("Created %d files in %s milliseconds" % (len(files_to_create), elapsed_ms))
|
||||
print("Created %d files in %s milliseconds" % (
|
||||
len(files_to_create), elapsed_ms))
|
||||
|
||||
# Listing files of pattern1.
|
||||
list_files_pattern = "%s/test_file*.txt" % dir_name
|
||||
@@ -185,7 +189,9 @@ def create_object_test():
|
||||
file_io.delete_recursively(dir_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
# Sanity check on the GCS bucket URL.
|
||||
if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
|
||||
print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
|
||||
@@ -210,7 +216,7 @@ if __name__ == "__main__":
|
||||
# tf_record_iterator works.
|
||||
record_iter = tf.python_io.tf_record_iterator(input_path)
|
||||
read_count = 0
|
||||
for r in record_iter:
|
||||
for _ in record_iter:
|
||||
read_count += 1
|
||||
print("Read %d records using tf_record_iterator" % read_count)
|
||||
|
||||
@@ -222,7 +228,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Verify that running the read op in a session works.
|
||||
print("\n=== Testing TFRecordReader.read op in a session... ===")
|
||||
with tf.Graph().as_default() as g:
|
||||
with tf.Graph().as_default():
|
||||
filename_queue = tf.train.string_input_producer([input_path], num_epochs=1)
|
||||
reader = tf.TFRecordReader()
|
||||
_, serialized_example = reader.read(filename_queue)
|
||||
@@ -249,3 +255,7 @@ if __name__ == "__main__":
|
||||
|
||||
create_dir_test()
|
||||
create_object_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run(main)
|
||||
|
||||
Reference in New Issue
Block a user