This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f01c53a7357 AIP-84 Add ordering to task instances and mapped task 
instances list (#45802)
f01c53a7357 is described below

commit f01c53a73573a5dacb2107944d32a5fd731d64f6
Author: Sneha Prabhu <sneha.pra...@astronomer.io>
AuthorDate: Thu Jan 23 14:38:15 2025 +0530

    AIP-84 Add ordering to task instances and mapped task instances list 
(#45802)
    
    * add attributes to task instance ordering
    
    * add attributes to task instance ordering
    
    * add additional ordering attributes and add test
    
    * add test for mapped task instances, parametrized the tests
    
    ---------
    
    Co-authored-by: Sneha Prabhu <snehaprabhu@Snehas-MacBook-Pro.local>
---
 .../core_api/routes/public/task_instances.py       |  39 +++++-
 .../core_api/routes/public/test_task_instances.py  | 135 +++++++++++----------
 2 files changed, 110 insertions(+), 64 deletions(-)

diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py 
b/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 91e4cb0ddcc..1ecc18a6513 100644
--- a/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -126,8 +126,25 @@ def get_mapped_task_instances(
         SortParam,
         Depends(
             SortParam(
-                ["id", "state", "duration", "start_date", "end_date", 
"map_index", "rendered_map_index"],
+                [
+                    "id",
+                    "state",
+                    "duration",
+                    "start_date",
+                    "end_date",
+                    "map_index",
+                    "try_number",
+                    "logical_date",
+                    "data_interval_start",
+                    "data_interval_end",
+                    "rendered_map_index",
+                ],
                 TI,
+                to_replace={
+                    "logical_date": DagRun.logical_date,
+                    "data_interval_start": DagRun.data_interval_start,
+                    "data_interval_end": DagRun.data_interval_end,
+                },
             ).dynamic_depends(default="map_index")
         ),
     ],
@@ -349,8 +366,25 @@ def get_task_instances(
         SortParam,
         Depends(
             SortParam(
-                ["id", "state", "duration", "start_date", "end_date", 
"map_index"],
+                [
+                    "id",
+                    "state",
+                    "duration",
+                    "start_date",
+                    "end_date",
+                    "map_index",
+                    "try_number",
+                    "logical_date",
+                    "data_interval_start",
+                    "data_interval_end",
+                    "rendered_map_index",
+                ],
                 TI,
+                to_replace={
+                    "logical_date": DagRun.logical_date,
+                    "data_interval_start": DagRun.data_interval_start,
+                    "data_interval_end": DagRun.data_interval_end,
+                },
             ).dynamic_depends(default="map_index")
         ),
     ],
@@ -399,6 +433,7 @@ def get_task_instances(
         limit=limit,
         session=session,
     )
+
     task_instances = session.scalars(task_instance_select)
     return TaskInstanceCollectionResponse(
         task_instances=task_instances,
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py 
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index f09f6293e90..2b53bbaeaa2 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -491,7 +491,11 @@ class TestGetMappedTaskInstances:
                 task1 = BaseOperator(task_id="op1")
                 mapped = MockOperator.partial(task_id="task_2", 
executor="default").expand(arg2=task1.output)
 
-            dr = dag_maker.create_dagrun(run_id=f"run_{dag_id}")
+            dr = dag_maker.create_dagrun(
+                run_id=f"run_{dag_id}",
+                logical_date=DEFAULT_DATETIME_1,
+                data_interval=(DEFAULT_DATETIME_1, DEFAULT_DATETIME_2),
+            )
 
             session.add(
                 TaskMap(
@@ -614,53 +618,49 @@ class TestGetMappedTaskInstances:
         assert len(body["task_instances"]) == 10
         assert list(range(4, 14)) == [ti["map_index"] for ti in 
body["task_instances"]]
 
-    def test_order(self, test_client, one_task_with_many_mapped_tis):
-        response = test_client.get(
-            
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-        )
-        assert response.status_code == 200
-        body = response.json()
-        assert body["total_entries"] == 110
-        assert len(body["task_instances"]) == 100
-        assert list(range(100)) == [ti["map_index"] for ti in 
body["task_instances"]]
-
-    def test_mapped_task_instances_reverse_order(self, test_client, 
one_task_with_many_mapped_tis):
+    @pytest.mark.parametrize(
+        "params , expected_map_indexes",
+        [
+            ({"order_by": "map_index", "limit": 100}, list(range(100))),
+            ({"order_by": "-map_index", "limit": 100}, list(range(109, 9, 
-1))),
+            (
+                {"order_by": "state", "limit": 108},
+                list(range(5, 25)) + list(range(25, 110)) + list(range(3)),
+            ),
+            (
+                {"order_by": "-state", "limit": 100},
+                list(range(5)[::-1]) + list(range(25, 110)[::-1]) + 
list(range(15, 25)[::-1]),
+            ),
+            ({"order_by": "logical_date", "limit": 100}, list(range(100))),
+            ({"order_by": "-logical_date", "limit": 100}, list(range(109, 9, 
-1))),
+            ({"order_by": "data_interval_start", "limit": 100}, 
list(range(100))),
+            ({"order_by": "-data_interval_start", "limit": 100}, 
list(range(109, 9, -1))),
+        ],
+    )
+    def test_mapped_instances_order(
+        self, test_client, session, params, expected_map_indexes, 
one_task_with_many_mapped_tis
+    ):
         response = test_client.get(
             
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-            params={"order_by": "-map_index"},
+            params=params,
         )
-        assert response.status_code == 200
-        body = response.json()
-        assert body["total_entries"] == 110
-        assert len(body["task_instances"]) == 100
-        assert list(range(109, 9, -1)) == [ti["map_index"] for ti in 
body["task_instances"]]
 
-    def test_state_order(self, test_client, one_task_with_many_mapped_tis):
-        response = test_client.get(
-            
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-            params={"order_by": "-state"},
-        )
-        assert response.status_code == 200
-        body = response.json()
-        assert body["total_entries"] == 110
-        assert len(body["task_instances"]) == 100
-        assert list(range(5)[::-1]) + list(range(25, 110)[::-1]) + 
list(range(15, 25)[::-1]) == [
-            ti["map_index"] for ti in body["task_instances"]
-        ]
-        # State ascending
-        response = test_client.get(
-            
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-            params={"order_by": "state", "limit": 108},
-        )
         assert response.status_code == 200
         body = response.json()
         assert body["total_entries"] == 110
-        assert len(body["task_instances"]) == 108
-        assert list(range(5, 25)) + list(range(25, 110)) + list(range(3)) == [
-            ti["map_index"] for ti in body["task_instances"]
-        ]
+        assert len(body["task_instances"]) == params["limit"]
+        assert expected_map_indexes == [ti["map_index"] for ti in 
body["task_instances"]]
 
-    def test_rendered_map_index_order(self, test_client, session, 
one_task_with_many_mapped_tis):
+    @pytest.mark.parametrize(
+        "params, expected_map_indexes",
+        [
+            ({"order_by": "rendered_map_index", "limit": 108}, [0] + 
list(range(1, 108))),  # Asc
+            ({"order_by": "-rendered_map_index", "limit": 100}, [0] + 
list(range(11, 110)[::-1])),  # Desc
+        ],
+    )
+    def test_rendered_map_index_order(
+        self, test_client, session, params, expected_map_indexes, 
one_task_with_many_mapped_tis
+    ):
         ti = (
             session.query(TaskInstance)
             .where(TaskInstance.task_id == "task_2", TaskInstance.map_index == 
0)
@@ -673,23 +673,13 @@ class TestGetMappedTaskInstances:
 
         response = test_client.get(
             
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-            params={"order_by": "-rendered_map_index"},
+            params=params,
         )
         assert response.status_code == 200
         body = response.json()
         assert body["total_entries"] == 110
-        assert len(body["task_instances"]) == 100
-        assert [0] + list(range(11, 110)[::-1]) == [ti["map_index"] for ti in 
body["task_instances"]]
-        # State ascending
-        response = test_client.get(
-            
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-            params={"order_by": "rendered_map_index", "limit": 108},
-        )
-        assert response.status_code == 200
-        body = response.json()
-        assert body["total_entries"] == 110
-        assert len(body["task_instances"]) == 108
-        assert [0] + list(range(1, 108)) == [ti["map_index"] for ti in 
body["task_instances"]]
+        assert len(body["task_instances"]) == params["limit"]
+        assert expected_map_indexes == [ti["map_index"] for ti in 
body["task_instances"]]
 
     def test_with_date(self, test_client, one_task_with_mapped_tis):
         response = test_client.get(
@@ -1032,12 +1022,33 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
         assert count == response.json()["total_entries"]
         assert count == len(response.json()["task_instances"])
 
-    def test_should_respond_200_for_order_by(self, test_client, session):
+    @pytest.mark.parametrize(
+        "order_by_field", ["start_date", "logical_date", 
"data_interval_start", "data_interval_end"]
+    )
+    def test_should_respond_200_for_order_by(self, order_by_field, 
test_client, session):
         dag_id = "example_python_operator"
+
+        dag_runs = [
+            DagRun(
+                dag_id=dag_id,
+                run_id=f"run_{i}",
+                run_type=DagRunType.MANUAL,
+                logical_date=DEFAULT_DATETIME_1 + dt.timedelta(days=i),
+                data_interval=(
+                    DEFAULT_DATETIME_1 + dt.timedelta(days=i),
+                    DEFAULT_DATETIME_1 + dt.timedelta(days=i, hours=1),
+                ),
+            )
+            for i in range(10)
+        ]
+        session.add_all(dag_runs)
+        session.commit()
+
         self.create_task_instances(
             session,
             task_instances=[
-                {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 
1))} for i in range(10)
+                {"run_id": f"run_{i}", "start_date": DEFAULT_DATETIME_1 + 
dt.timedelta(minutes=(i + 1))}
+                for i in range(10)
             ],
             dag_id=dag_id,
         )
@@ -1046,7 +1057,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
 
         # Ascending order
         response_asc = test_client.get(
-            "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": 
"start_date"}
+            "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": 
order_by_field}
         )
         assert response_asc.status_code == 200
         assert response_asc.json()["total_entries"] == ti_count
@@ -1054,18 +1065,18 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
 
         # Descending order
         response_desc = test_client.get(
-            "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": 
"-start_date"}
+            "/public/dags/~/dagRuns/~/taskInstances", params={"order_by": 
f"-{order_by_field}"}
         )
         assert response_desc.status_code == 200
         assert response_desc.json()["total_entries"] == ti_count
         assert len(response_desc.json()["task_instances"]) == ti_count
 
         # Compare
-        start_dates_asc = [ti["start_date"] for ti in 
response_asc.json()["task_instances"]]
-        assert len(start_dates_asc) == ti_count
-        start_dates_desc = [ti["start_date"] for ti in 
response_desc.json()["task_instances"]]
-        assert len(start_dates_desc) == ti_count
-        assert start_dates_asc == list(reversed(start_dates_desc))
+        field_asc = [ti["id"] for ti in response_asc.json()["task_instances"]]
+        assert len(field_asc) == ti_count
+        field_desc = [ti["id"] for ti in 
response_desc.json()["task_instances"]]
+        assert len(field_desc) == ti_count
+        assert field_asc == list(reversed(field_desc))
 
     def test_should_respond_200_for_pagination(self, test_client, session):
         dag_id = "example_python_operator"

Reply via email to