port distributed pipeline test files for Intel GPU (#159033)

In this PR we will port all distributed pipeline test files.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. instantiate_device_type_tests()
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend
3. use "requires_accelerator_dist_backend()" to replace requires_nccl()
4. use "get_default_backend_for_device()" to get backend
5. enabled XPU for some test path
6. add TEST_MULTIACCELERATOR in common_utils for all backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159033
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Daisy Deng <daisy.deng@intel.com>
This commit is contained in:
Liao, Wei
2025-08-11 19:43:11 +00:00
committed by PyTorch MergeBot
parent c8205cb354
commit 76a0609b6b
6 changed files with 102 additions and 57 deletions

View File

@@ -38,7 +38,7 @@ from torch.distributed.pipelining.schedules import (
W,
)
from torch.distributed.pipelining.stage import _PipelineStageBase, PipelineStage
from torch.testing._internal.common_distributed import requires_nccl
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
from torch.testing._internal.common_utils import (
check_leaked_tensors,
instantiate_parametrized_tests,
@@ -51,6 +51,8 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "artifacts")
device = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
logger = logging.getLogger(__name__)
torch.manual_seed(0)
@@ -657,7 +659,7 @@ class TestScheduleLowering(TestCase):
# print(_format_pipeline_order(simulated_schedule))
self.assertEqual(num_steps, 113)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_grad_with_v_schedule(self):
"""
We have a special case for V schedules where 2 adjacent stages are on the same rank.
@@ -677,7 +679,6 @@ class TestScheduleLowering(TestCase):
d_hid = 512
batch_size = 256
n_stages = 2
device = "cuda"
full_mod = MultiMLP(d_hid, n_layers=n_stages)
full_mod.to(device)
@@ -776,7 +777,7 @@ class TestScheduleLowering(TestCase):
torch.distributed.destroy_process_group()
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_grad_with_split_b_w(self):
"""
Ensure that separate dInput and dWeight computations are correctly executed.
@@ -789,7 +790,6 @@ class TestScheduleLowering(TestCase):
d_hid = 512
batch_size = 256
n_stages = 1
device = "cuda"
full_mod = MultiMLP(d_hid, n_layers=n_stages)
full_mod.to(device)

View File

@@ -26,10 +26,9 @@ from torch.distributed.pipelining import (
ScheduleZBVZeroBubble,
)
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
requires_nccl,
requires_accelerator_dist_backend,
)
from torch.testing._internal.common_utils import (
check_leaked_tensors,
@@ -37,6 +36,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_MULTIACCELERATOR,
)
@@ -45,7 +45,8 @@ logger = logging.getLogger(__name__)
d_hid = 512
batch_size = 64
torch.manual_seed(0)
device_type = "cuda"
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = dist.get_default_backend_for_device(device_type)
class ScheduleTest(MultiProcContinousTest):
@@ -53,8 +54,7 @@ class ScheduleTest(MultiProcContinousTest):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
return backend
@property
def device(self) -> torch.device:
@@ -180,8 +180,10 @@ class ScheduleTest(MultiProcContinousTest):
for stage_module in stage_modules:
stage_module.zero_grad()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
def test_forward_only(self, ScheduleClass):
mod, mod_ref, x, _, _ = self._setup_models_and_data()
@@ -210,8 +212,10 @@ class ScheduleTest(MultiProcContinousTest):
x_clone = mod_ref(x_clone)
torch.testing.assert_close(x_clone, out)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize(
"ScheduleClass",
[
@@ -283,8 +287,10 @@ class ScheduleTest(MultiProcContinousTest):
if self.rank == self.world_size - 1:
self.assertTrue(len(losses) > 0, "Losses should be computed during eval()")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_multi_iter(self, ScheduleClass):
mod, _, x, target, loss_fn = self._setup_models_and_data()
@@ -302,8 +308,10 @@ class ScheduleTest(MultiProcContinousTest):
else:
schedule.step()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_kwargs_with_tracer(self, ScheduleClass):
# Model has two stages only, thus limiting group size to 2
@@ -359,8 +367,10 @@ class ScheduleTest(MultiProcContinousTest):
torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3)
torch.testing.assert_close(pipe_loss, ref_loss)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_grad_with_tracer(self, ScheduleClass):
mod, ref_mod, x, target, loss_fn = self._setup_models_and_data()
@@ -398,8 +408,10 @@ class ScheduleTest(MultiProcContinousTest):
# Check gradients using helper method
self._check_gradients(stage_module, ref_mod)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@parametrize("shape_inference", [True, False])
def test_grad_with_manual(self, ScheduleClass, shape_inference):
@@ -453,8 +465,10 @@ class ScheduleTest(MultiProcContinousTest):
# Check gradients using helper method
self._check_gradients(stage_module, ref_mod)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize(
"ScheduleClass",
[
@@ -563,8 +577,10 @@ class ScheduleTest(MultiProcContinousTest):
stage_modules, ref_mod, submod_names, rtol=5e-3, atol=5e-3
)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble])
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
print(ScheduleClass)
@@ -621,9 +637,16 @@ class ScheduleTest(MultiProcContinousTest):
# Check gradients using helper method
self._check_gradients(stage_modules, ref_mod, submod_names)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize(
"ScheduleClass",
[
ScheduleWithReorderedB,
],
)
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
n_stages = 2
stages_per_rank = 1
@@ -679,8 +702,10 @@ class ScheduleTest(MultiProcContinousTest):
# Check gradients using helper method
self._check_gradients(stage_modules, ref_mod, submod_names)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize(
"schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble]
)
@@ -740,8 +765,10 @@ class ScheduleTest(MultiProcContinousTest):
# Check gradients using helper method
self._check_gradients(stage_modules, ref_mod, submod_names)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
stages_per_rank = 2
@@ -820,8 +847,10 @@ class ScheduleTest(MultiProcContinousTest):
# Check gradients using helper method
self._check_gradients(stage_modules, ref_mod, submod_names)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize(
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],

View File

@@ -14,11 +14,10 @@ from torch.distributed.pipelining import (
ScheduleGPipe,
)
from torch.distributed.pipelining._utils import PipeliningShapeError
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
MultiProcessTestCase,
requires_nccl,
requires_accelerator_dist_backend,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@@ -26,6 +25,7 @@ from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_MULTIACCELERATOR,
)
from torch.utils._pytree import tree_map_only
@@ -34,8 +34,8 @@ d_hid = 512
batch_size = 256
chunks = 4
device_type = "cuda"
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = dist.get_default_backend_for_device(device_type)
torch.manual_seed(0)
@@ -66,8 +66,7 @@ def get_flatten_hook():
class StageTest(MultiProcContinousTest):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
return backend
@classmethod
def device_type(cls) -> str:
@@ -77,8 +76,10 @@ class StageTest(MultiProcContinousTest):
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ModelClass", [ExampleCode, MultiMLP])
def test_tracer(self, ModelClass):
mod = ModelClass(d_hid, self.world_size)
@@ -121,8 +122,10 @@ class StageTest(MultiProcContinousTest):
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ModelClass", [ModelWithKwargs])
def test_tracer_kwargs(self, ModelClass):
mod = ModelClass(d_hid, self.world_size)
@@ -170,8 +173,10 @@ class StageTest(MultiProcContinousTest):
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
def test_manual(self):
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
full_mod.to(self.device)
@@ -202,8 +207,10 @@ class StageTest(MultiProcContinousTest):
ref_out = full_mod(x)
torch.testing.assert_close(out, ref_out)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
def test_custom_dw_with_fb_schedule(self):
"""Tests that separate weight grad function 'dw_runner' gets run under a schedule that's only aware of F/B."""
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
@@ -262,8 +269,10 @@ class StageTest(MultiProcContinousTest):
ref_out = full_mod(x)
torch.testing.assert_close(out, ref_out)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
def test_output_chunks_memory_usage(self):
"""Test that output_chunks doesn't store memory for non-first stages."""
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
@@ -347,14 +356,14 @@ class StageNegativeTest(MultiProcessTestCase):
def init_pg(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
backend=backend,
store=store,
rank=self.rank,
world_size=self.world_size,
device_id=self.device,
)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle("Flaky in CI")
def test_shape_prop_mismatch(self):
"""Tests shape prop errors are raised"""
@@ -402,8 +411,10 @@ class StageNegativeTest(MultiProcessTestCase):
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
def test_custom_dw_errors(self):
"""Tests expected errors are raised"""
self.init_pg()

View File

@@ -73,7 +73,9 @@ class TransformerTests(TestCase):
devices = ["cpu", "cuda", "hpu", "xpu"]
instantiate_device_type_tests(TransformerTests, globals(), only_for=devices)
instantiate_device_type_tests(
TransformerTests, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@@ -73,7 +73,9 @@ class UnflattenTests(TestCase):
devices = ["cpu", "cuda", "hpu", "xpu"]
instantiate_device_type_tests(UnflattenTests, globals(), only_for=devices)
instantiate_device_type_tests(
UnflattenTests, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@@ -1422,6 +1422,7 @@ MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
TEST_XPU = torch.xpu.is_available()
TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False
TEST_CUDA = torch.cuda.is_available()
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()