Makes tape.watch() work with ResourceVariables.

To this end, also adds a property, `device`, to TensorNode.

PiperOrigin-RevId: 165726368
This commit is contained in:
Ali Yahya
2017-08-18 11:02:26 -07:00
committed by TensorFlower Gardener
parent 80bd004cdc
commit 360bff8ae5
3 changed files with 8 additions and 7 deletions

View File

@@ -55,13 +55,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":core",
":tape",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:framework_ops",
"//third_party/py/numpy",
],
)
@@ -88,6 +82,7 @@ py_library(
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:util",
],
)

View File

@@ -25,6 +25,7 @@ from autograd import core as ag_core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@@ -143,6 +144,10 @@ def watch(tensor):
Returns:
The tensor, potentially wrapped by all tapes in the stack.
"""
if isinstance(tensor, resource_variable_ops.ResourceVariable):
tensor._handle = watch(tensor.handle) # pylint: disable=protected-access
return tensor
for t in _tape_stack.stack:
tensor = _watch_with_tape(t, tensor)
return tensor

View File

@@ -60,6 +60,7 @@ class TensorNode(ag_core.Node):
shape = property(lambda self: self.value.shape)
dtype = property(lambda self: self.value.dtype)
device = property(lambda self: self.value.device)
def get_shape(self):
return self.shape