diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 795dff548f7..2ba653af4a2 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -90,7 +90,7 @@ class Metric(object): # We create the variable scope now to get the unique name that will # be used as a variable prefix when build() calls add_variable(). with variable_scope.variable_scope( - None, default_name=scope_name, use_resource=True, reuse=False) as scope: + scope_name, use_resource=True, reuse=False) as scope: pos = scope.name.rfind(scope_name) self._name = name + scope.name[pos + len(scope_name):] self._scope = scope diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index a8377a06605..336ce9d307c 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -120,24 +120,18 @@ class MetricsTest(test.TestCase): # Verify two metrics with the same class and name don't # accidentally share state. m1 = metrics.Mean() - m2 = metrics.Mean() m1(0) - m2(2) - self.assertEqual(0, m1.result().numpy()) - self.assertEqual(2, m2.result().numpy()) - self.assertNotEqual(m1.name, m2.name) + with self.assertRaises(ValueError): + m2 = metrics.Mean() + m2(2) def testNamesWithSpaces(self): # Verify two metrics with the same class and name don't # accidentally share state. m1 = metrics.Mean("has space") - m2 = metrics.Mean("has space") - m2(2) m1(0) self.assertEqual(m1.name, "has space") self.assertEqual(m1.numer.name, "has_space/numer:0") - self.assertEqual(m2.name, "has space_1") - self.assertEqual(m2.numer.name, "has_space_1/numer:0") def testGraph(self): with context.graph_mode(), self.test_session() as sess: @@ -158,16 +152,12 @@ class MetricsTest(test.TestCase): def testTwoMeansGraph(self): # Verify two metrics with the same class and name don't # accidentally share state. - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(): m1 = metrics.Mean() - m2 = metrics.Mean() - accumulate1 = m1(0) - accumulate2 = m2(2) - m1.init_variables().run() - m2.init_variables().run() - sess.run([accumulate1, accumulate2]) - self.assertEqual(0, m1.result().eval()) - self.assertEqual(2, m2.result().eval()) + m1(0) + with self.assertRaises(ValueError): + m2 = metrics.Mean() + m2(2) if __name__ == "__main__": diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index b32b093675c..9c71bf7740c 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import summary_op_util from tensorflow.python.training import training_util @@ -57,7 +59,8 @@ def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] - collection_ref[:] = [training_util.get_global_step() % n == 0] + with ops.device("cpu:0"): + collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)] yield collection_ref[:] = old @@ -97,13 +100,17 @@ class SummaryWriter(object): @tf_contextlib.contextmanager def as_default(self): - old = context.context().summary_writer_resource - context.context().summary_writer_resource = self._resource - yield - # Flushes the summary writer in eager mode or in graph functions, but not in - # legacy graph mode (you're on your own there). - gen_summary_ops.flush_summary_writer(self._resource) - context.context().summary_writer_resource = old + if self._resource is None: + yield + else: + old = context.context().summary_writer_resource + context.context().summary_writer_resource = self._resource + yield + # Flushes the summary writer in eager mode or in graph functions, but not + # in legacy graph mode (you're on your own there). + with ops.device("cpu:0"): + gen_summary_ops.flush_summary_writer(self._resource) + context.context().summary_writer_resource = old def create_summary_file_writer(logdir, @@ -111,21 +118,40 @@ def create_summary_file_writer(logdir, flush_secs=None, filename_suffix=None, name=None): - """Creates a summary file writer in the current context.""" - if max_queue is None: - max_queue = constant_op.constant(10) - if flush_secs is None: - flush_secs = constant_op.constant(120) - if filename_suffix is None: - filename_suffix = constant_op.constant("") - resource = gen_summary_ops.summary_writer(shared_name=name) - # TODO(apassos) ensure the initialization op runs when in graph mode; consider - # calling session.run here. - ops.add_to_collection( - _SUMMARY_WRITER_INIT_COLLECTION_NAME, - gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, - flush_secs, filename_suffix)) - return SummaryWriter(resource) + """Creates a summary file writer in the current context. + + Args: + logdir: a string, or None. If a string, creates a summary file writer + which writes to the directory named by the string. If None, returns + a mock object which acts like a summary writer but does nothing, + useful to use as a context manager. + max_queue: the largest number of summaries to keep in a queue; will + flush once the queue gets bigger than this. + flush_secs: the largest interval (in seconds) between flushes. + filename_suffix: optional suffix for the event file name. + name: name for the summary writer. + + Returns: + Either a summary writer or an empty object which can be used as a + summary writer. + """ + if logdir is None: + return SummaryWriter(None) + with ops.device("cpu:0"): + if max_queue is None: + max_queue = constant_op.constant(10) + if flush_secs is None: + flush_secs = constant_op.constant(120) + if filename_suffix is None: + filename_suffix = constant_op.constant("") + resource = gen_summary_ops.summary_writer(shared_name=name) + # TODO(apassos) ensure the initialization op runs when in graph mode; + # consider calling session.run here. + ops.add_to_collection( + _SUMMARY_WRITER_INIT_COLLECTION_NAME, + gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, + flush_secs, filename_suffix)) + return SummaryWriter(resource) def _nothing(): @@ -168,6 +194,8 @@ def summary_writer_function(name, tensor, function, family=None): with ops.control_dependencies([function(tag, scope)]): return constant_op.constant(True) + if context.context().summary_writer_resource is None: + return control_flow_ops.no_op() with ops.device("cpu:0"): op = utils.smart_cond( should_record_summaries(), record, _nothing, name="") diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index bdd4ca734eb..89a9e129328 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -119,13 +119,24 @@ def create_global_step(graph=None): raise ValueError('"global_step" already exists.') # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): + if context.in_eager_mode(): + with ops.device('cpu:0'): + return variable_scope.get_variable( + ops.GraphKeys.GLOBAL_STEP, + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES, + ops.GraphKeys.GLOBAL_STEP]) return variable_scope.get_variable( ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) + collections=[ops.GraphKeys.GLOBAL_VARIABLES, + ops.GraphKeys.GLOBAL_STEP]) def get_or_create_global_step(graph=None):