This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-7-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7ebc49456eaac043d5ef7974b9ed9c32ed7b83ad Author: Hussein Awala <huss...@awala.fr> AuthorDate: Sun Oct 15 00:01:00 2023 +0200 Return only the TIs of the readable dags when ~ is provided as a dag_id (#34939) (cherry picked from commit 33ec72948f74f56f2adb5e2d388e60e88e8a3fa3) --- .../endpoints/task_instance_endpoint.py | 3 ++ airflow/api_connexion/security.py | 10 ++++- .../endpoints/test_task_instance_endpoint.py | 46 ++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index b9a5ac9777..b8dc6dc7ee 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -42,6 +42,7 @@ from airflow.api_connexion.schemas.task_instance_schema import ( task_instance_reference_schema, task_instance_schema, ) +from airflow.api_connexion.security import get_readable_dags from airflow.api_connexion.types import APIResponse from airflow.models import SlaMiss from airflow.models.dagrun import DagRun as DR @@ -338,6 +339,8 @@ def get_task_instances( if dag_id != "~": base_query = base_query.where(TI.dag_id == dag_id) + else: + base_query = base_query.where(TI.dag_id.in_(get_readable_dags())) if dag_run_id != "~": base_query = base_query.where(TI.run_id == dag_run_id) base_query = _apply_range_filter( diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index b108adc2c3..b19f15257c 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -19,7 +19,7 @@ from __future__ import annotations from functools import wraps from typing import Callable, Sequence, TypeVar, cast -from flask import Response +from flask import Response, g from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated from airflow.utils.airflow_flask_app import get_airflow_app @@ -55,3 +55,11 @@ def requires_access(permissions: Sequence[tuple[str, str]] | None = None) -> Cal return cast(T, decorated) return requires_access_decorator + + +def get_readable_dags() -> list[str]: + return get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + + +def can_read_dag(dag_id: str) -> bool: + return get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user) diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 5056f7736d..676722a237 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -658,6 +658,52 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): assert response.json["total_entries"] == expected_ti assert len(response.json["task_instances"]) == expected_ti + @pytest.mark.parametrize( + "task_instances, user, expected_ti", + [ + pytest.param( + { + "example_python_operator": 2, + "example_skip_dag": 1, + }, + "test_read_only_one_dag", + 2, + ), + pytest.param( + { + "example_python_operator": 1, + "example_skip_dag": 2, + }, + "test_read_only_one_dag", + 1, + ), + pytest.param( + { + "example_python_operator": 1, + "example_skip_dag": 2, + }, + "test", + 3, + ), + ], + ) + def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + for dag_id in task_instances: + self.create_task_instances( + session, + task_instances=[ + {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)} + for i in range(task_instances[dag_id]) + ], + dag_id=dag_id, + ) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == expected_ti + assert len(response.json["task_instances"]) == expected_ti + def test_should_respond_200_for_dag_id_filter(self, session): self.create_task_instances(session) self.create_task_instances(session, dag_id="example_skip_dag")