diff --git a/airflow/decorators/assets.py b/airflow/decorators/assets.py index 66d021218c12..3b7d357bd5e8 100644 --- a/airflow/decorators/assets.py +++ b/airflow/decorators/assets.py @@ -17,8 +17,9 @@ from __future__ import annotations +import inspect import types -import typing +from typing import TYPE_CHECKING, Any, Iterator, Mapping import attrs @@ -26,10 +27,41 @@ from airflow.models.dag import DAG, ScheduleArg from airflow.operators.python import PythonOperator -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from airflow.io.path import ObjectStoragePath +@attrs.define(kw_only=True) +class AssetRef: + """Reference to an asset.""" + + name: str + + +class _AssetMainOperator(PythonOperator): + def __init__(self, *, definition_name: str, **kwargs) -> None: + super().__init__(**kwargs) + self._definition_name = definition_name + + def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: + for key in inspect.signature(self.python_callable).parameters: + if key == "self": + value: Any = AssetRef(name=self._definition_name) + elif key == "context": + value = context + else: + # TODO: This does not check if the upstream asset actually + # exists. Should we do a second pass in the DAG processor to + # raise parse-time errors if a non-existent asset is referenced? + # How? Should we also fail the task at runtime? Or should the + # dangling reference simply do nothing? + value = AssetRef(name=key) + yield key, value + + def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: + return dict(self._iter_kwargs(context)) + + @attrs.define(kw_only=True) class AssetDefinition: """ @@ -44,8 +76,16 @@ class AssetDefinition: schedule: ScheduleArg def __attrs_post_init__(self) -> None: + parameters = inspect.signature(self.function).parameters with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True) as dag: - PythonOperator(task_id="__main__", outlets=[self.asset], python_callable=self.function) + _AssetMainOperator( + task_id="__main__", + # TODO: This should use the name argument instead. + inlets=[Asset(uri=k) for k in parameters if k not in ("self", "context")], + outlets=[self.asset], + python_callable=self.function, + definition_name=self.name, + ) # TODO: Currently this just gets serialized into a string. # When we create UI for assets, we should add logic to serde so the # serialized DAG contains appropriate asset information. @@ -58,11 +98,13 @@ class asset: schedule: ScheduleArg uri: str | ObjectStoragePath | None - extra: dict[str, typing.Any] = attrs.field(factory=dict) + extra: dict[str, Any] = attrs.field(factory=dict) def __call__(self, f: types.FunctionType) -> AssetDefinition: if (name := f.__name__) != f.__qualname__: raise ValueError("nested function not supported") + if name == "self" or name == "context": + raise ValueError(f"prohibited name for asset: {name}") return AssetDefinition( name=name, asset=Asset(