mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Support torch.accelerator.get_device_capability on XPU (#170747)
# Motivation This PR adds support for `torch.accelerator.get_device_capability` on XPU. At the current stage, it reports a limited set of basic scalar data types, taking potential software emulation into account where native hardware support may not be available. Pull Request resolved: https://github.com/pytorch/pytorch/pull/170747 Approved by: https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfa6f5e073
commit
949476b243
@@ -45,6 +45,27 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
c10::xpu::set_device(d.index());
|
||||
}
|
||||
|
||||
DeviceCapability getDeviceCapability(Device d) const override {
|
||||
DeviceCapability cap;
|
||||
cap.capability_data.capability_bits = (1ULL << kIndex_Byte) |
|
||||
(1ULL << kIndex_Char) | (1ULL << kIndex_Short) | (1ULL << kIndex_Int) |
|
||||
(1ULL << kIndex_Long) | (1ULL << kIndex_Float) |
|
||||
(1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_Bool);
|
||||
// BFloat16 may be emulated. We always assume BFloat16 is available;
|
||||
// users can call is_bf16_supported() to check for native hardware support.
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_BFloat16);
|
||||
auto& device = c10::xpu::get_raw_device(d.index());
|
||||
if (device.has(sycl::aspect::fp16)) {
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_Half);
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexHalf);
|
||||
}
|
||||
if (device.has(sycl::aspect::fp64)) {
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_Double);
|
||||
cap.capability_data.capability_bits |= (1ULL << kIndex_ComplexDouble);
|
||||
}
|
||||
return cap;
|
||||
}
|
||||
|
||||
Stream getStream(Device d) const override {
|
||||
return getCurrentXPUStream(d.index()).unwrap();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user