mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Add TF_GraphVersions() to C API and use in Graph.graph_def_versions()
PiperOrigin-RevId: 173902666
This commit is contained in:
committed by
TensorFlower Gardener
parent
4723f8f6ed
commit
7fd2616026
@@ -81,6 +81,7 @@ using tensorflow::TensorBuffer;
|
||||
using tensorflow::TensorId;
|
||||
using tensorflow::TensorShape;
|
||||
using tensorflow::TensorShapeProto;
|
||||
using tensorflow::VersionDef;
|
||||
using tensorflow::error::Code;
|
||||
using tensorflow::errors::FailedPrecondition;
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
@@ -1809,6 +1810,16 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
|
||||
status->status = MessageToBuffer(*op_def, output_op_def);
|
||||
}
|
||||
|
||||
void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
|
||||
TF_Status* status) {
|
||||
VersionDef versions;
|
||||
{
|
||||
mutex_lock l(graph->mu);
|
||||
versions = graph->graph.versions();
|
||||
}
|
||||
status->status = MessageToBuffer(versions, output_version_def);
|
||||
}
|
||||
|
||||
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
|
||||
return new TF_ImportGraphDefOptions;
|
||||
}
|
||||
|
||||
@@ -871,6 +871,11 @@ TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph,
|
||||
TF_Buffer* output_op_def,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the serialized VersionDef proto for this graph.
|
||||
TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph,
|
||||
TF_Buffer* output_version_def,
|
||||
TF_Status* status);
|
||||
|
||||
// TF_ImportGraphDefOptions holds options that can be passed to
|
||||
// TF_GraphImportGraphDef.
|
||||
typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions;
|
||||
|
||||
@@ -2713,7 +2713,16 @@ class Graph(object):
|
||||
A `VersionDef`.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
return self._graph_def_versions
|
||||
if self._c_graph:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
c_api.TF_GraphVersions(self._c_graph, buf, status)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
version_def = versions_pb2.VersionDef()
|
||||
version_def.ParseFromString(compat.as_bytes(data))
|
||||
return version_def
|
||||
else:
|
||||
return self._graph_def_versions
|
||||
|
||||
@property
|
||||
def seed(self):
|
||||
|
||||
@@ -1642,17 +1642,16 @@ class KernelLabelTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(b"My label is: overload_2", overload_2.eval())
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class AsGraphDefTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testGraphDefVersion(self):
|
||||
"""Test that the graphdef version is plumbed through to kernels."""
|
||||
for version in range(versions.GRAPH_DEF_VERSION_MIN_PRODUCER,
|
||||
versions.GRAPH_DEF_VERSION + 2):
|
||||
with ops.Graph().as_default() as g:
|
||||
g.graph_def_versions.producer = version
|
||||
with self.test_session(graph=g):
|
||||
v = test_ops.graph_def_version().eval()
|
||||
self.assertEqual(version, v)
|
||||
with ops.Graph().as_default() as g:
|
||||
version = g.graph_def_versions.producer
|
||||
with self.test_session(graph=g):
|
||||
v = test_ops.graph_def_version().eval()
|
||||
self.assertEqual(version, v)
|
||||
|
||||
def testAddShapes(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
|
||||
Reference in New Issue
Block a user