[xpu][feature] [1/2] Introduce XPUPluggableAllocator in cpp part (#168966)

# Motivation
This PR aims to introduce `XPUPluggableAllocator` and we make it as simple as possible. The follow-up PR would introduce the code related to the Python frontend part.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/168966
Approved by: https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/eellison
This commit is contained in:
Yu, Guangye
2025-12-03 18:59:35 +00:00
committed by PyTorch MergeBot
parent 85a315917e
commit 2e0c2e170f
5 changed files with 279 additions and 56 deletions

View File

@@ -875,6 +875,7 @@ libtorch_python_xpu_sources = [
"torch/csrc/xpu/Event.cpp",
"torch/csrc/xpu/Module.cpp",
"torch/csrc/xpu/Stream.cpp",
"torch/csrc/xpu/XPUPluggableAllocator.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner_xpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_xpu.cpp",
]

View File

@@ -1353,7 +1353,7 @@ class NativeCachingAllocator : public XPUAllocator {
public:
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
void init(DeviceIndex device_count) {
void init(DeviceIndex device_count) override {
const auto size = static_cast<DeviceIndex>(device_allocators.size());
if (size < device_count) {
device_allocators.resize(device_count);
@@ -1538,88 +1538,62 @@ class NativeCachingAllocator : public XPUAllocator {
}
};
static NativeCachingAllocator allocator;
static NativeCachingAllocator native_allocator;
void local_raw_delete(void* ptr) {
allocator.free(ptr);
native_allocator.free(ptr);
}
Allocator* get() {
return &allocator;
}
std::atomic<XPUAllocator*> allocator;
void init(DeviceIndex device_count) {
return allocator.init(device_count);
}
struct NativeAllocatorStaticInitializer {
NativeAllocatorStaticInitializer() {
allocator.store(&native_allocator);
c10::SetAllocator(c10::kXPU, &native_allocator, 0);
}
};
void emptyCache(MempoolId_t mempool_id) {
return allocator.emptyCache(mempool_id);
}
void resetPeakStats(DeviceIndex device) {
return allocator.resetPeakStats(device);
}
void resetAccumulatedStats(DeviceIndex device) {
return allocator.resetAccumulatedStats(device);
}
DeviceStats getDeviceStats(DeviceIndex device) {
return allocator.getDeviceStats(device);
}
void* raw_alloc(size_t size) {
return allocator.raw_alloc(size);
}
void raw_delete(void* ptr) {
return allocator.raw_delete(ptr);
}
void recordStream(const DataPtr& dataPtr, XPUStream stream) {
return allocator.recordStream(dataPtr, stream);
}
static NativeAllocatorStaticInitializer native_allocator_static_initializer;
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
return allocator.enablePeerAccess(dev, dev_to_access);
return native_allocator.enablePeerAccess(dev, dev_to_access);
}
double getMemoryFraction(DeviceIndex device) {
return allocator.getMemoryFraction(device);
return native_allocator.getMemoryFraction(device);
}
void setMemoryFraction(double fraction, DeviceIndex device) {
return allocator.setMemoryFraction(fraction, device);
return native_allocator.setMemoryFraction(fraction, device);
}
void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
XPUAllocator* allocator_ptr) {
return allocator.createOrIncrefPool(device, mempool_id, allocator_ptr);
return native_allocator.createOrIncrefPool(device, mempool_id, allocator_ptr);
}
void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(sycl::queue*)> filter) {
return allocator.beginAllocateToPool(device, mempool_id, std::move(filter));
return native_allocator.beginAllocateToPool(
device, mempool_id, std::move(filter));
}
void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return allocator.endAllocateToPool(device, mempool_id);
return native_allocator.endAllocateToPool(device, mempool_id);
}
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return allocator.releasePool(device, mempool_id);
return native_allocator.releasePool(device, mempool_id);
}
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
return allocator.getPoolUseCount(device, mempool_id);
return native_allocator.getPoolUseCount(device, mempool_id);
}
REGISTER_ALLOCATOR(kXPU, &allocator)
} // namespace c10::xpu::XPUCachingAllocator
namespace c10::xpu {

View File

@@ -8,28 +8,49 @@ namespace c10::xpu::XPUCachingAllocator {
class XPUAllocator : public DeviceAllocator {
public:
virtual void init(c10::DeviceIndex device_count) = 0;
virtual void* raw_alloc(size_t nbytes) = 0;
virtual void raw_delete(void* ptr) = 0;
};
C10_XPU_API Allocator* get();
C10_XPU_API extern std::atomic<XPUAllocator*> allocator;
C10_XPU_API void init(DeviceIndex device_count);
inline XPUAllocator* get() {
return allocator.load();
}
C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0});
inline void init(c10::DeviceIndex device_count) {
get()->init(device_count);
}
C10_XPU_API void resetPeakStats(DeviceIndex device);
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
get()->emptyCache(mempool_id);
}
C10_XPU_API void resetAccumulatedStats(DeviceIndex device);
inline void resetPeakStats(DeviceIndex device) {
get()->resetPeakStats(device);
}
C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
DeviceIndex device);
inline void resetAccumulatedStats(DeviceIndex device) {
get()->resetAccumulatedStats(device);
}
C10_XPU_API void* raw_alloc(size_t size);
inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
DeviceIndex device) {
return get()->getDeviceStats(device);
}
C10_XPU_API void raw_delete(void* ptr);
inline void* raw_alloc(size_t size) {
return get()->raw_alloc(size);
}
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
inline void raw_delete(void* ptr) {
get()->raw_delete(ptr);
}
inline void recordStream(const DataPtr& dataPtr, XPUStream stream) {
get()->recordStream(dataPtr, stream);
}
C10_XPU_API void enablePeerAccess(
c10::DeviceIndex dev,

View File

@@ -0,0 +1,147 @@
#include <torch/csrc/xpu/XPUPluggableAllocator.h>
namespace torch::xpu::XPUPluggableAllocator {
void custom_raw_deleter(void* ptr);
static c10::DeviceIndex device_count_ = 0;
void* XPUPluggableAllocator::malloc(
size_t size,
c10::DeviceIndex device,
sycl::queue* queue) {
void* r = alloc_fn_(size, device, queue);
{
const std::lock_guard<std::mutex> lock(allocator_mutex_);
allocation_metadata_.emplace(r, _AllocationMetadata(size, device, queue));
}
return r;
}
c10::DataPtr XPUPluggableAllocator::allocate(size_t size) {
auto device = c10::xpu::current_device();
sycl::queue& queue = c10::xpu::getCurrentXPUStream(device);
void* r = this->malloc(size, device, &queue);
return {r, r, raw_deleter(), c10::Device(c10::kXPU, device)};
}
void* XPUPluggableAllocator::raw_alloc(size_t nbytes) {
auto device = c10::xpu::current_device();
sycl::queue& queue = c10::xpu::getCurrentXPUStream(device);
return malloc(nbytes, device, &queue);
}
c10::DeleterFnPtr XPUPluggableAllocator::raw_deleter() const {
return &custom_raw_deleter;
}
void XPUPluggableAllocator::raw_delete(void* ptr) {
sycl::queue* queue = nullptr;
c10::DeviceIndex device_idx = -1;
size_t size = 0;
{
const std::lock_guard<std::mutex> lock(allocator_mutex_);
TORCH_CHECK(
allocation_metadata_.count(ptr),
"Trying to free a pointer not allocated here");
_AllocationMetadata& metadata = allocation_metadata_[ptr];
size = metadata.size;
device_idx = metadata.device_idx;
queue = metadata.queue;
allocation_metadata_.erase(ptr);
}
free_fn_(ptr, size, device_idx, queue);
}
void XPUPluggableAllocator::init(c10::DeviceIndex device_count) {
if (init_fn_) {
init_fn_(device_count);
}
device_count_ = device_count;
initialized_ = true;
}
bool XPUPluggableAllocator::initialized() {
return initialized_;
}
void XPUPluggableAllocator::copy_data(
void* dest,
const void* src,
std::size_t count) const {
c10::xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
}
void XPUPluggableAllocator::recordStream(
const c10::DataPtr& ptr,
c10::Stream stream) {
if (record_stream_fn_) {
auto xpu_stream = c10::xpu::XPUStream(stream);
record_stream_fn_(ptr.get(), &xpu_stream.queue());
}
}
void XPUPluggableAllocator::emptyCache(
/*unused*/ c10::MempoolId_t mempool_id) {
TORCH_CHECK(
false,
"XPUPluggableAllocator does not yet support emptyCache. "
"If you need it, please file an issue describing your use case.");
}
c10::CachingDeviceAllocator::DeviceStats XPUPluggableAllocator::getDeviceStats(
c10::DeviceIndex device) {
TORCH_CHECK(
false,
"XPUPluggableAllocator does not yet support getDeviceStats. "
"If you need it, please file an issue describing your use case.");
}
void XPUPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) {
TORCH_CHECK(
false,
"XPUPluggableAllocator does not yet support resetAccumulatedStats. "
"If you need it, please file an issue describing your use case.");
}
void XPUPluggableAllocator::resetPeakStats(c10::DeviceIndex device) {
TORCH_CHECK(
false,
"XPUPluggableAllocator does not yet support resetPeakStats. "
"If you need it, please file an issue describing your use case.");
}
std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
current_custom_allocator;
std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
getCurrentAllocator() {
return current_custom_allocator;
}
std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
createCustomAllocator(
std::function<void*(size_t, int, sycl::queue*)> alloc_fn,
std::function<void(void*, size_t, int, sycl::queue*)> free_fn) {
auto allocator = std::make_shared<XPUPluggableAllocator>(
std::move(alloc_fn), std::move(free_fn));
allocator->init(device_count_);
return allocator;
}
void changeCurrentAllocator(
const std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>&
allocator) {
TORCH_CHECK(
!c10::xpu::XPUCachingAllocator::get()->initialized(),
"Can't swap an already initialized allocator");
c10::xpu::XPUCachingAllocator::allocator.store(allocator.get());
c10::SetAllocator(c10::kXPU, allocator.get());
current_custom_allocator = allocator;
}
void custom_raw_deleter(void* ptr) {
current_custom_allocator->raw_delete(ptr);
}
} // namespace torch::xpu::XPUPluggableAllocator

View File

@@ -0,0 +1,80 @@
#pragma once
#include <c10/xpu/XPUCachingAllocator.h>
#include <torch/csrc/Export.h>
namespace torch::xpu::XPUPluggableAllocator {
struct _AllocationMetadata {
_AllocationMetadata() {}
_AllocationMetadata(
size_t size,
c10::DeviceIndex device_idx,
sycl::queue* queue)
: size(size), device_idx(device_idx), queue(queue) {}
size_t size{0};
c10::DeviceIndex device_idx{-1};
sycl::queue* queue{};
};
struct TORCH_PYTHON_API XPUPluggableAllocator
: public c10::xpu::XPUCachingAllocator::XPUAllocator {
XPUPluggableAllocator(
std::function<void*(size_t, int, sycl::queue*)> alloc_fn,
std::function<void(void*, size_t, int, sycl::queue*)> free_fn)
: alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {}
C10_DISABLE_COPY_AND_ASSIGN(XPUPluggableAllocator);
~XPUPluggableAllocator() override = default;
void* malloc(size_t size, c10::DeviceIndex device, sycl::queue* stream);
c10::DataPtr allocate(size_t size) override;
c10::DeleterFnPtr raw_deleter() const override;
void* raw_alloc(size_t nbytes) override;
void raw_delete(void* ptr) override;
void init(c10::DeviceIndex device_count) override;
bool initialized() override;
void copy_data(void* dest, const void* src, std::size_t count) const final;
void recordStream(const c10::DataPtr&, c10::Stream stream) override;
void emptyCache(c10::MempoolId_t mempool_id = {0, 0}) override;
c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) override;
void resetAccumulatedStats(c10::DeviceIndex device) override;
void resetPeakStats(c10::DeviceIndex device) override;
void set_init_fn(std::function<void(int)> init_fn) {
init_fn_ = std::move(init_fn);
}
void set_record_stream_fn(
std::function<void(void* ptr, sycl::queue* queue)> record_stream_fn) {
record_stream_fn_ = std::move(record_stream_fn);
}
protected:
std::function<void*(size_t, int, sycl::queue*)> alloc_fn_;
std::function<void(void*, size_t, int, sycl::queue*)> free_fn_;
std::function<void(int)> init_fn_;
std::function<void(void* ptr, sycl::queue*)> record_stream_fn_;
std::mutex allocator_mutex_;
// We do the bookkeeping here in order to simplify custom allocators
std::unordered_map<void*, _AllocationMetadata> allocation_metadata_;
bool initialized_ = false;
};
TORCH_XPU_API std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
getCurrentAllocator();
TORCH_XPU_API std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
createCustomAllocator(
std::function<void*(size_t, int, sycl::queue*)> alloc_fn,
std::function<void(void*, size_t, int, sycl::queue*)> free_fn);
TORCH_XPU_API void changeCurrentAllocator(
const std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>&
allocator);
} // namespace torch::xpu::XPUPluggableAllocator