This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit bfb5f585de73ed2c29dabe1431ea4f69bd1bc20b
Author: Hussein Awala <huss...@awala.fr>
AuthorDate: Sun Oct 15 00:01:00 2023 +0200

    Return only the TIs of the readable dags when ~ is provided as a dag_id 
(#34939)
    
    (cherry picked from commit 33ec72948f74f56f2adb5e2d388e60e88e8a3fa3)
---
 .../endpoints/task_instance_endpoint.py            |  3 ++
 airflow/api_connexion/security.py                  | 10 ++++-
 .../endpoints/test_task_instance_endpoint.py       | 46 ++++++++++++++++++++++
 3 files changed, 58 insertions(+), 1 deletion(-)

diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index b9a5ac9777..b8dc6dc7ee 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -42,6 +42,7 @@ from airflow.api_connexion.schemas.task_instance_schema 
import (
     task_instance_reference_schema,
     task_instance_schema,
 )
+from airflow.api_connexion.security import get_readable_dags
 from airflow.api_connexion.types import APIResponse
 from airflow.models import SlaMiss
 from airflow.models.dagrun import DagRun as DR
@@ -338,6 +339,8 @@ def get_task_instances(
 
     if dag_id != "~":
         base_query = base_query.where(TI.dag_id == dag_id)
+    else:
+        base_query = base_query.where(TI.dag_id.in_(get_readable_dags()))
     if dag_run_id != "~":
         base_query = base_query.where(TI.run_id == dag_run_id)
     base_query = _apply_range_filter(
diff --git a/airflow/api_connexion/security.py 
b/airflow/api_connexion/security.py
index b108adc2c3..b19f15257c 100644
--- a/airflow/api_connexion/security.py
+++ b/airflow/api_connexion/security.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 from functools import wraps
 from typing import Callable, Sequence, TypeVar, cast
 
-from flask import Response
+from flask import Response, g
 
 from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated
 from airflow.utils.airflow_flask_app import get_airflow_app
@@ -55,3 +55,11 @@ def requires_access(permissions: Sequence[tuple[str, str]] | 
None = None) -> Cal
         return cast(T, decorated)
 
     return requires_access_decorator
+
+
+def get_readable_dags() -> list[str]:
+    return get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
+
+
+def can_read_dag(dag_id: str) -> bool:
+    return get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user)
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py 
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 5056f7736d..676722a237 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -658,6 +658,52 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
         assert response.json["total_entries"] == expected_ti
         assert len(response.json["task_instances"]) == expected_ti
 
+    @pytest.mark.parametrize(
+        "task_instances, user, expected_ti",
+        [
+            pytest.param(
+                {
+                    "example_python_operator": 2,
+                    "example_skip_dag": 1,
+                },
+                "test_read_only_one_dag",
+                2,
+            ),
+            pytest.param(
+                {
+                    "example_python_operator": 1,
+                    "example_skip_dag": 2,
+                },
+                "test_read_only_one_dag",
+                1,
+            ),
+            pytest.param(
+                {
+                    "example_python_operator": 1,
+                    "example_skip_dag": 2,
+                },
+                "test",
+                3,
+            ),
+        ],
+    )
+    def test_return_TI_only_from_readable_dags(self, task_instances, user, 
expected_ti, session):
+        for dag_id in task_instances:
+            self.create_task_instances(
+                session,
+                task_instances=[
+                    {"execution_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=i)}
+                    for i in range(task_instances[dag_id])
+                ],
+                dag_id=dag_id,
+            )
+        response = self.client.get(
+            "/api/v1/dags/~/dagRuns/~/taskInstances", 
environ_overrides={"REMOTE_USER": user}
+        )
+        assert response.status_code == 200
+        assert response.json["total_entries"] == expected_ti
+        assert len(response.json["task_instances"]) == expected_ti
+
     def test_should_respond_200_for_dag_id_filter(self, session):
         self.create_task_instances(session)
         self.create_task_instances(session, dag_id="example_skip_dag")

Reply via email to