From 94106cd49d1dea5b1e2dec28ad13cbb5a1138251 Mon Sep 17 00:00:00 2001 From: Ma Jie Yue Date: Sat, 12 Oct 2024 16:53:50 +0800 Subject: [PATCH 1/5] add exception handler in _get_master_addr_port since the port might be null --- dlrover/python/elastic_agent/torch/training.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 508ed96d4..849d3ebb2 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -485,7 +485,7 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: spec.master_port, ) - master_addr, master_port = self._get_master_addr_port(store) + master_addr, master_port = self._safe_get_master_addr_port(store) # compatible with torch 2.4 if not version_less_than_240(): @@ -543,6 +543,17 @@ def _get_master_addr_port(self, store: Store) -> Tuple[str, int]: master_port = int(store.get("MASTER_PORT").decode(encoding="UTF-8")) return (master_addr, master_port) + def _safe_get_master_addr_port(self, store: Store) -> Tuple[str, int]: + for i in range(1,5): + try: + addr, port = self._get_master_addr_port(store) + return (addr, port) + except Exception as e: + logger.warning(f"_get_master_addr_port failed with exception {e}, will try again") + time.sleep(10) + + raise ValueError("invalid value in _get_master_addr_port") + def _get_socket_with_port(self) -> socket.socket: """Return a free port on localhost. From cf73308c66773e5c3ecd210583ad1dbc3ce7bd7c Mon Sep 17 00:00:00 2001 From: Ma JieYue Date: Sat, 12 Oct 2024 09:11:29 +0000 Subject: [PATCH 2/5] precommit fix --- dlrover/python/elastic_agent/torch/training.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 849d3ebb2..e0fe7fd9f 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -544,12 +544,14 @@ def _get_master_addr_port(self, store: Store) -> Tuple[str, int]: return (master_addr, master_port) def _safe_get_master_addr_port(self, store: Store) -> Tuple[str, int]: - for i in range(1,5): + for i in range(1, 5): try: addr, port = self._get_master_addr_port(store) return (addr, port) except Exception as e: - logger.warning(f"_get_master_addr_port failed with exception {e}, will try again") + logger.warning( + f"_get_master_addr_port failed with exception {e}, will try again" + ) time.sleep(10) raise ValueError("invalid value in _get_master_addr_port") From 0df1695e049f51899eed2f519d4d779c7426e9a4 Mon Sep 17 00:00:00 2001 From: Ma Jie Yue Date: Sat, 12 Oct 2024 17:16:48 +0800 Subject: [PATCH 3/5] meet precommit check --- dlrover/python/elastic_agent/torch/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index e0fe7fd9f..c8c1c622e 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -550,7 +550,7 @@ def _safe_get_master_addr_port(self, store: Store) -> Tuple[str, int]: return (addr, port) except Exception as e: logger.warning( - f"_get_master_addr_port failed with exception {e}, will try again" + f"_get_master_addr_port failed with exception {e}" ) time.sleep(10) From 260b7358edcb2dca3de8a41556d0027413621d85 Mon Sep 17 00:00:00 2001 From: Ma Jie Yue Date: Wed, 16 Oct 2024 09:19:59 +0800 Subject: [PATCH 4/5] change some code style and add UT case for _safe_get_master_addr_port --- .../python/elastic_agent/torch/training.py | 5 ++-- .../tests/test_elastic_training_agent.py | 24 +++++++++++++++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index c8c1c622e..5d80a3a3e 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -544,10 +544,9 @@ def _get_master_addr_port(self, store: Store) -> Tuple[str, int]: return (master_addr, master_port) def _safe_get_master_addr_port(self, store: Store) -> Tuple[str, int]: - for i in range(1, 5): + for _ in range(5): try: - addr, port = self._get_master_addr_port(store) - return (addr, port) + return self._get_master_addr_port(store) except Exception as e: logger.warning( f"_get_master_addr_port failed with exception {e}" diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index afcd03ffb..8215e42fa 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -17,6 +17,7 @@ import socket import tempfile import time +import threading import unittest from unittest import mock from unittest.mock import patch @@ -137,7 +138,7 @@ def test_auto_configure(self): config.auto_configure_params() self.assertEqual(config.failure_node_errors, "") - def test_rank0_rendzevous(self): + def test_rank0_rendezvous(self): agent = ElasticTrainingAgent( node_rank=0, config=self.config, @@ -166,7 +167,7 @@ def test_rank0_rendzevous(self): agent._membership_changed("default", self.rdzv_handler) ) - def test_rank1_rendzevous(self): + def test_rank1_rendezvous(self): agent = ElasticTrainingAgent( node_rank=1, config=self.config, @@ -180,9 +181,22 @@ def test_rank1_rendzevous(self): self.rdzv_handler._client.join_rendezvous( 0, 8, self.rdzv_handler._name ) + store = self.rdzv_handler._get_store(round=1, group=0) - store.set("MASTER_ADDR", "127.0.0.1".encode()) - store.set("MASTER_PORT", "12345".encode()) + + def _set_store(store): + time.sleep(5) + store.set("MASTER_ADDR", "127.0.0.1".encode()) + store.set("MASTER_PORT", "12345".encode()) + + _task = threading.Thread(target=_set_store, args=(store,)) + _task.start() + + addr, port = agent._safe_get_master_addr_port(store) + print(addr) + print(port) + self.assertEqual(addr, "127.0.0.1") + self.assertEqual(port, 12345) # Set the node id and rank as 1. agent._client._node_id = 1 @@ -196,6 +210,8 @@ def test_rank1_rendzevous(self): self.assertEqual(worker.local_rank, 1) self.assertEqual(worker.global_rank, 9) self.assertEqual(worker.world_size, 16) + self.assertEqual(store.get("MASTER_ADDR").decode(), "127.0.0.1") + self.assertEqual(store.get("MASTER_PORT").decode(), "12345") def test_get_local_ip(self): local_ip = _get_local_ip() From b466f012a2fcb64a343dd28287367a3e3cb9be4f Mon Sep 17 00:00:00 2001 From: Ma JieYue Date: Wed, 16 Oct 2024 06:49:51 +0000 Subject: [PATCH 5/5] pass pre-commit run -a --- dlrover/python/tests/test_elastic_training_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 8215e42fa..00404c55d 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -16,8 +16,8 @@ import shutil import socket import tempfile -import time import threading +import time import unittest from unittest import mock from unittest.mock import patch