This is an automated email from the ASF dual-hosted git repository.
Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b7247b13219 feat(core): add per-mapper max_fan_out override for
partition fan-out cap (#67184)
b7247b13219 is described below
commit b7247b132190afd4d21a3221414ed49a6030d175
Author: Wei Lee <[email protected]>
AuthorDate: Thu Jun 11 10:26:40 2026 +0800
feat(core): add per-mapper max_fan_out override for partition fan-out cap
(#67184)
---
airflow-core/newsfragments/67184.feature.rst | 1 +
airflow-core/src/airflow/assets/manager.py | 19 ++-
.../src/airflow/config_templates/config.yml | 15 +-
.../src/airflow/partition_mappers/allowed_key.py | 10 +-
airflow-core/src/airflow/partition_mappers/base.py | 26 +++-
.../src/airflow/partition_mappers/chain.py | 9 +-
.../src/airflow/partition_mappers/product.py | 13 +-
.../src/airflow/partition_mappers/temporal.py | 32 +++-
airflow-core/src/airflow/serialization/encoders.py | 35 ++++-
airflow-core/tests/unit/assets/test_manager.py | 167 +++++++++++++++++++++
.../unit/partition_mappers/test_allowed_key.py | 17 +++
.../tests/unit/partition_mappers/test_base.py | 57 ++++++-
.../tests/unit/partition_mappers/test_chain.py | 19 ++-
.../tests/unit/partition_mappers/test_fan_out.py | 26 ++++
.../tests/unit/partition_mappers/test_identity.py | 15 ++
.../tests/unit/partition_mappers/test_product.py | 21 ++-
.../tests/unit/partition_mappers/test_temporal.py | 20 ++-
.../definitions/partition_mappers/allowed_key.py | 6 +-
.../sdk/definitions/partition_mappers/base.py | 30 +++-
.../sdk/definitions/partition_mappers/chain.py | 8 +-
.../sdk/definitions/partition_mappers/product.py | 14 +-
.../sdk/definitions/partition_mappers/temporal.py | 66 +++++---
.../task_sdk/definitions/test_partition_mappers.py | 34 +++++
23 files changed, 577 insertions(+), 83 deletions(-)
diff --git a/airflow-core/newsfragments/67184.feature.rst
b/airflow-core/newsfragments/67184.feature.rst
new file mode 100644
index 00000000000..759e6921d51
--- /dev/null
+++ b/airflow-core/newsfragments/67184.feature.rst
@@ -0,0 +1 @@
+Add ``max_downstream_keys`` parameter to ``PartitionMapper`` to override
``[scheduler] partition_mapper_max_downstream_keys`` per mapper instance.
diff --git a/airflow-core/src/airflow/assets/manager.py
b/airflow-core/src/airflow/assets/manager.py
index 5f9e1b598ab..f72c533c5a0 100644
--- a/airflow-core/src/airflow/assets/manager.py
+++ b/airflow-core/src/airflow/assets/manager.py
@@ -552,8 +552,7 @@ class AssetManager(LoggingMixin):
)
return
- max_downstream_keys = conf.getint("scheduler",
"partition_mapper_max_downstream_keys")
-
+ global_cap = conf.getint("scheduler",
"partition_mapper_max_downstream_keys")
for target_dag in partition_dags:
if TYPE_CHECKING:
assert partition_key is not None
@@ -573,9 +572,8 @@ class AssetManager(LoggingMixin):
try:
# We'll need to catch every possible exception happen when
mapping partition_key.
- target_key = timetable.get_partition_mapper(
- name=asset_model.name, uri=asset_model.uri
- ).to_downstream(partition_key)
+ mapper = timetable.get_partition_mapper(name=asset_model.name,
uri=asset_model.uri)
+ target_key = mapper.to_downstream(partition_key)
except Exception as err:
log.exception(
"Could not map partition key for asset in target Dag. "
@@ -607,6 +605,14 @@ class AssetManager(LoggingMixin):
target_keys = [target_key]
del target_key
+ mapper_cap = mapper.max_downstream_keys
+ if mapper_cap is not None:
+ max_downstream_keys = mapper_cap
+ cap_source = f"max_downstream_keys={mapper_cap}"
+ else:
+ max_downstream_keys = global_cap
+ cap_source = f"[scheduler]
partition_mapper_max_downstream_keys={global_cap}"
+
if len(target_keys) > max_downstream_keys:
log.error(
"Partition mapper produced more downstream keys than
allowed; skipping queue.",
@@ -615,6 +621,7 @@ class AssetManager(LoggingMixin):
target_dag=target_dag.dag_id,
produced_keys=len(target_keys),
max_downstream_keys=max_downstream_keys,
+ cap_source=cap_source,
)
session.add(
Log(
@@ -624,7 +631,7 @@ class AssetManager(LoggingMixin):
f"uri='{asset_model.uri}') in target Dag
'{target_dag.dag_id}' "
f"produced {len(target_keys)} downstream keys from
"
f"partition_key='{partition_key}', exceeding "
- f"[scheduler]
partition_mapper_max_downstream_keys={max_downstream_keys}. "
+ f"{cap_source}. "
f"No Dag runs were queued for this event."
),
task_instance=task_instance,
diff --git a/airflow-core/src/airflow/config_templates/config.yml
b/airflow-core/src/airflow/config_templates/config.yml
index 9ed0433c6c4..eebe2d8c951 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -2690,13 +2690,14 @@ scheduler:
see_also: ":ref:`scheduler:ha:tunables`"
partition_mapper_max_downstream_keys:
description: |
- Maximum number of downstream partition keys a single
``PartitionMapper``
- invocation may produce. When any partition mapper (built-in or custom)
- expands one upstream key into more keys than this limit, the scheduler
- skips queuing the runs for that asset event and logs an error against
- the source task instance. This guards against a misconfigured
- ``PartitionMapper`` from queuing an unbounded number of Dag runs per
- upstream event.
+ Maximum number of downstream partition keys produced by a single
+ PartitionMapper invocation, applied to any PartitionMapper that returns
+ multiple keys (e.g. FanOutMapper). When a mapper instance sets a
per-instance
+ ``max_downstream_keys`` parameter, that value completely overrides
this global
+ cap for that instance — including when the per-mapper value exceeds
this
+ global. **Deployment managers cannot enforce this setting as a hard
+ cluster-wide ceiling**; treat this value as a default that user code
may
+ override.
version_added: 3.3.0
type: integer
example: ~
diff --git a/airflow-core/src/airflow/partition_mappers/allowed_key.py
b/airflow-core/src/airflow/partition_mappers/allowed_key.py
index 8b560f426aa..ff475eef0d2 100644
--- a/airflow-core/src/airflow/partition_mappers/allowed_key.py
+++ b/airflow-core/src/airflow/partition_mappers/allowed_key.py
@@ -25,7 +25,8 @@ from airflow.partition_mappers.base import PartitionMapper
class AllowedKeyMapper(PartitionMapper):
"""Partition mapper that validates keys against a set of allowed keys."""
- def __init__(self, allowed_keys: list[str]) -> None:
+ def __init__(self, allowed_keys: list[str], *, max_downstream_keys: int |
None = None) -> None:
+ super().__init__(max_downstream_keys=max_downstream_keys)
self.allowed_keys = allowed_keys
def to_downstream(self, key: str) -> str:
@@ -34,8 +35,11 @@ class AllowedKeyMapper(PartitionMapper):
return key
def serialize(self) -> dict[str, Any]:
- return {"allowed_keys": self.allowed_keys}
+ data: dict[str, Any] = {"allowed_keys": self.allowed_keys}
+ if self.max_downstream_keys is not None:
+ data["max_downstream_keys"] = self.max_downstream_keys
+ return data
@classmethod
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
- return cls(allowed_keys=data["allowed_keys"])
+ return cls(allowed_keys=data["allowed_keys"],
max_downstream_keys=data.get("max_downstream_keys"))
diff --git a/airflow-core/src/airflow/partition_mappers/base.py
b/airflow-core/src/airflow/partition_mappers/base.py
index ee70842cdcf..4226eb44916 100644
--- a/airflow-core/src/airflow/partition_mappers/base.py
+++ b/airflow-core/src/airflow/partition_mappers/base.py
@@ -36,6 +36,15 @@ class PartitionMapper(ABC):
is_rollup: ClassVar[bool] = False
+ def __init__(self, *, max_downstream_keys: int | None = None) -> None:
+ if max_downstream_keys is not None and (
+ not isinstance(max_downstream_keys, int) or max_downstream_keys < 1
+ ):
+ raise ValueError(
+ f"max_downstream_keys must be a positive integer or None, got
{max_downstream_keys!r}"
+ )
+ self.max_downstream_keys = max_downstream_keys
+
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
decode_overridden = cls.decode_downstream is not
PartitionMapper.decode_downstream
@@ -107,11 +116,13 @@ class PartitionMapper(ABC):
return None
def serialize(self) -> dict[str, Any]:
- return {}
+ if self.max_downstream_keys is None:
+ return {}
+ return {"max_downstream_keys": self.max_downstream_keys}
@classmethod
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
- return cls()
+ return cls(max_downstream_keys=data.get("max_downstream_keys"))
class RollupMapper(PartitionMapper):
@@ -126,7 +137,9 @@ class RollupMapper(PartitionMapper):
is_rollup: ClassVar[bool] = True
- def __init__(self, *, upstream_mapper: PartitionMapper, window: Window) ->
None:
+ def __init__(
+ self, *, upstream_mapper: PartitionMapper, window: Window,
max_downstream_keys: int | None = None
+ ) -> None:
decode_overridden = type(upstream_mapper).decode_downstream is not
PartitionMapper.decode_downstream
if not decode_overridden and window.expected_decoded_type is not str:
raise TypeError(
@@ -138,6 +151,7 @@ class RollupMapper(PartitionMapper):
f"{window.expected_decoded_type.__name__}, or use a window
whose "
f"'expected_decoded_type' accepts str."
)
+ super().__init__(max_downstream_keys=max_downstream_keys)
self.upstream_mapper = upstream_mapper
self.window = window
@@ -160,10 +174,13 @@ class RollupMapper(PartitionMapper):
def serialize(self) -> dict[str, Any]:
from airflow.serialization.encoders import encode_partition_mapper,
encode_window
- return {
+ data: dict[str, Any] = {
"upstream_mapper": encode_partition_mapper(self.upstream_mapper),
"window": encode_window(self.window),
}
+ if self.max_downstream_keys is not None:
+ data["max_downstream_keys"] = self.max_downstream_keys
+ return data
@classmethod
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
@@ -172,6 +189,7 @@ class RollupMapper(PartitionMapper):
return cls(
upstream_mapper=decode_partition_mapper(data["upstream_mapper"]),
window=decode_window(data["window"]),
+ max_downstream_keys=data.get("max_downstream_keys"),
)
diff --git a/airflow-core/src/airflow/partition_mappers/chain.py
b/airflow-core/src/airflow/partition_mappers/chain.py
index a4a2110c256..6c0749feba2 100644
--- a/airflow-core/src/airflow/partition_mappers/chain.py
+++ b/airflow-core/src/airflow/partition_mappers/chain.py
@@ -35,7 +35,9 @@ class ChainMapper(PartitionMapper):
mapper1: PartitionMapper,
/,
*mappers: PartitionMapper,
+ max_downstream_keys: int | None = None,
) -> None:
+ super().__init__(max_downstream_keys=max_downstream_keys)
self.mappers = [mapper0, mapper1, *mappers]
def to_downstream(self, key: str) -> str | Iterable[str]:
@@ -70,11 +72,14 @@ class ChainMapper(PartitionMapper):
def serialize(self) -> dict[str, Any]:
from airflow.serialization.encoders import encode_partition_mapper
- return {"mappers": [encode_partition_mapper(m) for m in self.mappers]}
+ result: dict[str, Any] = {"mappers": [encode_partition_mapper(m) for m
in self.mappers]}
+ if self.max_downstream_keys is not None:
+ result["max_downstream_keys"] = self.max_downstream_keys
+ return result
@classmethod
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
from airflow.serialization.decoders import decode_partition_mapper
mappers = [decode_partition_mapper(m) for m in data["mappers"]]
- return cls(*mappers)
+ return cls(*mappers,
max_downstream_keys=data.get("max_downstream_keys"))
diff --git a/airflow-core/src/airflow/partition_mappers/product.py
b/airflow-core/src/airflow/partition_mappers/product.py
index 5d8882bbbcc..e7c91aad082 100644
--- a/airflow-core/src/airflow/partition_mappers/product.py
+++ b/airflow-core/src/airflow/partition_mappers/product.py
@@ -32,7 +32,9 @@ class ProductMapper(PartitionMapper):
/,
*mappers: PartitionMapper,
delimiter: str = "|",
+ max_downstream_keys: int | None = None,
) -> None:
+ super().__init__(max_downstream_keys=max_downstream_keys)
self.mappers = [mapper0, mapper1, *mappers]
self.delimiter = delimiter
@@ -54,14 +56,21 @@ class ProductMapper(PartitionMapper):
def serialize(self) -> dict[str, Any]:
from airflow.serialization.encoders import encode_partition_mapper
- return {
+ result: dict[str, Any] = {
"delimiter": self.delimiter,
"mappers": [encode_partition_mapper(m) for m in self.mappers],
}
+ if self.max_downstream_keys is not None:
+ result["max_downstream_keys"] = self.max_downstream_keys
+ return result
@classmethod
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
from airflow.serialization.decoders import decode_partition_mapper
mappers = [decode_partition_mapper(m) for m in data["mappers"]]
- return cls(*mappers, delimiter=data.get("delimiter", "|"))
+ return cls(
+ *mappers,
+ delimiter=data.get("delimiter", "|"),
+ max_downstream_keys=data.get("max_downstream_keys"),
+ )
diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py
b/airflow-core/src/airflow/partition_mappers/temporal.py
index 622f0c6a556..1938a18f84d 100644
--- a/airflow-core/src/airflow/partition_mappers/temporal.py
+++ b/airflow-core/src/airflow/partition_mappers/temporal.py
@@ -163,7 +163,9 @@ class _BaseTemporalMapper(PartitionMapper):
timezone: str | Timezone | FixedTimezone = "UTC",
input_format: str = "%Y-%m-%dT%H:%M:%S",
output_format: str | None = None,
+ max_downstream_keys: int | None = None,
):
+ super().__init__(max_downstream_keys=max_downstream_keys)
self.input_format = input_format
self.output_format = output_format or self.default_output_format
if isinstance(timezone, str):
@@ -230,11 +232,14 @@ class _BaseTemporalMapper(PartitionMapper):
def serialize(self) -> dict[str, Any]:
from airflow.serialization.encoders import encode_timezone
- return {
+ result: dict[str, Any] = {
"timezone": encode_timezone(self._timezone),
"input_format": self.input_format,
"output_format": self.output_format,
}
+ if self.max_downstream_keys is not None:
+ result["max_downstream_keys"] = self.max_downstream_keys
+ return result
@classmethod
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
@@ -242,6 +247,7 @@ class _BaseTemporalMapper(PartitionMapper):
timezone=parse_timezone(data.get("timezone", "UTC")),
input_format=data["input_format"],
output_format=data["output_format"],
+ max_downstream_keys=data.get("max_downstream_keys"),
)
@@ -286,6 +292,7 @@ class StartOfWeekMapper(_BaseTemporalMapper):
timezone: str | Timezone | FixedTimezone = "UTC",
input_format: str = "%Y-%m-%dT%H:%M:%S",
output_format: str | None = None,
+ max_downstream_keys: int | None = None,
) -> None:
"""
Compile *output_format* eagerly so malformed patterns raise here.
@@ -302,7 +309,12 @@ class StartOfWeekMapper(_BaseTemporalMapper):
**must** include ``%Y``, ``%m``, and ``%d`` so the week-start date
can be recovered for ``to_upstream``.
"""
- super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
+ super().__init__(
+ timezone=timezone,
+ input_format=input_format,
+ output_format=output_format,
+ max_downstream_keys=max_downstream_keys,
+ )
# %V (ISO week) cannot be round-tripped through strptime without %G+%u,
# so derive a named-group regex from output_format and pull out
%Y/%m/%d.
# Compile eagerly so a malformed output_format raises ValueError here
@@ -358,6 +370,7 @@ class StartOfQuarterMapper(_BaseTemporalMapper):
timezone: str | Timezone | FixedTimezone = "UTC",
input_format: str = "%Y-%m-%dT%H:%M:%S",
output_format: str | None = None,
+ max_downstream_keys: int | None = None,
) -> None:
"""
Compile *output_format* eagerly so malformed patterns raise here.
@@ -376,7 +389,12 @@ class StartOfQuarterMapper(_BaseTemporalMapper):
and ``{quarter}`` so the quarter-start date can be recovered for
``to_upstream``.
"""
- super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
+ super().__init__(
+ timezone=timezone,
+ input_format=input_format,
+ output_format=output_format,
+ max_downstream_keys=max_downstream_keys,
+ )
# ``{quarter}`` is a Python-format placeholder, not a strftime
directive,
# so derive a named-group regex from output_format that handles both.
# Compile eagerly so a malformed output_format raises ValueError here
@@ -511,7 +529,9 @@ class FanOutMapper(PartitionMapper):
upstream_mapper: PartitionMapper,
window: Window,
downstream_mapper: PartitionMapper | None = None,
+ max_downstream_keys: int | None = None,
) -> None:
+ super().__init__(max_downstream_keys=max_downstream_keys)
self.upstream_mapper = upstream_mapper
self.window = window
self.downstream_mapper = downstream_mapper or
self._resolve_default_downstream_mapper(window)
@@ -537,11 +557,14 @@ class FanOutMapper(PartitionMapper):
def serialize(self) -> dict[str, Any]:
from airflow.serialization.encoders import encode_partition_mapper,
encode_window
- return {
+ result: dict[str, Any] = {
"upstream_mapper": encode_partition_mapper(self.upstream_mapper),
"window": encode_window(self.window),
"downstream_mapper":
encode_partition_mapper(self.downstream_mapper),
}
+ if self.max_downstream_keys is not None:
+ result["max_downstream_keys"] = self.max_downstream_keys
+ return result
@classmethod
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
@@ -551,6 +574,7 @@ class FanOutMapper(PartitionMapper):
upstream_mapper=decode_partition_mapper(data["upstream_mapper"]),
window=decode_window(data["window"]),
downstream_mapper=decode_partition_mapper(data["downstream_mapper"]),
+ max_downstream_keys=data.get("max_downstream_keys"),
)
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index 0c2f030da8f..3e804393b53 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -461,11 +461,17 @@ class _Serializer:
@serialize_partition_mapper.register
def _(self, partition_mapper: ChainMapper) -> dict[str, Any]:
- return {"mappers": [encode_partition_mapper(m) for m in
partition_mapper.mappers]}
+ data: dict[str, Any] = {"mappers": [encode_partition_mapper(m) for m
in partition_mapper.mappers]}
+ if partition_mapper.max_downstream_keys is not None:
+ data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+ return data
@serialize_partition_mapper.register
def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
- return {}
+ data: dict[str, Any] = {}
+ if partition_mapper.max_downstream_keys is not None:
+ data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+ return data
@serialize_partition_mapper.register
def _(self, partition_mapper: FixedKeyMapper) -> dict[str, Any]:
@@ -486,37 +492,52 @@ class _Serializer:
| StartOfQuarterMapper
| StartOfYearMapper,
) -> dict[str, Any]:
- return {
+ data: dict[str, Any] = {
"timezone": encode_timezone(partition_mapper._timezone),
"input_format": partition_mapper.input_format,
"output_format": partition_mapper.output_format,
}
+ if partition_mapper.max_downstream_keys is not None:
+ data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+ return data
@serialize_partition_mapper.register
def _(self, partition_mapper: ProductMapper) -> dict[str, Any]:
- return {
+ data: dict[str, Any] = {
"delimiter": partition_mapper.delimiter,
"mappers": [encode_partition_mapper(m) for m in
partition_mapper.mappers],
}
+ if partition_mapper.max_downstream_keys is not None:
+ data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+ return data
@serialize_partition_mapper.register
def _(self, partition_mapper: AllowedKeyMapper) -> dict[str, Any]:
- return {"allowed_keys": partition_mapper.allowed_keys}
+ data: dict[str, Any] = {"allowed_keys": partition_mapper.allowed_keys}
+ if partition_mapper.max_downstream_keys is not None:
+ data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+ return data
@serialize_partition_mapper.register
def _(self, partition_mapper: RollupMapper) -> dict[str, Any]:
- return {
+ data: dict[str, Any] = {
"upstream_mapper":
encode_partition_mapper(partition_mapper.upstream_mapper),
"window": encode_window(partition_mapper.window),
}
+ if partition_mapper.max_downstream_keys is not None:
+ data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+ return data
@serialize_partition_mapper.register
def _(self, partition_mapper: FanOutMapper) -> dict[str, Any]:
- return {
+ data: dict[str, Any] = {
"upstream_mapper":
encode_partition_mapper(partition_mapper.upstream_mapper),
"window": encode_window(partition_mapper.window),
"downstream_mapper":
encode_partition_mapper(partition_mapper.downstream_mapper),
}
+ if partition_mapper.max_downstream_keys is not None:
+ data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+ return data
BUILTIN_WINDOWS: dict[type, str] = {
HourWindow: "airflow.partition_mappers.window.HourWindow",
diff --git a/airflow-core/tests/unit/assets/test_manager.py
b/airflow-core/tests/unit/assets/test_manager.py
index 47da4d0db2e..687f4d17847 100644
--- a/airflow-core/tests/unit/assets/test_manager.py
+++ b/airflow-core/tests/unit/assets/test_manager.py
@@ -473,6 +473,173 @@ class TestAssetManager:
assert error_call.kwargs["source_partition_key"] ==
"2024-06-03T00:00:00"
assert error_call.kwargs["produced_keys"] == 7
assert error_call.kwargs["max_downstream_keys"] == cap
+ assert error_call.kwargs["cap_source"] == f"[scheduler]
partition_mapper_max_downstream_keys={cap}"
+
+ @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "100"})
+ @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+ def test_partition_fanout_per_mapper_override_stricter_than_global_trips(
+ self, session, dag_maker, mock_task_instance
+ ):
+ """Per-mapper max_downstream_keys=3 trips even when the global cap is
100.
+
+ Proves the per-mapper override takes precedence over a more permissive
global.
+ The Log.extra must mention 'max_downstream_keys=3' and must NOT mention
+ 'partition_mapper_max_downstream_keys' (i.e. the global cap name is
absent from the message).
+ """
+ clear_db_apdr()
+ clear_db_pakl()
+ clear_db_logs()
+
+ asset_def = Asset(uri="s3://bucket/per_mapper_strict",
name="per_mapper_strict")
+ # WeekWindow produces 7 daily keys; per-mapper cap of 3 must trip
first.
+ mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=WeekWindow(), max_downstream_keys=3)
+ with dag_maker(
+ dag_id="per_mapper_strict_dag",
+ schedule=PartitionedAssetTimetable(assets=asset_def,
partition_mapper_config={asset_def: mapper}),
+ serialized=True,
+ ):
+ EmptyOperator(task_id="t")
+ dag_maker.create_dagrun()
+ dag_maker.sync_dagbag_to_db()
+
+ with mock.patch("airflow.assets.manager.log") as mock_log:
+ AssetManager.register_asset_change(
+ task_instance=mock_task_instance,
+ asset=asset_def,
+ session=session,
+ partition_key="2024-06-03T00:00:00",
+ )
+ session.flush()
+
+ assert
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 0
+ log_extras = session.scalars(select(Log.extra).where(Log.event ==
"partition fan-out exceeded")).all()
+ assert len(log_extras) == 1
+ assert "max_downstream_keys=3" in log_extras[0]
+ assert "partition_mapper_max_downstream_keys" not in log_extras[0]
+ # Pin the scheduler-log error kwargs for the per-mapper path
symmetrically
+ # with the global-cap path in test_partition_fan_out_cap.
+ mock_log.error.assert_called_once()
+ error_call = mock_log.error.call_args
+ assert error_call.kwargs["cap_source"] == "max_downstream_keys=3"
+
+ @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "3"})
+ @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+ def test_partition_fanout_per_mapper_override_looser_than_global_permits(
+ self, session, dag_maker, mock_task_instance
+ ):
+ """Per-mapper max_downstream_keys=10 permits 7 keys even when the
global cap is 3.
+
+ Proves the per-mapper override can relax, not just tighten, the cap.
+ """
+ clear_db_apdr()
+ clear_db_pakl()
+ clear_db_logs()
+
+ asset_def = Asset(uri="s3://bucket/per_mapper_loose",
name="per_mapper_loose")
+ # Global cap of 3 would block the 7-key WeekWindow fanout, but
per-mapper
+ # max_downstream_keys=10 overrides it and all 7 rows must be queued.
+ mapper = FanOutMapper(
+ upstream_mapper=StartOfWeekMapper(), window=WeekWindow(),
max_downstream_keys=10
+ )
+ with dag_maker(
+ dag_id="per_mapper_loose_dag",
+ schedule=PartitionedAssetTimetable(assets=asset_def,
partition_mapper_config={asset_def: mapper}),
+ serialized=True,
+ ):
+ EmptyOperator(task_id="t")
+ dag_maker.create_dagrun()
+ dag_maker.sync_dagbag_to_db()
+
+ AssetManager.register_asset_change(
+ task_instance=mock_task_instance,
+ asset=asset_def,
+ session=session,
+ partition_key="2024-06-03T00:00:00",
+ )
+ session.flush()
+
+ assert
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 7
+ assert (
+ session.scalar(
+ select(func.count()).select_from(Log).where(Log.event ==
"partition fan-out exceeded")
+ )
+ == 0
+ )
+
+ @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "1"})
+ @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+ def test_partition_fanout_per_mapper_at_cap_is_allowed(self, session,
dag_maker, mock_task_instance):
+ """Per-mapper max_downstream_keys=7 with a 7-key fanout: exactly at
cap is allowed.
+
+ Pairs with test_partition_fanout_per_mapper_one_over_cap_trips to pin
the
+ boundary at '>' (not '>=') on the per-mapper branch.
+ """
+ clear_db_apdr()
+ clear_db_pakl()
+ clear_db_logs()
+
+ asset_def = Asset(uri="s3://bucket/per_mapper_at_cap",
name="per_mapper_at_cap")
+ mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=WeekWindow(), max_downstream_keys=7)
+ with dag_maker(
+ dag_id="per_mapper_at_cap_dag",
+ schedule=PartitionedAssetTimetable(assets=asset_def,
partition_mapper_config={asset_def: mapper}),
+ serialized=True,
+ ):
+ EmptyOperator(task_id="t")
+ dag_maker.create_dagrun()
+ dag_maker.sync_dagbag_to_db()
+
+ AssetManager.register_asset_change(
+ task_instance=mock_task_instance,
+ asset=asset_def,
+ session=session,
+ partition_key="2024-06-03T00:00:00",
+ )
+ session.flush()
+
+ assert
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 7
+ assert (
+ session.scalar(
+ select(func.count()).select_from(Log).where(Log.event ==
"partition fan-out exceeded")
+ )
+ == 0
+ )
+
+ @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "1"})
+ @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+ def test_partition_fanout_per_mapper_one_over_cap_trips(self, session,
dag_maker, mock_task_instance):
+ """Per-mapper max_downstream_keys=6 with a 7-key fanout: one over cap
trips the guard.
+
+ Pairs with test_partition_fanout_per_mapper_at_cap_is_allowed to lock
the
+ boundary: 7 keys at cap=7 is allowed, but 7 keys at cap=6 is not.
+ """
+ clear_db_apdr()
+ clear_db_pakl()
+ clear_db_logs()
+
+ asset_def = Asset(uri="s3://bucket/per_mapper_over_cap",
name="per_mapper_over_cap")
+ mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=WeekWindow(), max_downstream_keys=6)
+ with dag_maker(
+ dag_id="per_mapper_over_cap_dag",
+ schedule=PartitionedAssetTimetable(assets=asset_def,
partition_mapper_config={asset_def: mapper}),
+ serialized=True,
+ ):
+ EmptyOperator(task_id="t")
+ dag_maker.create_dagrun()
+ dag_maker.sync_dagbag_to_db()
+
+ AssetManager.register_asset_change(
+ task_instance=mock_task_instance,
+ asset=asset_def,
+ session=session,
+ partition_key="2024-06-03T00:00:00",
+ )
+ session.flush()
+
+ assert
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 0
+ log_extras = session.scalars(select(Log.extra).where(Log.event ==
"partition fan-out exceeded")).all()
+ assert len(log_extras) == 1
+ assert "max_downstream_keys=6" in log_extras[0]
def _make_dag(dag_id: str) -> DagModel:
diff --git a/airflow-core/tests/unit/partition_mappers/test_allowed_key.py
b/airflow-core/tests/unit/partition_mappers/test_allowed_key.py
index a04b22e48e9..6fb47b5be31 100644
--- a/airflow-core/tests/unit/partition_mappers/test_allowed_key.py
+++ b/airflow-core/tests/unit/partition_mappers/test_allowed_key.py
@@ -19,6 +19,9 @@ from __future__ import annotations
import pytest
from airflow.partition_mappers.allowed_key import AllowedKeyMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
class TestAllowedKeyMapper:
@@ -46,3 +49,17 @@ class TestAllowedKeyMapper:
assert pm.serialize() == {"allowed_keys": []}
with pytest.raises(ValueError, match="not in allowed keys"):
pm.to_downstream("any")
+
+ def test_max_downstream_keys_encode_decode_roundtrip(self):
+ """max_downstream_keys=5 survives encode_partition_mapper →
decode_partition_mapper."""
+
+ mapper = AllowedKeyMapper(["us", "eu", "apac", "latam", "africa"],
max_downstream_keys=5)
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys == 5
+
+ def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+ """max_downstream_keys must NOT appear in the encoded payload when not
set (zero-bloat contract)."""
+
+ mapper = AllowedKeyMapper(["us", "eu"])
+ encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+ assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_base.py
b/airflow-core/tests/unit/partition_mappers/test_base.py
index a313acd7e1c..e7b0ba02f07 100644
--- a/airflow-core/tests/unit/partition_mappers/test_base.py
+++ b/airflow-core/tests/unit/partition_mappers/test_base.py
@@ -16,9 +16,17 @@
# under the License.
from __future__ import annotations
+import re
+
import pytest
-from airflow.partition_mappers.base import PartitionMapper
+from airflow.partition_mappers.base import PartitionMapper, RollupMapper
+from airflow.partition_mappers.identity import IdentityMapper
+from airflow.partition_mappers.temporal import StartOfDayMapper
+from airflow.partition_mappers.window import DayWindow
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
class TestPartitionMapperInitSubclass:
@@ -125,3 +133,50 @@ class TestRollupMapperInit:
# Should not raise.
RollupMapper(upstream_mapper=_StringOnlyMapper(),
window=_AlphaWindow())
+
+
+class TestPartitionMapperMaxDownstreamKeysValidator:
+ """Verify the max_downstream_keys validator on the PartitionMapper base
class.
+
+ Uses IdentityMapper as the most lightweight concrete subclass — the
+ validator lives on the base class so any subclass exercises it.
+ """
+
+ def test_max_downstream_keys_none_is_accepted(self):
+ """Default (None) leaves max_downstream_keys as None."""
+ mapper = IdentityMapper()
+ assert mapper.max_downstream_keys is None
+
+ def test_max_downstream_keys_one_is_accepted(self):
+ """Minimum positive integer value is accepted."""
+ mapper = IdentityMapper(max_downstream_keys=1)
+ assert mapper.max_downstream_keys == 1
+
+ @pytest.mark.parametrize(
+ "bad_value",
+ [
+ pytest.param(0, id="zero"),
+ pytest.param(-1, id="negative"),
+ pytest.param(1.0, id="float"),
+ pytest.param("5", id="string"),
+ ],
+ )
+ def test_max_downstream_keys_invalid_raises(self, bad_value):
+ """Reject non-positive-integer values with the full validator
message."""
+ with pytest.raises(
+ ValueError,
+ match=re.escape(f"max_downstream_keys must be a positive integer
or None, got {bad_value!r}"),
+ ):
+ IdentityMapper(max_downstream_keys=bad_value)
+
+
+class TestRollupMapperMaxDownstreamKeys:
+ def test_max_downstream_keys_encode_decode_roundtrip(self):
+ mapper = RollupMapper(upstream_mapper=StartOfDayMapper(),
window=DayWindow(), max_downstream_keys=5)
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys == 5
+
+ def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+ mapper = RollupMapper(upstream_mapper=StartOfDayMapper(),
window=DayWindow())
+ encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+ assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_chain.py
b/airflow-core/tests/unit/partition_mappers/test_chain.py
index 528a176b1e0..a70230b2230 100644
--- a/airflow-core/tests/unit/partition_mappers/test_chain.py
+++ b/airflow-core/tests/unit/partition_mappers/test_chain.py
@@ -25,6 +25,9 @@ from airflow.partition_mappers.base import PartitionMapper
from airflow.partition_mappers.chain import ChainMapper
from airflow.partition_mappers.identity import IdentityMapper
from airflow.partition_mappers.temporal import StartOfDayMapper,
StartOfHourMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
class _InvalidReturnMapper(PartitionMapper):
@@ -69,8 +72,6 @@ class TestChainMapper:
sm.to_downstream("key")
def test_serialize(self):
- from airflow.serialization.encoders import encode_partition_mapper
-
sm = ChainMapper(StartOfHourMapper(),
StartOfDayMapper(input_format="%Y-%m-%dT%H"))
result = sm.serialize()
assert result == {
@@ -87,3 +88,17 @@ class TestChainMapper:
assert isinstance(restored, ChainMapper)
assert len(restored.mappers) == 2
assert restored.to_downstream("2024-01-15T10:30:00") == "2024-01-15"
+
+ def test_max_downstream_keys_encode_decode_roundtrip(self):
+ """max_downstream_keys=5 survives encode_partition_mapper →
decode_partition_mapper."""
+ mapper = ChainMapper(
+ StartOfHourMapper(), StartOfDayMapper(input_format="%Y-%m-%dT%H"),
max_downstream_keys=5
+ )
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys == 5
+
+ def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+ """max_downstream_keys must NOT appear in the encoded payload when not
set (zero-bloat contract)."""
+ mapper = ChainMapper(StartOfHourMapper(),
StartOfDayMapper(input_format="%Y-%m-%dT%H"))
+ encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+ assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_fan_out.py
b/airflow-core/tests/unit/partition_mappers/test_fan_out.py
index 17a455a4bbc..d36fea48a63 100644
--- a/airflow-core/tests/unit/partition_mappers/test_fan_out.py
+++ b/airflow-core/tests/unit/partition_mappers/test_fan_out.py
@@ -39,6 +39,9 @@ from airflow.partition_mappers.window import (
Window,
YearWindow,
)
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
class TestFanOutMapper:
@@ -311,3 +314,26 @@ class TestFanOutMapper:
assert list(restored.to_downstream("2024-03-04T00:00:00")) == list(
mapper.to_downstream("2024-03-04T00:00:00")
)
+
+ def test_max_downstream_keys_encode_decode_roundtrip(self):
+ """max_downstream_keys=5 survives encode_partition_mapper →
decode_partition_mapper."""
+ mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=WeekWindow(), max_downstream_keys=5)
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys == 5
+
+ def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+ """max_downstream_keys must NOT appear in the encoded payload when not
set (zero-bloat contract)."""
+ mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=WeekWindow())
+ encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+ assert "max_downstream_keys" not in encoded_var
+
+ def test_max_downstream_keys_defaults_to_none_when_absent(self):
+ """A payload lacking max_downstream_keys (e.g. serialized by an older
Airflow) decodes to None.
+
+ This is the real backward-compatibility path: the scheduler reads
+ ``mapper.max_downstream_keys`` directly, so the attribute must always
+ exist — ``deserialize`` defaults it to None when the key is absent.
+ """
+ mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=WeekWindow())
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys is None
diff --git a/airflow-core/tests/unit/partition_mappers/test_identity.py
b/airflow-core/tests/unit/partition_mappers/test_identity.py
index 282be34bfcc..9a97e583db5 100644
--- a/airflow-core/tests/unit/partition_mappers/test_identity.py
+++ b/airflow-core/tests/unit/partition_mappers/test_identity.py
@@ -17,6 +17,9 @@
from __future__ import annotations
from airflow.partition_mappers.identity import IdentityMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
class TestIdentityMapper:
@@ -30,3 +33,15 @@ class TestIdentityMapper:
def test_deserialize(self):
assert isinstance(IdentityMapper.deserialize({}), IdentityMapper)
+
+ def test_max_downstream_keys_encode_decode_roundtrip(self):
+ """max_downstream_keys=5 survives encode_partition_mapper →
decode_partition_mapper."""
+ mapper = IdentityMapper(max_downstream_keys=5)
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys == 5
+
+ def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+ """max_downstream_keys must NOT appear in the encoded payload when not
set (zero-bloat contract)."""
+ mapper = IdentityMapper()
+ encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+ assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_product.py
b/airflow-core/tests/unit/partition_mappers/test_product.py
index 17269f261ea..9d6eba1eb21 100644
--- a/airflow-core/tests/unit/partition_mappers/test_product.py
+++ b/airflow-core/tests/unit/partition_mappers/test_product.py
@@ -22,6 +22,9 @@ import pytest
from airflow.partition_mappers.identity import IdentityMapper
from airflow.partition_mappers.product import ProductMapper
from airflow.partition_mappers.temporal import StartOfDayMapper,
StartOfHourMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
class TestProductMapper:
@@ -49,8 +52,6 @@ class TestProductMapper:
pm.to_downstream("2024-01-15T10:30:00::2024-01-15T10:30:00::extra")
def test_serialize(self):
- from airflow.serialization.encoders import encode_partition_mapper
-
pm = ProductMapper(StartOfHourMapper(), StartOfDayMapper())
result = pm.serialize()
assert result == {
@@ -62,8 +63,6 @@ class TestProductMapper:
}
def test_serialize_custom_delimiter(self):
- from airflow.serialization.encoders import encode_partition_mapper
-
pm = ProductMapper(StartOfHourMapper(), StartOfDayMapper(),
delimiter="::")
result = pm.serialize()
assert result == {
@@ -95,8 +94,6 @@ class TestProductMapper:
def test_deserialize_backward_compat(self):
"""Deserializing data without delimiter field defaults to '|'."""
- from airflow.serialization.encoders import encode_partition_mapper
-
data = {
"mappers": [
encode_partition_mapper(StartOfHourMapper()),
@@ -111,3 +108,15 @@ class TestProductMapper:
assert (
pm.to_downstream("2024-01-15T10:30:00|2024-01-15T10:30:00|raw") ==
"2024-01-15T10|2024-01-15|raw"
)
+
+ def test_max_downstream_keys_encode_decode_roundtrip(self):
+ """max_downstream_keys=5 survives encode_partition_mapper →
decode_partition_mapper."""
+ mapper = ProductMapper(StartOfHourMapper(), StartOfDayMapper(),
max_downstream_keys=5)
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys == 5
+
+ def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+ """max_downstream_keys must NOT appear in the encoded payload when not
set (zero-bloat contract)."""
+ mapper = ProductMapper(StartOfHourMapper(), StartOfDayMapper())
+ encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+ assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py
b/airflow-core/tests/unit/partition_mappers/test_temporal.py
index ab2a79f52ce..e777017b51a 100644
--- a/airflow-core/tests/unit/partition_mappers/test_temporal.py
+++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py
@@ -20,6 +20,7 @@ from datetime import datetime, timezone as dt_timezone
import pendulum
import pytest
+from pendulum.tz.exceptions import InvalidTimezone
from airflow import sdk
from airflow.partition_mappers.base import RollupMapper
@@ -38,6 +39,7 @@ from airflow.partition_mappers.temporal import (
from airflow.partition_mappers.window import HourWindow, WeekWindow
from airflow.serialization.decoders import decode_partition_mapper
from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
class TestTemporalMappers:
@@ -297,8 +299,6 @@ class TestSdkTemporalMappersTimezoneSerialization:
def test_sdk_constructor_invalid_timezone_raises_eagerly(self,
sdk_mapper_name):
"""Passing an unknown timezone string must raise at construction time
(via ``parse_timezone``), not silently fall back to UTC or fail
later."""
- from pendulum.tz.exceptions import InvalidTimezone
-
sdk_cls = getattr(sdk, sdk_mapper_name)
with pytest.raises(InvalidTimezone):
sdk_cls(timezone="Not/A/Real/Zone")
@@ -447,3 +447,19 @@ class TestToPartitionDateDelegation:
)
def test_to_partition_date(self, mapper, downstream_key, expected):
assert mapper.to_partition_date(downstream_key) == expected
+
+
+class TestTemporalMapperMaxDownstreamKeys:
+ """Round-trip and zero-bloat tests for max_downstream_keys on temporal
mappers."""
+
+ def test_max_downstream_keys_encode_decode_roundtrip(self):
+ """max_downstream_keys=5 survives encode_partition_mapper →
decode_partition_mapper."""
+ mapper = StartOfWeekMapper(max_downstream_keys=5)
+ restored = decode_partition_mapper(encode_partition_mapper(mapper))
+ assert restored.max_downstream_keys == 5
+
+ def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+ """max_downstream_keys must NOT appear in the encoded payload when not
set (zero-bloat contract)."""
+ mapper = StartOfWeekMapper()
+ encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+ assert "max_downstream_keys" not in encoded_var
diff --git
a/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py
index 7d860bd8796..cc21144d42b 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py
@@ -16,11 +16,13 @@
# under the License.
from __future__ import annotations
+import attrs
+
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
[email protected]
class AllowedKeyMapper(PartitionMapper):
"""Partition mapper that validates keys against a set of allowed keys."""
- def __init__(self, allowed_keys: list[str]) -> None:
- self.allowed_keys = allowed_keys
+ allowed_keys: list[str]
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py
index fc26450f027..351ad4e94fa 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py
@@ -18,10 +18,18 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
+import attrs
+
if TYPE_CHECKING:
from airflow.sdk.definitions.partition_mappers.window import Window
+def _validate_max_downstream_keys(instance, attribute, value):
+ if value is not None and (not isinstance(value, int) or value < 1):
+ raise ValueError(f"max_downstream_keys must be a positive integer or
None, got {value!r}")
+
+
[email protected]
class PartitionMapper:
"""
Base partition mapper class.
@@ -36,7 +44,12 @@ class PartitionMapper:
#: Temporal mappers override to ``datetime``.
expected_decoded_type: ClassVar[type] = str
+ max_downstream_keys: int | None = attrs.field(
+ default=None, kw_only=True, validator=_validate_max_downstream_keys
+ )
+
[email protected]
class RollupMapper(PartitionMapper):
"""
Partition mapper that rolls up many upstream keys into one downstream key.
@@ -49,21 +62,22 @@ class RollupMapper(PartitionMapper):
is_rollup: ClassVar[bool] = True
- def __init__(self, *, upstream_mapper: PartitionMapper, window: Window) ->
None:
+ upstream_mapper: PartitionMapper = attrs.field(kw_only=True)
+ window: Window = attrs.field(kw_only=True)
+
+ def __attrs_post_init__(self) -> None:
# Mirrors the core-side ``RollupMapper.__init__`` check so user code
# ``from airflow.sdk import RollupMapper`` fails at Dag parse time
rather
# than slipping through to the scheduler tick (where the
misconfiguration
# would otherwise be swallowed by the bare ``except`` in
# ``_create_dagruns_for_partitioned_asset_dags`` and surface only as
# "Failed to deserialize Dag" spam).
- if upstream_mapper.expected_decoded_type is str and
window.expected_decoded_type is not str:
+ if self.upstream_mapper.expected_decoded_type is str and
self.window.expected_decoded_type is not str:
raise TypeError(
- f"{type(window).__name__} expects decoded values of type "
- f"{window.expected_decoded_type.__name__!r}, but "
- f"{type(upstream_mapper).__name__} decodes to 'str' (SDK
PartitionMapper default). "
+ f"{type(self.window).__name__} expects decoded values of type "
+ f"{self.window.expected_decoded_type.__name__!r}, but "
+ f"{type(self.upstream_mapper).__name__} decodes to 'str' (SDK
PartitionMapper default). "
f"Pair the window with an upstream mapper whose
'expected_decoded_type' is "
- f"{window.expected_decoded_type.__name__}, or use a window
whose "
+ f"{self.window.expected_decoded_type.__name__}, or use a
window whose "
f"'expected_decoded_type' accepts str."
)
- self.upstream_mapper = upstream_mapper
- self.window = window
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
index 0f96e2d2eaf..6ac3b78d1e4 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
@@ -16,17 +16,23 @@
# under the License.
from __future__ import annotations
+import attrs
+
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
[email protected](init=False)
class ChainMapper(PartitionMapper):
"""Partition mapper that applies multiple mappers sequentially."""
+ mappers: list[PartitionMapper]
+
def __init__(
self,
mapper0: PartitionMapper,
mapper1: PartitionMapper,
/,
*mappers: PartitionMapper,
+ max_downstream_keys: int | None = None,
) -> None:
- self.mappers = [mapper0, mapper1, *mappers]
+ self.__attrs_init__(mappers=[mapper0, mapper1, *mappers],
max_downstream_keys=max_downstream_keys)
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py
index ffc744ff71e..2ed96dbd31f 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py
@@ -16,12 +16,18 @@
# under the License.
from __future__ import annotations
+import attrs
+
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
[email protected](init=False)
class ProductMapper(PartitionMapper):
"""Partition mapper that combines multiple mappers into a
multi-dimensional key."""
+ mappers: list[PartitionMapper]
+ delimiter: str
+
def __init__(
self,
mapper0: PartitionMapper,
@@ -29,6 +35,10 @@ class ProductMapper(PartitionMapper):
/,
*mappers: PartitionMapper,
delimiter: str = "|",
+ max_downstream_keys: int | None = None,
) -> None:
- self.mappers = [mapper0, mapper1, *mappers]
- self.delimiter = delimiter
+ self.__attrs_init__(
+ mappers=[mapper0, mapper1, *mappers],
+ delimiter=delimiter,
+ max_downstream_keys=max_downstream_keys,
+ )
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
index 3eaf4e19d74..0cf0a705d5c 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
@@ -19,6 +19,8 @@ from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, ClassVar
+import attrs
+
from airflow.sdk._shared.timezones.timezone import parse_timezone
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
@@ -28,24 +30,31 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.partition_mappers.window import Window
+def _timezone_converter(value: str | Timezone | FixedTimezone) -> Timezone |
FixedTimezone:
+ if isinstance(value, str):
+ return parse_timezone(value)
+ return value
+
+
[email protected]
class _BaseTemporalMapper(PartitionMapper):
"""Base class for Temporal Partition Mappers."""
- default_output_format: str
+ default_output_format: ClassVar[str]
expected_decoded_type: ClassVar[type] = datetime
- def __init__(
- self,
- *,
- timezone: str | Timezone | FixedTimezone = "UTC",
- input_format: str = "%Y-%m-%dT%H:%M:%S",
- output_format: str | None = None,
- ) -> None:
- self.input_format = input_format
- self.output_format = output_format or self.default_output_format
- if isinstance(timezone, str):
- timezone = parse_timezone(timezone)
- self._timezone = timezone
+ _timezone: str | Timezone | FixedTimezone = attrs.field(
+ alias="timezone",
+ default="UTC",
+ kw_only=True,
+ converter=_timezone_converter,
+ )
+ input_format: str = attrs.field(default="%Y-%m-%dT%H:%M:%S", kw_only=True)
+ output_format: str | None = attrs.field(default=None, kw_only=True)
+
+ def __attrs_post_init__(self) -> None:
+ if not self.output_format:
+ self.output_format = self.default_output_format
class StartOfHourMapper(_BaseTemporalMapper):
@@ -108,6 +117,7 @@ class StartOfYearMapper(_BaseTemporalMapper):
default_output_format = "%Y"
[email protected](init=False)
class FanOutMapper(PartitionMapper):
"""
Partition mapper that fans one upstream key out into multiple downstream
keys.
@@ -161,6 +171,25 @@ class FanOutMapper(PartitionMapper):
"YearWindow": StartOfMonthMapper,
}
+ upstream_mapper: PartitionMapper = attrs.field(kw_only=True)
+ window: Window = attrs.field(kw_only=True)
+ downstream_mapper: PartitionMapper = attrs.field(kw_only=True)
+
+ def __init__(
+ self,
+ *,
+ upstream_mapper: PartitionMapper,
+ window: Window,
+ downstream_mapper: PartitionMapper | None = None,
+ max_downstream_keys: int | None = None,
+ ) -> None:
+ self.__attrs_init__(
+ upstream_mapper=upstream_mapper,
+ window=window,
+ downstream_mapper=downstream_mapper or
type(self)._resolve_default_downstream_mapper(window),
+ max_downstream_keys=max_downstream_keys,
+ )
+
@classmethod
def _resolve_default_downstream_mapper(cls, window: Window) ->
PartitionMapper:
"""
@@ -179,14 +208,3 @@ class FanOutMapper(PartitionMapper):
f"{type(window).__name__}; pass downstream_mapper explicitly."
)
return mapper_cls()
-
- def __init__(
- self,
- *,
- upstream_mapper: PartitionMapper,
- window: Window,
- downstream_mapper: PartitionMapper | None = None,
- ) -> None:
- self.upstream_mapper = upstream_mapper
- self.window = window
- self.downstream_mapper = downstream_mapper or
self._resolve_default_downstream_mapper(window)
diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
index 7114b65d6c4..ca811115cb7 100644
--- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
+++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
@@ -23,6 +23,7 @@ import pytest
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper,
RollupMapper
from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper
+from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper
from airflow.sdk.definitions.partition_mappers.temporal import StartOfDayMapper
from airflow.sdk.definitions.partition_mappers.window import (
DayWindow,
@@ -96,6 +97,39 @@ class TestSdkDirectionValidation:
WeekWindow(direction=bad_value)
+class TestSdkPartitionMapperMaxDownstreamKeysValidator:
+ """Verify the max_downstream_keys attrs field validator on the SDK
PartitionMapper base.
+
+ Uses IdentityMapper as the most lightweight concrete subclass — the
+ validator lives on the base class so any subclass exercises it.
+ Mirrors core's TestPartitionMapperMaxDownstreamKeysValidator (6 cases).
+ """
+
+ def test_max_downstream_keys_none_is_accepted(self):
+ """Default (None) leaves max_downstream_keys as None."""
+ mapper = IdentityMapper()
+ assert mapper.max_downstream_keys is None
+
+ def test_max_downstream_keys_one_is_accepted(self):
+ """Minimum positive integer value is accepted."""
+ mapper = IdentityMapper(max_downstream_keys=1)
+ assert mapper.max_downstream_keys == 1
+
+ @pytest.mark.parametrize(
+ "bad_value",
+ [
+ pytest.param(0, id="zero"),
+ pytest.param(-1, id="negative"),
+ pytest.param(1.0, id="float"),
+ pytest.param("5", id="string"),
+ ],
+ )
+ def test_max_downstream_keys_invalid_raises(self, bad_value):
+ """0, negative integers, floats, and strings are all rejected."""
+ with pytest.raises(ValueError, match="max_downstream_keys"):
+ IdentityMapper(max_downstream_keys=bad_value) # type:
ignore[arg-type]
+
+
class TestSdkWindowExpectedDecodedType:
"""Each SDK temporal window must declare ``datetime`` so the validation
lines up with core."""