diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 735a407640b..92e4284a296 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -740,21 +740,18 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) - // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) - // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> return %0: tensor<2x3xi32> @@ -762,21 +759,18 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]] + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) - // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) - // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] - // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0: tensor<2x3xi32> @@ -814,21 +808,18 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te // CHECK-LABEL: func @floordiv_dynamic func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) - // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) - // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index a3685110730..fdf7e306d33 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -126,13 +126,8 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // additional correction for a negative numerator / denominator. Equivalent // pseudocode is shown below: // -// if ((x < 0) != (y < 0)) { -// T abs_x = std::abs(x); -// T abs_y = std::abs(y); -// return -(abs_x + abs_y - 1) / abs_y; -// } else { -// return x / y; -// } +// T z = x / y +// return (z * y != x && (x < 0) != (y < 0)) ? z - 1 : z // // BroadcastToDimensions is used to compute the broadcast attr to higher // dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') @@ -142,31 +137,39 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // Requires static shaped inputs to create constant splats and computation of // broadcast attributes. def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_SelectOp - (HLOClient_BroadcastCompareOp - (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, - (HLO_DEFAULT_COMPARISON_TYPE)), - (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, - (HLO_DEFAULT_COMPARISON_TYPE)), - (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ, + (HLO_SelectOp + (HLOClient_BroadcastAndOp + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastMulOp:$mul + (HLOClient_BroadcastDivOp:$div $l, $r, + (BinBroadcastDimensions $l, $r)), + $r, (BinBroadcastDimensions $div, $r)), + $l, (BinBroadcastDimensions $mul, $l), HLO_COMPARISON_DIRECTION_NE, + (HLO_DEFAULT_COMPARISON_TYPE)), + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp:$l_cmp $l, + (HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, (HLO_DEFAULT_COMPARISON_TYPE)), - (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), - (HLOClient_BroadcastDivOp - (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), - (HLOClient_BroadcastSubOp (HLO_AbsOp $r), - (HLO_ConstOp (GetScalarOfType<1> $r)), - (NullDenseIntElementsAttr)), - (BinBroadcastDimensions $l, $r))), - (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), - [(SignedIntTensor $l)]>; + (HLOClient_BroadcastCompareOp:$r_cmp $r, + (HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE)), + (BinBroadcastDimensions $l_cmp, $r_cmp), HLO_COMPARISON_DIRECTION_NE, + (HLO_DEFAULT_COMPARISON_TYPE)), + (NullDenseIntElementsAttr)), + (HLOClient_BroadcastSubOp $div, + (HLO_ConstOp:$ones (GetScalarOfType<1> $div)), + (NullDenseIntElementsAttr)), $div), + [(SignedIntTensor $l)]>; // Performs a substitution of FloorMod designed to correct for possibly negative // values. Pseudocode shown below: // // T trunc_mod = std::fmod(x, y); // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y +// : trunc_mod +// // Requires static shaped inputs to create constant splats and computation of // broadcast attributes. def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index b8abbcb8035..2e54b17807c 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -710,6 +710,17 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) + if dtype in self.signed_int_types: + # Overflow cases. + int_min = np.iinfo(dtype).min + int_max = np.iinfo(dtype).max + self._testBinary( + gen_math_ops.floor_div, + np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]), + np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]), + expected=np.array([[1, 0, -1, -1], [int_min, 1, -1, -int_max], + [int_min, -1, 1, int_max], [-2, -1, 0, 1]], + dtype=dtype)) def testIntDivision(self): for dtype in self.signed_int_types: @@ -731,6 +742,24 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([3, 3, -1, -8], dtype=dtype), np.array([2, -2, 7, -4], dtype=dtype), expected=np.array([1, 1, -1, 0], dtype=dtype)) + if dtype in self.signed_int_types: + # Overflow cases. + int_min = np.iinfo(dtype).min + int_max = np.iinfo(dtype).max + self._testBinary( + gen_math_ops.floor_mod, + np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]), + np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]), + expected=np.array([[0, -1, -int_max, -1], [0, 0, 0, 0], [0, 0, 0, 0], + [int_max - 1, int_max - 1, 1, 0]], + dtype=dtype)) + self._testBinary( + gen_math_ops.truncate_mod, + np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([1, 4]), + np.array([int_min, -1, 1, int_max], dtype=dtype).reshape([4, 1]), + expected=np.array( + [[0, -1, 1, int_max], [0, 0, 0, 0], [0, 0, 0, 0], [-1, -1, 1, 0]], + dtype=dtype)) def testIntRemainder(self): for dtype in self.signed_int_types - {np.int8}: diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 59241051df1..d1b5feec8b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -106,12 +106,11 @@ XLA_MAKE_BINARY(MulNoNan, // // For floating-point values, simply returns floor(x / y). For integers, does: // -// if ((x < 0) != (y < 0)) { -// T abs_x = std::abs(x); -// T abs_y = std::abs(y); -// return -(abs_x + abs_y - 1) / abs_y; +// z = x / y +// if (z * y != x && (x < 0) != (y < 0)) { +// return z - 1; // } else { -// return x / y; +// return z; // } static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -134,11 +133,10 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, } auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); - auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); - auto abs_x = xla::Abs(x); - auto abs_y = xla::Abs(y); - auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one)); - return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y)); + auto x_div_y = xla::Div(x, y); + auto round_down = xla::And(xla::Ne(xla::Mul(x_div_y, y), x), + xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero))); + return xla::Select(round_down, xla::Sub(x_div_y, one), x_div_y); } XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index fac2a418c24..dbeb335180c 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -131,7 +131,15 @@ struct safe_div_or_mod_op { const T& b) const { const T safe_b = tensorflow::internal::SubtleMustCopy(b); if (TF_PREDICT_TRUE(safe_b != 0)) { - return DivOrMod()(a, safe_b); + // Avoid FPE for INT_MIN/-1. + const T safe_a = tensorflow::internal::SubtleMustCopy(a); + if (TF_PREDICT_FALSE(std::is_signed::value && + safe_a == std::numeric_limits::min() && + safe_b == T(-1))) { + // Prefer to overflow 'a' instead of crashing. + return DivOrMod()(-safe_a, 1); + } + return DivOrMod()(safe_a, safe_b); } else { *error = true; return 0; @@ -435,20 +443,10 @@ template struct google_floor_div { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, const T& y) const { - if ((x < T(0)) != (y < T(0))) { -// HIP does not have the device version of the abs routine defined -// for all datatypes that T can resolve to -#if defined(__HIP_DEVICE_COMPILE__) - T abs_x = (x < T(0)) ? -x : x; - T abs_y = (y < T(0)) ? -y : y; -#else - T abs_x = std::abs(x); - T abs_y = std::abs(y); -#endif - return -(abs_x + abs_y - 1) / abs_y; - } else { - return x / y; - } + const T z = x / y; + // Subtract one if there is a remainder and if the inputs have opposite + // signs. This approach avoids unnecessary overflows. + return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z; } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, @@ -457,11 +455,9 @@ struct google_floor_div { Packet x_mask = pcmp_lt(x, zeros); Packet y_mask = pcmp_lt(y, zeros); Packet x_div_y = pdiv(x, y); - Packet abs_x = pabs(x); - Packet abs_y = pabs(y); - Packet ones = pones(x); - Packet ratio_rounded = pdiv(pnegate(psub(padd(abs_x, abs_y), ones)), abs_y); - return pselect(pxor(x_mask, y_mask), ratio_rounded, x_div_y); + Packet x_div_y_times_y = pmul(x_div_y, y); + return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y, + psub(x_div_y, pones(x))); } }; diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 6b3a8c29857..8d3fa12e9ce 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -539,13 +539,36 @@ class DivAndModTest(test_util.TensorFlowTestCase): divs = np.arange(-3, 0, .25).reshape(1, 12) return nums, divs + def numpySafeFloorDivInt(self, x, y): + z = x // y + # Numpy produces 0 for INT_MIN/-1, but we expect an overflow to INT_MIN + # so that (INT_MIN/-1) + (INT_MIN % -1) = INT_MIN + 0 = INT_MIN. + z[(x == np.iinfo(x.dtype).min) & (y == -1)] = np.iinfo(x.dtype).min + return z + + def numpySafeFloorModInt(self, x, y): + # Numpy crashes with a FPE for INT_MIN % -1. + z = self.numpySafeFloorDivInt(x, y) + return x - z * y + + def numpySafeTruncateDivInt(self, x, y): + z = self.numpySafeFloorDivInt(x, y) + # Round up if non-zero remainder and inputs have opposite signs. + z[(x != z * y) & ((x < 0) != (y < 0))] += 1 + return z + + def numpySafeTruncateModInt(self, x, y): + # Numpy crashes with a FPE for INT_MIN % -1. + z = self.numpySafeTruncateDivInt(x, y) + return x - z * y + def testFloorModInt(self): nums, divs = self.intTestData() for dtype in [np.int32, np.int64]: x = nums.astype(dtype) y = divs.astype(dtype) tf_result = math_ops.floormod(x, y) - np_result = x % y + np_result = self.numpySafeFloorModInt(x, y) self.assertAllEqual(tf_result, np_result) tf2_result = (array_ops.constant(x) % array_ops.constant(y)) self.assertAllEqual(tf2_result, tf_result) @@ -581,14 +604,20 @@ class DivAndModTest(test_util.TensorFlowTestCase): np_result = np.fmod(nums, divs) self.assertAllEqual(tf_result, np_result) - def testDivideInt(self): + def testFloorDivideInt(self): nums, divs = self.intTestData() tf_result = math_ops.floor_div(nums, divs) - np_result = nums // divs + np_result = self.numpySafeFloorDivInt(nums, divs) self.assertAllEqual(tf_result, np_result) tf2_result = (array_ops.constant(nums) // array_ops.constant(divs)) self.assertAllEqual(tf2_result, tf_result) + def testTruncateDivideInt(self): + nums, divs = self.intTestData() + tf_result = math_ops.truncatediv(nums, divs) + np_result = self.numpySafeTruncateDivInt(nums, divs) + self.assertAllEqual(tf_result, np_result) + @test_util.deprecated_graph_mode_only def testDivideName(self): op = math_ops.divide( @@ -673,6 +702,40 @@ class DivAndModTest(test_util.TensorFlowTestCase): x = math_ops.divide(5, array_ops.constant(2.0)) self.assertIsInstance(x, ops.Tensor) + def intEdgeTestData(self, dtype): + """Edge-case test data for integer types.""" + nums = np.array([np.iinfo(dtype).min, -1, 1, + np.iinfo(dtype).max], + dtype=dtype).reshape([4, 1]) + divs = nums.reshape([1, 4]) + return nums, divs + + def testFloorDivModIntEdges(self): + for dtype in [np.int32, np.int64]: + x, y = self.intEdgeTestData(dtype) + tf_floor_div = math_ops.floor_div(x, y) + np_floor_div = self.numpySafeFloorDivInt(x, y) + self.assertAllEqual(tf_floor_div, np_floor_div) + tf_floor_mod = math_ops.floormod(x, y) + np_floor_mod = self.numpySafeFloorModInt(x, y) + self.assertAllEqual(tf_floor_mod, np_floor_mod) + z = math_ops.add(math_ops.multiply(tf_floor_div, y), tf_floor_mod) + # x = floor_div(x, y) * y + floor_mod(x, y) + self.assertAllEqual(z, np.broadcast_to(x, z.shape)) + + def testTruncateDivModIntEdges(self): + for dtype in [np.int32, np.int64]: + x, y = self.intEdgeTestData(dtype) + tf_truncate_div = math_ops.truncatediv(x, y) + np_truncate_div = self.numpySafeTruncateDivInt(x, y) + self.assertAllEqual(tf_truncate_div, np_truncate_div) + tf_truncate_mod = math_ops.truncatemod(x, y) + np_truncate_mod = self.numpySafeTruncateModInt(x, y) + self.assertAllEqual(tf_truncate_mod, np_truncate_mod) + z = math_ops.add(math_ops.multiply(tf_truncate_div, y), tf_truncate_mod) + # x = truncatediv(x, y) * y + truncatemod(x, y) + self.assertAllEqual(z, np.broadcast_to(x, z.shape)) + @test_util.run_all_in_graph_and_eager_modes class DivNoNanTest(test_util.TensorFlowTestCase, parameterized.TestCase):