elastic/rendezvous: make barrier and rank assignment operations O(n) instead of O(n^2) (#124982)

Summary:
This makes barrier and rank operations linear instead of quadratic with the number of workers. This drastically improves performance for rendezvous when running with over 1000 hosts.

This uses 2 approaches for different areas:

* local rank assignment: each worker does 1 set and 1 get, local ranks are assigned on the rank 0 host in a O(n) operation which reduces total store operations to be linear with number of workers.
* exit_barrier: use a counter and a final flag so each worker has to do max 1 set, 1 get and 1 add.

At 4000 hosts we see torchelastic be able to run in as little as 10 seconds down from 373 seconds.

Test Plan:
This is testing using many small tests running on a remote cluster.

{D56549942}

```
torchx run --scheduler mast -- --image=torchelastic_benchmark --j=4000x1
```

Differential Revision: D56605193

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124982
Approved by: https://github.com/kiukchung, https://github.com/kurman
This commit is contained in:
Tristan Rice
2024-04-27 02:21:44 +00:00
committed by PyTorch MergeBot
parent 1a6fef15ef
commit dc4c75ba72
6 changed files with 327 additions and 217 deletions

View File

@@ -579,6 +579,7 @@ coverage_ignore_functions = [
"barrier",
"get_all",
"synchronize",
"store_timeout",
# torch.distributed.fsdp.wrap
"always_wrap_policy",
"enable_wrap",

View File

@@ -11,8 +11,11 @@
import signal
import unittest
import uuid
from typing import Any, Dict
from unittest.mock import call, MagicMock, patch
from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List
from unittest.mock import call, patch
import torch.distributed as dist
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic.agent.server.api import (
@@ -20,6 +23,7 @@ from torch.distributed.elastic.agent.server.api import (
_RoleInstanceInfo,
RunResult,
SimpleElasticAgent,
Worker,
WorkerGroup,
WorkerSpec,
WorkerState,
@@ -470,22 +474,6 @@ class SimpleElasticAgentTest(unittest.TestCase):
self.assertEqual(1, mock_monitor_workers.call_count)
self.assertEqual(spec.max_restarts, agent._remaining_restarts)
def test_get_ranks(self):
role_infos = [
_RoleInstanceInfo("parameter_server", 0, 4),
_RoleInstanceInfo("trainer", 1, 1),
_RoleInstanceInfo("trainer", 2, 2),
_RoleInstanceInfo("trainer", 3, 3),
_RoleInstanceInfo("parameter_server", 4, 5),
]
spec = self._get_worker_spec(
max_restarts=3, monitor_interval=0.1, role="not_used", local_world_size=8
)
agent = TestAgent(spec)
total_sum, ranks = agent._get_ranks(role_infos, 0, 0, len(role_infos))
self.assertEqual(15, total_sum)
self.assertEqual([0, 1, 2, 3], list(ranks))
def test_assign_worker_ranks(self):
role_infos = [
_RoleInstanceInfo("parameter_server", 0, 4),
@@ -494,56 +482,64 @@ class SimpleElasticAgentTest(unittest.TestCase):
_RoleInstanceInfo("trainer", 3, 3),
_RoleInstanceInfo("parameter_server", 4, 5),
]
num_agents = len(role_infos)
with patch.object(TestAgent, "_share_and_gather", return_value=role_infos):
self.verify_worker_ranks(
role_infos[0], num_agents, [0, 1, 2, 3], [0, 1, 2, 3]
store = dist.HashStore()
def f(info) -> List[Worker]:
i, role_info = info
spec = self._get_worker_spec(
max_restarts=3,
monitor_interval=0.1,
role=role_info.role,
local_world_size=role_info.local_world_size,
)
self.verify_worker_ranks(role_infos[1], num_agents, [4], [0])
self.verify_worker_ranks(role_infos[2], num_agents, [5, 6], [1, 2])
self.verify_worker_ranks(role_infos[3], num_agents, [7, 8, 9], [3, 4, 5])
def verify_worker_ranks(
self, agent_config, total_agents, expected_global_ranks, expected_role_ranks
):
role, agent_rank, local_world_size = (
agent_config.role,
agent_config.rank,
agent_config.local_world_size,
)
spec = self._get_worker_spec(
max_restarts=3,
monitor_interval=0.1,
role=role,
local_world_size=local_world_size,
)
agent = TestAgent(spec)
workers = agent._assign_worker_ranks(None, agent_rank, total_agents, spec)
self.assertEqual(
expected_global_ranks, [worker.global_rank for worker in workers]
)
self.assertEqual(expected_role_ranks, [worker.role_rank for worker in workers])
@patch("torch.distributed.elastic.utils.store.synchronize")
def test_share_and_gather(self, sync_mock):
# when the state is unknown we exit immediately; no retries
spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1)
agent = TestAgent(spec)
expected_agent_infos = [
_RoleInstanceInfo("trainer", 0, 10),
_RoleInstanceInfo("trainer", 1, 10),
_RoleInstanceInfo("validator", 2, 10),
]
sync_mock.return_value = [obj.serialize() for obj in expected_agent_infos]
result = agent._share_and_gather(MagicMock(), 1, 3, spec)
sync_mock.assert_called_once()
for expected_role_info, actual_role_info in zip(expected_agent_infos, result):
self.assertEqual(expected_role_info.role, actual_role_info.role)
self.assertEqual(expected_role_info.rank, actual_role_info.rank)
self.assertEqual(
expected_role_info.local_world_size, actual_role_info.local_world_size
agent = TestAgent(spec)
workers = agent._assign_worker_ranks(
store, role_info.rank, len(role_infos), spec
)
return [
(
w.local_rank,
w.role_rank,
w.global_rank,
w.world_size,
w.role_world_size,
)
for w in workers
]
with ThreadPool(len(role_infos)) as pool:
out = pool.map(f, enumerate(role_infos))
self.assertListEqual(
out,
[
[
(0, 0, 0, 15, 9),
(1, 1, 1, 15, 9),
(2, 2, 2, 15, 9),
(3, 3, 3, 15, 9),
],
[
(0, 0, 4, 15, 6),
],
[
(0, 1, 5, 15, 6),
(1, 2, 6, 15, 6),
],
[
(0, 3, 7, 15, 6),
(1, 4, 8, 15, 6),
(2, 5, 9, 15, 6),
],
[
(0, 4, 10, 15, 9),
(1, 5, 11, 15, 9),
(2, 6, 12, 15, 9),
(3, 7, 13, 15, 9),
(4, 8, 14, 15, 9),
],
],
)
def test_get_event(self):
spec = self._get_worker_spec(max_restarts=1)

View File

@@ -7,77 +7,142 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from unittest import mock
import datetime
from multiprocessing.pool import ThreadPool
from typing import List
import torch.distributed as dist
import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.utils.logging import get_logger
from torch.testing._internal.common_utils import run_tests, TestCase
class MockStore:
def __init__(self):
self.ops = []
def set_timeout(self, timeout: float) -> None:
self.ops.append(("set_timeout", timeout))
@property
def timeout(self) -> datetime.timedelta:
self.ops.append(("timeout",))
return datetime.timedelta(seconds=1234)
def set(self, key: str, value: str) -> None:
self.ops.append(("set", key, value))
def get(self, key: str) -> str:
self.ops.append(("get", key))
return "value"
def multi_get(self, keys: List[str]) -> List[str]:
self.ops.append(("multi_get", keys))
return ["value"] * len(keys)
def add(self, key: str, val: int) -> int:
self.ops.append(("add", key, val))
return 3
class StoreUtilTest(TestCase):
def test_get_all_rank_0(self):
store = mock.MagicMock()
world_size = 3
store_util.get_all(store, 0, "test/store", world_size)
# omit empty kwargs, get only key
actual_set_call_args = [
call_args[0][0] for call_args in store.set.call_args_list
]
self.assertListEqual(["test/store0.FIN"], actual_set_call_args)
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
expected_get_call_args = [
("test/store0",),
("test/store1",),
("test/store2",),
("test/store0.FIN",),
("test/store1.FIN",),
("test/store2.FIN",),
]
self.assertListEqual(expected_get_call_args, actual_get_call_args)
store = MockStore()
store_util.get_all(store, 0, "test/store", world_size)
self.assertListEqual(
store.ops,
[
("multi_get", ["test/store0", "test/store1", "test/store2"]),
("add", "test/store/finished/num_members", 1),
("set", "test/store/finished/last_member", "<val_ignored>"),
("get", "test/store/finished/last_member"),
],
)
def test_get_all_rank_n(self):
store = mock.MagicMock()
store = MockStore()
world_size = 3
store_util.get_all(store, 1, "test/store", world_size)
# omit empty kwargs, get only key
actual_set_call_args = [
call_args[0][0] for call_args in store.set.call_args_list
]
self.assertListEqual(["test/store1.FIN"], actual_set_call_args)
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
expected_get_call_args = [
("test/store0",),
("test/store1",),
("test/store2",),
]
self.assertListEqual(expected_get_call_args, actual_get_call_args)
self.assertListEqual(
store.ops,
[
("multi_get", ["test/store0", "test/store1", "test/store2"]),
("add", "test/store/finished/num_members", 1),
("set", "test/store/finished/last_member", "<val_ignored>"),
],
)
def test_synchronize(self):
store_mock = mock.MagicMock()
data = b"data0"
store_util.synchronize(store_mock, data, 0, 3, key_prefix="torchelastic/test")
actual_set_call_args = store_mock.set.call_args_list
# omit empty kwargs
actual_set_call_args = [call_args[0] for call_args in actual_set_call_args]
expected_set_call_args = [
("torchelastic/test0", b"data0"),
("torchelastic/test0.FIN", b"FIN"),
]
self.assertListEqual(expected_set_call_args, actual_set_call_args)
store = MockStore()
expected_get_call_args = [
("torchelastic/test0",),
("torchelastic/test1",),
("torchelastic/test2",),
("torchelastic/test0.FIN",),
("torchelastic/test1.FIN",),
("torchelastic/test2.FIN",),
]
actual_get_call_args = store_mock.get.call_args_list
actual_get_call_args = [call_args[0] for call_args in actual_get_call_args]
self.assertListEqual(expected_get_call_args, actual_get_call_args)
data = b"data0"
store_util.synchronize(store, data, 0, 3, key_prefix="test/store")
self.assertListEqual(
store.ops,
[
("timeout",),
("set_timeout", datetime.timedelta(seconds=300)),
("set", "test/store0", data),
("multi_get", ["test/store0", "test/store1", "test/store2"]),
("add", "test/store/finished/num_members", 1),
("set", "test/store/finished/last_member", "<val_ignored>"),
("get", "test/store/finished/last_member"),
("set_timeout", datetime.timedelta(seconds=1234)),
],
)
def test_synchronize_hash_store(self) -> None:
N = 4
store = dist.HashStore()
def f(i: int):
return store_util.synchronize(
store, f"data{i}", i, N, key_prefix="test/store"
)
with ThreadPool(N) as pool:
out = pool.map(f, range(N))
self.assertListEqual(out, [[f"data{i}".encode() for i in range(N)]] * N)
def test_barrier(self):
store = MockStore()
store_util.barrier(store, 3, key_prefix="test/store")
self.assertListEqual(
store.ops,
[
("timeout",),
("set_timeout", datetime.timedelta(seconds=300)),
("add", "test/store/num_members", 1),
("set", "test/store/last_member", "<val_ignored>"),
("get", "test/store/last_member"),
("set_timeout", datetime.timedelta(seconds=1234)),
],
)
def test_barrier_hash_store(self) -> None:
N = 4
store = dist.HashStore()
def f(i: int):
store_util.barrier(store, N, key_prefix="test/store")
with ThreadPool(N) as pool:
out = pool.map(f, range(N))
self.assertEqual(out, [None] * N)
class UtilTest(TestCase):

View File

@@ -7,7 +7,6 @@
# LICENSE file in the root directory of this source tree.
import abc
import functools
import json
import os
import signal
@@ -30,6 +29,7 @@ from torch.distributed.elastic.multiprocessing import (
ProcessFailure,
SignalException,
)
from collections import defaultdict
from torch.distributed.elastic.utils.logging import get_logger
__all__ = [
@@ -592,26 +592,6 @@ class SimpleElasticAgent(ElasticAgent):
}
)
def _get_ranks(
self,
role_infos: List[_RoleInstanceInfo],
role_idx: int,
start_idx: int = 0,
end_idx: int = -1,
) -> Tuple[int, List[int]]:
if end_idx == -1:
end_idx = len(role_infos)
prefix_sum = 0
total_sum = 0
for idx in range(start_idx, end_idx):
if role_idx > idx:
prefix_sum += role_infos[idx].local_world_size
total_sum += role_infos[idx].local_world_size
return (
total_sum,
list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)),
)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
@@ -624,63 +604,86 @@ class SimpleElasticAgent(ElasticAgent):
1. Each agent writes its configuration(group_rank, group_world_size
, num_workers) to the common store.
2. Each agent retrieves configuration for all agents
and performs two level sort using role and rank.
3. Determine the global rank: the global rank of the workers for the current
agent is the offset of the infos array up to group_rank of the agent.
The offset is computed as a sum of local_world_size of all agents that
have rank less than the group_rank. The workers would have the ranks:
[offset, offset+local_world_size)
2. The rank 0 agent reads all the role_info from the store and
determines each agents worker ranks.
3. Determine the global rank: the global rank of the workers is computed
by cumulative sum of the local_world_size for all workers in front of it.
For efficiency reasons each worker is assigned a base global rank
such that it's workers are in the range [base_global_rank,
base_global_rank + local_world_size).
4. Determine the role rank: The role rank is determined using the algorithms
in the point 3 with the exception that the offset is done from the first
agent that has the same role as current one and has the minimum group rank.
in the point 3 with the exception that the ranks are calculated with
respect to the role name.
5. The rank 0 agent writes the assigned ranks to the store.
6. Each agent reads the assigned ranks from the store.
Time complexity: each worker O(1), rank0 O(n), overall O(n)
"""
role_infos = self._share_and_gather(store, group_rank, group_world_size, spec)
my_role_info = role_infos[group_rank]
worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank)
role_infos = sorted(
role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare)
ROLE_INFO_PREFIX = "torchelastic/role_info/"
ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/"
agent_role_info = _RoleInstanceInfo(
spec.role, group_rank, spec.local_world_size
)
role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries(
role_infos, my_role_info.role
)
role_pos = next(
idx
for idx, role_info in enumerate(role_infos)
if _RoleInstanceInfo.compare(role_info, my_role_info) == 0
)
role_world_size, role_ranks = self._get_ranks(
role_infos, role_pos, role_start_idx, role_end_idx + 1
store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize())
# tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations.
if group_rank == 0:
role_infos_bytes = store.multi_get(
[f"torchelastic/role_info/{i}" for i in range(group_world_size)]
)
role_infos = [
_RoleInstanceInfo.deserialize(info_bytes)
for info_bytes in role_infos_bytes
]
role_sizes = defaultdict(lambda: 0)
global_size = 0
for role_info in role_infos:
role_sizes[role_info.role] += role_info.local_world_size
global_size += role_info.local_world_size
base_global_rank = 0
role_ranks = defaultdict(lambda: 0)
keys = []
values = []
for i, role_info in enumerate(role_infos):
keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}")
values.append(
json.dumps(
[
base_global_rank,
global_size,
role_ranks[role_info.role],
role_sizes[role_info.role],
]
)
)
base_global_rank += role_info.local_world_size
role_ranks[role_info.role] += role_info.local_world_size
store.multi_set(keys, values)
# get will block until the data is available in the store.
base_global_rank, global_world_size, base_role_rank, role_world_size = json.loads(
store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")
)
workers = []
for ind in range(spec.local_world_size):
for local_rank in range(spec.local_world_size):
worker = Worker(
local_rank=ind,
global_rank=worker_global_ranks[ind],
role_rank=role_ranks[ind],
world_size=worker_world_size,
local_rank=local_rank,
global_rank=base_global_rank + local_rank,
role_rank=base_role_rank + local_rank,
world_size=global_world_size,
role_world_size=role_world_size,
)
workers.append(worker)
return workers
def _share_and_gather(
self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
) -> List:
agent_role_info = _RoleInstanceInfo(
spec.role, group_rank, spec.local_world_size
)
key_prefix = "torchelastic/role_info"
agent_config_enc = agent_role_info.serialize()
role_infos_bytes = store_util.synchronize(
store, agent_config_enc, group_rank, group_world_size, key_prefix
)
role_infos = [
_RoleInstanceInfo.deserialize(role_info_bytes)
for role_info_bytes in role_infos_bytes
]
return role_infos
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
@@ -935,9 +938,8 @@ class SimpleElasticAgent(ElasticAgent):
start = time.time()
try:
store_util.barrier(
self._store,
self._worker_group.group_rank,
self._worker_group.group_world_size,
store=self._store,
world_size=self._worker_group.group_world_size,
key_prefix=_TERMINAL_STATE_SYNC_ID,
barrier_timeout=self._exit_barrier_timeout,
)

View File

@@ -13,6 +13,7 @@ from typing import Optional
import torch.distributed as dist
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.elastic.utils.store import barrier
logger = get_logger(__name__)
@@ -20,8 +21,7 @@ logger = get_logger(__name__)
_ADDRESS_IN_USE = "Address already in use"
_SOCKET_TIMEOUT = "Socket Timeout"
_MEMBER_CHECKIN = "_tcp_store/num_members"
_LAST_MEMBER_CHECKIN = "_tcp_store/last_member"
_TCP_STORE_INIT = "_tcp_store/num_members"
def create_c10d_store(
@@ -54,8 +54,9 @@ def create_c10d_store(
"Creating c10d store on %s:%s\n"
" world_size : %s\n"
" is_server : %s\n"
" timeout(sec): %s\n",
server_addr, port, world_size, is_server, timeout
" timeout(sec): %s\n"
" use_libuv : %s\n",
server_addr, port, world_size, is_server, timeout, use_libuv,
)
try:
@@ -75,7 +76,7 @@ def create_c10d_store(
store = store_builder(use_libuv=use_libuv)
# skips full rank check when we don't have to wait for all workers
if wait_for_workers:
_check_full_rank(store, world_size)
_check_full_rank(store, world_size, timeout=timeout)
logger.info("Successfully created c10d store")
return store
except RuntimeError as e:
@@ -98,13 +99,9 @@ def create_c10d_store(
raise
def _check_full_rank(store, world_size):
idx = store.add(_MEMBER_CHECKIN, 1)
if idx == world_size:
store.set(_LAST_MEMBER_CHECKIN, "<val_ignored>")
def _check_full_rank(store, world_size, timeout):
try:
store.get(_LAST_MEMBER_CHECKIN)
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
except RuntimeError as e:
if str(e) == _SOCKET_TIMEOUT:
raise TimeoutError(

View File

@@ -8,9 +8,29 @@
from datetime import timedelta
from typing import List
from contextlib import contextmanager
_NUM_MEMBERS = "/num_members"
_LAST_MEMBER_CHECKIN = "/last_member"
@contextmanager
def store_timeout(store, timeout: float):
"""
This sets the timeout and then restores the old timeout when the context
manager exits.
Args:
store: the store to set the timeout on
timeout: the timeout to set
"""
old_timeout = store.timeout
store.set_timeout(timedelta(seconds=timeout))
yield
store.set_timeout(old_timeout)
def get_all(store, rank: int, prefix: str, size: int):
def get_all(store, rank: int, prefix: str, world_size: int):
r"""
Given a store and a prefix, the method goes through the array of keys
of the following format: ``{prefix}{idx}``, where idx is in a range
@@ -29,17 +49,20 @@ def get_all(store, rank: int, prefix: str, size: int):
value3 = values[2] # retrieves the data for key torchelastic/data2
"""
data_arr = []
for idx in range(size):
data = store.get(f"{prefix}{idx}")
data_arr.append(data)
store.set(f"{prefix}{rank}.FIN", b"FIN")
data_arr = store.multi_get(
[f"{prefix}{idx}" for idx in range(world_size)]
)
barrier_key = _barrier_nonblocking(
store=store,
world_size=world_size,
key_prefix=f"{prefix}/finished",
)
if rank == 0:
# Rank0 runs the TCPStore daemon, as a result it needs to exit last.
# Otherwise, the barrier may timeout if rank0 process finished the work
# before other processes finished `get_all` method
for node_rank in range(size):
store.get(f"{prefix}{node_rank}.FIN")
store.get(barrier_key)
return data_arr
@@ -50,7 +73,7 @@ def synchronize(
rank: int,
world_size: int,
key_prefix: str,
barrier_timeout: float = 300,
timeout: float = 300,
) -> List[bytes]:
"""
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
@@ -58,21 +81,47 @@ def synchronize(
Note: The data on the path is not deleted, as a result there can be stale data if
you use the same key_prefix twice.
Time complexity: O(N) per worker, O(N^2) globally.
"""
store.set_timeout(timedelta(seconds=barrier_timeout))
store.set(f"{key_prefix}{rank}", data)
agent_data = get_all(store, rank, key_prefix, world_size)
return agent_data
with store_timeout(store, timeout):
store.set(f"{key_prefix}{rank}", data)
agent_data = get_all(store, rank, key_prefix, world_size)
return agent_data
def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
"""
Does all the non-blocking operations for a barrier and returns the final key
that can be waited on.
"""
num_members_key = key_prefix + _NUM_MEMBERS
last_member_key = key_prefix + _LAST_MEMBER_CHECKIN
idx = store.add(num_members_key, 1)
if idx == world_size:
store.set(last_member_key, "<val_ignored>")
return last_member_key
def barrier(
store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
store, world_size: int, key_prefix: str, barrier_timeout: float = 300
) -> None:
"""
A global lock between agents.
A global lock between agents. This will pause all workers until at least
``world_size`` workers respond.
This uses a fast incrementing index to assign waiting ranks and a success
flag set by the last worker.
Time complexity: O(1) per worker, O(N) globally.
Note: Since the data is not removed from the store, the barrier can be used
once per unique ``key_prefix``.
"""
data = f"{rank}".encode()
synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
with store_timeout(store, barrier_timeout):
last_member_key = _barrier_nonblocking(store=store, world_size=world_size, key_prefix=key_prefix)
store.get(last_member_key)