This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b3ffaaf67c2 Extend PartitionedAssetTimetable to allow per asset
partition (#60966)
b3ffaaf67c2 is described below
commit b3ffaaf67c2e06d76504756b4f577398a0653a3a
Author: Wei Lee <[email protected]>
AuthorDate: Tue Feb 17 18:31:18 2026 +0800
Extend PartitionedAssetTimetable to allow per asset partition (#60966)
---
airflow-core/src/airflow/assets/manager.py | 12 +-
.../airflow/serialization/definitions/assets.py | 4 +
airflow-core/src/airflow/serialization/encoders.py | 28 +++-
airflow-core/src/airflow/serialization/helpers.py | 4 +-
airflow-core/src/airflow/timetables/simple.py | 83 +++++++++-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 26 ++-
.../tests/unit/models/test_taskinstance.py | 4 +-
.../unit/timetables/test_partitioned_timetable.py | 181 +++++++++++++++++++++
.../airflow/sdk/definitions/timetables/assets.py | 4 +-
9 files changed, 315 insertions(+), 31 deletions(-)
diff --git a/airflow-core/src/airflow/assets/manager.py
b/airflow-core/src/airflow/assets/manager.py
index b7b0ce5c239..258cd2cfc72 100644
--- a/airflow-core/src/airflow/assets/manager.py
+++ b/airflow-core/src/airflow/assets/manager.py
@@ -403,13 +403,19 @@ class AssetManager(LoggingMixin):
assert partition_key is not None
from airflow.models.serialized_dag import SerializedDagModel
- serdag = SerializedDagModel.get(dag_id=target_dag.dag_id,
session=session)
- if not serdag:
+ if not (serdag := SerializedDagModel.get(dag_id=target_dag.dag_id,
session=session)):
raise RuntimeError(f"Could not find serialized dag for
dag_id={target_dag.dag_id}")
+
timetable = serdag.dag.timetable
if TYPE_CHECKING:
assert isinstance(timetable, PartitionedAssetTimetable)
- target_key =
timetable.partition_mapper.to_downstream(partition_key)
+
+ if (asset_model :=
session.scalar(select(AssetModel).where(AssetModel.id == asset_id))) is None:
+ raise RuntimeError(f"Could not find asset for
asset_id={asset_id}")
+
+ target_key = timetable.get_partition_mapper(
+ name=asset_model.name, uri=asset_model.uri
+ ).to_downstream(partition_key)
apdr = cls._get_or_create_apdr(
target_key=target_key,
diff --git a/airflow-core/src/airflow/serialization/definitions/assets.py
b/airflow-core/src/airflow/serialization/definitions/assets.py
index 9a88568f7c5..51fe0ed3d1d 100644
--- a/airflow-core/src/airflow/serialization/definitions/assets.py
+++ b/airflow-core/src/airflow/serialization/definitions/assets.py
@@ -162,6 +162,10 @@ class SerializedAsset(SerializedAssetBase):
"""
return AssetProfile(name=self.name or None, uri=self.uri or None,
type="Asset")
+ def __hash__(self):
+ f = attrs.filters.include(*attrs.fields_dict(SerializedAsset))
+ return hash(json.dumps(attrs.asdict(self, filter=f), sort_keys=True))
+
class SerializedAssetRef(SerializedAssetBase, AttrsInstance):
"""Serialized representation of an asset reference."""
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index 3bb34c3700c..f82af80c70c 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -70,6 +70,7 @@ from airflow.serialization.enums import DagAttributeTypes as
DAT, Encoding
from airflow.serialization.helpers import (
find_registered_custom_partition_mapper,
find_registered_custom_timetable,
+ is_core_partition_mapper_import_path,
is_core_timetable_import_path,
)
from airflow.timetables.base import Timetable as CoreTimetable
@@ -357,7 +358,11 @@ class _Serializer:
def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]:
return {
"asset_condition": encode_asset_like(timetable.asset_condition),
- "partition_mapper":
encode_partition_mapper(timetable.partition_mapper),
+ "default_partition_mapper":
encode_partition_mapper(timetable.default_partition_mapper),
+ "partition_mapper_config": [
+ (encode_asset_like(asset),
encode_partition_mapper(partition_mapper))
+ for asset, partition_mapper in
timetable.partition_mapper_config.items()
+ ],
}
BUILTIN_PARTITION_MAPPERS: dict[type, str] = {
@@ -465,7 +470,7 @@ def ensure_serialized_deadline_alert(obj: DeadlineAlert |
SerializedDeadlineAler
return decode_deadline_alert(encode_deadline_alert(obj))
-def encode_partition_mapper(var: PartitionMapper) -> dict[str, Any]:
+def encode_partition_mapper(var: PartitionMapper | CorePartitionMapper) ->
dict[str, Any]:
"""
Encode a PartitionMapper instance.
@@ -474,11 +479,20 @@ def encode_partition_mapper(var: PartitionMapper) ->
dict[str, Any]:
:meta private:
"""
- if (importable_string :=
_serializer.BUILTIN_PARTITION_MAPPERS.get(var_type := type(var), None)) is None:
- find_registered_custom_partition_mapper(
- importable_string := qualname(var_type)
- ) # This raises if not found.
+ var_type = type(var)
+ importable_string = _serializer.BUILTIN_PARTITION_MAPPERS.get(var_type)
+ if importable_string is not None:
+ return {
+ Encoding.TYPE: importable_string,
+ Encoding.VAR: _serializer.serialize_partition_mapper(var),
+ }
+
+ qn = qualname(var)
+ if is_core_partition_mapper_import_path(qn) is False:
+ # This raises if not found.
+ find_registered_custom_partition_mapper(qn)
+
return {
- Encoding.TYPE: importable_string,
+ Encoding.TYPE: qn,
Encoding.VAR: _serializer.serialize_partition_mapper(var),
}
diff --git a/airflow-core/src/airflow/serialization/helpers.py
b/airflow-core/src/airflow/serialization/helpers.py
index fa0b7dcf533..effca4453dc 100644
--- a/airflow-core/src/airflow/serialization/helpers.py
+++ b/airflow-core/src/airflow/serialization/helpers.py
@@ -135,9 +135,9 @@ def
find_registered_custom_partition_mapper(importable_string: str) -> type[Part
"""Find a user-defined custom partition mapper class registered via a
plugin."""
from airflow import plugins_manager
- partition_mapper_cls = plugins_manager.get_partition_mapper_plugins()
+ partition_mapper_classes = plugins_manager.get_partition_mapper_plugins()
with contextlib.suppress(KeyError):
- return partition_mapper_cls[importable_string]
+ return partition_mapper_classes[importable_string]
raise PartitionMapperNotFound(importable_string)
diff --git a/airflow-core/src/airflow/timetables/simple.py
b/airflow-core/src/airflow/timetables/simple.py
index f6b1f32c2a8..a0a553bb111 100644
--- a/airflow-core/src/airflow/timetables/simple.py
+++ b/airflow-core/src/airflow/timetables/simple.py
@@ -16,10 +16,22 @@
# under the License.
from __future__ import annotations
+from contextlib import suppress
from typing import TYPE_CHECKING, Any, TypeAlias
+import structlog
+
from airflow._shared.timezones import timezone
-from airflow.serialization.definitions.assets import SerializedAsset,
SerializedAssetAll, SerializedAssetBase
+from airflow.partition_mappers.identity import IdentityMapper
+from airflow.serialization.definitions.assets import (
+ SerializedAsset,
+ SerializedAssetAlias,
+ SerializedAssetAll,
+ SerializedAssetBase,
+ SerializedAssetNameRef,
+ SerializedAssetUriRef,
+)
+from airflow.serialization.encoders import encode_asset_like,
encode_partition_mapper
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
try:
@@ -32,6 +44,8 @@ except ModuleNotFoundError:
return o
+log = structlog.get_logger()
+
if TYPE_CHECKING:
from collections.abc import Collection, Sequence
@@ -201,8 +215,6 @@ class AssetTriggeredTimetable(_TrivialTimetable):
return "Asset"
def serialize(self) -> dict[str, Any]:
- from airflow.serialization.encoders import encode_asset_like
-
return {"asset_condition": encode_asset_like(self.asset_condition)}
def generate_run_id(
@@ -235,6 +247,9 @@ class AssetTriggeredTimetable(_TrivialTimetable):
return None
+DEFAULT_PARTITION_MAPPER = IdentityMapper()
+
+
class PartitionedAssetTimetable(AssetTriggeredTimetable):
"""Asset-driven timetable that listens for partitioned assets."""
@@ -242,24 +257,74 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable):
def summary(self) -> str:
return "Partitioned Asset"
- def __init__(self, assets: SerializedAssetBase, partition_mapper:
PartitionMapper) -> None:
+ def __init__(
+ self,
+ *,
+ assets: SerializedAssetBase,
+ partition_mapper_config: dict[SerializedAssetBase, PartitionMapper] |
None = None,
+ default_partition_mapper: PartitionMapper = DEFAULT_PARTITION_MAPPER,
+ ) -> None:
super().__init__(assets=assets)
- self.partition_mapper = partition_mapper
+ self.partition_mapper_config = partition_mapper_config or {}
+ self.default_partition_mapper = default_partition_mapper
+
+ self._name_to_partition_mapper: dict[str, PartitionMapper] = {}
+ self._uri_to_partition_mapper: dict[str, PartitionMapper] = {}
+ self._build_name_uri_mapping()
+
+ def _build_name_uri_mapping(self) -> None:
+ for base_asset, partition_mapper in
self.partition_mapper_config.items():
+ for unique_key, _ in base_asset.iter_assets():
+ self._name_to_partition_mapper[unique_key.name] =
partition_mapper
+ self._uri_to_partition_mapper[unique_key.uri] =
partition_mapper
+
+ for s_asset_ref in base_asset.iter_asset_refs():
+ if isinstance(s_asset_ref, SerializedAssetNameRef):
+ self._name_to_partition_mapper[s_asset_ref.name] =
partition_mapper
+ elif isinstance(s_asset_ref, SerializedAssetUriRef):
+ self._uri_to_partition_mapper[s_asset_ref.uri] =
partition_mapper
+ else:
+ raise ValueError(f"{type(s_asset_ref)} is not supported")
+
+ if isinstance(base_asset, SerializedAssetAlias):
+ log.warning("Partitioned Asset Alias is not supported.")
+
+ def get_partition_mapper(self, *, name: str = "", uri: str = "") ->
PartitionMapper:
+ with suppress(KeyError):
+ if name:
+ return self._name_to_partition_mapper[name]
+
+ if uri:
+ return self._uri_to_partition_mapper[uri]
+
+ return self.default_partition_mapper
def serialize(self) -> dict[str, Any]:
from airflow.serialization.serialized_objects import encode_asset_like
return {
"asset_condition": encode_asset_like(self.asset_condition),
- "partition_mapper": self.partition_mapper.serialize(),
+ "partition_mapper_config": [
+ (encode_asset_like(asset),
encode_partition_mapper(partition_mapper))
+ for asset, partition_mapper in
self.partition_mapper_config.items()
+ ],
+ "default_partition_mapper":
encode_partition_mapper(self.default_partition_mapper),
}
@classmethod
- def deserialize(cls, data: dict[str, Any]) -> Timetable:
+ def deserialize(cls, data: dict[str, Any]) -> PartitionedAssetTimetable:
from airflow.serialization.decoders import decode_partition_mapper
from airflow.serialization.serialized_objects import decode_asset_like
- return cls(
+ default_partition_mapper_data = data["default_partition_mapper"]
+ partition_mapper_mappping_data = data["partition_mapper_config"]
+
+ timetable = cls(
assets=decode_asset_like(data["asset_condition"]),
- partition_mapper=decode_partition_mapper(data["partition_mapper"]),
+
default_partition_mapper=decode_partition_mapper(default_partition_mapper_data),
+ partition_mapper_config={
+ decode_asset_like(ser_asest):
decode_partition_mapper(ser_partition_mapper)
+ for ser_asest, ser_partition_mapper in
partition_mapper_mappping_data
+ },
)
+ return timetable
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 8cacf6d4e0e..ea9c51220ca 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -8751,6 +8751,7 @@ def _produce_and_register_asset_event(
session.commit()
serialized_outlets = dag.get_task("hi").outlets
+
TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[o.asprofile() for o in serialized_outlets],
@@ -8798,9 +8799,12 @@ def test_partitioned_dag_run_with_customized_mapper(
dag_id="asset-event-consumer",
schedule=PartitionedAssetTimetable(
assets=asset_1,
- # TODO: (GH-57694) this partition mapper interface will be
moved into asset as per-asset mapper
- # and the type mismatch will be handled there
- partition_mapper=Key1Mapper(), # type: ignore[arg-type]
+ # Most users should use the partition mapper provided by the
task-SDK.
+ # Advanced users can import from core and register their own
partition mapper
+ # via an Airflow plugin.
+ # We intentionally exclude core mappers from the public typing
+ # so standard users don't accidentally rely on internal
implementations.
+ default_partition_mapper=Key1Mapper(), # type:
ignore[arg-type]
),
session=session,
):
@@ -8848,8 +8852,8 @@ def test_consumer_dag_listen_to_two_partitioned_asset(
with dag_maker(
dag_id="asset-event-consumer",
schedule=PartitionedAssetTimetable(
- assets=asset_1 & asset_2,
- partition_mapper=IdentityMapper(),
+ assets=(Asset.ref(uri="asset-1") & Asset.ref(name="asset-2")),
+ default_partition_mapper=IdentityMapper(),
),
session=session,
):
@@ -8931,9 +8935,15 @@ def
test_consumer_dag_listen_to_two_partitioned_asset_with_key_1_mapper(
dag_id="asset-event-consumer",
schedule=PartitionedAssetTimetable(
assets=asset_1 & asset_2,
- # TODO: (GH-57694) this partition mapper interface will be
moved into asset as per-asset mapper
- # and the type mismatch will be handled there
- partition_mapper=Key1Mapper(), # type: ignore[arg-type]
+ # Most users should use the partition mapper provided by the
task-SDK.
+ # Advanced users can import from core and register their own
partition mapper
+ # via an Airflow plugin.
+ # We intentionally exclude core mappers from the public typing
+ # so standard users don't accidentally rely on internal
implementations.
+ partition_mapper_config={
+ Asset(name="asset-1"): Key1Mapper(), # type:
ignore[dict-item]
+ Asset(name="asset-2"): Key1Mapper(), # type:
ignore[dict-item]
+ },
),
session=session,
):
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 379d2218dbf..5ea9184a8b3 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -3074,7 +3074,9 @@ def
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
with dag_maker(
dag_id="asset_event_listener",
- schedule=PartitionedAssetTimetable(assets=asset,
partition_mapper=IdentityMapper()),
+ schedule=PartitionedAssetTimetable(
+ assets=Asset(name="hello"),
default_partition_mapper=IdentityMapper()
+ ),
session=session,
):
EmptyOperator(task_id="hi")
diff --git a/airflow-core/tests/unit/timetables/test_partitioned_timetable.py
b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py
new file mode 100644
index 00000000000..b1befe9f5d2
--- /dev/null
+++ b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from collections.abc import Callable, Iterable
+from contextlib import ExitStack
+from typing import TYPE_CHECKING
+from unittest import mock
+
+import pytest
+
+from airflow._shared.module_loading import qualname
+from airflow.partition_mappers.identity import IdentityMapper as IdentityMapper
+from airflow.sdk import Asset
+from airflow.serialization.definitions.assets import SerializedAsset
+from airflow.serialization.encoders import ensure_serialized_asset
+from airflow.serialization.enums import DagAttributeTypes
+from airflow.timetables.simple import PartitionedAssetTimetable
+
+if TYPE_CHECKING:
+ from airflow.partition_mappers.base import PartitionMapper
+
+
+class Key1Mapper(IdentityMapper):
+ """Partition Mapper that returns only key-1 as downstream key"""
+
+ def to_downstream(self, key: str) -> str:
+ return "key-1"
+
+ def to_upstream(self, key: str) -> Iterable[str]:
+ yield key
+
+
+def _find_registered_custom_partition_mapper(import_string: str) ->
type[PartitionMapper]:
+ if import_string == qualname(Key1Mapper):
+ return Key1Mapper
+ raise ValueError(f"unexpected class {import_string!r}")
+
+
[email protected]
+def custom_partition_mapper_patch() -> Callable[[], ExitStack]:
+ def _patch() -> ExitStack:
+ stack = ExitStack()
+ for mock_target in [
+
"airflow.serialization.encoders.find_registered_custom_partition_mapper",
+
"airflow.serialization.decoders.find_registered_custom_partition_mapper",
+ ]:
+ stack.enter_context(
+ mock.patch(
+ mock_target,
+ _find_registered_custom_partition_mapper,
+ )
+ )
+ return stack
+
+ return _patch
+
+
+class TestPartitionedAssetTimetable:
+ @pytest.mark.parametrize(
+ "asset_obj",
+ [
+ Asset("test_1"),
+ Asset(name="test_1"),
+ Asset(uri="test_1"),
+ Asset(name="test_1", uri="test_1"),
+ ],
+ )
+ def test_get_partition_mapper_without_mapping(self, asset_obj):
+ timetable = PartitionedAssetTimetable(assets=asset_obj)
+ assert timetable.partition_mapper_config == {}
+ assert isinstance(timetable.default_partition_mapper, IdentityMapper)
+ assert isinstance(timetable.get_partition_mapper(name="test_1",
uri="test_1"), IdentityMapper)
+ assert isinstance(timetable.get_partition_mapper(name="test_1"),
IdentityMapper)
+ assert isinstance(timetable.get_partition_mapper(uri="test_1"),
IdentityMapper)
+
+ @pytest.mark.parametrize(
+ "asset_obj",
+ [
+ Asset("test_1"),
+ Asset(name="test_1"),
+ Asset(uri="test_1"),
+ Asset(name="test_1", uri="test_1"),
+ ],
+ )
+ @pytest.mark.usefixtures("custom_partition_mapper_patch")
+ def test_get_partition_mapper_with_mapping(self, asset_obj):
+ ser_asset = ensure_serialized_asset(asset_obj)
+
+ timetable = PartitionedAssetTimetable(
+ assets=ser_asset, partition_mapper_config={ser_asset: Key1Mapper()}
+ )
+ assert isinstance(timetable.default_partition_mapper, IdentityMapper)
+ assert isinstance(timetable.get_partition_mapper(name="test_1",
uri="test_1"), Key1Mapper)
+ assert isinstance(timetable.get_partition_mapper(name="test_1"),
Key1Mapper)
+ assert isinstance(timetable.get_partition_mapper(uri="test_1"),
Key1Mapper)
+
+ def test_serialize(self):
+ ser_asset = ensure_serialized_asset(Asset("test"))
+ timetable = PartitionedAssetTimetable(
+ assets=ser_asset, partition_mapper_config={ser_asset:
IdentityMapper()}
+ )
+ assert timetable.serialize() == {
+ "asset_condition": {
+ "__type": DagAttributeTypes.ASSET,
+ "name": "test",
+ "uri": "test",
+ "group": "asset",
+ "extra": {},
+ },
+ "partition_mapper_config": [
+ (
+ {
+ "__type": DagAttributeTypes.ASSET,
+ "name": "test",
+ "uri": "test",
+ "group": "asset",
+ "extra": {},
+ },
+ {
+ "__type":
"airflow.partition_mappers.identity.IdentityMapper",
+ "__var": {},
+ },
+ )
+ ],
+ "default_partition_mapper": {
+ "__type": "airflow.partition_mappers.identity.IdentityMapper",
+ "__var": {},
+ },
+ }
+
+ def test_deserialize(self):
+ timetable = PartitionedAssetTimetable.deserialize(
+ {
+ "asset_condition": {
+ "__type": DagAttributeTypes.ASSET,
+ "name": "test",
+ "uri": "test",
+ "group": "asset",
+ "extra": {},
+ },
+ "partition_mapper_config": [
+ (
+ {
+ "__type": DagAttributeTypes.ASSET,
+ "name": "test",
+ "uri": "test",
+ "group": "asset",
+ "extra": {},
+ },
+ {
+ "__type":
"airflow.partition_mappers.identity.IdentityMapper",
+ "__var": {},
+ },
+ )
+ ],
+ "default_partition_mapper": {
+ "__type":
"airflow.partition_mappers.identity.IdentityMapper",
+ "__var": {},
+ },
+ }
+ )
+ ser_asset = SerializedAsset(name="test", uri="test", group="asset",
extra={}, watchers=[])
+ assert timetable.asset_condition == ser_asset
+ assert isinstance(timetable.default_partition_mapper, IdentityMapper)
+ assert isinstance(timetable.partition_mapper_config[ser_asset],
IdentityMapper)
diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
index 421a23f938b..e6bb683ebca 100644
--- a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
+++ b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
@@ -23,6 +23,7 @@ import attrs
from airflow.sdk.bases.timetable import BaseTimetable
from airflow.sdk.definitions.asset import AssetAll, BaseAsset
+from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper
if TYPE_CHECKING:
from collections.abc import Collection
@@ -49,7 +50,8 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable):
"""Asset-driven timetable that listens for partitioned assets."""
asset_condition: BaseAsset = attrs.field(alias="assets")
- partition_mapper: PartitionMapper
+ partition_mapper_config: dict[BaseAsset, PartitionMapper] =
attrs.field(factory=dict)
+ default_partition_mapper: PartitionMapper = IdentityMapper()
def _coerce_assets(o: Collection[Asset] | BaseAsset) -> BaseAsset: