mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Add test case for testing keras conversion with location information related to
saved model MLIR importer PiperOrigin-RevId: 384693660 Change-Id: I82d3f4eb8a796eeb4ae28701fb52beb5ceca5380
This commit is contained in:
committed by
TensorFlower Gardener
parent
cd064f1309
commit
6645320989
@@ -20,6 +20,7 @@ import tempfile
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -360,7 +361,8 @@ def mock_ngrams(data, width, axis=-1, string_separator=' ', name=None):
|
||||
return func(data)
|
||||
|
||||
|
||||
class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
|
||||
class ConverterErrorMetricTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
"""Testing conversion error metric."""
|
||||
|
||||
def setUp(self):
|
||||
@@ -397,7 +399,7 @@ class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
|
||||
self.assertIsNone(tflite_model)
|
||||
except ConverterError as converter_error:
|
||||
# pylint: disable=g-assert-in-except
|
||||
self.assertEqual(len(converter_error.errors), 1)
|
||||
self.assertLen(converter_error.errors, 1)
|
||||
location = converter_error.errors[0].location
|
||||
self.assertEqual(location.type, expected_type)
|
||||
|
||||
@@ -549,7 +551,12 @@ class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
|
||||
'tensorflow/lite/python/metrics_nonportable_test.py',
|
||||
])
|
||||
|
||||
def test_location_from_keras_model(self):
|
||||
@parameterized.named_parameters(
|
||||
('_WithoutLoweringToSavedModel', False, 'keras/engine/functional.py'),
|
||||
('_WithLoweringToSavedModel', True,
|
||||
'tensorflow/lite/python/metrics_nonportable_test.py'))
|
||||
def test_location_from_keras_model(self, lower_to_saved_model,
|
||||
expected_source):
|
||||
input_tensor1 = tf.keras.layers.Input(
|
||||
shape=[None, None, 2, 3, 3], dtype=tf.complex64)
|
||||
input_tensor2 = tf.keras.layers.Input(
|
||||
@@ -563,13 +570,12 @@ class ConverterErrorMetricTest(test_util.TensorFlowTestCase):
|
||||
metrics=['accuracy'])
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
converter.experimental_lower_to_saved_model = lower_to_saved_model
|
||||
# The location does not contain callsite to the current file.
|
||||
self.convert_and_check_location_info(
|
||||
converter,
|
||||
converter_error_data_pb2.ConverterErrorData.CALLSITELOC,
|
||||
expected_sources=[
|
||||
'keras/engine/functional.py',
|
||||
])
|
||||
expected_sources=[expected_source])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user