From 7fd261602677d3c251fba05264a20318231deb76 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 30 Oct 2017 09:20:20 -0700 Subject: [PATCH] Add TF_GraphVersions() to C API and use in Graph.graph_def_versions() PiperOrigin-RevId: 173902666 --- tensorflow/c/c_api.cc | 11 +++++++++++ tensorflow/c/c_api.h | 5 +++++ tensorflow/python/framework/ops.py | 11 ++++++++++- tensorflow/python/framework/ops_test.py | 13 ++++++------- 4 files changed, 32 insertions(+), 8 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b43d202f4e8..6dd1b999102 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -81,6 +81,7 @@ using tensorflow::TensorBuffer; using tensorflow::TensorId; using tensorflow::TensorShape; using tensorflow::TensorShapeProto; +using tensorflow::VersionDef; using tensorflow::error::Code; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; @@ -1809,6 +1810,16 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, status->status = MessageToBuffer(*op_def, output_op_def); } +void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, + TF_Status* status) { + VersionDef versions; + { + mutex_lock l(graph->mu); + versions = graph->graph.versions(); + } + status->status = MessageToBuffer(versions, output_version_def); +} + TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { return new TF_ImportGraphDefOptions; } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index ca5c934634d..bb569d67fcb 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -871,6 +871,11 @@ TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, TF_Buffer* output_op_def, TF_Status* status); +// Returns the serialized VersionDef proto for this graph. +TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, + TF_Buffer* output_version_def, + TF_Status* status); + // TF_ImportGraphDefOptions holds options that can be passed to // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 63f70a1a9d4..b5e3e548bd0 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2713,7 +2713,16 @@ class Graph(object): A `VersionDef`. """ # pylint: enable=line-too-long - return self._graph_def_versions + if self._c_graph: + with errors.raise_exception_on_not_ok_status() as status: + with c_api_util.tf_buffer() as buf: + c_api.TF_GraphVersions(self._c_graph, buf, status) + data = c_api.TF_GetBuffer(buf) + version_def = versions_pb2.VersionDef() + version_def.ParseFromString(compat.as_bytes(data)) + return version_def + else: + return self._graph_def_versions @property def seed(self): diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 59c02884574..b1269b84bd2 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1642,17 +1642,16 @@ class KernelLabelTest(test_util.TensorFlowTestCase): self.assertAllEqual(b"My label is: overload_2", overload_2.eval()) +@test_util.with_c_api class AsGraphDefTest(test_util.TensorFlowTestCase): def testGraphDefVersion(self): """Test that the graphdef version is plumbed through to kernels.""" - for version in range(versions.GRAPH_DEF_VERSION_MIN_PRODUCER, - versions.GRAPH_DEF_VERSION + 2): - with ops.Graph().as_default() as g: - g.graph_def_versions.producer = version - with self.test_session(graph=g): - v = test_ops.graph_def_version().eval() - self.assertEqual(version, v) + with ops.Graph().as_default() as g: + version = g.graph_def_versions.producer + with self.test_session(graph=g): + v = test_ops.graph_def_version().eval() + self.assertEqual(version, v) def testAddShapes(self): with ops.Graph().as_default() as g: