This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 128e11d0617 improve grid/ti_summaries and grid/runs (#64034)
128e11d0617 is described below

commit 128e11d0617b17fd0143679c6597000ce6c0a931
Author: Henry Chen <[email protected]>
AuthorDate: Fri Apr 3 01:58:44 2026 +0800

    improve grid/ti_summaries and grid/runs (#64034)
    
    * improve grid/ti_summaries and grid/runs
    
    * remove serdag
---
 .../airflow/api_fastapi/core_api/routes/ui/grid.py | 111 +++++++++-----
 .../api_fastapi/core_api/services/ui/grid.py       | 170 ++++++++++++++-------
 .../api_fastapi/core_api/routes/ui/test_grid.py    |  23 +++
 3 files changed, 205 insertions(+), 99 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
index 0143ae81e14..156982d4097 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
@@ -17,18 +17,18 @@
 
 from __future__ import annotations
 
-import collections
-from collections.abc import Generator, Sequence
+from collections.abc import Generator, Iterable
 from typing import TYPE_CHECKING, Annotated, Any
 
 import structlog
 from fastapi import Depends, HTTPException, Query, status
 from fastapi.responses import StreamingResponse
 from sqlalchemy import exists, select
-from sqlalchemy.orm import joinedload, load_only, selectinload
+from sqlalchemy.orm import Session, joinedload, load_only
 
 from airflow.api_fastapi.auth.managers.models.resource_details import 
DagAccessEntity
 from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
+from airflow.api_fastapi.common.db.dag_runs import attach_dag_versions_to_runs
 from airflow.api_fastapi.common.parameters import (
     QueryDagRunRunTypesFilter,
     QueryDagRunStateFilter,
@@ -52,6 +52,7 @@ from airflow.api_fastapi.core_api.datamodels.ui.grid import (
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
 from airflow.api_fastapi.core_api.security import requires_access_dag
 from airflow.api_fastapi.core_api.services.ui.grid import (
+    GridNodeAgg,
     _find_aggregates,
     _get_aggs_for_node,
     _merge_node_dicts,
@@ -60,13 +61,11 @@ from airflow.api_fastapi.core_api.services.ui.task_group 
import (
     get_task_group_children_getter,
     task_group_to_dict_grid,
 )
-from airflow.models.dag import DagModel
 from airflow.models.dag_version import DagVersion
 from airflow.models.dagrun import DagRun
 from airflow.models.deadline import Deadline
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import TaskInstance
-from airflow.models.taskinstancehistory import TaskInstanceHistory
 
 log = structlog.get_logger(logger_name=__name__)
 grid_router = AirflowRouter(prefix="/grid", tags=["Grid"])
@@ -301,16 +300,8 @@ def get_grid_runs(
                 DagRun.run_type,
                 DagRun.bundle_version,
             ),
-            
joinedload(DagRun.dag_model).load_only(DagModel._dag_display_property_value),
             
joinedload(DagRun.created_dag_version).joinedload(DagVersion.bundle),
-            selectinload(DagRun.task_instances)
-            .load_only(TaskInstance.dag_version_id)
-            .joinedload(TaskInstance.dag_version)
-            .joinedload(DagVersion.bundle),
-            selectinload(DagRun.task_instances_histories)
-            .load_only(TaskInstanceHistory.dag_version_id)
-            .joinedload(TaskInstanceHistory.dag_version)
-            .joinedload(DagVersion.bundle),
+            
joinedload(DagRun.created_dag_version).joinedload(DagVersion.dag_model),
         )
     )
 
@@ -332,38 +323,74 @@ def get_grid_runs(
         return_total_entries=False,
     )
     results = session.execute(dag_runs_select_filter).unique().all()
+    dag_runs = [run for run, _ in results]
+    attach_dag_versions_to_runs(dag_runs, session=session)
     grid_runs = []
     for run, has_missed in results:
-        run.has_missed_deadline = has_missed
-        grid_runs.append(GridRunsResponse.model_validate(run, 
from_attributes=True))
+        grid_runs.append(
+            GridRunsResponse.model_validate(
+                {
+                    "dag_id": run.dag_id,
+                    "run_id": run.run_id,
+                    "queued_at": run.queued_at,
+                    "start_date": run.start_date,
+                    "end_date": run.end_date,
+                    "run_after": run.run_after,
+                    "state": run.state,
+                    "run_type": run.run_type,
+                    "dag_versions": run.dag_versions,
+                    "has_missed_deadline": has_missed,
+                }
+            )
+        )
     return grid_runs
 
 
 def _build_ti_summaries(
-    dag_id: str, run_id: str, task_instances: Sequence, session, serdag: 
SerializedDagModel | None = None
-) -> dict:
-    ti_details: dict = collections.defaultdict(list)
+    dag_id: str,
+    run_id: str,
+    task_instances: Iterable[Any],
+    session: Session,
+    *,
+    serdag: SerializedDagModel | None = None,
+    serdag_cache: dict[Any, SerializedDagModel | None] | None = None,
+) -> dict[str, Any] | None:
+    ti_details: dict[str, GridNodeAgg] = {}
+    dag_version_id = None
     for ti in task_instances:
-        ti_details[ti.task_id].append(
-            {
-                "state": ti.state,
-                "start_date": ti.start_date,
-                "end_date": ti.end_date,
-                "dag_version_number": getattr(ti, "version_number", None),
-            }
+        dag_version_id = dag_version_id or ti.dag_version_id
+        summary = ti_details.get(ti.task_id)
+        if summary is None:
+            summary = ti_details[ti.task_id] = GridNodeAgg()
+        summary.add_ti(
+            state=ti.state,
+            start_date=ti.start_date,
+            end_date=ti.end_date,
+            dag_version_number=getattr(ti, "version_number", None),
         )
+    if not ti_details:
+        return None
     if serdag is None:
-        serdag = _get_serdag(
-            dag_id=dag_id,
-            dag_version_id=task_instances[0].dag_version_id,
-            session=session,
-        )
+        if serdag_cache is not None:
+            if dag_version_id not in serdag_cache:
+                serdag_cache[dag_version_id] = _get_serdag(
+                    dag_id=dag_id,
+                    dag_version_id=dag_version_id,
+                    session=session,
+                )
+            serdag = serdag_cache[dag_version_id]
+        else:
+            serdag = _get_serdag(
+                dag_id=dag_id,
+                dag_version_id=dag_version_id,
+                session=session,
+            )
     if TYPE_CHECKING:
         assert serdag
 
-    def get_node_summaries():
+    def get_node_summaries() -> Iterable[dict[str, Any]]:
         yielded_task_ids: set[str] = set()
-        for node in _find_aggregates(
+        for node, _ in _find_aggregates(
             node=serdag.dag.task_group,
             parent_node=None,
             ti_details=ti_details,
@@ -441,7 +468,7 @@ def get_grid_ti_summaries_stream(
     """
 
     def _generate() -> Generator[str, None, None]:
-        serdag_cache: dict = {}
+        serdag_cache: dict[Any, SerializedDagModel | None] = {}
         for run_id in run_ids or []:
             tis = session.execute(
                 select(
@@ -456,13 +483,17 @@ def get_grid_ti_summaries_stream(
                 .where(TaskInstance.dag_id == dag_id)
                 .where(TaskInstance.run_id == run_id)
                 .order_by(TaskInstance.task_id)
-            ).all()
-            if not tis:
+                .execution_options(yield_per=1000)
+            )
+            summary = _build_ti_summaries(
+                dag_id,
+                run_id,
+                tis,
+                session,
+                serdag_cache=serdag_cache,
+            )
+            if summary is None:
                 continue
-            version_id = tis[0].dag_version_id
-            if version_id not in serdag_cache:
-                serdag_cache[version_id] = _get_serdag(dag_id, version_id, 
session)
-            summary = _build_ti_summaries(dag_id, run_id, tis, session, 
serdag=serdag_cache[version_id])
             yield GridTISummaries.model_validate(summary).model_dump_json() + 
"\n"
 
     return StreamingResponse(content=_generate(), 
media_type="application/x-ndjson")
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py 
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
index 4ee93f4f10e..f8494cd5d23 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
@@ -18,7 +18,9 @@
 from __future__ import annotations
 
 from collections import Counter
-from collections.abc import Iterable
+from collections.abc import Iterable, Mapping
+from dataclasses import dataclass, field
+from datetime import datetime
 from typing import Any
 
 import structlog
@@ -33,6 +35,64 @@ from airflow.serialization.definitions.taskgroup import 
SerializedTaskGroup
 log = structlog.get_logger(logger_name=__name__)
 
 
+@dataclass
+class GridNodeAgg:
+    """Compact task instance summary used to aggregate grid state without 
keeping TI details."""
+
+    child_states: Counter[Any] = field(default_factory=Counter)
+    min_start_date: datetime | None = None
+    max_end_date: datetime | None = None
+    dag_version_number: int | None = None
+
+    def add_ti(
+        self,
+        *,
+        state: Any,
+        start_date: datetime | None,
+        end_date: datetime | None,
+        dag_version_number: int | None,
+    ) -> None:
+        """Merge one task instance row into the summary."""
+        self.child_states[state] += 1
+        if start_date is not None and (self.min_start_date is None or 
start_date < self.min_start_date):
+            self.min_start_date = start_date
+        if end_date is not None and (self.max_end_date is None or end_date > 
self.max_end_date):
+            self.max_end_date = end_date
+        if dag_version_number is not None and (
+            self.dag_version_number is None or dag_version_number > 
self.dag_version_number
+        ):
+            self.dag_version_number = dag_version_number
+
+    def merge(self, other: GridNodeAgg) -> None:
+        """Merge another summary into this one."""
+        self.child_states.update(other.child_states)
+        if other.min_start_date is not None and (
+            self.min_start_date is None or other.min_start_date < 
self.min_start_date
+        ):
+            self.min_start_date = other.min_start_date
+        if other.max_end_date is not None and (
+            self.max_end_date is None or other.max_end_date > self.max_end_date
+        ):
+            self.max_end_date = other.max_end_date
+        if other.dag_version_number is not None and (
+            self.dag_version_number is None or other.dag_version_number > 
self.dag_version_number
+        ):
+            self.dag_version_number = other.dag_version_number
+
+    def with_placeholder_state(self) -> GridNodeAgg:
+        """Represent mapped tasks without rows as a single no-status square in 
the grid."""
+        if self.child_states:
+            return self
+        placeholder = GridNodeAgg(dag_version_number=self.dag_version_number)
+        placeholder.add_ti(
+            state=None,
+            start_date=None,
+            end_date=None,
+            dag_version_number=self.dag_version_number,
+        )
+        return placeholder
+
+
 def _merge_node_dicts(current: list[dict[str, Any]], new: list[dict[str, Any]] 
| None) -> None:
     """Merge node dictionaries from different DAG versions, handling structure 
changes."""
     # Handle None case - can occur when merging old DAG versions
@@ -55,90 +115,82 @@ def _merge_node_dicts(current: list[dict[str, Any]], new: 
list[dict[str, Any]] |
 
 
 def agg_state(states):
-    states = Counter(states)
+    state_counts = states if isinstance(states, Counter) else Counter(states)
     for state in state_priority:
-        if state in states:
+        if state in state_counts:
             return state
     return None
 
 
-def _get_aggs_for_node(detail):
-    states = [x["state"] for x in detail]
-    try:
-        min_start_date = min(x["start_date"] for x in detail if 
x["start_date"])
-    except ValueError:
-        min_start_date = None
-    try:
-        max_end_date = max(x["end_date"] for x in detail if x["end_date"])
-    except ValueError:
-        max_end_date = None
-
-    dag_version_numbers = [
-        x.get("dag_version_number") for x in detail if 
x.get("dag_version_number") is not None
-    ]
-    dag_version_number = max(dag_version_numbers) if dag_version_numbers else 
None
-
+def _get_aggs_for_node(summary: GridNodeAgg) -> dict[str, Any]:
     return {
-        "state": agg_state(states),
-        "min_start_date": min_start_date,
-        "max_end_date": max_end_date,
-        "child_states": dict(Counter(states)),
-        "dag_version_number": dag_version_number,
+        "state": agg_state(summary.child_states),
+        "min_start_date": summary.min_start_date,
+        "max_end_date": summary.max_end_date,
+        "child_states": dict(summary.child_states),
+        "dag_version_number": summary.dag_version_number,
     }
 
 
 def _find_aggregates(
     node: SerializedTaskGroup | SerializedBaseOperator | TaskMap,
     parent_node: SerializedTaskGroup | SerializedBaseOperator | TaskMap | None,
-    ti_details: dict[str, list],
-) -> Iterable[dict]:
+    ti_details: Mapping[str, GridNodeAgg],
+) -> Iterable[tuple[dict[str, Any], GridNodeAgg]]:
     """Recursively fill the Task Group Map."""
     node_id = node.node_id
     parent_id = parent_node.node_id if parent_node else None
     # Do not mutate ti_details by accidental key creation
-    details = ti_details.get(node_id, [])
+    summary = ti_details.get(node_id)
+    if summary is None:
+        summary = GridNodeAgg()
 
     if node is None:
         return
     if isinstance(node, SerializedMappedOperator):
-        # For unmapped tasks, reflect a single None state so UI shows one 
square
-        mapped_details = details or [{"state": None, "start_date": None, 
"end_date": None}]
-        yield {
-            "task_id": node_id,
-            "task_display_name": node.task_display_name,
-            "type": "mapped_task",
-            "parent_id": parent_id,
-            **_get_aggs_for_node(mapped_details),
-            "details": mapped_details,
-        }
+        mapped_summary = summary.with_placeholder_state()
+        yield (
+            {
+                "task_id": node_id,
+                "task_display_name": node.task_display_name,
+                "type": "mapped_task",
+                "parent_id": parent_id,
+                **_get_aggs_for_node(mapped_summary),
+            },
+            mapped_summary,
+        )
 
         return
     if isinstance(node, SerializedTaskGroup):
-        children_details = []
+        children_summary = GridNodeAgg()
         for child in get_task_group_children_getter()(node):
-            for child_node in _find_aggregates(node=child, parent_node=node, 
ti_details=ti_details):
+            for child_node, child_summary in _find_aggregates(
+                node=child, parent_node=node, ti_details=ti_details
+            ):
                 if child_node["parent_id"] == node_id:
-                    # Collect detailed task instance data from all children
-                    if child_node.get("details"):
-                        children_details.extend(child_node["details"])
-                yield child_node
+                    children_summary.merge(child_summary)
+                yield child_node, child_summary
         if node_id:
-            yield {
-                "task_id": node_id,
-                "task_display_name": node_id,
-                "type": "group",
-                "parent_id": parent_id,
-                **_get_aggs_for_node(children_details),
-                "details": children_details,
-            }
+            yield (
+                {
+                    "task_id": node_id,
+                    "task_display_name": node_id,
+                    "type": "group",
+                    "parent_id": parent_id,
+                    **_get_aggs_for_node(children_summary),
+                },
+                children_summary,
+            )
         return
     if isinstance(node, SerializedBaseOperator):
-        yield {
-            "task_id": node_id,
-            "task_display_name": node.task_display_name,
-            "type": "task",
-            "parent_id": parent_id,
-            **_get_aggs_for_node(details),
-            "details": details,
-        }
+        yield (
+            {
+                "task_id": node_id,
+                "task_display_name": node.task_display_name,
+                "type": "task",
+                "parent_id": parent_id,
+                **_get_aggs_for_node(summary),
+            },
+            summary,
+        )
         return
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
index dbcf87149e5..b0dab9012a5 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
@@ -641,6 +641,29 @@ class TestGetGridDataEndpoint:
         assert response.status_code == 200
         assert _strip_dag_version_ids(response.json()) == [GRID_RUN_1, 
GRID_RUN_2]
 
+    def test_get_grid_runs_multiple_dag_versions(self, session, test_client):
+        latest_dag_version = 
session.scalar(select(DagModel).where(DagModel.dag_id == 
DAG_ID_5)).dag_versions[
+            -1
+        ]
+        latest_task_instance = session.scalar(
+            select(TaskInstance)
+            .where(TaskInstance.dag_id == DAG_ID_5, TaskInstance.run_id == 
"run_5_2")
+            .limit(1)
+        )
+        latest_task_instance.dag_version = latest_dag_version
+        session.commit()
+
+        response = test_client.get(f"/grid/runs/{DAG_ID_5}?limit=5")
+        assert response.status_code == 200
+        dag_versions_by_run_id = {
+            run["run_id"]: [dag_version["version_number"] for dag_version in 
run["dag_versions"]]
+            for run in response.json()
+        }
+        assert dag_versions_by_run_id == {
+            "run_5_1": [1],
+            "run_5_2": [1, 2],
+        }
+
     @pytest.mark.parametrize(
         ("endpoint", "run_type", "expected"),
         [

Reply via email to