This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a1c45b950c9 Fix bulk task instance rbac bypass (#64288)
a1c45b950c9 is described below

commit a1c45b950c9a54b7d1691c863def7ee56f07c521
Author: GPK <[email protected]>
AuthorDate: Sat Apr 4 10:05:22 2026 +0100

    Fix bulk task instance rbac bypass (#64288)
    
    * Fix bulk task instance RBAC checks across DAGs
    
    * Update tests
    
    * fix up tests
    
    * Fixup tests
    
    * Resolve comments
---
 .../core_api/services/public/task_instances.py     |  28 +++-
 .../core_api/routes/public/test_task_instances.py  | 181 ++++++++++++++++++++-
 .../services/public/test_task_instances.py         |  32 +++-
 3 files changed, 232 insertions(+), 9 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
index 8ab426bbef6..ce7f1a98964 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from collections.abc import Sequence
+from typing import Literal
 
 import structlog
 from fastapi import HTTPException, Query, status
@@ -27,6 +28,8 @@ from sqlalchemy import select, tuple_
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm.session import Session
 
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.api_fastapi.auth.managers.models.resource_details import 
DagAccessEntity, DagDetails
 from airflow.api_fastapi.common.dagbag import DagBagDep, 
get_latest_version_of_dag
 from airflow.api_fastapi.common.db.common import SessionDep
 from airflow.api_fastapi.core_api.datamodels.common import (
@@ -45,6 +48,7 @@ from airflow.api_fastapi.core_api.datamodels.task_instances 
import (
 from airflow.api_fastapi.core_api.security import GetUserDep
 from airflow.api_fastapi.core_api.services.public.common import BulkService
 from airflow.listeners.listener import get_listener_manager
+from airflow.models.dag import DagModel
 from airflow.models.taskinstance import TaskInstance as TI
 from airflow.serialization.definitions.dag import SerializedDAG
 from airflow.utils.state import TaskInstanceState
@@ -201,6 +205,8 @@ class 
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
         self,
         entities: Sequence[str | BulkTaskInstanceBody],
         results: BulkActionResponse,
+        method: Literal["PUT", "DELETE"],
+        action_name: str,
     ) -> tuple[set[tuple[str, str, str, int]], set[tuple[str, str, str]]]:
         """
         Validate entities and categorize them into specific and all map index 
update sets.
@@ -211,6 +217,7 @@ class 
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
         """
         specific_map_index_task_keys = set()
         all_map_index_task_keys = set()
+        dag_authorization_cache: dict[str, bool] = {}
 
         for entity in entities:
             dag_id, dag_run_id, task_id, map_index = 
self._extract_task_identifiers(entity)
@@ -229,6 +236,23 @@ class 
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
                 )
                 continue
 
+            if dag_id not in dag_authorization_cache:
+                team_name = DagModel.get_team_name(dag_id, 
session=self.session)
+                dag_authorization_cache[dag_id] = 
get_auth_manager().is_authorized_dag(
+                    method=method,
+                    access_entity=DagAccessEntity.TASK_INSTANCE,
+                    details=DagDetails(id=dag_id, team_name=team_name),
+                    user=self.user,
+                )
+            if not dag_authorization_cache[dag_id]:
+                results.errors.append(
+                    {
+                        "error": f"User is not authorized to {action_name} 
task instances for DAG '{dag_id}'",
+                        "status_code": status.HTTP_403_FORBIDDEN,
+                    }
+                )
+                continue
+
             # Separate logic for "update all" vs "update specific"
             if map_index is not None:
                 specific_map_index_task_keys.add((dag_id, dag_run_id, task_id, 
map_index))
@@ -318,7 +342,7 @@ class 
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
         """Bulk Update Task Instances."""
         # Validate and categorize entities into specific and all map index 
update sets
         update_specific_map_index_task_keys, update_all_map_index_task_keys = 
self._categorize_entities(
-            action.entities, results
+            action.entities, results, method="PUT", action_name=action.action
         )
 
         try:
@@ -420,7 +444,7 @@ class 
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
         """Bulk delete task instances."""
         # Validate and categorize entities into specific and all map index 
delete sets
         delete_specific_map_index_task_keys, delete_all_map_index_task_keys = 
self._categorize_entities(
-            action.entities, results
+            action.entities, results, method="DELETE", 
action_name=action.action
         )
 
         try:
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 1a5d689d1ef..192cde004a6 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -26,18 +26,22 @@ from unittest import mock
 
 import pendulum
 import pytest
+from fastapi.testclient import TestClient
 from sqlalchemy import delete, func, select, update
 
 from airflow._shared.timezones.timezone import datetime
+from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
 from airflow.dag_processing.bundles.manager import DagBundlesManager
 from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
 from airflow.jobs.job import Job
 from airflow.jobs.triggerer_job_runner import TriggererJobRunner
-from airflow.models import DagRun, Log, TaskInstance
+from airflow.models import DagModel, DagRun, Log, TaskInstance
 from airflow.models.dag_version import DagVersion
+from airflow.models.dagbundle import DagBundleModel
 from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.models.taskmap import TaskMap
+from airflow.models.team import Team
 from airflow.models.trigger import Trigger
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.sdk import BaseOperator, TaskGroup
@@ -50,6 +54,7 @@ from tests_common.test_utils.asserts import 
assert_queries_count
 from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.db import (
     clear_db_runs,
+    clear_db_teams,
     clear_rendered_ti_fields,
 )
 from tests_common.test_utils.logs import check_last_log
@@ -5502,6 +5507,14 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
     BASH_TASK_ID = "also_run_this"
     WILDCARD_ENDPOINT = "/dags/~/dagRuns/~/taskInstances"
 
+    @pytest.fixture(autouse=True)
+    def clean_db(self, session):
+        clear_db_runs()
+        clear_db_teams()
+        yield
+        clear_db_teams()
+        clear_db_runs()
+
     @pytest.mark.parametrize(
         ("default_ti", "actions", "expected_results", "endpoint_url", 
"setup_dags"),
         [
@@ -6069,10 +6082,24 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
     ):
         # Setup task instances
         if setup_dags:
-            for dag_id in setup_dags:
+            if setup_dags == [self.BASH_DAG_ID, self.DAG_ID]:
+                self.create_task_instances(
+                    session,
+                    task_instances=[{"task_id": self.BASH_TASK_ID, "state": 
default_ti[0]["state"]}],
+                    dag_id=self.BASH_DAG_ID,
+                    update_extras=True,
+                )
                 self.create_task_instances(
-                    session, task_instances=default_ti, dag_id=dag_id, 
update_extras=True
+                    session,
+                    task_instances=[{"task_id": self.TASK_ID, "state": 
default_ti[1]["state"]}],
+                    dag_id=self.DAG_ID,
+                    update_extras=True,
                 )
+            else:
+                for dag_id in setup_dags:
+                    self.create_task_instances(
+                        session, task_instances=default_ti, dag_id=dag_id, 
update_extras=True
+                    )
         else:
             self.create_task_instances(session, task_instances=default_ti)
 
@@ -6141,6 +6168,154 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
                     f"Expected map_index={mi} to remain running, got 
{ti.state!r}"
                 )
 
+    def 
test_bulk_task_instances_rejects_unauthorized_dag_ids_from_request_body(self, 
test_client, session):
+        restricted_bundle_name = "restricted-bundle-update"
+        restricted_team_name = "restricted-team-update"
+        self.create_task_instances(
+            session,
+            task_instances=[{"task_id": self.BASH_TASK_ID, "state": 
State.RUNNING}],
+            dag_id=self.BASH_DAG_ID,
+            update_extras=True,
+        )
+        self.create_task_instances(
+            session,
+            task_instances=[{"task_id": self.TASK_ID, "state": State.RUNNING}],
+            dag_id=self.DAG_ID,
+            update_extras=True,
+        )
+        restricted_bundle = DagBundleModel(name=restricted_bundle_name)
+        restricted_team = Team(name=restricted_team_name)
+        restricted_bundle.teams.append(restricted_team)
+        session.add_all([restricted_bundle, restricted_team])
+        session.flush()
+        session.execute(
+            update(DagModel)
+            .where(DagModel.dag_id == self.BASH_DAG_ID)
+            .values(bundle_name=restricted_bundle_name)
+        )
+        session.commit()
+
+        auth_manager = test_client.app.state.auth_manager
+        token = auth_manager._get_token_signer().generate(
+            auth_manager.serialize_user(
+                SimpleAuthManagerUser(username="limited-user", role="user", 
teams=[]),
+            )
+        )
+        with (
+            mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", 
return_value=False),
+            TestClient(
+                test_client.app,
+                headers={"Authorization": f"Bearer {token}"},
+                base_url=str(test_client.base_url),
+            ) as limited_test_client,
+        ):
+            response = limited_test_client.patch(
+                self.WILDCARD_ENDPOINT,
+                json={
+                    "actions": [
+                        {
+                            "action": "update",
+                            "entities": [
+                                {
+                                    "dag_id": self.BASH_DAG_ID,
+                                    "dag_run_id": self.RUN_ID,
+                                    "task_id": self.BASH_TASK_ID,
+                                    "new_state": "success",
+                                },
+                                {
+                                    "dag_id": self.DAG_ID,
+                                    "dag_run_id": self.RUN_ID,
+                                    "task_id": self.TASK_ID,
+                                    "new_state": "success",
+                                },
+                            ],
+                        }
+                    ]
+                },
+            )
+
+        assert response.status_code == 200
+        assert response.json()["update"]["success"] == 
[f"{self.DAG_ID}.{self.RUN_ID}.{self.TASK_ID}[-1]"]
+        assert response.json()["update"]["errors"] == [
+            {
+                "error": f"User is not authorized to update task instances for 
DAG '{self.BASH_DAG_ID}'",
+                "status_code": 403,
+            }
+        ]
+
+    def test_bulk_delete_rejects_unauthorized_dag_ids_from_request_body(self, 
test_client, session):
+        restricted_bundle_name = "restricted-bundle-delete"
+        restricted_team_name = "restricted-team-delete"
+        self.create_task_instances(
+            session,
+            task_instances=[{"task_id": self.BASH_TASK_ID, "state": 
State.SUCCESS}],
+            dag_id=self.BASH_DAG_ID,
+            update_extras=True,
+        )
+        self.create_task_instances(
+            session,
+            task_instances=[{"task_id": self.TASK_ID, "state": State.SUCCESS}],
+            dag_id=self.DAG_ID,
+            update_extras=True,
+        )
+        restricted_bundle = DagBundleModel(name=restricted_bundle_name)
+        restricted_team = Team(name=restricted_team_name)
+        restricted_bundle.teams.append(restricted_team)
+        session.add_all([restricted_bundle, restricted_team])
+        session.flush()
+        session.execute(
+            update(DagModel)
+            .where(DagModel.dag_id == self.BASH_DAG_ID)
+            .values(bundle_name=restricted_bundle_name)
+        )
+        session.commit()
+
+        auth_manager = test_client.app.state.auth_manager
+        token = auth_manager._get_token_signer().generate(
+            auth_manager.serialize_user(
+                SimpleAuthManagerUser(username="limited-user", role="user", 
teams=[]),
+            )
+        )
+        with (
+            mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", 
return_value=False),
+            TestClient(
+                test_client.app,
+                headers={"Authorization": f"Bearer {token}"},
+                base_url=str(test_client.base_url),
+            ) as limited_test_client,
+        ):
+            response = limited_test_client.patch(
+                self.WILDCARD_ENDPOINT,
+                json={
+                    "actions": [
+                        {
+                            "action": "delete",
+                            "entities": [
+                                {
+                                    "dag_id": self.BASH_DAG_ID,
+                                    "dag_run_id": self.RUN_ID,
+                                    "task_id": self.BASH_TASK_ID,
+                                },
+                                {
+                                    "dag_id": self.DAG_ID,
+                                    "dag_run_id": self.RUN_ID,
+                                    "task_id": self.TASK_ID,
+                                },
+                            ],
+                        }
+                    ]
+                },
+            )
+
+        assert response.status_code == 200
+        assert response.json()["delete"]["success"] == 
[f"{self.DAG_ID}.{self.RUN_ID}.{self.TASK_ID}[-1]"]
+        assert response.json()["delete"]["errors"] == [
+            {
+                "error": f"User is not authorized to delete task instances for 
DAG '{self.BASH_DAG_ID}'",
+                "status_code": 403,
+            }
+        ]
+
     def test_should_respond_401(self, unauthenticated_test_client):
         response = unauthenticated_test_client.patch(self.ENDPOINT_URL, 
json={})
         assert response.status_code == 401
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
index f4f051c4e7a..8a29193eae8 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
@@ -17,11 +17,15 @@
 
 from __future__ import annotations
 
+from unittest import mock
+
 import pytest
 
+from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
 from airflow.api_fastapi.core_api.datamodels.common import BulkActionResponse, 
BulkBody
 from airflow.api_fastapi.core_api.datamodels.task_instances import 
BulkTaskInstanceBody
 from airflow.api_fastapi.core_api.services.public.task_instances import 
BulkTaskInstanceService
+from airflow.models import DagModel
 from airflow.providers.standard.operators.bash import BashOperator
 
 from tests_common.test_utils.db import (
@@ -53,6 +57,10 @@ class TestCategorizeTaskInstances(TestTaskInstanceEndpoint):
         self.clear_db()
 
     class MockUser:
+        username = "test_user"
+        role = "admin"
+        teams = ["team1"]
+
         def get_id(self) -> str:
             return "test_user"
 
@@ -184,6 +192,10 @@ class TestExtractTaskIdentifiers(TestTaskInstanceEndpoint):
         self.clear_db()
 
     class MockUser:
+        username = "test_user"
+        role = "admin"
+        teams = ["team1"]
+
         def get_id(self) -> str:
             return "test_user"
 
@@ -260,6 +272,10 @@ class TestCategorizeEntities(TestTaskInstanceEndpoint):
         self.clear_db()
 
     class MockUser:
+        username = "test_user"
+        role = "admin"
+        teams = ["team1"]
+
         def get_id(self) -> str:
             return "test_user"
 
@@ -380,7 +396,6 @@ class TestCategorizeEntities(TestTaskInstanceEndpoint):
         expected_error_count,
     ):
         """Test _categorize_entities with different entity configurations and 
wildcard validation."""
-
         user = self.MockUser()
         bulk_request = BulkBody(actions=[])
         service = BulkTaskInstanceService(
@@ -393,9 +408,18 @@ class TestCategorizeEntities(TestTaskInstanceEndpoint):
         )
 
         results = BulkActionResponse()
-        specific_map_index_task_keys, all_map_index_task_keys = 
service._categorize_entities(
-            entities, results
-        )
+        with (
+            mock.patch.object(DagModel, "get_team_name", return_value="team1"),
+            mock.patch(
+                
"airflow.api_fastapi.core_api.services.public.task_instances.get_auth_manager"
+            ) as mock_get_auth_manager,
+        ):
+            auth_manager = mock.create_autospec(BaseAuthManager, 
instance=True, spec_set=True)
+            auth_manager.is_authorized_dag.return_value = True
+            mock_get_auth_manager.return_value = auth_manager
+            specific_map_index_task_keys, all_map_index_task_keys = 
service._categorize_entities(
+                entities, results, method="PUT", action_name="update"
+            )
 
         assert specific_map_index_task_keys == expected_specific_keys
         assert all_map_index_task_keys == expected_all_keys

Reply via email to