Skip to content

Commit

Permalink
refactor due to reviewer's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andylizf committed Oct 14, 2024
1 parent 129bdbf commit 12ec5a4
Showing 1 changed file with 6 additions and 25 deletions.
31 changes: 6 additions & 25 deletions sky/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Dag:
"""Dag: a user application, represented as a DAG of Tasks.
This class allows users to define and manage directed acyclic graphs
(DAGs) of tasks, representing complex workflows or pipelines.
(DAGs) of tasks, representing complex workflows.
Examples:
>>> import sky
Expand All @@ -26,39 +26,20 @@ class Dag:
>>> task1 >> task2
"""

def __init__(
self,
name: Optional[str] = None,
tasks: Optional[List['task.Task']] = None,
dependencies: Optional[Dict[TaskOrName, Union[List[TaskOrName],
TaskOrName]]] = None
) -> None:
def __init__(self, name: Optional[str] = None) -> None:
"""Initialize a new DAG.
Args:
name: Optional name for the DAG.
tasks: Optional list of Task objects to add to the DAG.
dependencies: Optional dictionary specifying task dependencies.
Keys are dependent tasks, values are lists of tasks
they depend on.
"""

self.name = name
self.tasks: List['task.Task'] = []
self._task_name_lookup: Dict[str, 'task.Task'] = {}
self.dependencies: Dict['task.Task', Set['task.Task']] = {}

self.graph = nx.DiGraph()

# Add tasks
if tasks:
for task in tasks:
self.add(task)

# Add dependencies
if dependencies:
for dependent, deps in dependencies.items():
self.set_dependencies(dependent, deps)

def _get_task(self, task_or_name: TaskOrName) -> 'task.Task':
"""Get a task object from a task or its name.
Expand Down Expand Up @@ -90,12 +71,12 @@ def add(self, task: 'task.Task') -> None:
"""
if task in self.tasks:
raise ValueError(f'Task {task.name} already exists in the DAG.')
if task.name in self._task_name_lookup:
raise ValueError(
f'Task name "{task.name}" is already used in the DAG.')
self.graph.add_node(task)
self.tasks.append(task)
if task.name is not None:
if task.name in self._task_name_lookup:
raise ValueError(
f'Task name "{task.name}" is already used in the DAG.')
self._task_name_lookup[task.name] = task

def remove(self, task: Union['task.Task', str]) -> None:
Expand Down

0 comments on commit 12ec5a4

Please sign in to comment.