Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/Enhance node management. #1301

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,20 +556,22 @@

def _process_list_nodes(self, nodes: List[Node]):
"""Callback with node list by the list api of k8s."""
if not nodes:
return

exist_nodes: Dict[str, List[int]] = {}
for node_type in self._job_nodes.keys():
exist_nodes[node_type] = []
for node in nodes:
exist_nodes[node.type].append(node.id)
if node.status == NodeStatus.DELETED:
type = NodeEventType.DELETED
else:
type = NodeEventType.MODIFIED
# Mock event to avoid missing events
event = NodeEvent(type, node)
self._process_event(event)

if nodes:
for node in nodes:
exist_nodes[node.type].append(node.id)
if node.status == NodeStatus.DELETED:
event_type = NodeEventType.DELETED

Check warning on line 568 in dlrover/python/master/node/dist_job_manager.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/node/dist_job_manager.py#L568

Added line #L568 was not covered by tests
else:
event_type = NodeEventType.MODIFIED
# Mock event to avoid missing events
event = NodeEvent(event_type, node)
self._process_event(event)
logger.debug(f"Got list nodes: {exist_nodes}")

for node_type in self._job_nodes.keys():
# Avoid dictionary keys changed during iteration
Expand All @@ -581,9 +583,8 @@
and node.id not in exist_nodes[node_type]
):
logger.info(
"Node %s %s is deleted without the event",
node_type,
node.id,
f"Node {node_type} {node.id} is deleted "
"without the event"
)
node.is_released = True
new_node = copy.deepcopy(node)
Expand All @@ -594,9 +595,9 @@
def close_job(self):
plan = ScalePlan()
ps_resource = NodeGroupResource.new_empty()
worker_reource = NodeGroupResource.new_empty()
worker_resource = NodeGroupResource.new_empty()

Check warning on line 598 in dlrover/python/master/node/dist_job_manager.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/node/dist_job_manager.py#L598

Added line #L598 was not covered by tests
plan.node_group_resources = {
"worker": worker_reource,
"worker": worker_resource,
"ps": ps_resource,
}
self._scaler.scale(plan=plan)
Expand Down Expand Up @@ -633,6 +634,7 @@
and len(pods.items) > 0
and any(
pod.status.phase == NodeStatus.RUNNING
and not pod.metadata.deletion_timestamp
for pod in pods.items
)
):
Expand Down
34 changes: 33 additions & 1 deletion dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from unittest import mock
from unittest.mock import patch

from kubernetes import client

Expand All @@ -41,7 +42,10 @@
from dlrover.python.master.dist_master import DistributedJobMaster
from dlrover.python.master.monitor.error_monitor import SimpleErrorMonitor
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
from dlrover.python.master.node.dist_job_manager import create_job_manager
from dlrover.python.master.node.dist_job_manager import (
DistributedJobManager,
create_job_manager,
)
from dlrover.python.master.node.event_callback import (
ClusterContext,
TaskRescheduleCallback,
Expand Down Expand Up @@ -427,6 +431,34 @@ def test_process_list_nodes(self):
ps_ids = list(manager._job_nodes[NodeType.PS].keys())
self.assertListEqual(ps_ids, [0, 1, 2])

@patch.object(DistributedJobManager, "_process_event")
def test_process_list_nodes_for_empty_case(self, mock_method):
params = MockK8sPSJobArgs()
params.initilize()
manager = create_job_manager(params, SpeedMonitor())
manager._job_nodes = {
NodeType.PS: {
0: Node(
node_type=NodeType.PS,
node_id=0,
status=NodeStatus.RUNNING,
config_resource=NodeResource(1, 4096),
max_relaunch_count=1,
)
},
NodeType.WORKER: {
1: Node(
node_type=NodeType.WORKER,
node_id=1,
status=NodeStatus.RUNNING,
config_resource=NodeResource(1, 4096),
max_relaunch_count=1,
)
},
}
manager._process_list_nodes([])
self.assertEqual(mock_method.call_count, 2)

def test_create_allreduce_job_manager(self):
params = MockK8sPSJobArgs()
params.initilize()
Expand Down
Loading