Support torch.accelerator.get_device_capability on XPU (#170747)

# Motivation
This PR adds support for `torch.accelerator.get_device_capability` on XPU. At the current stage, it reports a limited set of basic scalar data types, taking potential software emulation into account where native hardware support may not be available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/170747
Approved by: https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2025-12-23 13:35:39 +00:00
committed by PyTorch MergeBot
parent bfa6f5e073
commit 949476b243
2 changed files with 35 additions and 0 deletions

View File

@@ -45,6 +45,27 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
c10::xpu::set_device(d.index());
}
DeviceCapability getDeviceCapability(Device d) const override {
DeviceCapability cap;
cap.capability_data.capability_bits = (1ULL << kIndex_Byte) |
(1ULL << kIndex_Char) | (1ULL << kIndex_Short) | (1ULL << kIndex_Int) |
(1ULL << kIndex_Long) | (1ULL << kIndex_Float) |
(1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_Bool);
// BFloat16 may be emulated. We always assume BFloat16 is available;
// users can call is_bf16_supported() to check for native hardware support.
cap.capability_data.capability_bits |= (1ULL << kIndex_BFloat16);
auto& device = c10::xpu::get_raw_device(d.index());
if (device.has(sycl::aspect::fp16)) {
cap.capability_data.capability_bits |= (1ULL << kIndex_Half);
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexHalf);
}
if (device.has(sycl::aspect::fp64)) {
cap.capability_data.capability_bits |= (1ULL << kIndex_Double);
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexDouble);
}
return cap;
}
Stream getStream(Device d) const override {
return getCurrentXPUStream(d.index()).unwrap();
}