mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[xpu][feature] [1/2] Introduce XPUPluggableAllocator in cpp part (#168966)
# Motivation This PR aims to introduce `XPUPluggableAllocator` and we make it as simple as possible. The follow-up PR would introduce the code related to the Python frontend part. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168966 Approved by: https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
85a315917e
commit
2e0c2e170f
@@ -875,6 +875,7 @@ libtorch_python_xpu_sources = [
|
||||
"torch/csrc/xpu/Event.cpp",
|
||||
"torch/csrc/xpu/Module.cpp",
|
||||
"torch/csrc/xpu/Stream.cpp",
|
||||
"torch/csrc/xpu/XPUPluggableAllocator.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/model_container_runner_xpu.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/shim_xpu.cpp",
|
||||
]
|
||||
|
||||
@@ -1353,7 +1353,7 @@ class NativeCachingAllocator : public XPUAllocator {
|
||||
public:
|
||||
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
|
||||
|
||||
void init(DeviceIndex device_count) {
|
||||
void init(DeviceIndex device_count) override {
|
||||
const auto size = static_cast<DeviceIndex>(device_allocators.size());
|
||||
if (size < device_count) {
|
||||
device_allocators.resize(device_count);
|
||||
@@ -1538,88 +1538,62 @@ class NativeCachingAllocator : public XPUAllocator {
|
||||
}
|
||||
};
|
||||
|
||||
static NativeCachingAllocator allocator;
|
||||
static NativeCachingAllocator native_allocator;
|
||||
|
||||
void local_raw_delete(void* ptr) {
|
||||
allocator.free(ptr);
|
||||
native_allocator.free(ptr);
|
||||
}
|
||||
|
||||
Allocator* get() {
|
||||
return &allocator;
|
||||
}
|
||||
std::atomic<XPUAllocator*> allocator;
|
||||
|
||||
void init(DeviceIndex device_count) {
|
||||
return allocator.init(device_count);
|
||||
}
|
||||
struct NativeAllocatorStaticInitializer {
|
||||
NativeAllocatorStaticInitializer() {
|
||||
allocator.store(&native_allocator);
|
||||
c10::SetAllocator(c10::kXPU, &native_allocator, 0);
|
||||
}
|
||||
};
|
||||
|
||||
void emptyCache(MempoolId_t mempool_id) {
|
||||
return allocator.emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
void resetPeakStats(DeviceIndex device) {
|
||||
return allocator.resetPeakStats(device);
|
||||
}
|
||||
|
||||
void resetAccumulatedStats(DeviceIndex device) {
|
||||
return allocator.resetAccumulatedStats(device);
|
||||
}
|
||||
|
||||
DeviceStats getDeviceStats(DeviceIndex device) {
|
||||
return allocator.getDeviceStats(device);
|
||||
}
|
||||
|
||||
void* raw_alloc(size_t size) {
|
||||
return allocator.raw_alloc(size);
|
||||
}
|
||||
|
||||
void raw_delete(void* ptr) {
|
||||
return allocator.raw_delete(ptr);
|
||||
}
|
||||
|
||||
void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
||||
return allocator.recordStream(dataPtr, stream);
|
||||
}
|
||||
static NativeAllocatorStaticInitializer native_allocator_static_initializer;
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
return allocator.enablePeerAccess(dev, dev_to_access);
|
||||
return native_allocator.enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
return allocator.getMemoryFraction(device);
|
||||
return native_allocator.getMemoryFraction(device);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, DeviceIndex device) {
|
||||
return allocator.setMemoryFraction(fraction, device);
|
||||
return native_allocator.setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
void createOrIncrefPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
XPUAllocator* allocator_ptr) {
|
||||
return allocator.createOrIncrefPool(device, mempool_id, allocator_ptr);
|
||||
return native_allocator.createOrIncrefPool(device, mempool_id, allocator_ptr);
|
||||
}
|
||||
|
||||
void beginAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
std::function<bool(sycl::queue*)> filter) {
|
||||
return allocator.beginAllocateToPool(device, mempool_id, std::move(filter));
|
||||
return native_allocator.beginAllocateToPool(
|
||||
device, mempool_id, std::move(filter));
|
||||
}
|
||||
|
||||
void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return allocator.endAllocateToPool(device, mempool_id);
|
||||
return native_allocator.endAllocateToPool(device, mempool_id);
|
||||
}
|
||||
|
||||
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return allocator.releasePool(device, mempool_id);
|
||||
return native_allocator.releasePool(device, mempool_id);
|
||||
}
|
||||
|
||||
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return allocator.getPoolUseCount(device, mempool_id);
|
||||
return native_allocator.getPoolUseCount(device, mempool_id);
|
||||
}
|
||||
|
||||
REGISTER_ALLOCATOR(kXPU, &allocator)
|
||||
|
||||
} // namespace c10::xpu::XPUCachingAllocator
|
||||
|
||||
namespace c10::xpu {
|
||||
|
||||
@@ -8,28 +8,49 @@ namespace c10::xpu::XPUCachingAllocator {
|
||||
|
||||
class XPUAllocator : public DeviceAllocator {
|
||||
public:
|
||||
virtual void init(c10::DeviceIndex device_count) = 0;
|
||||
virtual void* raw_alloc(size_t nbytes) = 0;
|
||||
virtual void raw_delete(void* ptr) = 0;
|
||||
};
|
||||
|
||||
C10_XPU_API Allocator* get();
|
||||
C10_XPU_API extern std::atomic<XPUAllocator*> allocator;
|
||||
|
||||
C10_XPU_API void init(DeviceIndex device_count);
|
||||
inline XPUAllocator* get() {
|
||||
return allocator.load();
|
||||
}
|
||||
|
||||
C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0});
|
||||
inline void init(c10::DeviceIndex device_count) {
|
||||
get()->init(device_count);
|
||||
}
|
||||
|
||||
C10_XPU_API void resetPeakStats(DeviceIndex device);
|
||||
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
|
||||
get()->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
C10_XPU_API void resetAccumulatedStats(DeviceIndex device);
|
||||
inline void resetPeakStats(DeviceIndex device) {
|
||||
get()->resetPeakStats(device);
|
||||
}
|
||||
|
||||
C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
|
||||
DeviceIndex device);
|
||||
inline void resetAccumulatedStats(DeviceIndex device) {
|
||||
get()->resetAccumulatedStats(device);
|
||||
}
|
||||
|
||||
C10_XPU_API void* raw_alloc(size_t size);
|
||||
inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
|
||||
DeviceIndex device) {
|
||||
return get()->getDeviceStats(device);
|
||||
}
|
||||
|
||||
C10_XPU_API void raw_delete(void* ptr);
|
||||
inline void* raw_alloc(size_t size) {
|
||||
return get()->raw_alloc(size);
|
||||
}
|
||||
|
||||
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
|
||||
inline void raw_delete(void* ptr) {
|
||||
get()->raw_delete(ptr);
|
||||
}
|
||||
|
||||
inline void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
||||
get()->recordStream(dataPtr, stream);
|
||||
}
|
||||
|
||||
C10_XPU_API void enablePeerAccess(
|
||||
c10::DeviceIndex dev,
|
||||
|
||||
147
torch/csrc/xpu/XPUPluggableAllocator.cpp
Normal file
147
torch/csrc/xpu/XPUPluggableAllocator.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
#include <torch/csrc/xpu/XPUPluggableAllocator.h>
|
||||
|
||||
namespace torch::xpu::XPUPluggableAllocator {
|
||||
|
||||
void custom_raw_deleter(void* ptr);
|
||||
|
||||
static c10::DeviceIndex device_count_ = 0;
|
||||
|
||||
void* XPUPluggableAllocator::malloc(
|
||||
size_t size,
|
||||
c10::DeviceIndex device,
|
||||
sycl::queue* queue) {
|
||||
void* r = alloc_fn_(size, device, queue);
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock(allocator_mutex_);
|
||||
allocation_metadata_.emplace(r, _AllocationMetadata(size, device, queue));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
c10::DataPtr XPUPluggableAllocator::allocate(size_t size) {
|
||||
auto device = c10::xpu::current_device();
|
||||
sycl::queue& queue = c10::xpu::getCurrentXPUStream(device);
|
||||
void* r = this->malloc(size, device, &queue);
|
||||
return {r, r, raw_deleter(), c10::Device(c10::kXPU, device)};
|
||||
}
|
||||
|
||||
void* XPUPluggableAllocator::raw_alloc(size_t nbytes) {
|
||||
auto device = c10::xpu::current_device();
|
||||
sycl::queue& queue = c10::xpu::getCurrentXPUStream(device);
|
||||
return malloc(nbytes, device, &queue);
|
||||
}
|
||||
|
||||
c10::DeleterFnPtr XPUPluggableAllocator::raw_deleter() const {
|
||||
return &custom_raw_deleter;
|
||||
}
|
||||
|
||||
void XPUPluggableAllocator::raw_delete(void* ptr) {
|
||||
sycl::queue* queue = nullptr;
|
||||
c10::DeviceIndex device_idx = -1;
|
||||
size_t size = 0;
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock(allocator_mutex_);
|
||||
TORCH_CHECK(
|
||||
allocation_metadata_.count(ptr),
|
||||
"Trying to free a pointer not allocated here");
|
||||
_AllocationMetadata& metadata = allocation_metadata_[ptr];
|
||||
size = metadata.size;
|
||||
device_idx = metadata.device_idx;
|
||||
queue = metadata.queue;
|
||||
allocation_metadata_.erase(ptr);
|
||||
}
|
||||
free_fn_(ptr, size, device_idx, queue);
|
||||
}
|
||||
|
||||
void XPUPluggableAllocator::init(c10::DeviceIndex device_count) {
|
||||
if (init_fn_) {
|
||||
init_fn_(device_count);
|
||||
}
|
||||
device_count_ = device_count;
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool XPUPluggableAllocator::initialized() {
|
||||
return initialized_;
|
||||
}
|
||||
|
||||
void XPUPluggableAllocator::copy_data(
|
||||
void* dest,
|
||||
const void* src,
|
||||
std::size_t count) const {
|
||||
c10::xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
|
||||
}
|
||||
|
||||
void XPUPluggableAllocator::recordStream(
|
||||
const c10::DataPtr& ptr,
|
||||
c10::Stream stream) {
|
||||
if (record_stream_fn_) {
|
||||
auto xpu_stream = c10::xpu::XPUStream(stream);
|
||||
record_stream_fn_(ptr.get(), &xpu_stream.queue());
|
||||
}
|
||||
}
|
||||
|
||||
void XPUPluggableAllocator::emptyCache(
|
||||
/*unused*/ c10::MempoolId_t mempool_id) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"XPUPluggableAllocator does not yet support emptyCache. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
|
||||
c10::CachingDeviceAllocator::DeviceStats XPUPluggableAllocator::getDeviceStats(
|
||||
c10::DeviceIndex device) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"XPUPluggableAllocator does not yet support getDeviceStats. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
|
||||
void XPUPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"XPUPluggableAllocator does not yet support resetAccumulatedStats. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
|
||||
void XPUPluggableAllocator::resetPeakStats(c10::DeviceIndex device) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"XPUPluggableAllocator does not yet support resetPeakStats. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
|
||||
std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
|
||||
current_custom_allocator;
|
||||
|
||||
std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
|
||||
getCurrentAllocator() {
|
||||
return current_custom_allocator;
|
||||
}
|
||||
|
||||
std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
|
||||
createCustomAllocator(
|
||||
std::function<void*(size_t, int, sycl::queue*)> alloc_fn,
|
||||
std::function<void(void*, size_t, int, sycl::queue*)> free_fn) {
|
||||
auto allocator = std::make_shared<XPUPluggableAllocator>(
|
||||
std::move(alloc_fn), std::move(free_fn));
|
||||
allocator->init(device_count_);
|
||||
return allocator;
|
||||
}
|
||||
|
||||
void changeCurrentAllocator(
|
||||
const std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>&
|
||||
allocator) {
|
||||
TORCH_CHECK(
|
||||
!c10::xpu::XPUCachingAllocator::get()->initialized(),
|
||||
"Can't swap an already initialized allocator");
|
||||
c10::xpu::XPUCachingAllocator::allocator.store(allocator.get());
|
||||
c10::SetAllocator(c10::kXPU, allocator.get());
|
||||
current_custom_allocator = allocator;
|
||||
}
|
||||
|
||||
void custom_raw_deleter(void* ptr) {
|
||||
current_custom_allocator->raw_delete(ptr);
|
||||
}
|
||||
|
||||
} // namespace torch::xpu::XPUPluggableAllocator
|
||||
80
torch/csrc/xpu/XPUPluggableAllocator.h
Normal file
80
torch/csrc/xpu/XPUPluggableAllocator.h
Normal file
@@ -0,0 +1,80 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
|
||||
namespace torch::xpu::XPUPluggableAllocator {
|
||||
|
||||
struct _AllocationMetadata {
|
||||
_AllocationMetadata() {}
|
||||
_AllocationMetadata(
|
||||
size_t size,
|
||||
c10::DeviceIndex device_idx,
|
||||
sycl::queue* queue)
|
||||
: size(size), device_idx(device_idx), queue(queue) {}
|
||||
size_t size{0};
|
||||
c10::DeviceIndex device_idx{-1};
|
||||
sycl::queue* queue{};
|
||||
};
|
||||
|
||||
struct TORCH_PYTHON_API XPUPluggableAllocator
|
||||
: public c10::xpu::XPUCachingAllocator::XPUAllocator {
|
||||
XPUPluggableAllocator(
|
||||
std::function<void*(size_t, int, sycl::queue*)> alloc_fn,
|
||||
std::function<void(void*, size_t, int, sycl::queue*)> free_fn)
|
||||
: alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {}
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(XPUPluggableAllocator);
|
||||
|
||||
~XPUPluggableAllocator() override = default;
|
||||
|
||||
void* malloc(size_t size, c10::DeviceIndex device, sycl::queue* stream);
|
||||
|
||||
c10::DataPtr allocate(size_t size) override;
|
||||
c10::DeleterFnPtr raw_deleter() const override;
|
||||
|
||||
void* raw_alloc(size_t nbytes) override;
|
||||
void raw_delete(void* ptr) override;
|
||||
void init(c10::DeviceIndex device_count) override;
|
||||
bool initialized() override;
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final;
|
||||
|
||||
void recordStream(const c10::DataPtr&, c10::Stream stream) override;
|
||||
void emptyCache(c10::MempoolId_t mempool_id = {0, 0}) override;
|
||||
c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
|
||||
c10::DeviceIndex device) override;
|
||||
void resetAccumulatedStats(c10::DeviceIndex device) override;
|
||||
void resetPeakStats(c10::DeviceIndex device) override;
|
||||
|
||||
void set_init_fn(std::function<void(int)> init_fn) {
|
||||
init_fn_ = std::move(init_fn);
|
||||
}
|
||||
void set_record_stream_fn(
|
||||
std::function<void(void* ptr, sycl::queue* queue)> record_stream_fn) {
|
||||
record_stream_fn_ = std::move(record_stream_fn);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::function<void*(size_t, int, sycl::queue*)> alloc_fn_;
|
||||
std::function<void(void*, size_t, int, sycl::queue*)> free_fn_;
|
||||
std::function<void(int)> init_fn_;
|
||||
std::function<void(void* ptr, sycl::queue*)> record_stream_fn_;
|
||||
std::mutex allocator_mutex_;
|
||||
// We do the bookkeeping here in order to simplify custom allocators
|
||||
std::unordered_map<void*, _AllocationMetadata> allocation_metadata_;
|
||||
bool initialized_ = false;
|
||||
};
|
||||
|
||||
TORCH_XPU_API std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
|
||||
getCurrentAllocator();
|
||||
|
||||
TORCH_XPU_API std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>
|
||||
createCustomAllocator(
|
||||
std::function<void*(size_t, int, sycl::queue*)> alloc_fn,
|
||||
std::function<void(void*, size_t, int, sycl::queue*)> free_fn);
|
||||
|
||||
TORCH_XPU_API void changeCurrentAllocator(
|
||||
const std::shared_ptr<c10::xpu::XPUCachingAllocator::XPUAllocator>&
|
||||
allocator);
|
||||
|
||||
} // namespace torch::xpu::XPUPluggableAllocator
|
||||
Reference in New Issue
Block a user