diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 1485cf958cb..3edd9e70105 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -53,6 +53,7 @@ def optimize_loss(loss, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, + update_ops=None, variables=None, name=None): """Given loss and parameters for optimizer, returns a training op. @@ -81,6 +82,8 @@ def optimize_loss(loss, Can be used to implement any learning rate decay functions. For example: tf.train.exponential_decay. + update_ops: list of update `Operation`s to execute at each step. If `None`, + uses elements of UPDATE_OPS collection. variables: list of variables to optimize or `None` to use all trainable variables. name: The name for this operation is used to scope operations and summaries. @@ -92,6 +95,15 @@ def optimize_loss(loss, ValueError: if optimizer is wrong type. """ with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"): + # Update ops take UPDATE_OPS collection if not provided. + update_ops = (set(update_ops or []) or + set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))) + # Make sure update ops are ran before computing loss. + if update_ops: + with ops.control_dependencies(update_ops): + barrier = control_flow_ops.no_op(name="update_barrier") + loss = control_flow_ops.with_dependencies([barrier], loss) + # Moving average of the loss with decay. if moving_average_decay is not None: # Generate moving averages of the loss. diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 0f0bfe568b8..49baffb6f95 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -132,6 +132,25 @@ class OptimizersTest(tf.test.TestCase): tf.contrib.layers.optimize_loss( loss, global_step, learning_rate=0.1, optimizer="SGD") + def testUpdateOp(self): + optimizers = ["SGD", tf.train.GradientDescentOptimizer, + tf.train.GradientDescentOptimizer(learning_rate=0.1)] + for optimizer in optimizers: + with tf.Graph().as_default() as g: + with self.test_session(graph=g) as session: + x, var, loss, global_step = _setup_model() + update_op = tf.assign(var, 20) + train = tf.contrib.layers.optimize_loss(loss, + global_step, + learning_rate=0.1, + optimizer=optimizer, + update_ops=[update_op]) + tf.initialize_all_variables().run() + session.run(train, feed_dict={x: 5}) + var_value, global_step_value = session.run([var, global_step]) + # 19.5, due to update of var to 20 before loss computation. + self.assertEqual(var_value, 19.5) + self.assertEqual(global_step_value, 1) if __name__ == "__main__": tf.test.main()