This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f01c53a7357 AIP-84 Add ordering to task instances and mapped task
instances list (#45802)
f01c53a7357 is described below
commit f01c53a73573a5dacb2107944d32a5fd731d64f6
Author: Sneha Prabhu <[email protected]>
AuthorDate: Thu Jan 23 14:38:15 2025 +0530
AIP-84 Add ordering to task instances and mapped task instances list
(#45802)
* add attributes to task instance ordering
* add attributes to task instance ordering
* add additional ordering attributes and add test
* add test for mapped task instances, parametrized the tests
---------
Co-authored-by: Sneha Prabhu <[email protected]>
---
.../core_api/routes/public/task_instances.py | 39 +++++-
.../core_api/routes/public/test_task_instances.py | 135 +++++++++++----------
2 files changed, 110 insertions(+), 64 deletions(-)
diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 91e4cb0ddcc..1ecc18a6513 100644
--- a/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -126,8 +126,25 @@ def get_mapped_task_instances(
SortParam,
Depends(
SortParam(
- ["id", "state", "duration", "start_date", "end_date",
"map_index", "rendered_map_index"],
+ [
+ "id",
+ "state",
+ "duration",
+ "start_date",
+ "end_date",
+ "map_index",
+ "try_number",
+ "logical_date",
+ "data_interval_start",
+ "data_interval_end",
+ "rendered_map_index",
+ ],
TI,
+ to_replace={
+ "logical_date": DagRun.logical_date,
+ "data_interval_start": DagRun.data_interval_start,
+ "data_interval_end": DagRun.data_interval_end,
+ },
).dynamic_depends(default="map_index")
),
],
@@ -349,8 +366,25 @@ def get_task_instances(
SortParam,
Depends(
SortParam(
- ["id", "state", "duration", "start_date", "end_date",
"map_index"],
+ [
+ "id",
+ "state",
+ "duration",
+ "start_date",
+ "end_date",
+ "map_index",
+ "try_number",
+ "logical_date",
+ "data_interval_start",
+ "data_interval_end",
+ "rendered_map_index",
+ ],
TI,
+ to_replace={
+ "logical_date": DagRun.logical_date,
+ "data_interval_start": DagRun.data_interval_start,
+ "data_interval_end": DagRun.data_interval_end,
+ },
).dynamic_depends(default="map_index")
),
],
@@ -399,6 +433,7 @@ def get_task_instances(
limit=limit,
session=session,
)
+
task_instances = session.scalars(task_instance_select)
return TaskInstanceCollectionResponse(
task_instances=task_instances,
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index f09f6293e90..2b53bbaeaa2 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -491,7 +491,11 @@ class TestGetMappedTaskInstances:
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id="task_2",
executor="default").expand(arg2=task1.output)
- dr = dag_maker.create_dagrun(run_id=f"run_{dag_id}")
+ dr = dag_maker.create_dagrun(
+ run_id=f"run_{dag_id}",
+ logical_date=DEFAULT_DATETIME_1,
+ data_interval=(DEFAULT_DATETIME_1, DEFAULT_DATETIME_2),
+ )
session.add(
TaskMap(
@@ -614,53 +618,49 @@ class TestGetMappedTaskInstances:
assert len(body["task_instances"]) == 10
assert list(range(4, 14)) == [ti["map_index"] for ti in
body["task_instances"]]
- def test_order(self, test_client, one_task_with_many_mapped_tis):
- response = test_client.get(
-
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- )
- assert response.status_code == 200
- body = response.json()
- assert body["total_entries"] == 110
- assert len(body["task_instances"]) == 100
- assert list(range(100)) == [ti["map_index"] for ti in
body["task_instances"]]
-
- def test_mapped_task_instances_reverse_order(self, test_client,
one_task_with_many_mapped_tis):
+ @pytest.mark.parametrize(
+ "params , expected_map_indexes",
+ [
+ ({"order_by": "map_index", "limit": 100}, list(range(100))),
+ ({"order_by": "-map_index", "limit": 100}, list(range(109, 9,
-1))),
+ (
+ {"order_by": "state", "limit": 108},
+ list(range(5, 25)) + list(range(25, 110)) + list(range(3)),
+ ),
+ (
+ {"order_by": "-state", "limit": 100},
+ list(range(5)[::-1]) + list(range(25, 110)[::-1]) +
list(range(15, 25)[::-1]),
+ ),
+ ({"order_by": "logical_date", "limit": 100}, list(range(100))),
+ ({"order_by": "-logical_date", "limit": 100}, list(range(109, 9,
-1))),
+ ({"order_by": "data_interval_start", "limit": 100},
list(range(100))),
+ ({"order_by": "-data_interval_start", "limit": 100},
list(range(109, 9, -1))),
+ ],
+ )
+ def test_mapped_instances_order(
+ self, test_client, session, params, expected_map_indexes,
one_task_with_many_mapped_tis
+ ):
response = test_client.get(
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params={"order_by": "-map_index"},
+ params=params,
)
- assert response.status_code == 200
- body = response.json()
- assert body["total_entries"] == 110
- assert len(body["task_instances"]) == 100
- assert list(range(109, 9, -1)) == [ti["map_index"] for ti in
body["task_instances"]]
- def test_state_order(self, test_client, one_task_with_many_mapped_tis):
- response = test_client.get(
-
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params={"order_by": "-state"},
- )
- assert response.status_code == 200
- body = response.json()
- assert body["total_entries"] == 110
- assert len(body["task_instances"]) == 100
- assert list(range(5)[::-1]) + list(range(25, 110)[::-1]) +
list(range(15, 25)[::-1]) == [
- ti["map_index"] for ti in body["task_instances"]
- ]
- # State ascending
- response = test_client.get(
-
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params={"order_by": "state", "limit": 108},
- )
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 110
- assert len(body["task_instances"]) == 108
- assert list(range(5, 25)) + list(range(25, 110)) + list(range(3)) == [
- ti["map_index"] for ti in body["task_instances"]
- ]
+ assert len(body["task_instances"]) == params["limit"]
+ assert expected_map_indexes == [ti["map_index"] for ti in
body["task_instances"]]
- def test_rendered_map_index_order(self, test_client, session,
one_task_with_many_mapped_tis):
+ @pytest.mark.parametrize(
+ "params, expected_map_indexes",
+ [
+ ({"order_by": "rendered_map_index", "limit": 108}, [0] +
list(range(1, 108))), # Asc
+ ({"order_by": "-rendered_map_index", "limit": 100}, [0] +
list(range(11, 110)[::-1])), # Desc
+ ],
+ )
+ def test_rendered_map_index_order(
+ self, test_client, session, params, expected_map_indexes,
one_task_with_many_mapped_tis
+ ):
ti = (
session.query(TaskInstance)
.where(TaskInstance.task_id == "task_2", TaskInstance.map_index ==
0)
@@ -673,23 +673,13 @@ class TestGetMappedTaskInstances:
response = test_client.get(
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params={"order_by": "-rendered_map_index"},
+ params=params,
)
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 110
- assert len(body["task_instances"]) == 100
- assert [0] + list(range(11, 110)[::-1]) == [ti["map_index"] for ti in
body["task_instances"]]
- # State ascending
- response = test_client.get(
-
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params={"order_by": "rendered_map_index", "limit": 108},
- )
- assert response.status_code == 200
- body = response.json()
- assert body["total_entries"] == 110
- assert len(body["task_instances"]) == 108
- assert [0] + list(range(1, 108)) == [ti["map_index"] for ti in
body["task_instances"]]
+ assert len(body["task_instances"]) == params["limit"]
+ assert expected_map_indexes == [ti["map_index"] for ti in
body["task_instances"]]
def test_with_date(self, test_client, one_task_with_mapped_tis):
response = test_client.get(
@@ -1032,12 +1022,33 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
assert count == response.json()["total_entries"]
assert count == len(response.json()["task_instances"])
- def test_should_respond_200_for_order_by(self, test_client, session):
+ @pytest.mark.parametrize(
+ "order_by_field", ["start_date", "logical_date",
"data_interval_start", "data_interval_end"]
+ )
+ def test_should_respond_200_for_order_by(self, order_by_field,
test_client, session):
dag_id = "example_python_operator"
+
+ dag_runs = [
+ DagRun(
+ dag_id=dag_id,
+ run_id=f"run_{i}",
+ run_type=DagRunType.MANUAL,
+ logical_date=DEFAULT_DATETIME_1 + dt.timedelta(days=i),
+ data_interval=(
+ DEFAULT_DATETIME_1 + dt.timedelta(days=i),
+ DEFAULT_DATETIME_1 + dt.timedelta(days=i, hours=1),
+ ),
+ )
+ for i in range(10)
+ ]
+ session.add_all(dag_runs)
+ session.commit()
+
self.create_task_instances(
session,
task_instances=[
- {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i +
1))} for i in range(10)
+ {"run_id": f"run_{i}", "start_date": DEFAULT_DATETIME_1 +
dt.timedelta(minutes=(i + 1))}
+ for i in range(10)
],
dag_id=dag_id,
)
@@ -1046,7 +1057,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
# Ascending order
response_asc = test_client.get(
- "/public/dags/~/dagRuns/~/taskInstances", params={"order_by":
"start_date"}
+ "/public/dags/~/dagRuns/~/taskInstances", params={"order_by":
order_by_field}
)
assert response_asc.status_code == 200
assert response_asc.json()["total_entries"] == ti_count
@@ -1054,18 +1065,18 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
# Descending order
response_desc = test_client.get(
- "/public/dags/~/dagRuns/~/taskInstances", params={"order_by":
"-start_date"}
+ "/public/dags/~/dagRuns/~/taskInstances", params={"order_by":
f"-{order_by_field}"}
)
assert response_desc.status_code == 200
assert response_desc.json()["total_entries"] == ti_count
assert len(response_desc.json()["task_instances"]) == ti_count
# Compare
- start_dates_asc = [ti["start_date"] for ti in
response_asc.json()["task_instances"]]
- assert len(start_dates_asc) == ti_count
- start_dates_desc = [ti["start_date"] for ti in
response_desc.json()["task_instances"]]
- assert len(start_dates_desc) == ti_count
- assert start_dates_asc == list(reversed(start_dates_desc))
+ field_asc = [ti["id"] for ti in response_asc.json()["task_instances"]]
+ assert len(field_asc) == ti_count
+ field_desc = [ti["id"] for ti in
response_desc.json()["task_instances"]]
+ assert len(field_desc) == ti_count
+ assert field_asc == list(reversed(field_desc))
def test_should_respond_200_for_pagination(self, test_client, session):
dag_id = "example_python_operator"