From 5a421f2ad0e8921e7e753e64928ea82be7ff0101 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Mon, 25 Jan 2021 06:51:18 -0800 Subject: [PATCH] Add a `.element_spec` property to `DatasetSpec` Currently there is no way of recovering the inner spec. This is an issue for example when using nested datasets when `ds.element_spec` will return the outer dataset spec `{'nested_ds': DatasetSpec({'img': TensorSpec(...)})})` but it's not possible to access the inner DatasetSpec. PiperOrigin-RevId: 353636988 Change-Id: I4bcfb3ab31a0761834a2837075264f8117973861 --- RELEASE.md | 3 +++ tensorflow/python/data/kernel_tests/dataset_spec_test.py | 6 ++++++ tensorflow/python/data/ops/dataset_ops.py | 5 +++++ .../tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt | 4 ++++ .../tensorflow.data.experimental.-dataset-structure.pbtxt | 4 ++++ .../tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt | 4 ++++ 6 files changed, 26 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index f84f7614041..e49397d1dcf 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -56,6 +56,9 @@ the dataset elements. This avoids the need for explicitly specifying the `element_spec` argument of `tf.data.experimental.load` when loading the previously saved dataset. + * Add `.element_spec` property to `tf.data.DatasetSpec` to access the + inner spec. This can be used to extract the structure of nested + datasets. * XLA compilation: * `tf.function(experimental_compile=True)` has become a stable API, renamed `tf.function(jit_compile=True)`. diff --git a/tensorflow/python/data/kernel_tests/dataset_spec_test.py b/tensorflow/python/data/kernel_tests/dataset_spec_test.py index 781a972ea33..1053b0e4a4e 100644 --- a/tensorflow/python/data/kernel_tests/dataset_spec_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_spec_test.py @@ -49,6 +49,12 @@ class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase): fn(dataset) + @combinations.generate(test_base.default_test_combinations()) + def testDatasetSpecInnerSpec(self): + inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32) + ds_spec = dataset_ops.DatasetSpec(inner_spec) + self.assertEqual(ds_spec.element_spec, inner_spec) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index acd2af7bd0b..1c949254f48 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -3285,6 +3285,11 @@ class DatasetSpec(type_spec.BatchableTypeSpec): def value_type(self): return Dataset + @property + def element_spec(self): + """The inner element spec.""" + return self._element_spec + def _serialize(self): return (self._element_spec, self._dataset_shape) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt index 369aef45e9f..f56e1198f10 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "element_spec" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt index 474c725a696..fc65345f061 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "element_spec" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt index 369aef45e9f..f56e1198f10 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "element_spec" + mtype: "" + } member { name: "value_type" mtype: ""