diff --git a/tensorflow/lite/experimental/tensorboard/BUILD b/tensorflow/lite/experimental/tensorboard/BUILD new file mode 100644 index 00000000000..b156b5b1549 --- /dev/null +++ b/tensorflow/lite/experimental/tensorboard/BUILD @@ -0,0 +1,25 @@ +# TFLite modules to support TensorBoard plugin. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "ops_util", + srcs = ["ops_util.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/lite/toco/python:tensorflow_wrap_toco", + "//tensorflow/python:util", + ], +) + +py_test( + name = "ops_util_test", + srcs = ["ops_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops_util", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/lite/experimental/tensorboard/README.md b/tensorflow/lite/experimental/tensorboard/README.md new file mode 100644 index 00000000000..fdbd160a2a5 --- /dev/null +++ b/tensorflow/lite/experimental/tensorboard/README.md @@ -0,0 +1,4 @@ +This folder contains basic modules to support TFLite plugin for TensorBoard. + +Warning: Everything in this directory is experimental and highly subject to +changes. diff --git a/tensorflow/lite/experimental/tensorboard/ops_util.py b/tensorflow/lite/experimental/tensorboard/ops_util.py new file mode 100644 index 00000000000..8f48f77b301 --- /dev/null +++ b/tensorflow/lite/experimental/tensorboard/ops_util.py @@ -0,0 +1,49 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops util to handle ops for Lite.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from tensorflow.lite.toco.python import tensorflow_wrap_toco +from tensorflow.python.util.tf_export import tf_export + + +class SupportedOp(collections.namedtuple("SupportedOp", ["op"])): + """Spec of supported ops. + + Args: + op: string of op name. + """ + + +@tf_export("lite.experimental.get_potentially_supported_ops") +def get_potentially_supported_ops(): + """Returns operations potentially supported by TensorFlow Lite. + + The potentially support list contains a list of ops that are partially or + fully supported, which is derived by simply scanning op names to check whether + they can be handled without real conversion and specific parameters. + + Given that some ops may be partially supported, the optimal way to determine + if a model's operations are supported is by converting using the TensorFlow + Lite converter. + + Returns: + A list of SupportedOp. + """ + ops = tensorflow_wrap_toco.TocoGetPotentiallySupportedOps() + return [SupportedOp(o["op"]) for o in ops] diff --git a/tensorflow/lite/experimental/tensorboard/ops_util_test.py b/tensorflow/lite/experimental/tensorboard/ops_util_test.py new file mode 100644 index 00000000000..38966676dde --- /dev/null +++ b/tensorflow/lite/experimental/tensorboard/ops_util_test.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for backend.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.lite.experimental.tensorboard import ops_util +from tensorflow.python.platform import test + + +class OpsUtilTest(test.TestCase): + + def testGetPotentiallySupportedOps(self): + ops = ops_util.get_potentially_supported_ops() + # See GetTensorFlowNodeConverterMap() in + # tensorflow/lite/toco/import_tensorflow.cc + self.assertIsInstance(ops, list) + # Test partial ops that surely exist in the list. + self.assertIn(ops_util.SupportedOp("Add"), ops) + self.assertIn(ops_util.SupportedOp("Log"), ops) + self.assertIn(ops_util.SupportedOp("Sigmoid"), ops) + self.assertIn(ops_util.SupportedOp("Softmax"), ops) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index fae6540028c..a1bfe4898bc 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -73,6 +73,7 @@ py_library( ":op_hint", ":util", "//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops", + "//tensorflow/lite/experimental/tensorboard:ops_util", "//tensorflow/lite/python/optimize:calibrator", "//tensorflow/python:graph_util", "//tensorflow/python/keras", diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index dd5a89bf90f..751f07bb209 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -24,9 +24,11 @@ from six import PY3 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError +from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import +from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import @@ -47,7 +49,6 @@ from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_te from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes -from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.eager import def_function as _def_function diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index a218497b63b..0ff6f5865ee 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -1394,5 +1394,11 @@ class FromKerasFile(test_util.TensorFlowTestCase): interpreter.allocate_tensors() +class ImportOpsUtilTest(test_util.TensorFlowTestCase): + + def testGetPotentiallySupportedOps(self): + self.assertIsNotNone(lite.get_potentially_supported_ops()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index 6b5e68c533f..852b2b7480e 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2629,4 +2629,16 @@ std::unique_ptr ImportTensorFlowGraphDef( } return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph); } + +std::vector GetPotentiallySupportedOps() { + std::vector supported_ops; + const internal::ConverterMapType& converter_map = + internal::GetTensorFlowNodeConverterMap(); + + for (const auto& item : converter_map) { + supported_ops.push_back(item.first); + } + return supported_ops; +} + } // namespace toco diff --git a/tensorflow/lite/toco/import_tensorflow.h b/tensorflow/lite/toco/import_tensorflow.h index 5b74ff2bc31..4ada25e2fbe 100644 --- a/tensorflow/lite/toco/import_tensorflow.h +++ b/tensorflow/lite/toco/import_tensorflow.h @@ -17,9 +17,9 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_flags.pb.h" -#include "tensorflow/core/framework/graph.pb.h" namespace toco { @@ -34,14 +34,20 @@ struct TensorFlowImportFlags { bool import_all_ops_as_unsupported = false; }; +// Converts TOCO model from TensorFlow GraphDef with given flags. std::unique_ptr ImportTensorFlowGraphDef( const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags, const tensorflow::GraphDef& graph_def); +// Converts TOCO model from the file content of TensorFlow GraphDef with given +// flags. std::unique_ptr ImportTensorFlowGraphDef( const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags, const string& input_file_contents); +// Gets a list of supported ops by their names. +std::vector GetPotentiallySupportedOps(); + } // namespace toco #endif // TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_ diff --git a/tensorflow/lite/toco/python/toco.i b/tensorflow/lite/toco/python/toco.i index c7dfdc35ab2..de10fca99e8 100644 --- a/tensorflow/lite/toco/python/toco.i +++ b/tensorflow/lite/toco/python/toco.i @@ -32,4 +32,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return = false); +// Returns a list of names of all ops potentially supported by tflite. +PyObject* TocoGetPotentiallySupportedOps(); + } // namespace toco \ No newline at end of file diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 6fad092f35a..22557a34cc5 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" +#include "tensorflow/lite/toco/import_tensorflow.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/python/toco_python_api.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -49,21 +51,32 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, bool error; std::string model_flags_proto_txt = ConvertArg(model_flags_proto_txt_raw, &error); - if (error) return nullptr; + if (error) { + PyErr_SetString(PyExc_ValueError, "Model flags are invalid."); + return nullptr; + } std::string toco_flags_proto_txt = ConvertArg(toco_flags_proto_txt_raw, &error); - if (error) return nullptr; + if (error) { + PyErr_SetString(PyExc_ValueError, "Toco flags are invalid."); + return nullptr; + } std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); - if (error) return nullptr; + if (error) { + PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid."); + return nullptr; + } // Use TOCO to produce new outputs. toco::ModelFlags model_flags; if (!model_flags.ParseFromString(model_flags_proto_txt)) { - LOG(FATAL) << "Model proto failed to parse." << std::endl; + PyErr_SetString(PyExc_ValueError, "Model proto failed to parse."); + return nullptr; } toco::TocoFlags toco_flags; if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { - LOG(FATAL) << "Toco proto failed to parse." << std::endl; + PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse."); + return nullptr; } auto& dump_options = *GraphVizDumpOptions::singleton(); @@ -100,4 +113,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, output_file_contents_txt.data(), output_file_contents_txt.size()); } +PyObject* TocoGetPotentiallySupportedOps() { + std::vector supported_ops = toco::GetPotentiallySupportedOps(); + PyObject* list = PyList_New(supported_ops.size()); + for (size_t i = 0; i < supported_ops.size(); ++i) { + const string& op = supported_ops[i]; + PyObject* op_dict = PyDict_New(); + PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str())); + PyList_SetItem(list, i, op_dict); + } + return list; +} + } // namespace toco diff --git a/tensorflow/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h index 4ab0961e127..20390c32a5e 100644 --- a/tensorflow/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -31,6 +31,9 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return = false); +// Returns a list of names of all ops potentially supported by tflite. +PyObject* TocoGetPotentiallySupportedOps(); + } // namespace toco #endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index ab82ee9fd41..f23cf827470 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -125,12 +125,12 @@ def freeze_graph_with_def_protos(input_graph_def, # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): - print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") - return -1 + raise ValueError("Input checkpoint '" + input_checkpoint + + "' doesn't exist!") if not output_node_names: - print("You need to supply the name of a node to --output_node_names.") - return -1 + raise ValueError( + "You need to supply the name of a node to --output_node_names.") # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. @@ -193,14 +193,15 @@ def freeze_graph_with_def_protos(input_graph_def, # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: - print("Models containing partition variables cannot be converted " - "from checkpoint files. Please pass in a SavedModel using " - "the flag --input_saved_model_dir.") - return -1 + raise ValueError( + "Models containing partition variables cannot be converted " + "from checkpoint files. Please pass in a SavedModel using " + "the flag --input_saved_model_dir.") # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): - print("No variables were found in this model. It is likely the model " - "was frozen previously. You cannot freeze a graph twice.") + raise ValueError( + "No variables were found in this model. It is likely the model " + "was frozen previously. You cannot freeze a graph twice.") return 0 else: raise e @@ -242,8 +243,7 @@ def freeze_graph_with_def_protos(input_graph_def, def _parse_input_graph_proto(input_graph, input_binary): """Parses input tensorflow graph into GraphDef proto.""" if not gfile.Exists(input_graph): - print("Input graph file '" + input_graph + "' does not exist!") - return -1 + raise IOError("Input graph file '" + input_graph + "' does not exist!") input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.GFile(input_graph, mode) as f: @@ -257,8 +257,7 @@ def _parse_input_graph_proto(input_graph, input_binary): def _parse_input_meta_graph_proto(input_graph, input_binary): """Parses input tensorflow graph into MetaGraphDef proto.""" if not gfile.Exists(input_graph): - print("Input meta graph file '" + input_graph + "' does not exist!") - return -1 + raise IOError("Input meta graph file '" + input_graph + "' does not exist!") input_meta_graph_def = MetaGraphDef() mode = "rb" if input_binary else "r" with gfile.GFile(input_graph, mode) as f: @@ -273,8 +272,7 @@ def _parse_input_meta_graph_proto(input_graph, input_binary): def _parse_input_saver_proto(input_saver, input_binary): """Parses input tensorflow Saver into SaverDef proto.""" if not gfile.Exists(input_saver): - print("Input saver file '" + input_saver + "' does not exist!") - return -1 + raise IOError("Input saver file '" + input_saver + "' does not exist!") mode = "rb" if input_binary else "r" with gfile.GFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() @@ -369,9 +367,8 @@ def main(unused_args, flags): elif flags.checkpoint_version == 2: checkpoint_version = saver_pb2.SaverDef.V2 else: - print("Invalid checkpoint version (must be '1' or '2'): %d" % - flags.checkpoint_version) - return -1 + raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" % + flags.checkpoint_version) freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, flags.input_checkpoint, flags.output_node_names, flags.restore_op_name, flags.filename_tensor_name, @@ -380,7 +377,9 @@ def main(unused_args, flags): flags.input_meta_graph, flags.input_saved_model_dir, flags.saved_model_tags, checkpoint_version) + def run_main(): + """Main function of freeze_graph.""" parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( @@ -487,5 +486,6 @@ def run_main(): my_main = lambda unused_args: main(unused_args, flags) app.run(main=my_main, argv=[sys.argv[0]] + unparsed) -if __name__ == '__main__': + +if __name__ == "__main__": run_main() diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index d7edf4ec65d..0d054c00d33 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -316,17 +316,17 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): output_node_names = "save/restore_all" output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) - return_value = freeze_graph.freeze_graph_with_def_protos( - input_graph_def=sess.graph_def, - input_saver_def=None, - input_checkpoint=checkpoint_path, - output_node_names=output_node_names, - restore_op_name="save/restore_all", # default value - filename_tensor_name="save/Const:0", # default value - output_graph=output_graph_path, - clear_devices=False, - initializer_nodes="") - self.assertTrue(return_value, -1) + with self.assertRaises(ValueError): + freeze_graph.freeze_graph_with_def_protos( + input_graph_def=sess.graph_def, + input_saver_def=None, + input_checkpoint=checkpoint_path, + output_node_names=output_node_names, + restore_op_name="save/restore_all", # default value + filename_tensor_name="save/Const:0", # default value + output_graph=output_graph_path, + clear_devices=False, + initializer_nodes="") if __name__ == "__main__": diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.pbtxt index 354a7086d60..e4250ac75d4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.pbtxt @@ -8,4 +8,8 @@ tf_module { name: "convert_op_hints_to_stubs" argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \' instance>\'], " } + member_method { + name: "get_potentially_supported_ops" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.pbtxt index 354a7086d60..e4250ac75d4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.pbtxt @@ -8,4 +8,8 @@ tf_module { name: "convert_op_hints_to_stubs" argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \' instance>\'], " } + member_method { + name: "get_potentially_supported_ops" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } }