Saves summaries in the mnist example.

PiperOrigin-RevId: 173690505
This commit is contained in:
Alexandre Passos
2017-10-27 10:49:59 -07:00
committed by TensorFlower Gardener
parent 6b05b36cd2
commit 16538dab77
4 changed files with 72 additions and 43 deletions

View File

@@ -90,7 +90,7 @@ class Metric(object):
# We create the variable scope now to get the unique name that will
# be used as a variable prefix when build() calls add_variable().
with variable_scope.variable_scope(
None, default_name=scope_name, use_resource=True, reuse=False) as scope:
scope_name, use_resource=True, reuse=False) as scope:
pos = scope.name.rfind(scope_name)
self._name = name + scope.name[pos + len(scope_name):]
self._scope = scope

View File

@@ -120,24 +120,18 @@ class MetricsTest(test.TestCase):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean()
m2 = metrics.Mean()
m1(0)
m2(2)
self.assertEqual(0, m1.result().numpy())
self.assertEqual(2, m2.result().numpy())
self.assertNotEqual(m1.name, m2.name)
with self.assertRaises(ValueError):
m2 = metrics.Mean()
m2(2)
def testNamesWithSpaces(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean("has space")
m2 = metrics.Mean("has space")
m2(2)
m1(0)
self.assertEqual(m1.name, "has space")
self.assertEqual(m1.numer.name, "has_space/numer:0")
self.assertEqual(m2.name, "has space_1")
self.assertEqual(m2.numer.name, "has_space_1/numer:0")
def testGraph(self):
with context.graph_mode(), self.test_session() as sess:
@@ -158,16 +152,12 @@ class MetricsTest(test.TestCase):
def testTwoMeansGraph(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
with context.graph_mode(), self.test_session() as sess:
with context.graph_mode():
m1 = metrics.Mean()
m2 = metrics.Mean()
accumulate1 = m1(0)
accumulate2 = m2(2)
m1.init_variables().run()
m2.init_variables().run()
sess.run([accumulate1, accumulate2])
self.assertEqual(0, m1.result().eval())
self.assertEqual(2, m2.result().eval())
m1(0)
with self.assertRaises(ValueError):
m2 = metrics.Mean()
m2(2)
if __name__ == "__main__":

View File

@@ -26,6 +26,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import summary_op_util
from tensorflow.python.training import training_util
@@ -57,7 +59,8 @@ def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
collection_ref[:] = [training_util.get_global_step() % n == 0]
with ops.device("cpu:0"):
collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
yield
collection_ref[:] = old
@@ -97,13 +100,17 @@ class SummaryWriter(object):
@tf_contextlib.contextmanager
def as_default(self):
old = context.context().summary_writer_resource
context.context().summary_writer_resource = self._resource
yield
# Flushes the summary writer in eager mode or in graph functions, but not in
# legacy graph mode (you're on your own there).
gen_summary_ops.flush_summary_writer(self._resource)
context.context().summary_writer_resource = old
if self._resource is None:
yield
else:
old = context.context().summary_writer_resource
context.context().summary_writer_resource = self._resource
yield
# Flushes the summary writer in eager mode or in graph functions, but not
# in legacy graph mode (you're on your own there).
with ops.device("cpu:0"):
gen_summary_ops.flush_summary_writer(self._resource)
context.context().summary_writer_resource = old
def create_summary_file_writer(logdir,
@@ -111,21 +118,40 @@ def create_summary_file_writer(logdir,
flush_secs=None,
filename_suffix=None,
name=None):
"""Creates a summary file writer in the current context."""
if max_queue is None:
max_queue = constant_op.constant(10)
if flush_secs is None:
flush_secs = constant_op.constant(120)
if filename_suffix is None:
filename_suffix = constant_op.constant("")
resource = gen_summary_ops.summary_writer(shared_name=name)
# TODO(apassos) ensure the initialization op runs when in graph mode; consider
# calling session.run here.
ops.add_to_collection(
_SUMMARY_WRITER_INIT_COLLECTION_NAME,
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
flush_secs, filename_suffix))
return SummaryWriter(resource)
"""Creates a summary file writer in the current context.
Args:
logdir: a string, or None. If a string, creates a summary file writer
which writes to the directory named by the string. If None, returns
a mock object which acts like a summary writer but does nothing,
useful to use as a context manager.
max_queue: the largest number of summaries to keep in a queue; will
flush once the queue gets bigger than this.
flush_secs: the largest interval (in seconds) between flushes.
filename_suffix: optional suffix for the event file name.
name: name for the summary writer.
Returns:
Either a summary writer or an empty object which can be used as a
summary writer.
"""
if logdir is None:
return SummaryWriter(None)
with ops.device("cpu:0"):
if max_queue is None:
max_queue = constant_op.constant(10)
if flush_secs is None:
flush_secs = constant_op.constant(120)
if filename_suffix is None:
filename_suffix = constant_op.constant("")
resource = gen_summary_ops.summary_writer(shared_name=name)
# TODO(apassos) ensure the initialization op runs when in graph mode;
# consider calling session.run here.
ops.add_to_collection(
_SUMMARY_WRITER_INIT_COLLECTION_NAME,
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
flush_secs, filename_suffix))
return SummaryWriter(resource)
def _nothing():
@@ -168,6 +194,8 @@ def summary_writer_function(name, tensor, function, family=None):
with ops.control_dependencies([function(tag, scope)]):
return constant_op.constant(True)
if context.context().summary_writer_resource is None:
return control_flow_ops.no_op()
with ops.device("cpu:0"):
op = utils.smart_cond(
should_record_summaries(), record, _nothing, name="")

View File

@@ -119,13 +119,24 @@ def create_global_step(graph=None):
raise ValueError('"global_step" already exists.')
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
if context.in_eager_mode():
with ops.device('cpu:0'):
return variable_scope.get_variable(
ops.GraphKeys.GLOBAL_STEP,
shape=[],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
return variable_scope.get_variable(
ops.GraphKeys.GLOBAL_STEP,
shape=[],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
def get_or_create_global_step(graph=None):