diff --git a/test/distributed/tensor/test_pointwise_ops.py b/test/distributed/tensor/test_pointwise_ops.py index 5c1ac0651e3..30cc7ed9650 100644 --- a/test/distributed/tensor/test_pointwise_ops.py +++ b/test/distributed/tensor/test_pointwise_ops.py @@ -467,7 +467,8 @@ class DistElementwiseOpsTest(DTensorOpTestBase): self.assertTrue(res._spec.placements[0].is_partial()) res = res.redistribute(dt.device_mesh, placements=[Replicate()]) - self.assertEqual(res, 12) + expected = sum(i * 2 for i in range(self.world_size)) + self.assertEqual(res, expected) res = aten.div.Scalar(dt, 2) self.assertEqual( @@ -478,7 +479,8 @@ class DistElementwiseOpsTest(DTensorOpTestBase): self.assertTrue(res._spec.placements[0].is_partial()) res = res.redistribute(dt.device_mesh, placements=[Replicate()]) - self.assertEqual(res, 3) + expected = expected / 4.0 + self.assertEqual(res, expected) @with_comms @parametrize("op,reduce_op", [(torch.maximum, "max"), (torch.minimum, "min")])