tf.learn: Adding update op to tf.contrib.layers.optimize_loss to support batch_norm like layers.

Change: 123324232
This commit is contained in:
Illia Polosukhin
2016-05-26 08:04:08 -08:00
committed by TensorFlower Gardener
parent a9112b4a0a
commit 037bfaf05a
2 changed files with 31 additions and 0 deletions

View File

@@ -53,6 +53,7 @@ def optimize_loss(loss,
clip_gradients=None,
moving_average_decay=0.9,
learning_rate_decay_fn=None,
update_ops=None,
variables=None,
name=None):
"""Given loss and parameters for optimizer, returns a training op.
@@ -81,6 +82,8 @@ def optimize_loss(loss,
Can be used to implement any learning rate decay
functions.
For example: tf.train.exponential_decay.
update_ops: list of update `Operation`s to execute at each step. If `None`,
uses elements of UPDATE_OPS collection.
variables: list of variables to optimize or
`None` to use all trainable variables.
name: The name for this operation is used to scope operations and summaries.
@@ -92,6 +95,15 @@ def optimize_loss(loss,
ValueError: if optimizer is wrong type.
"""
with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"):
# Update ops take UPDATE_OPS collection if not provided.
update_ops = (set(update_ops or []) or
set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)))
# Make sure update ops are ran before computing loss.
if update_ops:
with ops.control_dependencies(update_ops):
barrier = control_flow_ops.no_op(name="update_barrier")
loss = control_flow_ops.with_dependencies([barrier], loss)
# Moving average of the loss with decay.
if moving_average_decay is not None:
# Generate moving averages of the loss.

View File

@@ -132,6 +132,25 @@ class OptimizersTest(tf.test.TestCase):
tf.contrib.layers.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer="SGD")
def testUpdateOp(self):
optimizers = ["SGD", tf.train.GradientDescentOptimizer,
tf.train.GradientDescentOptimizer(learning_rate=0.1)]
for optimizer in optimizers:
with tf.Graph().as_default() as g:
with self.test_session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_op = tf.assign(var, 20)
train = tf.contrib.layers.optimize_loss(loss,
global_step,
learning_rate=0.1,
optimizer=optimizer,
update_ops=[update_op])
tf.initialize_all_variables().run()
session.run(train, feed_dict={x: 5})
var_value, global_step_value = session.run([var, global_step])
# 19.5, due to update of var to 20 before loss computation.
self.assertEqual(var_value, 19.5)
self.assertEqual(global_step_value, 1)
if __name__ == "__main__":
tf.test.main()