mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Summary:
Implements MTIA RNG (Random Number Generator) integration into PyTorch following CUDA's architecture pattern. Adds default generators during device initialization and exposes them via `torch.mtia.default_generators`.
Key Changes
**C++ Infrastructure:**
* Extended `MTIAHooksInterface` with `getDefaultGenerator()` and `getNewGenerator()` methods
* Implemented hooks in `mtia_hooks.{h,cpp}` managing per-device RNG state
* Added `createMTIAGenerator()` factory in `mtia_rng.{h,cpp}`
* Modified `_mtia_init()` to create generators tuple mirroring CUDA implementation.
**Python Integration:**
* Added `default_generators: tuple[torch._C.Generator] = ()` declaration to `torch/mtia/__init__.py`
* Generators populated by C++ during lazy device initialization
**Architecture:**
* Generators for same device share underlying RNG state. This is a key design choice for MTIA that users must be aware of.
* Different devices maintain independent RNG states
* Accessible via `torch.mtia.default_generators[device_idx]`
Test Plan:
buck2 test fbcode//mtia/host_runtime/torch_mtia/tests:test_mtia_rng_generators
https://www.internalfb.com/intern/testinfra/testrun/6192449762838044
Differential Revision: D88012916
Pull Request resolved: https://github.com/pytorch/pytorch/pull/169271
Approved by: https://github.com/patrick-toulme, https://github.com/andyanwang, https://github.com/albanD
Co-authored-by: David Ramon Prados <davidrp@meta.com>
235 lines
6.7 KiB
C++
235 lines
6.7 KiB
C++
#include <ATen/CPUGeneratorImpl.h>
|
|
#include <ATen/detail/MTIAHooksInterface.h>
|
|
#include <c10/core/Device.h>
|
|
#include <c10/core/Stream.h>
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
#include <c10/util/Logging.h>
|
|
#include <torch/csrc/utils/device_lazy_init.h>
|
|
#include <thread>
|
|
namespace torch::mtia {
|
|
|
|
constexpr c10::DeviceType kMTIADeviceType = c10::DeviceType::MTIA;
|
|
constexpr c10::DeviceIndex kMTIADeviceCount = 2;
|
|
static thread_local c10::DeviceIndex current_device = 0;
|
|
static thread_local std::array<c10::Stream, kMTIADeviceCount> current_streams =
|
|
{c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA),
|
|
c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)};
|
|
static int64_t stream_id_gen = 1;
|
|
static int64_t event_id_gen = 1;
|
|
static std::array<c10::Stream, kMTIADeviceCount> default_streams = {
|
|
c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA),
|
|
c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)};
|
|
struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
MTIAGuardImpl() = default;
|
|
explicit MTIAGuardImpl(c10::DeviceType t) {
|
|
TORCH_INTERNAL_ASSERT(t == kMTIADeviceType);
|
|
}
|
|
c10::DeviceType type() const override {
|
|
return kMTIADeviceType;
|
|
}
|
|
c10::Device exchangeDevice(c10::Device d) const override {
|
|
c10::Device old_device = getDevice();
|
|
if (old_device.index() != d.index()) {
|
|
setDevice(d);
|
|
}
|
|
return old_device;
|
|
}
|
|
c10::Device getDevice() const override {
|
|
return c10::Device(kMTIADeviceType, current_device);
|
|
}
|
|
|
|
void setDevice(c10::Device d) const override {
|
|
if (getDevice().index() != d.index()) {
|
|
current_device = d.index();
|
|
}
|
|
}
|
|
void uncheckedSetDevice(c10::Device d) const noexcept override {
|
|
(void)d;
|
|
}
|
|
c10::Stream getStream(c10::Device d) const noexcept override {
|
|
return current_streams[d.index()];
|
|
}
|
|
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
|
|
(void)priority;
|
|
return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type());
|
|
}
|
|
c10::Stream getDefaultStream(c10::Device d) const override {
|
|
return default_streams[d.index()];
|
|
}
|
|
c10::Stream getStreamFromGlobalPool(
|
|
c10::Device d,
|
|
bool isHighPriority = false) const override {
|
|
return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type());
|
|
}
|
|
// NB: These do NOT set the current device
|
|
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
|
|
c10::Stream old_stream = getStream(s.device());
|
|
return old_stream;
|
|
}
|
|
c10::DeviceIndex deviceCount() const noexcept override {
|
|
return kMTIADeviceCount;
|
|
}
|
|
|
|
void destroyEvent(void* event, const c10::DeviceIndex device_index)
|
|
const noexcept override {
|
|
(void)device_index;
|
|
}
|
|
|
|
void record(
|
|
void** event,
|
|
const c10::Stream& stream,
|
|
const c10::DeviceIndex device_index,
|
|
const c10::EventFlag flag) const override {
|
|
TORCH_CHECK(
|
|
device_index == -1 || device_index == stream.device_index(),
|
|
"Event device index ",
|
|
device_index,
|
|
" does not match recording stream's device index ",
|
|
stream.device_index(),
|
|
".");
|
|
|
|
const auto orig_device = getDevice();
|
|
|
|
setDevice(stream.device());
|
|
|
|
if (*event == nullptr) {
|
|
*event = reinterpret_cast<void*>(event_id_gen++);
|
|
}
|
|
setDevice(orig_device);
|
|
}
|
|
|
|
void block(void* event, const c10::Stream& stream) const override {
|
|
(void)event;
|
|
(void)stream;
|
|
}
|
|
|
|
// May be called from any device
|
|
bool queryEvent(void* event) const override {
|
|
(void)event;
|
|
return true;
|
|
}
|
|
|
|
// Stream-related functions
|
|
bool queryStream(const c10::Stream& stream) const override {
|
|
(void)stream;
|
|
return true;
|
|
}
|
|
|
|
void synchronizeStream(const c10::Stream& stream) const override {
|
|
(void)stream;
|
|
}
|
|
|
|
void recordDataPtrOnStream(
|
|
const c10::DataPtr& data_ptr,
|
|
const c10::Stream& stream) const override {
|
|
(void)data_ptr;
|
|
(void)stream;
|
|
}
|
|
|
|
double elapsedTime(void* event1, void* event2, const c10::DeviceIndex device_index) const override {
|
|
(void)device_index;
|
|
uint64_t elapsed_time = 1e6;
|
|
return (double)(elapsed_time / 1e6);
|
|
}
|
|
|
|
void synchronizeEvent(void* event) const override {
|
|
(void)event;
|
|
}
|
|
};
|
|
|
|
struct MTIAHooks : public at::MTIAHooksInterface {
|
|
explicit MTIAHooks(at::MTIAHooksArgs) {}
|
|
void init() const override {}
|
|
|
|
bool hasMTIA() const override {
|
|
return true;
|
|
}
|
|
|
|
c10::DeviceIndex deviceCount() const override {
|
|
// Don't call device_lazy_init here to avoid recursive deadlock during initialization
|
|
return c10::DeviceIndex(2);
|
|
}
|
|
|
|
void deviceSynchronize(c10::DeviceIndex device_index) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
(void)device_index;
|
|
}
|
|
|
|
std::string showConfig() const override {
|
|
return "None config";
|
|
}
|
|
|
|
c10::DeviceIndex exchangeDevice(c10::DeviceIndex device) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
auto orig_device = current_device;
|
|
if (current_device != device) {
|
|
current_device = device;
|
|
}
|
|
return orig_device;
|
|
}
|
|
|
|
c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
|
|
auto orig_device = current_device;
|
|
if (current_device != device) {
|
|
current_device = device;
|
|
}
|
|
return orig_device;
|
|
}
|
|
|
|
c10::Stream getDefaultStream(c10::DeviceIndex device) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
|
|
return default_streams[device];
|
|
}
|
|
|
|
c10::Stream getCurrentStream(c10::DeviceIndex device) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
|
|
return current_streams[device];
|
|
}
|
|
|
|
void setCurrentStream(const c10::Stream& stream) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
|
|
current_streams[stream.device_index()] = stream;
|
|
}
|
|
|
|
c10::DeviceIndex getCurrentDevice() const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
|
|
return current_device;
|
|
}
|
|
|
|
void setCurrentDevice(c10::DeviceIndex device) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
|
|
if (current_device != device) {
|
|
current_device = device;
|
|
}
|
|
}
|
|
|
|
const at::Generator& getDefaultGenerator(c10::DeviceIndex device) const override {
|
|
// Don't call device_lazy_init here to avoid recursive deadlock during initialization
|
|
static std::vector<at::Generator> default_gens = {
|
|
at::make_generator<at::CPUGeneratorImpl>(),
|
|
at::make_generator<at::CPUGeneratorImpl>()
|
|
};
|
|
return default_gens[device];
|
|
}
|
|
|
|
at::Generator getNewGenerator(c10::DeviceIndex device) const override {
|
|
torch::utils::device_lazy_init(at::kMTIA);
|
|
return at::make_generator<at::CPUGeneratorImpl>();
|
|
}
|
|
};
|
|
|
|
using at::MTIAHooksRegistry;
|
|
using at::RegistererMTIAHooksRegistry;
|
|
|
|
REGISTER_MTIA_HOOKS(MTIAHooks);
|
|
C10_REGISTER_GUARD_IMPL(MTIA, MTIAGuardImpl);
|
|
|
|
} // namespace torch::mtia
|