Introduce C++ OperationHandle type and move simple methods from ops.py::Operation onto it.

Over time we would like to move most graph manipulation machinery out of Python and into these internal classes.

PiperOrigin-RevId: 509274663
This commit is contained in:
Russell Power
2023-02-13 10:51:45 -08:00
committed by TensorFlower Gardener
parent 137eb17db4
commit d702027b3c
19 changed files with 158 additions and 105 deletions

View File

@@ -115,6 +115,9 @@ struct TF_OperationDescription {
struct TF_Operation {
tensorflow::Node node;
private:
~TF_Operation() = default;
};
struct TF_Session {

View File

@@ -309,6 +309,7 @@ third_party/py/BUILD:
third_party/py/numpy/BUILD:
third_party/py/python_configure.bzl:
third_party/pybind11.BUILD:
third_party/pybind11_protobuf/BUILD:
third_party/python_runtime/BUILD:
third_party/remote_config/BUILD.tpl:
third_party/remote_config/BUILD:

View File

@@ -44,6 +44,7 @@ tf_python_pybind_extension(
],
deps = [
"//third_party/eigen3",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
"//tensorflow/python/lib/core:pybind11_lib",
"//tensorflow/python/lib/core:pybind11_status",
"//tensorflow/python/lib/core:safe_pyobject_ptr",

View File

@@ -19,9 +19,10 @@ limitations under the License.
#include "pybind11/complex.h"
#include "pybind11/functional.h"
#include "pybind11/pybind11.h"
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
#include "pybind11/stl.h"
// clang-format on
// clang-format on
#include "Python.h"
// Must be included first
@@ -107,12 +108,54 @@ PYBIND11_MAKE_OPAQUE(TF_Server);
PYBIND11_MAKE_OPAQUE(TF_DeviceList);
PYBIND11_MAKE_OPAQUE(TF_Status);
class TensorHandle {};
class OperationHandle {
public:
const TF_Operation* op() { return op_; }
void set_op(TF_Operation* op) { op_ = op; }
py::bytes node_def() {
return py::bytes(op_->node.def().SerializeAsString());
}
py::bytes op_def() {
return py::bytes(op_->node.op_def().SerializeAsString());
}
const std::string& type() { return op_->node.type_string(); }
TF_Output _tf_output(int idx) { return TF_Output{op_, idx}; }
TF_Input _tf_input(int idx) { return TF_Input{op_, idx}; }
private:
TF_Operation* op_;
};
class GraphHandle {
public:
};
PYBIND11_MODULE(_pywrap_tf_session, m) {
pybind11_protobuf::ImportNativeProtoCasters();
// Numpy initialization code for array checks.
tsl::ImportNumpy();
py::class_<TF_Graph> TF_Graph_class(m, "TF_Graph");
py::class_<TF_Operation> TF_Operation_class(m, "TF_Operation");
py::class_<TF_Operation, std::unique_ptr<TF_Operation, py::nodelete>>
TF_Operation_class(m, "TF_Operation");
py::class_<GraphHandle>(m, "GraphHandle").def(py::init<>());
py::class_<OperationHandle>(m, "OperationHandle")
.def(py::init<>())
.def("_tf_output", &OperationHandle::_tf_output)
.def("_tf_input", &OperationHandle::_tf_input)
.def_property_readonly("_op_def", &OperationHandle::op_def)
.def_property_readonly("_node_def", &OperationHandle::node_def)
.def_property_readonly("type", &OperationHandle::type)
.def_property("_c_op", &OperationHandle::op, &OperationHandle::set_op,
py::return_value_policy::reference);
py::class_<TF_Output>(m, "TF_Output")
.def(py::init<>())
@@ -688,7 +731,8 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
m.def("SetRequestedDevice", tensorflow::SetRequestedDevice);
// TF_Buffer util methods
// TODO(amitpatankar): Consolidate Buffer methods into a separate header file.
// TODO(amitpatankar): Consolidate Buffer methods into a separate header
// file.
m.def("TF_NewBuffer", TF_NewBuffer, py::return_value_policy::reference);
m.def("TF_GetBuffer", [](TF_Buffer* buf) {
TF_Buffer buffer = TF_GetBuffer(buf);
@@ -726,8 +770,8 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
});
// Note: users should prefer using tf.cast or equivalent, and only when
// it's infeasible to set the type via OpDef's type constructor and inference
// function.
// it's infeasible to set the type via OpDef's type constructor and
// inference function.
m.def("SetFullType", [](TF_Graph* graph, TF_Operation* op,
const std::string& serialized_full_type) {
tensorflow::FullTypeDef proto;

View File

@@ -1914,7 +1914,7 @@ def _create_c_op(graph,
@tf_export("Operation")
class Operation(object):
class Operation(pywrap_tf_session.OperationHandle):
"""Represents a graph node that performs computation on tensors.
An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor`
@@ -1979,17 +1979,11 @@ class Operation(object):
or if `inputs` and `input_types` are incompatible.
ValueError: if the `node_def` name is not valid.
"""
super().__init__()
if not isinstance(g, Graph):
raise TypeError(f"Argument g must be a Graph. "
f"Received an instance of type {type(g)}")
# TODO(feyu): This message is redundant with the check below. We raise it
# to help users to migrate. Remove this after 07/01/2022.
if isinstance(node_def, pywrap_tf_session.TF_Operation):
raise ValueError(
"Calling Operation() with node_def of a TF_Operation is deprecated. "
"Please switch to Operation.from_c_op.")
if not isinstance(node_def, node_def_pb2.NodeDef):
raise TypeError(f"Argument node_def must be a NodeDef. "
f"Received an instance of type: {type(node_def)}.")
@@ -2070,8 +2064,7 @@ class Operation(object):
Returns:
an Operation object.
"""
self = object.__new__(cls)
self = Operation.__new__(cls)
self._init_from_c_op(c_op=c_op, g=g) # pylint: disable=protected-access
return self
@@ -2087,7 +2080,6 @@ class Operation(object):
f"got {type(c_op)} for argument c_op.")
self._original_op = None
self._graph = g
self._c_op = c_op
@@ -2287,20 +2279,6 @@ class Operation(object):
return output_types
def _tf_output(self, output_idx):
"""Create and return a new TF_Output for output_idx'th output of this op."""
tf_output = pywrap_tf_session.TF_Output()
tf_output.oper = self._c_op
tf_output.index = output_idx
return tf_output
def _tf_input(self, input_idx):
"""Create and return a new TF_Input for input_idx'th input of this op."""
tf_input = pywrap_tf_session.TF_Input()
tf_input.oper = self._c_op
tf_input.index = input_idx
return tf_input
def _set_device(self, device): # pylint: disable=redefined-outer-name
"""Set the device of this operation.
@@ -2515,47 +2493,11 @@ class Operation(object):
]
# pylint: enable=protected-access
@property
def type(self):
"""The type of the op (e.g. `"MatMul"`)."""
return pywrap_tf_session.TF_OperationOpType(self._c_op)
@property
def graph(self):
"""The `Graph` that contains this operation."""
return self._graph
@property
def node_def(self):
# pylint: disable=line-too-long
"""Returns the `NodeDef` representation of this operation.
Returns:
A
[`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
protocol buffer.
"""
# pylint: enable=line-too-long
with c_api_util.tf_buffer() as buf:
pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf)
data = pywrap_tf_session.TF_GetBuffer(buf)
node_def = node_def_pb2.NodeDef()
node_def.ParseFromString(compat.as_bytes(data))
return node_def
@property
def op_def(self):
# pylint: disable=line-too-long
"""Returns the `OpDef` proto that represents the type of this op.
Returns:
An
[`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
protocol buffer.
"""
# pylint: enable=line-too-long
return self._graph._get_op_def(self.type)
@property
def traceback(self):
"""Returns the call stack from when this operation was constructed."""
@@ -2563,6 +2505,14 @@ class Operation(object):
# goes out of scope.
return pywrap_tf_session.TF_OperationGetStackTrace(self._c_op)
@property
def node_def(self):
return node_def_pb2.NodeDef.FromString(self._node_def)
@property
def op_def(self):
return op_def_pb2.OpDef.FromString(self._op_def)
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
buf = pywrap_tf_session.TF_NewBufferFromString(
@@ -2714,6 +2664,7 @@ class Operation(object):
"""
_run_using_default_session(self, feed_dict, self.graph, session)
# TODO(b/185395742): Clean up usages of _gradient_registry
gradient_registry = _gradient_registry = registry.Registry("gradient")
@@ -2993,7 +2944,7 @@ def resource_creator_scope(resource_type, resource_creator):
@tf_export("Graph")
class Graph(object):
class Graph(pywrap_tf_session.GraphHandle):
"""A TensorFlow computation, represented as a dataflow graph.
Graphs are used by `tf.function`s to represent the function's computations.
@@ -3039,6 +2990,7 @@ class Graph(object):
def __init__(self):
"""Creates a new, empty Graph."""
super().__init__()
# Protects core state that can be returned via public accessors.
# Thread-safety is provided on a best-effort basis to support buggy
# programs, and is not guaranteed by the public `tf.Graph` API.
@@ -4843,7 +4795,7 @@ class Graph(object):
self._old_stack = None
self._old_control_flow_context = None
# pylint: disable=protected-access
# pylint: disable=protected-access
def __enter__(self):
if self._new_stack:
@@ -4861,7 +4813,7 @@ class Graph(object):
self._graph._control_dependencies_stack = self._old_stack
self._graph._set_control_flow_context(self._old_control_flow_context)
# pylint: enable=protected-access
# pylint: enable=protected-access
@property
def control_inputs(self):

View File

@@ -1435,6 +1435,7 @@ class _OperationWithOutputs(ops.Operation):
"""
def __init__(self, c_op, g):
super(ops.Operation, self).__init__()
self._c_op = c_op
self._graph = g
self._outputs = None # Initialized by _duplicate_body_captures_in_cond().

View File

@@ -32,7 +32,6 @@ from tensorflow.python.compiler.xla import xla
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
@@ -318,9 +317,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._gradient_colocation_stack = []
self._host_compute_core = []
self._name = name
self._name_as_bytes = compat.as_bytes(name)
self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
self._tpu_replicate_attr = attr_value_pb2.AttrValue(s=compat.as_bytes(name))
self._unsupported_ops = []
self._pivot = pivot
self._replicated_vars = {}
@@ -373,7 +370,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
else:
raise ValueError(
"Failed to find a variable on any device in replica {} for "
"current device assignment".format(replica_id))
"current device assignment".format(replica_id)
)
else:
replicated_vars = vars_
@@ -610,8 +608,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if (_TPU_REPLICATE_ATTR in op.node_def.attr and
"_cloned" not in op.node_def.attr):
raise ValueError(f"TPU computations cannot be nested on op ({op})")
op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
self._tpu_relicate_attr_buf.buffer)
op._set_attr(_TPU_REPLICATE_ATTR, self._tpu_replicate_attr)
if self._outside_compilation_cluster:
op._set_attr(
_OUTSIDE_COMPILATION_ATTR,

View File

@@ -20,7 +20,6 @@ from absl import logging
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph
@@ -115,9 +114,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._gradient_colocation_stack = []
self._host_compute_core = []
self._name = name
self._name_as_bytes = compat.as_bytes(name)
self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
self._tpu_replicate_attr = attr_value_pb2.AttrValue(
s=compat.as_bytes(self._name)
)
self._unsupported_ops = []
self._pivot = pivot
self._replicated_vars = {}
@@ -170,7 +169,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
else:
raise ValueError(
"Failed to find a variable on any device in replica {} for "
"current device assignment".format(replica_id))
"current device assignment".format(replica_id)
)
else:
replicated_vars = vars_
@@ -413,8 +413,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if (_TPU_REPLICATE_ATTR in op.node_def.attr and
"_cloned" not in op.node_def.attr):
raise ValueError(f"TPU computations cannot be nested on op ({op})")
op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
self._tpu_relicate_attr_buf.buffer)
op._set_attr(_TPU_REPLICATE_ATTR, self._tpu_replicate_attr)
if self._outside_compilation_cluster:
op._set_attr(
_OUTSIDE_COMPILATION_ATTR,

View File

@@ -122,6 +122,43 @@ def _new_mark_used(self, *args, **kwargs):
except AttributeError:
pass
OVERLOADABLE_OPERATORS = {
'__add__',
'__radd__',
'__sub__',
'__rsub__',
'__mul__',
'__rmul__',
'__div__',
'__rdiv__',
'__truediv__',
'__rtruediv__',
'__floordiv__',
'__rfloordiv__',
'__mod__',
'__rmod__',
'__lt__',
'__le__',
'__gt__',
'__ge__',
'__ne__',
'__eq__',
'__and__',
'__rand__',
'__or__',
'__ror__',
'__xor__',
'__rxor__',
'__getitem__',
'__pow__',
'__rpow__',
'__invert__',
'__neg__',
'__abs__',
'__matmul__',
'__rmatmul__',
}
_WRAPPERS = {}
@@ -144,23 +181,22 @@ def _get_wrapper(x, tf_should_use_helper):
if memoized:
return memoized(x, tf_should_use_helper)
tx = copy.deepcopy(type_x)
# Make a copy of `object`
tx = copy.deepcopy(object)
# Prefer using __orig_bases__, which preserve generic type arguments.
bases = getattr(tx, '__orig_bases__', tx.__bases__)
# Use types.new_class when available, which is preferred over plain type in
# some distributions.
if sys.version_info >= (3, 5):
def set_body(ns):
ns.update(tx.__dict__)
return ns
copy_tx = types.new_class(tx.__name__, bases, exec_body=set_body)
else:
copy_tx = type(tx.__name__, bases, dict(tx.__dict__))
def set_body(ns):
ns.update(tx.__dict__)
return ns
copy_tx = types.new_class(tx.__name__, bases, exec_body=set_body)
copy_tx.__init__ = _new__init__
copy_tx.__getattribute__ = _new__getattribute__
for op in OVERLOADABLE_OPERATORS:
if hasattr(type_x, op):
setattr(copy_tx, op, getattr(type_x, op))
copy_tx.mark_used = _new_mark_used
copy_tx.__setattr__ = _new__setattr__
_WRAPPERS[type_x] = copy_tx

View File

@@ -1,7 +1,8 @@
path: "tensorflow.Graph"
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Graph\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python.client._pywrap_tf_session.GraphHandle\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "building_function"
mtype: "<type \'property\'>"

View File

@@ -1,7 +1,8 @@
path: "tensorflow.Operation"
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Operation\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python.client._pywrap_tf_session.OperationHandle\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "control_inputs"
mtype: "<type \'property\'>"

View File

@@ -98,7 +98,7 @@ tf_module {
}
member {
name: "Graph"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "GraphDef"
@@ -166,7 +166,7 @@ tf_module {
}
member {
name: "Operation"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "OptimizerOptions"

View File

@@ -1,7 +1,8 @@
path: "tensorflow.Graph"
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Graph\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python.client._pywrap_tf_session.GraphHandle\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "building_function"
mtype: "<type \'property\'>"

View File

@@ -1,7 +1,8 @@
path: "tensorflow.Operation"
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Operation\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python.client._pywrap_tf_session.OperationHandle\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "control_inputs"
mtype: "<type \'property\'>"

View File

@@ -2,7 +2,8 @@ path: "tensorflow.__internal__.FuncGraph"
tf_class {
is_instance: "<class \'tensorflow.python.framework.func_graph.FuncGraph\'>"
is_instance: "<class \'tensorflow.python.framework.ops.Graph\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python.client._pywrap_tf_session.GraphHandle\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "building_function"
mtype: "<type \'property\'>"

View File

@@ -10,7 +10,7 @@ tf_module {
}
member {
name: "FuncGraph"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "autograph"

View File

@@ -22,7 +22,7 @@ tf_module {
}
member {
name: "Graph"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "IndexedSlices"
@@ -38,7 +38,7 @@ tf_module {
}
member {
name: "Operation"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "OptionalSpec"

View File

@@ -876,8 +876,9 @@ def _tf_repositories():
tf_http_archive(
name = "pybind11_protobuf",
urls = tf_mirror_urls("https://github.com/pybind/pybind11_protobuf/archive/80f3440cd8fee124e077e2e47a8a17b78b451363.zip"),
sha256 = "",
sha256 = "c7ab64b1ccf9a678694a89035a8c865a693e4e872803778f91f0965c2f281d78",
strip_prefix = "pybind11_protobuf-80f3440cd8fee124e077e2e47a8a17b78b451363",
patch_file = ["//third_party/pybind11_protobuf:remove_license.patch"],
)
tf_http_archive(

View File

@@ -0,0 +1,13 @@
diff --git third_party/pybind11_protobuf/BUILD third_party/pybind11_protobuf/BUILD
index b62eb91..b7d1240 100644
--- a/pybind11_protobuf/BUILD
+++ b/pybind11_protobuf/BUILD
@@ -3,8 +3,6 @@
load("@pybind11_bazel//:build_defs.bzl", "pybind_library")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
-licenses(["notice"])
-
pybind_library(
name = "enum_type_caster",
hdrs = ["enum_type_caster.h"],