Skip to content

Commit

Permalink
Merge pull request #1287 from samplise/diagnosis-agent-observe
Browse files Browse the repository at this point in the history
Refactor diagnosis agent
  • Loading branch information
samplise authored Oct 12, 2024
2 parents 2a1a3f5 + cf40a3c commit f10ba6c
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 96 deletions.
22 changes: 22 additions & 0 deletions dlrover/python/diagnosis/common/diagnose_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2024 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List


class DiagnoseAction:
def __init__(self):
self._actions: List[str] = []

def add_action(self, action: str):
self._actions.append(action)
10 changes: 4 additions & 6 deletions dlrover/python/diagnosis/common/inference_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ class InferenceName:
END = "end"
TRAINING = "training"
NODE = "node"
WORKER = "worker"


class InferenceAttribute:
ISORNOT = "is_or_not"
IS = "is"
NOT = "not"
COLLECT = "collect"


class InferenceDescription:
HANG = "hang"
FAILURE = "failure"
METRICS = "metrics"


@dataclass
Expand Down Expand Up @@ -92,12 +95,7 @@ def combine_inferences(
) -> List[Inference]:
inferences = []
for inference2 in inferences2:
is_duplicate = False
for inference1 in inferences1:
if is_same_inference(inference1, inference2):
is_duplicate = True
break
if not is_duplicate:
if not is_inference_included(inferences1, inference2):
inferences.append(inference2)

for inference1 in inferences1:
Expand Down
21 changes: 21 additions & 0 deletions dlrover/python/diagnosis/inferencechain/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction
from dlrover.python.diagnosis.common.inference_chain import Inference


def coordinate_inferences(observations: List[Inference]) -> DiagnoseAction:
return DiagnoseAction()
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from dlrover.python.common import env_utils
from dlrover.python.diagnosis.common.constants import DiagnosisDataType
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
from dlrover.python.diagnosis.common.inference_chain import (
Inference,
InferenceAttribute,
InferenceDescription,
InferenceName,
InferenceOperator,
)
from dlrover.python.diagnosis.datacollector.xpu_timer_metric_collector import (
XpuTimerMetricsCollector,
)
from dlrover.python.elastic_agent.master_client import MasterClient


class MetricsCollectionOperator(InferenceOperator):
"""
MetricsCollectionOperator is the operator to collect
worker diagnosis metrics.
"""

def __init__(self):
super().__init__(None)
self._xpu_timer_collector = XpuTimerMetricsCollector()
self._client = MasterClient.singleton_instance()

def is_compatible(self, inference: Inference) -> bool:
if (
inference.name == InferenceName.WORKER
and inference.attribution == InferenceAttribute.COLLECT
and inference.description == InferenceDescription.METRICS
):
return True
else:
return False

def infer(self, inferences: List[Inference]) -> List[Inference]:
xpu_timer_metric = self._xpu_timer_collector.collect_data()
if xpu_timer_metric:
agent_xpu_metric = WorkerTrainingMetric(
data_type=DiagnosisDataType.XPU_TIMER_METRIC,
data_content=xpu_timer_metric,
node_id=env_utils.get_node_id(),
node_type=env_utils.get_node_type(),
node_rank=env_utils.get_node_rank(),
)
self._client.report_diagnosis_agent_metrics(agent_xpu_metric)

return []
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_failure_node_operator import ( # noqa: E501
CheckFailureNodeOperator,
)
from dlrover.python.diagnosis.inferencechain.inferenceoperator.metrics_collection_operator import ( # noqa: E501
MetricsCollectionOperator,
)


def get_training_failure_operators():
return [CheckFailureNodeOperator()]


def get_worker_observe_operators():
return [MetricsCollectionOperator()]


def get_worker_diagnosis_operators():
return []
62 changes: 46 additions & 16 deletions dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
import threading
import time
from datetime import datetime
from typing import Dict
from typing import Dict, List

from torch.distributed.elastic.multiprocessing.errors import ProcessFailure

from dlrover.python.common import env_utils
from dlrover.python.common.constants import TrainingExceptionLevel
from dlrover.python.common.error import ProcessError
from dlrover.python.common.log import default_logger as logger
Expand All @@ -28,25 +27,28 @@
from dlrover.python.diagnosis.common.constants import (
DiagnosisAction,
DiagnosisConstant,
DiagnosisDataType,
InferenceConfigKey,
)
from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
from dlrover.python.diagnosis.common.inference_chain import (
Inference,
InferenceAttribute,
InferenceDescription,
InferenceName,
combine_inferences,
is_inference_included,
)
from dlrover.python.diagnosis.datacollector.xpu_timer_metric_collector import (
XpuTimerMetricsCollector,
from dlrover.python.diagnosis.inferencechain.coordinator import (
coordinate_inferences,
)
from dlrover.python.diagnosis.inferencechain.inference_chain import (
InferenceChain,
)
from dlrover.python.diagnosis.inferencechain.inferenceoperator.operator import ( # noqa: E501
get_training_failure_operators,
get_worker_diagnosis_operators,
get_worker_observe_operators,
)
from dlrover.python.elastic_agent.master_client import MasterClient

Expand All @@ -56,8 +58,16 @@ def __init__(self, training_log_file: str, errors: str):
self._client = MasterClient.singleton_instance()
self._training_log_file = training_log_file
self._errors = errors
self._xpu_timer_metric_collector = XpuTimerMetricsCollector()
self._stopped = False
self._observe_problems: List[Inference] = [
Inference(
name=InferenceName.WORKER,
attribution=InferenceAttribute.COLLECT,
description=InferenceDescription.METRICS,
),
]
self._observe_operators = get_worker_observe_operators()
self._diagnosis_operators = get_worker_diagnosis_operators()

self.start()

Expand All @@ -81,23 +91,43 @@ def start(self):
def stop(self):
self._stopped = True

def _observe(self) -> List[Inference]:
observations: List[Inference] = []
for problem in self._observe_problems:
ic = InferenceChain([problem], self._observe_operators)
try:
infs = ic.infer()
if len(infs) > 0:
observations = combine_inferences(observations, infs)
except Exception as e:
logger.error(f"fail to observe problem {problem}: {e}")
return observations

def _diagnose_observations(
self, observations: List[Inference]
) -> DiagnoseAction:
conclusions: List[Inference] = []
for ob in observations:
ic = InferenceChain([ob], self._diagnosis_operators)
try:
infs = ic.infer()
if len(infs) > 0:
conclusions = combine_inferences(conclusions, infs)
except Exception as e:
logger.error(f"fail to diagnose observation {ob}: {e}")
return coordinate_inferences(conclusions)

def _periodically_diagnosis(self):
logger.info("Start periodically diagnosis...")
while True:
if self._stopped:
logger.info("Stop periodically diagnosis.")
break

xpu_timer_metric = self._xpu_timer_metric_collector.collect_data()
if xpu_timer_metric:
agent_xpu_metric = WorkerTrainingMetric(
data_type=DiagnosisDataType.XPU_TIMER_METRIC,
data_content=xpu_timer_metric,
node_id=env_utils.get_node_id(),
node_type=env_utils.get_node_type(),
node_rank=env_utils.get_node_rank(),
)
self._report_metric_to_master(agent_xpu_metric)
observations = self._observe()
if len(observations) > 0:
logger.info(f"Observed problems: {observations}")
self._diagnose_observations(observations)

time.sleep(
DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS
Expand Down
74 changes: 3 additions & 71 deletions dlrover/python/tests/test_diagnosis_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,15 @@

import os
import unittest
from unittest.mock import patch

from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
from torch.distributed.launcher.api import LaunchConfig

from dlrover.python.common import env_utils
from dlrover.python.common.constants import NodeEnv, NodeType, RendezvousName
from dlrover.python.common.constants import RendezvousName
from dlrover.python.common.worker import WorkerContext
from dlrover.python.diagnosis.common.constants import (
DiagnosisAction,
DiagnosisDataType,
EnvConfigKey,
)
from dlrover.python.diagnosis.common.constants import DiagnosisAction
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
from dlrover.python.diagnosis.datacollector.training_log_collector import (
TrainingLogCollector,
)
from dlrover.python.diagnosis.datacollector.xpu_timer_metric_collector import (
XpuTimerMetricsCollector,
)
from dlrover.python.elastic_agent.diagnosis.diagnosis_agent import (
DiagnosisAgent,
)
Expand All @@ -49,7 +38,7 @@

class TestDiagnosisAgent(unittest.TestCase):
def setUp(self):
self.master_proc, self.addr = start_local_master()
self._master, self.addr = start_local_master()
MasterClient._instance = build_master_client(self.addr, 1)
launch_config = LaunchConfig(
min_nodes=1,
Expand Down Expand Up @@ -109,63 +98,6 @@ def test_diagnose_training(self):
action = agent.diagnose_training_failure(wc)
self.assertEqual(action, DiagnosisAction.RESTART_WORKER)

@patch(
"dlrover.python.diagnosis.datacollector.training_log_collector"
".read_last_n_lines"
)
def test_log_collect(self, mock_file_util):
mock_file_util.return_value = [
"test0",
"DLRover agent started with:",
"test1",
]
training_log_collector = TrainingLogCollector(
log_file="test", n_line=3
)
self.assertTrue(training_log_collector.is_enabled())
result = training_log_collector.collect_data()
self.assertTrue("test0" not in result.logs)
self.assertTrue("test1" in result.logs)

def test_xpu_timer_metric_collect(self):
collector = XpuTimerMetricsCollector()
self.assertFalse(collector.is_enabled())

env_utils.set_env(EnvConfigKey.XPU_TIMER_PORT, 18889)
collector = XpuTimerMetricsCollector()
self.assertTrue(collector.is_enabled())

self.assertEqual(collector.collect_data(), "")

file = "data/xpu_timer_metrics"
file_path = os.path.join(os.path.dirname(__file__), file)
with open(file_path, "r", encoding="utf-8") as file:
test_metrics = file.read()
result = collector._preprocess_metrics(test_metrics)
self.assertTrue(result)
if "#" in result or "exposer" in result:
self.fail()

env_utils.set_env(NodeEnv.NODE_ID, 1)
env_utils.set_env(NodeEnv.NODE_TYPE, NodeType.WORKER)
env_utils.set_env(NodeEnv.NODE_RANK, 1)
agent_xpu_metric = WorkerTrainingMetric(
data_type=DiagnosisDataType.XPU_TIMER_METRIC,
data_content=result,
node_id=env_utils.get_node_id(),
node_type=env_utils.get_node_type(),
node_rank=env_utils.get_node_rank(),
)
self.assertEqual(
agent_xpu_metric.data_type,
DiagnosisDataType.XPU_TIMER_METRIC,
)
self.assertEqual(agent_xpu_metric.data_content, result)
self.assertEqual(agent_xpu_metric.node_id, 1)
self.assertEqual(agent_xpu_metric.node_type, NodeType.WORKER)
self.assertEqual(agent_xpu_metric.node_rank, 1)
self.assertTrue(agent_xpu_metric.timestamp > 0)

def test_worker_training_metric(self):
test = WorkerTrainingMetric(
data_content="test123",
Expand Down
Loading

0 comments on commit f10ba6c

Please sign in to comment.