Add test case for testing keras conversion with location information related to

saved model MLIR importer

PiperOrigin-RevId: 384693660
Change-Id: I82d3f4eb8a796eeb4ae28701fb52beb5ceca5380
This commit is contained in:
Jaesung Chung
2021-07-14 07:24:14 -07:00
committed by TensorFlower Gardener
parent cd064f1309
commit 6645320989

View File

@@ -20,6 +20,7 @@ import tempfile
import time
from unittest import mock
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
@@ -360,7 +361,8 @@ def mock_ngrams(data, width, axis=-1, string_separator=' ', name=None):
return func(data)
class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
class ConverterErrorMetricTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
"""Testing conversion error metric."""
def setUp(self):
@@ -397,7 +399,7 @@ class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
self.assertIsNone(tflite_model)
except ConverterError as converter_error:
# pylint: disable=g-assert-in-except
self.assertEqual(len(converter_error.errors), 1)
self.assertLen(converter_error.errors, 1)
location = converter_error.errors[0].location
self.assertEqual(location.type, expected_type)
@@ -549,7 +551,12 @@ class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
'tensorflow/lite/python/metrics_nonportable_test.py',
])
def test_location_from_keras_model(self):
@parameterized.named_parameters(
('_WithoutLoweringToSavedModel', False, 'keras/engine/functional.py'),
('_WithLoweringToSavedModel', True,
'tensorflow/lite/python/metrics_nonportable_test.py'))
def test_location_from_keras_model(self, lower_to_saved_model,
expected_source):
input_tensor1 = tf.keras.layers.Input(
shape=[None, None, 2, 3, 3], dtype=tf.complex64)
input_tensor2 = tf.keras.layers.Input(
@@ -563,13 +570,12 @@ class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
metrics=['accuracy'])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_lower_to_saved_model = lower_to_saved_model
# The location does not contain callsite to the current file.
self.convert_and_check_location_info(
converter,
converter_error_data_pb2.ConverterErrorData.CALLSITELOC,
expected_sources=[
'keras/engine/functional.py',
])
expected_sources=[expected_source])
if __name__ == '__main__':