From 09a89ae57d92b9753c76fa298d373468cb05cc6a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Oct 2017 18:01:37 -0700 Subject: [PATCH] Add `tf.contrib.distributions.bijectors.Reshape`. PiperOrigin-RevId: 173740491 --- tensorflow/contrib/distributions/BUILD | 16 + .../kernel_tests/bijectors/reshape_test.py | 242 ++++++++++++++ .../python/ops/bijectors/__init__.py | 2 + .../python/ops/bijectors/reshape.py | 29 ++ .../python/ops/bijectors/reshape_impl.py | 297 ++++++++++++++++++ 5 files changed, 586 insertions(+) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py create mode 100644 tensorflow/contrib/distributions/python/ops/bijectors/reshape.py create mode 100644 tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index bc72bc37a7f..4a4f3789016 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -913,6 +913,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "reshape_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/reshape_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py new file mode 100644 index 00000000000..38b3a23c2d6 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -0,0 +1,242 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Reshape Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + +class ReshapeBijectorTest(test.TestCase): + """Tests correctness of the reshape transformation.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testBijector(self): + """Do a basic sanity check of forward, inverse, jacobian.""" + expected_x = np.random.randn(4, 3, 2) + expected_y = np.reshape(expected_x, [4, 6]) + + with self.test_session() as sess: + bijector = Reshape( + event_shape_out=[6,], + event_shape_in=[3, 2], + validate_args=True) + (x_, + y_, + fldj_, + ildj_) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.forward_log_det_jacobian(expected_x), + bijector.inverse_log_det_jacobian(expected_y), + )) + self.assertEqual("reshape", bijector.name) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) + self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) + + def testEventShapeDynamicNdims(self): + """Check forward/inverse shape methods with dynamic ndims.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) + + shape_out = tensor_shape.TensorShape([2, 3]) + shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) + + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, validate_args=True) + + # using the _tensor methods, we should always get a fully-specified + # result since these are evaluated at graph runtime. + with self.test_session() as sess: + (shape_out_, + shape_in_) = sess.run(( + bijector.forward_event_shape_tensor(shape_in), + bijector.inverse_event_shape_tensor(shape_out), + ), feed_dict={ + shape_in_ph: shape_in, + shape_out_ph: shape_out, + }) + self.assertAllEqual(shape_out, shape_out_) + self.assertAllEqual(shape_in, shape_in_) + + def testEventShapeDynamic(self): + """Check shape methods with static ndims but dynamic shape.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_in_partial = tensor_shape.TensorShape([None,]) + shape_in_ph = array_ops.placeholder( + shape=[1,], dtype=dtypes.int32) + + shape_out = tensor_shape.TensorShape([2, 3]) + shape_out_partial = tensor_shape.TensorShape([None, None]) + shape_out_ph = array_ops.placeholder( + shape=[2,], dtype=dtypes.int32) + + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, + validate_args=True) + + # if event shapes are not statically available, should + # return partially-specified TensorShapes. + self.assertAllEqual( + bijector.forward_event_shape(shape_in).as_list(), + shape_out_partial.as_list()) + self.assertAllEqual( + bijector.inverse_event_shape(shape_out).as_list(), + shape_in_partial.as_list()) + + # using the _tensor methods, we should always get a fully-specified + # result since these are evaluated at graph runtime. + with self.test_session() as sess: + (shape_out_, + shape_in_) = sess.run(( + bijector.forward_event_shape_tensor(shape_in), + bijector.inverse_event_shape_tensor(shape_out), + ), feed_dict={ + shape_in_ph: shape_in, + shape_out_ph: shape_out, + }) + self.assertAllEqual(shape_out, shape_out_) + self.assertAllEqual(shape_in, shape_in_) + + def testEventShapeStatic(self): + """Check shape methods when shape is statically known.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_out = tensor_shape.TensorShape([2, 3]) + + bijector_static = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector_static.forward_event_shape(shape_in), + shape_out) + self.assertEqual( + bijector_static.inverse_event_shape(shape_out), + shape_in) + + with self.test_session() as sess: + (shape_out_static_, + shape_in_static_, + ) = sess.run(( + bijector_static.forward_event_shape_tensor(shape_in), + bijector_static.inverse_event_shape_tensor(shape_out), + )) + self.assertAllEqual(shape_out, shape_out_static_) + self.assertAllEqual(shape_in, shape_in_static_) + + def testScalarReshape(self): + """Test reshaping to and from a scalar shape ().""" + + expected_x = np.random.randn(4, 3, 1) + expected_y = np.reshape(expected_x, [4, 3]) + + expected_x_scalar = np.random.randn(1,) + expected_y_scalar = expected_x_scalar[0] + + with self.test_session() as sess: + bijector = Reshape( + event_shape_out=[], + event_shape_in=[1,], validate_args=True) + + (x_, + y_, + x_scalar_, + y_scalar_ + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.inverse(expected_y_scalar), + bijector.forward(expected_x_scalar), + )) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) + + def testRaisesOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 3, 2) + x3 = np.random.randn(4, 5, 1, 1) + + with self.test_session() as sess: + shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, + validate_args=True) + + with self.assertRaisesOpError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x2), + feed_dict={shape_out_ph: [1, 6, 1], + shape_in_ph: [2, 3]}) + + with self.assertRaisesOpError( + "event_shape_out entries must be positive."): + sess.run(bijector.forward(x1), + feed_dict={shape_out_ph: [-1, -1, 6], + shape_in_ph: [2, 3]}) + + # test that *all* methods check basic assertions + fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.forward(x1), feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.inverse_log_det_jacobian(x3), + feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.forward_log_det_jacobian(x1), + feed_dict=fd_mismatched) + + def testBijectiveAndFinite(self): + x = np.random.randn(4, 2, 3) + y = np.reshape(x, [4, 1, 2, 3]) + with self.test_session(): + bijector = Reshape( + event_shape_in=[2, 3], + event_shape_out=[1, 2, 3], + validate_args=True) + assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index fd6c5094469..bc0ec7f195a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -29,6 +29,7 @@ @@MaskedAutoregressiveFlow @@Permute @@PowerTransform +@@Reshape @@Sigmoid @@SigmoidCentered @@SinhArcsinh @@ -59,6 +60,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * +from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py new file mode 100644 index 00000000000..8997f7ab692 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reshape bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Reshape"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py new file mode 100644 index 00000000000..93682639aa3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py @@ -0,0 +1,297 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reshape bijectors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "Reshape", +] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + * The user must provide both the input and output shape, so that + the transformation can be inverted. + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + * The `Reshape` bijector does not currently support + partially-specified shapes, i.e., those with a dimension + implicitly specified by `-1`. + + Example usage: + ```python + + bs = tf.contrib.distributions.bijectors + + reverse = bs.Reshape(event_shape_out=[1,2], + event_shape_in=[2,]) + + reverse.forward([1., 2.]) # shape [2,] + # ==> [[1., 2.]] # shape [1,2] + + reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] + + reverse.inverse([[1., 2.]]) # shape [1,2] + # ==> [1., 2.] # shape [2,] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in, + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the fully specified (no -1's) event shape of the + transformed output. + event_shape_in: An `int`-like vector-shaped `Tensor` + representing the fully specified (no -1's) event shape of the + input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-vector shape (`rank > 1`), or non-integer `dtype`. + ValueError: if either `event_shape_in` or `event_shape_out` + contains non-positive entries, or if their sizes do not match + (`prod(event_shape_in)` != `prod(event_shape_out)`), or if + their dimensionality(s) cannot be statically inferred. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + # check that input shapes are positive integers + assertions = [] + assertions += self._maybe_check_valid_shape( + event_shape_out, "event_shape_out", + validate_args=validate_args) + assertions += self._maybe_check_valid_shape( + event_shape_in, "event_shape_in", validate_args=validate_args) + + # check that prod(event_shape_in) = prod(event_shape_out) + assertions += self._maybe_check_matching_sizes( + event_shape_in, event_shape_out, validate_args=validate_args) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + self._event_shape_in_static = tensor_util.constant_value_as_shape( + event_shape_in) + self._event_shape_out_static = tensor_util.constant_value_as_shape( + event_shape_out) + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape_tensor, label, + validate_args=False): + """Check that a shape Tensor is int-type and positive.""" + + assertions = [] + + if not shape_tensor.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + label, shape_tensor.dtype.name)) + + shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) + if shape_rank is not None and shape_rank > 1: + raise ValueError("{} rank should be <= 1.".format(label)) + + s = tensor_util.constant_value(shape_tensor) + if s is not None: + if (s <= 0).any(): + raise ValueError("{} entries must be positive, but found {}".format( + label, s)) + elif validate_args: + assertions.append(check_ops.assert_positive( + shape_tensor, message="{} entries must be positive".format(label))) + + return assertions + + def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, + validate_args=False): + """Check that prod(event_shape_in)==prod(event_shape_out).""" + + def _get_size_from_shape(shape): + """Computes size from a shape `Tensor`, statically if possible.""" + s = tensor_util.constant_value(shape) + if s is not None: + return [np.int32(np.prod(s))]*2 + return None, math_ops.reduce_prod(shape, name="size") + + # Ensure `event_shape_in` is compatible with `event_shape_out`. + event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking + event_shape_in) + event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking + event_shape_out) + + assertions = [] + if event_size_in_ is not None and event_size_out_ is not None: + if event_size_in_ != event_size_out_: + raise ValueError( + "Input `event_size` ({}) does not match output `event_size` ({}).". + format(event_size_in, event_size_out_)) + elif validate_args: + assertions.append(check_ops.assert_equal( + event_size_in, event_size_out, + message="Input/output `event_size`s do not match.")) + + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + def _get_rank_from_shape(shape): + """Computes rank from a shape `Tensor`, statically if possible.""" + # Uses fact that rank is "shape of shape". + ndims = shape.shape.with_rank_at_least(1)[0].value + if ndims is not None: + return ndims, ndims + return None, array_ops.shape(shape)[0] + + event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) + + assertions = [] + # Ensure x.event_shape is compatible with event_shape_in. + if x.shape.ndims is not None: + x_ndims_, x_ndims = [x.shape.ndims]*2 + else: + x_ndims_, x_ndims = None, array_ops.rank(x) + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + if not np.equal(x_event_shape_, event_shape_in_).all(): + raise ValueError( + "Input `event_shape` ({}) does not match `event_shape_in` ({}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + assertions.append(check_ops.assert_equal( + x_event_shape, event_shape_in, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + self._event_shape_in_static.assert_is_compatible_with(input_shape) + return self._event_shape_out_static + + def _inverse_event_shape(self, output_shape): + self._event_shape_out_static.assert_is_compatible_with(output_shape) + return self._event_shape_in_static + + def _forward_event_shape_tensor(self, input_shape): + input_assertions = self._maybe_check_valid_shape( + input_shape, "input event shape", validate_args=self.validate_args) + input_assertions += self._maybe_check_matching_sizes( + input_shape, self._event_shape_out, + validate_args=self.validate_args) + + return control_flow_ops.with_dependencies( + input_assertions + self._assertions, self._event_shape_out) + + def _inverse_event_shape_tensor(self, output_shape): + + output_assertions = self._maybe_check_valid_shape( + output_shape, "output event shape", validate_args=self.validate_args) + output_assertions += self._maybe_check_matching_sizes( + output_shape, self._event_shape_in, validate_args=self.validate_args) + + return control_flow_ops.with_dependencies( + output_assertions + self._assertions, self._event_shape_in)