This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 90255d9d44a649025f588497f6c82177dad48326
Author: Jed Cunningham <[email protected]>
AuthorDate: Tue Feb 20 08:06:54 2024 -0700

    Check permissions for ImportError (#37468)
    
    (cherry picked from commit d944eb0de216d9e1d125fae5ce4af7440154deb4)
---
 .../endpoints/import_error_endpoint.py             |  61 +++++++-
 airflow/www/views.py                               |  51 +++++--
 .../endpoints/test_import_error_endpoint.py        | 162 ++++++++++++++++++++-
 tests/www/views/test_views_home.py                 |  61 ++++++++
 4 files changed, 314 insertions(+), 21 deletions(-)

diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py 
b/airflow/api_connexion/endpoints/import_error_endpoint.py
index f555400526..59f63c8ffb 100644
--- a/airflow/api_connexion/endpoints/import_error_endpoint.py
+++ b/airflow/api_connexion/endpoints/import_error_endpoint.py
@@ -16,26 +16,29 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Sequence
 
 from sqlalchemy import func, select
 
 from airflow.api_connexion import security
-from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.exceptions import NotFound, PermissionDenied
 from airflow.api_connexion.parameters import apply_sorting, check_limit, 
format_parameters
 from airflow.api_connexion.schemas.error_schema import (
     ImportErrorCollection,
     import_error_collection_schema,
     import_error_schema,
 )
-from airflow.auth.managers.models.resource_details import AccessView
+from airflow.auth.managers.models.resource_details import AccessView, 
DagDetails
+from airflow.models.dag import DagModel
 from airflow.models.errors import ImportError as ImportErrorModel
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.www.extensions.init_auth_manager import get_auth_manager
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.api_connexion.types import APIResponse
+    from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
 
 
 @security.requires_access_view(AccessView.IMPORT_ERRORS)
@@ -43,12 +46,29 @@ if TYPE_CHECKING:
 def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) 
-> APIResponse:
     """Get an import error."""
     error = session.get(ImportErrorModel, import_error_id)
-
     if error is None:
         raise NotFound(
             "Import error not found",
             detail=f"The ImportError with import_error_id: `{import_error_id}` 
was not found",
         )
+    session.expunge(error)
+
+    can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
+    if not can_read_all_dags:
+        readable_dag_ids = security.get_readable_dags()
+        file_dag_ids = {
+            dag_id[0]
+            for dag_id in 
session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
+        }
+
+        # Can the user read any DAGs in the file?
+        if not readable_dag_ids.intersection(file_dag_ids):
+            raise PermissionDenied(detail="You do not have read permission on 
any of the DAGs in the file")
+
+        # Check if user has read access to all the DAGs defined in the file
+        if not file_dag_ids.issubset(readable_dag_ids):
+            error.stacktrace = "REDACTED - you do not have read permission on 
all DAGs in the file"
+
     return import_error_schema.dump(error)
 
 
@@ -65,10 +85,41 @@ def get_import_errors(
     """Get all import errors."""
     to_replace = {"import_error_id": "id"}
     allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
-    total_entries = session.scalars(func.count(ImportErrorModel.id)).one()
+    count_query = select(func.count(ImportErrorModel.id))
     query = select(ImportErrorModel)
     query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
+
+    can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
+
+    if not can_read_all_dags:
+        # if the user doesn't have access to all DAGs, only display errors 
from visible DAGs
+        readable_dag_ids = security.get_readable_dags()
+        dagfiles_subq = (
+            
select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids)).subquery()
+        )
+        query = query.where(ImportErrorModel.filename.in_(dagfiles_subq))
+        count_query = 
count_query.where(ImportErrorModel.filename.in_(dagfiles_subq))
+
+    total_entries = session.scalars(count_query).one()
     import_errors = session.scalars(query.offset(offset).limit(limit)).all()
+
+    if not can_read_all_dags:
+        for import_error in import_errors:
+            # Check if user has read access to all the DAGs defined in the file
+            file_dag_ids = (
+                session.query(DagModel.dag_id).filter(DagModel.fileloc == 
import_error.filename).all()
+            )
+            requests: Sequence[IsAuthorizedDagRequest] = [
+                {
+                    "method": "GET",
+                    "details": DagDetails(id=dag_id[0]),
+                }
+                for dag_id in file_dag_ids
+            ]
+            if not get_auth_manager().batch_is_authorized_dag(requests):
+                session.expunge(import_error)
+                import_error.stacktrace = "REDACTED - you do not have read 
permission on all DAGs in the file"
+
     return import_error_collection_schema.dump(
         ImportErrorCollection(import_errors=import_errors, 
total_entries=total_entries)
     )
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 5252d20ce3..093ff436eb 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -147,6 +147,7 @@ from airflow.www.widgets import AirflowModelListWidget, 
AirflowVariableShowWidge
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
+    from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
     from airflow.models.dag import DAG
     from airflow.models.operator import Operator
 
@@ -935,20 +936,44 @@ class Airflow(AirflowBaseView):
 
             owner_links_dict = DagOwnerAttributes.get_all(session)
 
-            import_errors = 
select(errors.ImportError).order_by(errors.ImportError.id)
-
-            if not get_auth_manager().is_authorized_dag(method="GET"):
-                # if the user doesn't have access to all DAGs, only display 
errors from visible DAGs
-                import_errors = import_errors.join(
-                    DagModel, DagModel.fileloc == errors.ImportError.filename
-                ).where(DagModel.dag_id.in_(filter_dag_ids))
+            if 
get_auth_manager().is_authorized_view(access_view=AccessView.IMPORT_ERRORS):
+                import_errors = 
select(errors.ImportError).order_by(errors.ImportError.id)
+
+                can_read_all_dags = 
get_auth_manager().is_authorized_dag(method="GET")
+                if not can_read_all_dags:
+                    # if the user doesn't have access to all DAGs, only 
display errors from visible DAGs
+                    import_errors = import_errors.where(
+                        errors.ImportError.filename.in_(
+                            select(DagModel.fileloc)
+                            .distinct()
+                            .where(DagModel.dag_id.in_(filter_dag_ids))
+                            .subquery()
+                        )
+                    )
 
-            import_errors = session.scalars(import_errors)
-            for import_error in import_errors:
-                flash(
-                    f"Broken DAG: [{import_error.filename}] 
{import_error.stacktrace}",
-                    "dag_import_error",
-                )
+                import_errors = session.scalars(import_errors)
+                for import_error in import_errors:
+                    stacktrace = import_error.stacktrace
+                    if not can_read_all_dags:
+                        # Check if user has read access to all the DAGs 
defined in the file
+                        file_dag_ids = (
+                            session.query(DagModel.dag_id)
+                            .filter(DagModel.fileloc == import_error.filename)
+                            .all()
+                        )
+                        requests: Sequence[IsAuthorizedDagRequest] = [
+                            {
+                                "method": "GET",
+                                "details": DagDetails(id=dag_id[0]),
+                            }
+                            for dag_id in file_dag_ids
+                        ]
+                        if not 
get_auth_manager().batch_is_authorized_dag(requests):
+                            stacktrace = "REDACTED - you do not have read 
permission on all DAGs in the file"
+                    flash(
+                        f"Broken DAG: [{import_error.filename}]\r{stacktrace}",
+                        "dag_import_error",
+                    )
 
         from airflow.plugins_manager import import_errors as 
plugin_import_errors
 
diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py 
b/tests/api_connexion/endpoints/test_import_error_endpoint.py
index 33550862ab..fae1312a32 100644
--- a/tests/api_connexion/endpoints/test_import_error_endpoint.py
+++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py
@@ -21,16 +21,19 @@ from datetime import timedelta
 import pytest
 
 from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
+from airflow.models.dag import DagModel
 from airflow.models.errors import ImportError
 from airflow.security import permissions
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
 from tests.test_utils.api_connexion_utils import assert_401, create_user, 
delete_user
 from tests.test_utils.config import conf_vars
-from tests.test_utils.db import clear_db_import_errors
+from tests.test_utils.db import clear_db_dags, clear_db_import_errors
 
 pytestmark = pytest.mark.db_test
 
+TEST_DAG_IDS = ["test_dag", "test_dag2"]
+
 
 @pytest.fixture(scope="module")
 def configured_app(minimal_app_for_api):
@@ -39,14 +42,34 @@ def configured_app(minimal_app_for_api):
         app,  # type:ignore
         username="test",
         role_name="Test",
-        permissions=[(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_IMPORT_ERROR)],  # type: ignore
+        permissions=[
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR),
+        ],  # type: ignore
     )
     create_user(app, username="test_no_permissions", 
role_name="TestNoPermissions")  # type: ignore
+    create_user(
+        app,  # type:ignore
+        username="test_single_dag",
+        role_name="TestSingleDAG",
+        permissions=[(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_IMPORT_ERROR)],  # type: ignore
+    )
+    # For some reason, DAG level permissions are not synced when in the above 
list of perms,
+    # so do it manually here:
+    app.appbuilder.sm.bulk_sync_roles(
+        [
+            {
+                "role": "TestSingleDAG",
+                "perms": [(permissions.ACTION_CAN_READ, 
permissions.resource_name_for_dag(TEST_DAG_IDS[0]))],
+            }
+        ]
+    )
 
-    yield minimal_app_for_api
+    yield app
 
     delete_user(app, username="test")  # type: ignore
     delete_user(app, username="test_no_permissions")  # type: ignore
+    delete_user(app, username="test_single_dag")  # type: ignore
 
 
 class TestBaseImportError:
@@ -58,9 +81,11 @@ class TestBaseImportError:
         self.client = self.app.test_client()  # type:ignore
 
         clear_db_import_errors()
+        clear_db_dags()
 
     def teardown_method(self) -> None:
         clear_db_import_errors()
+        clear_db_dags()
 
     @staticmethod
     def _normalize_import_errors(import_errors):
@@ -121,6 +146,72 @@ class TestGetImportErrorEndpoint(TestBaseImportError):
         )
         assert response.status_code == 403
 
+    def test_should_raise_403_forbidden_without_dag_read(self, session):
+        import_error = ImportError(
+            filename="Lorem_ipsum.py",
+            stacktrace="Lorem ipsum",
+            timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+        )
+        session.add(import_error)
+        session.commit()
+
+        response = self.client.get(
+            f"/api/v1/importErrors/{import_error.id}", 
environ_overrides={"REMOTE_USER": "test_single_dag"}
+        )
+
+        assert response.status_code == 403
+
+    def test_should_return_200_with_single_dag_read(self, session):
+        dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py")
+        session.add(dag_model)
+        import_error = ImportError(
+            filename="Lorem_ipsum.py",
+            stacktrace="Lorem ipsum",
+            timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+        )
+        session.add(import_error)
+        session.commit()
+
+        response = self.client.get(
+            f"/api/v1/importErrors/{import_error.id}", 
environ_overrides={"REMOTE_USER": "test_single_dag"}
+        )
+
+        assert response.status_code == 200
+        response_data = response.json
+        response_data["import_error_id"] = 1
+        assert {
+            "filename": "Lorem_ipsum.py",
+            "import_error_id": 1,
+            "stack_trace": "Lorem ipsum",
+            "timestamp": "2020-06-10T12:00:00+00:00",
+        } == response_data
+
+    def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, 
session):
+        for dag_id in TEST_DAG_IDS:
+            dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py")
+            session.add(dag_model)
+        import_error = ImportError(
+            filename="Lorem_ipsum.py",
+            stacktrace="Lorem ipsum",
+            timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+        )
+        session.add(import_error)
+        session.commit()
+
+        response = self.client.get(
+            f"/api/v1/importErrors/{import_error.id}", 
environ_overrides={"REMOTE_USER": "test_single_dag"}
+        )
+
+        assert response.status_code == 200
+        response_data = response.json
+        response_data["import_error_id"] = 1
+        assert {
+            "filename": "Lorem_ipsum.py",
+            "import_error_id": 1,
+            "stack_trace": "REDACTED - you do not have read permission on all 
DAGs in the file",
+            "timestamp": "2020-06-10T12:00:00+00:00",
+        } == response_data
+
 
 class TestGetImportErrorsEndpoint(TestBaseImportError):
     def test_get_import_errors(self, session):
@@ -231,6 +322,71 @@ class TestGetImportErrorsEndpoint(TestBaseImportError):
 
         assert_401(response)
 
+    def test_get_import_errors_single_dag(self, session):
+        for dag_id in TEST_DAG_IDS:
+            fake_filename = f"/tmp/{dag_id}.py"
+            dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
+            session.add(dag_model)
+            importerror = ImportError(
+                filename=fake_filename,
+                stacktrace="Lorem ipsum",
+                timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+            )
+            session.add(importerror)
+        session.commit()
+
+        response = self.client.get(
+            "/api/v1/importErrors", environ_overrides={"REMOTE_USER": 
"test_single_dag"}
+        )
+
+        assert response.status_code == 200
+        response_data = response.json
+        self._normalize_import_errors(response_data["import_errors"])
+        assert {
+            "import_errors": [
+                {
+                    "filename": "/tmp/test_dag.py",
+                    "import_error_id": 1,
+                    "stack_trace": "Lorem ipsum",
+                    "timestamp": "2020-06-10T12:00:00+00:00",
+                },
+            ],
+            "total_entries": 1,
+        } == response_data
+
+    def test_get_import_errors_single_dag_in_dagfile(self, session):
+        for dag_id in TEST_DAG_IDS:
+            fake_filename = "/tmp/all_in_one.py"
+            dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
+            session.add(dag_model)
+
+        importerror = ImportError(
+            filename="/tmp/all_in_one.py",
+            stacktrace="Lorem ipsum",
+            timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+        )
+        session.add(importerror)
+        session.commit()
+
+        response = self.client.get(
+            "/api/v1/importErrors", environ_overrides={"REMOTE_USER": 
"test_single_dag"}
+        )
+
+        assert response.status_code == 200
+        response_data = response.json
+        self._normalize_import_errors(response_data["import_errors"])
+        assert {
+            "import_errors": [
+                {
+                    "filename": "/tmp/all_in_one.py",
+                    "import_error_id": 1,
+                    "stack_trace": "REDACTED - you do not have read permission 
on all DAGs in the file",
+                    "timestamp": "2020-06-10T12:00:00+00:00",
+                },
+            ],
+            "total_entries": 1,
+        } == response_data
+
 
 class TestGetImportErrorsEndpointPagination(TestBaseImportError):
     @pytest.mark.parametrize(
diff --git a/tests/www/views/test_views_home.py 
b/tests/www/views/test_views_home.py
index c05eb45101..4787eecccc 100644
--- a/tests/www/views/test_views_home.py
+++ b/tests/www/views/test_views_home.py
@@ -111,6 +111,30 @@ def test_home_status_filter_cookie(admin_client):
         assert "all" == flask.session[FILTER_STATUS_COOKIE]
 
 
[email protected](scope="module")
+def user_no_importerror(app):
+    """Create User that cannot access Import Errors"""
+    return create_user(
+        app,
+        username="user_no_importerrors",
+        role_name="role_no_importerrors",
+        permissions=[
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE),
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+        ],
+    )
+
+
[email protected]()
+def client_no_importerror(app, user_no_importerror):
+    """Client for User that cannot access Import Errors"""
+    return client_with_login(
+        app,
+        username="user_no_importerrors",
+        password="user_no_importerrors",
+    )
+
+
 @pytest.fixture(scope="module")
 def user_single_dag(app):
     """Create User that can only access the first DAG from 
TEST_FILTER_DAG_IDS"""
@@ -120,6 +144,7 @@ def user_single_dag(app):
         role_name="role_single_dag",
         permissions=[
             (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE),
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR),
             (permissions.ACTION_CAN_READ, 
permissions.resource_name_for_dag(TEST_FILTER_DAG_IDS[0])),
         ],
     )
@@ -232,6 +257,24 @@ def broken_dags_with_read_perm(tmp_path, 
working_dags_with_read_perm):
             _process_file(path, session)
 
 
[email protected]()
+def broken_dags_after_working(tmp_path):
+    # First create and process a DAG file that works
+    path = tmp_path / "all_in_one.py"
+    with create_session() as session:
+        contents = "from airflow import DAG\n"
+        for i, dag_id in enumerate(TEST_FILTER_DAG_IDS):
+            contents += f"dag{i} = DAG('{dag_id}')\n"
+        path.write_text(contents)
+        _process_file(path, session)
+
+    # Then break it!
+    with create_session() as session:
+        contents += "foobar()"
+        path.write_text(contents)
+        _process_file(path, session)
+
+
 def test_home_filter_tags(working_dags, admin_client):
     with admin_client:
         admin_client.get("home?tags=example&tags=data", follow_redirects=True)
@@ -249,6 +292,12 @@ def test_home_importerrors(broken_dags, user_client):
         check_content_in_response(f"/{dag_id}.py", resp)
 
 
+def test_home_no_importerrors_perm(broken_dags, client_no_importerror):
+    # Users without "can read on import errors" don't see any import errors
+    resp = client_no_importerror.get("home", follow_redirects=True)
+    check_content_not_in_response("Import Errors", resp)
+
+
 @pytest.mark.parametrize(
     "page",
     [
@@ -266,11 +315,23 @@ def 
test_home_importerrors_filtered_singledag_user(broken_dags_with_read_perm, c
     check_content_in_response("Import Errors", resp)
     # They can see the first DAGs import error
     check_content_in_response(f"/{TEST_FILTER_DAG_IDS[0]}.py", resp)
+    check_content_in_response("Traceback", resp)
     # But not the rest
     for dag_id in TEST_FILTER_DAG_IDS[1:]:
         check_content_not_in_response(f"/{dag_id}.py", resp)
 
 
+def 
test_home_importerrors_missing_read_on_all_dags_in_file(broken_dags_after_working,
 client_single_dag):
+    # If a user doesn't have READ on all DAGs in a file, that files traceback 
is redacted
+    resp = client_single_dag.get("home", follow_redirects=True)
+    check_content_in_response("Import Errors", resp)
+    # They can see the DAG file has an import error
+    check_content_in_response("all_in_one.py", resp)
+    # And the traceback is redacted
+    check_content_not_in_response("Traceback", resp)
+    check_content_in_response("REDACTED", resp)
+
+
 def test_home_dag_list(working_dags, user_client):
     # Users with "can read on DAGs" gets all DAGs
     resp = user_client.get("home", follow_redirects=True)

Reply via email to