Skip to content

Commit

Permalink
Merge pull request #1293 from majieyue/fix-invalid-master-port
Browse files Browse the repository at this point in the history
add exception handler in _get_master_addr_port since the port might b…
  • Loading branch information
majieyue authored Oct 21, 2024
2 parents e88d7eb + b466f01 commit fbd0bae
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
14 changes: 13 additions & 1 deletion dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None:
spec.master_port,
)

master_addr, master_port = self._get_master_addr_port(store)
master_addr, master_port = self._safe_get_master_addr_port(store)

# compatible with torch 2.4
if not version_less_than_240():
Expand Down Expand Up @@ -543,6 +543,18 @@ def _get_master_addr_port(self, store: Store) -> Tuple[str, int]:
master_port = int(store.get("MASTER_PORT").decode(encoding="UTF-8"))
return (master_addr, master_port)

def _safe_get_master_addr_port(self, store: Store) -> Tuple[str, int]:
for _ in range(5):
try:
return self._get_master_addr_port(store)
except Exception as e:
logger.warning(
f"_get_master_addr_port failed with exception {e}"
)
time.sleep(10)

raise ValueError("invalid value in _get_master_addr_port")

def _get_socket_with_port(self) -> socket.socket:
"""Return a free port on localhost.
Expand Down
24 changes: 20 additions & 4 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import shutil
import socket
import tempfile
import threading
import time
import unittest
from unittest import mock
Expand Down Expand Up @@ -137,7 +138,7 @@ def test_auto_configure(self):
config.auto_configure_params()
self.assertEqual(config.failure_node_errors, "")

def test_rank0_rendzevous(self):
def test_rank0_rendezvous(self):
agent = ElasticTrainingAgent(
node_rank=0,
config=self.config,
Expand Down Expand Up @@ -166,7 +167,7 @@ def test_rank0_rendzevous(self):
agent._membership_changed("default", self.rdzv_handler)
)

def test_rank1_rendzevous(self):
def test_rank1_rendezvous(self):
agent = ElasticTrainingAgent(
node_rank=1,
config=self.config,
Expand All @@ -180,9 +181,22 @@ def test_rank1_rendzevous(self):
self.rdzv_handler._client.join_rendezvous(
0, 8, self.rdzv_handler._name
)

store = self.rdzv_handler._get_store(round=1, group=0)
store.set("MASTER_ADDR", "127.0.0.1".encode())
store.set("MASTER_PORT", "12345".encode())

def _set_store(store):
time.sleep(5)
store.set("MASTER_ADDR", "127.0.0.1".encode())
store.set("MASTER_PORT", "12345".encode())

_task = threading.Thread(target=_set_store, args=(store,))
_task.start()

addr, port = agent._safe_get_master_addr_port(store)
print(addr)
print(port)
self.assertEqual(addr, "127.0.0.1")
self.assertEqual(port, 12345)

# Set the node id and rank as 1.
agent._client._node_id = 1
Expand All @@ -196,6 +210,8 @@ def test_rank1_rendzevous(self):
self.assertEqual(worker.local_rank, 1)
self.assertEqual(worker.global_rank, 9)
self.assertEqual(worker.world_size, 16)
self.assertEqual(store.get("MASTER_ADDR").decode(), "127.0.0.1")
self.assertEqual(store.get("MASTER_PORT").decode(), "12345")

def test_get_local_ip(self):
local_ip = _get_local_ip()
Expand Down

0 comments on commit fbd0bae

Please sign in to comment.