mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Support torch.accelerator.get_device_capability on XPU (#170747)
# Motivation This PR adds support for `torch.accelerator.get_device_capability` on XPU. At the current stage, it reports a limited set of basic scalar data types, taking potential software emulation into account where native hardware support may not be available. Pull Request resolved: https://github.com/pytorch/pytorch/pull/170747 Approved by: https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfa6f5e073
commit
949476b243
@@ -45,6 +45,27 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
c10::xpu::set_device(d.index());
|
||||
}
|
||||
|
||||
DeviceCapability getDeviceCapability(Device d) const override {
|
||||
DeviceCapability cap;
|
||||
cap.capability_data.capability_bits = (1ULL << kIndex_Byte) |
|
||||
(1ULL << kIndex_Char) | (1ULL << kIndex_Short) | (1ULL << kIndex_Int) |
|
||||
(1ULL << kIndex_Long) | (1ULL << kIndex_Float) |
|
||||
(1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_Bool);
|
||||
// BFloat16 may be emulated. We always assume BFloat16 is available;
|
||||
// users can call is_bf16_supported() to check for native hardware support.
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_BFloat16);
|
||||
auto& device = c10::xpu::get_raw_device(d.index());
|
||||
if (device.has(sycl::aspect::fp16)) {
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_Half);
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexHalf);
|
||||
}
|
||||
if (device.has(sycl::aspect::fp64)) {
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_Double);
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexDouble);
|
||||
}
|
||||
return cap;
|
||||
}
|
||||
|
||||
Stream getStream(Device d) const override {
|
||||
return getCurrentXPUStream(d.index()).unwrap();
|
||||
}
|
||||
|
||||
@@ -141,6 +141,20 @@ class TestXpu(TestCase):
|
||||
) # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
self.assertEqual(len(device_properties.uuid.bytes), 16)
|
||||
|
||||
def test_get_device_capability(self):
|
||||
device_capability = torch.xpu.get_device_capability()
|
||||
acc_capability = torch.accelerator.get_device_capability()
|
||||
supported_dtypes = acc_capability["supported_dtypes"]
|
||||
self.assertIn(torch.bool, supported_dtypes)
|
||||
self.assertIn(torch.int, supported_dtypes)
|
||||
self.assertIn(torch.float, supported_dtypes)
|
||||
if device_capability["has_fp16"]:
|
||||
self.assertIn(torch.float16, supported_dtypes)
|
||||
if device_capability["has_fp64"]:
|
||||
self.assertIn(torch.double, supported_dtypes)
|
||||
if torch.xpu.is_bf16_supported(including_emulation=True):
|
||||
self.assertIn(torch.bfloat16, supported_dtypes)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
|
||||
def test_wrong_xpu_fork(self):
|
||||
stderr = TestCase.runWithPytorchAPIUsageStderr(
|
||||
|
||||
Reference in New Issue
Block a user