diff --git a/build_variables.bzl b/build_variables.bzl index ba856c5a97b..25f167191ab 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -875,6 +875,7 @@ libtorch_python_xpu_sources = [ "torch/csrc/xpu/Event.cpp", "torch/csrc/xpu/Module.cpp", "torch/csrc/xpu/Stream.cpp", + "torch/csrc/xpu/XPUPluggableAllocator.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_xpu.cpp", "torch/csrc/inductor/aoti_torch/shim_xpu.cpp", ] diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index dfcccc94c9e..92dffc91539 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1353,7 +1353,7 @@ class NativeCachingAllocator : public XPUAllocator { public: std::vector> device_allocators; - void init(DeviceIndex device_count) { + void init(DeviceIndex device_count) override { const auto size = static_cast(device_allocators.size()); if (size < device_count) { device_allocators.resize(device_count); @@ -1538,88 +1538,62 @@ class NativeCachingAllocator : public XPUAllocator { } }; -static NativeCachingAllocator allocator; +static NativeCachingAllocator native_allocator; void local_raw_delete(void* ptr) { - allocator.free(ptr); + native_allocator.free(ptr); } -Allocator* get() { - return &allocator; -} +std::atomic allocator; -void init(DeviceIndex device_count) { - return allocator.init(device_count); -} +struct NativeAllocatorStaticInitializer { + NativeAllocatorStaticInitializer() { + allocator.store(&native_allocator); + c10::SetAllocator(c10::kXPU, &native_allocator, 0); + } +}; -void emptyCache(MempoolId_t mempool_id) { - return allocator.emptyCache(mempool_id); -} - -void resetPeakStats(DeviceIndex device) { - return allocator.resetPeakStats(device); -} - -void resetAccumulatedStats(DeviceIndex device) { - return allocator.resetAccumulatedStats(device); -} - -DeviceStats getDeviceStats(DeviceIndex device) { - return allocator.getDeviceStats(device); -} - -void* raw_alloc(size_t size) { - return allocator.raw_alloc(size); -} - -void raw_delete(void* ptr) { - return allocator.raw_delete(ptr); -} - -void recordStream(const DataPtr& dataPtr, XPUStream stream) { - return allocator.recordStream(dataPtr, stream); -} +static NativeAllocatorStaticInitializer native_allocator_static_initializer; void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) { - return allocator.enablePeerAccess(dev, dev_to_access); + return native_allocator.enablePeerAccess(dev, dev_to_access); } double getMemoryFraction(DeviceIndex device) { - return allocator.getMemoryFraction(device); + return native_allocator.getMemoryFraction(device); } void setMemoryFraction(double fraction, DeviceIndex device) { - return allocator.setMemoryFraction(fraction, device); + return native_allocator.setMemoryFraction(fraction, device); } void createOrIncrefPool( c10::DeviceIndex device, MempoolId_t mempool_id, XPUAllocator* allocator_ptr) { - return allocator.createOrIncrefPool(device, mempool_id, allocator_ptr); + return native_allocator.createOrIncrefPool(device, mempool_id, allocator_ptr); } void beginAllocateToPool( c10::DeviceIndex device, MempoolId_t mempool_id, std::function filter) { - return allocator.beginAllocateToPool(device, mempool_id, std::move(filter)); + return native_allocator.beginAllocateToPool( + device, mempool_id, std::move(filter)); } void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { - return allocator.endAllocateToPool(device, mempool_id); + return native_allocator.endAllocateToPool(device, mempool_id); } void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { - return allocator.releasePool(device, mempool_id); + return native_allocator.releasePool(device, mempool_id); } int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { - return allocator.getPoolUseCount(device, mempool_id); + return native_allocator.getPoolUseCount(device, mempool_id); } -REGISTER_ALLOCATOR(kXPU, &allocator) - } // namespace c10::xpu::XPUCachingAllocator namespace c10::xpu { diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index 0054e359e77..54c7387cc38 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -8,28 +8,49 @@ namespace c10::xpu::XPUCachingAllocator { class XPUAllocator : public DeviceAllocator { public: + virtual void init(c10::DeviceIndex device_count) = 0; virtual void* raw_alloc(size_t nbytes) = 0; virtual void raw_delete(void* ptr) = 0; }; -C10_XPU_API Allocator* get(); +C10_XPU_API extern std::atomic allocator; -C10_XPU_API void init(DeviceIndex device_count); +inline XPUAllocator* get() { + return allocator.load(); +} -C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0}); +inline void init(c10::DeviceIndex device_count) { + get()->init(device_count); +} -C10_XPU_API void resetPeakStats(DeviceIndex device); +inline void emptyCache(MempoolId_t mempool_id = {0, 0}) { + get()->emptyCache(mempool_id); +} -C10_XPU_API void resetAccumulatedStats(DeviceIndex device); +inline void resetPeakStats(DeviceIndex device) { + get()->resetPeakStats(device); +} -C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - DeviceIndex device); +inline void resetAccumulatedStats(DeviceIndex device) { + get()->resetAccumulatedStats(device); +} -C10_XPU_API void* raw_alloc(size_t size); +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + DeviceIndex device) { + return get()->getDeviceStats(device); +} -C10_XPU_API void raw_delete(void* ptr); +inline void* raw_alloc(size_t size) { + return get()->raw_alloc(size); +} -C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream); +inline void raw_delete(void* ptr) { + get()->raw_delete(ptr); +} + +inline void recordStream(const DataPtr& dataPtr, XPUStream stream) { + get()->recordStream(dataPtr, stream); +} C10_XPU_API void enablePeerAccess( c10::DeviceIndex dev, diff --git a/torch/csrc/xpu/XPUPluggableAllocator.cpp b/torch/csrc/xpu/XPUPluggableAllocator.cpp new file mode 100644 index 00000000000..6534ac94f15 --- /dev/null +++ b/torch/csrc/xpu/XPUPluggableAllocator.cpp @@ -0,0 +1,147 @@ +#include + +namespace torch::xpu::XPUPluggableAllocator { + +void custom_raw_deleter(void* ptr); + +static c10::DeviceIndex device_count_ = 0; + +void* XPUPluggableAllocator::malloc( + size_t size, + c10::DeviceIndex device, + sycl::queue* queue) { + void* r = alloc_fn_(size, device, queue); + { + const std::lock_guard lock(allocator_mutex_); + allocation_metadata_.emplace(r, _AllocationMetadata(size, device, queue)); + } + return r; +} + +c10::DataPtr XPUPluggableAllocator::allocate(size_t size) { + auto device = c10::xpu::current_device(); + sycl::queue& queue = c10::xpu::getCurrentXPUStream(device); + void* r = this->malloc(size, device, &queue); + return {r, r, raw_deleter(), c10::Device(c10::kXPU, device)}; +} + +void* XPUPluggableAllocator::raw_alloc(size_t nbytes) { + auto device = c10::xpu::current_device(); + sycl::queue& queue = c10::xpu::getCurrentXPUStream(device); + return malloc(nbytes, device, &queue); +} + +c10::DeleterFnPtr XPUPluggableAllocator::raw_deleter() const { + return &custom_raw_deleter; +} + +void XPUPluggableAllocator::raw_delete(void* ptr) { + sycl::queue* queue = nullptr; + c10::DeviceIndex device_idx = -1; + size_t size = 0; + { + const std::lock_guard lock(allocator_mutex_); + TORCH_CHECK( + allocation_metadata_.count(ptr), + "Trying to free a pointer not allocated here"); + _AllocationMetadata& metadata = allocation_metadata_[ptr]; + size = metadata.size; + device_idx = metadata.device_idx; + queue = metadata.queue; + allocation_metadata_.erase(ptr); + } + free_fn_(ptr, size, device_idx, queue); +} + +void XPUPluggableAllocator::init(c10::DeviceIndex device_count) { + if (init_fn_) { + init_fn_(device_count); + } + device_count_ = device_count; + initialized_ = true; +} + +bool XPUPluggableAllocator::initialized() { + return initialized_; +} + +void XPUPluggableAllocator::copy_data( + void* dest, + const void* src, + std::size_t count) const { + c10::xpu::getCurrentXPUStream().queue().memcpy(dest, src, count); +} + +void XPUPluggableAllocator::recordStream( + const c10::DataPtr& ptr, + c10::Stream stream) { + if (record_stream_fn_) { + auto xpu_stream = c10::xpu::XPUStream(stream); + record_stream_fn_(ptr.get(), &xpu_stream.queue()); + } +} + +void XPUPluggableAllocator::emptyCache( + /*unused*/ c10::MempoolId_t mempool_id) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support emptyCache. " + "If you need it, please file an issue describing your use case."); +} + +c10::CachingDeviceAllocator::DeviceStats XPUPluggableAllocator::getDeviceStats( + c10::DeviceIndex device) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support getDeviceStats. " + "If you need it, please file an issue describing your use case."); +} + +void XPUPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support resetAccumulatedStats. " + "If you need it, please file an issue describing your use case."); +} + +void XPUPluggableAllocator::resetPeakStats(c10::DeviceIndex device) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support resetPeakStats. " + "If you need it, please file an issue describing your use case."); +} + +std::shared_ptr + current_custom_allocator; + +std::shared_ptr +getCurrentAllocator() { + return current_custom_allocator; +} + +std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn) { + auto allocator = std::make_shared( + std::move(alloc_fn), std::move(free_fn)); + allocator->init(device_count_); + return allocator; +} + +void changeCurrentAllocator( + const std::shared_ptr& + allocator) { + TORCH_CHECK( + !c10::xpu::XPUCachingAllocator::get()->initialized(), + "Can't swap an already initialized allocator"); + c10::xpu::XPUCachingAllocator::allocator.store(allocator.get()); + c10::SetAllocator(c10::kXPU, allocator.get()); + current_custom_allocator = allocator; +} + +void custom_raw_deleter(void* ptr) { + current_custom_allocator->raw_delete(ptr); +} + +} // namespace torch::xpu::XPUPluggableAllocator diff --git a/torch/csrc/xpu/XPUPluggableAllocator.h b/torch/csrc/xpu/XPUPluggableAllocator.h new file mode 100644 index 00000000000..5133955c588 --- /dev/null +++ b/torch/csrc/xpu/XPUPluggableAllocator.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include + +namespace torch::xpu::XPUPluggableAllocator { + +struct _AllocationMetadata { + _AllocationMetadata() {} + _AllocationMetadata( + size_t size, + c10::DeviceIndex device_idx, + sycl::queue* queue) + : size(size), device_idx(device_idx), queue(queue) {} + size_t size{0}; + c10::DeviceIndex device_idx{-1}; + sycl::queue* queue{}; +}; + +struct TORCH_PYTHON_API XPUPluggableAllocator + : public c10::xpu::XPUCachingAllocator::XPUAllocator { + XPUPluggableAllocator( + std::function alloc_fn, + std::function free_fn) + : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {} + + C10_DISABLE_COPY_AND_ASSIGN(XPUPluggableAllocator); + + ~XPUPluggableAllocator() override = default; + + void* malloc(size_t size, c10::DeviceIndex device, sycl::queue* stream); + + c10::DataPtr allocate(size_t size) override; + c10::DeleterFnPtr raw_deleter() const override; + + void* raw_alloc(size_t nbytes) override; + void raw_delete(void* ptr) override; + void init(c10::DeviceIndex device_count) override; + bool initialized() override; + void copy_data(void* dest, const void* src, std::size_t count) const final; + + void recordStream(const c10::DataPtr&, c10::Stream stream) override; + void emptyCache(c10::MempoolId_t mempool_id = {0, 0}) override; + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) override; + void resetAccumulatedStats(c10::DeviceIndex device) override; + void resetPeakStats(c10::DeviceIndex device) override; + + void set_init_fn(std::function init_fn) { + init_fn_ = std::move(init_fn); + } + void set_record_stream_fn( + std::function record_stream_fn) { + record_stream_fn_ = std::move(record_stream_fn); + } + + protected: + std::function alloc_fn_; + std::function free_fn_; + std::function init_fn_; + std::function record_stream_fn_; + std::mutex allocator_mutex_; + // We do the bookkeeping here in order to simplify custom allocators + std::unordered_map allocation_metadata_; + bool initialized_ = false; +}; + +TORCH_XPU_API std::shared_ptr +getCurrentAllocator(); + +TORCH_XPU_API std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn); + +TORCH_XPU_API void changeCurrentAllocator( + const std::shared_ptr& + allocator); + +} // namespace torch::xpu::XPUPluggableAllocator