[CI] fix test_pointwise_ops.py test_mul_div_scalar_partial (#170510)

Support any world size; 2, 3 or 4.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/170510
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-12-16 06:24:40 +00:00
committed by PyTorch MergeBot
parent 66407ac9cb
commit 225496166b

View File

@@ -467,7 +467,8 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
self.assertTrue(res._spec.placements[0].is_partial())
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
self.assertEqual(res, 12)
expected = sum(i * 2 for i in range(self.world_size))
self.assertEqual(res, expected)
res = aten.div.Scalar(dt, 2)
self.assertEqual(
@@ -478,7 +479,8 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
self.assertTrue(res._spec.placements[0].is_partial())
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
self.assertEqual(res, 3)
expected = expected / 4.0
self.assertEqual(res, expected)
@with_comms
@parametrize("op,reduce_op", [(torch.maximum, "max"), (torch.minimum, "min")])