This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f8a61cb6af1 Simplify asset decorator implementation (#44344)
f8a61cb6af1 is described below

commit f8a61cb6af17a215391d113220150a96d4840211
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Nov 26 18:51:21 2024 +0800

    Simplify asset decorator implementation (#44344)
---
 .../src/airflow/sdk/definitions/asset/__init__.py  | 71 ++++++++++++++++------
 .../airflow/sdk/definitions/asset/decorators.py    | 44 +++++---------
 task_sdk/tests/defintions/test_asset_decorators.py |  2 +-
 3 files changed, 68 insertions(+), 49 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index 38bd4868085..812c30261bb 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -18,10 +18,10 @@
 from __future__ import annotations
 
 import logging
+import operator
 import os
 import urllib.parse
 import warnings
-from collections.abc import Iterable, Iterator
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -40,6 +40,7 @@ from airflow.typing_compat import TypedDict
 from airflow.utils.session import NEW_SESSION, provide_session
 
 if TYPE_CHECKING:
+    from collections.abc import Iterable, Iterator
     from urllib.parse import SplitResult
 
     from sqlalchemy.orm.session import Session
@@ -221,11 +222,24 @@ class BaseAsset:
 class Asset(os.PathLike, BaseAsset):
     """A representation of data asset dependencies between workflows."""
 
-    name: str
-    uri: str
-    group: str
-    extra: dict[str, Any]
-    watchers: list[BaseTrigger]
+    name: str = attrs.field(
+        validator=[_validate_asset_name],
+    )
+    uri: str = attrs.field(
+        validator=[_validate_non_empty_identifier],
+        converter=_sanitize_uri,
+    )
+    group: str = attrs.field(
+        default=attrs.Factory(operator.attrgetter("asset_type"), 
takes_self=True),
+        validator=[_validate_identifier],
+    )
+    extra: dict[str, Any] = attrs.field(
+        factory=dict,
+        converter=_set_extra_default,
+    )
+    watchers: list[BaseTrigger] = attrs.field(
+        factory=list,
+    )
 
     asset_type: ClassVar[str] = "asset"
     __version__: ClassVar[int] = 1
@@ -236,9 +250,9 @@ class Asset(os.PathLike, BaseAsset):
         name: str,
         uri: str,
         *,
-        group: str = "",
+        group: str = ...,
         extra: dict | None = None,
-        watchers: list[BaseTrigger] | None = None,
+        watchers: list[BaseTrigger] = ...,
     ) -> None:
         """Canonical; both name and uri are provided."""
 
@@ -247,9 +261,9 @@ class Asset(os.PathLike, BaseAsset):
         self,
         name: str,
         *,
-        group: str = "",
+        group: str = ...,
         extra: dict | None = None,
-        watchers: list[BaseTrigger] | None = None,
+        watchers: list[BaseTrigger] = ...,
     ) -> None:
         """It's possible to only provide the name, either by keyword or as the 
only positional argument."""
 
@@ -258,9 +272,9 @@ class Asset(os.PathLike, BaseAsset):
         self,
         *,
         uri: str,
-        group: str = "",
+        group: str = ...,
         extra: dict | None = None,
-        watchers: list[BaseTrigger] | None = None,
+        watchers: list[BaseTrigger] = ...,
     ) -> None:
         """It's possible to only provide the URI as a keyword argument."""
 
@@ -269,7 +283,7 @@ class Asset(os.PathLike, BaseAsset):
         name: str | None = None,
         uri: str | None = None,
         *,
-        group: str = "",
+        group: str | None = None,
         extra: dict | None = None,
         watchers: list[BaseTrigger] | None = None,
     ) -> None:
@@ -279,16 +293,35 @@ class Asset(os.PathLike, BaseAsset):
             name = uri
         elif uri is None:
             uri = name
-        fields = attrs.fields_dict(Asset)
-        self.name = _validate_asset_name(self, fields["name"], name)
-        self.uri = _sanitize_uri(_validate_non_empty_identifier(self, 
fields["uri"], uri))
-        self.group = _validate_identifier(self, fields["group"], group) if 
group else self.asset_type
-        self.extra = _set_extra_default(extra)
-        self.watchers = watchers or []
+
+        if TYPE_CHECKING:
+            assert name is not None
+            assert uri is not None
+
+        # attrs default (and factory) does not kick in if any value is given to
+        # the argument. We need to exclude defaults from the custom ___init___.
+        kwargs: dict[str, Any] = {}
+        if group is not None:
+            kwargs["group"] = group
+        if extra is not None:
+            kwargs["extra"] = extra
+        if watchers is not None:
+            kwargs["watchers"] = watchers
+
+        self.__attrs_init__(name=name, uri=uri, **kwargs)
 
     def __fspath__(self) -> str:
         return self.uri
 
+    def __eq__(self, other: Any) -> bool:
+        # The Asset class can be subclassed, and we don't want fields added by 
a
+        # subclass to break equality. This explicitly filters out only fields
+        # defined by the Asset class for comparison.
+        if not isinstance(other, Asset):
+            return NotImplemented
+        f = attrs.filters.include(*attrs.fields_dict(Asset))
+        return attrs.asdict(self, filter=f) == attrs.asdict(other, filter=f)
+
     @property
     def normalized_uri(self) -> str | None:
         """
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py 
b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
index efdf8b70bd7..95876b76e66 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -18,23 +18,19 @@
 from __future__ import annotations
 
 import inspect
-from collections.abc import Iterator, Mapping
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-)
+from typing import TYPE_CHECKING, Any, Callable
 
 import attrs
 
-from airflow.models.asset import _fetch_active_assets_by_name
-from airflow.models.dag import DAG, ScheduleArg
 from airflow.providers.standard.operators.python import PythonOperator
 from airflow.sdk.definitions.asset import Asset, AssetRef
 from airflow.utils.session import create_session
 
 if TYPE_CHECKING:
+    from collections.abc import Iterator, Mapping
+
     from airflow.io.path import ObjectStoragePath
+    from airflow.models.dag import ScheduleArg
     from airflow.triggers.base import BaseTrigger
 
 
@@ -58,7 +54,8 @@ class _AssetMainOperator(PythonOperator):
             yield key, value
 
     def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, 
Any]:
-        active_assets: dict[str, Asset] = {}
+        from airflow.models.asset import _fetch_active_assets_by_name
+
         asset_names = [asset_ref.name for asset_ref in self.inlets if 
isinstance(asset_ref, AssetRef)]
         if "self" in inspect.signature(self.python_callable).parameters:
             asset_names.append(self._definition_name)
@@ -66,6 +63,8 @@ class _AssetMainOperator(PythonOperator):
         if asset_names:
             with create_session() as session:
                 active_assets = _fetch_active_assets_by_name(asset_names, 
session)
+        else:
+            active_assets = {}
         return dict(self._iter_kwargs(context, active_assets))
 
 
@@ -81,38 +80,22 @@ class AssetDefinition(Asset):
     schedule: ScheduleArg
 
     def __attrs_post_init__(self) -> None:
-        parameters = inspect.signature(self.function).parameters
+        from airflow.models.dag import DAG
 
         with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
             _AssetMainOperator(
                 task_id="__main__",
                 inlets=[
                     AssetRef(name=inlet_asset_name)
-                    for inlet_asset_name in parameters
+                    for inlet_asset_name in 
inspect.signature(self.function).parameters
                     if inlet_asset_name not in ("self", "context")
                 ],
-                outlets=[self.to_asset()],
+                outlets=[self],
                 python_callable=self.function,
                 definition_name=self.name,
                 uri=self.uri,
             )
 
-    def to_asset(self) -> Asset:
-        return Asset(
-            name=self.name,
-            uri=self.uri,
-            group=self.group,
-            extra=self.extra,
-        )
-
-    def serialize(self):
-        return {
-            "uri": self.uri,
-            "name": self.name,
-            "group": self.group,
-            "extra": self.extra,
-        }
-
 
 @attrs.define(kw_only=True)
 class asset:
@@ -120,11 +103,14 @@ class asset:
 
     schedule: ScheduleArg
     uri: str | ObjectStoragePath | None = None
-    group: str = ""
+    group: str = Asset.asset_type
     extra: dict[str, Any] = attrs.field(factory=dict)
     watchers: list[BaseTrigger] = attrs.field(factory=list)
 
     def __call__(self, f: Callable) -> AssetDefinition:
+        if self.schedule is not None:
+            raise NotImplementedError("asset scheduling not implemented yet")
+
         if (name := f.__name__) != f.__qualname__:
             raise ValueError("nested function not supported")
 
diff --git a/task_sdk/tests/defintions/test_asset_decorators.py 
b/task_sdk/tests/defintions/test_asset_decorators.py
index 04650bc6644..e7d7ef45b40 100644
--- a/task_sdk/tests/defintions/test_asset_decorators.py
+++ b/task_sdk/tests/defintions/test_asset_decorators.py
@@ -135,7 +135,7 @@ class TestAssetDefinition:
                 AssetRef(name="inlet_asset_1"),
                 AssetRef(name="inlet_asset_2"),
             ],
-            outlets=[asset_definition.to_asset()],
+            outlets=[asset_definition],
             python_callable=ANY,
             definition_name="example_asset_func",
             uri="s3://bucket/object",

Reply via email to