This is an automated email from the ASF dual-hosted git repository.
Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 085459e6f96 AIP-76: Add PartitionAtRuntime authoring API to Task SDK
(#65447)
085459e6f96 is described below
commit 085459e6f9681a508ff9d124d6275697b22fe8e1
Author: Anish Giri <[email protected]>
AuthorDate: Wed May 20 02:50:01 2026 -0500
AIP-76: Add PartitionAtRuntime authoring API to Task SDK (#65447)
---
.../docs/authoring-and-scheduling/assets.rst | 2 +
airflow-core/src/airflow/serialization/encoders.py | 7 ++-
airflow-core/src/airflow/timetables/base.py | 7 +++
airflow-core/src/airflow/timetables/simple.py | 16 +++++
task-sdk/docs/api.rst | 2 +
task-sdk/src/airflow/sdk/__init__.py | 11 +++-
task-sdk/src/airflow/sdk/__init__.pyi | 2 +
task-sdk/src/airflow/sdk/bases/timetable.py | 8 +++
.../airflow/sdk/definitions/asset/decorators.py | 8 ++-
.../airflow/sdk/definitions/timetables/assets.py | 7 +++
task-sdk/src/airflow/sdk/execution_time/context.py | 7 +++
.../src/airflow/sdk/execution_time/task_runner.py | 18 +++++-
task-sdk/src/airflow/sdk/types.py | 11 +++-
.../task_sdk/definitions/test_asset_decorators.py | 68 +++++++++++++++++-----
task-sdk/tests/task_sdk/definitions/test_dag.py | 15 ++++-
.../tests/task_sdk/execution_time/test_context.py | 29 +++++++++
.../task_sdk/execution_time/test_task_runner.py | 38 ++++++++++++
17 files changed, 235 insertions(+), 21 deletions(-)
diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst
b/airflow-core/docs/authoring-and-scheduling/assets.rst
index 8e5cf356394..a983544900c 100644
--- a/airflow-core/docs/authoring-and-scheduling/assets.rst
+++ b/airflow-core/docs/authoring-and-scheduling/assets.rst
@@ -188,6 +188,8 @@ Declaring an ``@asset`` automatically creates:
* A ``DAG`` with *dag_id* set to the function name.
* A task inside the ``DAG`` with *task_id* set to the function name, and
*outlet* to the created ``Asset``.
+The parameter names ``self``, ``context``, and ``outlet_events`` are
**reserved** in an ``@asset`` function: they are populated by Airflow at
runtime (with the asset itself, the execution context, and the outlet event
accessor respectively) and are never treated as inlet asset references.
+
Attaching extra information to an emitting asset event
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index 9e341cbe783..22384711d1b 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -55,6 +55,7 @@ from airflow.sdk.definitions.asset import AssetRef
from airflow.sdk.definitions.partition_mappers.temporal import
StartOfHourMapper
from airflow.sdk.definitions.timetables.assets import (
AssetTriggeredTimetable,
+ PartitionAtRuntime,
PartitionedAssetTimetable,
)
from airflow.sdk.definitions.timetables.simple import ContinuousTimetable,
NullTimetable, OnceTimetable
@@ -295,6 +296,7 @@ class _Serializer:
MultipleCronTriggerTimetable:
"airflow.timetables.trigger.MultipleCronTriggerTimetable",
NullTimetable: "airflow.timetables.simple.NullTimetable",
OnceTimetable: "airflow.timetables.simple.OnceTimetable",
+ PartitionAtRuntime: "airflow.timetables.simple.PartitionAtRuntime",
PartitionedAssetTimetable:
"airflow.timetables.simple.PartitionedAssetTimetable",
}
@@ -320,7 +322,10 @@ class _Serializer:
@serialize_timetable.register(ContinuousTimetable)
@serialize_timetable.register(NullTimetable)
@serialize_timetable.register(OnceTimetable)
- def _(self, timetable: ContinuousTimetable | NullTimetable |
OnceTimetable) -> dict[str, Any]:
+ @serialize_timetable.register(PartitionAtRuntime)
+ def _(
+ self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable |
PartitionAtRuntime
+ ) -> dict[str, Any]:
return {}
@serialize_timetable.register
diff --git a/airflow-core/src/airflow/timetables/base.py
b/airflow-core/src/airflow/timetables/base.py
index a92bcd7a1f8..b90377cb295 100644
--- a/airflow-core/src/airflow/timetables/base.py
+++ b/airflow-core/src/airflow/timetables/base.py
@@ -218,6 +218,13 @@ class Timetable(Protocol):
instead of the traditional logic based on logical dates and data intervals.
"""
+ partitioned_at_runtime: bool = False
+ """Whether this timetable defers partition selection to task runtime.
+
+ *True* for :class:`~airflow.timetables.simple.PartitionAtRuntime`;
+ downstream code can branch on this flag instead of using ``isinstance``.
+ """
+
@classmethod
def deserialize(cls, data: dict[str, Any]) -> Timetable:
"""
diff --git a/airflow-core/src/airflow/timetables/simple.py
b/airflow-core/src/airflow/timetables/simple.py
index 01fb12f81dd..086e1153d61 100644
--- a/airflow-core/src/airflow/timetables/simple.py
+++ b/airflow-core/src/airflow/timetables/simple.py
@@ -93,6 +93,7 @@ class NullTimetable(_TrivialTimetable):
"""
can_be_scheduled = False # TODO (GH-52141): Find a way to keep this and
one in Core in sync.
+ partitioned_at_runtime = False
description: str = "Never, external triggers only"
@property
@@ -183,6 +184,21 @@ class ContinuousTimetable(_TrivialTimetable):
return DagRunInfo.interval(start, end)
+class PartitionAtRuntime(NullTimetable):
+ """
+ Timetable that never schedules anything; partition keys are set at runtime.
+
+ This corresponds to ``schedule=PartitionAtRuntime()``.
+ """
+
+ description: str = "Never, partition key(s) set at runtime"
+ partitioned_at_runtime = True
+
+ @property
+ def summary(self) -> str:
+ return "PartitionAtRuntime"
+
+
class AssetTriggeredTimetable(_TrivialTimetable):
"""
Timetable that never schedules anything.
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 222b28ad53d..6e9ed0a2758 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -209,6 +209,8 @@ Timetables
.. autoapiclass:: airflow.sdk.MultipleCronTriggerTimetable
+.. autoapiclass:: airflow.sdk.PartitionAtRuntime
+
.. autoapiclass:: airflow.sdk.PartitionedAssetTimetable
diff --git a/task-sdk/src/airflow/sdk/__init__.py
b/task-sdk/src/airflow/sdk/__init__.py
index 05ececc2956..d9dcf60b994 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -60,6 +60,7 @@ __all__ = [
"ObjectStoragePath",
"Param",
"ParamsDict",
+ "PartitionAtRuntime",
"PartitionedAssetTimetable",
"PartitionMapper",
"PokeReturnValue",
@@ -119,7 +120,13 @@ if TYPE_CHECKING:
from airflow.sdk.bases.skipmixin import SkipMixin
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.configuration import AirflowSDKConfigParser
- from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll,
AssetAny, AssetWatcher
+ from airflow.sdk.definitions.asset import (
+ Asset,
+ AssetAlias,
+ AssetAll,
+ AssetAny,
+ AssetWatcher,
+ )
from airflow.sdk.definitions.asset.decorators import asset
from airflow.sdk.definitions.asset.metadata import Metadata
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
@@ -155,6 +162,7 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.template import literal
from airflow.sdk.definitions.timetables.assets import (
AssetOrTimeSchedule,
+ PartitionAtRuntime,
PartitionedAssetTimetable,
)
from airflow.sdk.definitions.timetables.events import EventsTimetable
@@ -217,6 +225,7 @@ __lazy_imports: dict[str, str] = {
"ObjectStoragePath": ".io.path",
"Param": ".definitions.param",
"ParamsDict": ".definitions.param",
+ "PartitionAtRuntime": ".definitions.timetables.assets",
"PartitionedAssetTimetable": ".definitions.timetables.assets",
"PartitionMapper": ".definitions.partition_mappers.base",
"PokeReturnValue": ".bases.sensor",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi
b/task-sdk/src/airflow/sdk/__init__.pyi
index 7e6d211674e..d0b4af5d9e5 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -86,6 +86,7 @@ from airflow.sdk.definitions.taskgroup import TaskGroup as
TaskGroup
from airflow.sdk.definitions.template import literal as literal
from airflow.sdk.definitions.timetables.assets import (
AssetOrTimeSchedule,
+ PartitionAtRuntime,
PartitionedAssetTimetable,
)
from airflow.sdk.definitions.timetables.events import EventsTimetable
@@ -145,6 +146,7 @@ __all__ = [
"ObjectStoragePath",
"Param",
"PokeReturnValue",
+ "PartitionAtRuntime",
"PartitionedAssetTimetable",
"PartitionMapper",
"ProductMapper",
diff --git a/task-sdk/src/airflow/sdk/bases/timetable.py
b/task-sdk/src/airflow/sdk/bases/timetable.py
index e732566f153..bd37a6a0a79 100644
--- a/task-sdk/src/airflow/sdk/bases/timetable.py
+++ b/task-sdk/src/airflow/sdk/bases/timetable.py
@@ -47,6 +47,14 @@ class BaseTimetable:
asset_condition: BaseAsset | None = None
+ partitioned_at_runtime: bool = False
+ """
+ Whether this timetable defers partition selection to task runtime.
+
+ *True* for :class:`~airflow.sdk.PartitionAtRuntime`; downstream code can
+ branch on this flag instead of using ``isinstance``.
+ """
+
def validate(self) -> None:
"""
Validate the timetable is correctly specified.
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
index 26205f0c583..784898f2803 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -39,6 +39,9 @@ if TYPE_CHECKING:
from airflow.triggers.base import BaseTrigger
+_INVALID_INLET_ASSET_NAMES = ("self", "context", "outlet_events")
+
+
def _validate_asset_function_arguments(f: Callable) -> None:
for name, param in inspect.signature(f).parameters.items():
if param.kind == inspect.Parameter.VAR_POSITIONAL:
@@ -62,7 +65,8 @@ class _AssetMainOperator(PythonOperator):
inlets=[
Asset.ref(name=inlet_asset_name)
for inlet_asset_name, param in
inspect.signature(definition._function).parameters.items()
- if inlet_asset_name not in ("self", "context") and
param.default is inspect.Parameter.empty
+ if inlet_asset_name not in _INVALID_INLET_ASSET_NAMES
+ and param.default is inspect.Parameter.empty
],
outlets=list(definition.iter_outlets()),
python_callable=definition._function,
@@ -89,6 +93,8 @@ class _AssetMainOperator(PythonOperator):
value = _fetch_asset(self._definition_name)
elif key == "context":
value = context
+ elif key == "outlet_events":
+ value = context["outlet_events"]
else:
value = _fetch_asset(key)
yield key, value
diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
index e6bb683ebca..a0c64936925 100644
--- a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
+++ b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py
@@ -54,6 +54,13 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable):
default_partition_mapper: PartitionMapper = IdentityMapper()
+class PartitionAtRuntime(BaseTimetable):
+ """Marker timetable indicating that partition key(s) are determined at
runtime."""
+
+ can_be_scheduled = False
+ partitioned_at_runtime = True
+
+
def _coerce_assets(o: Collection[Asset] | BaseAsset) -> BaseAsset:
if isinstance(o, BaseAsset):
return o
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 6f72667dd90..1e6874121fc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -776,6 +776,13 @@ class OutletEventAccessor(_AssetRefResolutionMixin):
key: BaseAssetUniqueKey
extra: dict[str, JsonValue] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
+ partition_keys: set[str] = attrs.field(factory=set)
+
+ def add_partitions(self, keys: str | list[str]) -> None:
+ """Add one or more partition keys to :attr:`partition_keys`."""
+ if isinstance(keys, str):
+ keys = [keys]
+ self.partition_keys.update(keys)
def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None
= None) -> None:
"""Add an AssetEvent to an existing Asset."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 61e476e60b9..852b5da9cc6 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -58,7 +58,13 @@ from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.configuration import conf
from airflow.sdk.definitions._internal.dag_parsing_context import
_airflow_parsing_context_manager
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet,
is_arg_set
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef,
AssetUniqueKey, AssetUriRef
+from airflow.sdk.definitions.asset import (
+ Asset,
+ AssetAlias,
+ AssetNameRef,
+ AssetUniqueKey,
+ AssetUriRef,
+)
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.param import process_params
from airflow.sdk.exceptions import (
@@ -1159,7 +1165,15 @@ def _serialize_outlet_events(events:
OutletEventAccessorsProtocol) -> Iterator[d
# Further filtering will be done in the API server.
for key, accessor in events._dict.items():
if isinstance(key, AssetUniqueKey):
- yield {"dest_asset_key": attrs.asdict(key), "extra":
accessor.extra}
+ if accessor.partition_keys:
+ for partition_key in accessor.partition_keys:
+ yield {
+ "dest_asset_key": attrs.asdict(key),
+ "extra": accessor.extra,
+ "partition_key": partition_key,
+ }
+ else:
+ yield {"dest_asset_key": attrs.asdict(key), "extra":
accessor.extra}
for alias_event in accessor.asset_alias_events:
yield attrs.asdict(alias_event)
diff --git a/task-sdk/src/airflow/sdk/types.py
b/task-sdk/src/airflow/sdk/types.py
index 262e1206a30..3a00fea6d96 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -42,7 +42,13 @@ if TYPE_CHECKING:
TaskInstanceState,
)
from airflow.sdk.bases.operator import BaseOperator
- from airflow.sdk.definitions.asset import Asset, AssetAlias,
AssetAliasEvent, AssetRef, BaseAssetUniqueKey
+ from airflow.sdk.definitions.asset import (
+ Asset,
+ AssetAlias,
+ AssetAliasEvent,
+ AssetRef,
+ BaseAssetUniqueKey,
+ )
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.execution_time.comms import DagResult
@@ -215,6 +221,7 @@ class OutletEventAccessorProtocol(Protocol):
key: BaseAssetUniqueKey
extra: dict[str, JsonValue]
asset_alias_events: list[AssetAliasEvent]
+ partition_keys: set[str]
def __init__(
self,
@@ -222,8 +229,10 @@ class OutletEventAccessorProtocol(Protocol):
key: BaseAssetUniqueKey,
extra: dict[str, JsonValue],
asset_alias_events: list[AssetAliasEvent],
+ partition_keys: set[str] = ...,
) -> None: ...
def add(self, asset: Asset, extra: dict[str, JsonValue] | None = None) ->
None: ...
+ def add_partitions(self, keys: str | list[str]) -> None: ...
class OutletEventAccessorsProtocol(Protocol):
diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
index c0b3b617fc0..47a208da04e 100644
--- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
+++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
@@ -24,6 +24,7 @@ from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset
from airflow.sdk.definitions.decorators import task
from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName
+from airflow.sdk.execution_time.context import OutletEventAccessors
@pytest.fixture
@@ -77,6 +78,15 @@ def
example_asset_func_with_valid_arg_as_inlet_asset_and_default(func_fixer):
return _example_asset_func
[email protected]
+def example_asset_func_with_outlet_events(func_fixer):
+ @func_fixer
+ def _example_asset_func(self, outlet_events):
+ return "This is example_asset"
+
+ return _example_asset_func
+
+
class TestAssetDecorator:
def test_without_uri(self, example_asset_func):
asset_definition = asset(schedule=None)(example_asset_func)
@@ -392,17 +402,18 @@ class Test_AssetMainOperator:
python_callable=example_asset_func_with_valid_arg_as_inlet_asset,
definition_name="example_asset_func",
)
- assert op.determine_kwargs(context={"k": "v"}) == {
- "self": Asset(
- name="example_asset_func",
- uri="s3://bucket/object",
- group="MLModel",
- extra={"k": "v"},
- ),
- "context": {"k": "v"},
- "inlet_asset_1": Asset(name="inlet_asset_1",
uri="s3://bucket/object1"),
- "inlet_asset_2": Asset(name="inlet_asset_2"),
- }
+ outlet_events = OutletEventAccessors()
+ context = {"k": "v", "outlet_events": outlet_events}
+ kwargs = op.determine_kwargs(context=context)
+ assert kwargs["self"] == Asset(
+ name="example_asset_func",
+ uri="s3://bucket/object",
+ group="MLModel",
+ extra={"k": "v"},
+ )
+ assert kwargs["context"] is context
+ assert kwargs["inlet_asset_1"] == Asset(name="inlet_asset_1",
uri="s3://bucket/object1")
+ assert kwargs["inlet_asset_2"] == Asset(name="inlet_asset_2")
assert mock_supervisor_comms.mock_calls == [
mock.call.send(GetAssetByName(name="example_asset_func")),
@@ -453,10 +464,39 @@ class Test_AssetMainOperator:
AssetResult(name="custom_name", uri="s3://bucket/object1",
group="Asset")
]
- assert op.determine_kwargs(context={}) == {
- "self": Asset(name="custom_name", uri="s3://bucket/object1",
group="Asset")
- }
+ kwargs = op.determine_kwargs(context={"outlet_events":
OutletEventAccessors()})
+ assert list(kwargs) == ["self"]
+ assert kwargs["self"] == Asset(name="custom_name",
uri="s3://bucket/object1", group="Asset")
assert mock_supervisor_comms.mock_calls == [
mock.call.send(GetAssetByName(name="custom_name",
uri="s3://bucket/object1", group="Asset"))
]
+
+
+class TestOutletEventsKwarg:
+ def test_determine_kwargs_injects_outlet_events(
+ self, mock_supervisor_comms, example_asset_func_with_outlet_events
+ ):
+ definition =
asset(schedule=None)(example_asset_func_with_outlet_events)
+ outlet_events = OutletEventAccessors()
+ context = {"outlet_events": outlet_events}
+
+ mock_supervisor_comms.send.side_effect = [
+ AssetResult(name="example_asset_func", uri="example_asset_func",
group="asset"),
+ ]
+
+ op = _AssetMainOperator(
+ task_id="example_asset_func",
+ inlets=[],
+ outlets=[definition],
+ python_callable=example_asset_func_with_outlet_events,
+ definition_name="example_asset_func",
+ )
+
+ kwargs = op.determine_kwargs(context=context)
+ assert kwargs["outlet_events"] is outlet_events
+
+ def test_from_definition_excludes_outlet_events_from_inlets(self,
example_asset_func_with_outlet_events):
+ definition =
asset(schedule=None)(example_asset_func_with_outlet_events)
+ op = _AssetMainOperator.from_definition(definition)
+ assert op.inlets == []
diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py
b/task-sdk/tests/task_sdk/definitions/test_dag.py
index a7d1bfbd926..c42eb8dfc80 100644
--- a/task-sdk/tests/task_sdk/definitions/test_dag.py
+++ b/task-sdk/tests/task_sdk/definitions/test_dag.py
@@ -24,10 +24,12 @@ from typing import Any
import pytest
-from airflow.sdk import Context, Label, TaskGroup
+from airflow.sdk import Context, Label, PartitionAtRuntime, TaskGroup
from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.bases.timetable import BaseTimetable
from airflow.sdk.definitions.dag import DAG, dag as dag_decorator
from airflow.sdk.definitions.param import DagParam, Param, ParamsDict
+from airflow.sdk.definitions.timetables import assets, events, interval,
simple, trigger # noqa: F401
from airflow.sdk.exceptions import AirflowDagCycleException,
DuplicateTaskIdFound, RemovedInAirflow4Warning
from airflow.utils.types import DagRunType
@@ -437,6 +439,17 @@ class TestDag:
with pytest.raises(ValueError, match="ContinuousTimetable requires
max_active_runs <= 1"):
dag = DAG("continuous", start_date=DEFAULT_DATE,
schedule="@continuous", max_active_runs=25)
+ def test_only_partition_at_runtime_has_partitioned_at_runtime_flag(self):
+ """Regression guard: across every BaseTimetable subclass, only
PartitionAtRuntime sets partitioned_at_runtime=True."""
+
+ def all_subclasses(cls):
+ for sub in cls.__subclasses__():
+ yield sub
+ yield from all_subclasses(sub)
+
+ flagged = {c for c in all_subclasses(BaseTimetable) if
c.partitioned_at_runtime}
+ assert flagged == {PartitionAtRuntime}
+
def test_dag_add_task_checks_trigger_rule(self):
# A non fail stop dag should allow any trigger rule
from airflow.sdk import TriggerRule
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index a5ff7be9ce8..062645d25b1 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -461,6 +461,35 @@ class TestOutletEventAccessor:
assert outlet_event_accessor.asset_alias_events == asset_alias_events
+class TestOutletEventAccessorPartitionKeys:
+ @pytest.fixture
+ def accessor(self) -> OutletEventAccessor:
+ return OutletEventAccessor(key=AssetUniqueKey.from_asset(Asset("a")))
+
+ def test_default_is_empty(self, accessor):
+ assert accessor.partition_keys == set()
+
+ def test_direct_assignment(self, accessor):
+ accessor.partition_keys = {"us", "eu"}
+ assert accessor.partition_keys == {"us", "eu"}
+
+ def test_add_partitions(self, accessor):
+ accessor.add_partitions("us")
+ assert accessor.partition_keys == {"us"}
+
+ def test_add_partitions_appends(self, accessor):
+ accessor.add_partitions("us")
+ accessor.add_partitions("eu")
+ accessor.add_partitions("apac")
+ assert accessor.partition_keys == {"us", "eu", "apac"}
+
+ def test_add_partitions_dedupes(self, accessor):
+ accessor.add_partitions("us")
+ accessor.add_partitions("us")
+ accessor.add_partitions(["us", "eu"])
+ assert accessor.partition_keys == {"us", "eu"}
+
+
class TestTriggeringAssetEventsAccessor:
@pytest.fixture(autouse=True)
def clear_cache(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index d7b72b2c48f..9425ce362ed 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -156,6 +156,7 @@ from airflow.sdk.execution_time.task_runner import (
_execute_task,
_make_task_span,
_push_xcom_if_needed,
+ _serialize_outlet_events,
_xcom_push,
finalize,
get_startup_details,
@@ -1813,6 +1814,43 @@ def
test_rendered_map_index_updates_sent_progressively(create_runtime_ti, mock_s
assert ti.rendered_map_index == "Label: test_task"
+class TestSerializeOutletEvents:
+ """Tests for the wire format produced by ``_serialize_outlet_events``."""
+
+ def test_emits_single_event_when_no_partition_keys(self):
+ accessors = OutletEventAccessors()
+ accessors[Asset(name="a")].extra = {"x": 1}
+
+ events = list(_serialize_outlet_events(accessors))
+
+ assert events == [{"dest_asset_key": {"name": "a", "uri": "a"},
"extra": {"x": 1}}]
+
+ def test_emits_one_event_per_partition_key(self):
+ accessors = OutletEventAccessors()
+ accessors[Asset(name="a")].partition_keys = {"us", "eu"}
+
+ events = list(_serialize_outlet_events(accessors))
+
+ assert sorted(events, key=lambda e: e["partition_key"]) == [
+ {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {},
"partition_key": "eu"},
+ {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {},
"partition_key": "us"},
+ ]
+
+ def test_dedupes_partition_keys_at_serialization(self):
+ accessors = OutletEventAccessors()
+ accessor = accessors[Asset(name="a")]
+ accessor.add_partitions("us")
+ accessor.add_partitions("us")
+ accessor.add_partitions(["us", "eu"])
+
+ events = list(_serialize_outlet_events(accessors))
+
+ assert sorted(events, key=lambda e: e["partition_key"]) == [
+ {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {},
"partition_key": "eu"},
+ {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {},
"partition_key": "us"},
+ ]
+
+
class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse,
make_ti_context):
"""Test get_template_context without ti_context_from_server."""