From 40caef44549a199eaac327b673fa862194b66fc4 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Mon, 10 May 2021 04:23:05 -0700 Subject: [PATCH] Cleanup: Avoid copying dictionary multiple times PiperOrigin-RevId: 372899187 Change-Id: Ie3b21b089ca7b9df327961ac4109358f66a64824 --- tensorflow/lite/python/lite.py | 54 ++++++++++++---------------------- 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 6a1795a77d9..919e413a149 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -434,8 +434,8 @@ class TFLiteConverterBase(object): self._saved_model_tags = None self._saved_model_version = 0 self._saved_model_exported_names = [] - # Variable for converter metrics. self._tflite_metrics = metrics.TFLiteConverterMetrics() + self._collected_converter_params = {} self._experimental_disable_batchmatmul_unfold = False self._experimental_lower_tensor_list_ops = True @@ -596,12 +596,11 @@ class TFLiteConverterBase(object): self._tflite_metrics.export_metrics() def _save_conversion_params_metric(self, - converter_params, graph_def=None, inference_type=None, inference_input_type=None): """Set conversion parameter metrics.""" - converter_kwargs = converter_params.copy() + converter_kwargs = self._collected_converter_params converter_kwargs.update(self._get_base_converter_args()) # Optimization parameters. @@ -668,6 +667,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): super(TFLiteConverterBaseV2, self).__init__() self.inference_input_type = _dtypes.float32 self.inference_output_type = _dtypes.float32 + self._collected_converter_params.update({"api_version": 2}) def _validate_inference_input_output_types(self, quant_mode): """Validate inference_input_type and inference_output_type flags.""" @@ -688,20 +688,6 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): raise ValueError("The inference_input_type and inference_output_type " "must be tf.float32.") - def _save_conversion_params_metric(self, - converter_params, - graph_def=None, - inference_type=None, - inference_input_type=None): - converter_kwargs = converter_params.copy() - converter_kwargs.update({ - "api_version": 2, - }) - super(TFLiteConverterBaseV2, - self)._save_conversion_params_metric(converter_kwargs, graph_def, - inference_type, - inference_input_type) - def convert(self, graph_def, input_tensors, output_tensors): """Converts a TensorFlow GraphDef based on instance variables. @@ -722,7 +708,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): Invalid quantization parameters. """ # Update conversion params with graph_def. - self._save_conversion_params_metric({}, graph_def) + self._save_conversion_params_metric(graph_def) quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, graph_def) @@ -826,12 +812,8 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): Input shape is not specified. Invalid quantization parameters. """ - converter_kwargs = { - "enable_tflite_resource_variables": - self._enable_tflite_resource_variables - } self._increase_conversion_attempt_metric() - self._save_conversion_params_metric(converter_kwargs) + self._save_conversion_params_metric() graph = _ops.Graph() saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir) saved_model.load_graph(graph, tags=self._saved_model_tags) @@ -873,13 +855,17 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): meta_graph.graph_def) # Update conversion params with graph_def. - self._save_conversion_params_metric(converter_kwargs, meta_graph.graph_def) + self._save_conversion_params_metric(meta_graph.graph_def) # Get quantization options and do some sanity checks. quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, meta_graph.graph_def) self._validate_inference_input_output_types(quant_mode) + converter_kwargs = { + "enable_tflite_resource_variables": + self._enable_tflite_resource_variables + } converter_kwargs.update(self._get_base_converter_args()) converter_kwargs.update(quant_mode.converter_flags()) @@ -967,7 +953,7 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): Invalid quantization parameters. """ self._increase_conversion_attempt_metric() - self._save_conversion_params_metric({}) + self._save_conversion_params_metric() saved_model_convert_result = self._convert_as_saved_model() if saved_model_convert_result: self._increase_conversion_success_metric(saved_model_convert_result) @@ -1063,7 +1049,7 @@ class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2): "under development.") self._increase_conversion_attempt_metric() - self._save_conversion_params_metric({}) + self._save_conversion_params_metric() frozen_func, graph_def = ( _convert_to_constants.convert_variables_to_constants_v2_as_graph( self._funcs[0], lower_control_flow=False)) @@ -1528,9 +1514,8 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): return False return True - def _save_conversion_params_metric(self, converter_params): - converter_kwargs = converter_params.copy() - converter_kwargs.update({ + def _save_conversion_params_metric(self): + self._collected_converter_params.update({ "output_format": self.output_format, "default_ranges_stats": self.default_ranges_stats, "drop_control_dependency": self.drop_control_dependency, @@ -1542,8 +1527,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): "api_version": 1, }) super(TFLiteConverterBaseV1, - self)._save_conversion_params_metric(converter_kwargs, - self._graph_def, + self)._save_conversion_params_metric(self._graph_def, self.inference_type, self.inference_input_type) @@ -1608,7 +1592,7 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1): None value for dimension in input_tensor. """ self._increase_conversion_attempt_metric() - self._save_conversion_params_metric({}) + self._save_conversion_params_metric() result = super(TFLiteSavedModelConverter, self).convert() self._increase_conversion_success_metric(result) return result @@ -1719,7 +1703,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1): self._output_tensors = result[2] self._debug_info_func = _build_debug_info_func(result[3]) # Update conversion params with graph_def. - self._save_conversion_params_metric({}) + self._save_conversion_params_metric() return super(TFLiteKerasModelConverter, self).convert() finally: shutil.rmtree(temp_dir, True) @@ -1737,7 +1721,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1): None value for dimension in input_tensor. """ self._increase_conversion_attempt_metric() - self._save_conversion_params_metric({}) + self._save_conversion_params_metric() saved_model_convert_result = self._convert_as_saved_model() if saved_model_convert_result: self._increase_conversion_success_metric(saved_model_convert_result) @@ -1816,7 +1800,7 @@ class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1): None value for dimension in input_tensor. """ self._increase_conversion_attempt_metric() - self._save_conversion_params_metric({}) + self._save_conversion_params_metric() result = super(TFLiteFrozenGraphConverter, self).convert() self._increase_conversion_success_metric(result) return result