diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index d18df4dffb7..20532c8ee8e 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -569,5 +569,17 @@ class BackpropTest(test.TestCase): var.assign_sub(lr*grad) self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.]) + def testCustomGradientIdentity(self): + + @custom_gradient.custom_gradient + def my_identity(x): + + def grad(dresult): + return [2 * dresult] + + return x, grad + + self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py index 4ac30075b2f..05460ff9968 100644 --- a/tensorflow/python/eager/custom_gradient.py +++ b/tensorflow/python/eager/custom_gradient.py @@ -22,6 +22,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -72,17 +73,19 @@ def custom_gradient(f): with tape.stop_recording(): result, grad_fn = f(*args, **kwargs) + flat_result = nest.flatten(result) + # TODO(apassos) consider removing the identity below. + flat_result = [gen_array_ops.identity(x) for x in flat_result] def actual_grad_fn(*outputs): return nest.flatten(grad_fn(*outputs)) - flat_result = nest.flatten(result) tape.record_operation( f.__name__, flat_result, input_tensors, actual_grad_fn) flat_result = list(flat_result) - return result + return nest.pack_sequence_as(result, flat_result) return tf_decorator.make_decorator(f, decorated)