This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new da3bdbf26c3 Implement asset.multi (#44711)
da3bdbf26c3 is described below
commit da3bdbf26c39cd9dd65b824e2dfac787124a4048
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Dec 6 22:22:26 2024 +0800
Implement asset.multi (#44711)
This allows a function to emit multiple assets. In this case, you are on
your own providing proper names to each asset, but it would work.
Also includes refactoring to the existing decorator mechanism so we
don't need to repeat code (especially arguments).
---
.../airflow/sdk/definitions/asset/decorators.py | 140 +++++++++++++++------
task_sdk/tests/defintions/test_asset_decorators.py | 78 +++++++++---
2 files changed, 164 insertions(+), 54 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
index 45b6686d059..1cb1ea4e316 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -18,27 +18,43 @@
from __future__ import annotations
import inspect
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any
import attrs
from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetRef
+from airflow.sdk.definitions.asset import Asset, AssetRef, BaseAsset
if TYPE_CHECKING:
- from collections.abc import Collection, Iterator, Mapping
+ from collections.abc import Callable, Collection, Iterator, Mapping
from airflow.io.path import ObjectStoragePath
- from airflow.models.dag import DagStateChangeCallback, ScheduleArg
from airflow.models.param import ParamsDict
+ from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
+ from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback,
ScheduleArg
+ from airflow.serialization.dag_dependency import DagDependency
from airflow.triggers.base import BaseTrigger
+ from airflow.typing_compat import Self
class _AssetMainOperator(PythonOperator):
def __init__(self, *, definition_name: str, uri: str | None = None,
**kwargs) -> None:
super().__init__(**kwargs)
self._definition_name = definition_name
- self._uri = uri
+
+ @classmethod
+ def from_definition(cls, definition: AssetDefinition |
MultiAssetDefinition) -> Self:
+ return cls(
+ task_id="__main__",
+ inlets=[
+ AssetRef(name=inlet_asset_name)
+ for inlet_asset_name in
inspect.signature(definition._function).parameters
+ if inlet_asset_name not in ("self", "context")
+ ],
+ outlets=[v for _, v in definition.iter_assets()],
+ python_callable=definition._function,
+ definition_name=definition._function.__name__,
+ )
def _iter_kwargs(
self, context: Mapping[str, Any], active_assets: dict[str, Asset]
@@ -81,41 +97,53 @@ class AssetDefinition(Asset):
_source: asset
def __attrs_post_init__(self) -> None:
- from airflow.models.dag import DAG
-
- with DAG(
- dag_id=self.name,
- schedule=self._source.schedule,
- is_paused_upon_creation=self._source.is_paused_upon_creation,
- dag_display_name=self._source.display_name or self.name,
- description=self._source.description,
- params=self._source.params,
- on_success_callback=self._source.on_success_callback,
- on_failure_callback=self._source.on_failure_callback,
- auto_register=True,
- ):
- _AssetMainOperator(
- task_id="__main__",
- inlets=[
- AssetRef(name=inlet_asset_name)
- for inlet_asset_name in
inspect.signature(self._function).parameters
- if inlet_asset_name not in ("self", "context")
- ],
- outlets=[self],
- python_callable=self._function,
- definition_name=self.name,
- uri=self.uri,
- )
+ with self._source.create_dag(dag_id=self.name):
+ _AssetMainOperator.from_definition(self)
@attrs.define(kw_only=True)
-class asset:
- """Create an asset by decorating a materialization function."""
+class MultiAssetDefinition(BaseAsset):
+ """
+ Representation from decorating a function with ``@asset.multi``.
+
+ This is implemented as an "asset-like" object that can be used in all
places
+ that accept asset-ish things (e.g. normal assets, aliases, AssetAll,
+ AssetAny).
+
+ :meta private:
+ """
+
+ _function: Callable
+ _source: asset.multi
+
+ def __attrs_post_init__(self) -> None:
+ with self._source.create_dag(dag_id=self._function.__name__):
+ _AssetMainOperator.from_definition(self)
+
+ def evaluate(self, statuses: dict[str, bool]) -> bool:
+ return all(o.evaluate(statuses=statuses) for o in self._source.outlets)
+
+ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
+ for o in self._source.outlets:
+ yield from o.iter_assets()
+
+ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
+ for o in self._source.outlets:
+ yield from o.iter_asset_aliases()
+
+ def iter_dag_dependencies(self, *, source: str, target: str) ->
Iterator[DagDependency]:
+ for obj in self._source.outlets:
+ yield from obj.iter_dag_dependencies(source=source, target=target)
- uri: str | ObjectStoragePath | None = None
- group: str = Asset.asset_type
- extra: dict[str, Any] = attrs.field(factory=dict)
- watchers: list[BaseTrigger] = attrs.field(factory=list)
+
[email protected](kw_only=True)
+class _DAGFactory:
+ """
+ Common class for things that take DAG-like arguments.
+
+ This exists so we don't need to define these arguments separately for
+ ``@asset`` and ``@asset.multi``.
+ """
schedule: ScheduleArg
is_paused_upon_creation: bool | None = None
@@ -130,6 +158,46 @@ class asset:
access_control: dict[str, dict[str, Collection[str]]] | None = None
owner_links: dict[str, str] | None = None
+ def create_dag(self, *, dag_id: str) -> DAG:
+ from airflow.models.dag import DAG # TODO: Use the SDK DAG when it
works.
+
+ return DAG(
+ dag_id=dag_id,
+ schedule=self.schedule,
+ is_paused_upon_creation=self.is_paused_upon_creation,
+ dag_display_name=self.display_name or dag_id,
+ description=self.description,
+ params=self.params,
+ on_success_callback=self.on_success_callback,
+ on_failure_callback=self.on_failure_callback,
+ auto_register=True,
+ )
+
+
[email protected](kw_only=True)
+class asset(_DAGFactory):
+ """Create an asset by decorating a materialization function."""
+
+ uri: str | ObjectStoragePath | None = None
+ group: str = Asset.asset_type
+ extra: dict[str, Any] = attrs.field(factory=dict)
+ watchers: list[BaseTrigger] = attrs.field(factory=list)
+
+ @attrs.define(kw_only=True)
+ class multi(_DAGFactory):
+ """Create a one-task DAG that emits multiple assets."""
+
+ outlets: Collection[BaseAsset] # TODO: Support non-asset outlets?
+
+ def __call__(self, f: Callable) -> MultiAssetDefinition:
+ if self.schedule is not None:
+ raise NotImplementedError("asset scheduling not implemented
yet")
+ if f.__name__ != f.__qualname__:
+ raise ValueError("nested function not supported")
+ if not self.outlets:
+ raise ValueError("no outlets provided")
+ return MultiAssetDefinition(function=f, source=self)
+
def __call__(self, f: Callable) -> AssetDefinition:
if self.schedule is not None:
raise NotImplementedError("asset scheduling not implemented yet")
diff --git a/task_sdk/tests/defintions/test_asset_decorators.py
b/task_sdk/tests/defintions/test_asset_decorators.py
index aeaa8632901..2e714237b41 100644
--- a/task_sdk/tests/defintions/test_asset_decorators.py
+++ b/task_sdk/tests/defintions/test_asset_decorators.py
@@ -17,7 +17,6 @@
from __future__ import annotations
from unittest import mock
-from unittest.mock import ANY
import pytest
@@ -108,12 +107,22 @@ class TestAssetDecorator:
assert err.value.args[0].startswith("prohibited name for asset: ")
+class TestAssetMultiDecorator:
+ def test_multi_asset(self, example_asset_func):
+ definition = asset.multi(
+ schedule=None,
+ outlets=[Asset(name="a"), Asset(name="b")],
+ )(example_asset_func)
+
+ assert definition._function == example_asset_func
+ assert definition._source.schedule is None
+ assert definition._source.outlets == [Asset(name="a"), Asset(name="b")]
+
+
class TestAssetDefinition:
- @mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator")
+
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
@mock.patch("airflow.models.dag.DAG")
- def test__attrs_post_init__(
- self, DAG, _AssetMainOperator,
example_asset_func_with_valid_arg_as_inlet_asset
- ):
+ def test__attrs_post_init__(self, DAG, from_definition,
example_asset_func_with_valid_arg_as_inlet_asset):
asset_definition = asset(schedule=None, uri="s3://bucket/object",
group="MLModel", extra={"k": "v"})(
example_asset_func_with_valid_arg_as_inlet_asset
)
@@ -129,23 +138,56 @@ class TestAssetDefinition:
params=None,
auto_register=True,
)
- _AssetMainOperator.assert_called_once_with(
- task_id="__main__",
- inlets=[
- AssetRef(name="inlet_asset_1"),
- AssetRef(name="inlet_asset_2"),
- ],
- outlets=[asset_definition],
- python_callable=ANY,
- definition_name="example_asset_func",
- uri="s3://bucket/object",
- )
+ from_definition.assert_called_once_with(asset_definition)
- python_callable =
_AssetMainOperator.call_args.kwargs["python_callable"]
- assert python_callable ==
example_asset_func_with_valid_arg_as_inlet_asset
+
+class TestMultiAssetDefinition:
+
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
+ @mock.patch("airflow.models.dag.DAG")
+ def test__attrs_post_init__(self, DAG, from_definition,
example_asset_func_with_valid_arg_as_inlet_asset):
+ definition = asset.multi(
+ schedule=None,
+ outlets=[Asset(name="a"), Asset(name="b")],
+ )(example_asset_func_with_valid_arg_as_inlet_asset)
+
+ DAG.assert_called_once_with(
+ dag_id="example_asset_func",
+ dag_display_name="example_asset_func",
+ description=None,
+ schedule=None,
+ is_paused_upon_creation=None,
+ on_failure_callback=None,
+ on_success_callback=None,
+ params=None,
+ auto_register=True,
+ )
+ from_definition.assert_called_once_with(definition)
class Test_AssetMainOperator:
+ def test_from_definition(self,
example_asset_func_with_valid_arg_as_inlet_asset):
+ definition = asset(schedule=None, uri="s3://bucket/object",
group="MLModel", extra={"k": "v"})(
+ example_asset_func_with_valid_arg_as_inlet_asset
+ )
+ op = _AssetMainOperator.from_definition(definition)
+ assert op.task_id == "__main__"
+ assert op.inlets == [AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2")]
+ assert op.outlets == [definition]
+ assert op.python_callable ==
example_asset_func_with_valid_arg_as_inlet_asset
+ assert op._definition_name == "example_asset_func"
+
+ def test_from_definition_multi(self,
example_asset_func_with_valid_arg_as_inlet_asset):
+ definition = asset.multi(
+ schedule=None,
+ outlets=[Asset(name="a"), Asset(name="b")],
+ )(example_asset_func_with_valid_arg_as_inlet_asset)
+ op = _AssetMainOperator.from_definition(definition)
+ assert op.task_id == "__main__"
+ assert op.inlets == [AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2")]
+ assert op.outlets == [Asset(name="a"), Asset(name="b")]
+ assert op.python_callable ==
example_asset_func_with_valid_arg_as_inlet_asset
+ assert op._definition_name == "example_asset_func"
+
@mock.patch("airflow.models.asset.fetch_active_assets_by_name")
@mock.patch("airflow.utils.session.create_session")
def test_determine_kwargs(