From ce0238198052358d102ca7786ad9be60a5e76d28 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 30 Oct 2017 08:07:11 -0700 Subject: [PATCH] Add ability to fetch return nodes and unused input mappings from C API GraphDef import This change introduces yet another ImportGraphDef function to the C API (TF_GraphImportGraphDefWithResults), but this one has extensible return values so we shouldn't have to add more in the future. This change also modifies the ImportGraphDef C interface to manage all string data for the user. PiperOrigin-RevId: 173894710 --- tensorflow/c/c_api.cc | 241 ++++++++++++++------- tensorflow/c/c_api.h | 57 ++++- tensorflow/c/c_api_internal.h | 16 ++ tensorflow/c/c_api_test.cc | 135 +++++++++++- tensorflow/core/graph/graph_constructor.cc | 2 +- tensorflow/core/graph/graph_constructor.h | 2 +- 6 files changed, 369 insertions(+), 84 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index cd98393e0a5..b43d202f4e8 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -86,6 +86,7 @@ using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; using tensorflow::mutex_lock; +using tensorflow::string; using tensorflow::strings::StrCat; extern "C" { @@ -366,7 +367,7 @@ namespace { // Reset helper for converting character arrays to string vectors. void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, int ncontainers, TF_Status* status) { - std::vector container_names(ncontainers); + std::vector container_names(ncontainers); for (int i = 0; i < ncontainers; ++i) { container_names[i] = containers[i]; } @@ -482,7 +483,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { const char* limit = input + src_size; *dst = Tensor(static_cast(src->dtype), src->shape); - auto dstarray = dst->flat(); + auto dstarray = dst->flat(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = reinterpret_cast(input)[i]; @@ -556,9 +557,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, // Compute bytes needed for encoding. size_t size = 0; - const auto& srcarray = src.flat(); + const auto& srcarray = src.flat(); for (int i = 0; i < srcarray.size(); ++i) { - const tensorflow::string& s = srcarray(i); + const string& s = srcarray(i); // uint64 starting_offset, TF_StringEncode-d string. size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size()); } @@ -572,7 +573,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, for (int i = 0; i < srcarray.size(); ++i) { *offsets = (dst - data_start); offsets++; - const tensorflow::string& s = srcarray(i); + const string& s = srcarray(i); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); if (!status->status.ok()) { status->status = InvalidArgument( @@ -637,10 +638,9 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, } } -static bool TF_Run_Inputs( - TF_Tensor* const* c_inputs, - std::vector>* input_pairs, - TF_Status* status) { +static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, + std::vector>* input_pairs, + TF_Status* status) { const int ninputs = input_pairs->size(); for (int i = 0; i < ninputs; ++i) { status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); @@ -652,13 +652,12 @@ static bool TF_Run_Inputs( static void TF_Run_Helper( Session* session, const char* handle, const TF_Buffer* run_options, // Input tensors - const std::vector>& input_pairs, + const std::vector>& input_pairs, // Output tensors - const std::vector& output_tensor_names, - TF_Tensor** c_outputs, + const std::vector& output_tensor_names, TF_Tensor** c_outputs, // Target nodes - const std::vector& target_oper_names, - TF_Buffer* run_metadata, TF_Status* status) { + const std::vector& target_oper_names, TF_Buffer* run_metadata, + TF_Status* status) { const int noutputs = output_tensor_names.size(); std::vector outputs(noutputs); Status result; @@ -718,16 +717,16 @@ void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, const char** c_target_oper_names, int ntargets, TF_Buffer* run_metadata, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = c_output_names[i]; } - std::vector target_oper_names(ntargets); + std::vector target_oper_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } @@ -745,9 +744,9 @@ void TF_PRunSetup(TF_DeprecatedSession* s, const char** handle, TF_Status* status) { *handle = nullptr; - std::vector input_names(ninputs); - std::vector output_names(noutputs); - std::vector target_oper_names(ntargets); + std::vector input_names(ninputs); + std::vector output_names(noutputs); + std::vector target_oper_names(ntargets); for (int i = 0; i < ninputs; ++i) { input_names[i] = c_input_names[i]; } @@ -757,7 +756,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } - tensorflow::string new_handle; + string new_handle; status->status = s->session->PRunSetup(input_names, output_names, target_oper_names, &new_handle); if (status->status.ok()) { @@ -776,17 +775,17 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, const char** c_target_oper_names, int ntargets, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = c_output_names[i]; } - std::vector target_oper_names(ntargets); + std::vector target_oper_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } @@ -881,7 +880,7 @@ TF_Operation* ToOperation(Node* node) { return static_cast(static_cast(node)); } -tensorflow::string OutputName(const TF_Output& output) { +string OutputName(const TF_Output& output) { return StrCat(output.oper->node.name(), ":", output.index); } @@ -1254,7 +1253,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, return; } desc->colocation_constraints.clear(); - for (const tensorflow::string& location : attr_value.list().s()) { + for (const string& location : attr_value.list().s()) { desc->colocation_constraints.insert(location); } } else { @@ -1276,8 +1275,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, if (!desc->colocation_constraints.empty()) { desc->node_builder.Attr( tensorflow::kColocationAttrName, - std::vector(desc->colocation_constraints.begin(), - desc->colocation_constraints.end())); + std::vector(desc->colocation_constraints.begin(), + desc->colocation_constraints.end())); } status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); @@ -1500,7 +1499,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { const auto& a = oper->node.op_def().attr(i); if (a.name().compare(attr_name) != 0) continue; - const tensorflow::string& typestr = a.type(); + const string& typestr = a.type(); if (typestr == "list(string)") { metadata.type = TF_ATTR_STRING; } else if (typestr == "list(int)") { @@ -1580,7 +1579,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, const auto len = std::min(max_values, attr->list().s_size()); char* p = static_cast(storage); for (int i = 0; i < len; ++i) { - const tensorflow::string& s = attr->list().s(i); + const string& s = attr->list().s(i); values[i] = p; lengths[i] = s.size(); if ((p + s.size()) > (static_cast(storage) + storage_size)) { @@ -1824,7 +1823,11 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { - opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst); + opts->tensor_id_data.push_back(src_name); + const string& src_name_str = opts->tensor_id_data.back(); + // We don't need to store dst's name in tensor_id_data, since `dst` must + // outlive the ImportGraphDef call. + opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); } void TF_ImportGraphDefOptionsRemapControlDependency( @@ -1840,7 +1843,9 @@ extern void TF_ImportGraphDefOptionsAddControlDependency( void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, const char* oper_name, int index) { - opts->opts.return_tensors.push_back({oper_name, index}); + opts->tensor_id_data.push_back(oper_name); + const string& oper_name_str = opts->tensor_id_data.back(); + opts->opts.return_tensors.emplace_back(oper_name_str, index); } int TF_ImportGraphDefOptionsNumReturnOutputs( @@ -1848,15 +1853,116 @@ int TF_ImportGraphDefOptionsNumReturnOutputs( return opts->opts.return_tensors.size(); } +void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, + const char* oper_name) { + opts->opts.return_nodes.push_back(oper_name); +} + +int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts) { + return opts->opts.return_nodes.size(); +} + +void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, + int* num_outputs, + TF_Output** outputs) { + *num_outputs = results->return_tensors.size(); + *outputs = results->return_tensors.data(); +} + +void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, + int* num_opers, + TF_Operation*** opers) { + *num_opers = results->return_nodes.size(); + *opers = results->return_nodes.data(); +} + +void TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_unused_input_mappings, + const char*** src_names, int** src_indexes) { + *num_unused_input_mappings = results->unused_key_names.size(); + *src_names = results->unused_key_names.data(); + *src_indexes = results->unused_key_indexes.data(); +} + +void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { + delete results; +} + static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, const TF_ImportGraphDefOptions* opts, - TF_Output* return_outputs, - int num_return_outputs, TF_Status* status) + TF_ImportGraphDefResults* tf_results, + TF_Status* status) EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - if (num_return_outputs != opts->opts.return_tensors.size()) { + const int last_node_id = graph->graph.num_node_ids(); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); + if (!status->status.ok()) return; + + // Add new nodes to name_map + for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { + auto* node = graph->graph.FindNodeId(i); + if (node != nullptr) graph->name_map[node->name()] = node; + } + + // Populate return_tensors + DCHECK(tf_results->return_tensors.empty()); + tf_results->return_tensors.resize(results.return_tensors.size()); + for (int i = 0; i < results.return_tensors.size(); ++i) { + tf_results->return_tensors[i].oper = + ToOperation(results.return_tensors[i].first); + tf_results->return_tensors[i].index = results.return_tensors[i].second; + } + + // Populate return_nodes + DCHECK(tf_results->return_nodes.empty()); + tf_results->return_nodes.resize(results.return_nodes.size()); + for (int i = 0; i < results.return_nodes.size(); ++i) { + tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); + } + + // Populate unused map keys + DCHECK(tf_results->unused_key_names.empty()); + DCHECK(tf_results->unused_key_indexes.empty()); + DCHECK(tf_results->unused_key_names_data.empty()); + tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); + tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); + for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { + TensorId id = results.unused_input_map_keys[i]; + tf_results->unused_key_names_data.push_back(id.first.ToString()); + tf_results->unused_key_names[i] = + tf_results->unused_key_names_data.back().c_str(); + tf_results->unused_key_indexes[i] = id.second; + } +} + +TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status) { + GraphDef def; + if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + status->status = InvalidArgument("Invalid GraphDef"); + return nullptr; + } + auto results = new TF_ImportGraphDefResults(); + mutex_lock l(graph->mu); + GraphImportGraphDefLocked(graph, def, options, results, status); + if (!status->status.ok()) { + delete results; + return nullptr; + } + return results; +} + +void TF_GraphImportGraphDefWithReturnOutputs( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, + int num_return_outputs, TF_Status* status) { + if (num_return_outputs != options->opts.return_tensors.size()) { status->status = InvalidArgument("Expected 'num_return_outputs' to be ", - opts->opts.return_tensors.size(), ", got ", - num_return_outputs); + options->opts.return_tensors.size(), + ", got ", num_return_outputs); return; } if (num_return_outputs > 0 && return_outputs == nullptr) { @@ -1864,41 +1970,25 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, "'return_outputs' must be preallocated to length ", num_return_outputs); return; } - const int last_node_id = graph->graph.num_node_ids(); - tensorflow::ImportGraphDefResults results; - status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, - &graph->refiner, &results); - if (!status->status.ok()) return; - for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { - auto* node = graph->graph.FindNodeId(i); - if (node != nullptr) graph->name_map[node->name()] = node; - } - DCHECK_EQ(results.return_tensors.size(), num_return_outputs); - for (int i = 0; i < num_return_outputs; ++i) { - return_outputs[i].oper = ToOperation(results.return_tensors[i].first); - return_outputs[i].index = results.return_tensors[i].second; - } -} - -void TF_GraphImportGraphDefWithReturnOutputs( - TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* opts, TF_Output* return_outputs, - int num_return_outputs, TF_Status* status) { GraphDef def; if (!def.ParseFromArray(graph_def->data, graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } + TF_ImportGraphDefResults results; mutex_lock l(graph->mu); - GraphImportGraphDefLocked(graph, def, opts, return_outputs, - num_return_outputs, status); + GraphImportGraphDefLocked(graph, def, options, &results, status); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); + memcpy(return_outputs, results.return_tensors.data(), + num_return_outputs * sizeof(TF_Output)); } void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { - TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, options, nullptr, 0, - status); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); + TF_DeleteImportGraphDefResults(results); } // While loop functions ------------------------------------------------------- @@ -1930,7 +2020,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, tensorflow::ShapeRefiner* dst_refiner, const TF_Output* src_inputs, const std::vector& dst_inputs, - const tensorflow::string& prefix, + const string& prefix, const std::vector& control_deps, const TF_Output* nodes_to_return, int nreturn_nodes, std::vector* return_nodes) { @@ -2257,9 +2347,9 @@ TF_Session* TF_LoadSessionFromSavedModel( return nullptr; } - std::unordered_set tag_set; + std::unordered_set tag_set; for (int i = 0; i < tags_len; i++) { - tag_set.insert(tensorflow::string(tags[i])); + tag_set.insert(string(tags[i])); } tensorflow::SavedModelBundle bundle; @@ -2275,8 +2365,9 @@ TF_Session* TF_LoadSessionFromSavedModel( // TODO(jhseu): When Session is modified to take Graphs instead of // GraphDefs, return the Graph generated in LoadSavedModel(). TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefResults results; GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), - import_opts, nullptr, 0, status); + import_opts, &results, status); TF_DeleteImportGraphDefOptions(import_opts); if (TF_GetCode(status) != TF_OK) return nullptr; @@ -2372,20 +2463,20 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, TF_Run_Setup(noutputs, output_values, status); // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = OutputName(inputs[i]); } // Convert from TF_Output to string names. - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); + std::vector target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } @@ -2406,22 +2497,22 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, return; } - std::vector input_names(ninputs); + std::vector input_names(ninputs); for (int i = 0; i < ninputs; ++i) { input_names[i] = OutputName(inputs[i]); } - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } - std::vector target_names(ntargets); + std::vector target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } - tensorflow::string new_handle; + string new_handle; status->status = session->session->PRunSetup(input_names, output_names, target_names, &new_handle); if (status->status.ok()) { @@ -2452,20 +2543,20 @@ void TF_SessionPRun(TF_Session* session, const char* handle, TF_Run_Setup(noutputs, output_values, status); // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = OutputName(inputs[i]); } // Convert from TF_Output to string names. - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); + std::vector target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 1e8bfdc7b06..ca5c934634d 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -914,7 +914,62 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); +// Add an operation in `graph_def` to be returned via the `return_opers` output +// parameter of TF_GraphImportGraphDef(). +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( + TF_ImportGraphDefOptions* opts, const char* oper_name); + +// Returns the number of return operations added via +// TF_ImportGraphDefOptionsAddReturnOperation(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts); + +// TF_ImportGraphDefResults holds results that are generated by +// TF_GraphImportGraphDefWithResults(). +typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults; + +// Fetches the return outputs requested via +// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is +// returned in `num_outputs`. The array of return outputs is returned in +// `outputs`. `*outputs` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs( + TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs); + +// Fetches the return operations requested via +// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched +// operations is returned in `num_opers`. The array of return operations is +// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( + TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); + +// Fetches any input mappings requested via +// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any +// node in the imported graph def. The number of fetched mappings is returned in +// `num_unused_input_mappings`. The array of each mapping's source node name is +// returned in `src_names`, and the array of each mapping's source index is +// returned in `src_indexes`. +// +// `*src_names`, `*src_indexes`, and the memory backing each string in +// `src_names` are owned by and have the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_unused_input_mappings, + const char*** src_names, int** src_indexes); + +// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults( + TF_ImportGraphDefResults* results); + +// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and +// a bad status on error. Otherwise, returns a populated +// TF_ImportGraphDefResults instance. The returned instance must be deleted via +// TF_DeleteImportGraphDefResults(). +TF_CAPI_EXPORT extern TF_ImportGraphDefResults* +TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, + TF_Status* status); + // Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when only return outputs are needed. // // `num_return_outputs` must be the number of return outputs added (i.e. the // result of TF_ImportGraphDefOptionsNumReturnOutputs()). If @@ -926,7 +981,7 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( int num_return_outputs, TF_Status* status); // Import the graph serialized in `graph_def` into `graph`. -// Convenience function for when no return outputs have been added. +// Convenience function for when no results are needed. TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 23ec1fac6f4..bb04e01beec 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/c/c_api.h" +#include #include +#include #include #include @@ -124,6 +126,20 @@ struct TF_Session { struct TF_ImportGraphDefOptions { tensorflow::ImportGraphDefOptions opts; + + // Backing memory for TensorId fields in opts. + // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. + std::list tensor_id_data; +}; + +struct TF_ImportGraphDefResults { + std::vector return_tensors; + std::vector return_nodes; + std::vector unused_key_names; + std::vector unused_key_indexes; + + // Backing memory for unused_key_names values. + std::list unused_key_names_data; }; struct TF_DeviceList { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index d220bc5e95f..05881e619ba 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -573,7 +573,7 @@ TEST(CAPI, ImportGraphDef) { TF_GraphToGraphDef(graph, graph_def, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - // Import it again, with a prefix, in a fresh graph. + // Import it, with a prefix, in a fresh graph. TF_DeleteGraph(graph); graph = TF_NewGraph(); TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); @@ -588,8 +588,8 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(feed != nullptr); ASSERT_TRUE(neg != nullptr); - // Import it again, with an input mapping and return outputs, into the same - // graph. + // Import it again, with an input mapping, return outputs, and a return + // operation, into the same graph. TF_DeleteImportGraphDefOptions(opts); opts = TF_NewImportGraphDefOptions(); TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); @@ -597,9 +597,10 @@ TEST(CAPI, ImportGraphDef) { TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); - TF_Output return_outputs[2]; - TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, - return_outputs, 2, s); + TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); + EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts)); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar"); @@ -615,11 +616,26 @@ TEST(CAPI, ImportGraphDef) { EXPECT_EQ(0, neg_input.index); // Check return outputs + TF_Output* return_outputs; + int num_return_outputs; + TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs, + &return_outputs); + ASSERT_EQ(2, num_return_outputs); EXPECT_EQ(feed2, return_outputs[0].oper); EXPECT_EQ(0, return_outputs[0].index); EXPECT_EQ(scalar, return_outputs[1].oper); // remapped EXPECT_EQ(0, return_outputs[1].index); + // Check return operation + TF_Operation** return_opers; + int num_return_opers; + TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers, + &return_opers); + ASSERT_EQ(1, num_return_opers); + EXPECT_EQ(scalar2, return_opers[0]); // not remapped + + TF_DeleteImportGraphDefResults(results); + // Import again, with control dependencies, into the same graph. TF_DeleteImportGraphDefOptions(opts); opts = TF_NewImportGraphDefOptions(); @@ -689,6 +705,113 @@ TEST(CAPI, ImportGraphDef) { TF_DeleteStatus(s); } +TEST(CAPI, ImportGraphDef_WithReturnOutputs) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create a graph with two nodes: x and 3 + Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); + TF_Operation* oper = ScalarConst(3, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); + Neg(oper, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); + + // Export to a GraphDef. + TF_Buffer* graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import it in a fresh graph with return outputs. + TF_DeleteGraph(graph); + graph = TF_NewGraph(); + TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); + EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); + TF_Output return_outputs[2]; + TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, + return_outputs, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar"); + TF_Operation* feed = TF_GraphOperationByName(graph, "feed"); + TF_Operation* neg = TF_GraphOperationByName(graph, "neg"); + ASSERT_TRUE(scalar != nullptr); + ASSERT_TRUE(feed != nullptr); + ASSERT_TRUE(neg != nullptr); + + // Check return outputs + EXPECT_EQ(feed, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); + EXPECT_EQ(0, return_outputs[1].index); + + TF_DeleteImportGraphDefOptions(opts); + TF_DeleteBuffer(graph_def); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI, ImportGraphDef_UnusedInputMappings) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create a graph with two nodes: x and 3 + Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); + TF_Operation* oper = ScalarConst(3, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); + Neg(oper, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); + + // Export to a GraphDef. + TF_Buffer* graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import it in a fresh graph. + TF_DeleteGraph(graph); + graph = TF_NewGraph(); + TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); + TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar"); + + // Import it in a fresh graph with an unused input mapping. + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); + TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0}); + TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0}); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Check unused input mappings + int num_unused_input_mappings; + const char** src_names; + int* src_indexes; + TF_ImportGraphDefResultsUnusedInputMappings( + results, &num_unused_input_mappings, &src_names, &src_indexes); + ASSERT_EQ(1, num_unused_input_mappings); + EXPECT_EQ(string("fake"), string(src_names[0])); + EXPECT_EQ(0, src_indexes[0]); + + TF_DeleteImportGraphDefResults(results); + TF_DeleteImportGraphDefOptions(opts); + TF_DeleteBuffer(graph_def); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + TEST(CAPI, Session) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index b2c193b050b..9432775ff32 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -90,7 +90,7 @@ class GraphConstructor { bool skip_mapped_nodes; std::vector control_dependencies; std::vector return_tensors; - std::vector return_nodes; + std::vector return_nodes; // TODO(ashankar): This bool exists to separate out functionality required // to make ImportGraphDef a close equivalent of Python's import_graph_def diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 6cd9347d965..a3644788788 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -110,7 +110,7 @@ struct ImportGraphDefOptions { // Unlike `return_tensors`, `input_map` has no effect on the nodes // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true. // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. - std::vector return_nodes; + std::vector return_nodes; // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries // with ops that are not defined in the binary calling ImportGraphDef.