mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Add tf.contrib.distributions.bijectors.Reshape.
PiperOrigin-RevId: 173740491
This commit is contained in:
committed by
TensorFlower Gardener
parent
729db035e7
commit
09a89ae57d
@@ -913,6 +913,22 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "reshape_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/bijectors/reshape_test.py"],
|
||||
additional_deps = [
|
||||
":bijectors_py",
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "sigmoid_test",
|
||||
size = "small",
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Reshape Bijector."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ReshapeBijectorTest(test.TestCase):
|
||||
"""Tests correctness of the reshape transformation."""
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(42)
|
||||
|
||||
def testBijector(self):
|
||||
"""Do a basic sanity check of forward, inverse, jacobian."""
|
||||
expected_x = np.random.randn(4, 3, 2)
|
||||
expected_y = np.reshape(expected_x, [4, 6])
|
||||
|
||||
with self.test_session() as sess:
|
||||
bijector = Reshape(
|
||||
event_shape_out=[6,],
|
||||
event_shape_in=[3, 2],
|
||||
validate_args=True)
|
||||
(x_,
|
||||
y_,
|
||||
fldj_,
|
||||
ildj_) = sess.run((
|
||||
bijector.inverse(expected_y),
|
||||
bijector.forward(expected_x),
|
||||
bijector.forward_log_det_jacobian(expected_x),
|
||||
bijector.inverse_log_det_jacobian(expected_y),
|
||||
))
|
||||
self.assertEqual("reshape", bijector.name)
|
||||
self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
|
||||
self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
|
||||
self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
|
||||
self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)
|
||||
|
||||
def testEventShapeDynamicNdims(self):
|
||||
"""Check forward/inverse shape methods with dynamic ndims."""
|
||||
|
||||
shape_in = tensor_shape.TensorShape([6,])
|
||||
shape_in_ph = array_ops.placeholder(dtype=dtypes.int32)
|
||||
|
||||
shape_out = tensor_shape.TensorShape([2, 3])
|
||||
shape_out_ph = array_ops.placeholder(dtype=dtypes.int32)
|
||||
|
||||
bijector = Reshape(
|
||||
event_shape_out=shape_out_ph,
|
||||
event_shape_in=shape_in_ph, validate_args=True)
|
||||
|
||||
# using the _tensor methods, we should always get a fully-specified
|
||||
# result since these are evaluated at graph runtime.
|
||||
with self.test_session() as sess:
|
||||
(shape_out_,
|
||||
shape_in_) = sess.run((
|
||||
bijector.forward_event_shape_tensor(shape_in),
|
||||
bijector.inverse_event_shape_tensor(shape_out),
|
||||
), feed_dict={
|
||||
shape_in_ph: shape_in,
|
||||
shape_out_ph: shape_out,
|
||||
})
|
||||
self.assertAllEqual(shape_out, shape_out_)
|
||||
self.assertAllEqual(shape_in, shape_in_)
|
||||
|
||||
def testEventShapeDynamic(self):
|
||||
"""Check shape methods with static ndims but dynamic shape."""
|
||||
|
||||
shape_in = tensor_shape.TensorShape([6,])
|
||||
shape_in_partial = tensor_shape.TensorShape([None,])
|
||||
shape_in_ph = array_ops.placeholder(
|
||||
shape=[1,], dtype=dtypes.int32)
|
||||
|
||||
shape_out = tensor_shape.TensorShape([2, 3])
|
||||
shape_out_partial = tensor_shape.TensorShape([None, None])
|
||||
shape_out_ph = array_ops.placeholder(
|
||||
shape=[2,], dtype=dtypes.int32)
|
||||
|
||||
bijector = Reshape(
|
||||
event_shape_out=shape_out_ph,
|
||||
event_shape_in=shape_in_ph,
|
||||
validate_args=True)
|
||||
|
||||
# if event shapes are not statically available, should
|
||||
# return partially-specified TensorShapes.
|
||||
self.assertAllEqual(
|
||||
bijector.forward_event_shape(shape_in).as_list(),
|
||||
shape_out_partial.as_list())
|
||||
self.assertAllEqual(
|
||||
bijector.inverse_event_shape(shape_out).as_list(),
|
||||
shape_in_partial.as_list())
|
||||
|
||||
# using the _tensor methods, we should always get a fully-specified
|
||||
# result since these are evaluated at graph runtime.
|
||||
with self.test_session() as sess:
|
||||
(shape_out_,
|
||||
shape_in_) = sess.run((
|
||||
bijector.forward_event_shape_tensor(shape_in),
|
||||
bijector.inverse_event_shape_tensor(shape_out),
|
||||
), feed_dict={
|
||||
shape_in_ph: shape_in,
|
||||
shape_out_ph: shape_out,
|
||||
})
|
||||
self.assertAllEqual(shape_out, shape_out_)
|
||||
self.assertAllEqual(shape_in, shape_in_)
|
||||
|
||||
def testEventShapeStatic(self):
|
||||
"""Check shape methods when shape is statically known."""
|
||||
|
||||
shape_in = tensor_shape.TensorShape([6,])
|
||||
shape_out = tensor_shape.TensorShape([2, 3])
|
||||
|
||||
bijector_static = Reshape(
|
||||
event_shape_out=shape_out,
|
||||
event_shape_in=shape_in,
|
||||
validate_args=True)
|
||||
|
||||
# test that forward_ and inverse_event_shape do sensible things
|
||||
# when shapes are statically known.
|
||||
self.assertEqual(
|
||||
bijector_static.forward_event_shape(shape_in),
|
||||
shape_out)
|
||||
self.assertEqual(
|
||||
bijector_static.inverse_event_shape(shape_out),
|
||||
shape_in)
|
||||
|
||||
with self.test_session() as sess:
|
||||
(shape_out_static_,
|
||||
shape_in_static_,
|
||||
) = sess.run((
|
||||
bijector_static.forward_event_shape_tensor(shape_in),
|
||||
bijector_static.inverse_event_shape_tensor(shape_out),
|
||||
))
|
||||
self.assertAllEqual(shape_out, shape_out_static_)
|
||||
self.assertAllEqual(shape_in, shape_in_static_)
|
||||
|
||||
def testScalarReshape(self):
|
||||
"""Test reshaping to and from a scalar shape ()."""
|
||||
|
||||
expected_x = np.random.randn(4, 3, 1)
|
||||
expected_y = np.reshape(expected_x, [4, 3])
|
||||
|
||||
expected_x_scalar = np.random.randn(1,)
|
||||
expected_y_scalar = expected_x_scalar[0]
|
||||
|
||||
with self.test_session() as sess:
|
||||
bijector = Reshape(
|
||||
event_shape_out=[],
|
||||
event_shape_in=[1,], validate_args=True)
|
||||
|
||||
(x_,
|
||||
y_,
|
||||
x_scalar_,
|
||||
y_scalar_
|
||||
) = sess.run((
|
||||
bijector.inverse(expected_y),
|
||||
bijector.forward(expected_x),
|
||||
bijector.inverse(expected_y_scalar),
|
||||
bijector.forward(expected_x_scalar),
|
||||
))
|
||||
self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
|
||||
self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
|
||||
self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0)
|
||||
self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0)
|
||||
|
||||
def testRaisesOpError(self):
|
||||
x1 = np.random.randn(4, 2, 3)
|
||||
x2 = np.random.randn(4, 3, 2)
|
||||
x3 = np.random.randn(4, 5, 1, 1)
|
||||
|
||||
with self.test_session() as sess:
|
||||
shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32)
|
||||
shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32)
|
||||
bijector = Reshape(
|
||||
event_shape_out=shape_out_ph,
|
||||
event_shape_in=shape_in_ph,
|
||||
validate_args=True)
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
"Input `event_shape` does not match `event_shape_in`."):
|
||||
sess.run(bijector.forward(x2),
|
||||
feed_dict={shape_out_ph: [1, 6, 1],
|
||||
shape_in_ph: [2, 3]})
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
"event_shape_out entries must be positive."):
|
||||
sess.run(bijector.forward(x1),
|
||||
feed_dict={shape_out_ph: [-1, -1, 6],
|
||||
shape_in_ph: [2, 3]})
|
||||
|
||||
# test that *all* methods check basic assertions
|
||||
fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]}
|
||||
with self.assertRaisesOpError(
|
||||
"Input/output `event_size`s do not match."):
|
||||
sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
|
||||
with self.assertRaisesOpError(
|
||||
"Input/output `event_size`s do not match."):
|
||||
sess.run(bijector.inverse(x3), feed_dict=fd_mismatched)
|
||||
with self.assertRaisesOpError(
|
||||
"Input/output `event_size`s do not match."):
|
||||
sess.run(bijector.inverse_log_det_jacobian(x3),
|
||||
feed_dict=fd_mismatched)
|
||||
with self.assertRaisesOpError(
|
||||
"Input/output `event_size`s do not match."):
|
||||
sess.run(bijector.forward_log_det_jacobian(x1),
|
||||
feed_dict=fd_mismatched)
|
||||
|
||||
def testBijectiveAndFinite(self):
|
||||
x = np.random.randn(4, 2, 3)
|
||||
y = np.reshape(x, [4, 1, 2, 3])
|
||||
with self.test_session():
|
||||
bijector = Reshape(
|
||||
event_shape_in=[2, 3],
|
||||
event_shape_out=[1, 2, 3],
|
||||
validate_args=True)
|
||||
assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
@@ -29,6 +29,7 @@
|
||||
@@MaskedAutoregressiveFlow
|
||||
@@Permute
|
||||
@@PowerTransform
|
||||
@@Reshape
|
||||
@@Sigmoid
|
||||
@@SigmoidCentered
|
||||
@@SinhArcsinh
|
||||
@@ -59,6 +60,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import *
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import *
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import *
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Reshape bijector."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ["Reshape"]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
@@ -0,0 +1,297 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Reshape bijectors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import bijector as bijector_lib
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Reshape",
|
||||
]
|
||||
|
||||
|
||||
class Reshape(bijector_lib.Bijector):
|
||||
"""Reshapes the `event_shape` of a `Tensor`.
|
||||
|
||||
The semantics generally follow that of `tf.reshape()`, with
|
||||
a few differences:
|
||||
* The user must provide both the input and output shape, so that
|
||||
the transformation can be inverted.
|
||||
* The `Reshape` bijector automatically broadcasts over the leftmost
|
||||
dimensions of its input (`sample_shape` and `batch_shape`); only
|
||||
the rightmost `event_ndims_in` dimensions are reshaped. The
|
||||
number of dimensions to reshape is inferred from the provided
|
||||
`event_shape_in` (`event_ndims_in = len(event_shape_in)`).
|
||||
* The `Reshape` bijector does not currently support
|
||||
partially-specified shapes, i.e., those with a dimension
|
||||
implicitly specified by `-1`.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
|
||||
bs = tf.contrib.distributions.bijectors
|
||||
|
||||
reverse = bs.Reshape(event_shape_out=[1,2],
|
||||
event_shape_in=[2,])
|
||||
|
||||
reverse.forward([1., 2.]) # shape [2,]
|
||||
# ==> [[1., 2.]] # shape [1,2]
|
||||
|
||||
reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2]
|
||||
# ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2]
|
||||
|
||||
reverse.inverse([[1., 2.]]) # shape [1,2]
|
||||
# ==> [1., 2.] # shape [2,]
|
||||
|
||||
reverse.forward_log_det_jacobian(any_value)
|
||||
# ==> 0.
|
||||
|
||||
reverse.inverse_log_det_jacobian(any_value)
|
||||
# ==> 0.
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, event_shape_out, event_shape_in,
|
||||
validate_args=False, name=None):
|
||||
"""Creates a `Reshape` bijector.
|
||||
|
||||
Args:
|
||||
event_shape_out: An `int`-like vector-shaped `Tensor`
|
||||
representing the fully specified (no -1's) event shape of the
|
||||
transformed output.
|
||||
event_shape_in: An `int`-like vector-shaped `Tensor`
|
||||
representing the fully specified (no -1's) event shape of the
|
||||
input.
|
||||
validate_args: Python `bool` indicating whether arguments should
|
||||
be checked for correctness.
|
||||
name: Python `str`, name given to ops managed by this object.
|
||||
|
||||
Raises:
|
||||
TypeError: if either `event_shape_in` or `event_shape_out` has
|
||||
non-vector shape (`rank > 1`), or non-integer `dtype`.
|
||||
ValueError: if either `event_shape_in` or `event_shape_out`
|
||||
contains non-positive entries, or if their sizes do not match
|
||||
(`prod(event_shape_in)` != `prod(event_shape_out)`), or if
|
||||
their dimensionality(s) cannot be statically inferred.
|
||||
"""
|
||||
with ops.name_scope(name, "reshape",
|
||||
values=[event_shape_out, event_shape_in]):
|
||||
|
||||
event_shape_out = ops.convert_to_tensor(event_shape_out,
|
||||
name="event_shape_out",
|
||||
preferred_dtype=dtypes.int32)
|
||||
event_shape_in = ops.convert_to_tensor(event_shape_in,
|
||||
name="event_shape_in",
|
||||
preferred_dtype=dtypes.int32)
|
||||
|
||||
# check that input shapes are positive integers
|
||||
assertions = []
|
||||
assertions += self._maybe_check_valid_shape(
|
||||
event_shape_out, "event_shape_out",
|
||||
validate_args=validate_args)
|
||||
assertions += self._maybe_check_valid_shape(
|
||||
event_shape_in, "event_shape_in", validate_args=validate_args)
|
||||
|
||||
# check that prod(event_shape_in) = prod(event_shape_out)
|
||||
assertions += self._maybe_check_matching_sizes(
|
||||
event_shape_in, event_shape_out, validate_args=validate_args)
|
||||
|
||||
self._assertions = assertions
|
||||
self._event_shape_in = event_shape_in
|
||||
self._event_shape_out = event_shape_out
|
||||
self._event_shape_in_static = tensor_util.constant_value_as_shape(
|
||||
event_shape_in)
|
||||
self._event_shape_out_static = tensor_util.constant_value_as_shape(
|
||||
event_shape_out)
|
||||
|
||||
super(Reshape, self).__init__(is_constant_jacobian=True,
|
||||
validate_args=validate_args,
|
||||
name=name or "reshape")
|
||||
|
||||
def _maybe_check_valid_shape(self, shape_tensor, label,
|
||||
validate_args=False):
|
||||
"""Check that a shape Tensor is int-type and positive."""
|
||||
|
||||
assertions = []
|
||||
|
||||
if not shape_tensor.dtype.is_integer:
|
||||
raise TypeError("{} dtype ({}) should be `int`-like.".format(
|
||||
label, shape_tensor.dtype.name))
|
||||
|
||||
shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor))
|
||||
if shape_rank is not None and shape_rank > 1:
|
||||
raise ValueError("{} rank should be <= 1.".format(label))
|
||||
|
||||
s = tensor_util.constant_value(shape_tensor)
|
||||
if s is not None:
|
||||
if (s <= 0).any():
|
||||
raise ValueError("{} entries must be positive, but found {}".format(
|
||||
label, s))
|
||||
elif validate_args:
|
||||
assertions.append(check_ops.assert_positive(
|
||||
shape_tensor, message="{} entries must be positive".format(label)))
|
||||
|
||||
return assertions
|
||||
|
||||
def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out,
|
||||
validate_args=False):
|
||||
"""Check that prod(event_shape_in)==prod(event_shape_out)."""
|
||||
|
||||
def _get_size_from_shape(shape):
|
||||
"""Computes size from a shape `Tensor`, statically if possible."""
|
||||
s = tensor_util.constant_value(shape)
|
||||
if s is not None:
|
||||
return [np.int32(np.prod(s))]*2
|
||||
return None, math_ops.reduce_prod(shape, name="size")
|
||||
|
||||
# Ensure `event_shape_in` is compatible with `event_shape_out`.
|
||||
event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
|
||||
event_shape_in)
|
||||
event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
|
||||
event_shape_out)
|
||||
|
||||
assertions = []
|
||||
if event_size_in_ is not None and event_size_out_ is not None:
|
||||
if event_size_in_ != event_size_out_:
|
||||
raise ValueError(
|
||||
"Input `event_size` ({}) does not match output `event_size` ({}).".
|
||||
format(event_size_in, event_size_out_))
|
||||
elif validate_args:
|
||||
assertions.append(check_ops.assert_equal(
|
||||
event_size_in, event_size_out,
|
||||
message="Input/output `event_size`s do not match."))
|
||||
|
||||
return assertions
|
||||
|
||||
def _reshape_helper(self, x, event_shape_in, event_shape_out):
|
||||
"""Reshape only the event_shape of an input `Tensor`."""
|
||||
|
||||
def _get_rank_from_shape(shape):
|
||||
"""Computes rank from a shape `Tensor`, statically if possible."""
|
||||
# Uses fact that rank is "shape of shape".
|
||||
ndims = shape.shape.with_rank_at_least(1)[0].value
|
||||
if ndims is not None:
|
||||
return ndims, ndims
|
||||
return None, array_ops.shape(shape)[0]
|
||||
|
||||
event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in)
|
||||
|
||||
assertions = []
|
||||
# Ensure x.event_shape is compatible with event_shape_in.
|
||||
if x.shape.ndims is not None:
|
||||
x_ndims_, x_ndims = [x.shape.ndims]*2
|
||||
else:
|
||||
x_ndims_, x_ndims = None, array_ops.rank(x)
|
||||
|
||||
if (event_ndims_in_ is not None
|
||||
and x_ndims_ is not None
|
||||
and x.shape.with_rank_at_least(event_ndims_in_)[
|
||||
x_ndims_-event_ndims_in_:].is_fully_defined()):
|
||||
x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking
|
||||
np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2
|
||||
else:
|
||||
x_event_shape_, x_event_shape = (
|
||||
None, array_ops.shape(x)[x_ndims-event_ndims_in:])
|
||||
|
||||
event_shape_in_ = tensor_util.constant_value(event_shape_in)
|
||||
|
||||
if x_event_shape_ is not None and event_shape_in_ is not None:
|
||||
if not np.equal(x_event_shape_, event_shape_in_).all():
|
||||
raise ValueError(
|
||||
"Input `event_shape` ({}) does not match `event_shape_in` ({}).".
|
||||
format(x_event_shape_, event_shape_in_))
|
||||
elif self.validate_args:
|
||||
assertions.append(check_ops.assert_equal(
|
||||
x_event_shape, event_shape_in,
|
||||
message="Input `event_shape` does not match `event_shape_in`."))
|
||||
|
||||
if assertions:
|
||||
x = control_flow_ops.with_dependencies(assertions, x)
|
||||
|
||||
# get the parts of shape(x) that will not change
|
||||
sample_and_batch_shape = array_ops.shape(x)
|
||||
|
||||
ndims = (x.shape.ndims if x.shape.ndims is not None
|
||||
else array_ops.rank(x))
|
||||
sample_and_batch_shape = sample_and_batch_shape[
|
||||
:(ndims - math_ops.abs(event_ndims_in))]
|
||||
|
||||
new_shape = array_ops.concat(
|
||||
[sample_and_batch_shape, event_shape_out], axis=0)
|
||||
|
||||
return array_ops.reshape(x, new_shape)
|
||||
|
||||
def _forward(self, x):
|
||||
with ops.control_dependencies(self._assertions):
|
||||
return self._reshape_helper(x,
|
||||
self._event_shape_in,
|
||||
self._event_shape_out)
|
||||
|
||||
def _inverse(self, y):
|
||||
with ops.control_dependencies(self._assertions):
|
||||
return self._reshape_helper(y,
|
||||
self._event_shape_out,
|
||||
self._event_shape_in)
|
||||
|
||||
def _inverse_log_det_jacobian(self, y):
|
||||
with ops.control_dependencies(self._assertions):
|
||||
return constant_op.constant(0., dtype=y.dtype)
|
||||
|
||||
def _forward_log_det_jacobian(self, x):
|
||||
with ops.control_dependencies(self._assertions):
|
||||
return constant_op.constant(0., dtype=x.dtype)
|
||||
|
||||
def _forward_event_shape(self, input_shape):
|
||||
self._event_shape_in_static.assert_is_compatible_with(input_shape)
|
||||
return self._event_shape_out_static
|
||||
|
||||
def _inverse_event_shape(self, output_shape):
|
||||
self._event_shape_out_static.assert_is_compatible_with(output_shape)
|
||||
return self._event_shape_in_static
|
||||
|
||||
def _forward_event_shape_tensor(self, input_shape):
|
||||
input_assertions = self._maybe_check_valid_shape(
|
||||
input_shape, "input event shape", validate_args=self.validate_args)
|
||||
input_assertions += self._maybe_check_matching_sizes(
|
||||
input_shape, self._event_shape_out,
|
||||
validate_args=self.validate_args)
|
||||
|
||||
return control_flow_ops.with_dependencies(
|
||||
input_assertions + self._assertions, self._event_shape_out)
|
||||
|
||||
def _inverse_event_shape_tensor(self, output_shape):
|
||||
|
||||
output_assertions = self._maybe_check_valid_shape(
|
||||
output_shape, "output event shape", validate_args=self.validate_args)
|
||||
output_assertions += self._maybe_check_matching_sizes(
|
||||
output_shape, self._event_shape_in, validate_args=self.validate_args)
|
||||
|
||||
return control_flow_ops.with_dependencies(
|
||||
output_assertions + self._assertions, self._event_shape_in)
|
||||
Reference in New Issue
Block a user