This is an automated email from the ASF dual-hosted git repository.
Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dae49a37f58 AIP-76: Consume task-emitted partition keys on asset
events (#66782)
dae49a37f58 is described below
commit dae49a37f58f91b5bb43c4566cc474497d42513c
Author: Anish Giri <[email protected]>
AuthorDate: Thu May 21 04:38:23 2026 -0500
AIP-76: Consume task-emitted partition keys on asset events (#66782)
---
airflow-core/src/airflow/models/taskinstance.py | 106 +++++++++-----
.../tests/unit/models/test_taskinstance.py | 157 +++++++++++++++++++++
2 files changed, 224 insertions(+), 39 deletions(-)
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 36e37692649..ea8656237fd 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -26,7 +26,7 @@ import warnings
from collections import defaultdict
from collections.abc import Collection, Iterable
from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, NamedTuple
from urllib.parse import quote
from uuid import UUID
@@ -130,6 +130,13 @@ if TYPE_CHECKING:
PAST_DEPENDS_MET = "past_depends_met"
+class OutletEventPayload(NamedTuple):
+ """A single outlet emission carrying its ``extra`` payload and optional
per-emission ``partition_key``."""
+
+ extra: dict
+ partition_key: str | None
+
+
@provide_session
def _add_log(
event,
@@ -1485,10 +1492,33 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
SerializedAssetUriRef,
)
- # TODO: AIP-76 should we provide an interface to override this, so
that the task can
- # tell the truth if for some reason it touches a different partition?
- # https://github.com/apache/airflow/issues/58474
- partition_key = ti.dag_run.partition_key
+ payloads_by_asset: dict[SerializedAssetUniqueKey,
list[OutletEventPayload]] = defaultdict(list)
+ runtime_pks: set[str] = set()
+ for outlet_event in outlet_events:
+ # Alias-emitted events are handled separately further down via
+ # register_asset_change_for_alias, which uses the DagRun-level
+ # partition_key. Per-emission partition keys do not fan out through
+ # the alias path — emission via an alias produces one event per
+ # resolved asset, all carrying the same dag_run_partition_key.
+ if "source_alias_name" in outlet_event:
+ continue
+ asset_key =
SerializedAssetUniqueKey(**outlet_event["dest_asset_key"])
+ partition_key = outlet_event.get("partition_key")
+ payloads_by_asset[asset_key].append(
+ OutletEventPayload(extra=outlet_event["extra"],
partition_key=partition_key)
+ )
+ if partition_key is not None:
+ runtime_pks.add(partition_key)
+
+ # Back-fill DagRun.partition_key from the task emission when the task
+ # emitted exactly one distinct partition_key across all outlet events
+ # and the DagRun did not already have one set. This lets a task that
+ # discovers the partition at runtime (rather than via params) act as
+ # the source of truth for the DagRun-level key.
+ if len(runtime_pks) == 1 and ti.dag_run.partition_key is None:
+ ti.dag_run.partition_key = next(iter(runtime_pks))
+ dag_run_partition_key = ti.dag_run.partition_key
+
asset_keys = {
SerializedAssetUniqueKey(o.name, o.uri)
for o in task_outlets
@@ -1515,11 +1545,25 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
)
}
- asset_event_extras: dict[SerializedAssetUniqueKey, dict] = {
- SerializedAssetUniqueKey(**event["dest_asset_key"]): event["extra"]
- for event in outlet_events
- if "source_alias_name" not in event
- }
+ def _register(am: AssetModel, key: SerializedAssetUniqueKey) -> None:
+ payloads_for_asset = payloads_by_asset.get(key, [])
+ if not payloads_for_asset:
+ asset_manager.register_asset_change(
+ task_instance=ti,
+ asset=am,
+ extra=None,
+ partition_key=dag_run_partition_key,
+ session=session,
+ )
+ return
+ for payload in payloads_for_asset:
+ asset_manager.register_asset_change(
+ task_instance=ti,
+ asset=am,
+ extra=payload.extra,
+ partition_key=payload.partition_key,
+ session=session,
+ )
for key in asset_keys:
try:
@@ -1532,52 +1576,36 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
)
continue
ti.log.debug("register event for asset %s", am)
- asset_manager.register_asset_change(
- task_instance=ti,
- asset=am,
- extra=asset_event_extras.get(key),
- partition_key=partition_key,
- session=session,
- )
+ _register(am, key)
if asset_name_refs:
- asset_models_by_name = {key.name: am for key, am in
asset_models.items()}
- asset_event_extras_by_name = {key.name: extra for key, extra in
asset_event_extras.items()}
+ asset_models_by_name: dict[str, tuple[SerializedAssetUniqueKey,
AssetModel]] = {
+ key.name: (key, am) for key, am in asset_models.items()
+ }
for nref in asset_name_refs:
try:
- am = asset_models_by_name[nref.name]
+ key, am = asset_models_by_name[nref.name]
except KeyError:
ti.log.warning(
'Task has inactive assets "Asset.ref(name=%s)" in
inlets or outlets', nref.name
)
continue
ti.log.debug("register event for asset name ref %s", am)
- asset_manager.register_asset_change(
- task_instance=ti,
- asset=am,
- extra=asset_event_extras_by_name.get(nref.name),
- partition_key=partition_key,
- session=session,
- )
+ _register(am, key)
if asset_uri_refs:
- asset_models_by_uri = {key.uri: am for key, am in
asset_models.items()}
- asset_event_extras_by_uri = {key.uri: extra for key, extra in
asset_event_extras.items()}
+ asset_models_by_uri: dict[str, tuple[SerializedAssetUniqueKey,
AssetModel]] = {
+ key.uri: (key, am) for key, am in asset_models.items()
+ }
for uref in asset_uri_refs:
try:
- am = asset_models_by_uri[uref.uri]
+ key, am = asset_models_by_uri[uref.uri]
except KeyError:
ti.log.warning(
'Task has inactive assets "Asset.ref(uri=%s)" in
inlets or outlets', uref.uri
)
continue
ti.log.debug("register event for asset uri ref %s", am)
- asset_manager.register_asset_change(
- task_instance=ti,
- asset=am,
- extra=asset_event_extras_by_uri.get(uref.uri),
- partition_key=partition_key,
- session=session,
- )
+ _register(am, key)
def _asset_event_extras_from_aliases() ->
dict[tuple[SerializedAssetUniqueKey, str, str], set[str]]:
d = defaultdict(set)
@@ -1616,7 +1644,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
- partition_key=partition_key,
+ partition_key=dag_run_partition_key,
session=session,
)
if event is None:
@@ -1628,7 +1656,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
- partition_key=partition_key,
+ partition_key=dag_run_partition_key,
session=session,
)
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index b9e5452855f..bea68206183 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -3513,6 +3513,163 @@ def
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
assert pakl.target_dag_id == "asset_event_listener"
+def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session):
+ """Single runtime key on a PartitionAtRuntime-style run
(dag_run.partition_key=None) back-fills the run."""
+ asset = Asset(name="hello")
+ with dag_maker(dag_id="rt_pk_backfill", schedule=None) as dag:
+ EmptyOperator(task_id="hi", outlets=[asset])
+ dr = dag_maker.create_dagrun(session=session)
+ assert dr.partition_key is None
+ [ti] = dr.get_task_instances(session=session)
+
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[ensure_serialized_asset(asset).asprofile()],
+ outlet_events=[
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {},
"partition_key": "us"},
+ ],
+ session=session,
+ )
+ event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id
== dag.dag_id))
+ assert event.partition_key == "us"
+ session.refresh(dr)
+ assert dr.partition_key == "us"
+
+
+def
test_runtime_partition_key_does_not_overwrite_scheduler_partition(dag_maker,
session):
+ """Task-emitted key lands on the AssetEvent but does NOT overwrite a
scheduler-set DagRun.partition_key."""
+ asset = Asset(name="hello")
+ with dag_maker(dag_id="rt_pk_no_overwrite", schedule=None) as dag:
+ EmptyOperator(task_id="hi", outlets=[asset])
+ dr = dag_maker.create_dagrun(partition_key="scheduler-key",
session=session)
+ [ti] = dr.get_task_instances(session=session)
+
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[ensure_serialized_asset(asset).asprofile()],
+ outlet_events=[
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {},
"partition_key": "task-key"},
+ ],
+ session=session,
+ )
+ event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id
== dag.dag_id))
+ assert event.partition_key == "task-key"
+ session.refresh(dr)
+ assert dr.partition_key == "scheduler-key"
+
+
+def test_runtime_partition_keys_fan_out_to_one_event_per_key(dag_maker,
session):
+ """Multiple distinct runtime keys produce one AssetEvent each;
DagRun.partition_key stays None."""
+ asset = Asset(name="hello")
+ with dag_maker(dag_id="rt_pk_fanout", schedule=None) as dag:
+ EmptyOperator(task_id="hi", outlets=[asset])
+ dr = dag_maker.create_dagrun(session=session)
+ assert dr.partition_key is None
+ [ti] = dr.get_task_instances(session=session)
+
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[ensure_serialized_asset(asset).asprofile()],
+ outlet_events=[
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {},
"partition_key": "us"},
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {},
"partition_key": "eu"},
+ ],
+ session=session,
+ )
+ events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id
== dag.dag_id)).all()
+ assert {e.partition_key for e in events} == {"us", "eu"}
+ session.refresh(dr)
+ assert dr.partition_key is None
+
+
+def test_runtime_partition_key_is_none_when_event_has_no_key(dag_maker,
session):
+ """An outlet event without partition_key produces an AssetEvent with
partition_key=None."""
+ asset = Asset(name="hello")
+ with dag_maker(dag_id="rt_pk_none", schedule=None) as dag:
+ EmptyOperator(task_id="hi", outlets=[asset])
+ dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
+ [ti] = dr.get_task_instances(session=session)
+
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[ensure_serialized_asset(asset).asprofile()],
+ outlet_events=[
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra":
{"x": 1}},
+ ],
+ session=session,
+ )
+ event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id
== dag.dag_id))
+ assert event.partition_key is None
+
+
+def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session):
+ """One event with partition_key + one without produce two AssetEvents (key
+ None)."""
+ asset = Asset(name="hello")
+ with dag_maker(dag_id="rt_pk_mixed", schedule=None) as dag:
+ EmptyOperator(task_id="hi", outlets=[asset])
+ dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
+ [ti] = dr.get_task_instances(session=session)
+
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[ensure_serialized_asset(asset).asprofile()],
+ outlet_events=[
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {},
"partition_key": "us"},
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}},
+ ],
+ session=session,
+ )
+ events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id
== dag.dag_id)).all()
+ assert {e.partition_key for e in events} == {"us", None}
+ session.refresh(dr)
+ assert dr.partition_key == "from-run"
+
+
+def
test_when_runtime_partition_keys_and_downstreams_listening_then_tables_populated(
+ dag_maker,
+ session,
+):
+ """Runtime-emitted fan-out populates PartitionedAssetKeyLog +
AssetPartitionDagRun per key."""
+ asset = Asset(name="hello")
+ with dag_maker(dag_id="rt_producer", schedule=None, session=session) as
dag:
+ EmptyOperator(task_id="hi", outlets=[asset])
+ producer_dag_id = dag.dag_id
+ dr = dag_maker.create_dagrun(session=session)
+ assert dr.partition_key is None
+ [ti] = dr.get_task_instances(session=session)
+ session.commit()
+
+ with dag_maker(
+ dag_id="rt_consumer",
+ schedule=PartitionedAssetTimetable(
+ assets=Asset(name="hello"),
default_partition_mapper=IdentityMapper()
+ ),
+ session=session,
+ ):
+ EmptyOperator(task_id="hi")
+ session.commit()
+
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[ensure_serialized_asset(asset).asprofile()],
+ outlet_events=[
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {},
"partition_key": "us"},
+ {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {},
"partition_key": "eu"},
+ ],
+ session=session,
+ )
+ session.commit()
+ events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id
== producer_dag_id)).all()
+ assert {e.partition_key for e in events} == {"us", "eu"}
+ pakls = session.scalars(select(PartitionedAssetKeyLog)).all()
+ apdrs = session.scalars(select(AssetPartitionDagRun)).all()
+ assert {p.source_partition_key for p in pakls} == {"us", "eu"}
+ assert {p.target_partition_key for p in pakls} == {"us", "eu"}
+ assert {p.target_dag_id for p in pakls} == {"rt_consumer"}
+ assert {a.partition_key for a in apdrs} == {"us", "eu"}
+ assert {a.target_dag_id for a in apdrs} == {"rt_consumer"}
+
+
async def empty_callback_for_deadline():
"""Used in deadline tests to confirm that Deadlines and DeadlineAlerts
function correctly."""
pass