mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Makes tape.watch() work with ResourceVariables.
To this end, also adds a property, `device`, to TensorNode. PiperOrigin-RevId: 165726368
This commit is contained in:
committed by
TensorFlower Gardener
parent
80bd004cdc
commit
360bff8ae5
@@ -55,13 +55,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":core",
|
||||
":tape",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
@@ -88,6 +82,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ from autograd import core as ag_core
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
@@ -143,6 +144,10 @@ def watch(tensor):
|
||||
Returns:
|
||||
The tensor, potentially wrapped by all tapes in the stack.
|
||||
"""
|
||||
if isinstance(tensor, resource_variable_ops.ResourceVariable):
|
||||
tensor._handle = watch(tensor.handle) # pylint: disable=protected-access
|
||||
return tensor
|
||||
|
||||
for t in _tape_stack.stack:
|
||||
tensor = _watch_with_tape(t, tensor)
|
||||
return tensor
|
||||
|
||||
@@ -60,6 +60,7 @@ class TensorNode(ag_core.Node):
|
||||
|
||||
shape = property(lambda self: self.value.shape)
|
||||
dtype = property(lambda self: self.value.dtype)
|
||||
device = property(lambda self: self.value.device)
|
||||
|
||||
def get_shape(self):
|
||||
return self.shape
|
||||
|
||||
Reference in New Issue
Block a user