This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e31cca1258d Fix N+1 query in bulk task instance delete endpoint
(#67304)
e31cca1258d is described below
commit e31cca1258d0f4c9d2067ae9b051c3b12768bbcf
Author: Colten <[email protected]>
AuthorDate: Tue May 26 21:54:37 2026 +0800
Fix N+1 query in bulk task instance delete endpoint (#67304)
* Fix N+1 query pattern in bulk task instance delete endpoint
* Add regression test for bulk task instance delete N+1
* Refactor N+1 regression test to use parametrize pattern
---
.../core_api/services/public/task_instances.py | 23 +++----------
.../core_api/routes/public/test_task_instances.py | 40 ++++++++++++++++++++++
2 files changed, 45 insertions(+), 18 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
index f72dfc52fbb..9d6227a695c 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
@@ -611,7 +611,7 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
try:
# Handle deletion of specific (dag_id, dag_run_id, task_id,
map_index) tuples
if delete_specific_map_index_task_keys:
- _, matched_task_keys, not_found_task_keys =
self._categorize_task_instances(
+ task_instances_map, matched_task_keys, not_found_task_keys =
self._categorize_task_instances(
delete_specific_map_index_task_keys
)
not_found_task_ids = [
@@ -625,23 +625,10 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
detail=f"The task instances with these identifiers:
{not_found_task_ids} were not found",
)
- for dag_id, run_id, task_id, map_index in matched_task_keys:
- ti = (
- self.session.execute(
- select(TI).where(
- TI.dag_id == dag_id,
- TI.run_id == run_id,
- TI.task_id == task_id,
- TI.map_index == map_index,
- )
- )
- .scalars()
- .one_or_none()
- )
-
- if ti:
- self.session.delete(ti)
-
results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]")
+ for task_key in matched_task_keys:
+ dag_id, run_id, task_id, map_index = task_key
+ self.session.delete(task_instances_map[task_key])
+
results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]")
# Handle deletion of all map indexes for certain (dag_id,
dag_run_id, task_id) tuples
if delete_all_map_index_task_keys:
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index d49e36c8c63..1e874f925b3 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -6719,6 +6719,46 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
}
]
+ @pytest.mark.parametrize("task_count", [5, 10, 20])
+ def test_bulk_delete_query_count_scales_linearly_with_task_count(self,
test_client, session, task_count):
+ # Regression guard for the N+1 fix in
BulkTaskInstanceService.handle_bulk_delete:
+ # each extra task instance must add exactly QUERIES_PER_TASK_INSTANCE
query (its DELETE),
+ # not 2 (DELETE + re-SELECT). A regression that re-queries inside the
loop would make
+ # each run strictly exceed BASE_QUERY_COUNT + task_count *
QUERIES_PER_TASK_INSTANCE.
+ QUERIES_PER_TASK_INSTANCE = 1
+ BASE_QUERY_COUNT = 5
+
+ self.create_task_instances(
+ session,
+ task_instances=[{"state": State.RUNNING, "map_indexes":
tuple(range(task_count))}],
+ )
+ request_body = {
+ "actions": [
+ {
+ "action": "delete",
+ "entities": [
+ {"task_id": self.TASK_ID, "map_index": map_index} for
map_index in range(task_count)
+ ],
+ "action_on_non_existence": "fail",
+ }
+ ]
+ }
+
+ with count_queries() as result:
+ response = test_client.patch(self.ENDPOINT_URL, json=request_body)
+
+ assert response.status_code == 200
+ assert len(response.json()["delete"]["success"]) == task_count
+
+ query_count = sum(result.values())
+ expected_query_count = BASE_QUERY_COUNT + task_count *
QUERIES_PER_TASK_INSTANCE
+ assert query_count == expected_query_count, (
+ f"Bulk-delete query count {query_count} does not match expected
{expected_query_count} "
+ f"for {task_count} task instances. "
+ f"A regression that re-queries each task instance would give "
+ f"~{BASE_QUERY_COUNT + task_count * 2} queries instead."
+ )
+
def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch(self.ENDPOINT_URL,
json={})
assert response.status_code == 401