mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Remove object metadata when saving SavedModel.
This change also fixes a few bugs when loading the metadata file, and fixes Keras tests so that they use model.save instead of tf.saved_model.save PiperOrigin-RevId: 378963258 Change-Id: I261d9ee80b3982a253a9479bc312e4fce6165981
This commit is contained in:
committed by
TensorFlower Gardener
parent
e269917bc3
commit
0733c41e5c
@@ -58,7 +58,8 @@
|
||||
enable this feature. The documentation in [Advanced autodiff]
|
||||
(https://www.tensorflow.org/guide/advanced_autodiff#custom_gradients)
|
||||
has been updated.
|
||||
|
||||
* Object metadata has now been deprecated and no longer saved to the
|
||||
SavedModel.
|
||||
* TF Core:
|
||||
* Added `tf.config.experimental.reset_memory_stats` to reset the tracked
|
||||
peak memory returned by `tf.config.experimental.get_memory_info`.
|
||||
|
||||
@@ -77,11 +77,12 @@ message SavedUserObject {
|
||||
string identifier = 1;
|
||||
// Version information from the producer of this SavedUserObject.
|
||||
VersionDef version = 2;
|
||||
// Metadata for deserializing this object.
|
||||
//
|
||||
// Deprecated! At the time of deprecation, Keras was the only user of this
|
||||
// field, and its saving and loading code will be updated shortly.
|
||||
// Please save your application-specific metadata to separate file
|
||||
// Initialization-related metadata.
|
||||
string metadata = 3;
|
||||
// Please save your application-specific metadata to a separate file.
|
||||
string metadata = 3 [deprecated = true];
|
||||
}
|
||||
|
||||
// A SavedAsset points to an asset in the MetaGraph.
|
||||
|
||||
@@ -3048,6 +3048,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
|
||||
@property
|
||||
def _tracking_metadata(self):
|
||||
"""Info about this layer to be saved into the SavedModel."""
|
||||
return self._trackable_saved_model_saver.tracking_metadata
|
||||
|
||||
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||
|
||||
@@ -215,8 +215,8 @@ class TpuStrategyTest(tf.test.TestCase):
|
||||
serving_fn = create_serving_signature(model)
|
||||
|
||||
saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
tf.saved_model.save(
|
||||
model, saved_model_dir, signatures={"serving_default": serving_fn})
|
||||
model.save(saved_model_dir, save_format="tf",
|
||||
signatures={"serving_default": serving_fn})
|
||||
|
||||
# Test the saved_model.
|
||||
loaded_serving_fn = tf.keras.models.load_model(
|
||||
|
||||
@@ -195,6 +195,12 @@ def _read_legacy_metadata(object_graph_def, metadata):
|
||||
for node_id, proto in enumerate(object_graph_def.nodes):
|
||||
if (proto.WhichOneof('kind') == 'user_object' and
|
||||
proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
|
||||
if not proto.user_object.metadata:
|
||||
raise ValueError('Unable to create a Keras model from this SavedModel. '
|
||||
'This SavedModel was created with '
|
||||
'`tf.saved_model.save`, and lacks the Keras metadata.'
|
||||
'Please save your Keras model by calling `model.save`'
|
||||
'or `tf.keras.models.save_model`.')
|
||||
metadata.nodes.add(
|
||||
node_id=node_id,
|
||||
node_path=node_paths[node_id],
|
||||
@@ -253,7 +259,7 @@ class KerasObjectLoader(object):
|
||||
"""
|
||||
|
||||
def __init__(self, metadata, object_graph_def):
|
||||
self._metadata = metadata
|
||||
self._metadata = {x.node_id: x for x in metadata.nodes}
|
||||
self._proto = object_graph_def
|
||||
|
||||
self._node_paths = {node_data.node_id: node_data.node_path
|
||||
@@ -309,7 +315,7 @@ class KerasObjectLoader(object):
|
||||
self._traversed_nodes_from_config.add(node_id)
|
||||
obj._maybe_initialize_trackable()
|
||||
if isinstance(obj, base_layer.Layer) and not obj.built:
|
||||
metadata = json_utils.decode(proto.user_object.metadata)
|
||||
metadata = json_utils.decode(self._metadata[node_id].metadata)
|
||||
self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))
|
||||
|
||||
# Create list of all possible children
|
||||
@@ -378,7 +384,7 @@ class KerasObjectLoader(object):
|
||||
# and layers will create the metric when initialized (this avoids wasting
|
||||
# time by creating objects multiple times).
|
||||
metric_list = []
|
||||
for node_metadata in self._metadata.nodes:
|
||||
for node_metadata in self._metadata.values():
|
||||
if node_metadata.identifier == constants.METRIC_IDENTIFIER:
|
||||
metric_list.append(node_metadata)
|
||||
continue
|
||||
@@ -666,8 +672,7 @@ class KerasObjectLoader(object):
|
||||
|
||||
def _reconstruct_model(self, model_id, model, layers):
|
||||
"""Reconstructs the network structure."""
|
||||
config = json_utils.decode(
|
||||
self._proto.nodes[model_id].user_object.metadata)['config']
|
||||
config = json_utils.decode(self._metadata[model_id].metadata)['config']
|
||||
|
||||
# Set up model inputs
|
||||
if model.inputs:
|
||||
|
||||
@@ -23,6 +23,7 @@ Tests that focus on the model structure should go in revive_test.py
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@@ -151,7 +152,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
|
||||
def _save_and_load(self, model):
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
return loaded
|
||||
|
||||
@@ -194,7 +195,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
def test_trainable_weights(self):
|
||||
"""Tests that trainable status of individual weights is preserved."""
|
||||
layer = keras.layers.Dense(4, name='custom_layer')
|
||||
layer.build([3,])
|
||||
layer.build([None, 3])
|
||||
layer.add_weight(
|
||||
'extra_weight', shape=[],
|
||||
initializer=init_ops.constant_initializer(11),
|
||||
@@ -203,12 +204,15 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
'extra_weight_2', shape=[],
|
||||
initializer=init_ops.constant_initializer(12),
|
||||
trainable=False)
|
||||
model = keras.Sequential([keras.Input([3,]), layer])
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
self.evaluate(variables.variables_initializer(layer.variables))
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded_model = keras_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded_model.variables))
|
||||
|
||||
loaded = loaded_model.layers[-1]
|
||||
|
||||
equal_attrs = ['name', '_expects_training_arg', 'trainable']
|
||||
for attr in equal_attrs:
|
||||
@@ -259,7 +263,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
previous_losses = model.losses[:]
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
|
||||
with previous_losses[0].graph.as_default():
|
||||
# If we try to compare symbolic Tensors in eager mode assertAllEqual will
|
||||
@@ -270,15 +274,18 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
# Test that eager losses are maintained.
|
||||
model(input_arr) # Calls model eagerly, creating eager losses.
|
||||
previous_losses = model.losses[:]
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
self.assertAllEqual(previous_losses, model.losses)
|
||||
|
||||
def test_layer_with_learning_phase(self):
|
||||
layer = LayerWithLearningPhase()
|
||||
layer.build([None, None])
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[layer], input_shape=[None], model_type='functional')
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded_model = keras_load.load(saved_model_dir)
|
||||
loaded = loaded_model.layers[-1]
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
|
||||
# Run the layer, and use the keras backend learning phase
|
||||
@@ -306,7 +313,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
model.predict(np.random.random((1, 3)).astype(np.float32))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
|
||||
loaded = tf_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
@@ -338,7 +345,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
# Compile and save model.
|
||||
model.compile('rmsprop', 'mse')
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
actual_predict = loaded.predict(input_arr)
|
||||
@@ -364,7 +371,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
super(LayerWithNestedSpec, self).__init__()
|
||||
self.input_spec = {
|
||||
'a': keras.layers.InputSpec(max_ndim=3, axes={-1: 2}),
|
||||
'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='float16')}
|
||||
'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='int32')}
|
||||
|
||||
@property
|
||||
def _use_input_spec_as_call_signature(self):
|
||||
@@ -372,12 +379,17 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
|
||||
layer = LayerWithNestedSpec()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[layer], model_type='subclass')
|
||||
model({'a': constant_op.constant([[2, 4]]),
|
||||
'b': array_ops.ones([1, 2, 3], dtype=dtypes.int32)})
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded_model = keras_load.load(saved_model_dir)
|
||||
loaded = loaded_model.layers[-1]
|
||||
self.assertEqual(3, loaded.input_spec['a'].max_ndim)
|
||||
self.assertEqual({-1: 2}, loaded.input_spec['a'].axes)
|
||||
self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
|
||||
self.assertEqual('float16', loaded.input_spec['b'].dtype)
|
||||
self.assertEqual('int32', loaded.input_spec['b'].dtype)
|
||||
|
||||
def test_must_restore_from_config_fails_if_layer_is_not_in_scope(self):
|
||||
|
||||
@@ -386,7 +398,9 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
|
||||
layer = LayerThatShouldFailIfNotAdded()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[layer], input_shape=[3], model_type='functional')
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
with self.assertRaisesRegex(RuntimeError, 'Unable to restore a layer of'):
|
||||
_ = keras_load.load(saved_model_dir)
|
||||
|
||||
@@ -396,8 +410,10 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
_must_restore_from_config = True
|
||||
|
||||
layer = LayerThatShouldFailIfNotAdded()
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[layer], input_shape=[3], model_type='functional')
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
with generic_utils.CustomObjectScope(
|
||||
{'LayerThatShouldFailIfNotAdded': LayerThatShouldFailIfNotAdded}):
|
||||
_ = keras_load.load(saved_model_dir)
|
||||
@@ -405,7 +421,9 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
def test_must_restore_from_config_registration(self):
|
||||
layer = GlobalLayerThatShouldFailIfNotAdded()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[layer], input_shape=[3], model_type='functional')
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
_ = keras_load.load(saved_model_dir)
|
||||
|
||||
def test_multi_input_model(self):
|
||||
@@ -454,16 +472,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
self.evaluate(variables.variables_initializer(model.variables))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
# TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
|
||||
# metadata warning.
|
||||
# with self.captureWritesToStream(sys.stderr) as captured_logs:
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
with self.captureWritesToStream(sys.stderr) as captured_logs:
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
|
||||
# Assert that saving does not log deprecation warnings
|
||||
# (even if it needs to set learning phase for compat reasons)
|
||||
# if context.executing_eagerly():
|
||||
# self.assertNotIn('deprecated', captured_logs.contents())
|
||||
if context.executing_eagerly():
|
||||
self.assertNotIn('deprecated', captured_logs.contents())
|
||||
|
||||
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
|
||||
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
|
||||
@@ -806,7 +822,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
np.zeros((batch, 64)).astype('float32'))
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
loaded_layer = loaded.layers[1]
|
||||
@@ -834,7 +850,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
self.evaluate([v.initializer for v in model.variables])
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
del model
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
@@ -874,7 +890,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
model = keras.Model(f_inputs, out)
|
||||
self.evaluate(variables.variables_initializer(model.variables))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
@@ -973,7 +989,7 @@ class TestSavedModelFormat(test.TestCase):
|
||||
inp = constant_op.constant([[1.0]])
|
||||
model(inp)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
self.assertAllEqual([[1.0]], self.evaluate(loaded(inp)))
|
||||
@@ -1026,6 +1042,13 @@ class TestSavedModelFormat(test.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'I said do not trace'):
|
||||
loaded.attached_layer(constant_op.constant([1.]))
|
||||
|
||||
def test_load_non_keras_saved_model(self):
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
with self.assertRaisesRegex(ValueError, 'Unable to create a Keras model'):
|
||||
keras_load.load(saved_model_dir)
|
||||
|
||||
|
||||
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@@ -1124,6 +1147,16 @@ class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(previous_losses, layer.losses)
|
||||
|
||||
|
||||
@generic_utils.register_keras_serializable('Testing')
|
||||
class CustomMeanMetric(keras.metrics.Mean):
|
||||
|
||||
def update_state(self, *args): # pylint: disable=useless-super-delegation
|
||||
# Sometimes built-in metrics return an op in update_state. Custom
|
||||
# metrics don't support returning ops, so wrap the update_state method
|
||||
# while returning nothing.
|
||||
super(CustomMeanMetric, self).update_state(*args)
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@@ -1145,8 +1178,12 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
shape=(1, 5),
|
||||
test_sample_weight=True):
|
||||
with self.cached_session():
|
||||
tf_save.save(metric, save_dir)
|
||||
loaded = keras_load.load(save_dir)
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[keras.layers.Layer()], input_shape=[3], model_type='functional')
|
||||
model.saved_metric = metric
|
||||
model.save(save_dir, save_format='tf')
|
||||
loaded_model = keras_load.load(save_dir)
|
||||
loaded = loaded_model.saved_metric
|
||||
self.evaluate([v.initializer for v in loaded.variables])
|
||||
self.assertEqual(metric.name, loaded.name)
|
||||
self.assertEqual(metric.dtype, loaded.dtype)
|
||||
@@ -1236,15 +1273,6 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_registered_custom_metric(self):
|
||||
|
||||
@generic_utils.register_keras_serializable('Testing')
|
||||
class CustomMeanMetric(keras.metrics.Mean):
|
||||
|
||||
def update_state(self, *args): # pylint: disable=useless-super-delegation
|
||||
# Sometimes built-in metrics return an op in update_state. Custom
|
||||
# metrics don't support returning ops, so wrap the update_state method
|
||||
# while returning nothing.
|
||||
super(CustomMeanMetric, self).update_state(*args)
|
||||
|
||||
with self.cached_session():
|
||||
metric = CustomMeanMetric()
|
||||
save_dir = self._save_model_dir('first_save')
|
||||
@@ -1298,7 +1326,7 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
metrics=[CustomMetric(), zero_metric])
|
||||
model.fit(x, y)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'custom_objects'):
|
||||
keras_load.load(saved_model_dir)
|
||||
|
||||
@@ -572,7 +572,8 @@ def get_model_from_layers(model_layers,
|
||||
input_dtype=None,
|
||||
name=None,
|
||||
input_ragged=None,
|
||||
input_sparse=None):
|
||||
input_sparse=None,
|
||||
model_type=None):
|
||||
"""Builds a model from a sequence of layers.
|
||||
|
||||
Args:
|
||||
@@ -582,12 +583,14 @@ def get_model_from_layers(model_layers,
|
||||
name: Name for the model.
|
||||
input_ragged: Boolean, whether the input data is a ragged tensor.
|
||||
input_sparse: Boolean, whether the input data is a sparse tensor.
|
||||
model_type: One of "subclass", "subclass_custom_build", "sequential", or
|
||||
"functional". When None, defaults to `get_model_type`.
|
||||
|
||||
Returns:
|
||||
A Keras model.
|
||||
"""
|
||||
|
||||
model_type = get_model_type()
|
||||
if model_type is None:
|
||||
model_type = get_model_type()
|
||||
if model_type == 'subclass':
|
||||
inputs = None
|
||||
if input_ragged or input_sparse:
|
||||
|
||||
@@ -27,7 +27,6 @@ from tensorflow.python.keras.layers.preprocessing import string_lookup
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save as tf_save
|
||||
|
||||
|
||||
class DistributeKplTestUtils(test.TestCase):
|
||||
@@ -170,8 +169,8 @@ class DistributeKplTestUtils(test.TestCase):
|
||||
label_inverse_lookup_layer)
|
||||
|
||||
saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
tf_save.save(
|
||||
model, saved_model_dir, signatures={"serving_default": serving_fn})
|
||||
model.save(saved_model_dir, save_format="tf",
|
||||
signatures={"serving_default": serving_fn})
|
||||
|
||||
# Test the saved_model.
|
||||
loaded_serving_fn = keras.saving.save.load_model(
|
||||
|
||||
@@ -927,17 +927,14 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
|
||||
if serialized is not None:
|
||||
proto.concrete_functions[name].CopyFrom(serialized)
|
||||
|
||||
saved_object_metadata = False
|
||||
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
||||
has_saved_object_metadata = _write_object_proto(
|
||||
obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
|
||||
saved_object_metadata = saved_object_metadata or has_saved_object_metadata
|
||||
return proto, saved_object_metadata
|
||||
_write_object_proto(obj, obj_proto, asset_file_def_index,
|
||||
saveable_view.function_name_map)
|
||||
return proto
|
||||
|
||||
|
||||
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||
"""Saves an object into SavedObject proto."""
|
||||
has_saved_object_metadata = False # The metadata field will be deprecated.
|
||||
if isinstance(obj, tracking.Asset):
|
||||
proto.asset.SetInParent()
|
||||
proto.asset.asset_file_def_index = asset_file_def_index[obj]
|
||||
@@ -963,19 +960,13 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||
if registered_type_proto is None:
|
||||
# Fallback for types with no matching registration
|
||||
# pylint:disable=protected-access
|
||||
metadata = obj._tracking_metadata
|
||||
if metadata:
|
||||
has_saved_object_metadata = True
|
||||
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
|
||||
identifier=obj._object_identifier,
|
||||
version=versions_pb2.VersionDef(
|
||||
producer=1, min_consumer=1, bad_consumers=[]),
|
||||
metadata=metadata)
|
||||
producer=1, min_consumer=1, bad_consumers=[]))
|
||||
# pylint:enable=protected-access
|
||||
proto.user_object.CopyFrom(registered_type_proto)
|
||||
|
||||
return has_saved_object_metadata
|
||||
|
||||
|
||||
def _export_debug_info(exported_graph, export_dir):
|
||||
"""Exports debug information from graph to file.
|
||||
@@ -1188,8 +1179,7 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
metrics.IncrementWriteApi(_SAVE_V2_LABEL)
|
||||
result = save_and_return_nodes(
|
||||
obj, export_dir, signatures, options, raise_metadata_warning=True)
|
||||
result = save_and_return_nodes(obj, export_dir, signatures, options)
|
||||
metrics.IncrementWrite()
|
||||
return result
|
||||
|
||||
@@ -1198,7 +1188,6 @@ def save_and_return_nodes(obj,
|
||||
export_dir,
|
||||
signatures=None,
|
||||
options=None,
|
||||
raise_metadata_warning=False,
|
||||
experimental_skip_checkpoint=False):
|
||||
"""Saves a SavedModel while returning all saved nodes and their paths.
|
||||
|
||||
@@ -1210,8 +1199,6 @@ def save_and_return_nodes(obj,
|
||||
signatures: A function or dictionary of functions to save in the SavedModel
|
||||
as signatures.
|
||||
options: `tf.saved_model.SaveOptions` object for configuring save options.
|
||||
raise_metadata_warning: Whether to raise the metadata warning. This arg will
|
||||
be removed in TF 2.5.
|
||||
experimental_skip_checkpoint: If set to `True`, the checkpoint will not
|
||||
be written.
|
||||
|
||||
@@ -1228,8 +1215,7 @@ def save_and_return_nodes(obj,
|
||||
meta_graph_def = saved_model.meta_graphs.add()
|
||||
|
||||
_, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
|
||||
_build_meta_graph(obj, signatures, options, meta_graph_def,
|
||||
raise_metadata_warning))
|
||||
_build_meta_graph(obj, signatures, options, meta_graph_def))
|
||||
saved_model.saved_model_schema_version = (
|
||||
pywrap_libexport.SAVED_MODEL_SCHEMA_VERSION)
|
||||
|
||||
@@ -1322,8 +1308,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
|
||||
def _build_meta_graph_impl(obj,
|
||||
signatures,
|
||||
options,
|
||||
meta_graph_def=None,
|
||||
raise_metadata_warning=True):
|
||||
meta_graph_def=None):
|
||||
"""Creates a MetaGraph containing the resources and functions of an object."""
|
||||
if ops.inside_function():
|
||||
raise AssertionError(
|
||||
@@ -1364,26 +1349,10 @@ def _build_meta_graph_impl(obj,
|
||||
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
|
||||
function_aliases[fdef.name] = alias
|
||||
|
||||
object_graph_proto, saved_object_metadata = _serialize_object_graph(
|
||||
object_graph_proto = _serialize_object_graph(
|
||||
saveable_view, asset_info.asset_index)
|
||||
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
||||
|
||||
if saved_object_metadata and raise_metadata_warning:
|
||||
tf_logging.warning(
|
||||
'FOR KERAS USERS: The object that you are saving contains one or more '
|
||||
'Keras models or layers. If you are loading the SavedModel with '
|
||||
'`tf.keras.models.load_model`, continue reading (otherwise, you may '
|
||||
'ignore the following instructions). Please change your code to save '
|
||||
'with `tf.keras.models.save_model` or `model.save`, and confirm that '
|
||||
'the file "keras.metadata" exists in the export directory. In the '
|
||||
'future, Keras will only load the SavedModels that have this file. In '
|
||||
'other words, `tf.saved_model.save` will no longer write SavedModels '
|
||||
'that can be recovered as Keras models (this will apply in TF 2.5).'
|
||||
'\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
|
||||
' this property has been used to save metadata in the SavedModel. The '
|
||||
'metadata field will be deprecated soon, so please move the metadata to'
|
||||
' a different file.')
|
||||
|
||||
return (meta_graph_def, exported_graph, object_saver, asset_info,
|
||||
saveable_view.nodes, saveable_view.node_paths)
|
||||
|
||||
@@ -1391,8 +1360,7 @@ def _build_meta_graph_impl(obj,
|
||||
def _build_meta_graph(obj,
|
||||
signatures,
|
||||
options,
|
||||
meta_graph_def=None,
|
||||
raise_metadata_warning=True):
|
||||
meta_graph_def=None):
|
||||
"""Creates a MetaGraph under a save context.
|
||||
|
||||
Args:
|
||||
@@ -1405,8 +1373,6 @@ def _build_meta_graph(obj,
|
||||
options: `tf.saved_model.SaveOptions` object that specifies options for
|
||||
saving.
|
||||
meta_graph_def: Optional, the MetaGraphDef proto fill.
|
||||
raise_metadata_warning: Whether to raise a warning when user objects contain
|
||||
non-empty metadata.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
|
||||
@@ -1420,5 +1386,4 @@ def _build_meta_graph(obj,
|
||||
"""
|
||||
|
||||
with save_context.save_context(options):
|
||||
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
|
||||
raise_metadata_warning)
|
||||
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
|
||||
|
||||
@@ -704,11 +704,6 @@ class Trackable(object):
|
||||
"""
|
||||
return "_generic_user_object"
|
||||
|
||||
@property
|
||||
def _tracking_metadata(self):
|
||||
"""String containing object metadata, which is saved to the SavedModel."""
|
||||
return ""
|
||||
|
||||
def _no_dependency(self, value):
|
||||
"""If automatic dependency tracking is enabled, ignores `value`."""
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user