diff --git a/docs/source/conf.py b/docs/source/conf.py index fb7e2bdb844..0f89d2799fa 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -579,6 +579,7 @@ coverage_ignore_functions = [ "barrier", "get_all", "synchronize", + "store_timeout", # torch.distributed.fsdp.wrap "always_wrap_policy", "enable_wrap", diff --git a/test/distributed/elastic/agent/server/test/api_test.py b/test/distributed/elastic/agent/server/test/api_test.py index e57b7b9fcb4..e1dd16bcf96 100644 --- a/test/distributed/elastic/agent/server/test/api_test.py +++ b/test/distributed/elastic/agent/server/test/api_test.py @@ -11,8 +11,11 @@ import signal import unittest import uuid -from typing import Any, Dict -from unittest.mock import call, MagicMock, patch +from multiprocessing.pool import ThreadPool +from typing import Any, Dict, List +from unittest.mock import call, patch + +import torch.distributed as dist import torch.distributed.elastic.rendezvous.registry as rdzv_registry from torch.distributed.elastic.agent.server.api import ( @@ -20,6 +23,7 @@ from torch.distributed.elastic.agent.server.api import ( _RoleInstanceInfo, RunResult, SimpleElasticAgent, + Worker, WorkerGroup, WorkerSpec, WorkerState, @@ -470,22 +474,6 @@ class SimpleElasticAgentTest(unittest.TestCase): self.assertEqual(1, mock_monitor_workers.call_count) self.assertEqual(spec.max_restarts, agent._remaining_restarts) - def test_get_ranks(self): - role_infos = [ - _RoleInstanceInfo("parameter_server", 0, 4), - _RoleInstanceInfo("trainer", 1, 1), - _RoleInstanceInfo("trainer", 2, 2), - _RoleInstanceInfo("trainer", 3, 3), - _RoleInstanceInfo("parameter_server", 4, 5), - ] - spec = self._get_worker_spec( - max_restarts=3, monitor_interval=0.1, role="not_used", local_world_size=8 - ) - agent = TestAgent(spec) - total_sum, ranks = agent._get_ranks(role_infos, 0, 0, len(role_infos)) - self.assertEqual(15, total_sum) - self.assertEqual([0, 1, 2, 3], list(ranks)) - def test_assign_worker_ranks(self): role_infos = [ _RoleInstanceInfo("parameter_server", 0, 4), @@ -494,56 +482,64 @@ class SimpleElasticAgentTest(unittest.TestCase): _RoleInstanceInfo("trainer", 3, 3), _RoleInstanceInfo("parameter_server", 4, 5), ] - num_agents = len(role_infos) - with patch.object(TestAgent, "_share_and_gather", return_value=role_infos): - self.verify_worker_ranks( - role_infos[0], num_agents, [0, 1, 2, 3], [0, 1, 2, 3] + store = dist.HashStore() + + def f(info) -> List[Worker]: + i, role_info = info + spec = self._get_worker_spec( + max_restarts=3, + monitor_interval=0.1, + role=role_info.role, + local_world_size=role_info.local_world_size, ) - self.verify_worker_ranks(role_infos[1], num_agents, [4], [0]) - self.verify_worker_ranks(role_infos[2], num_agents, [5, 6], [1, 2]) - self.verify_worker_ranks(role_infos[3], num_agents, [7, 8, 9], [3, 4, 5]) - - def verify_worker_ranks( - self, agent_config, total_agents, expected_global_ranks, expected_role_ranks - ): - role, agent_rank, local_world_size = ( - agent_config.role, - agent_config.rank, - agent_config.local_world_size, - ) - spec = self._get_worker_spec( - max_restarts=3, - monitor_interval=0.1, - role=role, - local_world_size=local_world_size, - ) - agent = TestAgent(spec) - workers = agent._assign_worker_ranks(None, agent_rank, total_agents, spec) - self.assertEqual( - expected_global_ranks, [worker.global_rank for worker in workers] - ) - self.assertEqual(expected_role_ranks, [worker.role_rank for worker in workers]) - - @patch("torch.distributed.elastic.utils.store.synchronize") - def test_share_and_gather(self, sync_mock): - # when the state is unknown we exit immediately; no retries - spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1) - agent = TestAgent(spec) - expected_agent_infos = [ - _RoleInstanceInfo("trainer", 0, 10), - _RoleInstanceInfo("trainer", 1, 10), - _RoleInstanceInfo("validator", 2, 10), - ] - - sync_mock.return_value = [obj.serialize() for obj in expected_agent_infos] - result = agent._share_and_gather(MagicMock(), 1, 3, spec) - sync_mock.assert_called_once() - for expected_role_info, actual_role_info in zip(expected_agent_infos, result): - self.assertEqual(expected_role_info.role, actual_role_info.role) - self.assertEqual(expected_role_info.rank, actual_role_info.rank) - self.assertEqual( - expected_role_info.local_world_size, actual_role_info.local_world_size + agent = TestAgent(spec) + workers = agent._assign_worker_ranks( + store, role_info.rank, len(role_infos), spec ) + return [ + ( + w.local_rank, + w.role_rank, + w.global_rank, + w.world_size, + w.role_world_size, + ) + for w in workers + ] + + with ThreadPool(len(role_infos)) as pool: + out = pool.map(f, enumerate(role_infos)) + + self.assertListEqual( + out, + [ + [ + (0, 0, 0, 15, 9), + (1, 1, 1, 15, 9), + (2, 2, 2, 15, 9), + (3, 3, 3, 15, 9), + ], + [ + (0, 0, 4, 15, 6), + ], + [ + (0, 1, 5, 15, 6), + (1, 2, 6, 15, 6), + ], + [ + (0, 3, 7, 15, 6), + (1, 4, 8, 15, 6), + (2, 5, 9, 15, 6), + ], + [ + (0, 4, 10, 15, 9), + (1, 5, 11, 15, 9), + (2, 6, 12, 15, 9), + (3, 7, 13, 15, 9), + (4, 8, 14, 15, 9), + ], + ], + ) def test_get_event(self): spec = self._get_worker_spec(max_restarts=1) diff --git a/test/distributed/elastic/utils/util_test.py b/test/distributed/elastic/utils/util_test.py index 60db327d5b9..ab890b0375a 100644 --- a/test/distributed/elastic/utils/util_test.py +++ b/test/distributed/elastic/utils/util_test.py @@ -7,77 +7,142 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from unittest import mock +import datetime +from multiprocessing.pool import ThreadPool +from typing import List + +import torch.distributed as dist import torch.distributed.elastic.utils.store as store_util from torch.distributed.elastic.utils.logging import get_logger from torch.testing._internal.common_utils import run_tests, TestCase +class MockStore: + def __init__(self): + self.ops = [] + + def set_timeout(self, timeout: float) -> None: + self.ops.append(("set_timeout", timeout)) + + @property + def timeout(self) -> datetime.timedelta: + self.ops.append(("timeout",)) + + return datetime.timedelta(seconds=1234) + + def set(self, key: str, value: str) -> None: + self.ops.append(("set", key, value)) + + def get(self, key: str) -> str: + self.ops.append(("get", key)) + return "value" + + def multi_get(self, keys: List[str]) -> List[str]: + self.ops.append(("multi_get", keys)) + return ["value"] * len(keys) + + def add(self, key: str, val: int) -> int: + self.ops.append(("add", key, val)) + return 3 + + class StoreUtilTest(TestCase): def test_get_all_rank_0(self): - store = mock.MagicMock() world_size = 3 - store_util.get_all(store, 0, "test/store", world_size) - # omit empty kwargs, get only key - actual_set_call_args = [ - call_args[0][0] for call_args in store.set.call_args_list - ] - self.assertListEqual(["test/store0.FIN"], actual_set_call_args) - actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list] - expected_get_call_args = [ - ("test/store0",), - ("test/store1",), - ("test/store2",), - ("test/store0.FIN",), - ("test/store1.FIN",), - ("test/store2.FIN",), - ] - self.assertListEqual(expected_get_call_args, actual_get_call_args) + store = MockStore() + + store_util.get_all(store, 0, "test/store", world_size) + + self.assertListEqual( + store.ops, + [ + ("multi_get", ["test/store0", "test/store1", "test/store2"]), + ("add", "test/store/finished/num_members", 1), + ("set", "test/store/finished/last_member", ""), + ("get", "test/store/finished/last_member"), + ], + ) def test_get_all_rank_n(self): - store = mock.MagicMock() + store = MockStore() world_size = 3 store_util.get_all(store, 1, "test/store", world_size) - # omit empty kwargs, get only key - actual_set_call_args = [ - call_args[0][0] for call_args in store.set.call_args_list - ] - self.assertListEqual(["test/store1.FIN"], actual_set_call_args) - actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list] - expected_get_call_args = [ - ("test/store0",), - ("test/store1",), - ("test/store2",), - ] - self.assertListEqual(expected_get_call_args, actual_get_call_args) + self.assertListEqual( + store.ops, + [ + ("multi_get", ["test/store0", "test/store1", "test/store2"]), + ("add", "test/store/finished/num_members", 1), + ("set", "test/store/finished/last_member", ""), + ], + ) def test_synchronize(self): - store_mock = mock.MagicMock() - data = b"data0" - store_util.synchronize(store_mock, data, 0, 3, key_prefix="torchelastic/test") - actual_set_call_args = store_mock.set.call_args_list - # omit empty kwargs - actual_set_call_args = [call_args[0] for call_args in actual_set_call_args] - expected_set_call_args = [ - ("torchelastic/test0", b"data0"), - ("torchelastic/test0.FIN", b"FIN"), - ] - self.assertListEqual(expected_set_call_args, actual_set_call_args) + store = MockStore() - expected_get_call_args = [ - ("torchelastic/test0",), - ("torchelastic/test1",), - ("torchelastic/test2",), - ("torchelastic/test0.FIN",), - ("torchelastic/test1.FIN",), - ("torchelastic/test2.FIN",), - ] - actual_get_call_args = store_mock.get.call_args_list - actual_get_call_args = [call_args[0] for call_args in actual_get_call_args] - self.assertListEqual(expected_get_call_args, actual_get_call_args) + data = b"data0" + store_util.synchronize(store, data, 0, 3, key_prefix="test/store") + + self.assertListEqual( + store.ops, + [ + ("timeout",), + ("set_timeout", datetime.timedelta(seconds=300)), + ("set", "test/store0", data), + ("multi_get", ["test/store0", "test/store1", "test/store2"]), + ("add", "test/store/finished/num_members", 1), + ("set", "test/store/finished/last_member", ""), + ("get", "test/store/finished/last_member"), + ("set_timeout", datetime.timedelta(seconds=1234)), + ], + ) + + def test_synchronize_hash_store(self) -> None: + N = 4 + + store = dist.HashStore() + + def f(i: int): + return store_util.synchronize( + store, f"data{i}", i, N, key_prefix="test/store" + ) + + with ThreadPool(N) as pool: + out = pool.map(f, range(N)) + + self.assertListEqual(out, [[f"data{i}".encode() for i in range(N)]] * N) + + def test_barrier(self): + store = MockStore() + + store_util.barrier(store, 3, key_prefix="test/store") + + self.assertListEqual( + store.ops, + [ + ("timeout",), + ("set_timeout", datetime.timedelta(seconds=300)), + ("add", "test/store/num_members", 1), + ("set", "test/store/last_member", ""), + ("get", "test/store/last_member"), + ("set_timeout", datetime.timedelta(seconds=1234)), + ], + ) + + def test_barrier_hash_store(self) -> None: + N = 4 + + store = dist.HashStore() + + def f(i: int): + store_util.barrier(store, N, key_prefix="test/store") + + with ThreadPool(N) as pool: + out = pool.map(f, range(N)) + + self.assertEqual(out, [None] * N) class UtilTest(TestCase): diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index dd20703cedb..c9f76e5917b 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -7,7 +7,6 @@ # LICENSE file in the root directory of this source tree. import abc -import functools import json import os import signal @@ -30,6 +29,7 @@ from torch.distributed.elastic.multiprocessing import ( ProcessFailure, SignalException, ) +from collections import defaultdict from torch.distributed.elastic.utils.logging import get_logger __all__ = [ @@ -592,26 +592,6 @@ class SimpleElasticAgent(ElasticAgent): } ) - def _get_ranks( - self, - role_infos: List[_RoleInstanceInfo], - role_idx: int, - start_idx: int = 0, - end_idx: int = -1, - ) -> Tuple[int, List[int]]: - if end_idx == -1: - end_idx = len(role_infos) - prefix_sum = 0 - total_sum = 0 - for idx in range(start_idx, end_idx): - if role_idx > idx: - prefix_sum += role_infos[idx].local_world_size - total_sum += role_infos[idx].local_world_size - return ( - total_sum, - list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)), - ) - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof @@ -624,63 +604,86 @@ class SimpleElasticAgent(ElasticAgent): 1. Each agent writes its configuration(group_rank, group_world_size , num_workers) to the common store. - 2. Each agent retrieves configuration for all agents - and performs two level sort using role and rank. - 3. Determine the global rank: the global rank of the workers for the current - agent is the offset of the infos array up to group_rank of the agent. - The offset is computed as a sum of local_world_size of all agents that - have rank less than the group_rank. The workers would have the ranks: - [offset, offset+local_world_size) + 2. The rank 0 agent reads all the role_info from the store and + determines each agents worker ranks. + 3. Determine the global rank: the global rank of the workers is computed + by cumulative sum of the local_world_size for all workers in front of it. + For efficiency reasons each worker is assigned a base global rank + such that it's workers are in the range [base_global_rank, + base_global_rank + local_world_size). 4. Determine the role rank: The role rank is determined using the algorithms - in the point 3 with the exception that the offset is done from the first - agent that has the same role as current one and has the minimum group rank. + in the point 3 with the exception that the ranks are calculated with + respect to the role name. + 5. The rank 0 agent writes the assigned ranks to the store. + 6. Each agent reads the assigned ranks from the store. + + Time complexity: each worker O(1), rank0 O(n), overall O(n) """ - role_infos = self._share_and_gather(store, group_rank, group_world_size, spec) - my_role_info = role_infos[group_rank] - worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank) - role_infos = sorted( - role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare) + + ROLE_INFO_PREFIX = "torchelastic/role_info/" + ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/" + + agent_role_info = _RoleInstanceInfo( + spec.role, group_rank, spec.local_world_size ) - role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries( - role_infos, my_role_info.role - ) - role_pos = next( - idx - for idx, role_info in enumerate(role_infos) - if _RoleInstanceInfo.compare(role_info, my_role_info) == 0 - ) - role_world_size, role_ranks = self._get_ranks( - role_infos, role_pos, role_start_idx, role_end_idx + 1 + store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize()) + + # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations. + if group_rank == 0: + role_infos_bytes = store.multi_get( + [f"torchelastic/role_info/{i}" for i in range(group_world_size)] + ) + role_infos = [ + _RoleInstanceInfo.deserialize(info_bytes) + for info_bytes in role_infos_bytes + ] + + role_sizes = defaultdict(lambda: 0) + global_size = 0 + for role_info in role_infos: + role_sizes[role_info.role] += role_info.local_world_size + global_size += role_info.local_world_size + + base_global_rank = 0 + role_ranks = defaultdict(lambda: 0) + + keys = [] + values = [] + for i, role_info in enumerate(role_infos): + keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}") + values.append( + json.dumps( + [ + base_global_rank, + global_size, + role_ranks[role_info.role], + role_sizes[role_info.role], + ] + ) + ) + + base_global_rank += role_info.local_world_size + role_ranks[role_info.role] += role_info.local_world_size + + store.multi_set(keys, values) + + # get will block until the data is available in the store. + base_global_rank, global_world_size, base_role_rank, role_world_size = json.loads( + store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}") ) + workers = [] - for ind in range(spec.local_world_size): + for local_rank in range(spec.local_world_size): worker = Worker( - local_rank=ind, - global_rank=worker_global_ranks[ind], - role_rank=role_ranks[ind], - world_size=worker_world_size, + local_rank=local_rank, + global_rank=base_global_rank + local_rank, + role_rank=base_role_rank + local_rank, + world_size=global_world_size, role_world_size=role_world_size, ) workers.append(worker) return workers - def _share_and_gather( - self, store, group_rank: int, group_world_size: int, spec: WorkerSpec - ) -> List: - agent_role_info = _RoleInstanceInfo( - spec.role, group_rank, spec.local_world_size - ) - key_prefix = "torchelastic/role_info" - agent_config_enc = agent_role_info.serialize() - role_infos_bytes = store_util.synchronize( - store, agent_config_enc, group_rank, group_world_size, key_prefix - ) - role_infos = [ - _RoleInstanceInfo.deserialize(role_info_bytes) - for role_info_bytes in role_infos_bytes - ] - return role_infos - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof @@ -935,9 +938,8 @@ class SimpleElasticAgent(ElasticAgent): start = time.time() try: store_util.barrier( - self._store, - self._worker_group.group_rank, - self._worker_group.group_world_size, + store=self._store, + world_size=self._worker_group.group_world_size, key_prefix=_TERMINAL_STATE_SYNC_ID, barrier_timeout=self._exit_barrier_timeout, ) diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 808b965d50c..bf4a537bbf0 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -13,6 +13,7 @@ from typing import Optional import torch.distributed as dist from torch.distributed.elastic.utils.logging import get_logger +from torch.distributed.elastic.utils.store import barrier logger = get_logger(__name__) @@ -20,8 +21,7 @@ logger = get_logger(__name__) _ADDRESS_IN_USE = "Address already in use" _SOCKET_TIMEOUT = "Socket Timeout" -_MEMBER_CHECKIN = "_tcp_store/num_members" -_LAST_MEMBER_CHECKIN = "_tcp_store/last_member" +_TCP_STORE_INIT = "_tcp_store/num_members" def create_c10d_store( @@ -54,8 +54,9 @@ def create_c10d_store( "Creating c10d store on %s:%s\n" " world_size : %s\n" " is_server : %s\n" - " timeout(sec): %s\n", - server_addr, port, world_size, is_server, timeout + " timeout(sec): %s\n" + " use_libuv : %s\n", + server_addr, port, world_size, is_server, timeout, use_libuv, ) try: @@ -75,7 +76,7 @@ def create_c10d_store( store = store_builder(use_libuv=use_libuv) # skips full rank check when we don't have to wait for all workers if wait_for_workers: - _check_full_rank(store, world_size) + _check_full_rank(store, world_size, timeout=timeout) logger.info("Successfully created c10d store") return store except RuntimeError as e: @@ -98,13 +99,9 @@ def create_c10d_store( raise -def _check_full_rank(store, world_size): - idx = store.add(_MEMBER_CHECKIN, 1) - if idx == world_size: - store.set(_LAST_MEMBER_CHECKIN, "") - +def _check_full_rank(store, world_size, timeout): try: - store.get(_LAST_MEMBER_CHECKIN) + barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout) except RuntimeError as e: if str(e) == _SOCKET_TIMEOUT: raise TimeoutError( diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index 9c7abab9291..719c83b8265 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -8,9 +8,29 @@ from datetime import timedelta from typing import List +from contextlib import contextmanager + +_NUM_MEMBERS = "/num_members" +_LAST_MEMBER_CHECKIN = "/last_member" + +@contextmanager +def store_timeout(store, timeout: float): + """ + This sets the timeout and then restores the old timeout when the context + manager exits. + + Args: + store: the store to set the timeout on + timeout: the timeout to set + """ + + old_timeout = store.timeout + store.set_timeout(timedelta(seconds=timeout)) + yield + store.set_timeout(old_timeout) -def get_all(store, rank: int, prefix: str, size: int): +def get_all(store, rank: int, prefix: str, world_size: int): r""" Given a store and a prefix, the method goes through the array of keys of the following format: ``{prefix}{idx}``, where idx is in a range @@ -29,17 +49,20 @@ def get_all(store, rank: int, prefix: str, size: int): value3 = values[2] # retrieves the data for key torchelastic/data2 """ - data_arr = [] - for idx in range(size): - data = store.get(f"{prefix}{idx}") - data_arr.append(data) - store.set(f"{prefix}{rank}.FIN", b"FIN") + data_arr = store.multi_get( + [f"{prefix}{idx}" for idx in range(world_size)] + ) + + barrier_key = _barrier_nonblocking( + store=store, + world_size=world_size, + key_prefix=f"{prefix}/finished", + ) if rank == 0: # Rank0 runs the TCPStore daemon, as a result it needs to exit last. # Otherwise, the barrier may timeout if rank0 process finished the work # before other processes finished `get_all` method - for node_rank in range(size): - store.get(f"{prefix}{node_rank}.FIN") + store.get(barrier_key) return data_arr @@ -50,7 +73,7 @@ def synchronize( rank: int, world_size: int, key_prefix: str, - barrier_timeout: float = 300, + timeout: float = 300, ) -> List[bytes]: """ Synchronizes ``world_size`` agents between each other using the underlying c10d store. @@ -58,21 +81,47 @@ def synchronize( Note: The data on the path is not deleted, as a result there can be stale data if you use the same key_prefix twice. + + Time complexity: O(N) per worker, O(N^2) globally. """ - store.set_timeout(timedelta(seconds=barrier_timeout)) - store.set(f"{key_prefix}{rank}", data) - agent_data = get_all(store, rank, key_prefix, world_size) - return agent_data + with store_timeout(store, timeout): + store.set(f"{key_prefix}{rank}", data) + agent_data = get_all(store, rank, key_prefix, world_size) + return agent_data + + +def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: + """ + Does all the non-blocking operations for a barrier and returns the final key + that can be waited on. + """ + num_members_key = key_prefix + _NUM_MEMBERS + last_member_key = key_prefix + _LAST_MEMBER_CHECKIN + + + idx = store.add(num_members_key, 1) + if idx == world_size: + store.set(last_member_key, "") + + return last_member_key def barrier( - store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300 + store, world_size: int, key_prefix: str, barrier_timeout: float = 300 ) -> None: """ - A global lock between agents. + A global lock between agents. This will pause all workers until at least + ``world_size`` workers respond. + + This uses a fast incrementing index to assign waiting ranks and a success + flag set by the last worker. + + Time complexity: O(1) per worker, O(N) globally. Note: Since the data is not removed from the store, the barrier can be used once per unique ``key_prefix``. """ - data = f"{rank}".encode() - synchronize(store, data, rank, world_size, key_prefix, barrier_timeout) + + with store_timeout(store, barrier_timeout): + last_member_key = _barrier_nonblocking(store=store, world_size=world_size, key_prefix=key_prefix) + store.get(last_member_key)