Add tf.contrib.distributions.bijectors.Reshape.

PiperOrigin-RevId: 173740491
This commit is contained in:
A. Unique TensorFlower
2017-10-27 18:01:37 -07:00
committed by TensorFlower Gardener
parent 729db035e7
commit 09a89ae57d
5 changed files with 586 additions and 0 deletions

View File

@@ -913,6 +913,22 @@ cuda_py_test(
],
)
cuda_py_test(
name = "reshape_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/reshape_test.py"],
additional_deps = [
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "sigmoid_test",
size = "small",

View File

@@ -0,0 +1,242 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Reshape Bijector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.platform import test
class ReshapeBijectorTest(test.TestCase):
"""Tests correctness of the reshape transformation."""
def setUp(self):
self._rng = np.random.RandomState(42)
def testBijector(self):
"""Do a basic sanity check of forward, inverse, jacobian."""
expected_x = np.random.randn(4, 3, 2)
expected_y = np.reshape(expected_x, [4, 6])
with self.test_session() as sess:
bijector = Reshape(
event_shape_out=[6,],
event_shape_in=[3, 2],
validate_args=True)
(x_,
y_,
fldj_,
ildj_) = sess.run((
bijector.inverse(expected_y),
bijector.forward(expected_x),
bijector.forward_log_det_jacobian(expected_x),
bijector.inverse_log_det_jacobian(expected_y),
))
self.assertEqual("reshape", bijector.name)
self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)
def testEventShapeDynamicNdims(self):
"""Check forward/inverse shape methods with dynamic ndims."""
shape_in = tensor_shape.TensorShape([6,])
shape_in_ph = array_ops.placeholder(dtype=dtypes.int32)
shape_out = tensor_shape.TensorShape([2, 3])
shape_out_ph = array_ops.placeholder(dtype=dtypes.int32)
bijector = Reshape(
event_shape_out=shape_out_ph,
event_shape_in=shape_in_ph, validate_args=True)
# using the _tensor methods, we should always get a fully-specified
# result since these are evaluated at graph runtime.
with self.test_session() as sess:
(shape_out_,
shape_in_) = sess.run((
bijector.forward_event_shape_tensor(shape_in),
bijector.inverse_event_shape_tensor(shape_out),
), feed_dict={
shape_in_ph: shape_in,
shape_out_ph: shape_out,
})
self.assertAllEqual(shape_out, shape_out_)
self.assertAllEqual(shape_in, shape_in_)
def testEventShapeDynamic(self):
"""Check shape methods with static ndims but dynamic shape."""
shape_in = tensor_shape.TensorShape([6,])
shape_in_partial = tensor_shape.TensorShape([None,])
shape_in_ph = array_ops.placeholder(
shape=[1,], dtype=dtypes.int32)
shape_out = tensor_shape.TensorShape([2, 3])
shape_out_partial = tensor_shape.TensorShape([None, None])
shape_out_ph = array_ops.placeholder(
shape=[2,], dtype=dtypes.int32)
bijector = Reshape(
event_shape_out=shape_out_ph,
event_shape_in=shape_in_ph,
validate_args=True)
# if event shapes are not statically available, should
# return partially-specified TensorShapes.
self.assertAllEqual(
bijector.forward_event_shape(shape_in).as_list(),
shape_out_partial.as_list())
self.assertAllEqual(
bijector.inverse_event_shape(shape_out).as_list(),
shape_in_partial.as_list())
# using the _tensor methods, we should always get a fully-specified
# result since these are evaluated at graph runtime.
with self.test_session() as sess:
(shape_out_,
shape_in_) = sess.run((
bijector.forward_event_shape_tensor(shape_in),
bijector.inverse_event_shape_tensor(shape_out),
), feed_dict={
shape_in_ph: shape_in,
shape_out_ph: shape_out,
})
self.assertAllEqual(shape_out, shape_out_)
self.assertAllEqual(shape_in, shape_in_)
def testEventShapeStatic(self):
"""Check shape methods when shape is statically known."""
shape_in = tensor_shape.TensorShape([6,])
shape_out = tensor_shape.TensorShape([2, 3])
bijector_static = Reshape(
event_shape_out=shape_out,
event_shape_in=shape_in,
validate_args=True)
# test that forward_ and inverse_event_shape do sensible things
# when shapes are statically known.
self.assertEqual(
bijector_static.forward_event_shape(shape_in),
shape_out)
self.assertEqual(
bijector_static.inverse_event_shape(shape_out),
shape_in)
with self.test_session() as sess:
(shape_out_static_,
shape_in_static_,
) = sess.run((
bijector_static.forward_event_shape_tensor(shape_in),
bijector_static.inverse_event_shape_tensor(shape_out),
))
self.assertAllEqual(shape_out, shape_out_static_)
self.assertAllEqual(shape_in, shape_in_static_)
def testScalarReshape(self):
"""Test reshaping to and from a scalar shape ()."""
expected_x = np.random.randn(4, 3, 1)
expected_y = np.reshape(expected_x, [4, 3])
expected_x_scalar = np.random.randn(1,)
expected_y_scalar = expected_x_scalar[0]
with self.test_session() as sess:
bijector = Reshape(
event_shape_out=[],
event_shape_in=[1,], validate_args=True)
(x_,
y_,
x_scalar_,
y_scalar_
) = sess.run((
bijector.inverse(expected_y),
bijector.forward(expected_x),
bijector.inverse(expected_y_scalar),
bijector.forward(expected_x_scalar),
))
self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0)
self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0)
def testRaisesOpError(self):
x1 = np.random.randn(4, 2, 3)
x2 = np.random.randn(4, 3, 2)
x3 = np.random.randn(4, 5, 1, 1)
with self.test_session() as sess:
shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32)
shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32)
bijector = Reshape(
event_shape_out=shape_out_ph,
event_shape_in=shape_in_ph,
validate_args=True)
with self.assertRaisesOpError(
"Input `event_shape` does not match `event_shape_in`."):
sess.run(bijector.forward(x2),
feed_dict={shape_out_ph: [1, 6, 1],
shape_in_ph: [2, 3]})
with self.assertRaisesOpError(
"event_shape_out entries must be positive."):
sess.run(bijector.forward(x1),
feed_dict={shape_out_ph: [-1, -1, 6],
shape_in_ph: [2, 3]})
# test that *all* methods check basic assertions
fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]}
with self.assertRaisesOpError(
"Input/output `event_size`s do not match."):
sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
with self.assertRaisesOpError(
"Input/output `event_size`s do not match."):
sess.run(bijector.inverse(x3), feed_dict=fd_mismatched)
with self.assertRaisesOpError(
"Input/output `event_size`s do not match."):
sess.run(bijector.inverse_log_det_jacobian(x3),
feed_dict=fd_mismatched)
with self.assertRaisesOpError(
"Input/output `event_size`s do not match."):
sess.run(bijector.forward_log_det_jacobian(x1),
feed_dict=fd_mismatched)
def testBijectiveAndFinite(self):
x = np.random.randn(4, 2, 3)
y = np.reshape(x, [4, 1, 2, 3])
with self.test_session():
bijector = Reshape(
event_shape_in=[2, 3],
event_shape_out=[1, 2, 3],
validate_args=True)
assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
if __name__ == "__main__":
test.main()

View File

@@ -29,6 +29,7 @@
@@MaskedAutoregressiveFlow
@@Permute
@@PowerTransform
@@Reshape
@@Sigmoid
@@SigmoidCentered
@@SinhArcsinh
@@ -59,6 +60,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import *
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import *
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import *
from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *

View File

@@ -0,0 +1,29 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reshape bijector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ["Reshape"]
remove_undocumented(__name__, _allowed_symbols)

View File

@@ -0,0 +1,297 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reshape bijectors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector as bijector_lib
__all__ = [
"Reshape",
]
class Reshape(bijector_lib.Bijector):
"""Reshapes the `event_shape` of a `Tensor`.
The semantics generally follow that of `tf.reshape()`, with
a few differences:
* The user must provide both the input and output shape, so that
the transformation can be inverted.
* The `Reshape` bijector automatically broadcasts over the leftmost
dimensions of its input (`sample_shape` and `batch_shape`); only
the rightmost `event_ndims_in` dimensions are reshaped. The
number of dimensions to reshape is inferred from the provided
`event_shape_in` (`event_ndims_in = len(event_shape_in)`).
* The `Reshape` bijector does not currently support
partially-specified shapes, i.e., those with a dimension
implicitly specified by `-1`.
Example usage:
```python
bs = tf.contrib.distributions.bijectors
reverse = bs.Reshape(event_shape_out=[1,2],
event_shape_in=[2,])
reverse.forward([1., 2.]) # shape [2,]
# ==> [[1., 2.]] # shape [1,2]
reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2]
# ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2]
reverse.inverse([[1., 2.]]) # shape [1,2]
# ==> [1., 2.] # shape [2,]
reverse.forward_log_det_jacobian(any_value)
# ==> 0.
reverse.inverse_log_det_jacobian(any_value)
# ==> 0.
```
"""
def __init__(self, event_shape_out, event_shape_in,
validate_args=False, name=None):
"""Creates a `Reshape` bijector.
Args:
event_shape_out: An `int`-like vector-shaped `Tensor`
representing the fully specified (no -1's) event shape of the
transformed output.
event_shape_in: An `int`-like vector-shaped `Tensor`
representing the fully specified (no -1's) event shape of the
input.
validate_args: Python `bool` indicating whether arguments should
be checked for correctness.
name: Python `str`, name given to ops managed by this object.
Raises:
TypeError: if either `event_shape_in` or `event_shape_out` has
non-vector shape (`rank > 1`), or non-integer `dtype`.
ValueError: if either `event_shape_in` or `event_shape_out`
contains non-positive entries, or if their sizes do not match
(`prod(event_shape_in)` != `prod(event_shape_out)`), or if
their dimensionality(s) cannot be statically inferred.
"""
with ops.name_scope(name, "reshape",
values=[event_shape_out, event_shape_in]):
event_shape_out = ops.convert_to_tensor(event_shape_out,
name="event_shape_out",
preferred_dtype=dtypes.int32)
event_shape_in = ops.convert_to_tensor(event_shape_in,
name="event_shape_in",
preferred_dtype=dtypes.int32)
# check that input shapes are positive integers
assertions = []
assertions += self._maybe_check_valid_shape(
event_shape_out, "event_shape_out",
validate_args=validate_args)
assertions += self._maybe_check_valid_shape(
event_shape_in, "event_shape_in", validate_args=validate_args)
# check that prod(event_shape_in) = prod(event_shape_out)
assertions += self._maybe_check_matching_sizes(
event_shape_in, event_shape_out, validate_args=validate_args)
self._assertions = assertions
self._event_shape_in = event_shape_in
self._event_shape_out = event_shape_out
self._event_shape_in_static = tensor_util.constant_value_as_shape(
event_shape_in)
self._event_shape_out_static = tensor_util.constant_value_as_shape(
event_shape_out)
super(Reshape, self).__init__(is_constant_jacobian=True,
validate_args=validate_args,
name=name or "reshape")
def _maybe_check_valid_shape(self, shape_tensor, label,
validate_args=False):
"""Check that a shape Tensor is int-type and positive."""
assertions = []
if not shape_tensor.dtype.is_integer:
raise TypeError("{} dtype ({}) should be `int`-like.".format(
label, shape_tensor.dtype.name))
shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor))
if shape_rank is not None and shape_rank > 1:
raise ValueError("{} rank should be <= 1.".format(label))
s = tensor_util.constant_value(shape_tensor)
if s is not None:
if (s <= 0).any():
raise ValueError("{} entries must be positive, but found {}".format(
label, s))
elif validate_args:
assertions.append(check_ops.assert_positive(
shape_tensor, message="{} entries must be positive".format(label)))
return assertions
def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out,
validate_args=False):
"""Check that prod(event_shape_in)==prod(event_shape_out)."""
def _get_size_from_shape(shape):
"""Computes size from a shape `Tensor`, statically if possible."""
s = tensor_util.constant_value(shape)
if s is not None:
return [np.int32(np.prod(s))]*2
return None, math_ops.reduce_prod(shape, name="size")
# Ensure `event_shape_in` is compatible with `event_shape_out`.
event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
event_shape_in)
event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
event_shape_out)
assertions = []
if event_size_in_ is not None and event_size_out_ is not None:
if event_size_in_ != event_size_out_:
raise ValueError(
"Input `event_size` ({}) does not match output `event_size` ({}).".
format(event_size_in, event_size_out_))
elif validate_args:
assertions.append(check_ops.assert_equal(
event_size_in, event_size_out,
message="Input/output `event_size`s do not match."))
return assertions
def _reshape_helper(self, x, event_shape_in, event_shape_out):
"""Reshape only the event_shape of an input `Tensor`."""
def _get_rank_from_shape(shape):
"""Computes rank from a shape `Tensor`, statically if possible."""
# Uses fact that rank is "shape of shape".
ndims = shape.shape.with_rank_at_least(1)[0].value
if ndims is not None:
return ndims, ndims
return None, array_ops.shape(shape)[0]
event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in)
assertions = []
# Ensure x.event_shape is compatible with event_shape_in.
if x.shape.ndims is not None:
x_ndims_, x_ndims = [x.shape.ndims]*2
else:
x_ndims_, x_ndims = None, array_ops.rank(x)
if (event_ndims_in_ is not None
and x_ndims_ is not None
and x.shape.with_rank_at_least(event_ndims_in_)[
x_ndims_-event_ndims_in_:].is_fully_defined()):
x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking
np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2
else:
x_event_shape_, x_event_shape = (
None, array_ops.shape(x)[x_ndims-event_ndims_in:])
event_shape_in_ = tensor_util.constant_value(event_shape_in)
if x_event_shape_ is not None and event_shape_in_ is not None:
if not np.equal(x_event_shape_, event_shape_in_).all():
raise ValueError(
"Input `event_shape` ({}) does not match `event_shape_in` ({}).".
format(x_event_shape_, event_shape_in_))
elif self.validate_args:
assertions.append(check_ops.assert_equal(
x_event_shape, event_shape_in,
message="Input `event_shape` does not match `event_shape_in`."))
if assertions:
x = control_flow_ops.with_dependencies(assertions, x)
# get the parts of shape(x) that will not change
sample_and_batch_shape = array_ops.shape(x)
ndims = (x.shape.ndims if x.shape.ndims is not None
else array_ops.rank(x))
sample_and_batch_shape = sample_and_batch_shape[
:(ndims - math_ops.abs(event_ndims_in))]
new_shape = array_ops.concat(
[sample_and_batch_shape, event_shape_out], axis=0)
return array_ops.reshape(x, new_shape)
def _forward(self, x):
with ops.control_dependencies(self._assertions):
return self._reshape_helper(x,
self._event_shape_in,
self._event_shape_out)
def _inverse(self, y):
with ops.control_dependencies(self._assertions):
return self._reshape_helper(y,
self._event_shape_out,
self._event_shape_in)
def _inverse_log_det_jacobian(self, y):
with ops.control_dependencies(self._assertions):
return constant_op.constant(0., dtype=y.dtype)
def _forward_log_det_jacobian(self, x):
with ops.control_dependencies(self._assertions):
return constant_op.constant(0., dtype=x.dtype)
def _forward_event_shape(self, input_shape):
self._event_shape_in_static.assert_is_compatible_with(input_shape)
return self._event_shape_out_static
def _inverse_event_shape(self, output_shape):
self._event_shape_out_static.assert_is_compatible_with(output_shape)
return self._event_shape_in_static
def _forward_event_shape_tensor(self, input_shape):
input_assertions = self._maybe_check_valid_shape(
input_shape, "input event shape", validate_args=self.validate_args)
input_assertions += self._maybe_check_matching_sizes(
input_shape, self._event_shape_out,
validate_args=self.validate_args)
return control_flow_ops.with_dependencies(
input_assertions + self._assertions, self._event_shape_out)
def _inverse_event_shape_tensor(self, output_shape):
output_assertions = self._maybe_check_valid_shape(
output_shape, "output event shape", validate_args=self.validate_args)
output_assertions += self._maybe_check_matching_sizes(
output_shape, self._event_shape_in, validate_args=self.validate_args)
return control_flow_ops.with_dependencies(
output_assertions + self._assertions, self._event_shape_in)