This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new da3bdbf26c3 Implement asset.multi (#44711)
da3bdbf26c3 is described below

commit da3bdbf26c39cd9dd65b824e2dfac787124a4048
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Dec 6 22:22:26 2024 +0800

    Implement asset.multi (#44711)
    
    This allows a function to emit multiple assets. In this case, you are on
    your own providing proper names to each asset, but it would work.
    
    Also includes refactoring to the existing decorator mechanism so we
    don't need to repeat code (especially arguments).
---
 .../airflow/sdk/definitions/asset/decorators.py    | 140 +++++++++++++++------
 task_sdk/tests/defintions/test_asset_decorators.py |  78 +++++++++---
 2 files changed, 164 insertions(+), 54 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py 
b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
index 45b6686d059..1cb1ea4e316 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -18,27 +18,43 @@
 from __future__ import annotations
 
 import inspect
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any
 
 import attrs
 
 from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetRef
+from airflow.sdk.definitions.asset import Asset, AssetRef, BaseAsset
 
 if TYPE_CHECKING:
-    from collections.abc import Collection, Iterator, Mapping
+    from collections.abc import Callable, Collection, Iterator, Mapping
 
     from airflow.io.path import ObjectStoragePath
-    from airflow.models.dag import DagStateChangeCallback, ScheduleArg
     from airflow.models.param import ParamsDict
+    from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
+    from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, 
ScheduleArg
+    from airflow.serialization.dag_dependency import DagDependency
     from airflow.triggers.base import BaseTrigger
+    from airflow.typing_compat import Self
 
 
 class _AssetMainOperator(PythonOperator):
     def __init__(self, *, definition_name: str, uri: str | None = None, 
**kwargs) -> None:
         super().__init__(**kwargs)
         self._definition_name = definition_name
-        self._uri = uri
+
+    @classmethod
+    def from_definition(cls, definition: AssetDefinition | 
MultiAssetDefinition) -> Self:
+        return cls(
+            task_id="__main__",
+            inlets=[
+                AssetRef(name=inlet_asset_name)
+                for inlet_asset_name in 
inspect.signature(definition._function).parameters
+                if inlet_asset_name not in ("self", "context")
+            ],
+            outlets=[v for _, v in definition.iter_assets()],
+            python_callable=definition._function,
+            definition_name=definition._function.__name__,
+        )
 
     def _iter_kwargs(
         self, context: Mapping[str, Any], active_assets: dict[str, Asset]
@@ -81,41 +97,53 @@ class AssetDefinition(Asset):
     _source: asset
 
     def __attrs_post_init__(self) -> None:
-        from airflow.models.dag import DAG
-
-        with DAG(
-            dag_id=self.name,
-            schedule=self._source.schedule,
-            is_paused_upon_creation=self._source.is_paused_upon_creation,
-            dag_display_name=self._source.display_name or self.name,
-            description=self._source.description,
-            params=self._source.params,
-            on_success_callback=self._source.on_success_callback,
-            on_failure_callback=self._source.on_failure_callback,
-            auto_register=True,
-        ):
-            _AssetMainOperator(
-                task_id="__main__",
-                inlets=[
-                    AssetRef(name=inlet_asset_name)
-                    for inlet_asset_name in 
inspect.signature(self._function).parameters
-                    if inlet_asset_name not in ("self", "context")
-                ],
-                outlets=[self],
-                python_callable=self._function,
-                definition_name=self.name,
-                uri=self.uri,
-            )
+        with self._source.create_dag(dag_id=self.name):
+            _AssetMainOperator.from_definition(self)
 
 
 @attrs.define(kw_only=True)
-class asset:
-    """Create an asset by decorating a materialization function."""
+class MultiAssetDefinition(BaseAsset):
+    """
+    Representation from decorating a function with ``@asset.multi``.
+
+    This is implemented as an "asset-like" object that can be used in all 
places
+    that accept asset-ish things (e.g. normal assets, aliases, AssetAll,
+    AssetAny).
+
+    :meta private:
+    """
+
+    _function: Callable
+    _source: asset.multi
+
+    def __attrs_post_init__(self) -> None:
+        with self._source.create_dag(dag_id=self._function.__name__):
+            _AssetMainOperator.from_definition(self)
+
+    def evaluate(self, statuses: dict[str, bool]) -> bool:
+        return all(o.evaluate(statuses=statuses) for o in self._source.outlets)
+
+    def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
+        for o in self._source.outlets:
+            yield from o.iter_assets()
+
+    def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
+        for o in self._source.outlets:
+            yield from o.iter_asset_aliases()
+
+    def iter_dag_dependencies(self, *, source: str, target: str) -> 
Iterator[DagDependency]:
+        for obj in self._source.outlets:
+            yield from obj.iter_dag_dependencies(source=source, target=target)
 
-    uri: str | ObjectStoragePath | None = None
-    group: str = Asset.asset_type
-    extra: dict[str, Any] = attrs.field(factory=dict)
-    watchers: list[BaseTrigger] = attrs.field(factory=list)
+
[email protected](kw_only=True)
+class _DAGFactory:
+    """
+    Common class for things that take DAG-like arguments.
+
+    This exists so we don't need to define these arguments separately for
+    ``@asset`` and ``@asset.multi``.
+    """
 
     schedule: ScheduleArg
     is_paused_upon_creation: bool | None = None
@@ -130,6 +158,46 @@ class asset:
     access_control: dict[str, dict[str, Collection[str]]] | None = None
     owner_links: dict[str, str] | None = None
 
+    def create_dag(self, *, dag_id: str) -> DAG:
+        from airflow.models.dag import DAG  # TODO: Use the SDK DAG when it 
works.
+
+        return DAG(
+            dag_id=dag_id,
+            schedule=self.schedule,
+            is_paused_upon_creation=self.is_paused_upon_creation,
+            dag_display_name=self.display_name or dag_id,
+            description=self.description,
+            params=self.params,
+            on_success_callback=self.on_success_callback,
+            on_failure_callback=self.on_failure_callback,
+            auto_register=True,
+        )
+
+
[email protected](kw_only=True)
+class asset(_DAGFactory):
+    """Create an asset by decorating a materialization function."""
+
+    uri: str | ObjectStoragePath | None = None
+    group: str = Asset.asset_type
+    extra: dict[str, Any] = attrs.field(factory=dict)
+    watchers: list[BaseTrigger] = attrs.field(factory=list)
+
+    @attrs.define(kw_only=True)
+    class multi(_DAGFactory):
+        """Create a one-task DAG that emits multiple assets."""
+
+        outlets: Collection[BaseAsset]  # TODO: Support non-asset outlets?
+
+        def __call__(self, f: Callable) -> MultiAssetDefinition:
+            if self.schedule is not None:
+                raise NotImplementedError("asset scheduling not implemented 
yet")
+            if f.__name__ != f.__qualname__:
+                raise ValueError("nested function not supported")
+            if not self.outlets:
+                raise ValueError("no outlets provided")
+            return MultiAssetDefinition(function=f, source=self)
+
     def __call__(self, f: Callable) -> AssetDefinition:
         if self.schedule is not None:
             raise NotImplementedError("asset scheduling not implemented yet")
diff --git a/task_sdk/tests/defintions/test_asset_decorators.py 
b/task_sdk/tests/defintions/test_asset_decorators.py
index aeaa8632901..2e714237b41 100644
--- a/task_sdk/tests/defintions/test_asset_decorators.py
+++ b/task_sdk/tests/defintions/test_asset_decorators.py
@@ -17,7 +17,6 @@
 from __future__ import annotations
 
 from unittest import mock
-from unittest.mock import ANY
 
 import pytest
 
@@ -108,12 +107,22 @@ class TestAssetDecorator:
         assert err.value.args[0].startswith("prohibited name for asset: ")
 
 
+class TestAssetMultiDecorator:
+    def test_multi_asset(self, example_asset_func):
+        definition = asset.multi(
+            schedule=None,
+            outlets=[Asset(name="a"), Asset(name="b")],
+        )(example_asset_func)
+
+        assert definition._function == example_asset_func
+        assert definition._source.schedule is None
+        assert definition._source.outlets == [Asset(name="a"), Asset(name="b")]
+
+
 class TestAssetDefinition:
-    @mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator")
+    
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
     @mock.patch("airflow.models.dag.DAG")
-    def test__attrs_post_init__(
-        self, DAG, _AssetMainOperator, 
example_asset_func_with_valid_arg_as_inlet_asset
-    ):
+    def test__attrs_post_init__(self, DAG, from_definition, 
example_asset_func_with_valid_arg_as_inlet_asset):
         asset_definition = asset(schedule=None, uri="s3://bucket/object", 
group="MLModel", extra={"k": "v"})(
             example_asset_func_with_valid_arg_as_inlet_asset
         )
@@ -129,23 +138,56 @@ class TestAssetDefinition:
             params=None,
             auto_register=True,
         )
-        _AssetMainOperator.assert_called_once_with(
-            task_id="__main__",
-            inlets=[
-                AssetRef(name="inlet_asset_1"),
-                AssetRef(name="inlet_asset_2"),
-            ],
-            outlets=[asset_definition],
-            python_callable=ANY,
-            definition_name="example_asset_func",
-            uri="s3://bucket/object",
-        )
+        from_definition.assert_called_once_with(asset_definition)
 
-        python_callable = 
_AssetMainOperator.call_args.kwargs["python_callable"]
-        assert python_callable == 
example_asset_func_with_valid_arg_as_inlet_asset
+
+class TestMultiAssetDefinition:
+    
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
+    @mock.patch("airflow.models.dag.DAG")
+    def test__attrs_post_init__(self, DAG, from_definition, 
example_asset_func_with_valid_arg_as_inlet_asset):
+        definition = asset.multi(
+            schedule=None,
+            outlets=[Asset(name="a"), Asset(name="b")],
+        )(example_asset_func_with_valid_arg_as_inlet_asset)
+
+        DAG.assert_called_once_with(
+            dag_id="example_asset_func",
+            dag_display_name="example_asset_func",
+            description=None,
+            schedule=None,
+            is_paused_upon_creation=None,
+            on_failure_callback=None,
+            on_success_callback=None,
+            params=None,
+            auto_register=True,
+        )
+        from_definition.assert_called_once_with(definition)
 
 
 class Test_AssetMainOperator:
+    def test_from_definition(self, 
example_asset_func_with_valid_arg_as_inlet_asset):
+        definition = asset(schedule=None, uri="s3://bucket/object", 
group="MLModel", extra={"k": "v"})(
+            example_asset_func_with_valid_arg_as_inlet_asset
+        )
+        op = _AssetMainOperator.from_definition(definition)
+        assert op.task_id == "__main__"
+        assert op.inlets == [AssetRef(name="inlet_asset_1"), 
AssetRef(name="inlet_asset_2")]
+        assert op.outlets == [definition]
+        assert op.python_callable == 
example_asset_func_with_valid_arg_as_inlet_asset
+        assert op._definition_name == "example_asset_func"
+
+    def test_from_definition_multi(self, 
example_asset_func_with_valid_arg_as_inlet_asset):
+        definition = asset.multi(
+            schedule=None,
+            outlets=[Asset(name="a"), Asset(name="b")],
+        )(example_asset_func_with_valid_arg_as_inlet_asset)
+        op = _AssetMainOperator.from_definition(definition)
+        assert op.task_id == "__main__"
+        assert op.inlets == [AssetRef(name="inlet_asset_1"), 
AssetRef(name="inlet_asset_2")]
+        assert op.outlets == [Asset(name="a"), Asset(name="b")]
+        assert op.python_callable == 
example_asset_func_with_valid_arg_as_inlet_asset
+        assert op._definition_name == "example_asset_func"
+
     @mock.patch("airflow.models.asset.fetch_active_assets_by_name")
     @mock.patch("airflow.utils.session.create_session")
     def test_determine_kwargs(

Reply via email to