This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 128e11d0617 improve grid/ti_summaries and grid/runs (#64034)
128e11d0617 is described below
commit 128e11d0617b17fd0143679c6597000ce6c0a931
Author: Henry Chen <[email protected]>
AuthorDate: Fri Apr 3 01:58:44 2026 +0800
improve grid/ti_summaries and grid/runs (#64034)
* improve grid/ti_summaries and grid/runs
* remove serdag
---
.../airflow/api_fastapi/core_api/routes/ui/grid.py | 111 +++++++++-----
.../api_fastapi/core_api/services/ui/grid.py | 170 ++++++++++++++-------
.../api_fastapi/core_api/routes/ui/test_grid.py | 23 +++
3 files changed, 205 insertions(+), 99 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
index 0143ae81e14..156982d4097 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
@@ -17,18 +17,18 @@
from __future__ import annotations
-import collections
-from collections.abc import Generator, Sequence
+from collections.abc import Generator, Iterable
from typing import TYPE_CHECKING, Annotated, Any
import structlog
from fastapi import Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from sqlalchemy import exists, select
-from sqlalchemy.orm import joinedload, load_only, selectinload
+from sqlalchemy.orm import Session, joinedload, load_only
from airflow.api_fastapi.auth.managers.models.resource_details import
DagAccessEntity
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
+from airflow.api_fastapi.common.db.dag_runs import attach_dag_versions_to_runs
from airflow.api_fastapi.common.parameters import (
QueryDagRunRunTypesFilter,
QueryDagRunStateFilter,
@@ -52,6 +52,7 @@ from airflow.api_fastapi.core_api.datamodels.ui.grid import (
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.api_fastapi.core_api.services.ui.grid import (
+ GridNodeAgg,
_find_aggregates,
_get_aggs_for_node,
_merge_node_dicts,
@@ -60,13 +61,11 @@ from airflow.api_fastapi.core_api.services.ui.task_group
import (
get_task_group_children_getter,
task_group_to_dict_grid,
)
-from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.deadline import Deadline
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
-from airflow.models.taskinstancehistory import TaskInstanceHistory
log = structlog.get_logger(logger_name=__name__)
grid_router = AirflowRouter(prefix="/grid", tags=["Grid"])
@@ -301,16 +300,8 @@ def get_grid_runs(
DagRun.run_type,
DagRun.bundle_version,
),
-
joinedload(DagRun.dag_model).load_only(DagModel._dag_display_property_value),
joinedload(DagRun.created_dag_version).joinedload(DagVersion.bundle),
- selectinload(DagRun.task_instances)
- .load_only(TaskInstance.dag_version_id)
- .joinedload(TaskInstance.dag_version)
- .joinedload(DagVersion.bundle),
- selectinload(DagRun.task_instances_histories)
- .load_only(TaskInstanceHistory.dag_version_id)
- .joinedload(TaskInstanceHistory.dag_version)
- .joinedload(DagVersion.bundle),
+
joinedload(DagRun.created_dag_version).joinedload(DagVersion.dag_model),
)
)
@@ -332,38 +323,74 @@ def get_grid_runs(
return_total_entries=False,
)
results = session.execute(dag_runs_select_filter).unique().all()
+ dag_runs = [run for run, _ in results]
+ attach_dag_versions_to_runs(dag_runs, session=session)
grid_runs = []
for run, has_missed in results:
- run.has_missed_deadline = has_missed
- grid_runs.append(GridRunsResponse.model_validate(run,
from_attributes=True))
+ grid_runs.append(
+ GridRunsResponse.model_validate(
+ {
+ "dag_id": run.dag_id,
+ "run_id": run.run_id,
+ "queued_at": run.queued_at,
+ "start_date": run.start_date,
+ "end_date": run.end_date,
+ "run_after": run.run_after,
+ "state": run.state,
+ "run_type": run.run_type,
+ "dag_versions": run.dag_versions,
+ "has_missed_deadline": has_missed,
+ }
+ )
+ )
return grid_runs
def _build_ti_summaries(
- dag_id: str, run_id: str, task_instances: Sequence, session, serdag:
SerializedDagModel | None = None
-) -> dict:
- ti_details: dict = collections.defaultdict(list)
+ dag_id: str,
+ run_id: str,
+ task_instances: Iterable[Any],
+ session: Session,
+ *,
+ serdag: SerializedDagModel | None = None,
+ serdag_cache: dict[Any, SerializedDagModel | None] | None = None,
+) -> dict[str, Any] | None:
+ ti_details: dict[str, GridNodeAgg] = {}
+ dag_version_id = None
for ti in task_instances:
- ti_details[ti.task_id].append(
- {
- "state": ti.state,
- "start_date": ti.start_date,
- "end_date": ti.end_date,
- "dag_version_number": getattr(ti, "version_number", None),
- }
+ dag_version_id = dag_version_id or ti.dag_version_id
+ summary = ti_details.get(ti.task_id)
+ if summary is None:
+ summary = ti_details[ti.task_id] = GridNodeAgg()
+ summary.add_ti(
+ state=ti.state,
+ start_date=ti.start_date,
+ end_date=ti.end_date,
+ dag_version_number=getattr(ti, "version_number", None),
)
+ if not ti_details:
+ return None
if serdag is None:
- serdag = _get_serdag(
- dag_id=dag_id,
- dag_version_id=task_instances[0].dag_version_id,
- session=session,
- )
+ if serdag_cache is not None:
+ if dag_version_id not in serdag_cache:
+ serdag_cache[dag_version_id] = _get_serdag(
+ dag_id=dag_id,
+ dag_version_id=dag_version_id,
+ session=session,
+ )
+ serdag = serdag_cache[dag_version_id]
+ else:
+ serdag = _get_serdag(
+ dag_id=dag_id,
+ dag_version_id=dag_version_id,
+ session=session,
+ )
if TYPE_CHECKING:
assert serdag
- def get_node_summaries():
+ def get_node_summaries() -> Iterable[dict[str, Any]]:
yielded_task_ids: set[str] = set()
- for node in _find_aggregates(
+ for node, _ in _find_aggregates(
node=serdag.dag.task_group,
parent_node=None,
ti_details=ti_details,
@@ -441,7 +468,7 @@ def get_grid_ti_summaries_stream(
"""
def _generate() -> Generator[str, None, None]:
- serdag_cache: dict = {}
+ serdag_cache: dict[Any, SerializedDagModel | None] = {}
for run_id in run_ids or []:
tis = session.execute(
select(
@@ -456,13 +483,17 @@ def get_grid_ti_summaries_stream(
.where(TaskInstance.dag_id == dag_id)
.where(TaskInstance.run_id == run_id)
.order_by(TaskInstance.task_id)
- ).all()
- if not tis:
+ .execution_options(yield_per=1000)
+ )
+ summary = _build_ti_summaries(
+ dag_id,
+ run_id,
+ tis,
+ session,
+ serdag_cache=serdag_cache,
+ )
+ if summary is None:
continue
- version_id = tis[0].dag_version_id
- if version_id not in serdag_cache:
- serdag_cache[version_id] = _get_serdag(dag_id, version_id,
session)
- summary = _build_ti_summaries(dag_id, run_id, tis, session,
serdag=serdag_cache[version_id])
yield GridTISummaries.model_validate(summary).model_dump_json() +
"\n"
return StreamingResponse(content=_generate(),
media_type="application/x-ndjson")
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
index 4ee93f4f10e..f8494cd5d23 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
@@ -18,7 +18,9 @@
from __future__ import annotations
from collections import Counter
-from collections.abc import Iterable
+from collections.abc import Iterable, Mapping
+from dataclasses import dataclass, field
+from datetime import datetime
from typing import Any
import structlog
@@ -33,6 +35,64 @@ from airflow.serialization.definitions.taskgroup import
SerializedTaskGroup
log = structlog.get_logger(logger_name=__name__)
+@dataclass
+class GridNodeAgg:
+ """Compact task instance summary used to aggregate grid state without
keeping TI details."""
+
+ child_states: Counter[Any] = field(default_factory=Counter)
+ min_start_date: datetime | None = None
+ max_end_date: datetime | None = None
+ dag_version_number: int | None = None
+
+ def add_ti(
+ self,
+ *,
+ state: Any,
+ start_date: datetime | None,
+ end_date: datetime | None,
+ dag_version_number: int | None,
+ ) -> None:
+ """Merge one task instance row into the summary."""
+ self.child_states[state] += 1
+ if start_date is not None and (self.min_start_date is None or
start_date < self.min_start_date):
+ self.min_start_date = start_date
+ if end_date is not None and (self.max_end_date is None or end_date >
self.max_end_date):
+ self.max_end_date = end_date
+ if dag_version_number is not None and (
+ self.dag_version_number is None or dag_version_number >
self.dag_version_number
+ ):
+ self.dag_version_number = dag_version_number
+
+ def merge(self, other: GridNodeAgg) -> None:
+ """Merge another summary into this one."""
+ self.child_states.update(other.child_states)
+ if other.min_start_date is not None and (
+ self.min_start_date is None or other.min_start_date <
self.min_start_date
+ ):
+ self.min_start_date = other.min_start_date
+ if other.max_end_date is not None and (
+ self.max_end_date is None or other.max_end_date > self.max_end_date
+ ):
+ self.max_end_date = other.max_end_date
+ if other.dag_version_number is not None and (
+ self.dag_version_number is None or other.dag_version_number >
self.dag_version_number
+ ):
+ self.dag_version_number = other.dag_version_number
+
+ def with_placeholder_state(self) -> GridNodeAgg:
+ """Represent mapped tasks without rows as a single no-status square in
the grid."""
+ if self.child_states:
+ return self
+ placeholder = GridNodeAgg(dag_version_number=self.dag_version_number)
+ placeholder.add_ti(
+ state=None,
+ start_date=None,
+ end_date=None,
+ dag_version_number=self.dag_version_number,
+ )
+ return placeholder
+
+
def _merge_node_dicts(current: list[dict[str, Any]], new: list[dict[str, Any]]
| None) -> None:
"""Merge node dictionaries from different DAG versions, handling structure
changes."""
# Handle None case - can occur when merging old DAG versions
@@ -55,90 +115,82 @@ def _merge_node_dicts(current: list[dict[str, Any]], new:
list[dict[str, Any]] |
def agg_state(states):
- states = Counter(states)
+ state_counts = states if isinstance(states, Counter) else Counter(states)
for state in state_priority:
- if state in states:
+ if state in state_counts:
return state
return None
-def _get_aggs_for_node(detail):
- states = [x["state"] for x in detail]
- try:
- min_start_date = min(x["start_date"] for x in detail if
x["start_date"])
- except ValueError:
- min_start_date = None
- try:
- max_end_date = max(x["end_date"] for x in detail if x["end_date"])
- except ValueError:
- max_end_date = None
-
- dag_version_numbers = [
- x.get("dag_version_number") for x in detail if
x.get("dag_version_number") is not None
- ]
- dag_version_number = max(dag_version_numbers) if dag_version_numbers else
None
-
+def _get_aggs_for_node(summary: GridNodeAgg) -> dict[str, Any]:
return {
- "state": agg_state(states),
- "min_start_date": min_start_date,
- "max_end_date": max_end_date,
- "child_states": dict(Counter(states)),
- "dag_version_number": dag_version_number,
+ "state": agg_state(summary.child_states),
+ "min_start_date": summary.min_start_date,
+ "max_end_date": summary.max_end_date,
+ "child_states": dict(summary.child_states),
+ "dag_version_number": summary.dag_version_number,
}
def _find_aggregates(
node: SerializedTaskGroup | SerializedBaseOperator | TaskMap,
parent_node: SerializedTaskGroup | SerializedBaseOperator | TaskMap | None,
- ti_details: dict[str, list],
-) -> Iterable[dict]:
+ ti_details: Mapping[str, GridNodeAgg],
+) -> Iterable[tuple[dict[str, Any], GridNodeAgg]]:
"""Recursively fill the Task Group Map."""
node_id = node.node_id
parent_id = parent_node.node_id if parent_node else None
# Do not mutate ti_details by accidental key creation
- details = ti_details.get(node_id, [])
+ summary = ti_details.get(node_id)
+ if summary is None:
+ summary = GridNodeAgg()
if node is None:
return
if isinstance(node, SerializedMappedOperator):
- # For unmapped tasks, reflect a single None state so UI shows one
square
- mapped_details = details or [{"state": None, "start_date": None,
"end_date": None}]
- yield {
- "task_id": node_id,
- "task_display_name": node.task_display_name,
- "type": "mapped_task",
- "parent_id": parent_id,
- **_get_aggs_for_node(mapped_details),
- "details": mapped_details,
- }
+ mapped_summary = summary.with_placeholder_state()
+ yield (
+ {
+ "task_id": node_id,
+ "task_display_name": node.task_display_name,
+ "type": "mapped_task",
+ "parent_id": parent_id,
+ **_get_aggs_for_node(mapped_summary),
+ },
+ mapped_summary,
+ )
return
if isinstance(node, SerializedTaskGroup):
- children_details = []
+ children_summary = GridNodeAgg()
for child in get_task_group_children_getter()(node):
- for child_node in _find_aggregates(node=child, parent_node=node,
ti_details=ti_details):
+ for child_node, child_summary in _find_aggregates(
+ node=child, parent_node=node, ti_details=ti_details
+ ):
if child_node["parent_id"] == node_id:
- # Collect detailed task instance data from all children
- if child_node.get("details"):
- children_details.extend(child_node["details"])
- yield child_node
+ children_summary.merge(child_summary)
+ yield child_node, child_summary
if node_id:
- yield {
- "task_id": node_id,
- "task_display_name": node_id,
- "type": "group",
- "parent_id": parent_id,
- **_get_aggs_for_node(children_details),
- "details": children_details,
- }
+ yield (
+ {
+ "task_id": node_id,
+ "task_display_name": node_id,
+ "type": "group",
+ "parent_id": parent_id,
+ **_get_aggs_for_node(children_summary),
+ },
+ children_summary,
+ )
return
if isinstance(node, SerializedBaseOperator):
- yield {
- "task_id": node_id,
- "task_display_name": node.task_display_name,
- "type": "task",
- "parent_id": parent_id,
- **_get_aggs_for_node(details),
- "details": details,
- }
+ yield (
+ {
+ "task_id": node_id,
+ "task_display_name": node.task_display_name,
+ "type": "task",
+ "parent_id": parent_id,
+ **_get_aggs_for_node(summary),
+ },
+ summary,
+ )
return
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
index dbcf87149e5..b0dab9012a5 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
@@ -641,6 +641,29 @@ class TestGetGridDataEndpoint:
assert response.status_code == 200
assert _strip_dag_version_ids(response.json()) == [GRID_RUN_1,
GRID_RUN_2]
+ def test_get_grid_runs_multiple_dag_versions(self, session, test_client):
+ latest_dag_version =
session.scalar(select(DagModel).where(DagModel.dag_id ==
DAG_ID_5)).dag_versions[
+ -1
+ ]
+ latest_task_instance = session.scalar(
+ select(TaskInstance)
+ .where(TaskInstance.dag_id == DAG_ID_5, TaskInstance.run_id ==
"run_5_2")
+ .limit(1)
+ )
+ latest_task_instance.dag_version = latest_dag_version
+ session.commit()
+
+ response = test_client.get(f"/grid/runs/{DAG_ID_5}?limit=5")
+ assert response.status_code == 200
+ dag_versions_by_run_id = {
+ run["run_id"]: [dag_version["version_number"] for dag_version in
run["dag_versions"]]
+ for run in response.json()
+ }
+ assert dag_versions_by_run_id == {
+ "run_5_1": [1],
+ "run_5_2": [1, 2],
+ }
+
@pytest.mark.parametrize(
("endpoint", "run_type", "expected"),
[