mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
fix the ambiguity of simultaneous support for multiple backends and supplement tests (#166619)
**Background**
The previous implementation used a function-local static variable: static at::DeviceType at_fork_device_type = device_type;. This logic would "lock in" the first registered device type and ignore all subsequent registration requests — leading to ambiguity: the system appeared to support multiple device types, but only the first registered one would actually work. Following discussions in Issue #166256, we reached a consensus: this ambiguity should be eliminated to support safe forking behavior in scenarios where "multiple backends are initialized simultaneously within a single process".
d46d8d6f54/torch/csrc/utils/device_lazy_init.cpp (L79-L92)
**Changes**
In device_lazy_init.cpp:
- Add a globally registered-once atfork callback: at_fork_register_once. Unify the atfork handling logic to avoid the single-device-type limitation and multi-callback registration complexity of the old implementation.
- Introduce and use a per-device-type touch flag: at_fork_registered, which records backends that have been accessed/initialized in the parent process.
- In the child process atfork callback, iterate over all device types with at_fork_registered == true: set their "bad fork" flag and reset their lazy initialization requirement.
- Remove the old per-device once_flag declaration (at_fork_once_flags).
**Tests**
Add coverage for:
Poison fork behavior: after forking, initializing the device in the child process fails with a clear RuntimeError.
We conducted a test to validate the outcome of child-process backend device re-initialization following parent-process multi-device initialization, with results aligning with expectations.
<img width="1749" height="1115" alt="image" src="https://github.com/user-attachments/assets/696918e8-8e5d-49ac-9fea-5904fb499fd4" />
<img width="1504" height="338" alt="image" src="https://github.com/user-attachments/assets/8e019395-df3f-45b5-b68d-205ea05ddbfa" />
Fixes #166256
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166619
Approved by: https://github.com/fffrog, https://github.com/guangyey, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c65ab8680f
commit
e49e2bb8e4
@@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
@@ -1020,5 +1021,36 @@ class TestDeprecate(TestCase):
|
||||
_deprecated_api(1, y=2)
|
||||
|
||||
|
||||
class TestDeviceLazyInit(TestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "pthread_atfork not available on Windows")
|
||||
def test_fork_poison_on_lazy_init(self, device):
|
||||
torch.empty(1, device=device)
|
||||
|
||||
def child(q):
|
||||
try:
|
||||
torch.empty(1, device=device)
|
||||
except Exception as e:
|
||||
q.put(e)
|
||||
|
||||
ctx = multiprocessing.get_context("fork")
|
||||
q = ctx.Queue()
|
||||
p = ctx.Process(target=child, args=(q,))
|
||||
p.start()
|
||||
p.join()
|
||||
self.assertTrue(not q.empty())
|
||||
exc = q.get()
|
||||
pattern = (
|
||||
r"Cannot re-initialize .* in forked subprocess\. "
|
||||
r"To use .* with multiprocessing, you must use the 'spawn' start method"
|
||||
)
|
||||
self.assertIsInstance(exc, RuntimeError)
|
||||
self.assertRegex(str(exc), pattern)
|
||||
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestDeviceLazyInit, globals(), except_for=["cpu"], allow_xpu=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@@ -14,8 +14,8 @@ namespace {
|
||||
|
||||
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_initialized{};
|
||||
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_in_bad_fork{};
|
||||
std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
at_fork_once_flags{};
|
||||
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> at_fork_registered{};
|
||||
c10::once_flag at_fork_register_once{};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@@ -77,13 +77,19 @@ void set_device_in_bad_fork(at::DeviceType device_type, bool value) {
|
||||
// Should be called before the first device runtime call.
|
||||
void register_fork_handler_for_device_init(at::DeviceType device_type) {
|
||||
#ifndef WIN32
|
||||
auto& flag = at_fork_once_flags[static_cast<int>(device_type)];
|
||||
c10::call_once(flag, [device_type]() {
|
||||
static at::DeviceType at_fork_device_type = device_type;
|
||||
at_fork_registered[static_cast<int>(device_type)] = true;
|
||||
c10::call_once(at_fork_register_once, []() {
|
||||
pthread_atfork(nullptr, nullptr, []() {
|
||||
set_device_in_bad_fork(at_fork_device_type, true);
|
||||
if (is_device_lazy_init_supported(at_fork_device_type)) {
|
||||
set_requires_device_init(at_fork_device_type, true);
|
||||
for (int i = 0; i < static_cast<int>(at::COMPILE_TIME_MAX_DEVICE_TYPES);
|
||||
++i) {
|
||||
if (!at_fork_registered[i]) {
|
||||
continue;
|
||||
}
|
||||
auto dt = static_cast<at::DeviceType>(i);
|
||||
set_device_in_bad_fork(dt, true);
|
||||
if (is_device_lazy_init_supported(dt)) {
|
||||
set_requires_device_init(dt, true);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user