mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
elastic/rendezvous: make barrier and rank assignment operations O(n) instead of O(n^2) (#124982)
Summary:
This makes barrier and rank operations linear instead of quadratic with the number of workers. This drastically improves performance for rendezvous when running with over 1000 hosts.
This uses 2 approaches for different areas:
* local rank assignment: each worker does 1 set and 1 get, local ranks are assigned on the rank 0 host in a O(n) operation which reduces total store operations to be linear with number of workers.
* exit_barrier: use a counter and a final flag so each worker has to do max 1 set, 1 get and 1 add.
At 4000 hosts we see torchelastic be able to run in as little as 10 seconds down from 373 seconds.
Test Plan:
This is testing using many small tests running on a remote cluster.
{D56549942}
```
torchx run --scheduler mast -- --image=torchelastic_benchmark --j=4000x1
```
Differential Revision: D56605193
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124982
Approved by: https://github.com/kiukchung, https://github.com/kurman
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a6fef15ef
commit
dc4c75ba72
@@ -579,6 +579,7 @@ coverage_ignore_functions = [
|
||||
"barrier",
|
||||
"get_all",
|
||||
"synchronize",
|
||||
"store_timeout",
|
||||
# torch.distributed.fsdp.wrap
|
||||
"always_wrap_policy",
|
||||
"enable_wrap",
|
||||
|
||||
@@ -11,8 +11,11 @@
|
||||
import signal
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import call, MagicMock, patch
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
||||
from torch.distributed.elastic.agent.server.api import (
|
||||
@@ -20,6 +23,7 @@ from torch.distributed.elastic.agent.server.api import (
|
||||
_RoleInstanceInfo,
|
||||
RunResult,
|
||||
SimpleElasticAgent,
|
||||
Worker,
|
||||
WorkerGroup,
|
||||
WorkerSpec,
|
||||
WorkerState,
|
||||
@@ -470,22 +474,6 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
||||
self.assertEqual(1, mock_monitor_workers.call_count)
|
||||
self.assertEqual(spec.max_restarts, agent._remaining_restarts)
|
||||
|
||||
def test_get_ranks(self):
|
||||
role_infos = [
|
||||
_RoleInstanceInfo("parameter_server", 0, 4),
|
||||
_RoleInstanceInfo("trainer", 1, 1),
|
||||
_RoleInstanceInfo("trainer", 2, 2),
|
||||
_RoleInstanceInfo("trainer", 3, 3),
|
||||
_RoleInstanceInfo("parameter_server", 4, 5),
|
||||
]
|
||||
spec = self._get_worker_spec(
|
||||
max_restarts=3, monitor_interval=0.1, role="not_used", local_world_size=8
|
||||
)
|
||||
agent = TestAgent(spec)
|
||||
total_sum, ranks = agent._get_ranks(role_infos, 0, 0, len(role_infos))
|
||||
self.assertEqual(15, total_sum)
|
||||
self.assertEqual([0, 1, 2, 3], list(ranks))
|
||||
|
||||
def test_assign_worker_ranks(self):
|
||||
role_infos = [
|
||||
_RoleInstanceInfo("parameter_server", 0, 4),
|
||||
@@ -494,56 +482,64 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
||||
_RoleInstanceInfo("trainer", 3, 3),
|
||||
_RoleInstanceInfo("parameter_server", 4, 5),
|
||||
]
|
||||
num_agents = len(role_infos)
|
||||
with patch.object(TestAgent, "_share_and_gather", return_value=role_infos):
|
||||
self.verify_worker_ranks(
|
||||
role_infos[0], num_agents, [0, 1, 2, 3], [0, 1, 2, 3]
|
||||
store = dist.HashStore()
|
||||
|
||||
def f(info) -> List[Worker]:
|
||||
i, role_info = info
|
||||
spec = self._get_worker_spec(
|
||||
max_restarts=3,
|
||||
monitor_interval=0.1,
|
||||
role=role_info.role,
|
||||
local_world_size=role_info.local_world_size,
|
||||
)
|
||||
self.verify_worker_ranks(role_infos[1], num_agents, [4], [0])
|
||||
self.verify_worker_ranks(role_infos[2], num_agents, [5, 6], [1, 2])
|
||||
self.verify_worker_ranks(role_infos[3], num_agents, [7, 8, 9], [3, 4, 5])
|
||||
|
||||
def verify_worker_ranks(
|
||||
self, agent_config, total_agents, expected_global_ranks, expected_role_ranks
|
||||
):
|
||||
role, agent_rank, local_world_size = (
|
||||
agent_config.role,
|
||||
agent_config.rank,
|
||||
agent_config.local_world_size,
|
||||
)
|
||||
spec = self._get_worker_spec(
|
||||
max_restarts=3,
|
||||
monitor_interval=0.1,
|
||||
role=role,
|
||||
local_world_size=local_world_size,
|
||||
)
|
||||
agent = TestAgent(spec)
|
||||
workers = agent._assign_worker_ranks(None, agent_rank, total_agents, spec)
|
||||
self.assertEqual(
|
||||
expected_global_ranks, [worker.global_rank for worker in workers]
|
||||
)
|
||||
self.assertEqual(expected_role_ranks, [worker.role_rank for worker in workers])
|
||||
|
||||
@patch("torch.distributed.elastic.utils.store.synchronize")
|
||||
def test_share_and_gather(self, sync_mock):
|
||||
# when the state is unknown we exit immediately; no retries
|
||||
spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1)
|
||||
agent = TestAgent(spec)
|
||||
expected_agent_infos = [
|
||||
_RoleInstanceInfo("trainer", 0, 10),
|
||||
_RoleInstanceInfo("trainer", 1, 10),
|
||||
_RoleInstanceInfo("validator", 2, 10),
|
||||
]
|
||||
|
||||
sync_mock.return_value = [obj.serialize() for obj in expected_agent_infos]
|
||||
result = agent._share_and_gather(MagicMock(), 1, 3, spec)
|
||||
sync_mock.assert_called_once()
|
||||
for expected_role_info, actual_role_info in zip(expected_agent_infos, result):
|
||||
self.assertEqual(expected_role_info.role, actual_role_info.role)
|
||||
self.assertEqual(expected_role_info.rank, actual_role_info.rank)
|
||||
self.assertEqual(
|
||||
expected_role_info.local_world_size, actual_role_info.local_world_size
|
||||
agent = TestAgent(spec)
|
||||
workers = agent._assign_worker_ranks(
|
||||
store, role_info.rank, len(role_infos), spec
|
||||
)
|
||||
return [
|
||||
(
|
||||
w.local_rank,
|
||||
w.role_rank,
|
||||
w.global_rank,
|
||||
w.world_size,
|
||||
w.role_world_size,
|
||||
)
|
||||
for w in workers
|
||||
]
|
||||
|
||||
with ThreadPool(len(role_infos)) as pool:
|
||||
out = pool.map(f, enumerate(role_infos))
|
||||
|
||||
self.assertListEqual(
|
||||
out,
|
||||
[
|
||||
[
|
||||
(0, 0, 0, 15, 9),
|
||||
(1, 1, 1, 15, 9),
|
||||
(2, 2, 2, 15, 9),
|
||||
(3, 3, 3, 15, 9),
|
||||
],
|
||||
[
|
||||
(0, 0, 4, 15, 6),
|
||||
],
|
||||
[
|
||||
(0, 1, 5, 15, 6),
|
||||
(1, 2, 6, 15, 6),
|
||||
],
|
||||
[
|
||||
(0, 3, 7, 15, 6),
|
||||
(1, 4, 8, 15, 6),
|
||||
(2, 5, 9, 15, 6),
|
||||
],
|
||||
[
|
||||
(0, 4, 10, 15, 9),
|
||||
(1, 5, 11, 15, 9),
|
||||
(2, 6, 12, 15, 9),
|
||||
(3, 7, 13, 15, 9),
|
||||
(4, 8, 14, 15, 9),
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_event(self):
|
||||
spec = self._get_worker_spec(max_restarts=1)
|
||||
|
||||
@@ -7,77 +7,142 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from unittest import mock
|
||||
import datetime
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import List
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch.distributed.elastic.utils.store as store_util
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class MockStore:
|
||||
def __init__(self):
|
||||
self.ops = []
|
||||
|
||||
def set_timeout(self, timeout: float) -> None:
|
||||
self.ops.append(("set_timeout", timeout))
|
||||
|
||||
@property
|
||||
def timeout(self) -> datetime.timedelta:
|
||||
self.ops.append(("timeout",))
|
||||
|
||||
return datetime.timedelta(seconds=1234)
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
self.ops.append(("set", key, value))
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
self.ops.append(("get", key))
|
||||
return "value"
|
||||
|
||||
def multi_get(self, keys: List[str]) -> List[str]:
|
||||
self.ops.append(("multi_get", keys))
|
||||
return ["value"] * len(keys)
|
||||
|
||||
def add(self, key: str, val: int) -> int:
|
||||
self.ops.append(("add", key, val))
|
||||
return 3
|
||||
|
||||
|
||||
class StoreUtilTest(TestCase):
|
||||
def test_get_all_rank_0(self):
|
||||
store = mock.MagicMock()
|
||||
world_size = 3
|
||||
store_util.get_all(store, 0, "test/store", world_size)
|
||||
# omit empty kwargs, get only key
|
||||
actual_set_call_args = [
|
||||
call_args[0][0] for call_args in store.set.call_args_list
|
||||
]
|
||||
self.assertListEqual(["test/store0.FIN"], actual_set_call_args)
|
||||
|
||||
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
|
||||
expected_get_call_args = [
|
||||
("test/store0",),
|
||||
("test/store1",),
|
||||
("test/store2",),
|
||||
("test/store0.FIN",),
|
||||
("test/store1.FIN",),
|
||||
("test/store2.FIN",),
|
||||
]
|
||||
self.assertListEqual(expected_get_call_args, actual_get_call_args)
|
||||
store = MockStore()
|
||||
|
||||
store_util.get_all(store, 0, "test/store", world_size)
|
||||
|
||||
self.assertListEqual(
|
||||
store.ops,
|
||||
[
|
||||
("multi_get", ["test/store0", "test/store1", "test/store2"]),
|
||||
("add", "test/store/finished/num_members", 1),
|
||||
("set", "test/store/finished/last_member", "<val_ignored>"),
|
||||
("get", "test/store/finished/last_member"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_all_rank_n(self):
|
||||
store = mock.MagicMock()
|
||||
store = MockStore()
|
||||
world_size = 3
|
||||
store_util.get_all(store, 1, "test/store", world_size)
|
||||
# omit empty kwargs, get only key
|
||||
actual_set_call_args = [
|
||||
call_args[0][0] for call_args in store.set.call_args_list
|
||||
]
|
||||
self.assertListEqual(["test/store1.FIN"], actual_set_call_args)
|
||||
|
||||
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
|
||||
expected_get_call_args = [
|
||||
("test/store0",),
|
||||
("test/store1",),
|
||||
("test/store2",),
|
||||
]
|
||||
self.assertListEqual(expected_get_call_args, actual_get_call_args)
|
||||
self.assertListEqual(
|
||||
store.ops,
|
||||
[
|
||||
("multi_get", ["test/store0", "test/store1", "test/store2"]),
|
||||
("add", "test/store/finished/num_members", 1),
|
||||
("set", "test/store/finished/last_member", "<val_ignored>"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_synchronize(self):
|
||||
store_mock = mock.MagicMock()
|
||||
data = b"data0"
|
||||
store_util.synchronize(store_mock, data, 0, 3, key_prefix="torchelastic/test")
|
||||
actual_set_call_args = store_mock.set.call_args_list
|
||||
# omit empty kwargs
|
||||
actual_set_call_args = [call_args[0] for call_args in actual_set_call_args]
|
||||
expected_set_call_args = [
|
||||
("torchelastic/test0", b"data0"),
|
||||
("torchelastic/test0.FIN", b"FIN"),
|
||||
]
|
||||
self.assertListEqual(expected_set_call_args, actual_set_call_args)
|
||||
store = MockStore()
|
||||
|
||||
expected_get_call_args = [
|
||||
("torchelastic/test0",),
|
||||
("torchelastic/test1",),
|
||||
("torchelastic/test2",),
|
||||
("torchelastic/test0.FIN",),
|
||||
("torchelastic/test1.FIN",),
|
||||
("torchelastic/test2.FIN",),
|
||||
]
|
||||
actual_get_call_args = store_mock.get.call_args_list
|
||||
actual_get_call_args = [call_args[0] for call_args in actual_get_call_args]
|
||||
self.assertListEqual(expected_get_call_args, actual_get_call_args)
|
||||
data = b"data0"
|
||||
store_util.synchronize(store, data, 0, 3, key_prefix="test/store")
|
||||
|
||||
self.assertListEqual(
|
||||
store.ops,
|
||||
[
|
||||
("timeout",),
|
||||
("set_timeout", datetime.timedelta(seconds=300)),
|
||||
("set", "test/store0", data),
|
||||
("multi_get", ["test/store0", "test/store1", "test/store2"]),
|
||||
("add", "test/store/finished/num_members", 1),
|
||||
("set", "test/store/finished/last_member", "<val_ignored>"),
|
||||
("get", "test/store/finished/last_member"),
|
||||
("set_timeout", datetime.timedelta(seconds=1234)),
|
||||
],
|
||||
)
|
||||
|
||||
def test_synchronize_hash_store(self) -> None:
|
||||
N = 4
|
||||
|
||||
store = dist.HashStore()
|
||||
|
||||
def f(i: int):
|
||||
return store_util.synchronize(
|
||||
store, f"data{i}", i, N, key_prefix="test/store"
|
||||
)
|
||||
|
||||
with ThreadPool(N) as pool:
|
||||
out = pool.map(f, range(N))
|
||||
|
||||
self.assertListEqual(out, [[f"data{i}".encode() for i in range(N)]] * N)
|
||||
|
||||
def test_barrier(self):
|
||||
store = MockStore()
|
||||
|
||||
store_util.barrier(store, 3, key_prefix="test/store")
|
||||
|
||||
self.assertListEqual(
|
||||
store.ops,
|
||||
[
|
||||
("timeout",),
|
||||
("set_timeout", datetime.timedelta(seconds=300)),
|
||||
("add", "test/store/num_members", 1),
|
||||
("set", "test/store/last_member", "<val_ignored>"),
|
||||
("get", "test/store/last_member"),
|
||||
("set_timeout", datetime.timedelta(seconds=1234)),
|
||||
],
|
||||
)
|
||||
|
||||
def test_barrier_hash_store(self) -> None:
|
||||
N = 4
|
||||
|
||||
store = dist.HashStore()
|
||||
|
||||
def f(i: int):
|
||||
store_util.barrier(store, N, key_prefix="test/store")
|
||||
|
||||
with ThreadPool(N) as pool:
|
||||
out = pool.map(f, range(N))
|
||||
|
||||
self.assertEqual(out, [None] * N)
|
||||
|
||||
|
||||
class UtilTest(TestCase):
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
@@ -30,6 +29,7 @@ from torch.distributed.elastic.multiprocessing import (
|
||||
ProcessFailure,
|
||||
SignalException,
|
||||
)
|
||||
from collections import defaultdict
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
__all__ = [
|
||||
@@ -592,26 +592,6 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
}
|
||||
)
|
||||
|
||||
def _get_ranks(
|
||||
self,
|
||||
role_infos: List[_RoleInstanceInfo],
|
||||
role_idx: int,
|
||||
start_idx: int = 0,
|
||||
end_idx: int = -1,
|
||||
) -> Tuple[int, List[int]]:
|
||||
if end_idx == -1:
|
||||
end_idx = len(role_infos)
|
||||
prefix_sum = 0
|
||||
total_sum = 0
|
||||
for idx in range(start_idx, end_idx):
|
||||
if role_idx > idx:
|
||||
prefix_sum += role_infos[idx].local_world_size
|
||||
total_sum += role_infos[idx].local_world_size
|
||||
return (
|
||||
total_sum,
|
||||
list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)),
|
||||
)
|
||||
|
||||
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
||||
# `torch.distributed.elastic.metrics.prof`.
|
||||
@prof
|
||||
@@ -624,63 +604,86 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
|
||||
1. Each agent writes its configuration(group_rank, group_world_size
|
||||
, num_workers) to the common store.
|
||||
2. Each agent retrieves configuration for all agents
|
||||
and performs two level sort using role and rank.
|
||||
3. Determine the global rank: the global rank of the workers for the current
|
||||
agent is the offset of the infos array up to group_rank of the agent.
|
||||
The offset is computed as a sum of local_world_size of all agents that
|
||||
have rank less than the group_rank. The workers would have the ranks:
|
||||
[offset, offset+local_world_size)
|
||||
2. The rank 0 agent reads all the role_info from the store and
|
||||
determines each agents worker ranks.
|
||||
3. Determine the global rank: the global rank of the workers is computed
|
||||
by cumulative sum of the local_world_size for all workers in front of it.
|
||||
For efficiency reasons each worker is assigned a base global rank
|
||||
such that it's workers are in the range [base_global_rank,
|
||||
base_global_rank + local_world_size).
|
||||
4. Determine the role rank: The role rank is determined using the algorithms
|
||||
in the point 3 with the exception that the offset is done from the first
|
||||
agent that has the same role as current one and has the minimum group rank.
|
||||
in the point 3 with the exception that the ranks are calculated with
|
||||
respect to the role name.
|
||||
5. The rank 0 agent writes the assigned ranks to the store.
|
||||
6. Each agent reads the assigned ranks from the store.
|
||||
|
||||
Time complexity: each worker O(1), rank0 O(n), overall O(n)
|
||||
"""
|
||||
role_infos = self._share_and_gather(store, group_rank, group_world_size, spec)
|
||||
my_role_info = role_infos[group_rank]
|
||||
worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank)
|
||||
role_infos = sorted(
|
||||
role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare)
|
||||
|
||||
ROLE_INFO_PREFIX = "torchelastic/role_info/"
|
||||
ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/"
|
||||
|
||||
agent_role_info = _RoleInstanceInfo(
|
||||
spec.role, group_rank, spec.local_world_size
|
||||
)
|
||||
role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries(
|
||||
role_infos, my_role_info.role
|
||||
)
|
||||
role_pos = next(
|
||||
idx
|
||||
for idx, role_info in enumerate(role_infos)
|
||||
if _RoleInstanceInfo.compare(role_info, my_role_info) == 0
|
||||
)
|
||||
role_world_size, role_ranks = self._get_ranks(
|
||||
role_infos, role_pos, role_start_idx, role_end_idx + 1
|
||||
store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize())
|
||||
|
||||
# tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations.
|
||||
if group_rank == 0:
|
||||
role_infos_bytes = store.multi_get(
|
||||
[f"torchelastic/role_info/{i}" for i in range(group_world_size)]
|
||||
)
|
||||
role_infos = [
|
||||
_RoleInstanceInfo.deserialize(info_bytes)
|
||||
for info_bytes in role_infos_bytes
|
||||
]
|
||||
|
||||
role_sizes = defaultdict(lambda: 0)
|
||||
global_size = 0
|
||||
for role_info in role_infos:
|
||||
role_sizes[role_info.role] += role_info.local_world_size
|
||||
global_size += role_info.local_world_size
|
||||
|
||||
base_global_rank = 0
|
||||
role_ranks = defaultdict(lambda: 0)
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
for i, role_info in enumerate(role_infos):
|
||||
keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}")
|
||||
values.append(
|
||||
json.dumps(
|
||||
[
|
||||
base_global_rank,
|
||||
global_size,
|
||||
role_ranks[role_info.role],
|
||||
role_sizes[role_info.role],
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
base_global_rank += role_info.local_world_size
|
||||
role_ranks[role_info.role] += role_info.local_world_size
|
||||
|
||||
store.multi_set(keys, values)
|
||||
|
||||
# get will block until the data is available in the store.
|
||||
base_global_rank, global_world_size, base_role_rank, role_world_size = json.loads(
|
||||
store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")
|
||||
)
|
||||
|
||||
workers = []
|
||||
for ind in range(spec.local_world_size):
|
||||
for local_rank in range(spec.local_world_size):
|
||||
worker = Worker(
|
||||
local_rank=ind,
|
||||
global_rank=worker_global_ranks[ind],
|
||||
role_rank=role_ranks[ind],
|
||||
world_size=worker_world_size,
|
||||
local_rank=local_rank,
|
||||
global_rank=base_global_rank + local_rank,
|
||||
role_rank=base_role_rank + local_rank,
|
||||
world_size=global_world_size,
|
||||
role_world_size=role_world_size,
|
||||
)
|
||||
workers.append(worker)
|
||||
return workers
|
||||
|
||||
def _share_and_gather(
|
||||
self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
|
||||
) -> List:
|
||||
agent_role_info = _RoleInstanceInfo(
|
||||
spec.role, group_rank, spec.local_world_size
|
||||
)
|
||||
key_prefix = "torchelastic/role_info"
|
||||
agent_config_enc = agent_role_info.serialize()
|
||||
role_infos_bytes = store_util.synchronize(
|
||||
store, agent_config_enc, group_rank, group_world_size, key_prefix
|
||||
)
|
||||
role_infos = [
|
||||
_RoleInstanceInfo.deserialize(role_info_bytes)
|
||||
for role_info_bytes in role_infos_bytes
|
||||
]
|
||||
return role_infos
|
||||
|
||||
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
||||
# `torch.distributed.elastic.metrics.prof`.
|
||||
@prof
|
||||
@@ -935,9 +938,8 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
start = time.time()
|
||||
try:
|
||||
store_util.barrier(
|
||||
self._store,
|
||||
self._worker_group.group_rank,
|
||||
self._worker_group.group_world_size,
|
||||
store=self._store,
|
||||
world_size=self._worker_group.group_world_size,
|
||||
key_prefix=_TERMINAL_STATE_SYNC_ID,
|
||||
barrier_timeout=self._exit_barrier_timeout,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
from torch.distributed.elastic.utils.store import barrier
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -20,8 +21,7 @@ logger = get_logger(__name__)
|
||||
_ADDRESS_IN_USE = "Address already in use"
|
||||
_SOCKET_TIMEOUT = "Socket Timeout"
|
||||
|
||||
_MEMBER_CHECKIN = "_tcp_store/num_members"
|
||||
_LAST_MEMBER_CHECKIN = "_tcp_store/last_member"
|
||||
_TCP_STORE_INIT = "_tcp_store/num_members"
|
||||
|
||||
|
||||
def create_c10d_store(
|
||||
@@ -54,8 +54,9 @@ def create_c10d_store(
|
||||
"Creating c10d store on %s:%s\n"
|
||||
" world_size : %s\n"
|
||||
" is_server : %s\n"
|
||||
" timeout(sec): %s\n",
|
||||
server_addr, port, world_size, is_server, timeout
|
||||
" timeout(sec): %s\n"
|
||||
" use_libuv : %s\n",
|
||||
server_addr, port, world_size, is_server, timeout, use_libuv,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -75,7 +76,7 @@ def create_c10d_store(
|
||||
store = store_builder(use_libuv=use_libuv)
|
||||
# skips full rank check when we don't have to wait for all workers
|
||||
if wait_for_workers:
|
||||
_check_full_rank(store, world_size)
|
||||
_check_full_rank(store, world_size, timeout=timeout)
|
||||
logger.info("Successfully created c10d store")
|
||||
return store
|
||||
except RuntimeError as e:
|
||||
@@ -98,13 +99,9 @@ def create_c10d_store(
|
||||
raise
|
||||
|
||||
|
||||
def _check_full_rank(store, world_size):
|
||||
idx = store.add(_MEMBER_CHECKIN, 1)
|
||||
if idx == world_size:
|
||||
store.set(_LAST_MEMBER_CHECKIN, "<val_ignored>")
|
||||
|
||||
def _check_full_rank(store, world_size, timeout):
|
||||
try:
|
||||
store.get(_LAST_MEMBER_CHECKIN)
|
||||
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
|
||||
except RuntimeError as e:
|
||||
if str(e) == _SOCKET_TIMEOUT:
|
||||
raise TimeoutError(
|
||||
|
||||
@@ -8,9 +8,29 @@
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
|
||||
_NUM_MEMBERS = "/num_members"
|
||||
_LAST_MEMBER_CHECKIN = "/last_member"
|
||||
|
||||
@contextmanager
|
||||
def store_timeout(store, timeout: float):
|
||||
"""
|
||||
This sets the timeout and then restores the old timeout when the context
|
||||
manager exits.
|
||||
|
||||
Args:
|
||||
store: the store to set the timeout on
|
||||
timeout: the timeout to set
|
||||
"""
|
||||
|
||||
old_timeout = store.timeout
|
||||
store.set_timeout(timedelta(seconds=timeout))
|
||||
yield
|
||||
store.set_timeout(old_timeout)
|
||||
|
||||
|
||||
def get_all(store, rank: int, prefix: str, size: int):
|
||||
def get_all(store, rank: int, prefix: str, world_size: int):
|
||||
r"""
|
||||
Given a store and a prefix, the method goes through the array of keys
|
||||
of the following format: ``{prefix}{idx}``, where idx is in a range
|
||||
@@ -29,17 +49,20 @@ def get_all(store, rank: int, prefix: str, size: int):
|
||||
value3 = values[2] # retrieves the data for key torchelastic/data2
|
||||
|
||||
"""
|
||||
data_arr = []
|
||||
for idx in range(size):
|
||||
data = store.get(f"{prefix}{idx}")
|
||||
data_arr.append(data)
|
||||
store.set(f"{prefix}{rank}.FIN", b"FIN")
|
||||
data_arr = store.multi_get(
|
||||
[f"{prefix}{idx}" for idx in range(world_size)]
|
||||
)
|
||||
|
||||
barrier_key = _barrier_nonblocking(
|
||||
store=store,
|
||||
world_size=world_size,
|
||||
key_prefix=f"{prefix}/finished",
|
||||
)
|
||||
if rank == 0:
|
||||
# Rank0 runs the TCPStore daemon, as a result it needs to exit last.
|
||||
# Otherwise, the barrier may timeout if rank0 process finished the work
|
||||
# before other processes finished `get_all` method
|
||||
for node_rank in range(size):
|
||||
store.get(f"{prefix}{node_rank}.FIN")
|
||||
store.get(barrier_key)
|
||||
|
||||
return data_arr
|
||||
|
||||
@@ -50,7 +73,7 @@ def synchronize(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
key_prefix: str,
|
||||
barrier_timeout: float = 300,
|
||||
timeout: float = 300,
|
||||
) -> List[bytes]:
|
||||
"""
|
||||
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
|
||||
@@ -58,21 +81,47 @@ def synchronize(
|
||||
|
||||
Note: The data on the path is not deleted, as a result there can be stale data if
|
||||
you use the same key_prefix twice.
|
||||
|
||||
Time complexity: O(N) per worker, O(N^2) globally.
|
||||
"""
|
||||
store.set_timeout(timedelta(seconds=barrier_timeout))
|
||||
store.set(f"{key_prefix}{rank}", data)
|
||||
agent_data = get_all(store, rank, key_prefix, world_size)
|
||||
return agent_data
|
||||
with store_timeout(store, timeout):
|
||||
store.set(f"{key_prefix}{rank}", data)
|
||||
agent_data = get_all(store, rank, key_prefix, world_size)
|
||||
return agent_data
|
||||
|
||||
|
||||
def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
|
||||
"""
|
||||
Does all the non-blocking operations for a barrier and returns the final key
|
||||
that can be waited on.
|
||||
"""
|
||||
num_members_key = key_prefix + _NUM_MEMBERS
|
||||
last_member_key = key_prefix + _LAST_MEMBER_CHECKIN
|
||||
|
||||
|
||||
idx = store.add(num_members_key, 1)
|
||||
if idx == world_size:
|
||||
store.set(last_member_key, "<val_ignored>")
|
||||
|
||||
return last_member_key
|
||||
|
||||
|
||||
def barrier(
|
||||
store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
|
||||
store, world_size: int, key_prefix: str, barrier_timeout: float = 300
|
||||
) -> None:
|
||||
"""
|
||||
A global lock between agents.
|
||||
A global lock between agents. This will pause all workers until at least
|
||||
``world_size`` workers respond.
|
||||
|
||||
This uses a fast incrementing index to assign waiting ranks and a success
|
||||
flag set by the last worker.
|
||||
|
||||
Time complexity: O(1) per worker, O(N) globally.
|
||||
|
||||
Note: Since the data is not removed from the store, the barrier can be used
|
||||
once per unique ``key_prefix``.
|
||||
"""
|
||||
data = f"{rank}".encode()
|
||||
synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
|
||||
|
||||
with store_timeout(store, barrier_timeout):
|
||||
last_member_key = _barrier_nonblocking(store=store, world_size=world_size, key_prefix=key_prefix)
|
||||
store.get(last_member_key)
|
||||
|
||||
Reference in New Issue
Block a user