Skip to content

Commit

Permalink
Basic inlet dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored and Lee-W committed Oct 22, 2024
1 parent 5c2762f commit bbb0cf5
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions airflow/decorators/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,51 @@

from __future__ import annotations

import inspect
import types
import typing
from typing import TYPE_CHECKING, Any, Iterator, Mapping

import attrs

from airflow.datasets import Dataset as Asset
from airflow.models.dag import DAG, ScheduleArg
from airflow.operators.python import PythonOperator

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from airflow.io.path import ObjectStoragePath


@attrs.define(kw_only=True)
class AssetRef:
"""Reference to an asset."""

name: str


class _AssetMainOperator(PythonOperator):
def __init__(self, *, definition_name: str, **kwargs) -> None:
super().__init__(**kwargs)
self._definition_name = definition_name

def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]:
for key in inspect.signature(self.python_callable).parameters:
if key == "self":
value: Any = AssetRef(name=self._definition_name)
elif key == "context":
value = context
else:
# TODO: This does not check if the upstream asset actually
# exists. Should we do a second pass in the DAG processor to
# raise parse-time errors if a non-existent asset is referenced?
# How? Should we also fail the task at runtime? Or should the
# dangling reference simply do nothing?
value = AssetRef(name=key)
yield key, value

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
return dict(self._iter_kwargs(context))


@attrs.define(kw_only=True)
class AssetDefinition:
"""
Expand All @@ -44,8 +76,16 @@ class AssetDefinition:
schedule: ScheduleArg

def __attrs_post_init__(self) -> None:
parameters = inspect.signature(self.function).parameters
with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True) as dag:
PythonOperator(task_id="__main__", outlets=[self.asset], python_callable=self.function)
_AssetMainOperator(
task_id="__main__",
# TODO: This should use the name argument instead.
inlets=[Asset(uri=k) for k in parameters if k not in ("self", "context")],
outlets=[self.asset],
python_callable=self.function,
definition_name=self.name,
)
# TODO: Currently this just gets serialized into a string.
# When we create UI for assets, we should add logic to serde so the
# serialized DAG contains appropriate asset information.
Expand All @@ -58,11 +98,13 @@ class asset:

schedule: ScheduleArg
uri: str | ObjectStoragePath | None
extra: dict[str, typing.Any] = attrs.field(factory=dict)
extra: dict[str, Any] = attrs.field(factory=dict)

def __call__(self, f: types.FunctionType) -> AssetDefinition:
if (name := f.__name__) != f.__qualname__:
raise ValueError("nested function not supported")
if name == "self" or name == "context":
raise ValueError(f"prohibited name for asset: {name}")
return AssetDefinition(
name=name,
asset=Asset(
Expand Down

0 comments on commit bbb0cf5

Please sign in to comment.