mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
tf.learn: Adding update op to tf.contrib.layers.optimize_loss to support batch_norm like layers.
Change: 123324232
This commit is contained in:
committed by
TensorFlower Gardener
parent
a9112b4a0a
commit
037bfaf05a
@@ -53,6 +53,7 @@ def optimize_loss(loss,
|
||||
clip_gradients=None,
|
||||
moving_average_decay=0.9,
|
||||
learning_rate_decay_fn=None,
|
||||
update_ops=None,
|
||||
variables=None,
|
||||
name=None):
|
||||
"""Given loss and parameters for optimizer, returns a training op.
|
||||
@@ -81,6 +82,8 @@ def optimize_loss(loss,
|
||||
Can be used to implement any learning rate decay
|
||||
functions.
|
||||
For example: tf.train.exponential_decay.
|
||||
update_ops: list of update `Operation`s to execute at each step. If `None`,
|
||||
uses elements of UPDATE_OPS collection.
|
||||
variables: list of variables to optimize or
|
||||
`None` to use all trainable variables.
|
||||
name: The name for this operation is used to scope operations and summaries.
|
||||
@@ -92,6 +95,15 @@ def optimize_loss(loss,
|
||||
ValueError: if optimizer is wrong type.
|
||||
"""
|
||||
with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"):
|
||||
# Update ops take UPDATE_OPS collection if not provided.
|
||||
update_ops = (set(update_ops or []) or
|
||||
set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)))
|
||||
# Make sure update ops are ran before computing loss.
|
||||
if update_ops:
|
||||
with ops.control_dependencies(update_ops):
|
||||
barrier = control_flow_ops.no_op(name="update_barrier")
|
||||
loss = control_flow_ops.with_dependencies([barrier], loss)
|
||||
|
||||
# Moving average of the loss with decay.
|
||||
if moving_average_decay is not None:
|
||||
# Generate moving averages of the loss.
|
||||
|
||||
@@ -132,6 +132,25 @@ class OptimizersTest(tf.test.TestCase):
|
||||
tf.contrib.layers.optimize_loss(
|
||||
loss, global_step, learning_rate=0.1, optimizer="SGD")
|
||||
|
||||
def testUpdateOp(self):
|
||||
optimizers = ["SGD", tf.train.GradientDescentOptimizer,
|
||||
tf.train.GradientDescentOptimizer(learning_rate=0.1)]
|
||||
for optimizer in optimizers:
|
||||
with tf.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as session:
|
||||
x, var, loss, global_step = _setup_model()
|
||||
update_op = tf.assign(var, 20)
|
||||
train = tf.contrib.layers.optimize_loss(loss,
|
||||
global_step,
|
||||
learning_rate=0.1,
|
||||
optimizer=optimizer,
|
||||
update_ops=[update_op])
|
||||
tf.initialize_all_variables().run()
|
||||
session.run(train, feed_dict={x: 5})
|
||||
var_value, global_step_value = session.run([var, global_step])
|
||||
# 19.5, due to update of var to 20 before loss computation.
|
||||
self.assertEqual(var_value, 19.5)
|
||||
self.assertEqual(global_step_value, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
||||
Reference in New Issue
Block a user