Skip to content

Commit

Permalink
Scale down nodes with the number unit if not enough nodes. (#490)
Browse files Browse the repository at this point in the history
* Scale down nodes with an number unit if there are not enough nodes

* Rename worker_num_unit to node_unit

* Fix test cases

* The manager notifies workers to restart processes only when the number of nodes is a multiple of node unit

* Build rendzvous only when the number of new nodes is bigger than node unit
  • Loading branch information
workingloong authored Jul 12, 2023
1 parent 0a41611 commit 2b3e886
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 226 deletions.
1 change: 1 addition & 0 deletions dlrover/proto/elastic_training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ message RendezvousParams {
int32 min_nodes = 1;
int32 max_nodes = 2;
int32 waiting_timeout = 3;
int32 node_unit = 4;
}

message KeyValuePair {
Expand Down
9 changes: 7 additions & 2 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,14 @@ def network_check_success(self, timeout=180):
return response.success

@retry_grpc_request
def report_rdzv_params(self, min_nodes, max_nodes, waiting_timeout):
def report_rdzv_params(
self, min_nodes, max_nodes, waiting_timeout, node_unit
):
request = elastic_training_pb2.RendezvousParams()
request.min_nodes = min_nodes
request.max_nodes = max_nodes
request.waiting_timeout = waiting_timeout
request.node_unit = node_unit
response = self._stub.report_rdzv_params(request)
return response.success

Expand Down Expand Up @@ -582,7 +585,9 @@ def network_check_success(self, node_id):
def report_node_status(self, normal):
return True

def report_rdzv_params(self, min_nodes, max_nodes, waiting_timeout):
def report_rdzv_params(
self, min_nodes, max_nodes, waiting_timeout, node_unit
):
return True

def kv_store_set(self, key, value):
Expand Down
16 changes: 4 additions & 12 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ def __init__(self, name, rank_id, rdzv_params: RendezvousParameters):
self._client = GlobalMasterClient.MASTER_CLIENT
self._store = MasterKVStore(self._name, timedelta(seconds=60))
lastcall_timeout = int(rdzv_params.get("lastcall_timeout", 60))
node_unit = int(rdzv_params.get("node_unit", "1"))
if self._rank_id == 0:
self._client.report_rdzv_params(
rdzv_params.min_nodes,
rdzv_params.max_nodes,
lastcall_timeout,
node_unit,
)

def get_backend(self) -> str:
Expand Down Expand Up @@ -318,7 +320,7 @@ def _assign_worker_ranks(
return workers

def _initialize_workers(self, worker_group):
if self._config.network_check and self._restart_count == 0:
if self._config.network_check:
run_network_check(self._config, self._entrypoint)
super()._initialize_workers(worker_group)

Expand Down Expand Up @@ -359,8 +361,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
self._report_failure_to_master(run_result.failures)
has_fatal_error = self._has_fatal_error(run_result)
if not has_fatal_error and self._remaining_failovers > 0:
if self._remaining_failovers > 0:
logger.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_failovers}/{spec.max_restarts}"
Expand All @@ -369,7 +370,6 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
self._remaining_failovers -= 1
self._restart_workers(self._worker_group)
else:
logger.info("Cannot restart workers with fatal error.")
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
return run_result
Expand All @@ -380,14 +380,6 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
else:
raise Exception(f"[{role}] Worker group in {state.name} state")

def _has_fatal_error(self, run_result: RunResult):
"""The error with exitcode 1 is the Python exception and we cannot
recover it by restarting workers."""
for pfailure in run_result.failures.values():
if pfailure.exitcode == 1:
return True
return False

def _report_failure_to_master(self, failures: Dict[int, ProcessFailure]):
errors = {}
for rank, failure in failures.items():
Expand Down
Loading

0 comments on commit 2b3e886

Please sign in to comment.