This is an automated email from the ASF dual-hosted git repository. pierrejeambrun pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new f01c53a7357 AIP-84 Add ordering to task instances and mapped task instances list (#45802) f01c53a7357 is described below commit f01c53a73573a5dacb2107944d32a5fd731d64f6 Author: Sneha Prabhu <sneha.pra...@astronomer.io> AuthorDate: Thu Jan 23 14:38:15 2025 +0530 AIP-84 Add ordering to task instances and mapped task instances list (#45802) * add attributes to task instance ordering * add attributes to task instance ordering * add additional ordering attributes and add test * add test for mapped task instances, parametrized the tests --------- Co-authored-by: Sneha Prabhu <snehaprabhu@Snehas-MacBook-Pro.local> --- .../core_api/routes/public/task_instances.py | 39 +++++- .../core_api/routes/public/test_task_instances.py | 135 +++++++++++---------- 2 files changed, 110 insertions(+), 64 deletions(-) diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index 91e4cb0ddcc..1ecc18a6513 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -126,8 +126,25 @@ def get_mapped_task_instances( SortParam, Depends( SortParam( - ["id", "state", "duration", "start_date", "end_date", "map_index", "rendered_map_index"], + [ + "id", + "state", + "duration", + "start_date", + "end_date", + "map_index", + "try_number", + "logical_date", + "data_interval_start", + "data_interval_end", + "rendered_map_index", + ], TI, + to_replace={ + "logical_date": DagRun.logical_date, + "data_interval_start": DagRun.data_interval_start, + "data_interval_end": DagRun.data_interval_end, + }, ).dynamic_depends(default="map_index") ), ], @@ -349,8 +366,25 @@ def get_task_instances( SortParam, Depends( SortParam( - ["id", "state", "duration", "start_date", "end_date", "map_index"], + [ + "id", + "state", + "duration", + "start_date", + "end_date", + "map_index", + "try_number", + "logical_date", + "data_interval_start", + "data_interval_end", + "rendered_map_index", + ], TI, + to_replace={ + "logical_date": DagRun.logical_date, + "data_interval_start": DagRun.data_interval_start, + "data_interval_end": DagRun.data_interval_end, + }, ).dynamic_depends(default="map_index") ), ], @@ -399,6 +433,7 @@ def get_task_instances( limit=limit, session=session, ) + task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( task_instances=task_instances, diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index f09f6293e90..2b53bbaeaa2 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -491,7 +491,11 @@ class TestGetMappedTaskInstances: task1 = BaseOperator(task_id="op1") mapped = MockOperator.partial(task_id="task_2", executor="default").expand(arg2=task1.output) - dr = dag_maker.create_dagrun(run_id=f"run_{dag_id}") + dr = dag_maker.create_dagrun( + run_id=f"run_{dag_id}", + logical_date=DEFAULT_DATETIME_1, + data_interval=(DEFAULT_DATETIME_1, DEFAULT_DATETIME_2), + ) session.add( TaskMap( @@ -614,53 +618,49 @@ class TestGetMappedTaskInstances: assert len(body["task_instances"]) == 10 assert list(range(4, 14)) == [ti["map_index"] for ti in body["task_instances"]] - def test_order(self, test_client, one_task_with_many_mapped_tis): - response = test_client.get( - "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - ) - assert response.status_code == 200 - body = response.json() - assert body["total_entries"] == 110 - assert len(body["task_instances"]) == 100 - assert list(range(100)) == [ti["map_index"] for ti in body["task_instances"]] - - def test_mapped_task_instances_reverse_order(self, test_client, one_task_with_many_mapped_tis): + @pytest.mark.parametrize( + "params , expected_map_indexes", + [ + ({"order_by": "map_index", "limit": 100}, list(range(100))), + ({"order_by": "-map_index", "limit": 100}, list(range(109, 9, -1))), + ( + {"order_by": "state", "limit": 108}, + list(range(5, 25)) + list(range(25, 110)) + list(range(3)), + ), + ( + {"order_by": "-state", "limit": 100}, + list(range(5)[::-1]) + list(range(25, 110)[::-1]) + list(range(15, 25)[::-1]), + ), + ({"order_by": "logical_date", "limit": 100}, list(range(100))), + ({"order_by": "-logical_date", "limit": 100}, list(range(109, 9, -1))), + ({"order_by": "data_interval_start", "limit": 100}, list(range(100))), + ({"order_by": "-data_interval_start", "limit": 100}, list(range(109, 9, -1))), + ], + ) + def test_mapped_instances_order( + self, test_client, session, params, expected_map_indexes, one_task_with_many_mapped_tis + ): response = test_client.get( "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - params={"order_by": "-map_index"}, + params=params, ) - assert response.status_code == 200 - body = response.json() - assert body["total_entries"] == 110 - assert len(body["task_instances"]) == 100 - assert list(range(109, 9, -1)) == [ti["map_index"] for ti in body["task_instances"]] - def test_state_order(self, test_client, one_task_with_many_mapped_tis): - response = test_client.get( - "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - params={"order_by": "-state"}, - ) - assert response.status_code == 200 - body = response.json() - assert body["total_entries"] == 110 - assert len(body["task_instances"]) == 100 - assert list(range(5)[::-1]) + list(range(25, 110)[::-1]) + list(range(15, 25)[::-1]) == [ - ti["map_index"] for ti in body["task_instances"] - ] - # State ascending - response = test_client.get( - "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - params={"order_by": "state", "limit": 108}, - ) assert response.status_code == 200 body = response.json() assert body["total_entries"] == 110 - assert len(body["task_instances"]) == 108 - assert list(range(5, 25)) + list(range(25, 110)) + list(range(3)) == [ - ti["map_index"] for ti in body["task_instances"] - ] + assert len(body["task_instances"]) == params["limit"] + assert expected_map_indexes == [ti["map_index"] for ti in body["task_instances"]] - def test_rendered_map_index_order(self, test_client, session, one_task_with_many_mapped_tis): + @pytest.mark.parametrize( + "params, expected_map_indexes", + [ + ({"order_by": "rendered_map_index", "limit": 108}, [0] + list(range(1, 108))), # Asc + ({"order_by": "-rendered_map_index", "limit": 100}, [0] + list(range(11, 110)[::-1])), # Desc + ], + ) + def test_rendered_map_index_order( + self, test_client, session, params, expected_map_indexes, one_task_with_many_mapped_tis + ): ti = ( session.query(TaskInstance) .where(TaskInstance.task_id == "task_2", TaskInstance.map_index == 0) @@ -673,23 +673,13 @@ class TestGetMappedTaskInstances: response = test_client.get( "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - params={"order_by": "-rendered_map_index"}, + params=params, ) assert response.status_code == 200 body = response.json() assert body["total_entries"] == 110 - assert len(body["task_instances"]) == 100 - assert [0] + list(range(11, 110)[::-1]) == [ti["map_index"] for ti in body["task_instances"]] - # State ascending - response = test_client.get( - "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - params={"order_by": "rendered_map_index", "limit": 108}, - ) - assert response.status_code == 200 - body = response.json() - assert body["total_entries"] == 110 - assert len(body["task_instances"]) == 108 - assert [0] + list(range(1, 108)) == [ti["map_index"] for ti in body["task_instances"]] + assert len(body["task_instances"]) == params["limit"] + assert expected_map_indexes == [ti["map_index"] for ti in body["task_instances"]] def test_with_date(self, test_client, one_task_with_mapped_tis): response = test_client.get( @@ -1032,12 +1022,33 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): assert count == response.json()["total_entries"] assert count == len(response.json()["task_instances"]) - def test_should_respond_200_for_order_by(self, test_client, session): + @pytest.mark.parametrize( + "order_by_field", ["start_date", "logical_date", "data_interval_start", "data_interval_end"] + ) + def test_should_respond_200_for_order_by(self, order_by_field, test_client, session): dag_id = "example_python_operator" + + dag_runs = [ + DagRun( + dag_id=dag_id, + run_id=f"run_{i}", + run_type=DagRunType.MANUAL, + logical_date=DEFAULT_DATETIME_1 + dt.timedelta(days=i), + data_interval=( + DEFAULT_DATETIME_1 + dt.timedelta(days=i), + DEFAULT_DATETIME_1 + dt.timedelta(days=i, hours=1), + ), + ) + for i in range(10) + ] + session.add_all(dag_runs) + session.commit() + self.create_task_instances( session, task_instances=[ - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 1))} for i in range(10) + {"run_id": f"run_{i}", "start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 1))} + for i in range(10) ], dag_id=dag_id, ) @@ -1046,7 +1057,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): # Ascending order response_asc = test_client.get( - "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": "start_date"} + "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": order_by_field} ) assert response_asc.status_code == 200 assert response_asc.json()["total_entries"] == ti_count @@ -1054,18 +1065,18 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): # Descending order response_desc = test_client.get( - "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": "-start_date"} + "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": f"-{order_by_field}"} ) assert response_desc.status_code == 200 assert response_desc.json()["total_entries"] == ti_count assert len(response_desc.json()["task_instances"]) == ti_count # Compare - start_dates_asc = [ti["start_date"] for ti in response_asc.json()["task_instances"]] - assert len(start_dates_asc) == ti_count - start_dates_desc = [ti["start_date"] for ti in response_desc.json()["task_instances"]] - assert len(start_dates_desc) == ti_count - assert start_dates_asc == list(reversed(start_dates_desc)) + field_asc = [ti["id"] for ti in response_asc.json()["task_instances"]] + assert len(field_asc) == ti_count + field_desc = [ti["id"] for ti in response_desc.json()["task_instances"]] + assert len(field_desc) == ti_count + assert field_asc == list(reversed(field_desc)) def test_should_respond_200_for_pagination(self, test_client, session): dag_id = "example_python_operator"