This is an automated email from the ASF dual-hosted git repository.

Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 085459e6f96 AIP-76: Add PartitionAtRuntime authoring API to Task SDK 
(#65447)
085459e6f96 is described below

commit 085459e6f9681a508ff9d124d6275697b22fe8e1
Author: Anish Giri <[email protected]>
AuthorDate: Wed May 20 02:50:01 2026 -0500

    AIP-76: Add PartitionAtRuntime authoring API to Task SDK (#65447)
---
 .../docs/authoring-and-scheduling/assets.rst       |  2 +
 airflow-core/src/airflow/serialization/encoders.py |  7 ++-
 airflow-core/src/airflow/timetables/base.py        |  7 +++
 airflow-core/src/airflow/timetables/simple.py      | 16 +++++
 task-sdk/docs/api.rst                              |  2 +
 task-sdk/src/airflow/sdk/__init__.py               | 11 +++-
 task-sdk/src/airflow/sdk/__init__.pyi              |  2 +
 task-sdk/src/airflow/sdk/bases/timetable.py        |  8 +++
 .../airflow/sdk/definitions/asset/decorators.py    |  8 ++-
 .../airflow/sdk/definitions/timetables/assets.py   |  7 +++
 task-sdk/src/airflow/sdk/execution_time/context.py |  7 +++
 .../src/airflow/sdk/execution_time/task_runner.py  | 18 +++++-
 task-sdk/src/airflow/sdk/types.py                  | 11 +++-
 .../task_sdk/definitions/test_asset_decorators.py  | 68 +++++++++++++++++-----
 task-sdk/tests/task_sdk/definitions/test_dag.py    | 15 ++++-
 .../tests/task_sdk/execution_time/test_context.py  | 29 +++++++++
 .../task_sdk/execution_time/test_task_runner.py    | 38 ++++++++++++
 17 files changed, 235 insertions(+), 21 deletions(-)

diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst 
b/airflow-core/docs/authoring-and-scheduling/assets.rst
index 8e5cf356394..a983544900c 100644
--- a/airflow-core/docs/authoring-and-scheduling/assets.rst
+++ b/airflow-core/docs/authoring-and-scheduling/assets.rst
@@ -188,6 +188,8 @@ Declaring an ``@asset`` automatically creates:
 * A ``DAG`` with *dag_id* set to the function name.
 * A task inside the ``DAG`` with *task_id* set to the function name, and 
*outlet* to the created ``Asset``.
 
+The parameter names ``self``, ``context``, and ``outlet_events`` are 
**reserved** in an ``@asset`` function: they are populated by Airflow at 
runtime (with the asset itself, the execution context, and the outlet event 
accessor respectively) and are never treated as inlet asset references.
+
 Attaching extra information to an emitting asset event
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/airflow-core/src/airflow/serialization/encoders.py 
b/airflow-core/src/airflow/serialization/encoders.py
index 9e341cbe783..22384711d1b 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -55,6 +55,7 @@ from airflow.sdk.definitions.asset import AssetRef
 from airflow.sdk.definitions.partition_mappers.temporal import 
StartOfHourMapper
 from airflow.sdk.definitions.timetables.assets import (
     AssetTriggeredTimetable,
+    PartitionAtRuntime,
     PartitionedAssetTimetable,
 )
 from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, 
NullTimetable, OnceTimetable
@@ -295,6 +296,7 @@ class _Serializer:
         MultipleCronTriggerTimetable: 
"airflow.timetables.trigger.MultipleCronTriggerTimetable",
         NullTimetable: "airflow.timetables.simple.NullTimetable",
         OnceTimetable: "airflow.timetables.simple.OnceTimetable",
+        PartitionAtRuntime: "airflow.timetables.simple.PartitionAtRuntime",
         PartitionedAssetTimetable: 
"airflow.timetables.simple.PartitionedAssetTimetable",
     }
 
@@ -320,7 +322,10 @@ class _Serializer:
     @serialize_timetable.register(ContinuousTimetable)
     @serialize_timetable.register(NullTimetable)
     @serialize_timetable.register(OnceTimetable)
-    def _(self, timetable: ContinuousTimetable | NullTimetable | 
OnceTimetable) -> dict[str, Any]:
+    @serialize_timetable.register(PartitionAtRuntime)
+    def _(
+        self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable | 
PartitionAtRuntime
+    ) -> dict[str, Any]:
         return {}
 
     @serialize_timetable.register
diff --git a/airflow-core/src/airflow/timetables/base.py 
b/airflow-core/src/airflow/timetables/base.py
index a92bcd7a1f8..b90377cb295 100644
--- a/airflow-core/src/airflow/timetables/base.py
+++ b/airflow-core/src/airflow/timetables/base.py
@@ -218,6 +218,13 @@ class Timetable(Protocol):
     instead of the traditional logic based on logical dates and data intervals.
     """
 
+    partitioned_at_runtime: bool = False
+    """Whether this timetable defers partition selection to task runtime.
+
+    *True* for :class:`~airflow.timetables.simple.PartitionAtRuntime`;
+    downstream code can branch on this flag instead of using ``isinstance``.
+    """
+
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> Timetable:
         """
diff --git a/airflow-core/src/airflow/timetables/simple.py 
b/airflow-core/src/airflow/timetables/simple.py
index 01fb12f81dd..086e1153d61 100644
--- a/airflow-core/src/airflow/timetables/simple.py
+++ b/airflow-core/src/airflow/timetables/simple.py
@@ -93,6 +93,7 @@ class NullTimetable(_TrivialTimetable):
     """
 
     can_be_scheduled = False  # TODO (GH-52141): Find a way to keep this and 
one in Core in sync.
+    partitioned_at_runtime = False
     description: str = "Never, external triggers only"
 
     @property
@@ -183,6 +184,21 @@ class ContinuousTimetable(_TrivialTimetable):
         return DagRunInfo.interval(start, end)
 
 
+class PartitionAtRuntime(NullTimetable):
+    """
+    Timetable that never schedules anything; partition keys are set at runtime.
+
+    This corresponds to ``schedule=PartitionAtRuntime()``.
+    """
+
+    description: str = "Never, partition key(s) set at runtime"
+    partitioned_at_runtime = True
+
+    @property
+    def summary(self) -> str:
+        return "PartitionAtRuntime"
+
+
 class AssetTriggeredTimetable(_TrivialTimetable):
     """
     Timetable that never schedules anything.
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 222b28ad53d..6e9ed0a2758 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -209,6 +209,8 @@ Timetables
 
 .. autoapiclass:: airflow.sdk.MultipleCronTriggerTimetable
 
+.. autoapiclass:: airflow.sdk.PartitionAtRuntime
+
 .. autoapiclass:: airflow.sdk.PartitionedAssetTimetable
 
 
diff --git a/task-sdk/src/airflow/sdk/__init__.py 
b/task-sdk/src/airflow/sdk/__init__.py
index 05ececc2956..d9dcf60b994 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -60,6 +60,7 @@ __all__ = [
     "ObjectStoragePath",
     "Param",
     "ParamsDict",
+    "PartitionAtRuntime",
     "PartitionedAssetTimetable",
     "PartitionMapper",
     "PokeReturnValue",
@@ -119,7 +120,13 @@ if TYPE_CHECKING:
     from airflow.sdk.bases.skipmixin import SkipMixin
     from airflow.sdk.bases.xcom import BaseXCom
     from airflow.sdk.configuration import AirflowSDKConfigParser
-    from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, 
AssetAny, AssetWatcher
+    from airflow.sdk.definitions.asset import (
+        Asset,
+        AssetAlias,
+        AssetAll,
+        AssetAny,
+        AssetWatcher,
+    )
     from airflow.sdk.definitions.asset.decorators import asset
     from airflow.sdk.definitions.asset.metadata import Metadata
     from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
@@ -155,6 +162,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.template import literal
     from airflow.sdk.definitions.timetables.assets import (
         AssetOrTimeSchedule,
+        PartitionAtRuntime,
         PartitionedAssetTimetable,
     )
     from airflow.sdk.definitions.timetables.events import EventsTimetable
@@ -217,6 +225,7 @@ __lazy_imports: dict[str, str] = {
     "ObjectStoragePath": ".io.path",
     "Param": ".definitions.param",
     "ParamsDict": ".definitions.param",
+    "PartitionAtRuntime": ".definitions.timetables.assets",
     "PartitionedAssetTimetable": ".definitions.timetables.assets",
     "PartitionMapper": ".definitions.partition_mappers.base",
     "PokeReturnValue": ".bases.sensor",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi 
b/task-sdk/src/airflow/sdk/__init__.pyi
index 7e6d211674e..d0b4af5d9e5 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -86,6 +86,7 @@ from airflow.sdk.definitions.taskgroup import TaskGroup as 
TaskGroup
 from airflow.sdk.definitions.template import literal as literal
 from airflow.sdk.definitions.timetables.assets import (
     AssetOrTimeSchedule,
+    PartitionAtRuntime,
     PartitionedAssetTimetable,
 )
 from airflow.sdk.definitions.timetables.events import EventsTimetable
@@ -145,6 +146,7 @@ __all__ = [
     "ObjectStoragePath",
     "Param",
     "PokeReturnValue",
+    "PartitionAtRuntime",
     "PartitionedAssetTimetable",
     "PartitionMapper",
     "ProductMapper",
diff --git a/task-sdk/src/airflow/sdk/bases/timetable.py 
b/task-sdk/src/airflow/sdk/bases/timetable.py
index e732566f153..bd37a6a0a79 100644
--- a/task-sdk/src/airflow/sdk/bases/timetable.py
+++ b/task-sdk/src/airflow/sdk/bases/timetable.py
@@ -47,6 +47,14 @@ class BaseTimetable:
 
     asset_condition: BaseAsset | None = None
 
+    partitioned_at_runtime: bool = False
+    """
+    Whether this timetable defers partition selection to task runtime.
+
+    *True* for :class:`~airflow.sdk.PartitionAtRuntime`; downstream code can
+    branch on this flag instead of using ``isinstance``.
+    """
+
     def validate(self) -> None:
         """
         Validate the timetable is correctly specified.
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py 
b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
index 26205f0c583..784898f2803 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -39,6 +39,9 @@ if TYPE_CHECKING:
     from airflow.triggers.base import BaseTrigger
 
 
+_INVALID_INLET_ASSET_NAMES = ("self", "context", "outlet_events")
+
+
 def _validate_asset_function_arguments(f: Callable) -> None:
     for name, param in inspect.signature(f).parameters.items():
         if param.kind == inspect.Parameter.VAR_POSITIONAL:
@@ -62,7 +65,8 @@ class _AssetMainOperator(PythonOperator):
             inlets=[
                 Asset.ref(name=inlet_asset_name)
                 for inlet_asset_name, param in 
inspect.signature(definition._function).parameters.items()
-                if inlet_asset_name not in ("self", "context") and 
param.default is inspect.Parameter.empty
+                if inlet_asset_name not in _INVALID_INLET_ASSET_NAMES
+                and param.default is inspect.Parameter.empty
             ],
             outlets=list(definition.iter_outlets()),
             python_callable=definition._function,
@@ -89,6 +93,8 @@ class _AssetMainOperator(PythonOperator):
                 value = _fetch_asset(self._definition_name)
             elif key == "context":
                 value = context
+            elif key == "outlet_events":
+                value = context["outlet_events"]
             else:
                 value = _fetch_asset(key)
             yield key, value
diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py 
b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
index e6bb683ebca..a0c64936925 100644
--- a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
+++ b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
@@ -54,6 +54,13 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable):
     default_partition_mapper: PartitionMapper = IdentityMapper()
 
 
+class PartitionAtRuntime(BaseTimetable):
+    """Marker timetable indicating that partition key(s) are determined at 
runtime."""
+
+    can_be_scheduled = False
+    partitioned_at_runtime = True
+
+
 def _coerce_assets(o: Collection[Asset] | BaseAsset) -> BaseAsset:
     if isinstance(o, BaseAsset):
         return o
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 6f72667dd90..1e6874121fc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -776,6 +776,13 @@ class OutletEventAccessor(_AssetRefResolutionMixin):
     key: BaseAssetUniqueKey
     extra: dict[str, JsonValue] = attrs.Factory(dict)
     asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
+    partition_keys: set[str] = attrs.field(factory=set)
+
+    def add_partitions(self, keys: str | list[str]) -> None:
+        """Add one or more partition keys to :attr:`partition_keys`."""
+        if isinstance(keys, str):
+            keys = [keys]
+        self.partition_keys.update(keys)
 
     def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None 
= None) -> None:
         """Add an AssetEvent to an existing Asset."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 61e476e60b9..852b5da9cc6 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -58,7 +58,13 @@ from airflow.sdk.bases.xcom import BaseXCom
 from airflow.sdk.configuration import conf
 from airflow.sdk.definitions._internal.dag_parsing_context import 
_airflow_parsing_context_manager
 from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, 
is_arg_set
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, 
AssetUniqueKey, AssetUriRef
+from airflow.sdk.definitions.asset import (
+    Asset,
+    AssetAlias,
+    AssetNameRef,
+    AssetUniqueKey,
+    AssetUriRef,
+)
 from airflow.sdk.definitions.mappedoperator import MappedOperator
 from airflow.sdk.definitions.param import process_params
 from airflow.sdk.exceptions import (
@@ -1159,7 +1165,15 @@ def _serialize_outlet_events(events: 
OutletEventAccessorsProtocol) -> Iterator[d
     # Further filtering will be done in the API server.
     for key, accessor in events._dict.items():
         if isinstance(key, AssetUniqueKey):
-            yield {"dest_asset_key": attrs.asdict(key), "extra": 
accessor.extra}
+            if accessor.partition_keys:
+                for partition_key in accessor.partition_keys:
+                    yield {
+                        "dest_asset_key": attrs.asdict(key),
+                        "extra": accessor.extra,
+                        "partition_key": partition_key,
+                    }
+            else:
+                yield {"dest_asset_key": attrs.asdict(key), "extra": 
accessor.extra}
         for alias_event in accessor.asset_alias_events:
             yield attrs.asdict(alias_event)
 
diff --git a/task-sdk/src/airflow/sdk/types.py 
b/task-sdk/src/airflow/sdk/types.py
index 262e1206a30..3a00fea6d96 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -42,7 +42,13 @@ if TYPE_CHECKING:
         TaskInstanceState,
     )
     from airflow.sdk.bases.operator import BaseOperator
-    from airflow.sdk.definitions.asset import Asset, AssetAlias, 
AssetAliasEvent, AssetRef, BaseAssetUniqueKey
+    from airflow.sdk.definitions.asset import (
+        Asset,
+        AssetAlias,
+        AssetAliasEvent,
+        AssetRef,
+        BaseAssetUniqueKey,
+    )
     from airflow.sdk.definitions.context import Context
     from airflow.sdk.definitions.mappedoperator import MappedOperator
     from airflow.sdk.execution_time.comms import DagResult
@@ -215,6 +221,7 @@ class OutletEventAccessorProtocol(Protocol):
     key: BaseAssetUniqueKey
     extra: dict[str, JsonValue]
     asset_alias_events: list[AssetAliasEvent]
+    partition_keys: set[str]
 
     def __init__(
         self,
@@ -222,8 +229,10 @@ class OutletEventAccessorProtocol(Protocol):
         key: BaseAssetUniqueKey,
         extra: dict[str, JsonValue],
         asset_alias_events: list[AssetAliasEvent],
+        partition_keys: set[str] = ...,
     ) -> None: ...
     def add(self, asset: Asset, extra: dict[str, JsonValue] | None = None) -> 
None: ...
+    def add_partitions(self, keys: str | list[str]) -> None: ...
 
 
 class OutletEventAccessorsProtocol(Protocol):
diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py 
b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
index c0b3b617fc0..47a208da04e 100644
--- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
+++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
@@ -24,6 +24,7 @@ from airflow.sdk.definitions.asset import Asset
 from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset
 from airflow.sdk.definitions.decorators import task
 from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName
+from airflow.sdk.execution_time.context import OutletEventAccessors
 
 
 @pytest.fixture
@@ -77,6 +78,15 @@ def 
example_asset_func_with_valid_arg_as_inlet_asset_and_default(func_fixer):
     return _example_asset_func
 
 
[email protected]
+def example_asset_func_with_outlet_events(func_fixer):
+    @func_fixer
+    def _example_asset_func(self, outlet_events):
+        return "This is example_asset"
+
+    return _example_asset_func
+
+
 class TestAssetDecorator:
     def test_without_uri(self, example_asset_func):
         asset_definition = asset(schedule=None)(example_asset_func)
@@ -392,17 +402,18 @@ class Test_AssetMainOperator:
             python_callable=example_asset_func_with_valid_arg_as_inlet_asset,
             definition_name="example_asset_func",
         )
-        assert op.determine_kwargs(context={"k": "v"}) == {
-            "self": Asset(
-                name="example_asset_func",
-                uri="s3://bucket/object",
-                group="MLModel",
-                extra={"k": "v"},
-            ),
-            "context": {"k": "v"},
-            "inlet_asset_1": Asset(name="inlet_asset_1", 
uri="s3://bucket/object1"),
-            "inlet_asset_2": Asset(name="inlet_asset_2"),
-        }
+        outlet_events = OutletEventAccessors()
+        context = {"k": "v", "outlet_events": outlet_events}
+        kwargs = op.determine_kwargs(context=context)
+        assert kwargs["self"] == Asset(
+            name="example_asset_func",
+            uri="s3://bucket/object",
+            group="MLModel",
+            extra={"k": "v"},
+        )
+        assert kwargs["context"] is context
+        assert kwargs["inlet_asset_1"] == Asset(name="inlet_asset_1", 
uri="s3://bucket/object1")
+        assert kwargs["inlet_asset_2"] == Asset(name="inlet_asset_2")
 
         assert mock_supervisor_comms.mock_calls == [
             mock.call.send(GetAssetByName(name="example_asset_func")),
@@ -453,10 +464,39 @@ class Test_AssetMainOperator:
             AssetResult(name="custom_name", uri="s3://bucket/object1", 
group="Asset")
         ]
 
-        assert op.determine_kwargs(context={}) == {
-            "self": Asset(name="custom_name", uri="s3://bucket/object1", 
group="Asset")
-        }
+        kwargs = op.determine_kwargs(context={"outlet_events": 
OutletEventAccessors()})
+        assert list(kwargs) == ["self"]
+        assert kwargs["self"] == Asset(name="custom_name", 
uri="s3://bucket/object1", group="Asset")
 
         assert mock_supervisor_comms.mock_calls == [
             mock.call.send(GetAssetByName(name="custom_name", 
uri="s3://bucket/object1", group="Asset"))
         ]
+
+
+class TestOutletEventsKwarg:
+    def test_determine_kwargs_injects_outlet_events(
+        self, mock_supervisor_comms, example_asset_func_with_outlet_events
+    ):
+        definition = 
asset(schedule=None)(example_asset_func_with_outlet_events)
+        outlet_events = OutletEventAccessors()
+        context = {"outlet_events": outlet_events}
+
+        mock_supervisor_comms.send.side_effect = [
+            AssetResult(name="example_asset_func", uri="example_asset_func", 
group="asset"),
+        ]
+
+        op = _AssetMainOperator(
+            task_id="example_asset_func",
+            inlets=[],
+            outlets=[definition],
+            python_callable=example_asset_func_with_outlet_events,
+            definition_name="example_asset_func",
+        )
+
+        kwargs = op.determine_kwargs(context=context)
+        assert kwargs["outlet_events"] is outlet_events
+
+    def test_from_definition_excludes_outlet_events_from_inlets(self, 
example_asset_func_with_outlet_events):
+        definition = 
asset(schedule=None)(example_asset_func_with_outlet_events)
+        op = _AssetMainOperator.from_definition(definition)
+        assert op.inlets == []
diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py 
b/task-sdk/tests/task_sdk/definitions/test_dag.py
index a7d1bfbd926..c42eb8dfc80 100644
--- a/task-sdk/tests/task_sdk/definitions/test_dag.py
+++ b/task-sdk/tests/task_sdk/definitions/test_dag.py
@@ -24,10 +24,12 @@ from typing import Any
 
 import pytest
 
-from airflow.sdk import Context, Label, TaskGroup
+from airflow.sdk import Context, Label, PartitionAtRuntime, TaskGroup
 from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.bases.timetable import BaseTimetable
 from airflow.sdk.definitions.dag import DAG, dag as dag_decorator
 from airflow.sdk.definitions.param import DagParam, Param, ParamsDict
+from airflow.sdk.definitions.timetables import assets, events, interval, 
simple, trigger  # noqa: F401
 from airflow.sdk.exceptions import AirflowDagCycleException, 
DuplicateTaskIdFound, RemovedInAirflow4Warning
 from airflow.utils.types import DagRunType
 
@@ -437,6 +439,17 @@ class TestDag:
         with pytest.raises(ValueError, match="ContinuousTimetable requires 
max_active_runs <= 1"):
             dag = DAG("continuous", start_date=DEFAULT_DATE, 
schedule="@continuous", max_active_runs=25)
 
+    def test_only_partition_at_runtime_has_partitioned_at_runtime_flag(self):
+        """Regression guard: across every BaseTimetable subclass, only 
PartitionAtRuntime sets partitioned_at_runtime=True."""
+
+        def all_subclasses(cls):
+            for sub in cls.__subclasses__():
+                yield sub
+                yield from all_subclasses(sub)
+
+        flagged = {c for c in all_subclasses(BaseTimetable) if 
c.partitioned_at_runtime}
+        assert flagged == {PartitionAtRuntime}
+
     def test_dag_add_task_checks_trigger_rule(self):
         # A non fail stop dag should allow any trigger rule
         from airflow.sdk import TriggerRule
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index a5ff7be9ce8..062645d25b1 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -461,6 +461,35 @@ class TestOutletEventAccessor:
         assert outlet_event_accessor.asset_alias_events == asset_alias_events
 
 
+class TestOutletEventAccessorPartitionKeys:
+    @pytest.fixture
+    def accessor(self) -> OutletEventAccessor:
+        return OutletEventAccessor(key=AssetUniqueKey.from_asset(Asset("a")))
+
+    def test_default_is_empty(self, accessor):
+        assert accessor.partition_keys == set()
+
+    def test_direct_assignment(self, accessor):
+        accessor.partition_keys = {"us", "eu"}
+        assert accessor.partition_keys == {"us", "eu"}
+
+    def test_add_partitions(self, accessor):
+        accessor.add_partitions("us")
+        assert accessor.partition_keys == {"us"}
+
+    def test_add_partitions_appends(self, accessor):
+        accessor.add_partitions("us")
+        accessor.add_partitions("eu")
+        accessor.add_partitions("apac")
+        assert accessor.partition_keys == {"us", "eu", "apac"}
+
+    def test_add_partitions_dedupes(self, accessor):
+        accessor.add_partitions("us")
+        accessor.add_partitions("us")
+        accessor.add_partitions(["us", "eu"])
+        assert accessor.partition_keys == {"us", "eu"}
+
+
 class TestTriggeringAssetEventsAccessor:
     @pytest.fixture(autouse=True)
     def clear_cache(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index d7b72b2c48f..9425ce362ed 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -156,6 +156,7 @@ from airflow.sdk.execution_time.task_runner import (
     _execute_task,
     _make_task_span,
     _push_xcom_if_needed,
+    _serialize_outlet_events,
     _xcom_push,
     finalize,
     get_startup_details,
@@ -1813,6 +1814,43 @@ def 
test_rendered_map_index_updates_sent_progressively(create_runtime_ti, mock_s
     assert ti.rendered_map_index == "Label: test_task"
 
 
+class TestSerializeOutletEvents:
+    """Tests for the wire format produced by ``_serialize_outlet_events``."""
+
+    def test_emits_single_event_when_no_partition_keys(self):
+        accessors = OutletEventAccessors()
+        accessors[Asset(name="a")].extra = {"x": 1}
+
+        events = list(_serialize_outlet_events(accessors))
+
+        assert events == [{"dest_asset_key": {"name": "a", "uri": "a"}, 
"extra": {"x": 1}}]
+
+    def test_emits_one_event_per_partition_key(self):
+        accessors = OutletEventAccessors()
+        accessors[Asset(name="a")].partition_keys = {"us", "eu"}
+
+        events = list(_serialize_outlet_events(accessors))
+
+        assert sorted(events, key=lambda e: e["partition_key"]) == [
+            {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {}, 
"partition_key": "eu"},
+            {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {}, 
"partition_key": "us"},
+        ]
+
+    def test_dedupes_partition_keys_at_serialization(self):
+        accessors = OutletEventAccessors()
+        accessor = accessors[Asset(name="a")]
+        accessor.add_partitions("us")
+        accessor.add_partitions("us")
+        accessor.add_partitions(["us", "eu"])
+
+        events = list(_serialize_outlet_events(accessors))
+
+        assert sorted(events, key=lambda e: e["partition_key"]) == [
+            {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {}, 
"partition_key": "eu"},
+            {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {}, 
"partition_key": "us"},
+        ]
+
+
 class TestRuntimeTaskInstance:
     def test_get_context_without_ti_context_from_server(self, mocked_parse, 
make_ti_context):
         """Test get_template_context without ti_context_from_server."""

Reply via email to