Files
pytorch/c10/core/SafePyObject.h
Sam Gross 2b5eabc74b Rework PyObject preservation (v2) (#167564)
Make the PyObject preservation scheme thread-safe with free threaded (nogil) Python. The general idea is:

* Python Tensor and Storage objects always hold a strong reference to their underlying c10 object
* c10 objects hold a strong reference to their Python objects if there's at least one other reference to the c10 object

This is implemented in `intrusive_ptr`:

* The top most bit (`kHasPyObject`) from the weakref count is now used to indicate if the `intrusive_ptr_target` has an associated PyObject. So `kHasPyObject` is one bit, the weakref count is now 31 bits and the strong refcount remains 32 bits.
* When the reference count increases from one to two and `kHasPyObject` is set, we incref the associated Python object to ensure that it's kept alive.
* When the reference count decreases from two to one (i.e., there are no C++ reference to the `intrusive_ptr_target` other than from the Python object), we decre the associated Python object to break the cycle.

Other benefits:

* We can delete a lot of the copypasta from Python internal `subtype_dealloc`
* This fixes the weakref and GC bugs we had in the previous scheme. Python weakrefs on Tensors and Storages should just work as expected now.

Risks:

* Extra branch for reference count operations on `intrusive_ptr<TensorImpl>`, `intrusive_ptr<StorageImpl>`, and the generic `intrusive_ptr<intrusive_ptr_target>` even when we're not using Python.
* It's a big change

(Second attempt at https://github.com/pytorch/pytorch/pull/166342)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167564
Approved by: https://github.com/albanD, https://github.com/Skylion007
2025-11-17 14:52:02 +00:00

121 lines
3.6 KiB
C++

#pragma once
#include <c10/core/impl/PyInterpreter.h>
#include <c10/macros/Export.h>
#include <c10/util/python_stub.h>
#include <utility>
namespace c10 {
// This is an safe owning holder for a PyObject, akin to pybind11's
// py::object, with two major differences:
//
// - It is in c10/core; i.e., you can use this type in contexts where
// you do not have a libpython dependency
//
// - It is multi-interpreter safe (ala torchdeploy); when you fetch
// the underlying PyObject* you are required to specify what the current
// interpreter context is and we will check that you match it.
//
// It is INVALID to store a reference to a Tensor object in this way;
// you should just use TensorImpl directly in that case!
struct C10_API SafePyObject {
// Steals a reference to data
SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
: data_(data), pyinterpreter_(pyinterpreter) {}
SafePyObject(SafePyObject&& other) noexcept
: data_(std::exchange(other.data_, nullptr)),
pyinterpreter_(other.pyinterpreter_) {}
// For now it's not used, so we just disallow it.
SafePyObject& operator=(SafePyObject&&) = delete;
SafePyObject(SafePyObject const& other)
: data_(other.data_), pyinterpreter_(other.pyinterpreter_) {
if (data_ != nullptr) {
(*pyinterpreter_)->incref(data_);
}
}
SafePyObject& operator=(SafePyObject const& other) {
if (this == &other) {
return *this; // Handle self-assignment
}
if (other.data_ != nullptr) {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
return *this;
}
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
}
}
c10::impl::PyInterpreter& pyinterpreter() const {
return *pyinterpreter_;
}
PyObject* ptr(const c10::impl::PyInterpreter* /*interpreter*/) const;
// stop tracking the current object, and return it
PyObject* release() {
auto rv = data_;
data_ = nullptr;
return rv;
}
private:
PyObject* data_;
c10::impl::PyInterpreter* pyinterpreter_;
};
// A newtype wrapper around SafePyObject for type safety when a python object
// represents a specific type. Note that `T` is only used as a tag and isn't
// actually used for any true purpose.
template <typename T>
struct SafePyObjectT : private SafePyObject {
SafePyObjectT(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
: SafePyObject(data, pyinterpreter) {}
~SafePyObjectT() = default;
SafePyObjectT(SafePyObjectT&& other) noexcept : SafePyObject(other) {}
SafePyObjectT(SafePyObjectT const&) = delete;
SafePyObjectT& operator=(SafePyObjectT const&) = delete;
SafePyObjectT& operator=(SafePyObjectT&&) = delete;
using SafePyObject::ptr;
using SafePyObject::pyinterpreter;
using SafePyObject::release;
};
// Like SafePyObject, but non-owning. Good for references to global PyObjects
// that will be leaked on interpreter exit. You get a copy constructor/assign
// this way.
struct C10_API SafePyHandle {
SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
: data_(data), pyinterpreter_(pyinterpreter) {}
c10::impl::PyInterpreter& pyinterpreter() const {
return *pyinterpreter_;
}
PyObject* ptr(const c10::impl::PyInterpreter* /*interpreter*/) const;
void reset() {
data_ = nullptr;
pyinterpreter_ = nullptr;
}
operator bool() {
return data_;
}
private:
PyObject* data_;
c10::impl::PyInterpreter* pyinterpreter_;
};
} // namespace c10