This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f8a61cb6af1 Simplify asset decorator implementation (#44344)
f8a61cb6af1 is described below
commit f8a61cb6af17a215391d113220150a96d4840211
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Nov 26 18:51:21 2024 +0800
Simplify asset decorator implementation (#44344)
---
.../src/airflow/sdk/definitions/asset/__init__.py | 71 ++++++++++++++++------
.../airflow/sdk/definitions/asset/decorators.py | 44 +++++---------
task_sdk/tests/defintions/test_asset_decorators.py | 2 +-
3 files changed, 68 insertions(+), 49 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index 38bd4868085..812c30261bb 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -18,10 +18,10 @@
from __future__ import annotations
import logging
+import operator
import os
import urllib.parse
import warnings
-from collections.abc import Iterable, Iterator
from typing import (
TYPE_CHECKING,
Any,
@@ -40,6 +40,7 @@ from airflow.typing_compat import TypedDict
from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
+ from collections.abc import Iterable, Iterator
from urllib.parse import SplitResult
from sqlalchemy.orm.session import Session
@@ -221,11 +222,24 @@ class BaseAsset:
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""
- name: str
- uri: str
- group: str
- extra: dict[str, Any]
- watchers: list[BaseTrigger]
+ name: str = attrs.field(
+ validator=[_validate_asset_name],
+ )
+ uri: str = attrs.field(
+ validator=[_validate_non_empty_identifier],
+ converter=_sanitize_uri,
+ )
+ group: str = attrs.field(
+ default=attrs.Factory(operator.attrgetter("asset_type"),
takes_self=True),
+ validator=[_validate_identifier],
+ )
+ extra: dict[str, Any] = attrs.field(
+ factory=dict,
+ converter=_set_extra_default,
+ )
+ watchers: list[BaseTrigger] = attrs.field(
+ factory=list,
+ )
asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1
@@ -236,9 +250,9 @@ class Asset(os.PathLike, BaseAsset):
name: str,
uri: str,
*,
- group: str = "",
+ group: str = ...,
extra: dict | None = None,
- watchers: list[BaseTrigger] | None = None,
+ watchers: list[BaseTrigger] = ...,
) -> None:
"""Canonical; both name and uri are provided."""
@@ -247,9 +261,9 @@ class Asset(os.PathLike, BaseAsset):
self,
name: str,
*,
- group: str = "",
+ group: str = ...,
extra: dict | None = None,
- watchers: list[BaseTrigger] | None = None,
+ watchers: list[BaseTrigger] = ...,
) -> None:
"""It's possible to only provide the name, either by keyword or as the
only positional argument."""
@@ -258,9 +272,9 @@ class Asset(os.PathLike, BaseAsset):
self,
*,
uri: str,
- group: str = "",
+ group: str = ...,
extra: dict | None = None,
- watchers: list[BaseTrigger] | None = None,
+ watchers: list[BaseTrigger] = ...,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""
@@ -269,7 +283,7 @@ class Asset(os.PathLike, BaseAsset):
name: str | None = None,
uri: str | None = None,
*,
- group: str = "",
+ group: str | None = None,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
@@ -279,16 +293,35 @@ class Asset(os.PathLike, BaseAsset):
name = uri
elif uri is None:
uri = name
- fields = attrs.fields_dict(Asset)
- self.name = _validate_asset_name(self, fields["name"], name)
- self.uri = _sanitize_uri(_validate_non_empty_identifier(self,
fields["uri"], uri))
- self.group = _validate_identifier(self, fields["group"], group) if
group else self.asset_type
- self.extra = _set_extra_default(extra)
- self.watchers = watchers or []
+
+ if TYPE_CHECKING:
+ assert name is not None
+ assert uri is not None
+
+ # attrs default (and factory) does not kick in if any value is given to
+ # the argument. We need to exclude defaults from the custom ___init___.
+ kwargs: dict[str, Any] = {}
+ if group is not None:
+ kwargs["group"] = group
+ if extra is not None:
+ kwargs["extra"] = extra
+ if watchers is not None:
+ kwargs["watchers"] = watchers
+
+ self.__attrs_init__(name=name, uri=uri, **kwargs)
def __fspath__(self) -> str:
return self.uri
+ def __eq__(self, other: Any) -> bool:
+ # The Asset class can be subclassed, and we don't want fields added by
a
+ # subclass to break equality. This explicitly filters out only fields
+ # defined by the Asset class for comparison.
+ if not isinstance(other, Asset):
+ return NotImplemented
+ f = attrs.filters.include(*attrs.fields_dict(Asset))
+ return attrs.asdict(self, filter=f) == attrs.asdict(other, filter=f)
+
@property
def normalized_uri(self) -> str | None:
"""
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
index efdf8b70bd7..95876b76e66 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -18,23 +18,19 @@
from __future__ import annotations
import inspect
-from collections.abc import Iterator, Mapping
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
-)
+from typing import TYPE_CHECKING, Any, Callable
import attrs
-from airflow.models.asset import _fetch_active_assets_by_name
-from airflow.models.dag import DAG, ScheduleArg
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetRef
from airflow.utils.session import create_session
if TYPE_CHECKING:
+ from collections.abc import Iterator, Mapping
+
from airflow.io.path import ObjectStoragePath
+ from airflow.models.dag import ScheduleArg
from airflow.triggers.base import BaseTrigger
@@ -58,7 +54,8 @@ class _AssetMainOperator(PythonOperator):
yield key, value
def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str,
Any]:
- active_assets: dict[str, Asset] = {}
+ from airflow.models.asset import _fetch_active_assets_by_name
+
asset_names = [asset_ref.name for asset_ref in self.inlets if
isinstance(asset_ref, AssetRef)]
if "self" in inspect.signature(self.python_callable).parameters:
asset_names.append(self._definition_name)
@@ -66,6 +63,8 @@ class _AssetMainOperator(PythonOperator):
if asset_names:
with create_session() as session:
active_assets = _fetch_active_assets_by_name(asset_names,
session)
+ else:
+ active_assets = {}
return dict(self._iter_kwargs(context, active_assets))
@@ -81,38 +80,22 @@ class AssetDefinition(Asset):
schedule: ScheduleArg
def __attrs_post_init__(self) -> None:
- parameters = inspect.signature(self.function).parameters
+ from airflow.models.dag import DAG
with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
_AssetMainOperator(
task_id="__main__",
inlets=[
AssetRef(name=inlet_asset_name)
- for inlet_asset_name in parameters
+ for inlet_asset_name in
inspect.signature(self.function).parameters
if inlet_asset_name not in ("self", "context")
],
- outlets=[self.to_asset()],
+ outlets=[self],
python_callable=self.function,
definition_name=self.name,
uri=self.uri,
)
- def to_asset(self) -> Asset:
- return Asset(
- name=self.name,
- uri=self.uri,
- group=self.group,
- extra=self.extra,
- )
-
- def serialize(self):
- return {
- "uri": self.uri,
- "name": self.name,
- "group": self.group,
- "extra": self.extra,
- }
-
@attrs.define(kw_only=True)
class asset:
@@ -120,11 +103,14 @@ class asset:
schedule: ScheduleArg
uri: str | ObjectStoragePath | None = None
- group: str = ""
+ group: str = Asset.asset_type
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)
def __call__(self, f: Callable) -> AssetDefinition:
+ if self.schedule is not None:
+ raise NotImplementedError("asset scheduling not implemented yet")
+
if (name := f.__name__) != f.__qualname__:
raise ValueError("nested function not supported")
diff --git a/task_sdk/tests/defintions/test_asset_decorators.py
b/task_sdk/tests/defintions/test_asset_decorators.py
index 04650bc6644..e7d7ef45b40 100644
--- a/task_sdk/tests/defintions/test_asset_decorators.py
+++ b/task_sdk/tests/defintions/test_asset_decorators.py
@@ -135,7 +135,7 @@ class TestAssetDefinition:
AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2"),
],
- outlets=[asset_definition.to_asset()],
+ outlets=[asset_definition],
python_callable=ANY,
definition_name="example_asset_func",
uri="s3://bucket/object",