This is an automated email from the ASF dual-hosted git repository.

Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dae49a37f58 AIP-76: Consume task-emitted partition keys on asset 
events (#66782)
dae49a37f58 is described below

commit dae49a37f58f91b5bb43c4566cc474497d42513c
Author: Anish Giri <[email protected]>
AuthorDate: Thu May 21 04:38:23 2026 -0500

    AIP-76: Consume task-emitted partition keys on asset events (#66782)
---
 airflow-core/src/airflow/models/taskinstance.py    | 106 +++++++++-----
 .../tests/unit/models/test_taskinstance.py         | 157 +++++++++++++++++++++
 2 files changed, 224 insertions(+), 39 deletions(-)

diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 36e37692649..ea8656237fd 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -26,7 +26,7 @@ import warnings
 from collections import defaultdict
 from collections.abc import Collection, Iterable
 from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, NamedTuple
 from urllib.parse import quote
 from uuid import UUID
 
@@ -130,6 +130,13 @@ if TYPE_CHECKING:
 PAST_DEPENDS_MET = "past_depends_met"
 
 
+class OutletEventPayload(NamedTuple):
+    """A single outlet emission carrying its ``extra`` payload and optional 
per-emission ``partition_key``."""
+
+    extra: dict
+    partition_key: str | None
+
+
 @provide_session
 def _add_log(
     event,
@@ -1485,10 +1492,33 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
             SerializedAssetUriRef,
         )
 
-        # TODO: AIP-76 should we provide an interface to override this, so 
that the task can
-        #  tell the truth if for some reason it touches a different partition?
-        #  https://github.com/apache/airflow/issues/58474
-        partition_key = ti.dag_run.partition_key
+        payloads_by_asset: dict[SerializedAssetUniqueKey, 
list[OutletEventPayload]] = defaultdict(list)
+        runtime_pks: set[str] = set()
+        for outlet_event in outlet_events:
+            # Alias-emitted events are handled separately further down via
+            # register_asset_change_for_alias, which uses the DagRun-level
+            # partition_key. Per-emission partition keys do not fan out through
+            # the alias path — emission via an alias produces one event per
+            # resolved asset, all carrying the same dag_run_partition_key.
+            if "source_alias_name" in outlet_event:
+                continue
+            asset_key = 
SerializedAssetUniqueKey(**outlet_event["dest_asset_key"])
+            partition_key = outlet_event.get("partition_key")
+            payloads_by_asset[asset_key].append(
+                OutletEventPayload(extra=outlet_event["extra"], 
partition_key=partition_key)
+            )
+            if partition_key is not None:
+                runtime_pks.add(partition_key)
+
+        # Back-fill DagRun.partition_key from the task emission when the task
+        # emitted exactly one distinct partition_key across all outlet events
+        # and the DagRun did not already have one set. This lets a task that
+        # discovers the partition at runtime (rather than via params) act as
+        # the source of truth for the DagRun-level key.
+        if len(runtime_pks) == 1 and ti.dag_run.partition_key is None:
+            ti.dag_run.partition_key = next(iter(runtime_pks))
+        dag_run_partition_key = ti.dag_run.partition_key
+
         asset_keys = {
             SerializedAssetUniqueKey(o.name, o.uri)
             for o in task_outlets
@@ -1515,11 +1545,25 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
             )
         }
 
-        asset_event_extras: dict[SerializedAssetUniqueKey, dict] = {
-            SerializedAssetUniqueKey(**event["dest_asset_key"]): event["extra"]
-            for event in outlet_events
-            if "source_alias_name" not in event
-        }
+        def _register(am: AssetModel, key: SerializedAssetUniqueKey) -> None:
+            payloads_for_asset = payloads_by_asset.get(key, [])
+            if not payloads_for_asset:
+                asset_manager.register_asset_change(
+                    task_instance=ti,
+                    asset=am,
+                    extra=None,
+                    partition_key=dag_run_partition_key,
+                    session=session,
+                )
+                return
+            for payload in payloads_for_asset:
+                asset_manager.register_asset_change(
+                    task_instance=ti,
+                    asset=am,
+                    extra=payload.extra,
+                    partition_key=payload.partition_key,
+                    session=session,
+                )
 
         for key in asset_keys:
             try:
@@ -1532,52 +1576,36 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
                 )
                 continue
             ti.log.debug("register event for asset %s", am)
-            asset_manager.register_asset_change(
-                task_instance=ti,
-                asset=am,
-                extra=asset_event_extras.get(key),
-                partition_key=partition_key,
-                session=session,
-            )
+            _register(am, key)
 
         if asset_name_refs:
-            asset_models_by_name = {key.name: am for key, am in 
asset_models.items()}
-            asset_event_extras_by_name = {key.name: extra for key, extra in 
asset_event_extras.items()}
+            asset_models_by_name: dict[str, tuple[SerializedAssetUniqueKey, 
AssetModel]] = {
+                key.name: (key, am) for key, am in asset_models.items()
+            }
             for nref in asset_name_refs:
                 try:
-                    am = asset_models_by_name[nref.name]
+                    key, am = asset_models_by_name[nref.name]
                 except KeyError:
                     ti.log.warning(
                         'Task has inactive assets "Asset.ref(name=%s)" in 
inlets or outlets', nref.name
                     )
                     continue
                 ti.log.debug("register event for asset name ref %s", am)
-                asset_manager.register_asset_change(
-                    task_instance=ti,
-                    asset=am,
-                    extra=asset_event_extras_by_name.get(nref.name),
-                    partition_key=partition_key,
-                    session=session,
-                )
+                _register(am, key)
         if asset_uri_refs:
-            asset_models_by_uri = {key.uri: am for key, am in 
asset_models.items()}
-            asset_event_extras_by_uri = {key.uri: extra for key, extra in 
asset_event_extras.items()}
+            asset_models_by_uri: dict[str, tuple[SerializedAssetUniqueKey, 
AssetModel]] = {
+                key.uri: (key, am) for key, am in asset_models.items()
+            }
             for uref in asset_uri_refs:
                 try:
-                    am = asset_models_by_uri[uref.uri]
+                    key, am = asset_models_by_uri[uref.uri]
                 except KeyError:
                     ti.log.warning(
                         'Task has inactive assets "Asset.ref(uri=%s)" in 
inlets or outlets', uref.uri
                     )
                     continue
                 ti.log.debug("register event for asset uri ref %s", am)
-                asset_manager.register_asset_change(
-                    task_instance=ti,
-                    asset=am,
-                    extra=asset_event_extras_by_uri.get(uref.uri),
-                    partition_key=partition_key,
-                    session=session,
-                )
+                _register(am, key)
 
         def _asset_event_extras_from_aliases() -> 
dict[tuple[SerializedAssetUniqueKey, str, str], set[str]]:
             d = defaultdict(set)
@@ -1616,7 +1644,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
                     asset=asset,
                     source_alias_names=event_aliase_names,
                     extra=asset_event_extra,
-                    partition_key=partition_key,
+                    partition_key=dag_run_partition_key,
                     session=session,
                 )
                 if event is None:
@@ -1628,7 +1656,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
                         asset=asset,
                         source_alias_names=event_aliase_names,
                         extra=asset_event_extra,
-                        partition_key=partition_key,
+                        partition_key=dag_run_partition_key,
                         session=session,
                     )
 
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index b9e5452855f..bea68206183 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -3513,6 +3513,163 @@ def 
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
     assert pakl.target_dag_id == "asset_event_listener"
 
 
+def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session):
+    """Single runtime key on a PartitionAtRuntime-style run 
(dag_run.partition_key=None) back-fills the run."""
+    asset = Asset(name="hello")
+    with dag_maker(dag_id="rt_pk_backfill", schedule=None) as dag:
+        EmptyOperator(task_id="hi", outlets=[asset])
+    dr = dag_maker.create_dagrun(session=session)
+    assert dr.partition_key is None
+    [ti] = dr.get_task_instances(session=session)
+
+    TaskInstance.register_asset_changes_in_db(
+        ti=ti,
+        task_outlets=[ensure_serialized_asset(asset).asprofile()],
+        outlet_events=[
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, 
"partition_key": "us"},
+        ],
+        session=session,
+    )
+    event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id 
== dag.dag_id))
+    assert event.partition_key == "us"
+    session.refresh(dr)
+    assert dr.partition_key == "us"
+
+
+def 
test_runtime_partition_key_does_not_overwrite_scheduler_partition(dag_maker, 
session):
+    """Task-emitted key lands on the AssetEvent but does NOT overwrite a 
scheduler-set DagRun.partition_key."""
+    asset = Asset(name="hello")
+    with dag_maker(dag_id="rt_pk_no_overwrite", schedule=None) as dag:
+        EmptyOperator(task_id="hi", outlets=[asset])
+    dr = dag_maker.create_dagrun(partition_key="scheduler-key", 
session=session)
+    [ti] = dr.get_task_instances(session=session)
+
+    TaskInstance.register_asset_changes_in_db(
+        ti=ti,
+        task_outlets=[ensure_serialized_asset(asset).asprofile()],
+        outlet_events=[
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, 
"partition_key": "task-key"},
+        ],
+        session=session,
+    )
+    event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id 
== dag.dag_id))
+    assert event.partition_key == "task-key"
+    session.refresh(dr)
+    assert dr.partition_key == "scheduler-key"
+
+
+def test_runtime_partition_keys_fan_out_to_one_event_per_key(dag_maker, 
session):
+    """Multiple distinct runtime keys produce one AssetEvent each; 
DagRun.partition_key stays None."""
+    asset = Asset(name="hello")
+    with dag_maker(dag_id="rt_pk_fanout", schedule=None) as dag:
+        EmptyOperator(task_id="hi", outlets=[asset])
+    dr = dag_maker.create_dagrun(session=session)
+    assert dr.partition_key is None
+    [ti] = dr.get_task_instances(session=session)
+
+    TaskInstance.register_asset_changes_in_db(
+        ti=ti,
+        task_outlets=[ensure_serialized_asset(asset).asprofile()],
+        outlet_events=[
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, 
"partition_key": "us"},
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, 
"partition_key": "eu"},
+        ],
+        session=session,
+    )
+    events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id 
== dag.dag_id)).all()
+    assert {e.partition_key for e in events} == {"us", "eu"}
+    session.refresh(dr)
+    assert dr.partition_key is None
+
+
+def test_runtime_partition_key_is_none_when_event_has_no_key(dag_maker, 
session):
+    """An outlet event without partition_key produces an AssetEvent with 
partition_key=None."""
+    asset = Asset(name="hello")
+    with dag_maker(dag_id="rt_pk_none", schedule=None) as dag:
+        EmptyOperator(task_id="hi", outlets=[asset])
+    dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
+    [ti] = dr.get_task_instances(session=session)
+
+    TaskInstance.register_asset_changes_in_db(
+        ti=ti,
+        task_outlets=[ensure_serialized_asset(asset).asprofile()],
+        outlet_events=[
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": 
{"x": 1}},
+        ],
+        session=session,
+    )
+    event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id 
== dag.dag_id))
+    assert event.partition_key is None
+
+
+def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session):
+    """One event with partition_key + one without produce two AssetEvents (key 
+ None)."""
+    asset = Asset(name="hello")
+    with dag_maker(dag_id="rt_pk_mixed", schedule=None) as dag:
+        EmptyOperator(task_id="hi", outlets=[asset])
+    dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
+    [ti] = dr.get_task_instances(session=session)
+
+    TaskInstance.register_asset_changes_in_db(
+        ti=ti,
+        task_outlets=[ensure_serialized_asset(asset).asprofile()],
+        outlet_events=[
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, 
"partition_key": "us"},
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}},
+        ],
+        session=session,
+    )
+    events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id 
== dag.dag_id)).all()
+    assert {e.partition_key for e in events} == {"us", None}
+    session.refresh(dr)
+    assert dr.partition_key == "from-run"
+
+
+def 
test_when_runtime_partition_keys_and_downstreams_listening_then_tables_populated(
+    dag_maker,
+    session,
+):
+    """Runtime-emitted fan-out populates PartitionedAssetKeyLog + 
AssetPartitionDagRun per key."""
+    asset = Asset(name="hello")
+    with dag_maker(dag_id="rt_producer", schedule=None, session=session) as 
dag:
+        EmptyOperator(task_id="hi", outlets=[asset])
+    producer_dag_id = dag.dag_id
+    dr = dag_maker.create_dagrun(session=session)
+    assert dr.partition_key is None
+    [ti] = dr.get_task_instances(session=session)
+    session.commit()
+
+    with dag_maker(
+        dag_id="rt_consumer",
+        schedule=PartitionedAssetTimetable(
+            assets=Asset(name="hello"), 
default_partition_mapper=IdentityMapper()
+        ),
+        session=session,
+    ):
+        EmptyOperator(task_id="hi")
+    session.commit()
+
+    TaskInstance.register_asset_changes_in_db(
+        ti=ti,
+        task_outlets=[ensure_serialized_asset(asset).asprofile()],
+        outlet_events=[
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, 
"partition_key": "us"},
+            {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, 
"partition_key": "eu"},
+        ],
+        session=session,
+    )
+    session.commit()
+    events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id 
== producer_dag_id)).all()
+    assert {e.partition_key for e in events} == {"us", "eu"}
+    pakls = session.scalars(select(PartitionedAssetKeyLog)).all()
+    apdrs = session.scalars(select(AssetPartitionDagRun)).all()
+    assert {p.source_partition_key for p in pakls} == {"us", "eu"}
+    assert {p.target_partition_key for p in pakls} == {"us", "eu"}
+    assert {p.target_dag_id for p in pakls} == {"rt_consumer"}
+    assert {a.partition_key for a in apdrs} == {"us", "eu"}
+    assert {a.target_dag_id for a in apdrs} == {"rt_consumer"}
+
+
 async def empty_callback_for_deadline():
     """Used in deadline tests to confirm that Deadlines and DeadlineAlerts 
function correctly."""
     pass

Reply via email to