This is an automated email from the ASF dual-hosted git repository.

Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b7247b13219 feat(core): add per-mapper max_fan_out override for 
partition fan-out cap (#67184)
b7247b13219 is described below

commit b7247b132190afd4d21a3221414ed49a6030d175
Author: Wei Lee <[email protected]>
AuthorDate: Thu Jun 11 10:26:40 2026 +0800

    feat(core): add per-mapper max_fan_out override for partition fan-out cap 
(#67184)
---
 airflow-core/newsfragments/67184.feature.rst       |   1 +
 airflow-core/src/airflow/assets/manager.py         |  19 ++-
 .../src/airflow/config_templates/config.yml        |  15 +-
 .../src/airflow/partition_mappers/allowed_key.py   |  10 +-
 airflow-core/src/airflow/partition_mappers/base.py |  26 +++-
 .../src/airflow/partition_mappers/chain.py         |   9 +-
 .../src/airflow/partition_mappers/product.py       |  13 +-
 .../src/airflow/partition_mappers/temporal.py      |  32 +++-
 airflow-core/src/airflow/serialization/encoders.py |  35 ++++-
 airflow-core/tests/unit/assets/test_manager.py     | 167 +++++++++++++++++++++
 .../unit/partition_mappers/test_allowed_key.py     |  17 +++
 .../tests/unit/partition_mappers/test_base.py      |  57 ++++++-
 .../tests/unit/partition_mappers/test_chain.py     |  19 ++-
 .../tests/unit/partition_mappers/test_fan_out.py   |  26 ++++
 .../tests/unit/partition_mappers/test_identity.py  |  15 ++
 .../tests/unit/partition_mappers/test_product.py   |  21 ++-
 .../tests/unit/partition_mappers/test_temporal.py  |  20 ++-
 .../definitions/partition_mappers/allowed_key.py   |   6 +-
 .../sdk/definitions/partition_mappers/base.py      |  30 +++-
 .../sdk/definitions/partition_mappers/chain.py     |   8 +-
 .../sdk/definitions/partition_mappers/product.py   |  14 +-
 .../sdk/definitions/partition_mappers/temporal.py  |  66 +++++---
 .../task_sdk/definitions/test_partition_mappers.py |  34 +++++
 23 files changed, 577 insertions(+), 83 deletions(-)

diff --git a/airflow-core/newsfragments/67184.feature.rst 
b/airflow-core/newsfragments/67184.feature.rst
new file mode 100644
index 00000000000..759e6921d51
--- /dev/null
+++ b/airflow-core/newsfragments/67184.feature.rst
@@ -0,0 +1 @@
+Add ``max_downstream_keys`` parameter to ``PartitionMapper`` to override 
``[scheduler] partition_mapper_max_downstream_keys`` per mapper instance.
diff --git a/airflow-core/src/airflow/assets/manager.py 
b/airflow-core/src/airflow/assets/manager.py
index 5f9e1b598ab..f72c533c5a0 100644
--- a/airflow-core/src/airflow/assets/manager.py
+++ b/airflow-core/src/airflow/assets/manager.py
@@ -552,8 +552,7 @@ class AssetManager(LoggingMixin):
             )
             return
 
-        max_downstream_keys = conf.getint("scheduler", 
"partition_mapper_max_downstream_keys")
-
+        global_cap = conf.getint("scheduler", 
"partition_mapper_max_downstream_keys")
         for target_dag in partition_dags:
             if TYPE_CHECKING:
                 assert partition_key is not None
@@ -573,9 +572,8 @@ class AssetManager(LoggingMixin):
 
             try:
                 # We'll need to catch every possible exception happen when 
mapping partition_key.
-                target_key = timetable.get_partition_mapper(
-                    name=asset_model.name, uri=asset_model.uri
-                ).to_downstream(partition_key)
+                mapper = timetable.get_partition_mapper(name=asset_model.name, 
uri=asset_model.uri)
+                target_key = mapper.to_downstream(partition_key)
             except Exception as err:
                 log.exception(
                     "Could not map partition key for asset in target Dag. "
@@ -607,6 +605,14 @@ class AssetManager(LoggingMixin):
                 target_keys = [target_key]
             del target_key
 
+            mapper_cap = mapper.max_downstream_keys
+            if mapper_cap is not None:
+                max_downstream_keys = mapper_cap
+                cap_source = f"max_downstream_keys={mapper_cap}"
+            else:
+                max_downstream_keys = global_cap
+                cap_source = f"[scheduler] 
partition_mapper_max_downstream_keys={global_cap}"
+
             if len(target_keys) > max_downstream_keys:
                 log.error(
                     "Partition mapper produced more downstream keys than 
allowed; skipping queue.",
@@ -615,6 +621,7 @@ class AssetManager(LoggingMixin):
                     target_dag=target_dag.dag_id,
                     produced_keys=len(target_keys),
                     max_downstream_keys=max_downstream_keys,
+                    cap_source=cap_source,
                 )
                 session.add(
                     Log(
@@ -624,7 +631,7 @@ class AssetManager(LoggingMixin):
                             f"uri='{asset_model.uri}') in target Dag 
'{target_dag.dag_id}' "
                             f"produced {len(target_keys)} downstream keys from 
"
                             f"partition_key='{partition_key}', exceeding "
-                            f"[scheduler] 
partition_mapper_max_downstream_keys={max_downstream_keys}. "
+                            f"{cap_source}. "
                             f"No Dag runs were queued for this event."
                         ),
                         task_instance=task_instance,
diff --git a/airflow-core/src/airflow/config_templates/config.yml 
b/airflow-core/src/airflow/config_templates/config.yml
index 9ed0433c6c4..eebe2d8c951 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -2690,13 +2690,14 @@ scheduler:
       see_also: ":ref:`scheduler:ha:tunables`"
     partition_mapper_max_downstream_keys:
       description: |
-        Maximum number of downstream partition keys a single 
``PartitionMapper``
-        invocation may produce. When any partition mapper (built-in or custom)
-        expands one upstream key into more keys than this limit, the scheduler
-        skips queuing the runs for that asset event and logs an error against
-        the source task instance. This guards against a misconfigured
-        ``PartitionMapper`` from queuing an unbounded number of Dag runs per
-        upstream event.
+        Maximum number of downstream partition keys produced by a single
+        PartitionMapper invocation, applied to any PartitionMapper that returns
+        multiple keys (e.g. FanOutMapper). When a mapper instance sets a 
per-instance
+        ``max_downstream_keys`` parameter, that value completely overrides 
this global
+        cap for that instance — including when the per-mapper value exceeds 
this
+        global. **Deployment managers cannot enforce this setting as a hard
+        cluster-wide ceiling**; treat this value as a default that user code 
may
+        override.
       version_added: 3.3.0
       type: integer
       example: ~
diff --git a/airflow-core/src/airflow/partition_mappers/allowed_key.py 
b/airflow-core/src/airflow/partition_mappers/allowed_key.py
index 8b560f426aa..ff475eef0d2 100644
--- a/airflow-core/src/airflow/partition_mappers/allowed_key.py
+++ b/airflow-core/src/airflow/partition_mappers/allowed_key.py
@@ -25,7 +25,8 @@ from airflow.partition_mappers.base import PartitionMapper
 class AllowedKeyMapper(PartitionMapper):
     """Partition mapper that validates keys against a set of allowed keys."""
 
-    def __init__(self, allowed_keys: list[str]) -> None:
+    def __init__(self, allowed_keys: list[str], *, max_downstream_keys: int | 
None = None) -> None:
+        super().__init__(max_downstream_keys=max_downstream_keys)
         self.allowed_keys = allowed_keys
 
     def to_downstream(self, key: str) -> str:
@@ -34,8 +35,11 @@ class AllowedKeyMapper(PartitionMapper):
         return key
 
     def serialize(self) -> dict[str, Any]:
-        return {"allowed_keys": self.allowed_keys}
+        data: dict[str, Any] = {"allowed_keys": self.allowed_keys}
+        if self.max_downstream_keys is not None:
+            data["max_downstream_keys"] = self.max_downstream_keys
+        return data
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
-        return cls(allowed_keys=data["allowed_keys"])
+        return cls(allowed_keys=data["allowed_keys"], 
max_downstream_keys=data.get("max_downstream_keys"))
diff --git a/airflow-core/src/airflow/partition_mappers/base.py 
b/airflow-core/src/airflow/partition_mappers/base.py
index ee70842cdcf..4226eb44916 100644
--- a/airflow-core/src/airflow/partition_mappers/base.py
+++ b/airflow-core/src/airflow/partition_mappers/base.py
@@ -36,6 +36,15 @@ class PartitionMapper(ABC):
 
     is_rollup: ClassVar[bool] = False
 
+    def __init__(self, *, max_downstream_keys: int | None = None) -> None:
+        if max_downstream_keys is not None and (
+            not isinstance(max_downstream_keys, int) or max_downstream_keys < 1
+        ):
+            raise ValueError(
+                f"max_downstream_keys must be a positive integer or None, got 
{max_downstream_keys!r}"
+            )
+        self.max_downstream_keys = max_downstream_keys
+
     def __init_subclass__(cls, **kwargs: Any) -> None:
         super().__init_subclass__(**kwargs)
         decode_overridden = cls.decode_downstream is not 
PartitionMapper.decode_downstream
@@ -107,11 +116,13 @@ class PartitionMapper(ABC):
         return None
 
     def serialize(self) -> dict[str, Any]:
-        return {}
+        if self.max_downstream_keys is None:
+            return {}
+        return {"max_downstream_keys": self.max_downstream_keys}
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
-        return cls()
+        return cls(max_downstream_keys=data.get("max_downstream_keys"))
 
 
 class RollupMapper(PartitionMapper):
@@ -126,7 +137,9 @@ class RollupMapper(PartitionMapper):
 
     is_rollup: ClassVar[bool] = True
 
-    def __init__(self, *, upstream_mapper: PartitionMapper, window: Window) -> 
None:
+    def __init__(
+        self, *, upstream_mapper: PartitionMapper, window: Window, 
max_downstream_keys: int | None = None
+    ) -> None:
         decode_overridden = type(upstream_mapper).decode_downstream is not 
PartitionMapper.decode_downstream
         if not decode_overridden and window.expected_decoded_type is not str:
             raise TypeError(
@@ -138,6 +151,7 @@ class RollupMapper(PartitionMapper):
                 f"{window.expected_decoded_type.__name__}, or use a window 
whose "
                 f"'expected_decoded_type' accepts str."
             )
+        super().__init__(max_downstream_keys=max_downstream_keys)
         self.upstream_mapper = upstream_mapper
         self.window = window
 
@@ -160,10 +174,13 @@ class RollupMapper(PartitionMapper):
     def serialize(self) -> dict[str, Any]:
         from airflow.serialization.encoders import encode_partition_mapper, 
encode_window
 
-        return {
+        data: dict[str, Any] = {
             "upstream_mapper": encode_partition_mapper(self.upstream_mapper),
             "window": encode_window(self.window),
         }
+        if self.max_downstream_keys is not None:
+            data["max_downstream_keys"] = self.max_downstream_keys
+        return data
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
@@ -172,6 +189,7 @@ class RollupMapper(PartitionMapper):
         return cls(
             upstream_mapper=decode_partition_mapper(data["upstream_mapper"]),
             window=decode_window(data["window"]),
+            max_downstream_keys=data.get("max_downstream_keys"),
         )
 
 
diff --git a/airflow-core/src/airflow/partition_mappers/chain.py 
b/airflow-core/src/airflow/partition_mappers/chain.py
index a4a2110c256..6c0749feba2 100644
--- a/airflow-core/src/airflow/partition_mappers/chain.py
+++ b/airflow-core/src/airflow/partition_mappers/chain.py
@@ -35,7 +35,9 @@ class ChainMapper(PartitionMapper):
         mapper1: PartitionMapper,
         /,
         *mappers: PartitionMapper,
+        max_downstream_keys: int | None = None,
     ) -> None:
+        super().__init__(max_downstream_keys=max_downstream_keys)
         self.mappers = [mapper0, mapper1, *mappers]
 
     def to_downstream(self, key: str) -> str | Iterable[str]:
@@ -70,11 +72,14 @@ class ChainMapper(PartitionMapper):
     def serialize(self) -> dict[str, Any]:
         from airflow.serialization.encoders import encode_partition_mapper
 
-        return {"mappers": [encode_partition_mapper(m) for m in self.mappers]}
+        result: dict[str, Any] = {"mappers": [encode_partition_mapper(m) for m 
in self.mappers]}
+        if self.max_downstream_keys is not None:
+            result["max_downstream_keys"] = self.max_downstream_keys
+        return result
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
         from airflow.serialization.decoders import decode_partition_mapper
 
         mappers = [decode_partition_mapper(m) for m in data["mappers"]]
-        return cls(*mappers)
+        return cls(*mappers, 
max_downstream_keys=data.get("max_downstream_keys"))
diff --git a/airflow-core/src/airflow/partition_mappers/product.py 
b/airflow-core/src/airflow/partition_mappers/product.py
index 5d8882bbbcc..e7c91aad082 100644
--- a/airflow-core/src/airflow/partition_mappers/product.py
+++ b/airflow-core/src/airflow/partition_mappers/product.py
@@ -32,7 +32,9 @@ class ProductMapper(PartitionMapper):
         /,
         *mappers: PartitionMapper,
         delimiter: str = "|",
+        max_downstream_keys: int | None = None,
     ) -> None:
+        super().__init__(max_downstream_keys=max_downstream_keys)
         self.mappers = [mapper0, mapper1, *mappers]
         self.delimiter = delimiter
 
@@ -54,14 +56,21 @@ class ProductMapper(PartitionMapper):
     def serialize(self) -> dict[str, Any]:
         from airflow.serialization.encoders import encode_partition_mapper
 
-        return {
+        result: dict[str, Any] = {
             "delimiter": self.delimiter,
             "mappers": [encode_partition_mapper(m) for m in self.mappers],
         }
+        if self.max_downstream_keys is not None:
+            result["max_downstream_keys"] = self.max_downstream_keys
+        return result
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
         from airflow.serialization.decoders import decode_partition_mapper
 
         mappers = [decode_partition_mapper(m) for m in data["mappers"]]
-        return cls(*mappers, delimiter=data.get("delimiter", "|"))
+        return cls(
+            *mappers,
+            delimiter=data.get("delimiter", "|"),
+            max_downstream_keys=data.get("max_downstream_keys"),
+        )
diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py 
b/airflow-core/src/airflow/partition_mappers/temporal.py
index 622f0c6a556..1938a18f84d 100644
--- a/airflow-core/src/airflow/partition_mappers/temporal.py
+++ b/airflow-core/src/airflow/partition_mappers/temporal.py
@@ -163,7 +163,9 @@ class _BaseTemporalMapper(PartitionMapper):
         timezone: str | Timezone | FixedTimezone = "UTC",
         input_format: str = "%Y-%m-%dT%H:%M:%S",
         output_format: str | None = None,
+        max_downstream_keys: int | None = None,
     ):
+        super().__init__(max_downstream_keys=max_downstream_keys)
         self.input_format = input_format
         self.output_format = output_format or self.default_output_format
         if isinstance(timezone, str):
@@ -230,11 +232,14 @@ class _BaseTemporalMapper(PartitionMapper):
     def serialize(self) -> dict[str, Any]:
         from airflow.serialization.encoders import encode_timezone
 
-        return {
+        result: dict[str, Any] = {
             "timezone": encode_timezone(self._timezone),
             "input_format": self.input_format,
             "output_format": self.output_format,
         }
+        if self.max_downstream_keys is not None:
+            result["max_downstream_keys"] = self.max_downstream_keys
+        return result
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
@@ -242,6 +247,7 @@ class _BaseTemporalMapper(PartitionMapper):
             timezone=parse_timezone(data.get("timezone", "UTC")),
             input_format=data["input_format"],
             output_format=data["output_format"],
+            max_downstream_keys=data.get("max_downstream_keys"),
         )
 
 
@@ -286,6 +292,7 @@ class StartOfWeekMapper(_BaseTemporalMapper):
         timezone: str | Timezone | FixedTimezone = "UTC",
         input_format: str = "%Y-%m-%dT%H:%M:%S",
         output_format: str | None = None,
+        max_downstream_keys: int | None = None,
     ) -> None:
         """
         Compile *output_format* eagerly so malformed patterns raise here.
@@ -302,7 +309,12 @@ class StartOfWeekMapper(_BaseTemporalMapper):
             **must** include ``%Y``, ``%m``, and ``%d`` so the week-start date
             can be recovered for ``to_upstream``.
         """
-        super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)
+        super().__init__(
+            timezone=timezone,
+            input_format=input_format,
+            output_format=output_format,
+            max_downstream_keys=max_downstream_keys,
+        )
         # %V (ISO week) cannot be round-tripped through strptime without %G+%u,
         # so derive a named-group regex from output_format and pull out 
%Y/%m/%d.
         # Compile eagerly so a malformed output_format raises ValueError here
@@ -358,6 +370,7 @@ class StartOfQuarterMapper(_BaseTemporalMapper):
         timezone: str | Timezone | FixedTimezone = "UTC",
         input_format: str = "%Y-%m-%dT%H:%M:%S",
         output_format: str | None = None,
+        max_downstream_keys: int | None = None,
     ) -> None:
         """
         Compile *output_format* eagerly so malformed patterns raise here.
@@ -376,7 +389,12 @@ class StartOfQuarterMapper(_BaseTemporalMapper):
             and ``{quarter}`` so the quarter-start date can be recovered for
             ``to_upstream``.
         """
-        super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)
+        super().__init__(
+            timezone=timezone,
+            input_format=input_format,
+            output_format=output_format,
+            max_downstream_keys=max_downstream_keys,
+        )
         # ``{quarter}`` is a Python-format placeholder, not a strftime 
directive,
         # so derive a named-group regex from output_format that handles both.
         # Compile eagerly so a malformed output_format raises ValueError here
@@ -511,7 +529,9 @@ class FanOutMapper(PartitionMapper):
         upstream_mapper: PartitionMapper,
         window: Window,
         downstream_mapper: PartitionMapper | None = None,
+        max_downstream_keys: int | None = None,
     ) -> None:
+        super().__init__(max_downstream_keys=max_downstream_keys)
         self.upstream_mapper = upstream_mapper
         self.window = window
         self.downstream_mapper = downstream_mapper or 
self._resolve_default_downstream_mapper(window)
@@ -537,11 +557,14 @@ class FanOutMapper(PartitionMapper):
     def serialize(self) -> dict[str, Any]:
         from airflow.serialization.encoders import encode_partition_mapper, 
encode_window
 
-        return {
+        result: dict[str, Any] = {
             "upstream_mapper": encode_partition_mapper(self.upstream_mapper),
             "window": encode_window(self.window),
             "downstream_mapper": 
encode_partition_mapper(self.downstream_mapper),
         }
+        if self.max_downstream_keys is not None:
+            result["max_downstream_keys"] = self.max_downstream_keys
+        return result
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
@@ -551,6 +574,7 @@ class FanOutMapper(PartitionMapper):
             upstream_mapper=decode_partition_mapper(data["upstream_mapper"]),
             window=decode_window(data["window"]),
             
downstream_mapper=decode_partition_mapper(data["downstream_mapper"]),
+            max_downstream_keys=data.get("max_downstream_keys"),
         )
 
 
diff --git a/airflow-core/src/airflow/serialization/encoders.py 
b/airflow-core/src/airflow/serialization/encoders.py
index 0c2f030da8f..3e804393b53 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -461,11 +461,17 @@ class _Serializer:
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: ChainMapper) -> dict[str, Any]:
-        return {"mappers": [encode_partition_mapper(m) for m in 
partition_mapper.mappers]}
+        data: dict[str, Any] = {"mappers": [encode_partition_mapper(m) for m 
in partition_mapper.mappers]}
+        if partition_mapper.max_downstream_keys is not None:
+            data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+        return data
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
-        return {}
+        data: dict[str, Any] = {}
+        if partition_mapper.max_downstream_keys is not None:
+            data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+        return data
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: FixedKeyMapper) -> dict[str, Any]:
@@ -486,37 +492,52 @@ class _Serializer:
         | StartOfQuarterMapper
         | StartOfYearMapper,
     ) -> dict[str, Any]:
-        return {
+        data: dict[str, Any] = {
             "timezone": encode_timezone(partition_mapper._timezone),
             "input_format": partition_mapper.input_format,
             "output_format": partition_mapper.output_format,
         }
+        if partition_mapper.max_downstream_keys is not None:
+            data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+        return data
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: ProductMapper) -> dict[str, Any]:
-        return {
+        data: dict[str, Any] = {
             "delimiter": partition_mapper.delimiter,
             "mappers": [encode_partition_mapper(m) for m in 
partition_mapper.mappers],
         }
+        if partition_mapper.max_downstream_keys is not None:
+            data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+        return data
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: AllowedKeyMapper) -> dict[str, Any]:
-        return {"allowed_keys": partition_mapper.allowed_keys}
+        data: dict[str, Any] = {"allowed_keys": partition_mapper.allowed_keys}
+        if partition_mapper.max_downstream_keys is not None:
+            data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+        return data
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: RollupMapper) -> dict[str, Any]:
-        return {
+        data: dict[str, Any] = {
             "upstream_mapper": 
encode_partition_mapper(partition_mapper.upstream_mapper),
             "window": encode_window(partition_mapper.window),
         }
+        if partition_mapper.max_downstream_keys is not None:
+            data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+        return data
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: FanOutMapper) -> dict[str, Any]:
-        return {
+        data: dict[str, Any] = {
             "upstream_mapper": 
encode_partition_mapper(partition_mapper.upstream_mapper),
             "window": encode_window(partition_mapper.window),
             "downstream_mapper": 
encode_partition_mapper(partition_mapper.downstream_mapper),
         }
+        if partition_mapper.max_downstream_keys is not None:
+            data["max_downstream_keys"] = partition_mapper.max_downstream_keys
+        return data
 
     BUILTIN_WINDOWS: dict[type, str] = {
         HourWindow: "airflow.partition_mappers.window.HourWindow",
diff --git a/airflow-core/tests/unit/assets/test_manager.py 
b/airflow-core/tests/unit/assets/test_manager.py
index 47da4d0db2e..687f4d17847 100644
--- a/airflow-core/tests/unit/assets/test_manager.py
+++ b/airflow-core/tests/unit/assets/test_manager.py
@@ -473,6 +473,173 @@ class TestAssetManager:
         assert error_call.kwargs["source_partition_key"] == 
"2024-06-03T00:00:00"
         assert error_call.kwargs["produced_keys"] == 7
         assert error_call.kwargs["max_downstream_keys"] == cap
+        assert error_call.kwargs["cap_source"] == f"[scheduler] 
partition_mapper_max_downstream_keys={cap}"
+
+    @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "100"})
+    @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+    def test_partition_fanout_per_mapper_override_stricter_than_global_trips(
+        self, session, dag_maker, mock_task_instance
+    ):
+        """Per-mapper max_downstream_keys=3 trips even when the global cap is 
100.
+
+        Proves the per-mapper override takes precedence over a more permissive 
global.
+        The Log.extra must mention 'max_downstream_keys=3' and must NOT mention
+        'partition_mapper_max_downstream_keys' (i.e. the global cap name is 
absent from the message).
+        """
+        clear_db_apdr()
+        clear_db_pakl()
+        clear_db_logs()
+
+        asset_def = Asset(uri="s3://bucket/per_mapper_strict", 
name="per_mapper_strict")
+        # WeekWindow produces 7 daily keys; per-mapper cap of 3 must trip 
first.
+        mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(), 
window=WeekWindow(), max_downstream_keys=3)
+        with dag_maker(
+            dag_id="per_mapper_strict_dag",
+            schedule=PartitionedAssetTimetable(assets=asset_def, 
partition_mapper_config={asset_def: mapper}),
+            serialized=True,
+        ):
+            EmptyOperator(task_id="t")
+        dag_maker.create_dagrun()
+        dag_maker.sync_dagbag_to_db()
+
+        with mock.patch("airflow.assets.manager.log") as mock_log:
+            AssetManager.register_asset_change(
+                task_instance=mock_task_instance,
+                asset=asset_def,
+                session=session,
+                partition_key="2024-06-03T00:00:00",
+            )
+            session.flush()
+
+        assert 
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 0
+        log_extras = session.scalars(select(Log.extra).where(Log.event == 
"partition fan-out exceeded")).all()
+        assert len(log_extras) == 1
+        assert "max_downstream_keys=3" in log_extras[0]
+        assert "partition_mapper_max_downstream_keys" not in log_extras[0]
+        # Pin the scheduler-log error kwargs for the per-mapper path 
symmetrically
+        # with the global-cap path in test_partition_fan_out_cap.
+        mock_log.error.assert_called_once()
+        error_call = mock_log.error.call_args
+        assert error_call.kwargs["cap_source"] == "max_downstream_keys=3"
+
+    @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "3"})
+    @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+    def test_partition_fanout_per_mapper_override_looser_than_global_permits(
+        self, session, dag_maker, mock_task_instance
+    ):
+        """Per-mapper max_downstream_keys=10 permits 7 keys even when the 
global cap is 3.
+
+        Proves the per-mapper override can relax, not just tighten, the cap.
+        """
+        clear_db_apdr()
+        clear_db_pakl()
+        clear_db_logs()
+
+        asset_def = Asset(uri="s3://bucket/per_mapper_loose", 
name="per_mapper_loose")
+        # Global cap of 3 would block the 7-key WeekWindow fanout, but 
per-mapper
+        # max_downstream_keys=10 overrides it and all 7 rows must be queued.
+        mapper = FanOutMapper(
+            upstream_mapper=StartOfWeekMapper(), window=WeekWindow(), 
max_downstream_keys=10
+        )
+        with dag_maker(
+            dag_id="per_mapper_loose_dag",
+            schedule=PartitionedAssetTimetable(assets=asset_def, 
partition_mapper_config={asset_def: mapper}),
+            serialized=True,
+        ):
+            EmptyOperator(task_id="t")
+        dag_maker.create_dagrun()
+        dag_maker.sync_dagbag_to_db()
+
+        AssetManager.register_asset_change(
+            task_instance=mock_task_instance,
+            asset=asset_def,
+            session=session,
+            partition_key="2024-06-03T00:00:00",
+        )
+        session.flush()
+
+        assert 
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 7
+        assert (
+            session.scalar(
+                select(func.count()).select_from(Log).where(Log.event == 
"partition fan-out exceeded")
+            )
+            == 0
+        )
+
+    @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "1"})
+    @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+    def test_partition_fanout_per_mapper_at_cap_is_allowed(self, session, 
dag_maker, mock_task_instance):
+        """Per-mapper max_downstream_keys=7 with a 7-key fanout: exactly at 
cap is allowed.
+
+        Pairs with test_partition_fanout_per_mapper_one_over_cap_trips to pin 
the
+        boundary at '>' (not '>=') on the per-mapper branch.
+        """
+        clear_db_apdr()
+        clear_db_pakl()
+        clear_db_logs()
+
+        asset_def = Asset(uri="s3://bucket/per_mapper_at_cap", 
name="per_mapper_at_cap")
+        mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(), 
window=WeekWindow(), max_downstream_keys=7)
+        with dag_maker(
+            dag_id="per_mapper_at_cap_dag",
+            schedule=PartitionedAssetTimetable(assets=asset_def, 
partition_mapper_config={asset_def: mapper}),
+            serialized=True,
+        ):
+            EmptyOperator(task_id="t")
+        dag_maker.create_dagrun()
+        dag_maker.sync_dagbag_to_db()
+
+        AssetManager.register_asset_change(
+            task_instance=mock_task_instance,
+            asset=asset_def,
+            session=session,
+            partition_key="2024-06-03T00:00:00",
+        )
+        session.flush()
+
+        assert 
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 7
+        assert (
+            session.scalar(
+                select(func.count()).select_from(Log).where(Log.event == 
"partition fan-out exceeded")
+            )
+            == 0
+        )
+
+    @conf_vars({("scheduler", "partition_mapper_max_downstream_keys"): "1"})
+    @pytest.mark.usefixtures("clear_assets", "testing_dag_bundle")
+    def test_partition_fanout_per_mapper_one_over_cap_trips(self, session, 
dag_maker, mock_task_instance):
+        """Per-mapper max_downstream_keys=6 with a 7-key fanout: one over cap 
trips the guard.
+
+        Pairs with test_partition_fanout_per_mapper_at_cap_is_allowed to lock 
the
+        boundary: 7 keys at cap=7 is allowed, but 7 keys at cap=6 is not.
+        """
+        clear_db_apdr()
+        clear_db_pakl()
+        clear_db_logs()
+
+        asset_def = Asset(uri="s3://bucket/per_mapper_over_cap", 
name="per_mapper_over_cap")
+        mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(), 
window=WeekWindow(), max_downstream_keys=6)
+        with dag_maker(
+            dag_id="per_mapper_over_cap_dag",
+            schedule=PartitionedAssetTimetable(assets=asset_def, 
partition_mapper_config={asset_def: mapper}),
+            serialized=True,
+        ):
+            EmptyOperator(task_id="t")
+        dag_maker.create_dagrun()
+        dag_maker.sync_dagbag_to_db()
+
+        AssetManager.register_asset_change(
+            task_instance=mock_task_instance,
+            asset=asset_def,
+            session=session,
+            partition_key="2024-06-03T00:00:00",
+        )
+        session.flush()
+
+        assert 
session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 0
+        log_extras = session.scalars(select(Log.extra).where(Log.event == 
"partition fan-out exceeded")).all()
+        assert len(log_extras) == 1
+        assert "max_downstream_keys=6" in log_extras[0]
 
 
 def _make_dag(dag_id: str) -> DagModel:
diff --git a/airflow-core/tests/unit/partition_mappers/test_allowed_key.py 
b/airflow-core/tests/unit/partition_mappers/test_allowed_key.py
index a04b22e48e9..6fb47b5be31 100644
--- a/airflow-core/tests/unit/partition_mappers/test_allowed_key.py
+++ b/airflow-core/tests/unit/partition_mappers/test_allowed_key.py
@@ -19,6 +19,9 @@ from __future__ import annotations
 import pytest
 
 from airflow.partition_mappers.allowed_key import AllowedKeyMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
 
 
 class TestAllowedKeyMapper:
@@ -46,3 +49,17 @@ class TestAllowedKeyMapper:
         assert pm.serialize() == {"allowed_keys": []}
         with pytest.raises(ValueError, match="not in allowed keys"):
             pm.to_downstream("any")
+
+    def test_max_downstream_keys_encode_decode_roundtrip(self):
+        """max_downstream_keys=5 survives encode_partition_mapper → 
decode_partition_mapper."""
+
+        mapper = AllowedKeyMapper(["us", "eu", "apac", "latam", "africa"], 
max_downstream_keys=5)
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys == 5
+
+    def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+        """max_downstream_keys must NOT appear in the encoded payload when not 
set (zero-bloat contract)."""
+
+        mapper = AllowedKeyMapper(["us", "eu"])
+        encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+        assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_base.py 
b/airflow-core/tests/unit/partition_mappers/test_base.py
index a313acd7e1c..e7b0ba02f07 100644
--- a/airflow-core/tests/unit/partition_mappers/test_base.py
+++ b/airflow-core/tests/unit/partition_mappers/test_base.py
@@ -16,9 +16,17 @@
 # under the License.
 from __future__ import annotations
 
+import re
+
 import pytest
 
-from airflow.partition_mappers.base import PartitionMapper
+from airflow.partition_mappers.base import PartitionMapper, RollupMapper
+from airflow.partition_mappers.identity import IdentityMapper
+from airflow.partition_mappers.temporal import StartOfDayMapper
+from airflow.partition_mappers.window import DayWindow
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
 
 
 class TestPartitionMapperInitSubclass:
@@ -125,3 +133,50 @@ class TestRollupMapperInit:
 
         # Should not raise.
         RollupMapper(upstream_mapper=_StringOnlyMapper(), 
window=_AlphaWindow())
+
+
+class TestPartitionMapperMaxDownstreamKeysValidator:
+    """Verify the max_downstream_keys validator on the PartitionMapper base 
class.
+
+    Uses IdentityMapper as the most lightweight concrete subclass — the
+    validator lives on the base class so any subclass exercises it.
+    """
+
+    def test_max_downstream_keys_none_is_accepted(self):
+        """Default (None) leaves max_downstream_keys as None."""
+        mapper = IdentityMapper()
+        assert mapper.max_downstream_keys is None
+
+    def test_max_downstream_keys_one_is_accepted(self):
+        """Minimum positive integer value is accepted."""
+        mapper = IdentityMapper(max_downstream_keys=1)
+        assert mapper.max_downstream_keys == 1
+
+    @pytest.mark.parametrize(
+        "bad_value",
+        [
+            pytest.param(0, id="zero"),
+            pytest.param(-1, id="negative"),
+            pytest.param(1.0, id="float"),
+            pytest.param("5", id="string"),
+        ],
+    )
+    def test_max_downstream_keys_invalid_raises(self, bad_value):
+        """Reject non-positive-integer values with the full validator 
message."""
+        with pytest.raises(
+            ValueError,
+            match=re.escape(f"max_downstream_keys must be a positive integer 
or None, got {bad_value!r}"),
+        ):
+            IdentityMapper(max_downstream_keys=bad_value)
+
+
+class TestRollupMapperMaxDownstreamKeys:
+    def test_max_downstream_keys_encode_decode_roundtrip(self):
+        mapper = RollupMapper(upstream_mapper=StartOfDayMapper(), 
window=DayWindow(), max_downstream_keys=5)
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys == 5
+
+    def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+        mapper = RollupMapper(upstream_mapper=StartOfDayMapper(), 
window=DayWindow())
+        encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+        assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_chain.py 
b/airflow-core/tests/unit/partition_mappers/test_chain.py
index 528a176b1e0..a70230b2230 100644
--- a/airflow-core/tests/unit/partition_mappers/test_chain.py
+++ b/airflow-core/tests/unit/partition_mappers/test_chain.py
@@ -25,6 +25,9 @@ from airflow.partition_mappers.base import PartitionMapper
 from airflow.partition_mappers.chain import ChainMapper
 from airflow.partition_mappers.identity import IdentityMapper
 from airflow.partition_mappers.temporal import StartOfDayMapper, 
StartOfHourMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
 
 
 class _InvalidReturnMapper(PartitionMapper):
@@ -69,8 +72,6 @@ class TestChainMapper:
             sm.to_downstream("key")
 
     def test_serialize(self):
-        from airflow.serialization.encoders import encode_partition_mapper
-
         sm = ChainMapper(StartOfHourMapper(), 
StartOfDayMapper(input_format="%Y-%m-%dT%H"))
         result = sm.serialize()
         assert result == {
@@ -87,3 +88,17 @@ class TestChainMapper:
         assert isinstance(restored, ChainMapper)
         assert len(restored.mappers) == 2
         assert restored.to_downstream("2024-01-15T10:30:00") == "2024-01-15"
+
+    def test_max_downstream_keys_encode_decode_roundtrip(self):
+        """max_downstream_keys=5 survives encode_partition_mapper → 
decode_partition_mapper."""
+        mapper = ChainMapper(
+            StartOfHourMapper(), StartOfDayMapper(input_format="%Y-%m-%dT%H"), 
max_downstream_keys=5
+        )
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys == 5
+
+    def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+        """max_downstream_keys must NOT appear in the encoded payload when not 
set (zero-bloat contract)."""
+        mapper = ChainMapper(StartOfHourMapper(), 
StartOfDayMapper(input_format="%Y-%m-%dT%H"))
+        encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+        assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_fan_out.py 
b/airflow-core/tests/unit/partition_mappers/test_fan_out.py
index 17a455a4bbc..d36fea48a63 100644
--- a/airflow-core/tests/unit/partition_mappers/test_fan_out.py
+++ b/airflow-core/tests/unit/partition_mappers/test_fan_out.py
@@ -39,6 +39,9 @@ from airflow.partition_mappers.window import (
     Window,
     YearWindow,
 )
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
 
 
 class TestFanOutMapper:
@@ -311,3 +314,26 @@ class TestFanOutMapper:
         assert list(restored.to_downstream("2024-03-04T00:00:00")) == list(
             mapper.to_downstream("2024-03-04T00:00:00")
         )
+
+    def test_max_downstream_keys_encode_decode_roundtrip(self):
+        """max_downstream_keys=5 survives encode_partition_mapper → 
decode_partition_mapper."""
+        mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(), 
window=WeekWindow(), max_downstream_keys=5)
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys == 5
+
+    def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+        """max_downstream_keys must NOT appear in the encoded payload when not 
set (zero-bloat contract)."""
+        mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(), 
window=WeekWindow())
+        encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+        assert "max_downstream_keys" not in encoded_var
+
+    def test_max_downstream_keys_defaults_to_none_when_absent(self):
+        """A payload lacking max_downstream_keys (e.g. serialized by an older 
Airflow) decodes to None.
+
+        This is the real backward-compatibility path: the scheduler reads
+        ``mapper.max_downstream_keys`` directly, so the attribute must always
+        exist — ``deserialize`` defaults it to None when the key is absent.
+        """
+        mapper = FanOutMapper(upstream_mapper=StartOfWeekMapper(), 
window=WeekWindow())
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys is None
diff --git a/airflow-core/tests/unit/partition_mappers/test_identity.py 
b/airflow-core/tests/unit/partition_mappers/test_identity.py
index 282be34bfcc..9a97e583db5 100644
--- a/airflow-core/tests/unit/partition_mappers/test_identity.py
+++ b/airflow-core/tests/unit/partition_mappers/test_identity.py
@@ -17,6 +17,9 @@
 from __future__ import annotations
 
 from airflow.partition_mappers.identity import IdentityMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
 
 
 class TestIdentityMapper:
@@ -30,3 +33,15 @@ class TestIdentityMapper:
 
     def test_deserialize(self):
         assert isinstance(IdentityMapper.deserialize({}), IdentityMapper)
+
+    def test_max_downstream_keys_encode_decode_roundtrip(self):
+        """max_downstream_keys=5 survives encode_partition_mapper → 
decode_partition_mapper."""
+        mapper = IdentityMapper(max_downstream_keys=5)
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys == 5
+
+    def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+        """max_downstream_keys must NOT appear in the encoded payload when not 
set (zero-bloat contract)."""
+        mapper = IdentityMapper()
+        encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+        assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_product.py 
b/airflow-core/tests/unit/partition_mappers/test_product.py
index 17269f261ea..9d6eba1eb21 100644
--- a/airflow-core/tests/unit/partition_mappers/test_product.py
+++ b/airflow-core/tests/unit/partition_mappers/test_product.py
@@ -22,6 +22,9 @@ import pytest
 from airflow.partition_mappers.identity import IdentityMapper
 from airflow.partition_mappers.product import ProductMapper
 from airflow.partition_mappers.temporal import StartOfDayMapper, 
StartOfHourMapper
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
 
 
 class TestProductMapper:
@@ -49,8 +52,6 @@ class TestProductMapper:
             pm.to_downstream("2024-01-15T10:30:00::2024-01-15T10:30:00::extra")
 
     def test_serialize(self):
-        from airflow.serialization.encoders import encode_partition_mapper
-
         pm = ProductMapper(StartOfHourMapper(), StartOfDayMapper())
         result = pm.serialize()
         assert result == {
@@ -62,8 +63,6 @@ class TestProductMapper:
         }
 
     def test_serialize_custom_delimiter(self):
-        from airflow.serialization.encoders import encode_partition_mapper
-
         pm = ProductMapper(StartOfHourMapper(), StartOfDayMapper(), 
delimiter="::")
         result = pm.serialize()
         assert result == {
@@ -95,8 +94,6 @@ class TestProductMapper:
 
     def test_deserialize_backward_compat(self):
         """Deserializing data without delimiter field defaults to '|'."""
-        from airflow.serialization.encoders import encode_partition_mapper
-
         data = {
             "mappers": [
                 encode_partition_mapper(StartOfHourMapper()),
@@ -111,3 +108,15 @@ class TestProductMapper:
         assert (
             pm.to_downstream("2024-01-15T10:30:00|2024-01-15T10:30:00|raw") == 
"2024-01-15T10|2024-01-15|raw"
         )
+
+    def test_max_downstream_keys_encode_decode_roundtrip(self):
+        """max_downstream_keys=5 survives encode_partition_mapper → 
decode_partition_mapper."""
+        mapper = ProductMapper(StartOfHourMapper(), StartOfDayMapper(), 
max_downstream_keys=5)
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys == 5
+
+    def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+        """max_downstream_keys must NOT appear in the encoded payload when not 
set (zero-bloat contract)."""
+        mapper = ProductMapper(StartOfHourMapper(), StartOfDayMapper())
+        encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+        assert "max_downstream_keys" not in encoded_var
diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py 
b/airflow-core/tests/unit/partition_mappers/test_temporal.py
index ab2a79f52ce..e777017b51a 100644
--- a/airflow-core/tests/unit/partition_mappers/test_temporal.py
+++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py
@@ -20,6 +20,7 @@ from datetime import datetime, timezone as dt_timezone
 
 import pendulum
 import pytest
+from pendulum.tz.exceptions import InvalidTimezone
 
 from airflow import sdk
 from airflow.partition_mappers.base import RollupMapper
@@ -38,6 +39,7 @@ from airflow.partition_mappers.temporal import (
 from airflow.partition_mappers.window import HourWindow, WeekWindow
 from airflow.serialization.decoders import decode_partition_mapper
 from airflow.serialization.encoders import encode_partition_mapper
+from airflow.serialization.enums import Encoding
 
 
 class TestTemporalMappers:
@@ -297,8 +299,6 @@ class TestSdkTemporalMappersTimezoneSerialization:
     def test_sdk_constructor_invalid_timezone_raises_eagerly(self, 
sdk_mapper_name):
         """Passing an unknown timezone string must raise at construction time
         (via ``parse_timezone``), not silently fall back to UTC or fail 
later."""
-        from pendulum.tz.exceptions import InvalidTimezone
-
         sdk_cls = getattr(sdk, sdk_mapper_name)
         with pytest.raises(InvalidTimezone):
             sdk_cls(timezone="Not/A/Real/Zone")
@@ -447,3 +447,19 @@ class TestToPartitionDateDelegation:
     )
     def test_to_partition_date(self, mapper, downstream_key, expected):
         assert mapper.to_partition_date(downstream_key) == expected
+
+
+class TestTemporalMapperMaxDownstreamKeys:
+    """Round-trip and zero-bloat tests for max_downstream_keys on temporal 
mappers."""
+
+    def test_max_downstream_keys_encode_decode_roundtrip(self):
+        """max_downstream_keys=5 survives encode_partition_mapper → 
decode_partition_mapper."""
+        mapper = StartOfWeekMapper(max_downstream_keys=5)
+        restored = decode_partition_mapper(encode_partition_mapper(mapper))
+        assert restored.max_downstream_keys == 5
+
+    def test_max_downstream_keys_absent_from_default_encoded_payload(self):
+        """max_downstream_keys must NOT appear in the encoded payload when not 
set (zero-bloat contract)."""
+        mapper = StartOfWeekMapper()
+        encoded_var = encode_partition_mapper(mapper)[Encoding.VAR]
+        assert "max_downstream_keys" not in encoded_var
diff --git 
a/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py
index 7d860bd8796..cc21144d42b 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py
@@ -16,11 +16,13 @@
 # under the License.
 from __future__ import annotations
 
+import attrs
+
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
 
 
[email protected]
 class AllowedKeyMapper(PartitionMapper):
     """Partition mapper that validates keys against a set of allowed keys."""
 
-    def __init__(self, allowed_keys: list[str]) -> None:
-        self.allowed_keys = allowed_keys
+    allowed_keys: list[str]
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py
index fc26450f027..351ad4e94fa 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py
@@ -18,10 +18,18 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, ClassVar
 
+import attrs
+
 if TYPE_CHECKING:
     from airflow.sdk.definitions.partition_mappers.window import Window
 
 
+def _validate_max_downstream_keys(instance, attribute, value):
+    if value is not None and (not isinstance(value, int) or value < 1):
+        raise ValueError(f"max_downstream_keys must be a positive integer or 
None, got {value!r}")
+
+
[email protected]
 class PartitionMapper:
     """
     Base partition mapper class.
@@ -36,7 +44,12 @@ class PartitionMapper:
     #: Temporal mappers override to ``datetime``.
     expected_decoded_type: ClassVar[type] = str
 
+    max_downstream_keys: int | None = attrs.field(
+        default=None, kw_only=True, validator=_validate_max_downstream_keys
+    )
 
+
[email protected]
 class RollupMapper(PartitionMapper):
     """
     Partition mapper that rolls up many upstream keys into one downstream key.
@@ -49,21 +62,22 @@ class RollupMapper(PartitionMapper):
 
     is_rollup: ClassVar[bool] = True
 
-    def __init__(self, *, upstream_mapper: PartitionMapper, window: Window) -> 
None:
+    upstream_mapper: PartitionMapper = attrs.field(kw_only=True)
+    window: Window = attrs.field(kw_only=True)
+
+    def __attrs_post_init__(self) -> None:
         # Mirrors the core-side ``RollupMapper.__init__`` check so user code
         # ``from airflow.sdk import RollupMapper`` fails at Dag parse time 
rather
         # than slipping through to the scheduler tick (where the 
misconfiguration
         # would otherwise be swallowed by the bare ``except`` in
         # ``_create_dagruns_for_partitioned_asset_dags`` and surface only as
         # "Failed to deserialize Dag" spam).
-        if upstream_mapper.expected_decoded_type is str and 
window.expected_decoded_type is not str:
+        if self.upstream_mapper.expected_decoded_type is str and 
self.window.expected_decoded_type is not str:
             raise TypeError(
-                f"{type(window).__name__} expects decoded values of type "
-                f"{window.expected_decoded_type.__name__!r}, but "
-                f"{type(upstream_mapper).__name__} decodes to 'str' (SDK 
PartitionMapper default). "
+                f"{type(self.window).__name__} expects decoded values of type "
+                f"{self.window.expected_decoded_type.__name__!r}, but "
+                f"{type(self.upstream_mapper).__name__} decodes to 'str' (SDK 
PartitionMapper default). "
                 f"Pair the window with an upstream mapper whose 
'expected_decoded_type' is "
-                f"{window.expected_decoded_type.__name__}, or use a window 
whose "
+                f"{self.window.expected_decoded_type.__name__}, or use a 
window whose "
                 f"'expected_decoded_type' accepts str."
             )
-        self.upstream_mapper = upstream_mapper
-        self.window = window
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
index 0f96e2d2eaf..6ac3b78d1e4 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
@@ -16,17 +16,23 @@
 # under the License.
 from __future__ import annotations
 
+import attrs
+
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
 
 
[email protected](init=False)
 class ChainMapper(PartitionMapper):
     """Partition mapper that applies multiple mappers sequentially."""
 
+    mappers: list[PartitionMapper]
+
     def __init__(
         self,
         mapper0: PartitionMapper,
         mapper1: PartitionMapper,
         /,
         *mappers: PartitionMapper,
+        max_downstream_keys: int | None = None,
     ) -> None:
-        self.mappers = [mapper0, mapper1, *mappers]
+        self.__attrs_init__(mappers=[mapper0, mapper1, *mappers], 
max_downstream_keys=max_downstream_keys)
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py
index ffc744ff71e..2ed96dbd31f 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/product.py
@@ -16,12 +16,18 @@
 # under the License.
 from __future__ import annotations
 
+import attrs
+
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
 
 
[email protected](init=False)
 class ProductMapper(PartitionMapper):
     """Partition mapper that combines multiple mappers into a 
multi-dimensional key."""
 
+    mappers: list[PartitionMapper]
+    delimiter: str
+
     def __init__(
         self,
         mapper0: PartitionMapper,
@@ -29,6 +35,10 @@ class ProductMapper(PartitionMapper):
         /,
         *mappers: PartitionMapper,
         delimiter: str = "|",
+        max_downstream_keys: int | None = None,
     ) -> None:
-        self.mappers = [mapper0, mapper1, *mappers]
-        self.delimiter = delimiter
+        self.__attrs_init__(
+            mappers=[mapper0, mapper1, *mappers],
+            delimiter=delimiter,
+            max_downstream_keys=max_downstream_keys,
+        )
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
index 3eaf4e19d74..0cf0a705d5c 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
@@ -19,6 +19,8 @@ from __future__ import annotations
 from datetime import datetime
 from typing import TYPE_CHECKING, ClassVar
 
+import attrs
+
 from airflow.sdk._shared.timezones.timezone import parse_timezone
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
 
@@ -28,24 +30,31 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.partition_mappers.window import Window
 
 
+def _timezone_converter(value: str | Timezone | FixedTimezone) -> Timezone | 
FixedTimezone:
+    if isinstance(value, str):
+        return parse_timezone(value)
+    return value
+
+
[email protected]
 class _BaseTemporalMapper(PartitionMapper):
     """Base class for Temporal Partition Mappers."""
 
-    default_output_format: str
+    default_output_format: ClassVar[str]
     expected_decoded_type: ClassVar[type] = datetime
 
-    def __init__(
-        self,
-        *,
-        timezone: str | Timezone | FixedTimezone = "UTC",
-        input_format: str = "%Y-%m-%dT%H:%M:%S",
-        output_format: str | None = None,
-    ) -> None:
-        self.input_format = input_format
-        self.output_format = output_format or self.default_output_format
-        if isinstance(timezone, str):
-            timezone = parse_timezone(timezone)
-        self._timezone = timezone
+    _timezone: str | Timezone | FixedTimezone = attrs.field(
+        alias="timezone",
+        default="UTC",
+        kw_only=True,
+        converter=_timezone_converter,
+    )
+    input_format: str = attrs.field(default="%Y-%m-%dT%H:%M:%S", kw_only=True)
+    output_format: str | None = attrs.field(default=None, kw_only=True)
+
+    def __attrs_post_init__(self) -> None:
+        if not self.output_format:
+            self.output_format = self.default_output_format
 
 
 class StartOfHourMapper(_BaseTemporalMapper):
@@ -108,6 +117,7 @@ class StartOfYearMapper(_BaseTemporalMapper):
     default_output_format = "%Y"
 
 
[email protected](init=False)
 class FanOutMapper(PartitionMapper):
     """
     Partition mapper that fans one upstream key out into multiple downstream 
keys.
@@ -161,6 +171,25 @@ class FanOutMapper(PartitionMapper):
         "YearWindow": StartOfMonthMapper,
     }
 
+    upstream_mapper: PartitionMapper = attrs.field(kw_only=True)
+    window: Window = attrs.field(kw_only=True)
+    downstream_mapper: PartitionMapper = attrs.field(kw_only=True)
+
+    def __init__(
+        self,
+        *,
+        upstream_mapper: PartitionMapper,
+        window: Window,
+        downstream_mapper: PartitionMapper | None = None,
+        max_downstream_keys: int | None = None,
+    ) -> None:
+        self.__attrs_init__(
+            upstream_mapper=upstream_mapper,
+            window=window,
+            downstream_mapper=downstream_mapper or 
type(self)._resolve_default_downstream_mapper(window),
+            max_downstream_keys=max_downstream_keys,
+        )
+
     @classmethod
     def _resolve_default_downstream_mapper(cls, window: Window) -> 
PartitionMapper:
         """
@@ -179,14 +208,3 @@ class FanOutMapper(PartitionMapper):
                 f"{type(window).__name__}; pass downstream_mapper explicitly."
             )
         return mapper_cls()
-
-    def __init__(
-        self,
-        *,
-        upstream_mapper: PartitionMapper,
-        window: Window,
-        downstream_mapper: PartitionMapper | None = None,
-    ) -> None:
-        self.upstream_mapper = upstream_mapper
-        self.window = window
-        self.downstream_mapper = downstream_mapper or 
self._resolve_default_downstream_mapper(window)
diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py 
b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
index 7114b65d6c4..ca811115cb7 100644
--- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
+++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
@@ -23,6 +23,7 @@ import pytest
 
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, 
RollupMapper
 from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper
+from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper
 from airflow.sdk.definitions.partition_mappers.temporal import StartOfDayMapper
 from airflow.sdk.definitions.partition_mappers.window import (
     DayWindow,
@@ -96,6 +97,39 @@ class TestSdkDirectionValidation:
             WeekWindow(direction=bad_value)
 
 
+class TestSdkPartitionMapperMaxDownstreamKeysValidator:
+    """Verify the max_downstream_keys attrs field validator on the SDK 
PartitionMapper base.
+
+    Uses IdentityMapper as the most lightweight concrete subclass — the
+    validator lives on the base class so any subclass exercises it.
+    Mirrors core's TestPartitionMapperMaxDownstreamKeysValidator (6 cases).
+    """
+
+    def test_max_downstream_keys_none_is_accepted(self):
+        """Default (None) leaves max_downstream_keys as None."""
+        mapper = IdentityMapper()
+        assert mapper.max_downstream_keys is None
+
+    def test_max_downstream_keys_one_is_accepted(self):
+        """Minimum positive integer value is accepted."""
+        mapper = IdentityMapper(max_downstream_keys=1)
+        assert mapper.max_downstream_keys == 1
+
+    @pytest.mark.parametrize(
+        "bad_value",
+        [
+            pytest.param(0, id="zero"),
+            pytest.param(-1, id="negative"),
+            pytest.param(1.0, id="float"),
+            pytest.param("5", id="string"),
+        ],
+    )
+    def test_max_downstream_keys_invalid_raises(self, bad_value):
+        """0, negative integers, floats, and strings are all rejected."""
+        with pytest.raises(ValueError, match="max_downstream_keys"):
+            IdentityMapper(max_downstream_keys=bad_value)  # type: 
ignore[arg-type]
+
+
 class TestSdkWindowExpectedDecodedType:
     """Each SDK temporal window must declare ``datetime`` so the validation 
lines up with core."""
 

Reply via email to