mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Fix integer div/mod overflow issues for Eigen, XLA, MLIR.
The original implementations of `google_floor_div`, XLA `FloorDiv` and MLIR `TF_FloorDivOp` all suffered from overflows for `abs(x) + abs(y) > INT_MAX`. The new implementation addresses this. Also fixed FPE issues for the Eigen backend when trying to divide `INT_MIN / -1`. This previously caused TF to crash. Here we decide to overflow so that `INT_MIN / -1 = INT_MIN`. This is consistent with `int16` and `int8` types (which already overflow without crashing in C++), and maintains the equality ``` x = div(x, y) * y + mod(x, y) ``` Also added tests for these edge cases. Fixes #46887, #45771. PiperOrigin-RevId: 372993714 Change-Id: If38fdd9e53d562011e529c58b415786ba015a0be
This commit is contained in:
committed by
TensorFlower Gardener
parent
bf5b1056fc
commit
b47be308c4
@@ -740,21 +740,18 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @floordiv_broadcast_i32
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
|
||||
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
|
||||
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
|
||||
// CHECK: return [[SELECT]]
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
return %0: tensor<2x3xi32>
|
||||
@@ -762,21 +759,18 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te
|
||||
|
||||
// CHECK-LABEL: func @floordiv_reverse_broadcast_i32
|
||||
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]]
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
|
||||
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
|
||||
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]]
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
|
||||
// CHECK: return [[SELECT]]
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0: tensor<2x3xi32>
|
||||
@@ -814,21 +808,18 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te
|
||||
|
||||
// CHECK-LABEL: func @floordiv_dynamic
|
||||
func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
|
||||
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
|
||||
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
|
||||
// CHECK: return [[SELECT]]
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
|
||||
return %0: tensor<?x?xi32>
|
||||
|
||||
@@ -126,13 +126,8 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
|
||||
// additional correction for a negative numerator / denominator. Equivalent
|
||||
// pseudocode is shown below:
|
||||
//
|
||||
// if ((x < 0) != (y < 0)) {
|
||||
// T abs_x = std::abs(x);
|
||||
// T abs_y = std::abs(y);
|
||||
// return -(abs_x + abs_y - 1) / abs_y;
|
||||
// } else {
|
||||
// return x / y;
|
||||
// }
|
||||
// T z = x / y
|
||||
// return (z * y != x && (x < 0) != (y < 0)) ? z - 1 : z
|
||||
//
|
||||
// BroadcastToDimensions is used to compute the broadcast attr to higher
|
||||
// dimensions. This computes the broadcast of 'l' to broadcast('l', 'r')
|
||||
@@ -142,31 +137,39 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
|
||||
// Requires static shaped inputs to create constant splats and computation of
|
||||
// broadcast attributes.
|
||||
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
(HLO_SelectOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||
(HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||
(BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ,
|
||||
(HLO_SelectOp
|
||||
(HLOClient_BroadcastAndOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLOClient_BroadcastMulOp:$mul
|
||||
(HLOClient_BroadcastDivOp:$div $l, $r,
|
||||
(BinBroadcastDimensions $l, $r)),
|
||||
$r, (BinBroadcastDimensions $div, $r)),
|
||||
$l, (BinBroadcastDimensions $mul, $l), HLO_COMPARISON_DIRECTION_NE,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLOClient_BroadcastCompareOp:$l_cmp $l,
|
||||
(HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||
(HLOClient_BroadcastDivOp
|
||||
(HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l),
|
||||
(HLOClient_BroadcastSubOp (HLO_AbsOp $r),
|
||||
(HLO_ConstOp (GetScalarOfType<1> $r)),
|
||||
(NullDenseIntElementsAttr)),
|
||||
(BinBroadcastDimensions $l, $r))),
|
||||
(HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))),
|
||||
[(SignedIntTensor $l)]>;
|
||||
(HLOClient_BroadcastCompareOp:$r_cmp $r,
|
||||
(HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||
(BinBroadcastDimensions $l_cmp, $r_cmp), HLO_COMPARISON_DIRECTION_NE,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||
(NullDenseIntElementsAttr)),
|
||||
(HLOClient_BroadcastSubOp $div,
|
||||
(HLO_ConstOp:$ones (GetScalarOfType<1> $div)),
|
||||
(NullDenseIntElementsAttr)), $div),
|
||||
[(SignedIntTensor $l)]>;
|
||||
|
||||
// Performs a substitution of FloorMod designed to correct for possibly negative
|
||||
// values. Pseudocode shown below:
|
||||
//
|
||||
// T trunc_mod = std::fmod(x, y);
|
||||
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
||||
// : trunc_mod
|
||||
//
|
||||
// Requires static shaped inputs to create constant splats and computation of
|
||||
// broadcast attributes.
|
||||
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
|
||||
|
||||
@@ -710,6 +710,17 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([3, 3, -1, -9, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, 2, -4], dtype=dtype),
|
||||
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
|
||||
if dtype in self.signed_int_types:
|
||||
# Overflow cases.
|
||||
int_min = np.iinfo(dtype).min
|
||||
int_max = np.iinfo(dtype).max
|
||||
self._testBinary(
|
||||
gen_math_ops.floor_div,
|
||||
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]),
|
||||
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]),
|
||||
expected=np.array([[1, 0, -1, -1], [int_min, 1, -1, -int_max],
|
||||
[int_min, -1, 1, int_max], [-2, -1, 0, 1]],
|
||||
dtype=dtype))
|
||||
|
||||
def testIntDivision(self):
|
||||
for dtype in self.signed_int_types:
|
||||
@@ -731,6 +742,24 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([3, 3, -1, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, -4], dtype=dtype),
|
||||
expected=np.array([1, 1, -1, 0], dtype=dtype))
|
||||
if dtype in self.signed_int_types:
|
||||
# Overflow cases.
|
||||
int_min = np.iinfo(dtype).min
|
||||
int_max = np.iinfo(dtype).max
|
||||
self._testBinary(
|
||||
gen_math_ops.floor_mod,
|
||||
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]),
|
||||
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]),
|
||||
expected=np.array([[0, -1, -int_max, -1], [0, 0, 0, 0], [0, 0, 0, 0],
|
||||
[int_max - 1, int_max - 1, 1, 0]],
|
||||
dtype=dtype))
|
||||
self._testBinary(
|
||||
gen_math_ops.truncate_mod,
|
||||
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]),
|
||||
np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]),
|
||||
expected=np.array(
|
||||
[[0, -1, 1, int_max], [0, 0, 0, 0], [0, 0, 0, 0], [-1, -1, 1, 0]],
|
||||
dtype=dtype))
|
||||
|
||||
def testIntRemainder(self):
|
||||
for dtype in self.signed_int_types - {np.int8}:
|
||||
|
||||
@@ -106,12 +106,11 @@ XLA_MAKE_BINARY(MulNoNan,
|
||||
//
|
||||
// For floating-point values, simply returns floor(x / y). For integers, does:
|
||||
//
|
||||
// if ((x < 0) != (y < 0)) {
|
||||
// T abs_x = std::abs(x);
|
||||
// T abs_y = std::abs(y);
|
||||
// return -(abs_x + abs_y - 1) / abs_y;
|
||||
// z = x / y
|
||||
// if (z * y != x && (x < 0) != (y < 0)) {
|
||||
// return z - 1;
|
||||
// } else {
|
||||
// return x / y;
|
||||
// return z;
|
||||
// }
|
||||
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
xla::XlaOp y, const BCast& broadcast_helper) {
|
||||
@@ -134,11 +133,10 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
}
|
||||
auto zero = XlaHelpers::Zero(b, dtype);
|
||||
auto one = XlaHelpers::One(b, dtype);
|
||||
auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
|
||||
auto abs_x = xla::Abs(x);
|
||||
auto abs_y = xla::Abs(y);
|
||||
auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
|
||||
return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
|
||||
auto x_div_y = xla::Div(x, y);
|
||||
auto round_down = xla::And(xla::Ne(xla::Mul(x_div_y, y), x),
|
||||
xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)));
|
||||
return xla::Select(round_down, xla::Sub(x_div_y, one), x_div_y);
|
||||
}
|
||||
XLA_MAKE_BINARY(FloorDiv,
|
||||
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
|
||||
|
||||
@@ -131,7 +131,15 @@ struct safe_div_or_mod_op {
|
||||
const T& b) const {
|
||||
const T safe_b = tensorflow::internal::SubtleMustCopy(b);
|
||||
if (TF_PREDICT_TRUE(safe_b != 0)) {
|
||||
return DivOrMod()(a, safe_b);
|
||||
// Avoid FPE for INT_MIN/-1.
|
||||
const T safe_a = tensorflow::internal::SubtleMustCopy(a);
|
||||
if (TF_PREDICT_FALSE(std::is_signed<T>::value &&
|
||||
safe_a == std::numeric_limits<T>::min() &&
|
||||
safe_b == T(-1))) {
|
||||
// Prefer to overflow 'a' instead of crashing.
|
||||
return DivOrMod()(-safe_a, 1);
|
||||
}
|
||||
return DivOrMod()(safe_a, safe_b);
|
||||
} else {
|
||||
*error = true;
|
||||
return 0;
|
||||
@@ -435,20 +443,10 @@ template <typename T, typename Enable = void>
|
||||
struct google_floor_div {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
|
||||
const T& y) const {
|
||||
if ((x < T(0)) != (y < T(0))) {
|
||||
// HIP does not have the device version of the abs routine defined
|
||||
// for all datatypes that T can resolve to
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
T abs_x = (x < T(0)) ? -x : x;
|
||||
T abs_y = (y < T(0)) ? -y : y;
|
||||
#else
|
||||
T abs_x = std::abs(x);
|
||||
T abs_y = std::abs(y);
|
||||
#endif
|
||||
return -(abs_x + abs_y - 1) / abs_y;
|
||||
} else {
|
||||
return x / y;
|
||||
}
|
||||
const T z = x / y;
|
||||
// Subtract one if there is a remainder and if the inputs have opposite
|
||||
// signs. This approach avoids unnecessary overflows.
|
||||
return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z;
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
|
||||
@@ -457,11 +455,9 @@ struct google_floor_div {
|
||||
Packet x_mask = pcmp_lt(x, zeros);
|
||||
Packet y_mask = pcmp_lt(y, zeros);
|
||||
Packet x_div_y = pdiv(x, y);
|
||||
Packet abs_x = pabs(x);
|
||||
Packet abs_y = pabs(y);
|
||||
Packet ones = pones(x);
|
||||
Packet ratio_rounded = pdiv(pnegate(psub(padd(abs_x, abs_y), ones)), abs_y);
|
||||
return pselect(pxor(x_mask, y_mask), ratio_rounded, x_div_y);
|
||||
Packet x_div_y_times_y = pmul(x_div_y, y);
|
||||
return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y,
|
||||
psub(x_div_y, pones(x)));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -539,13 +539,36 @@ class DivAndModTest(test_util.TensorFlowTestCase):
|
||||
divs = np.arange(-3, 0, .25).reshape(1, 12)
|
||||
return nums, divs
|
||||
|
||||
def numpySafeFloorDivInt(self, x, y):
|
||||
z = x // y
|
||||
# Numpy produces 0 for INT_MIN/-1, but we expect an overflow to INT_MIN
|
||||
# so that (INT_MIN/-1) + (INT_MIN % -1) = INT_MIN + 0 = INT_MIN.
|
||||
z[(x == np.iinfo(x.dtype).min) & (y == -1)] = np.iinfo(x.dtype).min
|
||||
return z
|
||||
|
||||
def numpySafeFloorModInt(self, x, y):
|
||||
# Numpy crashes with a FPE for INT_MIN % -1.
|
||||
z = self.numpySafeFloorDivInt(x, y)
|
||||
return x - z * y
|
||||
|
||||
def numpySafeTruncateDivInt(self, x, y):
|
||||
z = self.numpySafeFloorDivInt(x, y)
|
||||
# Round up if non-zero remainder and inputs have opposite signs.
|
||||
z[(x != z * y) & ((x < 0) != (y < 0))] += 1
|
||||
return z
|
||||
|
||||
def numpySafeTruncateModInt(self, x, y):
|
||||
# Numpy crashes with a FPE for INT_MIN % -1.
|
||||
z = self.numpySafeTruncateDivInt(x, y)
|
||||
return x - z * y
|
||||
|
||||
def testFloorModInt(self):
|
||||
nums, divs = self.intTestData()
|
||||
for dtype in [np.int32, np.int64]:
|
||||
x = nums.astype(dtype)
|
||||
y = divs.astype(dtype)
|
||||
tf_result = math_ops.floormod(x, y)
|
||||
np_result = x % y
|
||||
np_result = self.numpySafeFloorModInt(x, y)
|
||||
self.assertAllEqual(tf_result, np_result)
|
||||
tf2_result = (array_ops.constant(x) % array_ops.constant(y))
|
||||
self.assertAllEqual(tf2_result, tf_result)
|
||||
@@ -581,14 +604,20 @@ class DivAndModTest(test_util.TensorFlowTestCase):
|
||||
np_result = np.fmod(nums, divs)
|
||||
self.assertAllEqual(tf_result, np_result)
|
||||
|
||||
def testDivideInt(self):
|
||||
def testFloorDivideInt(self):
|
||||
nums, divs = self.intTestData()
|
||||
tf_result = math_ops.floor_div(nums, divs)
|
||||
np_result = nums // divs
|
||||
np_result = self.numpySafeFloorDivInt(nums, divs)
|
||||
self.assertAllEqual(tf_result, np_result)
|
||||
tf2_result = (array_ops.constant(nums) // array_ops.constant(divs))
|
||||
self.assertAllEqual(tf2_result, tf_result)
|
||||
|
||||
def testTruncateDivideInt(self):
|
||||
nums, divs = self.intTestData()
|
||||
tf_result = math_ops.truncatediv(nums, divs)
|
||||
np_result = self.numpySafeTruncateDivInt(nums, divs)
|
||||
self.assertAllEqual(tf_result, np_result)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testDivideName(self):
|
||||
op = math_ops.divide(
|
||||
@@ -673,6 +702,40 @@ class DivAndModTest(test_util.TensorFlowTestCase):
|
||||
x = math_ops.divide(5, array_ops.constant(2.0))
|
||||
self.assertIsInstance(x, ops.Tensor)
|
||||
|
||||
def intEdgeTestData(self, dtype):
|
||||
"""Edge-case test data for integer types."""
|
||||
nums = np.array([np.iinfo(dtype).min, -1, 1,
|
||||
np.iinfo(dtype).max],
|
||||
dtype=dtype).reshape([4, 1])
|
||||
divs = nums.reshape([1, 4])
|
||||
return nums, divs
|
||||
|
||||
def testFloorDivModIntEdges(self):
|
||||
for dtype in [np.int32, np.int64]:
|
||||
x, y = self.intEdgeTestData(dtype)
|
||||
tf_floor_div = math_ops.floor_div(x, y)
|
||||
np_floor_div = self.numpySafeFloorDivInt(x, y)
|
||||
self.assertAllEqual(tf_floor_div, np_floor_div)
|
||||
tf_floor_mod = math_ops.floormod(x, y)
|
||||
np_floor_mod = self.numpySafeFloorModInt(x, y)
|
||||
self.assertAllEqual(tf_floor_mod, np_floor_mod)
|
||||
z = math_ops.add(math_ops.multiply(tf_floor_div, y), tf_floor_mod)
|
||||
# x = floor_div(x, y) * y + floor_mod(x, y)
|
||||
self.assertAllEqual(z, np.broadcast_to(x, z.shape))
|
||||
|
||||
def testTruncateDivModIntEdges(self):
|
||||
for dtype in [np.int32, np.int64]:
|
||||
x, y = self.intEdgeTestData(dtype)
|
||||
tf_truncate_div = math_ops.truncatediv(x, y)
|
||||
np_truncate_div = self.numpySafeTruncateDivInt(x, y)
|
||||
self.assertAllEqual(tf_truncate_div, np_truncate_div)
|
||||
tf_truncate_mod = math_ops.truncatemod(x, y)
|
||||
np_truncate_mod = self.numpySafeTruncateModInt(x, y)
|
||||
self.assertAllEqual(tf_truncate_mod, np_truncate_mod)
|
||||
z = math_ops.add(math_ops.multiply(tf_truncate_div, y), tf_truncate_mod)
|
||||
# x = truncatediv(x, y) * y + truncatemod(x, y)
|
||||
self.assertAllEqual(z, np.broadcast_to(x, z.shape))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DivNoNanTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user