mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
# Motivation There are several issues related to the data type and precision that an accelerator supports (see #165038 and #143112). Sometimes, we have to check for these capabilities in the document, and then hard-code. This PR proposes a new unified API for users to check their accelerator capabilities. # Changes This PR creates a new data structure `DeviceCapability` containing the capabilities that an accelerator commonly has: - Supporting DataType (set to be supported as default): - `fp16`, `int32`, `complex` ... etc - Other capabilities (need to be discussed) To access the structure, this PR defines a new Python API in the Accelerator module -- `get_device_capability`. It takes `device` as an input and returns a dictionary containing the capabilities (now we have `supported_dtypes` as the key). # Usage ```python >>> import torch >>> import torch_openreg >>> torch.accelerator.get_device_capability('openreg:0') {'supported_dtypes': [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, torch.complex32, torch.complex64, torch.complex128, torch.bool, torch.qint8, torch.quint8, torch.qint32, torch.bfloat16, torch.quint4x2, torch.quint2x4, torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16, torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e5m2fnuz, torch.float8_e4m3fnuz, torch.uint16, torch.uint32, torch.uint64, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.int1, torch.int2, torch.int3, torch.int4, torch.int5, torch.int6, torch.int7, torch.float8_e8m0fnu, torch.float4_e2m1fn_x2]} ``` # TODO - So far, precision is the only capability to track, based on my knowledge. But we can find more capabilities in common, and the API should be designed for good extension. - It will support other in-tree accelerators, such as **cuda** and **mps**. - Clarify whether the capabilities are software or hardware supported. (By @guangyey ) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165631 Approved by: https://github.com/guangyey, https://github.com/albanD Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
173 lines
6.5 KiB
C++
173 lines
6.5 KiB
C++
#include <c10/core/AllocatorConfig.h>
|
|
#include <torch/csrc/DeviceAccelerator.h>
|
|
#include <torch/csrc/utils/device_lazy_init.h>
|
|
|
|
namespace torch::accelerator {
|
|
|
|
void initModule(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
m.def("_accelerator_getAccelerator", []() -> std::optional<c10::Device> {
|
|
// If no accelerator was available at compile time, return None.
|
|
auto acc = at::getAccelerator(false);
|
|
if (acc.has_value()) {
|
|
return acc.value();
|
|
} else {
|
|
return std::nullopt;
|
|
}
|
|
});
|
|
|
|
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
|
|
// If device index is negative, no-op
|
|
if (device_index < 0) {
|
|
return;
|
|
}
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
at::accelerator::setDeviceIndex(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_getDeviceIndex", []() {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::getDeviceIndex();
|
|
});
|
|
|
|
m.def("_accelerator_getDeviceCapability", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
auto caps = at::accelerator::getDeviceCapability(device_index);
|
|
|
|
py::dict dict;
|
|
|
|
py::set dtype_set;
|
|
caps.forEachSupportedScalarType([&](c10::ScalarType dtype) {
|
|
THPDtype* thp_dtype = torch::getTHPDtype(dtype);
|
|
py::object dtype_obj =
|
|
py::reinterpret_borrow<py::object>((PyObject*)thp_dtype);
|
|
dtype_set.add(dtype_obj);
|
|
});
|
|
|
|
dict["supported_dtypes"] = dtype_set;
|
|
return dict;
|
|
});
|
|
|
|
m.def("_accelerator_setStream", [](c10::Stream stream) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
// Set the current device to the device of stream
|
|
if (at::accelerator::getDeviceIndex() != stream.device_index()) {
|
|
at::accelerator::setDeviceIndex(stream.device_index());
|
|
}
|
|
at::accelerator::setCurrentStream(stream);
|
|
});
|
|
|
|
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::getCurrentStream(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
if (torch::utils::is_device_lazy_init_supported(device_type) &&
|
|
!torch::utils::is_device_initialized(device_type)) {
|
|
return;
|
|
}
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
{
|
|
py::gil_scoped_release no_gil;
|
|
at::accelerator::synchronizeDevice(device_index);
|
|
}
|
|
});
|
|
|
|
m.def("_accelerator_exchangeDevice", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::exchangeDevice(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::maybeExchangeDevice(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_isAllocatorInitialized", []() {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
return at::getDeviceAllocator(device_type)->initialized();
|
|
});
|
|
|
|
m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); });
|
|
|
|
m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) {
|
|
using c10::CachingAllocator::Stat;
|
|
using c10::CachingAllocator::StatArray;
|
|
using c10::CachingAllocator::StatType;
|
|
using c10::CachingDeviceAllocator::DeviceStats;
|
|
|
|
const auto stats = at::accelerator::getDeviceStats(device_index);
|
|
const auto stat_to_dict = [](const Stat& stat) -> py::dict {
|
|
py::dict dict;
|
|
dict["current"] = stat.current;
|
|
dict["peak"] = stat.peak;
|
|
dict["allocated"] = stat.allocated;
|
|
dict["freed"] = stat.freed;
|
|
return dict;
|
|
};
|
|
|
|
const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict {
|
|
const std::array<const char*, static_cast<size_t>(StatType::NUM_TYPES)>
|
|
kStatTypeNames = {"all", "small_pool", "large_pool"};
|
|
py::dict dict;
|
|
for (const auto i : c10::irange(kStatTypeNames.size())) {
|
|
dict[kStatTypeNames[i]] = stat_to_dict(stats[i]);
|
|
}
|
|
return dict;
|
|
};
|
|
|
|
py::dict result;
|
|
result["num_alloc_retries"] = stats.num_alloc_retries;
|
|
result["num_ooms"] = stats.num_ooms;
|
|
result["max_split_size"] = stats.max_split_size;
|
|
result["num_sync_all_streams"] = stats.num_sync_all_streams;
|
|
result["num_device_alloc"] = stats.num_device_alloc;
|
|
result["num_device_free"] = stats.num_device_free;
|
|
result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes);
|
|
result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes);
|
|
result["active_bytes"] = stat_array_to_dict(stats.active_bytes);
|
|
result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes);
|
|
result["allocation"] = stat_array_to_dict(stats.allocation);
|
|
result["segment"] = stat_array_to_dict(stats.segment);
|
|
result["active"] = stat_array_to_dict(stats.active);
|
|
result["inactive_split"] = stat_array_to_dict(stats.inactive_split);
|
|
result["inactive_split_bytes"] =
|
|
stat_array_to_dict(stats.inactive_split_bytes);
|
|
result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations);
|
|
result["oversize_segments"] = stat_to_dict(stats.oversize_segments);
|
|
return result;
|
|
});
|
|
|
|
m.def(
|
|
"_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) {
|
|
at::accelerator::resetAccumulatedStats(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) {
|
|
at::accelerator::resetPeakStats(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
py::gil_scoped_release no_gil;
|
|
return at::accelerator::getMemoryInfo(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
|
|
c10::CachingAllocator::setAllocatorSettings(env);
|
|
});
|
|
}
|
|
|
|
} // namespace torch::accelerator
|