diff --git a/dlrover/proto/elastic_training.proto b/dlrover/proto/elastic_training.proto index e493ef1dd..7f479a871 100644 --- a/dlrover/proto/elastic_training.proto +++ b/dlrover/proto/elastic_training.proto @@ -219,6 +219,7 @@ message RendezvousParams { int32 min_nodes = 1; int32 max_nodes = 2; int32 waiting_timeout = 3; + int32 node_unit = 4; } message KeyValuePair { diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index 3e550380b..d1a424797 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -367,11 +367,14 @@ def network_check_success(self, timeout=180): return response.success @retry_grpc_request - def report_rdzv_params(self, min_nodes, max_nodes, waiting_timeout): + def report_rdzv_params( + self, min_nodes, max_nodes, waiting_timeout, node_unit + ): request = elastic_training_pb2.RendezvousParams() request.min_nodes = min_nodes request.max_nodes = max_nodes request.waiting_timeout = waiting_timeout + request.node_unit = node_unit response = self._stub.report_rdzv_params(request) return response.success @@ -582,7 +585,9 @@ def network_check_success(self, node_id): def report_node_status(self, normal): return True - def report_rdzv_params(self, min_nodes, max_nodes, waiting_timeout): + def report_rdzv_params( + self, min_nodes, max_nodes, waiting_timeout, node_unit + ): return True def kv_store_set(self, key, value): diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index ca07859b8..8a87dcc74 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -79,11 +79,13 @@ def __init__(self, name, rank_id, rdzv_params: RendezvousParameters): self._client = GlobalMasterClient.MASTER_CLIENT self._store = MasterKVStore(self._name, timedelta(seconds=60)) lastcall_timeout = int(rdzv_params.get("lastcall_timeout", 60)) + node_unit = int(rdzv_params.get("node_unit", "1")) if self._rank_id == 0: self._client.report_rdzv_params( rdzv_params.min_nodes, rdzv_params.max_nodes, lastcall_timeout, + node_unit, ) def get_backend(self) -> str: @@ -318,7 +320,7 @@ def _assign_worker_ranks( return workers def _initialize_workers(self, worker_group): - if self._config.network_check and self._restart_count == 0: + if self._config.network_check: run_network_check(self._config, self._entrypoint) super()._initialize_workers(worker_group) @@ -359,8 +361,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: return run_result elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: self._report_failure_to_master(run_result.failures) - has_fatal_error = self._has_fatal_error(run_result) - if not has_fatal_error and self._remaining_failovers > 0: + if self._remaining_failovers > 0: logger.info( f"[{role}] Worker group {state.name}. " f"{self._remaining_failovers}/{spec.max_restarts}" @@ -369,7 +370,6 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: self._remaining_failovers -= 1 self._restart_workers(self._worker_group) else: - logger.info("Cannot restart workers with fatal error.") self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED return run_result @@ -380,14 +380,6 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: else: raise Exception(f"[{role}] Worker group in {state.name} state") - def _has_fatal_error(self, run_result: RunResult): - """The error with exitcode 1 is the Python exception and we cannot - recover it by restarting workers.""" - for pfailure in run_result.failures.values(): - if pfailure.exitcode == 1: - return True - return False - def _report_failure_to_master(self, failures: Dict[int, ProcessFailure]): errors = {} for rank, failure in failures.items(): diff --git a/dlrover/python/master/elastic_training/rdzv_manager.py b/dlrover/python/master/elastic_training/rdzv_manager.py index 638cb870e..88bb75a6e 100644 --- a/dlrover/python/master/elastic_training/rdzv_manager.py +++ b/dlrover/python/master/elastic_training/rdzv_manager.py @@ -17,11 +17,38 @@ from threading import Lock from typing import Dict, List -from dlrover.python.common.constants import NetworkFailureReason +from dlrover.python.common.constants import ( + NetworkFailureReason, + RendezvousName, +) from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node +class RendezvousParameters(object): + """Holds the parameters to construct rendezvous. + Args: + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + waiting_timeout: + An additional wait amount before completing the rendezvous once + the rendezvous has the minimum number of required participants. + Default 30s, + """ + + def __init__( + self, + min_nodes: int, + max_nodes: int, + waiting_timeout=30, + ): + self.min_nodes = min_nodes + self.max_nodes = max_nodes + self.waiting_timeout = waiting_timeout + + class RendezvousManager(metaclass=ABCMeta): def __init__(self): self._lock = Lock() @@ -32,28 +59,132 @@ def __init__(self): self._lastcall_time = 0 self._rdzv_params = RendezvousParameters(0, 0) self._rdzv_round = 0 + self._node_unit = 1 + self._name = "" + self._latest_rdzv_nodes = [] - def update_rdzv_params(self, min_nodes, max_ndoes, waiting_timeout): - """Update rendezvous parameters""" - - @abstractmethod def add_alive_node(self, node: Node): """When a node is running, the master will add it to alive list.""" - pass + self._alive_nodes.add(node.id) + logger.info( + f"Add alive worker {node.name} to {self._name} rendezvous." + ) - @abstractmethod def remove_alive_node(self, node: Node): """When a node is exited, the master will remove it from alive list.""" - pass + if node.id in self._alive_nodes: + self._alive_nodes.remove(node.id) + logger.info( + f"Remove exited worker {node.name} from " + f"{self._name} rendezvous." + ) + self._waiting_nodes.pop(node.rank_index, 0) - @abstractmethod - def get_comm_world(self, node_id): - """Get communication world of all alive nodes.""" - pass + def update_rdzv_params( + self, min_nodes, max_ndoes, waiting_timeout, node_unit + ): + """Update rendezvous parameters + Args: + min_nodes: The minimum number of nodes. + max_nodes: THe maximum number of nodes. + waiting_timeout: the time to wait more workers. + worker_unit: the number unit of workers to build the communication + world. This is, the number of nodes in a world should be + a multiple of worker_unit. + """ + self._rdzv_params.min_nodes = min_nodes + self._rdzv_params.max_nodes = max_ndoes + self._rdzv_params.waiting_timeout = waiting_timeout + self._node_unit = node_unit + + def _check_rdzv_completed(self): + rdzv_completed = False + waiting_num = len(self._waiting_nodes) + if len(self._waiting_nodes) == self._rdzv_params.max_nodes: + rdzv_completed = True + else: + waiting_time = time.time() - self._lastcall_time + if ( + waiting_num >= self._rdzv_params.min_nodes + and waiting_time >= self._rdzv_params.waiting_timeout + ): + rdzv_completed = True + waiting_num = ( + waiting_num // self._node_unit + ) * self._node_unit + + if rdzv_completed: + node_ids = sorted(self._waiting_nodes.keys())[0:waiting_num] + self._rdzv_nodes = {} + for i in node_ids: + self._rdzv_nodes[i] = self._waiting_nodes[i] + self._latest_rdzv_nodes = list(self._rdzv_nodes.keys()) + self._waiting_nodes = dict() + self._lastcall_time = 0 + logger.info( + f"Completed {self._rdzv_round} round " + f"rendezvous of elastic training is {self._rdzv_nodes}" + ) + return rdzv_completed + + def not_joined_rdzv_nodes(self): + """Return workers which do not join a rendezvous.""" + nodes = [] + if self._rdzv_nodes: + for node_id in self._alive_nodes: + if node_id not in self._rdzv_nodes: + nodes.append(node_id) + return nodes + + def join_rendezvous( + self, + rank_id, + local_world_size, + ): + """The node joins the current rond rendezvous. + Args: + node_id: the node ID which is unique in an ElasticJob of DLrover. + local_world_size: the local world size of a node. + + Returns: + int: the number of rendezvous round. + """ + with self._lock: + if rank_id in self._waiting_nodes: + return + self._waiting_nodes[rank_id] = local_world_size + logger.info(f"{self._name} waiting nodes: {self._waiting_nodes}") + self._rdzv_nodes = {} + self._lastcall_time = time.time() + return self._rdzv_round + + def num_nodes_waiting(self): + """The elastic agent will restart training processes if it + find the number of waiting nodes is not zero. The manager + will notify all nodes to restart training processes immediately if + ab existing node re-joins the next round rendezvous. + If there are new nodes, the master notifies all nodes to re-join + the next round rendezvous only when the number of waiting nodes + is bigger than the number unit of nodes. + """ + with self._lock: + if self._has_node_restart(): + return len(self._waiting_nodes) + elif len(self._waiting_nodes) >= self._node_unit: + return len(self._waiting_nodes) + return 0 + + def _has_node_restart(self): + """The node will restart training processes if it + re-joins the rendezvous.""" + for rank_id in self._waiting_nodes.keys(): + if rank_id in self._latest_rdzv_nodes: + return True + return False @abstractmethod - def join_rendezvous(self, node_id, local_world_size): - """The node joins a rond rendezvous.""" + def get_comm_world(self, rank_id): + """Get communication world of all alive nodes.""" pass @abstractmethod @@ -61,35 +192,6 @@ def report_network_check_result(self, node_id: int, normal: bool): """The node updates its status""" pass - @abstractmethod - def num_nodes_waiting(self): - """Get the number of waiting nodes.""" - pass - - -class RendezvousParameters(object): - """Holds the parameters to construct rendezvous. - Args: - min_nodes: - The minimum number of nodes to admit to the rendezvous. - max_nodes: - The maximum number of nodes to admit to the rendezvous. - waiting_timeout: - An additional wait amount before completing the rendezvous once - the rendezvous has the minimum number of required participants. - Default 30s, - """ - - def __init__( - self, - min_nodes: int, - max_nodes: int, - waiting_timeout=30, - ): - self.min_nodes = min_nodes - self.max_nodes = max_nodes - self.waiting_timeout = waiting_timeout - class ElasticTrainingRendezvousManager(RendezvousManager): """ElasticTrainingRendezvousManager runs on the DLRover master. The manager @@ -109,30 +211,9 @@ class ElasticTrainingRendezvousManager(RendezvousManager): def __init__(self): super().__init__() + self._name = RendezvousName.ELASTIC_TRAINING - def update_rdzv_params(self, min_nodes, max_ndoes, waiting_timeout): - """Update rendezvous parameters""" - self._rdzv_params.min_nodes = min_nodes - self._rdzv_params.max_nodes = max_ndoes - self._rdzv_params.waiting_timeout = waiting_timeout - - def add_alive_node(self, node: Node): - """When a node is running, the master will add it to alive list.""" - self._alive_nodes.add(node.id) - logger.info( - f"Add alive worker {node.name} to elastic training rendezvous." - ) - - def remove_alive_node(self, node: Node): - """When a node is exited, the master will remove it from alive list.""" - if node.id in self._alive_nodes: - self._alive_nodes.remove(node.id) - logger.info(f"Remove exited worker {node.name} from Rendezvous.") - - def get_released_workers(self): - return [] - - def get_comm_world(self, node_id): + def get_comm_world(self, rank_id): """Return the communication world if a round rendezvous is completed. The rendezvous is completed if one of the following conditions is satisfied: @@ -146,58 +227,14 @@ def get_comm_world(self, node_id): and the value is the local world size of the node. """ with self._lock: - rdzv_completed = False - if self._rdzv_nodes: - return 0, self._rdzv_nodes - if len(self._waiting_nodes) == self._rdzv_params.max_nodes: - rdzv_completed = True - else: - waiting_num = len(self._waiting_nodes) - alive_num = len(self._alive_nodes) - waiting_time = time.time() - self._lastcall_time - rdzv_completed = ( - waiting_num >= self._rdzv_params.min_nodes - and waiting_num == alive_num - and waiting_time >= self._rdzv_params.waiting_timeout - ) - - if rdzv_completed: - self._rdzv_nodes = dict(sorted(self._waiting_nodes.items())) - self._waiting_nodes = dict() - self._lastcall_time = 0 - logger.info( - f"Completed {self._rdzv_round} round " - f"rendezvous of elastic training is {self._rdzv_nodes}" - ) - self._rdzv_round += 1 - - return 0, self._rdzv_nodes - - def join_rendezvous(self, node_id, local_world_size): - """The node joins the current rond rendezvous. - Args: - node_id: the node ID which is unique in an ElasticJob of DLrover. - local_world_size: the local world size of a node. - - Returns: - int: the number of rendezvous round. - """ - with self._lock: - if node_id in self._waiting_nodes: - return - self._waiting_nodes[node_id] = local_world_size - self._rdzv_nodes = {} - if len(self._waiting_nodes) >= self._rdzv_params.min_nodes: - if self._lastcall_time == 0: - self._lastcall_time = time.time() - return self._rdzv_round + if not self._rdzv_nodes: + rdzv_completed = self._check_rdzv_completed() + if rdzv_completed: + self._rdzv_round += 1 - def num_nodes_waiting(self): - """The number of waiting nodes. The agent of a node will re-join - a rendezvous if it finds there are waiting nodes. - """ - with self._lock: - return len(self._waiting_nodes) + if rank_id not in self._rdzv_nodes: + return self._rdzv_round, {} + return self._rdzv_round, self._rdzv_nodes def report_network_check_result(self, node_id, normal): return @@ -223,61 +260,19 @@ class NetworkCheckRendezvousManager(RendezvousManager): def __init__(self): super().__init__() + self._name = RendezvousName.NETWORK_CHECK self._node_status: Dict[int, bool] = {} self._reported_nodes = set() self._node_groups: List[Dict[int, int]] = [] - def update_rdzv_params(self, min_nodes, max_ndoes, waiting_timeout): - """Update rendezvous parameters""" - self._rdzv_params.min_nodes = min_nodes - self._rdzv_params.max_nodes = max_ndoes - self._rdzv_params.waiting_timeout = waiting_timeout - - def add_alive_node(self, node: Node): - """When a node is running, the master will add it to alive list.""" - self._alive_nodes.add(node.id) - logger.info( - f"Add alive worker {node.name} to network check rendezvous." - ) - - def remove_alive_node(self, node: Node): - """When a node is exited, the master will remove it from alive list.""" - if node.id in self._alive_nodes: - self._alive_nodes.remove(node.id) - logger.info(f"Remove exited worker {node.name} from Rendezvous.") - - def get_released_workers(self): - return [] - - def get_comm_world(self, node_id): + def get_comm_world(self, rank_id): """Return the communication world if a round rendezvous is completed. The rendezvous is completed if one of the following conditions. """ with self._lock: - rdzv_completed = False if not self._node_groups: - if len(self._waiting_nodes) == self._rdzv_params.max_nodes: - rdzv_completed = True - else: - waiting_num = len(self._waiting_nodes) - alive_num = len(self._alive_nodes) - waiting_time = time.time() - self._lastcall_time - rdzv_completed = ( - waiting_num >= self._rdzv_params.min_nodes - and waiting_num == alive_num - and waiting_time >= self._rdzv_params.waiting_timeout - ) - + rdzv_completed = self._check_rdzv_completed() if rdzv_completed: - self._rdzv_nodes = dict( - sorted(self._waiting_nodes.items()) - ) - self._waiting_nodes = dict() - self._lastcall_time = 0 - logger.info( - f"Completed {self._rdzv_round} round " - f"rendezvous of network check is {self._rdzv_nodes}" - ) self._node_groups = self._group_nodes(self._rdzv_round) logger.info( f"Round {self._rdzv_round} " @@ -289,7 +284,7 @@ def get_comm_world(self, node_id): self._rdzv_round += 1 for i, group in enumerate(self._node_groups): - if node_id in group: + if rank_id in group: return i, group return 0, {} @@ -350,31 +345,19 @@ def report_network_check_result(self, node_id: int, succeed): def join_rendezvous( self, - node_id, + rank_id, local_world_size, ): """The node joins the current rond rendezvous. Args: - node_id: the node ID which is unique in an ElasticJob of DLrover. + rank_id: the node ID which is unique in an ElasticJob of DLrover. local_world_size: the local world size of a node. Returns: int: the number of rendezvous round. """ - with self._lock: - if node_id in self._waiting_nodes: - return - self._waiting_nodes[node_id] = local_world_size - self._rdzv_nodes = {} - self._node_groups = [] - if len(self._waiting_nodes) >= self._rdzv_params.min_nodes: - if self._lastcall_time == 0: - self._lastcall_time = time.time() - return self._rdzv_round - - def num_nodes_waiting(self): - with self._lock: - return len(self._waiting_nodes) + self._node_groups = [] + return super().join_rendezvous(rank_id, local_world_size) def network_check_success(self): """Check the network task is succeed. Each task contains 3 rounds diff --git a/dlrover/python/master/master.py b/dlrover/python/master/master.py index 9a41b1c8a..4fb07d84a 100644 --- a/dlrover/python/master/master.py +++ b/dlrover/python/master/master.py @@ -12,6 +12,7 @@ # limitations under the License. import time +from typing import Dict from dlrover.python.common.constants import ( DistributionStrategy, @@ -26,6 +27,7 @@ from dlrover.python.master.elastic_training.rdzv_manager import ( ElasticTrainingRendezvousManager, NetworkCheckRendezvousManager, + RendezvousManager, ) from dlrover.python.master.elastic_training.sync_service import SyncService from dlrover.python.master.monitor.speed_monitor import SpeedMonitor @@ -64,7 +66,7 @@ def __init__(self, port, args: JobArgs): else None ) elastic_training = RendezvousName.ELASTIC_TRAINING - self.rdzv_managers = { + self.rdzv_managers: Dict[str, RendezvousManager] = { elastic_training: ElasticTrainingRendezvousManager(), RendezvousName.NETWORK_CHECK: NetworkCheckRendezvousManager(), } @@ -198,10 +200,10 @@ def run(self): def _remove_not_participated_workers(self): """Remove workers who do not participate training.""" - et_manager = self.rdzv_managers[RendezvousName.ELASTIC_TRAINING] - workers = et_manager.get_released_workers() - if workers: - self.job_manager.remove_not_participated_workers(workers) + for manager in self.rdzv_managers.values(): + ranks = manager.not_joined_rdzv_nodes() + if ranks: + self.job_manager.remove_not_joined_rdzv_workers(ranks) def stop(self): """ diff --git a/dlrover/python/master/node/job_manager.py b/dlrover/python/master/node/job_manager.py index 29413fef5..3f5eeed99 100644 --- a/dlrover/python/master/node/job_manager.py +++ b/dlrover/python/master/node/job_manager.py @@ -650,8 +650,10 @@ def all_running_node_hanged(self): return all(node_hang) return False - def remove_not_participated_workers(self, workers): - plan = self._worker_manager.remove_not_participated_workers(workers) + def remove_not_joined_rdzv_workers(self, worker_ranks): + plan = self._worker_manager.remove_not_joined_rdzv_workers( + worker_ranks + ) self._scaler.scale(plan) def pend_without_workers(self): diff --git a/dlrover/python/master/node/worker.py b/dlrover/python/master/node/worker.py index 7af0d4354..ff192bbce 100644 --- a/dlrover/python/master/node/worker.py +++ b/dlrover/python/master/node/worker.py @@ -251,12 +251,16 @@ def migrate_workers(self, workers: Dict[str, NodeResource]): plan.remove_nodes.append(old_node) return plan - def remove_not_participated_workers(self, workers): - """Remove workers which do not participate in the training.""" + def remove_not_joined_rdzv_workers(self, worker_ranks: List[int]): + """Remove workers which do not participate in the training. + Args: + worker_ranks: The rank of worker which does not join rendezvous. + """ plan = ScalePlan() - for worker_id, worker in self._nodes.items(): - if worker.name in workers: - p = self.remove_node(worker_id) + for node_id, node in self._nodes.items(): + if node.rank_index in worker_ranks: + p = self.remove_node(node.id) + self._nodes[node_id].relaunchable = False if p: plan.merge(p) return plan diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index ef5968ac4..940cb2207 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -395,8 +395,8 @@ def get_comm_world(self, request, _): group, nodes = rdzv_manager.get_comm_world(request.node_id) res = elastic_training_pb2.RendezvousState() res.group = group - for node_id, worker_num in nodes.items(): - res.world[node_id] = worker_num + for rank_id, worker_num in nodes.items(): + res.world[rank_id] = worker_num return res def join_rendezvous(self, request, _): @@ -409,8 +409,10 @@ def join_rendezvous(self, request, _): return res def num_nodes_waiting(self, request, _): - rdzv_manager = self._rdzv_managers[request.rdzv_name] - waiting_num = rdzv_manager.num_nodes_waiting() + waiting_num = 0 + for rdzv_manager in self._rdzv_managers.values(): + num = rdzv_manager.num_nodes_waiting() + waiting_num = max(num, waiting_num) res = elastic_training_pb2.RendezvousState() res.waiting_num = waiting_num return res @@ -421,6 +423,7 @@ def report_rdzv_params(self, request, _): min_nodes=request.min_nodes, max_ndoes=request.max_nodes, waiting_timeout=request.waiting_timeout, + node_unit=request.node_unit, ) res = elastic_training_pb2.Response() res.success = True diff --git a/dlrover/python/master/watcher/k8s_watcher.py b/dlrover/python/master/watcher/k8s_watcher.py index 52ae6902f..dab43a618 100644 --- a/dlrover/python/master/watcher/k8s_watcher.py +++ b/dlrover/python/master/watcher/k8s_watcher.py @@ -48,9 +48,7 @@ def _get_pod_exit_reason(pod): and pod.status.container_statuses[0].state.terminated ): terminated = pod.status.container_statuses[0].state.terminated - pod_name = pod.metadata.name exit_code = terminated.exit_code - logger.warning(f"Pod {pod_name} exits with exitcode {exit_code}") if terminated.reason == "OOMKilled" or exit_code == ExitCode.OOM_CODE: return NodeExitReason.OOM elif exit_code in [ExitCode.KILLED_CODE, ExitCode.TERMED_CODE]: diff --git a/dlrover/python/tests/test_rdzv_manager.py b/dlrover/python/tests/test_rdzv_manager.py index 3714c8e2c..a52021c8a 100644 --- a/dlrover/python/tests/test_rdzv_manager.py +++ b/dlrover/python/tests/test_rdzv_manager.py @@ -46,7 +46,7 @@ def test_kv_store_api(self): class ElasticTrainingRendezvousManagerTest(unittest.TestCase): def test_max_nodes(self): rdzv_manager = ElasticTrainingRendezvousManager() - rdzv_manager.update_rdzv_params(3, 3, 60) + rdzv_manager.update_rdzv_params(3, 3, 60, 1) rdzv_manager._alive_nodes = [0, 1, 2] rdzv_manager.join_rendezvous(0, 8) rdzv_manager.join_rendezvous(1, 8) @@ -60,7 +60,7 @@ def test_max_nodes(self): def test_min_nodes(self): rdzv_manager = ElasticTrainingRendezvousManager() - rdzv_manager.update_rdzv_params(2, 3, 0.1) + rdzv_manager.update_rdzv_params(2, 3, 0.1, 1) node_1 = Node("worker", 1) rdzv_manager.add_alive_node(node_1) node_0 = Node("worker", 0) @@ -79,11 +79,53 @@ def test_min_nodes(self): self.assertEqual(len(rdzv_manager._rdzv_nodes), 2) self.assertDictEqual(world, {0: 8, 1: 8}) + def test_min_nodes_with_unit(self): + rdzv_manager = ElasticTrainingRendezvousManager() + rdzv_manager.update_rdzv_params(8, 12, 0.1, 4) + for i in range(10): + node = Node("worker", i, name=f"worker-{i}") + rdzv_manager.add_alive_node(node) + rdzv_manager.join_rendezvous(i, 8) + self.assertEqual(len(rdzv_manager._alive_nodes), 10) + self.assertEqual(len(rdzv_manager._waiting_nodes), 10) + self.assertEqual(len(rdzv_manager._rdzv_nodes), 0) + time.sleep(0.2) + _, world = rdzv_manager.get_comm_world(1) + self.assertEqual(len(rdzv_manager._waiting_nodes), 0) + self.assertEqual(len(rdzv_manager._rdzv_nodes), 8) + expected_world = {i: 8 for i in range(8)} + self.assertDictEqual(expected_world, world) + _, world = rdzv_manager.get_comm_world(9) + self.assertDictEqual(world, {}) + + # Test the number of waiting nodes is less than the node unit. + rdzv_manager.join_rendezvous(10, 8) + rdzv_manager.join_rendezvous(11, 8) + num = rdzv_manager.num_nodes_waiting() + self.assertEqual(num, 0) + self.assertEqual(len(rdzv_manager._waiting_nodes), 2) + node_10 = Node("worker", 10, name="worker-10") + node_11 = Node("worker", 11, name="worker-11") + + # Test removing nodes from waiting nodes. + rdzv_manager.add_alive_node(node_10) + rdzv_manager.add_alive_node(node_11) + rdzv_manager.remove_alive_node(node_10) + rdzv_manager.remove_alive_node(node_11) + self.assertEqual(len(rdzv_manager._waiting_nodes), 0) + + # Test the number of waiting nodes is equal or + # bigger than the node unit. + for i in range(12, 16): + rdzv_manager.join_rendezvous(i, 8) + num = rdzv_manager.num_nodes_waiting() + self.assertEqual(num, 4) + class NcclCheckRendezvousManagerTest(unittest.TestCase): def test_network_check_rdzv(self): rdzv_manager = NetworkCheckRendezvousManager() - rdzv_manager.update_rdzv_params(4, 4, 60) + rdzv_manager.update_rdzv_params(4, 4, 60, 1) rdzv_manager._alive_nodes = [0, 1, 2, 3] for i in range(4): round = rdzv_manager.join_rendezvous(i, 8) diff --git a/dlrover/trainer/torch/elastic_run.py b/dlrover/trainer/torch/elastic_run.py index 7be167582..5c1a6113b 100644 --- a/dlrover/trainer/torch/elastic_run.py +++ b/dlrover/trainer/torch/elastic_run.py @@ -14,7 +14,7 @@ import uuid from typing import Callable, Union -from torch.distributed.argparse_util import check_env +from torch.distributed.argparse_util import check_env, env from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.launcher.api import LaunchConfig from torch.distributed.run import config_from_args, get_args_parser @@ -27,9 +27,19 @@ def parse_args(args): parser = get_args_parser() parser.add_argument( "--network-check", + "--network_check", action=check_env, help="Whether to check network before starting training process.", ) + parser.add_argument( + "--node_unit", + "--node-unit", + type=int, + action=env, + default=1, + help="The number unit of nodes to schedule. The scheduled number of " + "nodes should be a multiple of node_unit.", + ) return parser.parse_args(args) @@ -89,10 +99,12 @@ def run(args): ) config, cmd, cmd_args = config_from_args(args) + setattr(config, "network_check", False) + setattr(config, "node_unit", 1) if hasattr(args, "network_check"): config.network_check = args.network_check - else: - config.network_check = False + if hasattr(args, "node_unit"): + config.rdzv_configs["node_unit"] = args.node_unit elastic_launch( config=config, entrypoint=cmd, diff --git a/dlrover/trainer/torch/run_network_check.py b/dlrover/trainer/torch/run_network_check.py index a602dd523..943744a1f 100644 --- a/dlrover/trainer/torch/run_network_check.py +++ b/dlrover/trainer/torch/run_network_check.py @@ -43,13 +43,10 @@ def main(use_cuda): if __name__ == "__main__": - try: - use_cuda = torch.cuda.is_available() - if use_cuda: - dist.init_process_group("nccl", timeout=timedelta(seconds=180)) - else: - dist.init_process_group("gloo", timeout=timedelta(seconds=180)) - main(use_cuda) - finally: - dist.destroy_process_group() + use_cuda = torch.cuda.is_available() + if use_cuda: + dist.init_process_group("nccl", timeout=timedelta(seconds=180)) + else: + dist.init_process_group("gloo", timeout=timedelta(seconds=180)) + main(use_cuda) logger.info("Finish testing allgather.")