fix the ambiguity of simultaneous support for multiple backends and supplement tests (#166619)

**Background**

The previous implementation used a function-local static variable: static at::DeviceType at_fork_device_type = device_type;. This logic would "lock in" the first registered device type and ignore all subsequent registration requests — leading to ambiguity: the system appeared to support multiple device types, but only the first registered one would actually work. Following discussions in Issue #166256, we reached a consensus: this ambiguity should be eliminated to support safe forking behavior in scenarios where "multiple backends are initialized simultaneously within a single process".

d46d8d6f54/torch/csrc/utils/device_lazy_init.cpp (L79-L92)

**Changes**

In device_lazy_init.cpp:
- Add a globally registered-once atfork callback: at_fork_register_once. Unify the atfork handling logic to avoid the single-device-type limitation and multi-callback registration complexity of the old implementation.
- Introduce and use a per-device-type touch flag: at_fork_registered, which records backends that have been accessed/initialized in the parent process.
- In the child process atfork callback, iterate over all device types with at_fork_registered == true: set their "bad fork" flag and reset their lazy initialization requirement.
- Remove the old per-device once_flag declaration (at_fork_once_flags).

**Tests**

Add coverage for:
Poison fork behavior: after forking, initializing the device in the child process fails with a clear RuntimeError.

We conducted a test to validate the outcome of child-process backend device re-initialization following parent-process multi-device initialization, with results aligning with expectations.
<img width="1749" height="1115" alt="image" src="https://github.com/user-attachments/assets/696918e8-8e5d-49ac-9fea-5904fb499fd4" />
<img width="1504" height="338" alt="image" src="https://github.com/user-attachments/assets/8e019395-df3f-45b5-b68d-205ea05ddbfa" />

Fixes #166256

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166619
Approved by: https://github.com/fffrog, https://github.com/guangyey, https://github.com/albanD
This commit is contained in:
zhudada
2025-12-11 17:00:55 +00:00
committed by PyTorch MergeBot
parent c65ab8680f
commit e49e2bb8e4
2 changed files with 46 additions and 8 deletions

View File

@@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# Owner(s): ["module: unknown"]
import multiprocessing
import os
import random
import shutil
@@ -1020,5 +1021,36 @@ class TestDeprecate(TestCase):
_deprecated_api(1, y=2)
class TestDeviceLazyInit(TestCase):
@unittest.skipIf(IS_WINDOWS, "pthread_atfork not available on Windows")
def test_fork_poison_on_lazy_init(self, device):
torch.empty(1, device=device)
def child(q):
try:
torch.empty(1, device=device)
except Exception as e:
q.put(e)
ctx = multiprocessing.get_context("fork")
q = ctx.Queue()
p = ctx.Process(target=child, args=(q,))
p.start()
p.join()
self.assertTrue(not q.empty())
exc = q.get()
pattern = (
r"Cannot re-initialize .* in forked subprocess\. "
r"To use .* with multiprocessing, you must use the 'spawn' start method"
)
self.assertIsInstance(exc, RuntimeError)
self.assertRegex(str(exc), pattern)
instantiate_device_type_tests(
TestDeviceLazyInit, globals(), except_for=["cpu"], allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@@ -14,8 +14,8 @@ namespace {
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_initialized{};
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_in_bad_fork{};
std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES>
at_fork_once_flags{};
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> at_fork_registered{};
c10::once_flag at_fork_register_once{};
} // anonymous namespace
@@ -77,13 +77,19 @@ void set_device_in_bad_fork(at::DeviceType device_type, bool value) {
// Should be called before the first device runtime call.
void register_fork_handler_for_device_init(at::DeviceType device_type) {
#ifndef WIN32
auto& flag = at_fork_once_flags[static_cast<int>(device_type)];
c10::call_once(flag, [device_type]() {
static at::DeviceType at_fork_device_type = device_type;
at_fork_registered[static_cast<int>(device_type)] = true;
c10::call_once(at_fork_register_once, []() {
pthread_atfork(nullptr, nullptr, []() {
set_device_in_bad_fork(at_fork_device_type, true);
if (is_device_lazy_init_supported(at_fork_device_type)) {
set_requires_device_init(at_fork_device_type, true);
for (int i = 0; i < static_cast<int>(at::COMPILE_TIME_MAX_DEVICE_TYPES);
++i) {
if (!at_fork_registered[i]) {
continue;
}
auto dt = static_cast<at::DeviceType>(i);
set_device_in_bad_fork(dt, true);
if (is_device_lazy_init_supported(dt)) {
set_requires_device_init(dt, true);
}
}
});
});