From 76a0609b6bddb2bc40f1eb4ade12885023653d59 Mon Sep 17 00:00:00 2001 From: "Liao, Wei" Date: Mon, 11 Aug 2025 19:43:11 +0000 Subject: [PATCH] port distributed pipeline test files for Intel GPU (#159033) In this PR we will port all distributed pipeline test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: 1. instantiate_device_type_tests() 2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend 3. use "requires_accelerator_dist_backend()" to replace requires_nccl() 4. use "get_default_backend_for_device()" to get backend 5. enabled XPU for some test path 6. add TEST_MULTIACCELERATOR in common_utils for all backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159033 Approved by: https://github.com/guangyey, https://github.com/d4l3k Co-authored-by: Daisy Deng --- test/distributed/pipelining/test_schedule.py | 10 +-- .../pipelining/test_schedule_multiproc.py | 89 ++++++++++++------- test/distributed/pipelining/test_stage.py | 51 ++++++----- .../pipelining/test_transformer.py | 4 +- test/distributed/pipelining/test_unflatten.py | 4 +- torch/testing/_internal/common_utils.py | 1 + 6 files changed, 102 insertions(+), 57 deletions(-) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index b1ad9b757a8..6f5b4df82a4 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -38,7 +38,7 @@ from torch.distributed.pipelining.schedules import ( W, ) from torch.distributed.pipelining.stage import _PipelineStageBase, PipelineStage -from torch.testing._internal.common_distributed import requires_nccl +from torch.testing._internal.common_distributed import requires_accelerator_dist_backend from torch.testing._internal.common_utils import ( check_leaked_tensors, instantiate_parametrized_tests, @@ -51,6 +51,8 @@ from torch.testing._internal.distributed.fake_pg import FakeStore ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "artifacts") +device = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + logger = logging.getLogger(__name__) torch.manual_seed(0) @@ -657,7 +659,7 @@ class TestScheduleLowering(TestCase): # print(_format_pipeline_order(simulated_schedule)) self.assertEqual(num_steps, 113) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) def test_grad_with_v_schedule(self): """ We have a special case for V schedules where 2 adjacent stages are on the same rank. @@ -677,7 +679,6 @@ class TestScheduleLowering(TestCase): d_hid = 512 batch_size = 256 n_stages = 2 - device = "cuda" full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) @@ -776,7 +777,7 @@ class TestScheduleLowering(TestCase): torch.distributed.destroy_process_group() - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) def test_grad_with_split_b_w(self): """ Ensure that separate dInput and dWeight computations are correctly executed. @@ -789,7 +790,6 @@ class TestScheduleLowering(TestCase): d_hid = 512 batch_size = 256 n_stages = 1 - device = "cuda" full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index ae91911bc6a..a87d9245415 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -26,10 +26,9 @@ from torch.distributed.pipelining import ( ScheduleZBVZeroBubble, ) from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime -from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, - requires_nccl, + requires_accelerator_dist_backend, ) from torch.testing._internal.common_utils import ( check_leaked_tensors, @@ -37,6 +36,7 @@ from torch.testing._internal.common_utils import ( parametrize, run_tests, skip_but_pass_in_sandcastle_if, + TEST_MULTIACCELERATOR, ) @@ -45,7 +45,8 @@ logger = logging.getLogger(__name__) d_hid = 512 batch_size = 64 torch.manual_seed(0) -device_type = "cuda" +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +backend = dist.get_default_backend_for_device(device_type) class ScheduleTest(MultiProcContinousTest): @@ -53,8 +54,7 @@ class ScheduleTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: - # Testing with NCCL backend - return "nccl" + return backend @property def device(self) -> torch.device: @@ -180,8 +180,10 @@ class ScheduleTest(MultiProcContinousTest): for stage_module in stage_modules: stage_module.zero_grad() - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [_ScheduleForwardOnly]) def test_forward_only(self, ScheduleClass): mod, mod_ref, x, _, _ = self._setup_models_and_data() @@ -210,8 +212,10 @@ class ScheduleTest(MultiProcContinousTest): x_clone = mod_ref(x_clone) torch.testing.assert_close(x_clone, out) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "ScheduleClass", [ @@ -283,8 +287,10 @@ class ScheduleTest(MultiProcContinousTest): if self.rank == self.world_size - 1: self.assertTrue(len(losses) > 0, "Losses should be computed during eval()") - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_multi_iter(self, ScheduleClass): mod, _, x, target, loss_fn = self._setup_models_and_data() @@ -302,8 +308,10 @@ class ScheduleTest(MultiProcContinousTest): else: schedule.step() - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_kwargs_with_tracer(self, ScheduleClass): # Model has two stages only, thus limiting group size to 2 @@ -359,8 +367,10 @@ class ScheduleTest(MultiProcContinousTest): torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3) torch.testing.assert_close(pipe_loss, ref_loss) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_grad_with_tracer(self, ScheduleClass): mod, ref_mod, x, target, loss_fn = self._setup_models_and_data() @@ -398,8 +408,10 @@ class ScheduleTest(MultiProcContinousTest): # Check gradients using helper method self._check_gradients(stage_module, ref_mod) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @parametrize("shape_inference", [True, False]) def test_grad_with_manual(self, ScheduleClass, shape_inference): @@ -453,8 +465,10 @@ class ScheduleTest(MultiProcContinousTest): # Check gradients using helper method self._check_gradients(stage_module, ref_mod) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "ScheduleClass", [ @@ -563,8 +577,10 @@ class ScheduleTest(MultiProcContinousTest): stage_modules, ref_mod, submod_names, rtol=5e-3, atol=5e-3 ) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble]) def test_schedule_with_native_zero_bubble(self, ScheduleClass): print(ScheduleClass) @@ -621,9 +637,16 @@ class ScheduleTest(MultiProcContinousTest): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - @parametrize("ScheduleClass", [ScheduleWithReorderedB]) + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) + @parametrize( + "ScheduleClass", + [ + ScheduleWithReorderedB, + ], + ) def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): n_stages = 2 stages_per_rank = 1 @@ -679,8 +702,10 @@ class ScheduleTest(MultiProcContinousTest): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble] ) @@ -740,8 +765,10 @@ class ScheduleTest(MultiProcContinousTest): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 @@ -820,8 +847,10 @@ class ScheduleTest(MultiProcContinousTest): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "ScheduleClass", [ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B], diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index a711cec64d7..acb5bec7d84 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -14,11 +14,10 @@ from torch.distributed.pipelining import ( ScheduleGPipe, ) from torch.distributed.pipelining._utils import PipeliningShapeError -from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, MultiProcessTestCase, - requires_nccl, + requires_accelerator_dist_backend, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -26,6 +25,7 @@ from torch.testing._internal.common_utils import ( run_tests, skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, + TEST_MULTIACCELERATOR, ) from torch.utils._pytree import tree_map_only @@ -34,8 +34,8 @@ d_hid = 512 batch_size = 256 chunks = 4 -device_type = "cuda" - +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +backend = dist.get_default_backend_for_device(device_type) torch.manual_seed(0) @@ -66,8 +66,7 @@ def get_flatten_hook(): class StageTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: - # Testing with NCCL backend - return "nccl" + return backend @classmethod def device_type(cls) -> str: @@ -77,8 +76,10 @@ class StageTest(MultiProcContinousTest): def device(self) -> torch.device: return torch.device(device_type, self.rank) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ModelClass", [ExampleCode, MultiMLP]) def test_tracer(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -121,8 +122,10 @@ class StageTest(MultiProcContinousTest): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ModelClass", [ModelWithKwargs]) def test_tracer_kwargs(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -170,8 +173,10 @@ class StageTest(MultiProcContinousTest): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_manual(self): full_mod = MultiMLP(d_hid, n_layers=self.world_size) full_mod.to(self.device) @@ -202,8 +207,10 @@ class StageTest(MultiProcContinousTest): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_custom_dw_with_fb_schedule(self): """Tests that separate weight grad function 'dw_runner' gets run under a schedule that's only aware of F/B.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -262,8 +269,10 @@ class StageTest(MultiProcContinousTest): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_output_chunks_memory_usage(self): """Test that output_chunks doesn't store memory for non-first stages.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -347,14 +356,14 @@ class StageNegativeTest(MultiProcessTestCase): def init_pg(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( - backend="nccl", + backend=backend, store=store, rank=self.rank, world_size=self.world_size, device_id=self.device, ) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle("Flaky in CI") def test_shape_prop_mismatch(self): """Tests shape prop errors are raised""" @@ -402,8 +411,10 @@ class StageNegativeTest(MultiProcessTestCase): with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): _run_step(x) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_custom_dw_errors(self): """Tests expected errors are raised""" self.init_pg() diff --git a/test/distributed/pipelining/test_transformer.py b/test/distributed/pipelining/test_transformer.py index 7e58129186a..20e830547de 100644 --- a/test/distributed/pipelining/test_transformer.py +++ b/test/distributed/pipelining/test_transformer.py @@ -73,7 +73,9 @@ class TransformerTests(TestCase): devices = ["cpu", "cuda", "hpu", "xpu"] -instantiate_device_type_tests(TransformerTests, globals(), only_for=devices) +instantiate_device_type_tests( + TransformerTests, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index ae1e684d7c2..0493f39b16c 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -73,7 +73,9 @@ class UnflattenTests(TestCase): devices = ["cpu", "cuda", "hpu", "xpu"] -instantiate_device_type_tests(UnflattenTests, globals(), only_for=devices) +instantiate_device_type_tests( + UnflattenTests, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": run_tests() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bfc568bc146..f3c0648b462 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1422,6 +1422,7 @@ MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1) TEST_XPU = torch.xpu.is_available() TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False TEST_CUDA = torch.cuda.is_available() +TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2 custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) TEST_PRIVATEUSE1 = is_privateuse1_backend_available() TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()