Cleanup: Avoid copying dictionary multiple times

PiperOrigin-RevId: 372899187
Change-Id: Ie3b21b089ca7b9df327961ac4109358f66a64824
This commit is contained in:
Thai Nguyen
2021-05-10 04:23:05 -07:00
committed by TensorFlower Gardener
parent 20a908fd41
commit 40caef4454

View File

@@ -434,8 +434,8 @@ class TFLiteConverterBase(object):
self._saved_model_tags = None
self._saved_model_version = 0
self._saved_model_exported_names = []
# Variable for converter metrics.
self._tflite_metrics = metrics.TFLiteConverterMetrics()
self._collected_converter_params = {}
self._experimental_disable_batchmatmul_unfold = False
self._experimental_lower_tensor_list_ops = True
@@ -596,12 +596,11 @@ class TFLiteConverterBase(object):
self._tflite_metrics.export_metrics()
def _save_conversion_params_metric(self,
converter_params,
graph_def=None,
inference_type=None,
inference_input_type=None):
"""Set conversion parameter metrics."""
converter_kwargs = converter_params.copy()
converter_kwargs = self._collected_converter_params
converter_kwargs.update(self._get_base_converter_args())
# Optimization parameters.
@@ -668,6 +667,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
super(TFLiteConverterBaseV2, self).__init__()
self.inference_input_type = _dtypes.float32
self.inference_output_type = _dtypes.float32
self._collected_converter_params.update({"api_version": 2})
def _validate_inference_input_output_types(self, quant_mode):
"""Validate inference_input_type and inference_output_type flags."""
@@ -688,20 +688,6 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
raise ValueError("The inference_input_type and inference_output_type "
"must be tf.float32.")
def _save_conversion_params_metric(self,
converter_params,
graph_def=None,
inference_type=None,
inference_input_type=None):
converter_kwargs = converter_params.copy()
converter_kwargs.update({
"api_version": 2,
})
super(TFLiteConverterBaseV2,
self)._save_conversion_params_metric(converter_kwargs, graph_def,
inference_type,
inference_input_type)
def convert(self, graph_def, input_tensors, output_tensors):
"""Converts a TensorFlow GraphDef based on instance variables.
@@ -722,7 +708,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
Invalid quantization parameters.
"""
# Update conversion params with graph_def.
self._save_conversion_params_metric({}, graph_def)
self._save_conversion_params_metric(graph_def)
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset, graph_def)
@@ -826,12 +812,8 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
Input shape is not specified.
Invalid quantization parameters.
"""
converter_kwargs = {
"enable_tflite_resource_variables":
self._enable_tflite_resource_variables
}
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric(converter_kwargs)
self._save_conversion_params_metric()
graph = _ops.Graph()
saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir)
saved_model.load_graph(graph, tags=self._saved_model_tags)
@@ -873,13 +855,17 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
meta_graph.graph_def)
# Update conversion params with graph_def.
self._save_conversion_params_metric(converter_kwargs, meta_graph.graph_def)
self._save_conversion_params_metric(meta_graph.graph_def)
# Get quantization options and do some sanity checks.
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset,
meta_graph.graph_def)
self._validate_inference_input_output_types(quant_mode)
converter_kwargs = {
"enable_tflite_resource_variables":
self._enable_tflite_resource_variables
}
converter_kwargs.update(self._get_base_converter_args())
converter_kwargs.update(quant_mode.converter_flags())
@@ -967,7 +953,7 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
Invalid quantization parameters.
"""
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric({})
self._save_conversion_params_metric()
saved_model_convert_result = self._convert_as_saved_model()
if saved_model_convert_result:
self._increase_conversion_success_metric(saved_model_convert_result)
@@ -1063,7 +1049,7 @@ class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
"under development.")
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric({})
self._save_conversion_params_metric()
frozen_func, graph_def = (
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
self._funcs[0], lower_control_flow=False))
@@ -1528,9 +1514,8 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
return False
return True
def _save_conversion_params_metric(self, converter_params):
converter_kwargs = converter_params.copy()
converter_kwargs.update({
def _save_conversion_params_metric(self):
self._collected_converter_params.update({
"output_format": self.output_format,
"default_ranges_stats": self.default_ranges_stats,
"drop_control_dependency": self.drop_control_dependency,
@@ -1542,8 +1527,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
"api_version": 1,
})
super(TFLiteConverterBaseV1,
self)._save_conversion_params_metric(converter_kwargs,
self._graph_def,
self)._save_conversion_params_metric(self._graph_def,
self.inference_type,
self.inference_input_type)
@@ -1608,7 +1592,7 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
None value for dimension in input_tensor.
"""
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric({})
self._save_conversion_params_metric()
result = super(TFLiteSavedModelConverter, self).convert()
self._increase_conversion_success_metric(result)
return result
@@ -1719,7 +1703,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
self._output_tensors = result[2]
self._debug_info_func = _build_debug_info_func(result[3])
# Update conversion params with graph_def.
self._save_conversion_params_metric({})
self._save_conversion_params_metric()
return super(TFLiteKerasModelConverter, self).convert()
finally:
shutil.rmtree(temp_dir, True)
@@ -1737,7 +1721,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
None value for dimension in input_tensor.
"""
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric({})
self._save_conversion_params_metric()
saved_model_convert_result = self._convert_as_saved_model()
if saved_model_convert_result:
self._increase_conversion_success_metric(saved_model_convert_result)
@@ -1816,7 +1800,7 @@ class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
None value for dimension in input_tensor.
"""
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric({})
self._save_conversion_params_metric()
result = super(TFLiteFrozenGraphConverter, self).convert()
self._increase_conversion_success_metric(result)
return result