diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/grpc.py index 7220230d1..a0c78a756 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/grpc.py @@ -509,3 +509,8 @@ class ElasticRunConfigRequest(Message): @dataclass class ElasticRunConfig(Message): configs: Dict[str, str] = field(default_factory=dict) + + +@dataclass +class SucceededRequest(Message): + pass diff --git a/dlrover/python/common/node.py b/dlrover/python/common/node.py index 1b146650f..63fbc49ea 100644 --- a/dlrover/python/common/node.py +++ b/dlrover/python/common/node.py @@ -216,6 +216,7 @@ def __init__( self.migrated = False self.unrecoverable_failure_msg = "" self.heartbeat_time = 0 + self.succeeded = False def exited(self): return self.status in [ @@ -340,6 +341,12 @@ def timeout(self, timeout): ): return True + def set_as_succeeded(self): + self.succeeded = True + + def is_succeeded(self): + return self.succeeded + def __repr__(self): return ( f"name:{self.name};" diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index e02554018..e646da113 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -420,6 +420,11 @@ def get_elastic_run_config(self) -> Dict[str, str]: response: grpc.ElasticRunConfig = self._get(request) return response.configs + def report_succeeded(self): + request = grpc.SucceededRequest() + response = self._report(request) + return response.success + @classmethod def singleton_instance(cls, *args, **kwargs): if not cls._instance: diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 508ed96d4..29eebd7d4 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -850,6 +850,8 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: logger.info("Async saver stopped.") except Exception as e: logger.warning(f"Unexpected exception when ending: {e}") + finally: + self._client.report_succeeded() return run_result elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index 6cc00cea2..cb7151ace 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -470,6 +470,7 @@ def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]: and node.start_time and node.create_time and node.status == NodeStatus.RUNNING + and not node.is_succeeded() ): if ( node.heartbeat_time <= node.start_time.timestamp() @@ -1135,6 +1136,10 @@ def collect_node_heart_beat(self, node_type, node_id, timestamp): def update_node_required_info_callback(self): self._worker_manager.update_node_required_info(self._nodes_required) + def update_succeeded_node(self, node_id, node_type): + with self._lock: + super().update_succeeded_node(node_id, node_type) + def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager: critical_worker_index = get_critical_worker_index(args) diff --git a/dlrover/python/master/node/job_manager.py b/dlrover/python/master/node/job_manager.py index 77916a1e7..acec71396 100644 --- a/dlrover/python/master/node/job_manager.py +++ b/dlrover/python/master/node/job_manager.py @@ -231,3 +231,11 @@ def update_node_required_info_callback(self): def get_elastic_run_configs(self) -> Dict[str, str]: return self._training_node_config.get_elastic_run_configs() + + def update_succeeded_node(self, node_id, node_type): + if ( + node_type in self._job_nodes + and node_id in self._job_nodes[node_type] + ): + logger.info(f"Node {node_id}({node_type}) to succeeded.") + self._job_nodes[node_type][node_id].set_as_succeeded() diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 2c08b5938..e8250dc17 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -361,6 +361,8 @@ def report(self, request, _): success = self._sync_checkpoint(node_type, node_id, message) elif isinstance(message, grpc.DiagnosisReportData): success = self._report_worker_diagnosis_data(message) + elif isinstance(message, grpc.SucceededRequest): + success = self._report_succeeded(node_id, node_type) response.success = success return response @@ -632,6 +634,10 @@ def _report_worker_diagnosis_data(self, message: grpc.DiagnosisReportData): self._diagnosis_manager.collect_diagnosis_data(data_obj) return True + def _report_succeeded(self, node_id, node_type): + self._job_manager.update_succeeded_node(node_id, node_type) + return True + def _sync_training_ports( self, node_id, message: grpc.SyncTrainingPort ) -> grpc.SyncTrainingPort: diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index ac6840a2f..770ecde85 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -354,6 +354,7 @@ def test_get_dead_node_event(self): node.status = NodeStatus.RUNNING events = manager._get_dead_node_event() self.assertEqual(len(events), 0) + for index, node in enumerate( manager._job_nodes[NodeType.WORKER].values() ): @@ -373,6 +374,23 @@ def test_get_dead_node_event(self): self.assertIsNotNone(nodes_time_info) self.assertEqual(len(nodes_time_info), 3) + for index, node in enumerate( + manager._job_nodes[NodeType.WORKER].values() + ): + node.status = NodeStatus.RUNNING + now = datetime.now() + node.heartbeat_time = (now - timedelta(seconds=1000)).timestamp() + if index == 0: + node.create_time = now - timedelta(seconds=800) + node.start_time = now - timedelta(seconds=600) + else: + if index == 1: + node.succeeded = True + node.create_time = now - timedelta(seconds=1400) + node.start_time = now - timedelta(seconds=1200) + events = manager._get_dead_node_event() + self.assertEqual(len(events), 1) + def test_relaunch_training_master(self): params = MockK8sPSJobArgs() params.initilize() @@ -736,3 +754,16 @@ def test_local_job_manager(self): worker = job_manager._job_nodes[NodeType.WORKER][0] self.assertEqual(worker.paral_config, paral_config) job_manager.handle_training_failure(NodeType.WORKER, 3) + + try: + self.assertFalse( + job_manager._job_nodes[NodeType.WORKER][0].is_succeeded() + ) + job_manager.update_succeeded_node(0, NodeType.WORKER) + self.assertTrue( + job_manager._job_nodes[NodeType.WORKER][0].is_succeeded() + ) + job_manager.update_succeeded_node(5, NodeType.WORKER) + job_manager.update_succeeded_node(0, "unknown") + except Exception: + self.fail() diff --git a/dlrover/python/tests/test_node.py b/dlrover/python/tests/test_node.py index 0139cc726..fcd1fc6bc 100644 --- a/dlrover/python/tests/test_node.py +++ b/dlrover/python/tests/test_node.py @@ -46,3 +46,7 @@ def test_is_unrecoverable_failure(self): is_unrecoverable = node.is_unrecoverable_failure() self.assertEqual(is_unrecoverable, True) self.assertEqual("oom" in node.unrecoverable_failure_msg, True) + + self.assertFalse(node.is_succeeded()) + node.set_as_succeeded() + self.assertTrue(node.is_succeeded()) diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 37e109aa3..b30b6e999 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -424,6 +424,10 @@ def test_report_worker_diagnosis_data(self): ) self.assertTrue(self.servicer._report_worker_diagnosis_data(request)) + def test_report_succeeded(self): + self.assertTrue(self.servicer._report_succeeded(0, NodeType.WORKER)) + self.assertTrue(self.servicer._report_succeeded(0, "test")) + class MasterServicerForRayTest(unittest.TestCase): def setUp(self) -> None: