mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[CI] fix test_pointwise_ops.py test_mul_div_scalar_partial (#170510)
Support any world size; 2, 3 or 4. Pull Request resolved: https://github.com/pytorch/pytorch/pull/170510 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
66407ac9cb
commit
225496166b
@@ -467,7 +467,8 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
|
||||
self.assertTrue(res._spec.placements[0].is_partial())
|
||||
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
|
||||
self.assertEqual(res, 12)
|
||||
expected = sum(i * 2 for i in range(self.world_size))
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
res = aten.div.Scalar(dt, 2)
|
||||
self.assertEqual(
|
||||
@@ -478,7 +479,8 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
self.assertTrue(res._spec.placements[0].is_partial())
|
||||
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
|
||||
|
||||
self.assertEqual(res, 3)
|
||||
expected = expected / 4.0
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
@with_comms
|
||||
@parametrize("op,reduce_op", [(torch.maximum, "max"), (torch.minimum, "min")])
|
||||
|
||||
Reference in New Issue
Block a user