This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a1c45b950c9 Fix bulk task instance rbac bypass (#64288)
a1c45b950c9 is described below
commit a1c45b950c9a54b7d1691c863def7ee56f07c521
Author: GPK <[email protected]>
AuthorDate: Sat Apr 4 10:05:22 2026 +0100
Fix bulk task instance rbac bypass (#64288)
* Fix bulk task instance RBAC checks across DAGs
* Update tests
* fix up tests
* Fixup tests
* Resolve comments
---
.../core_api/services/public/task_instances.py | 28 +++-
.../core_api/routes/public/test_task_instances.py | 181 ++++++++++++++++++++-
.../services/public/test_task_instances.py | 32 +++-
3 files changed, 232 insertions(+), 9 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
index 8ab426bbef6..ce7f1a98964 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from collections.abc import Sequence
+from typing import Literal
import structlog
from fastapi import HTTPException, Query, status
@@ -27,6 +28,8 @@ from sqlalchemy import select, tuple_
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.session import Session
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.api_fastapi.auth.managers.models.resource_details import
DagAccessEntity, DagDetails
from airflow.api_fastapi.common.dagbag import DagBagDep,
get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.core_api.datamodels.common import (
@@ -45,6 +48,7 @@ from airflow.api_fastapi.core_api.datamodels.task_instances
import (
from airflow.api_fastapi.core_api.security import GetUserDep
from airflow.api_fastapi.core_api.services.public.common import BulkService
from airflow.listeners.listener import get_listener_manager
+from airflow.models.dag import DagModel
from airflow.models.taskinstance import TaskInstance as TI
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.utils.state import TaskInstanceState
@@ -201,6 +205,8 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
self,
entities: Sequence[str | BulkTaskInstanceBody],
results: BulkActionResponse,
+ method: Literal["PUT", "DELETE"],
+ action_name: str,
) -> tuple[set[tuple[str, str, str, int]], set[tuple[str, str, str]]]:
"""
Validate entities and categorize them into specific and all map index
update sets.
@@ -211,6 +217,7 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
"""
specific_map_index_task_keys = set()
all_map_index_task_keys = set()
+ dag_authorization_cache: dict[str, bool] = {}
for entity in entities:
dag_id, dag_run_id, task_id, map_index =
self._extract_task_identifiers(entity)
@@ -229,6 +236,23 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
)
continue
+ if dag_id not in dag_authorization_cache:
+ team_name = DagModel.get_team_name(dag_id,
session=self.session)
+ dag_authorization_cache[dag_id] =
get_auth_manager().is_authorized_dag(
+ method=method,
+ access_entity=DagAccessEntity.TASK_INSTANCE,
+ details=DagDetails(id=dag_id, team_name=team_name),
+ user=self.user,
+ )
+ if not dag_authorization_cache[dag_id]:
+ results.errors.append(
+ {
+ "error": f"User is not authorized to {action_name}
task instances for DAG '{dag_id}'",
+ "status_code": status.HTTP_403_FORBIDDEN,
+ }
+ )
+ continue
+
# Separate logic for "update all" vs "update specific"
if map_index is not None:
specific_map_index_task_keys.add((dag_id, dag_run_id, task_id,
map_index))
@@ -318,7 +342,7 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
"""Bulk Update Task Instances."""
# Validate and categorize entities into specific and all map index
update sets
update_specific_map_index_task_keys, update_all_map_index_task_keys =
self._categorize_entities(
- action.entities, results
+ action.entities, results, method="PUT", action_name=action.action
)
try:
@@ -420,7 +444,7 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
"""Bulk delete task instances."""
# Validate and categorize entities into specific and all map index
delete sets
delete_specific_map_index_task_keys, delete_all_map_index_task_keys =
self._categorize_entities(
- action.entities, results
+ action.entities, results, method="DELETE",
action_name=action.action
)
try:
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 1a5d689d1ef..192cde004a6 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -26,18 +26,22 @@ from unittest import mock
import pendulum
import pytest
+from fastapi.testclient import TestClient
from sqlalchemy import delete, func, select, update
from airflow._shared.timezones.timezone import datetime
+from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
-from airflow.models import DagRun, Log, TaskInstance
+from airflow.models import DagModel, DagRun, Log, TaskInstance
from airflow.models.dag_version import DagVersion
+from airflow.models.dagbundle import DagBundleModel
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.models.taskmap import TaskMap
+from airflow.models.team import Team
from airflow.models.trigger import Trigger
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import BaseOperator, TaskGroup
@@ -50,6 +54,7 @@ from tests_common.test_utils.asserts import
assert_queries_count
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import (
clear_db_runs,
+ clear_db_teams,
clear_rendered_ti_fields,
)
from tests_common.test_utils.logs import check_last_log
@@ -5502,6 +5507,14 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
BASH_TASK_ID = "also_run_this"
WILDCARD_ENDPOINT = "/dags/~/dagRuns/~/taskInstances"
+ @pytest.fixture(autouse=True)
+ def clean_db(self, session):
+ clear_db_runs()
+ clear_db_teams()
+ yield
+ clear_db_teams()
+ clear_db_runs()
+
@pytest.mark.parametrize(
("default_ti", "actions", "expected_results", "endpoint_url",
"setup_dags"),
[
@@ -6069,10 +6082,24 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
):
# Setup task instances
if setup_dags:
- for dag_id in setup_dags:
+ if setup_dags == [self.BASH_DAG_ID, self.DAG_ID]:
+ self.create_task_instances(
+ session,
+ task_instances=[{"task_id": self.BASH_TASK_ID, "state":
default_ti[0]["state"]}],
+ dag_id=self.BASH_DAG_ID,
+ update_extras=True,
+ )
self.create_task_instances(
- session, task_instances=default_ti, dag_id=dag_id,
update_extras=True
+ session,
+ task_instances=[{"task_id": self.TASK_ID, "state":
default_ti[1]["state"]}],
+ dag_id=self.DAG_ID,
+ update_extras=True,
)
+ else:
+ for dag_id in setup_dags:
+ self.create_task_instances(
+ session, task_instances=default_ti, dag_id=dag_id,
update_extras=True
+ )
else:
self.create_task_instances(session, task_instances=default_ti)
@@ -6141,6 +6168,154 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
f"Expected map_index={mi} to remain running, got
{ti.state!r}"
)
+ def
test_bulk_task_instances_rejects_unauthorized_dag_ids_from_request_body(self,
test_client, session):
+ restricted_bundle_name = "restricted-bundle-update"
+ restricted_team_name = "restricted-team-update"
+ self.create_task_instances(
+ session,
+ task_instances=[{"task_id": self.BASH_TASK_ID, "state":
State.RUNNING}],
+ dag_id=self.BASH_DAG_ID,
+ update_extras=True,
+ )
+ self.create_task_instances(
+ session,
+ task_instances=[{"task_id": self.TASK_ID, "state": State.RUNNING}],
+ dag_id=self.DAG_ID,
+ update_extras=True,
+ )
+ restricted_bundle = DagBundleModel(name=restricted_bundle_name)
+ restricted_team = Team(name=restricted_team_name)
+ restricted_bundle.teams.append(restricted_team)
+ session.add_all([restricted_bundle, restricted_team])
+ session.flush()
+ session.execute(
+ update(DagModel)
+ .where(DagModel.dag_id == self.BASH_DAG_ID)
+ .values(bundle_name=restricted_bundle_name)
+ )
+ session.commit()
+
+ auth_manager = test_client.app.state.auth_manager
+ token = auth_manager._get_token_signer().generate(
+ auth_manager.serialize_user(
+ SimpleAuthManagerUser(username="limited-user", role="user",
teams=[]),
+ )
+ )
+ with (
+ mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked",
return_value=False),
+ TestClient(
+ test_client.app,
+ headers={"Authorization": f"Bearer {token}"},
+ base_url=str(test_client.base_url),
+ ) as limited_test_client,
+ ):
+ response = limited_test_client.patch(
+ self.WILDCARD_ENDPOINT,
+ json={
+ "actions": [
+ {
+ "action": "update",
+ "entities": [
+ {
+ "dag_id": self.BASH_DAG_ID,
+ "dag_run_id": self.RUN_ID,
+ "task_id": self.BASH_TASK_ID,
+ "new_state": "success",
+ },
+ {
+ "dag_id": self.DAG_ID,
+ "dag_run_id": self.RUN_ID,
+ "task_id": self.TASK_ID,
+ "new_state": "success",
+ },
+ ],
+ }
+ ]
+ },
+ )
+
+ assert response.status_code == 200
+ assert response.json()["update"]["success"] ==
[f"{self.DAG_ID}.{self.RUN_ID}.{self.TASK_ID}[-1]"]
+ assert response.json()["update"]["errors"] == [
+ {
+ "error": f"User is not authorized to update task instances for
DAG '{self.BASH_DAG_ID}'",
+ "status_code": 403,
+ }
+ ]
+
+ def test_bulk_delete_rejects_unauthorized_dag_ids_from_request_body(self,
test_client, session):
+ restricted_bundle_name = "restricted-bundle-delete"
+ restricted_team_name = "restricted-team-delete"
+ self.create_task_instances(
+ session,
+ task_instances=[{"task_id": self.BASH_TASK_ID, "state":
State.SUCCESS}],
+ dag_id=self.BASH_DAG_ID,
+ update_extras=True,
+ )
+ self.create_task_instances(
+ session,
+ task_instances=[{"task_id": self.TASK_ID, "state": State.SUCCESS}],
+ dag_id=self.DAG_ID,
+ update_extras=True,
+ )
+ restricted_bundle = DagBundleModel(name=restricted_bundle_name)
+ restricted_team = Team(name=restricted_team_name)
+ restricted_bundle.teams.append(restricted_team)
+ session.add_all([restricted_bundle, restricted_team])
+ session.flush()
+ session.execute(
+ update(DagModel)
+ .where(DagModel.dag_id == self.BASH_DAG_ID)
+ .values(bundle_name=restricted_bundle_name)
+ )
+ session.commit()
+
+ auth_manager = test_client.app.state.auth_manager
+ token = auth_manager._get_token_signer().generate(
+ auth_manager.serialize_user(
+ SimpleAuthManagerUser(username="limited-user", role="user",
teams=[]),
+ )
+ )
+ with (
+ mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked",
return_value=False),
+ TestClient(
+ test_client.app,
+ headers={"Authorization": f"Bearer {token}"},
+ base_url=str(test_client.base_url),
+ ) as limited_test_client,
+ ):
+ response = limited_test_client.patch(
+ self.WILDCARD_ENDPOINT,
+ json={
+ "actions": [
+ {
+ "action": "delete",
+ "entities": [
+ {
+ "dag_id": self.BASH_DAG_ID,
+ "dag_run_id": self.RUN_ID,
+ "task_id": self.BASH_TASK_ID,
+ },
+ {
+ "dag_id": self.DAG_ID,
+ "dag_run_id": self.RUN_ID,
+ "task_id": self.TASK_ID,
+ },
+ ],
+ }
+ ]
+ },
+ )
+
+ assert response.status_code == 200
+ assert response.json()["delete"]["success"] ==
[f"{self.DAG_ID}.{self.RUN_ID}.{self.TASK_ID}[-1]"]
+ assert response.json()["delete"]["errors"] == [
+ {
+ "error": f"User is not authorized to delete task instances for
DAG '{self.BASH_DAG_ID}'",
+ "status_code": 403,
+ }
+ ]
+
def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch(self.ENDPOINT_URL,
json={})
assert response.status_code == 401
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
index f4f051c4e7a..8a29193eae8 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
@@ -17,11 +17,15 @@
from __future__ import annotations
+from unittest import mock
+
import pytest
+from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.api_fastapi.core_api.datamodels.common import BulkActionResponse,
BulkBody
from airflow.api_fastapi.core_api.datamodels.task_instances import
BulkTaskInstanceBody
from airflow.api_fastapi.core_api.services.public.task_instances import
BulkTaskInstanceService
+from airflow.models import DagModel
from airflow.providers.standard.operators.bash import BashOperator
from tests_common.test_utils.db import (
@@ -53,6 +57,10 @@ class TestCategorizeTaskInstances(TestTaskInstanceEndpoint):
self.clear_db()
class MockUser:
+ username = "test_user"
+ role = "admin"
+ teams = ["team1"]
+
def get_id(self) -> str:
return "test_user"
@@ -184,6 +192,10 @@ class TestExtractTaskIdentifiers(TestTaskInstanceEndpoint):
self.clear_db()
class MockUser:
+ username = "test_user"
+ role = "admin"
+ teams = ["team1"]
+
def get_id(self) -> str:
return "test_user"
@@ -260,6 +272,10 @@ class TestCategorizeEntities(TestTaskInstanceEndpoint):
self.clear_db()
class MockUser:
+ username = "test_user"
+ role = "admin"
+ teams = ["team1"]
+
def get_id(self) -> str:
return "test_user"
@@ -380,7 +396,6 @@ class TestCategorizeEntities(TestTaskInstanceEndpoint):
expected_error_count,
):
"""Test _categorize_entities with different entity configurations and
wildcard validation."""
-
user = self.MockUser()
bulk_request = BulkBody(actions=[])
service = BulkTaskInstanceService(
@@ -393,9 +408,18 @@ class TestCategorizeEntities(TestTaskInstanceEndpoint):
)
results = BulkActionResponse()
- specific_map_index_task_keys, all_map_index_task_keys =
service._categorize_entities(
- entities, results
- )
+ with (
+ mock.patch.object(DagModel, "get_team_name", return_value="team1"),
+ mock.patch(
+
"airflow.api_fastapi.core_api.services.public.task_instances.get_auth_manager"
+ ) as mock_get_auth_manager,
+ ):
+ auth_manager = mock.create_autospec(BaseAuthManager,
instance=True, spec_set=True)
+ auth_manager.is_authorized_dag.return_value = True
+ mock_get_auth_manager.return_value = auth_manager
+ specific_map_index_task_keys, all_map_index_task_keys =
service._categorize_entities(
+ entities, results, method="PUT", action_name="update"
+ )
assert specific_map_index_task_keys == expected_specific_keys
assert all_map_index_task_keys == expected_all_keys