mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Cleanup: Avoid copying dictionary multiple times
PiperOrigin-RevId: 372899187 Change-Id: Ie3b21b089ca7b9df327961ac4109358f66a64824
This commit is contained in:
committed by
TensorFlower Gardener
parent
20a908fd41
commit
40caef4454
@@ -434,8 +434,8 @@ class TFLiteConverterBase(object):
|
||||
self._saved_model_tags = None
|
||||
self._saved_model_version = 0
|
||||
self._saved_model_exported_names = []
|
||||
# Variable for converter metrics.
|
||||
self._tflite_metrics = metrics.TFLiteConverterMetrics()
|
||||
self._collected_converter_params = {}
|
||||
self._experimental_disable_batchmatmul_unfold = False
|
||||
self._experimental_lower_tensor_list_ops = True
|
||||
|
||||
@@ -596,12 +596,11 @@ class TFLiteConverterBase(object):
|
||||
self._tflite_metrics.export_metrics()
|
||||
|
||||
def _save_conversion_params_metric(self,
|
||||
converter_params,
|
||||
graph_def=None,
|
||||
inference_type=None,
|
||||
inference_input_type=None):
|
||||
"""Set conversion parameter metrics."""
|
||||
converter_kwargs = converter_params.copy()
|
||||
converter_kwargs = self._collected_converter_params
|
||||
converter_kwargs.update(self._get_base_converter_args())
|
||||
|
||||
# Optimization parameters.
|
||||
@@ -668,6 +667,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
super(TFLiteConverterBaseV2, self).__init__()
|
||||
self.inference_input_type = _dtypes.float32
|
||||
self.inference_output_type = _dtypes.float32
|
||||
self._collected_converter_params.update({"api_version": 2})
|
||||
|
||||
def _validate_inference_input_output_types(self, quant_mode):
|
||||
"""Validate inference_input_type and inference_output_type flags."""
|
||||
@@ -688,20 +688,6 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
raise ValueError("The inference_input_type and inference_output_type "
|
||||
"must be tf.float32.")
|
||||
|
||||
def _save_conversion_params_metric(self,
|
||||
converter_params,
|
||||
graph_def=None,
|
||||
inference_type=None,
|
||||
inference_input_type=None):
|
||||
converter_kwargs = converter_params.copy()
|
||||
converter_kwargs.update({
|
||||
"api_version": 2,
|
||||
})
|
||||
super(TFLiteConverterBaseV2,
|
||||
self)._save_conversion_params_metric(converter_kwargs, graph_def,
|
||||
inference_type,
|
||||
inference_input_type)
|
||||
|
||||
def convert(self, graph_def, input_tensors, output_tensors):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
|
||||
@@ -722,7 +708,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
Invalid quantization parameters.
|
||||
"""
|
||||
# Update conversion params with graph_def.
|
||||
self._save_conversion_params_metric({}, graph_def)
|
||||
self._save_conversion_params_metric(graph_def)
|
||||
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
||||
self.representative_dataset, graph_def)
|
||||
|
||||
@@ -826,12 +812,8 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
||||
Input shape is not specified.
|
||||
Invalid quantization parameters.
|
||||
"""
|
||||
converter_kwargs = {
|
||||
"enable_tflite_resource_variables":
|
||||
self._enable_tflite_resource_variables
|
||||
}
|
||||
self._increase_conversion_attempt_metric()
|
||||
self._save_conversion_params_metric(converter_kwargs)
|
||||
self._save_conversion_params_metric()
|
||||
graph = _ops.Graph()
|
||||
saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir)
|
||||
saved_model.load_graph(graph, tags=self._saved_model_tags)
|
||||
@@ -873,13 +855,17 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
||||
meta_graph.graph_def)
|
||||
|
||||
# Update conversion params with graph_def.
|
||||
self._save_conversion_params_metric(converter_kwargs, meta_graph.graph_def)
|
||||
self._save_conversion_params_metric(meta_graph.graph_def)
|
||||
# Get quantization options and do some sanity checks.
|
||||
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
||||
self.representative_dataset,
|
||||
meta_graph.graph_def)
|
||||
self._validate_inference_input_output_types(quant_mode)
|
||||
|
||||
converter_kwargs = {
|
||||
"enable_tflite_resource_variables":
|
||||
self._enable_tflite_resource_variables
|
||||
}
|
||||
converter_kwargs.update(self._get_base_converter_args())
|
||||
converter_kwargs.update(quant_mode.converter_flags())
|
||||
|
||||
@@ -967,7 +953,7 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
||||
Invalid quantization parameters.
|
||||
"""
|
||||
self._increase_conversion_attempt_metric()
|
||||
self._save_conversion_params_metric({})
|
||||
self._save_conversion_params_metric()
|
||||
saved_model_convert_result = self._convert_as_saved_model()
|
||||
if saved_model_convert_result:
|
||||
self._increase_conversion_success_metric(saved_model_convert_result)
|
||||
@@ -1063,7 +1049,7 @@ class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
|
||||
"under development.")
|
||||
|
||||
self._increase_conversion_attempt_metric()
|
||||
self._save_conversion_params_metric({})
|
||||
self._save_conversion_params_metric()
|
||||
frozen_func, graph_def = (
|
||||
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
||||
self._funcs[0], lower_control_flow=False))
|
||||
@@ -1528,9 +1514,8 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _save_conversion_params_metric(self, converter_params):
|
||||
converter_kwargs = converter_params.copy()
|
||||
converter_kwargs.update({
|
||||
def _save_conversion_params_metric(self):
|
||||
self._collected_converter_params.update({
|
||||
"output_format": self.output_format,
|
||||
"default_ranges_stats": self.default_ranges_stats,
|
||||
"drop_control_dependency": self.drop_control_dependency,
|
||||
@@ -1542,8 +1527,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
"api_version": 1,
|
||||
})
|
||||
super(TFLiteConverterBaseV1,
|
||||
self)._save_conversion_params_metric(converter_kwargs,
|
||||
self._graph_def,
|
||||
self)._save_conversion_params_metric(self._graph_def,
|
||||
self.inference_type,
|
||||
self.inference_input_type)
|
||||
|
||||
@@ -1608,7 +1592,7 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
self._increase_conversion_attempt_metric()
|
||||
self._save_conversion_params_metric({})
|
||||
self._save_conversion_params_metric()
|
||||
result = super(TFLiteSavedModelConverter, self).convert()
|
||||
self._increase_conversion_success_metric(result)
|
||||
return result
|
||||
@@ -1719,7 +1703,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||
self._output_tensors = result[2]
|
||||
self._debug_info_func = _build_debug_info_func(result[3])
|
||||
# Update conversion params with graph_def.
|
||||
self._save_conversion_params_metric({})
|
||||
self._save_conversion_params_metric()
|
||||
return super(TFLiteKerasModelConverter, self).convert()
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, True)
|
||||
@@ -1737,7 +1721,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
self._increase_conversion_attempt_metric()
|
||||
self._save_conversion_params_metric({})
|
||||
self._save_conversion_params_metric()
|
||||
saved_model_convert_result = self._convert_as_saved_model()
|
||||
if saved_model_convert_result:
|
||||
self._increase_conversion_success_metric(saved_model_convert_result)
|
||||
@@ -1816,7 +1800,7 @@ class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
self._increase_conversion_attempt_metric()
|
||||
self._save_conversion_params_metric({})
|
||||
self._save_conversion_params_metric()
|
||||
result = super(TFLiteFrozenGraphConverter, self).convert()
|
||||
self._increase_conversion_success_metric(result)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user