Support torch.accelerator.get_device_capability on XPU (#170747)

# Motivation
This PR adds support for `torch.accelerator.get_device_capability` on XPU. At the current stage, it reports a limited set of basic scalar data types, taking potential software emulation into account where native hardware support may not be available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/170747
Approved by: https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2025-12-23 13:35:39 +00:00
committed by PyTorch MergeBot
parent bfa6f5e073
commit 949476b243
2 changed files with 35 additions and 0 deletions

View File

@@ -45,6 +45,27 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
c10::xpu::set_device(d.index());
}
DeviceCapability getDeviceCapability(Device d) const override {
DeviceCapability cap;
cap.capability_data.capability_bits = (1ULL << kIndex_Byte) |
(1ULL << kIndex_Char) | (1ULL << kIndex_Short) | (1ULL << kIndex_Int) |
(1ULL << kIndex_Long) | (1ULL << kIndex_Float) |
(1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_Bool);
// BFloat16 may be emulated. We always assume BFloat16 is available;
// users can call is_bf16_supported() to check for native hardware support.
cap.capability_data.capability_bits |= (1ULL << kIndex_BFloat16);
auto& device = c10::xpu::get_raw_device(d.index());
if (device.has(sycl::aspect::fp16)) {
cap.capability_data.capability_bits |= (1ULL << kIndex_Half);
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexHalf);
}
if (device.has(sycl::aspect::fp64)) {
cap.capability_data.capability_bits |= (1ULL << kIndex_Double);
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexDouble);
}
return cap;
}
Stream getStream(Device d) const override {
return getCurrentXPUStream(d.index()).unwrap();
}

View File

@@ -141,6 +141,20 @@ class TestXpu(TestCase):
) # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
self.assertEqual(len(device_properties.uuid.bytes), 16)
def test_get_device_capability(self):
device_capability = torch.xpu.get_device_capability()
acc_capability = torch.accelerator.get_device_capability()
supported_dtypes = acc_capability["supported_dtypes"]
self.assertIn(torch.bool, supported_dtypes)
self.assertIn(torch.int, supported_dtypes)
self.assertIn(torch.float, supported_dtypes)
if device_capability["has_fp16"]:
self.assertIn(torch.float16, supported_dtypes)
if device_capability["has_fp64"]:
self.assertIn(torch.double, supported_dtypes)
if torch.xpu.is_bf16_supported(including_emulation=True):
self.assertIn(torch.bfloat16, supported_dtypes)
@unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
def test_wrong_xpu_fork(self):
stderr = TestCase.runWithPytorchAPIUsageStderr(