Fix integer div/mod overflow issues for Eigen, XLA, MLIR.

The original implementations of `google_floor_div`, XLA `FloorDiv` and
MLIR `TF_FloorDivOp` all suffered from overflows for
`abs(x) + abs(y) > INT_MAX`. The new implementation addresses this.

Also fixed FPE issues for the Eigen backend when trying to divide
`INT_MIN / -1`. This previously caused TF to crash. Here we decide to
overflow so that `INT_MIN / -1 = INT_MIN`.  This is consistent with
`int16` and `int8` types (which already overflow without crashing in C++),
and maintains the equality
```
 x = div(x, y) * y + mod(x, y)
```

Also added tests for these edge cases.

Fixes #46887, #45771.

PiperOrigin-RevId: 372993714
Change-Id: If38fdd9e53d562011e529c58b415786ba015a0be
This commit is contained in:
A. Unique TensorFlower
2021-05-10 12:58:26 -07:00
committed by TensorFlower Gardener
parent bf5b1056fc
commit b47be308c4
6 changed files with 174 additions and 94 deletions

View File

@@ -740,21 +740,18 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @floordiv_broadcast_i32
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"}
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
// CHECK: return [[SELECT]]
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
return %0: tensor<2x3xi32>
@@ -762,21 +759,18 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te
// CHECK-LABEL: func @floordiv_reverse_broadcast_i32
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]]
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]]
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
// CHECK: return [[SELECT]]
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %0: tensor<2x3xi32>
@@ -814,21 +808,18 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te
// CHECK-LABEL: func @floordiv_dynamic
func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
// CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"}
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
// CHECK: return [[SELECT]]
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>

View File

@@ -126,13 +126,8 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
// additional correction for a negative numerator / denominator. Equivalent
// pseudocode is shown below:
//
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
// T abs_y = std::abs(y);
// return -(abs_x + abs_y - 1) / abs_y;
// } else {
// return x / y;
// }
// T z = x / y
// return (z * y != x && (x < 0) != (y < 0)) ? z - 1 : z
//
// BroadcastToDimensions is used to compute the broadcast attr to higher
// dimensions. This computes the broadcast of 'l' to broadcast('l', 'r')
@@ -142,31 +137,39 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
// Requires static shaped inputs to create constant splats and computation of
// broadcast attributes.
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_SelectOp
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ,
(HLO_SelectOp
(HLOClient_BroadcastAndOp
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastMulOp:$mul
(HLOClient_BroadcastDivOp:$div $l, $r,
(BinBroadcastDimensions $l, $r)),
$r, (BinBroadcastDimensions $div, $r)),
$l, (BinBroadcastDimensions $mul, $l), HLO_COMPARISON_DIRECTION_NE,
(HLO_DEFAULT_COMPARISON_TYPE)),
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastCompareOp:$l_cmp $l,
(HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)),
(HLOClient_BroadcastDivOp
(HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l),
(HLOClient_BroadcastSubOp (HLO_AbsOp $r),
(HLO_ConstOp (GetScalarOfType<1> $r)),
(NullDenseIntElementsAttr)),
(BinBroadcastDimensions $l, $r))),
(HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))),
[(SignedIntTensor $l)]>;
(HLOClient_BroadcastCompareOp:$r_cmp $r,
(HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(BinBroadcastDimensions $l_cmp, $r_cmp), HLO_COMPARISON_DIRECTION_NE,
(HLO_DEFAULT_COMPARISON_TYPE)),
(NullDenseIntElementsAttr)),
(HLOClient_BroadcastSubOp $div,
(HLO_ConstOp:$ones (GetScalarOfType<1> $div)),
(NullDenseIntElementsAttr)), $div),
[(SignedIntTensor $l)]>;
// Performs a substitution of FloorMod designed to correct for possibly negative
// values. Pseudocode shown below:
//
// T trunc_mod = std::fmod(x, y);
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
// : trunc_mod
//
// Requires static shaped inputs to create constant splats and computation of
// broadcast attributes.
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),

View File

@@ -710,6 +710,17 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([3, 3, -1, -9, -8], dtype=dtype),
np.array([2, -2, 7, 2, -4], dtype=dtype),
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
if dtype in self.signed_int_types:
# Overflow cases.
int_min = np.iinfo(dtype).min
int_max = np.iinfo(dtype).max
self._testBinary(
gen_math_ops.floor_div,
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]),
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]),
expected=np.array([[1, 0, -1, -1], [int_min, 1, -1, -int_max],
[int_min, -1, 1, int_max], [-2, -1, 0, 1]],
dtype=dtype))
def testIntDivision(self):
for dtype in self.signed_int_types:
@@ -731,6 +742,24 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([3, 3, -1, -8], dtype=dtype),
np.array([2, -2, 7, -4], dtype=dtype),
expected=np.array([1, 1, -1, 0], dtype=dtype))
if dtype in self.signed_int_types:
# Overflow cases.
int_min = np.iinfo(dtype).min
int_max = np.iinfo(dtype).max
self._testBinary(
gen_math_ops.floor_mod,
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]),
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]),
expected=np.array([[0, -1, -int_max, -1], [0, 0, 0, 0], [0, 0, 0, 0],
[int_max - 1, int_max - 1, 1, 0]],
dtype=dtype))
self._testBinary(
gen_math_ops.truncate_mod,
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]),
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]),
expected=np.array(
[[0, -1, 1, int_max], [0, 0, 0, 0], [0, 0, 0, 0], [-1, -1, 1, 0]],
dtype=dtype))
def testIntRemainder(self):
for dtype in self.signed_int_types - {np.int8}:

View File

@@ -106,12 +106,11 @@ XLA_MAKE_BINARY(MulNoNan,
//
// For floating-point values, simply returns floor(x / y). For integers, does:
//
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
// T abs_y = std::abs(y);
// return -(abs_x + abs_y - 1) / abs_y;
// z = x / y
// if (z * y != x && (x < 0) != (y < 0)) {
// return z - 1;
// } else {
// return x / y;
// return z;
// }
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
@@ -134,11 +133,10 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
}
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
auto abs_x = xla::Abs(x);
auto abs_y = xla::Abs(y);
auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
auto x_div_y = xla::Div(x, y);
auto round_down = xla::And(xla::Ne(xla::Mul(x_div_y, y), x),
xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)));
return xla::Select(round_down, xla::Sub(x_div_y, one), x_div_y);
}
XLA_MAKE_BINARY(FloorDiv,
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));

View File

@@ -131,7 +131,15 @@ struct safe_div_or_mod_op {
const T& b) const {
const T safe_b = tensorflow::internal::SubtleMustCopy(b);
if (TF_PREDICT_TRUE(safe_b != 0)) {
return DivOrMod()(a, safe_b);
// Avoid FPE for INT_MIN/-1.
const T safe_a = tensorflow::internal::SubtleMustCopy(a);
if (TF_PREDICT_FALSE(std::is_signed<T>::value &&
safe_a == std::numeric_limits<T>::min() &&
safe_b == T(-1))) {
// Prefer to overflow 'a' instead of crashing.
return DivOrMod()(-safe_a, 1);
}
return DivOrMod()(safe_a, safe_b);
} else {
*error = true;
return 0;
@@ -435,20 +443,10 @@ template <typename T, typename Enable = void>
struct google_floor_div {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
const T& y) const {
if ((x < T(0)) != (y < T(0))) {
// HIP does not have the device version of the abs routine defined
// for all datatypes that T can resolve to
#if defined(__HIP_DEVICE_COMPILE__)
T abs_x = (x < T(0)) ? -x : x;
T abs_y = (y < T(0)) ? -y : y;
#else
T abs_x = std::abs(x);
T abs_y = std::abs(y);
#endif
return -(abs_x + abs_y - 1) / abs_y;
} else {
return x / y;
}
const T z = x / y;
// Subtract one if there is a remainder and if the inputs have opposite
// signs. This approach avoids unnecessary overflows.
return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
@@ -457,11 +455,9 @@ struct google_floor_div {
Packet x_mask = pcmp_lt(x, zeros);
Packet y_mask = pcmp_lt(y, zeros);
Packet x_div_y = pdiv(x, y);
Packet abs_x = pabs(x);
Packet abs_y = pabs(y);
Packet ones = pones(x);
Packet ratio_rounded = pdiv(pnegate(psub(padd(abs_x, abs_y), ones)), abs_y);
return pselect(pxor(x_mask, y_mask), ratio_rounded, x_div_y);
Packet x_div_y_times_y = pmul(x_div_y, y);
return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y,
psub(x_div_y, pones(x)));
}
};

View File

@@ -539,13 +539,36 @@ class DivAndModTest(test_util.TensorFlowTestCase):
divs = np.arange(-3, 0, .25).reshape(1, 12)
return nums, divs
def numpySafeFloorDivInt(self, x, y):
z = x // y
# Numpy produces 0 for INT_MIN/-1, but we expect an overflow to INT_MIN
# so that (INT_MIN/-1) + (INT_MIN % -1) = INT_MIN + 0 = INT_MIN.
z[(x == np.iinfo(x.dtype).min) & (y == -1)] = np.iinfo(x.dtype).min
return z
def numpySafeFloorModInt(self, x, y):
# Numpy crashes with a FPE for INT_MIN % -1.
z = self.numpySafeFloorDivInt(x, y)
return x - z * y
def numpySafeTruncateDivInt(self, x, y):
z = self.numpySafeFloorDivInt(x, y)
# Round up if non-zero remainder and inputs have opposite signs.
z[(x != z * y) & ((x < 0) != (y < 0))] += 1
return z
def numpySafeTruncateModInt(self, x, y):
# Numpy crashes with a FPE for INT_MIN % -1.
z = self.numpySafeTruncateDivInt(x, y)
return x - z * y
def testFloorModInt(self):
nums, divs = self.intTestData()
for dtype in [np.int32, np.int64]:
x = nums.astype(dtype)
y = divs.astype(dtype)
tf_result = math_ops.floormod(x, y)
np_result = x % y
np_result = self.numpySafeFloorModInt(x, y)
self.assertAllEqual(tf_result, np_result)
tf2_result = (array_ops.constant(x) % array_ops.constant(y))
self.assertAllEqual(tf2_result, tf_result)
@@ -581,14 +604,20 @@ class DivAndModTest(test_util.TensorFlowTestCase):
np_result = np.fmod(nums, divs)
self.assertAllEqual(tf_result, np_result)
def testDivideInt(self):
def testFloorDivideInt(self):
nums, divs = self.intTestData()
tf_result = math_ops.floor_div(nums, divs)
np_result = nums // divs
np_result = self.numpySafeFloorDivInt(nums, divs)
self.assertAllEqual(tf_result, np_result)
tf2_result = (array_ops.constant(nums) // array_ops.constant(divs))
self.assertAllEqual(tf2_result, tf_result)
def testTruncateDivideInt(self):
nums, divs = self.intTestData()
tf_result = math_ops.truncatediv(nums, divs)
np_result = self.numpySafeTruncateDivInt(nums, divs)
self.assertAllEqual(tf_result, np_result)
@test_util.deprecated_graph_mode_only
def testDivideName(self):
op = math_ops.divide(
@@ -673,6 +702,40 @@ class DivAndModTest(test_util.TensorFlowTestCase):
x = math_ops.divide(5, array_ops.constant(2.0))
self.assertIsInstance(x, ops.Tensor)
def intEdgeTestData(self, dtype):
"""Edge-case test data for integer types."""
nums = np.array([np.iinfo(dtype).min, -1, 1,
np.iinfo(dtype).max],
dtype=dtype).reshape([4, 1])
divs = nums.reshape([1, 4])
return nums, divs
def testFloorDivModIntEdges(self):
for dtype in [np.int32, np.int64]:
x, y = self.intEdgeTestData(dtype)
tf_floor_div = math_ops.floor_div(x, y)
np_floor_div = self.numpySafeFloorDivInt(x, y)
self.assertAllEqual(tf_floor_div, np_floor_div)
tf_floor_mod = math_ops.floormod(x, y)
np_floor_mod = self.numpySafeFloorModInt(x, y)
self.assertAllEqual(tf_floor_mod, np_floor_mod)
z = math_ops.add(math_ops.multiply(tf_floor_div, y), tf_floor_mod)
# x = floor_div(x, y) * y + floor_mod(x, y)
self.assertAllEqual(z, np.broadcast_to(x, z.shape))
def testTruncateDivModIntEdges(self):
for dtype in [np.int32, np.int64]:
x, y = self.intEdgeTestData(dtype)
tf_truncate_div = math_ops.truncatediv(x, y)
np_truncate_div = self.numpySafeTruncateDivInt(x, y)
self.assertAllEqual(tf_truncate_div, np_truncate_div)
tf_truncate_mod = math_ops.truncatemod(x, y)
np_truncate_mod = self.numpySafeTruncateModInt(x, y)
self.assertAllEqual(tf_truncate_mod, np_truncate_mod)
z = math_ops.add(math_ops.multiply(tf_truncate_div, y), tf_truncate_mod)
# x = truncatediv(x, y) * y + truncatemod(x, y)
self.assertAllEqual(z, np.broadcast_to(x, z.shape))
@test_util.run_all_in_graph_and_eager_modes
class DivNoNanTest(test_util.TensorFlowTestCase, parameterized.TestCase):