This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b41214e395 Enforce pool team ownership in scheduling loop (#68649)
6b41214e395 is described below

commit 6b41214e395680b0298e988c3318c8fdc0789734
Author: Niko Oliveira <[email protected]>
AuthorDate: Wed Jun 17 06:37:46 2026 -0700

    Enforce pool team ownership in scheduling loop (#68649)
    
    The scheduler's _executable_task_instances_to_queued bypassed pool team
    ownership checks, allowing tasks from one team to use pools owned by
    another team. The existing PoolSlotsAvailableDep had the correct logic
    but was never invoked in the scheduler's critical section.
    
    Add an inline team check that batch-resolves pool-to-team mappings and
    blocks cross-team pool access before tasks are queued for execution.
---
 .../src/airflow/jobs/scheduler_job_runner.py       | 18 ++++++
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 64 ++++++++++++++++++++++
 2 files changed, 82 insertions(+)

diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index ec5b3248178..51178847267 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -575,6 +575,10 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         starved_pools = {pool_name for pool_name, stats in pools.items() if 
stats["open"] <= 0}
 
+        pool_to_team_name: dict[str, str | None] = {}
+        if self._multi_team:
+            pool_to_team_name = 
Pool.get_name_to_team_name_mapping(list(pools.keys()), session=session)
+
         # dag_id to # of running tasks and (dag_id, task_id) to # of running 
tasks.
         concurrency_map = ConcurrencyMap()
         concurrency_map.load(session=session)
@@ -749,6 +753,20 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     starved_pools.add(pool_name)
                     continue
 
+                if pool_team := pool_to_team_name.get(pool_name):
+                    dag_team = dag_id_to_team_name.get(task_instance.dag_id)
+                    if dag_team != pool_team:
+                        self.log.debug(
+                            "Not executing %s. Pool '%s' is assigned to team 
'%s' "
+                            "but task's DAG belongs to team '%s'",
+                            task_instance,
+                            pool_name,
+                            pool_team,
+                            dag_team,
+                        )
+                        starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
+                        continue
+
                 # Make sure to emit metrics if pool has no starving tasks
                 pool_num_starving_tasks.setdefault(pool_name, 0)
 
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 540bdc46054..af1b34423da 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -1425,6 +1425,70 @@ class TestSchedulerJob:
         assert tis[3].key in res_keys
         session.rollback()
 
+    @conf_vars({("core", "multi_team"): "true"})
+    def test_find_executable_task_instances_pool_team_enforcement(self, 
dag_maker, session):
+        """Tasks using a pool owned by another team are not scheduled."""
+        clear_db_teams()
+        clear_db_dag_bundles()
+
+        team_a = Team(name="team_a")
+        team_b = Team(name="team_b")
+        session.add_all([team_a, team_b])
+        session.flush()
+
+        bundle_a = DagBundleModel(name="bundle_a")
+        bundle_a.teams.append(team_a)
+        bundle_b = DagBundleModel(name="bundle_b")
+        bundle_b.teams.append(team_b)
+        session.add_all([bundle_a, bundle_b])
+        session.flush()
+
+        # Pool owned by team_a
+        pool_a = Pool(pool="pool_a", slots=10, include_deferred=False, 
team_name="team_a")
+        # Shared pool (no team)
+        pool_shared = Pool(pool="pool_shared", slots=10, 
include_deferred=False)
+        session.add_all([pool_a, pool_shared])
+        session.flush()
+
+        # DAG in team_a using pool_a (allowed)
+        with dag_maker(dag_id="dag_a", bundle_name="bundle_a", 
session=session):
+            EmptyOperator(task_id="task_a", pool="pool_a")
+        dr_a = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        ti_a = dr_a.get_task_instance("task_a", session=session)
+        ti_a.state = State.SCHEDULED
+        session.merge(ti_a)
+
+        # DAG in team_b using pool_a (should be blocked)
+        with dag_maker(dag_id="dag_b_cross", bundle_name="bundle_b", 
session=session):
+            EmptyOperator(task_id="task_cross", pool="pool_a")
+        dr_b = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        ti_b = dr_b.get_task_instance("task_cross", session=session)
+        ti_b.state = State.SCHEDULED
+        session.merge(ti_b)
+
+        # DAG in team_b using shared pool (allowed)
+        with dag_maker(dag_id="dag_b_shared", bundle_name="bundle_b", 
session=session):
+            EmptyOperator(task_id="task_shared", pool="pool_shared")
+        dr_b2 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        ti_b2 = dr_b2.get_task_instance("task_shared", session=session)
+        ti_b2.state = State.SCHEDULED
+        session.merge(ti_b2)
+        session.flush()
+
+        scheduler_job = Job()
+        self.job_runner = SchedulerJobRunner(job=scheduler_job)
+
+        res = self.job_runner._executable_task_instances_to_queued(max_tis=32, 
session=session)
+        queued_keys = {ti.key for ti in res}
+
+        # team_a task using its own pool: allowed
+        assert ti_a.key in queued_keys
+        # team_b task using team_a's pool: blocked
+        assert ti_b.key not in queued_keys
+        # team_b task using shared pool: allowed
+        assert ti_b2.key in queued_keys
+        session.rollback()
+
     @pytest.mark.parametrize(
         ("state", "total_executed_ti"),
         [

Reply via email to