From 25909d26299008114893a12cdc533f8e73bf6de5 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 24 Oct 2025 12:20:16 -0700 Subject: [PATCH] Simplify SingletonOrSharedTypePtr (#166183) @neildhar pointed out at PTC yesterday that the assumption SingletonOrSharedTypePtr makes about shared_ptr's pointers being either both null or both non-null is incorrect because of the aliasing constructor, and furthermore that SingletonOrSharedTypePtr needn't be as fancy as it is because said constructor exists. (See also https://github.com/pytorch/pytorch/issues/166152 .) Differential Revision: [D85458769](https://our.internmc.facebook.com/intern/diff/D85458769/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166183 Approved by: https://github.com/Skylion007, https://github.com/cyyever --- aten/src/ATen/core/jit_type_base.h | 160 +++------------------------ aten/src/ATen/test/type_ptr_test.cpp | 6 + 2 files changed, 24 insertions(+), 142 deletions(-) diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index 18077ad9f6b..4db1cb18883 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -185,11 +185,11 @@ struct TORCH_API Type { : repr_(nullptr) {} /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p) - : repr_(p) {} + : repr_(makeSingletonSharedPtr(p.get())) {} template , bool> = true> /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p) - : repr_(SingletonTypePtr(p.get())) {} + : repr_(makeSingletonSharedPtr(static_cast(p.get()))) {} // We need to support construction from T* for pybind. The problem @@ -202,8 +202,8 @@ struct TORCH_API Type { // Case 2: if T is exactly Type, we need to do a dynamic_cast to // check if it's a SharedType and do the right thing. // - // Case 3: Otherwise, T is not a SharedType. (debug-check this - // assumption!) Use a singleton pointer. + // Case 3: Otherwise, T is not a SharedType. Use a singleton + // pointer. template , bool> = true> /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast::type>(p)->shared_from_this()) {} @@ -211,15 +211,15 @@ struct TORCH_API Type { template , bool> = true> /* implicit */ SingletonOrSharedTypePtr(T* p) { if (auto* shared_p = dynamic_cast::type>(p)) { - repr_ = Repr(shared_p->shared_from_this()); + repr_ = shared_p->shared_from_this(); } else { - repr_ = Repr(p); + repr_ = makeSingletonSharedPtr(p); } } template && !std::is_base_of_v, bool> = true> /* implicit */ SingletonOrSharedTypePtr(T* p) - : repr_(p) { + : repr_(makeSingletonSharedPtr(p)) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast::type>(p) == nullptr); } @@ -230,19 +230,19 @@ struct TORCH_API Type { ~SingletonOrSharedTypePtr() = default; T* get() const { - return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast(repr_.rawRepr().first); + return repr_.get(); } operator bool() const { - return repr_.isNonNull(); + return repr_ != nullptr; } bool operator==(std::nullptr_t) const { - return !repr_.isNonNull(); + return repr_ == nullptr; } bool operator!=(std::nullptr_t) const { - return repr_.isNonNull(); + return repr_ != nullptr; } template , void>, bool> = true> @@ -255,138 +255,14 @@ struct TORCH_API Type { } private: - // NOTE: SharedPtrWrapper exists to work around a baffling bug in - // nvcc; see comment in destroy() below. - struct SharedPtrWrapper { - SharedPtrWrapper(std::shared_ptr &&x) - : repr_(std::move(x)) {} - std::shared_ptr repr_; - }; - union Repr { - Repr() : Repr(nullptr) {} + // Use shared_ptr's aliasing constructor to create a non-owning pointer + // to a singleton. The lifetime is tied to the null shared_ptr, so there's + // no reference counting overhead for the singleton itself. + static std::shared_ptr makeSingletonSharedPtr(T* ptr) { + return std::shared_ptr(std::shared_ptr(), ptr); + } - explicit Repr(std::shared_ptr x) - : shared_(std::move(x)) {} - - explicit Repr(std::nullptr_t) - : singletonRepr_(nullptr) {} - - explicit Repr(SingletonTypePtr p) - : singletonRepr_(p.get()) {} - - ~Repr() { - destroy(); - } - - // NOTE: the only non-UB way to access our null state is through - // rawRepr(), because our copy operation doesn't preserve which - // union member is active for null pointers. - Repr(const Repr& rhs) { - if (rhs.isSharedAndNonNull()) { - new (&shared_) SharedPtrWrapper(rhs.shared_); - } else { - singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); - singletonRepr_.unused_ = nullptr; - } - } - - Repr(Repr&& rhs) noexcept { - if (rhs.isSharedAndNonNull()) { - new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); - } else { - singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); - singletonRepr_.unused_ = nullptr; - } - } - - Repr& operator=(const Repr& rhs) { - if (&rhs == this) { - return *this; - } - if (rhs.isSharedAndNonNull()) { - if (isSharedAndNonNull()) { - shared_ = rhs.shared_; - } else { - new (&shared_) SharedPtrWrapper(rhs.shared_); - } - } else { - if (isSharedAndNonNull()) { - destroy(); - } - singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); - singletonRepr_.unused_ = nullptr; - } - return *this; - } - - Repr& operator=(Repr&& rhs) noexcept { - if (&rhs == this) { - return *this; - } - if (rhs.isSharedAndNonNull()) { - if (isSharedAndNonNull()) { - shared_ = std::move(rhs.shared_); - } else { - new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); - } - } else { - if (isSharedAndNonNull()) { - destroy(); - } - singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); - singletonRepr_.unused_ = nullptr; - } - return *this; - } - - SharedPtrWrapper shared_; - - struct SingletonRepr { - explicit SingletonRepr(T* s) : singleton_(s) {} - T* singleton_; - void* unused_ = nullptr; - } singletonRepr_; - struct RawRepr { - void* first; - void* nullIfSingleton_; - }; - - // It is UB to read the singleton part of Repr if it was - // constructed as a shared_ptr and vice versa, but memcpying out - // the representation is always OK, so here's an accessor to obey - // the letter of the law. - RawRepr rawRepr() const { - RawRepr repr{}; - memcpy(&repr, reinterpret_cast(this), sizeof(RawRepr)); - return repr; - } - - bool isNonNull() const { - auto repr = rawRepr(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr); - return repr.first != nullptr; - } - - bool isSharedAndNonNull() const { - return rawRepr().nullIfSingleton_ != nullptr; - } - - private: - void destroy() { - if (isSharedAndNonNull()) { - // Without SharedPtrWrapper, this line would read - // `shared_.~shared_ptr()` and nvcc would complain with - // "error: expected primary-expression before '>' token" - // referring to the "t" in "shared_ptr". SharedPtrWrapper - // exists to work around this compiler bug. - shared_.~SharedPtrWrapper(); - } - } - } repr_; + std::shared_ptr repr_; }; using TypePtr = SingletonOrSharedTypePtr; diff --git a/aten/src/ATen/test/type_ptr_test.cpp b/aten/src/ATen/test/type_ptr_test.cpp index fa1858545e3..f872f897733 100644 --- a/aten/src/ATen/test/type_ptr_test.cpp +++ b/aten/src/ATen/test/type_ptr_test.cpp @@ -37,6 +37,10 @@ TEST(SingletonOrSharedTypePtr, Comparison) { EXPECT_NE(empty, p); EXPECT_NE(p, p2); + + EXPECT_EQ(empty, empty); + EXPECT_EQ(p, p); + EXPECT_EQ(p2, p2); } TEST(SingletonOrSharedTypePtr, SingletonComparison) { @@ -47,6 +51,8 @@ TEST(SingletonOrSharedTypePtr, SingletonComparison) { c10::TypePtr type = c10::NoneType::get(); EXPECT_NE(type, c10::StringType::get()); EXPECT_NE(type, c10::DeviceObjType::get()); + EXPECT_EQ(type, type); + EXPECT_EQ(type, c10::NoneType::get()); }