Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize heartbeat timeout judgement if worker already succeeded. #1289

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dlrover/python/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,8 @@ class ElasticRunConfigRequest(Message):
@dataclass
class ElasticRunConfig(Message):
configs: Dict[str, str] = field(default_factory=dict)


@dataclass
class SucceededRequest(Message):
pass
7 changes: 7 additions & 0 deletions dlrover/python/common/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def __init__(
self.migrated = False
self.unrecoverable_failure_msg = ""
self.heartbeat_time = 0
self.succeeded = False

def exited(self):
return self.status in [
Expand Down Expand Up @@ -340,6 +341,12 @@ def timeout(self, timeout):
):
return True

def set_as_succeeded(self):
self.succeeded = True

def is_succeeded(self):
return self.succeeded

def __repr__(self):
return (
f"name:{self.name};"
Expand Down
5 changes: 5 additions & 0 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ def get_elastic_run_config(self) -> Dict[str, str]:
response: grpc.ElasticRunConfig = self._get(request)
return response.configs

def report_succeeded(self):
request = grpc.SucceededRequest()
response = self._report(request)
return response.success

@classmethod
def singleton_instance(cls, *args, **kwargs):
if not cls._instance:
Expand Down
2 changes: 2 additions & 0 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
logger.info("Async saver stopped.")
except Exception as e:
logger.warning(f"Unexpected exception when ending: {e}")
finally:
self._client.report_succeeded()

return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
Expand Down
5 changes: 5 additions & 0 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]:
and node.start_time
and node.create_time
and node.status == NodeStatus.RUNNING
and not node.is_succeeded()
):
if (
node.heartbeat_time <= node.start_time.timestamp()
Expand Down Expand Up @@ -1135,6 +1136,10 @@ def collect_node_heart_beat(self, node_type, node_id, timestamp):
def update_node_required_info_callback(self):
self._worker_manager.update_node_required_info(self._nodes_required)

def update_succeeded_node(self, node_id, node_type):
with self._lock:
super().update_succeeded_node(node_id, node_type)


def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager:
critical_worker_index = get_critical_worker_index(args)
Expand Down
8 changes: 8 additions & 0 deletions dlrover/python/master/node/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,11 @@ def update_node_required_info_callback(self):

def get_elastic_run_configs(self) -> Dict[str, str]:
return self._training_node_config.get_elastic_run_configs()

def update_succeeded_node(self, node_id, node_type):
if (
node_type in self._job_nodes
and node_id in self._job_nodes[node_type]
):
logger.info(f"Node {node_id}({node_type}) to succeeded.")
self._job_nodes[node_type][node_id].set_as_succeeded()
6 changes: 6 additions & 0 deletions dlrover/python/master/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def report(self, request, _):
success = self._sync_checkpoint(node_type, node_id, message)
elif isinstance(message, grpc.DiagnosisReportData):
success = self._report_worker_diagnosis_data(message)
elif isinstance(message, grpc.SucceededRequest):
success = self._report_succeeded(node_id, node_type)

response.success = success
return response
Expand Down Expand Up @@ -632,6 +634,10 @@ def _report_worker_diagnosis_data(self, message: grpc.DiagnosisReportData):
self._diagnosis_manager.collect_diagnosis_data(data_obj)
return True

def _report_succeeded(self, node_id, node_type):
self._job_manager.update_succeeded_node(node_id, node_type)
return True

def _sync_training_ports(
self, node_id, message: grpc.SyncTrainingPort
) -> grpc.SyncTrainingPort:
Expand Down
31 changes: 31 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@
node.status = NodeStatus.RUNNING
events = manager._get_dead_node_event()
self.assertEqual(len(events), 0)

for index, node in enumerate(
manager._job_nodes[NodeType.WORKER].values()
):
Expand All @@ -373,6 +374,23 @@
self.assertIsNotNone(nodes_time_info)
self.assertEqual(len(nodes_time_info), 3)

for index, node in enumerate(
manager._job_nodes[NodeType.WORKER].values()
):
node.status = NodeStatus.RUNNING
now = datetime.now()
node.heartbeat_time = (now - timedelta(seconds=1000)).timestamp()
if index == 0:
node.create_time = now - timedelta(seconds=800)
node.start_time = now - timedelta(seconds=600)
else:
if index == 1:
node.succeeded = True
node.create_time = now - timedelta(seconds=1400)
node.start_time = now - timedelta(seconds=1200)
events = manager._get_dead_node_event()
self.assertEqual(len(events), 1)

def test_relaunch_training_master(self):
params = MockK8sPSJobArgs()
params.initilize()
Expand Down Expand Up @@ -736,3 +754,16 @@
worker = job_manager._job_nodes[NodeType.WORKER][0]
self.assertEqual(worker.paral_config, paral_config)
job_manager.handle_training_failure(NodeType.WORKER, 3)

try:
self.assertFalse(
job_manager._job_nodes[NodeType.WORKER][0].is_succeeded()
)
job_manager.update_succeeded_node(0, NodeType.WORKER)
self.assertTrue(
job_manager._job_nodes[NodeType.WORKER][0].is_succeeded()
)
job_manager.update_succeeded_node(5, NodeType.WORKER)
job_manager.update_succeeded_node(0, "unknown")
except Exception:
self.fail()

Check warning on line 769 in dlrover/python/tests/test_job_manager.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/tests/test_job_manager.py#L768-L769

Added lines #L768 - L769 were not covered by tests
4 changes: 4 additions & 0 deletions dlrover/python/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,7 @@ def test_is_unrecoverable_failure(self):
is_unrecoverable = node.is_unrecoverable_failure()
self.assertEqual(is_unrecoverable, True)
self.assertEqual("oom" in node.unrecoverable_failure_msg, True)

self.assertFalse(node.is_succeeded())
node.set_as_succeeded()
self.assertTrue(node.is_succeeded())
4 changes: 4 additions & 0 deletions dlrover/python/tests/test_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ def test_report_worker_diagnosis_data(self):
)
self.assertTrue(self.servicer._report_worker_diagnosis_data(request))

def test_report_succeeded(self):
self.assertTrue(self.servicer._report_succeeded(0, NodeType.WORKER))
self.assertTrue(self.servicer._report_succeeded(0, "test"))


class MasterServicerForRayTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
Loading