mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Saves summaries in the mnist example.
PiperOrigin-RevId: 173690505
This commit is contained in:
committed by
TensorFlower Gardener
parent
6b05b36cd2
commit
16538dab77
@@ -90,7 +90,7 @@ class Metric(object):
|
||||
# We create the variable scope now to get the unique name that will
|
||||
# be used as a variable prefix when build() calls add_variable().
|
||||
with variable_scope.variable_scope(
|
||||
None, default_name=scope_name, use_resource=True, reuse=False) as scope:
|
||||
scope_name, use_resource=True, reuse=False) as scope:
|
||||
pos = scope.name.rfind(scope_name)
|
||||
self._name = name + scope.name[pos + len(scope_name):]
|
||||
self._scope = scope
|
||||
|
||||
@@ -120,24 +120,18 @@ class MetricsTest(test.TestCase):
|
||||
# Verify two metrics with the same class and name don't
|
||||
# accidentally share state.
|
||||
m1 = metrics.Mean()
|
||||
m2 = metrics.Mean()
|
||||
m1(0)
|
||||
m2(2)
|
||||
self.assertEqual(0, m1.result().numpy())
|
||||
self.assertEqual(2, m2.result().numpy())
|
||||
self.assertNotEqual(m1.name, m2.name)
|
||||
with self.assertRaises(ValueError):
|
||||
m2 = metrics.Mean()
|
||||
m2(2)
|
||||
|
||||
def testNamesWithSpaces(self):
|
||||
# Verify two metrics with the same class and name don't
|
||||
# accidentally share state.
|
||||
m1 = metrics.Mean("has space")
|
||||
m2 = metrics.Mean("has space")
|
||||
m2(2)
|
||||
m1(0)
|
||||
self.assertEqual(m1.name, "has space")
|
||||
self.assertEqual(m1.numer.name, "has_space/numer:0")
|
||||
self.assertEqual(m2.name, "has space_1")
|
||||
self.assertEqual(m2.numer.name, "has_space_1/numer:0")
|
||||
|
||||
def testGraph(self):
|
||||
with context.graph_mode(), self.test_session() as sess:
|
||||
@@ -158,16 +152,12 @@ class MetricsTest(test.TestCase):
|
||||
def testTwoMeansGraph(self):
|
||||
# Verify two metrics with the same class and name don't
|
||||
# accidentally share state.
|
||||
with context.graph_mode(), self.test_session() as sess:
|
||||
with context.graph_mode():
|
||||
m1 = metrics.Mean()
|
||||
m2 = metrics.Mean()
|
||||
accumulate1 = m1(0)
|
||||
accumulate2 = m2(2)
|
||||
m1.init_variables().run()
|
||||
m2.init_variables().run()
|
||||
sess.run([accumulate1, accumulate2])
|
||||
self.assertEqual(0, m1.result().eval())
|
||||
self.assertEqual(2, m2.result().eval())
|
||||
m1(0)
|
||||
with self.assertRaises(ValueError):
|
||||
m2 = metrics.Mean()
|
||||
m2(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -26,6 +26,8 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import summary_op_util
|
||||
from tensorflow.python.training import training_util
|
||||
@@ -57,7 +59,8 @@ def record_summaries_every_n_global_steps(n):
|
||||
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
|
||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
old = collection_ref[:]
|
||||
collection_ref[:] = [training_util.get_global_step() % n == 0]
|
||||
with ops.device("cpu:0"):
|
||||
collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
|
||||
yield
|
||||
collection_ref[:] = old
|
||||
|
||||
@@ -97,13 +100,17 @@ class SummaryWriter(object):
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def as_default(self):
|
||||
old = context.context().summary_writer_resource
|
||||
context.context().summary_writer_resource = self._resource
|
||||
yield
|
||||
# Flushes the summary writer in eager mode or in graph functions, but not in
|
||||
# legacy graph mode (you're on your own there).
|
||||
gen_summary_ops.flush_summary_writer(self._resource)
|
||||
context.context().summary_writer_resource = old
|
||||
if self._resource is None:
|
||||
yield
|
||||
else:
|
||||
old = context.context().summary_writer_resource
|
||||
context.context().summary_writer_resource = self._resource
|
||||
yield
|
||||
# Flushes the summary writer in eager mode or in graph functions, but not
|
||||
# in legacy graph mode (you're on your own there).
|
||||
with ops.device("cpu:0"):
|
||||
gen_summary_ops.flush_summary_writer(self._resource)
|
||||
context.context().summary_writer_resource = old
|
||||
|
||||
|
||||
def create_summary_file_writer(logdir,
|
||||
@@ -111,21 +118,40 @@ def create_summary_file_writer(logdir,
|
||||
flush_secs=None,
|
||||
filename_suffix=None,
|
||||
name=None):
|
||||
"""Creates a summary file writer in the current context."""
|
||||
if max_queue is None:
|
||||
max_queue = constant_op.constant(10)
|
||||
if flush_secs is None:
|
||||
flush_secs = constant_op.constant(120)
|
||||
if filename_suffix is None:
|
||||
filename_suffix = constant_op.constant("")
|
||||
resource = gen_summary_ops.summary_writer(shared_name=name)
|
||||
# TODO(apassos) ensure the initialization op runs when in graph mode; consider
|
||||
# calling session.run here.
|
||||
ops.add_to_collection(
|
||||
_SUMMARY_WRITER_INIT_COLLECTION_NAME,
|
||||
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
|
||||
flush_secs, filename_suffix))
|
||||
return SummaryWriter(resource)
|
||||
"""Creates a summary file writer in the current context.
|
||||
|
||||
Args:
|
||||
logdir: a string, or None. If a string, creates a summary file writer
|
||||
which writes to the directory named by the string. If None, returns
|
||||
a mock object which acts like a summary writer but does nothing,
|
||||
useful to use as a context manager.
|
||||
max_queue: the largest number of summaries to keep in a queue; will
|
||||
flush once the queue gets bigger than this.
|
||||
flush_secs: the largest interval (in seconds) between flushes.
|
||||
filename_suffix: optional suffix for the event file name.
|
||||
name: name for the summary writer.
|
||||
|
||||
Returns:
|
||||
Either a summary writer or an empty object which can be used as a
|
||||
summary writer.
|
||||
"""
|
||||
if logdir is None:
|
||||
return SummaryWriter(None)
|
||||
with ops.device("cpu:0"):
|
||||
if max_queue is None:
|
||||
max_queue = constant_op.constant(10)
|
||||
if flush_secs is None:
|
||||
flush_secs = constant_op.constant(120)
|
||||
if filename_suffix is None:
|
||||
filename_suffix = constant_op.constant("")
|
||||
resource = gen_summary_ops.summary_writer(shared_name=name)
|
||||
# TODO(apassos) ensure the initialization op runs when in graph mode;
|
||||
# consider calling session.run here.
|
||||
ops.add_to_collection(
|
||||
_SUMMARY_WRITER_INIT_COLLECTION_NAME,
|
||||
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
|
||||
flush_secs, filename_suffix))
|
||||
return SummaryWriter(resource)
|
||||
|
||||
|
||||
def _nothing():
|
||||
@@ -168,6 +194,8 @@ def summary_writer_function(name, tensor, function, family=None):
|
||||
with ops.control_dependencies([function(tag, scope)]):
|
||||
return constant_op.constant(True)
|
||||
|
||||
if context.context().summary_writer_resource is None:
|
||||
return control_flow_ops.no_op()
|
||||
with ops.device("cpu:0"):
|
||||
op = utils.smart_cond(
|
||||
should_record_summaries(), record, _nothing, name="")
|
||||
|
||||
@@ -119,13 +119,24 @@ def create_global_step(graph=None):
|
||||
raise ValueError('"global_step" already exists.')
|
||||
# Create in proper graph and base name_scope.
|
||||
with graph.as_default() as g, g.name_scope(None):
|
||||
if context.in_eager_mode():
|
||||
with ops.device('cpu:0'):
|
||||
return variable_scope.get_variable(
|
||||
ops.GraphKeys.GLOBAL_STEP,
|
||||
shape=[],
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.GLOBAL_STEP])
|
||||
return variable_scope.get_variable(
|
||||
ops.GraphKeys.GLOBAL_STEP,
|
||||
shape=[],
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.GLOBAL_STEP])
|
||||
|
||||
|
||||
def get_or_create_global_step(graph=None):
|
||||
|
||||
Reference in New Issue
Block a user