Add TF_GraphVersions() to C API and use in Graph.graph_def_versions()

PiperOrigin-RevId: 173902666
This commit is contained in:
Skye Wanderman-Milne
2017-10-30 09:20:20 -07:00
committed by TensorFlower Gardener
parent 4723f8f6ed
commit 7fd2616026
4 changed files with 32 additions and 8 deletions

View File

@@ -81,6 +81,7 @@ using tensorflow::TensorBuffer;
using tensorflow::TensorId;
using tensorflow::TensorShape;
using tensorflow::TensorShapeProto;
using tensorflow::VersionDef;
using tensorflow::error::Code;
using tensorflow::errors::FailedPrecondition;
using tensorflow::errors::InvalidArgument;
@@ -1809,6 +1810,16 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
status->status = MessageToBuffer(*op_def, output_op_def);
}
void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
TF_Status* status) {
VersionDef versions;
{
mutex_lock l(graph->mu);
versions = graph->graph.versions();
}
status->status = MessageToBuffer(versions, output_version_def);
}
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
return new TF_ImportGraphDefOptions;
}

View File

@@ -871,6 +871,11 @@ TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph,
TF_Buffer* output_op_def,
TF_Status* status);
// Returns the serialized VersionDef proto for this graph.
TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph,
TF_Buffer* output_version_def,
TF_Status* status);
// TF_ImportGraphDefOptions holds options that can be passed to
// TF_GraphImportGraphDef.
typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions;

View File

@@ -2713,7 +2713,16 @@ class Graph(object):
A `VersionDef`.
"""
# pylint: enable=line-too-long
return self._graph_def_versions
if self._c_graph:
with errors.raise_exception_on_not_ok_status() as status:
with c_api_util.tf_buffer() as buf:
c_api.TF_GraphVersions(self._c_graph, buf, status)
data = c_api.TF_GetBuffer(buf)
version_def = versions_pb2.VersionDef()
version_def.ParseFromString(compat.as_bytes(data))
return version_def
else:
return self._graph_def_versions
@property
def seed(self):

View File

@@ -1642,17 +1642,16 @@ class KernelLabelTest(test_util.TensorFlowTestCase):
self.assertAllEqual(b"My label is: overload_2", overload_2.eval())
@test_util.with_c_api
class AsGraphDefTest(test_util.TensorFlowTestCase):
def testGraphDefVersion(self):
"""Test that the graphdef version is plumbed through to kernels."""
for version in range(versions.GRAPH_DEF_VERSION_MIN_PRODUCER,
versions.GRAPH_DEF_VERSION + 2):
with ops.Graph().as_default() as g:
g.graph_def_versions.producer = version
with self.test_session(graph=g):
v = test_ops.graph_def_version().eval()
self.assertEqual(version, v)
with ops.Graph().as_default() as g:
version = g.graph_def_versions.producer
with self.test_session(graph=g):
v = test_ops.graph_def_version().eval()
self.assertEqual(version, v)
def testAddShapes(self):
with ops.Graph().as_default() as g: