This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dfdcf02ba6a Add multi-team query filtering to triggerer trigger 
assignment (#67517)
dfdcf02ba6a is described below

commit dfdcf02ba6a5e714f57b2db57f9d1e1044a277f5
Author: Ramit Kataria <[email protected]>
AuthorDate: Tue May 26 12:12:58 2026 -0700

    Add multi-team query filtering to triggerer trigger assignment (#67517)
    
    When core.multi_team is enabled, the triggerer's polling queries now
    filter triggers by team_name. A team-scoped triggerer (--team-name X)
    only picks up triggers with team_name='X', while a global triggerer
    (no --team-name) only picks up triggers with team_name IS NULL.
    
    When core.multi_team is disabled, no filtering is applied and queries
    remain unchanged from the non-multi-team path. This handles the edge
    case where multi-team is disabled after triggers were already created
    with a team assignment, so those triggers are still picked up rather
    than orphaned.
    
    Team and queue filters combine as AND conditions.
---
 .../src/airflow/jobs/triggerer_job_runner.py       |   3 +-
 airflow-core/src/airflow/models/trigger.py         |  26 ++-
 airflow-core/tests/unit/jobs/test_triggerer_job.py |  23 ++
 airflow-core/tests/unit/models/test_trigger.py     | 238 +++++++++++++++++++++
 4 files changed, 288 insertions(+), 2 deletions(-)

diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 14c522d9c29..36102aaa0d3 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -657,8 +657,9 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
             self.capacity,
             self.health_check_threshold,
             queues=self.queues,
+            team_name=self.team_name,
         )
-        ids = Trigger.ids_for_triggerer(self.job.id, queues=self.queues)
+        ids = Trigger.ids_for_triggerer(self.job.id, queues=self.queues, 
team_name=self.team_name)
         self.update_triggers(set(ids))
 
     def handle_events(self):
diff --git a/airflow-core/src/airflow/models/trigger.py 
b/airflow-core/src/airflow/models/trigger.py
index 0d15d5ac497..09c5a1054e5 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -337,7 +337,11 @@ class Trigger(Base):
     @classmethod
     @provide_session
     def ids_for_triggerer(
-        cls, triggerer_id, queues: set[str] | None = None, session: Session = 
NEW_SESSION
+        cls,
+        triggerer_id,
+        queues: set[str] | None = None,
+        team_name: str | None = None,
+        session: Session = NEW_SESSION,
     ) -> list[int]:
         """Retrieve a list of trigger ids."""
         query = select(cls.id).where(cls.triggerer_id == triggerer_id)
@@ -349,6 +353,14 @@ class Trigger(Base):
         else:
             query = query.filter(cls.queue.is_(None))
 
+        # Check config instead of team_name: if multi-team is disabled after 
triggers were
+        # created with a team, those triggers must still be picked up instead 
of being orphaned.
+        if conf.getboolean("core", "multi_team"):
+            if team_name:
+                query = query.filter(cls.team_name == team_name)
+            else:
+                query = query.filter(cls.team_name.is_(None))
+
         return list(session.scalars(query).all())
 
     @classmethod
@@ -359,6 +371,7 @@ class Trigger(Base):
         capacity,
         health_check_threshold,
         queues: set[str] | None = None,
+        team_name: str | None = None,
         session: Session = NEW_SESSION,
     ) -> None:
         """
@@ -393,6 +406,7 @@ class Trigger(Base):
             capacity=capacity,
             alive_triggerer_ids=alive_triggerer_ids,
             queues=queues,
+            team_name=team_name,
             session=session,
         )
         if trigger_ids_query:
@@ -412,6 +426,7 @@ class Trigger(Base):
         alive_triggerer_ids: list[int] | Select,
         queues: set[str] | None,
         session: Session,
+        team_name: str | None = None,
     ):
         """
         Get sorted triggers based on capacity and alive triggerer ids.
@@ -420,6 +435,7 @@ class Trigger(Base):
         :param alive_triggerer_ids: The alive triggerer ids as a list or a 
select query.
         :param queues: The optional set of trigger queues to filter triggers 
by.
         :param session: The database session.
+        :param team_name: The team to filter triggers for (None = global 
triggerer).
         """
         from airflow.models.callback import Callback  # to avoid circular 
import: Callback -> Trigger
 
@@ -465,6 +481,14 @@ class Trigger(Base):
             else:
                 filtered_query = query.filter(cls.queue.is_(None))
 
+            # Check config instead of team_name: if multi-team is disabled 
after triggers were
+            # created with a team, those triggers must still be picked up 
instead of being orphaned.
+            if conf.getboolean("core", "multi_team"):
+                if team_name:
+                    filtered_query = filtered_query.filter(cls.team_name == 
team_name)
+                else:
+                    filtered_query = 
filtered_query.filter(cls.team_name.is_(None))
+
             locked_query = 
with_row_locks(filtered_query.limit(remaining_capacity), session, 
skip_locked=True)
             result.extend(session.execute(locked_query).all())
 
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py 
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 3d3a399c121..f42c84a2bf5 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -501,6 +501,29 @@ def 
test_load_triggers_raises_without_job(jobless_supervisor, mocker):
     update_triggers.assert_not_called()
 
 
+def test_load_triggers_passes_team_name(supervisor_builder, mocker):
+    """load_triggers passes team_name to assign_unassigned and 
ids_for_triggerer."""
+    proc = supervisor_builder()
+    proc.team_name = "team_x"
+
+    assign_unassigned = 
mocker.patch("airflow.jobs.triggerer_job_runner.Trigger.assign_unassigned")
+    ids_for_triggerer = mocker.patch(
+        "airflow.jobs.triggerer_job_runner.Trigger.ids_for_triggerer", 
return_value=[1, 2]
+    )
+    mocker.patch.object(TriggerRunnerSupervisor, "update_triggers")
+
+    proc.load_triggers()
+
+    assign_unassigned.assert_called_once_with(
+        proc.job.id,
+        proc.capacity,
+        proc.health_check_threshold,
+        queues=proc.queues,
+        team_name="team_x",
+    )
+    ids_for_triggerer.assert_called_once_with(proc.job.id, queues=proc.queues, 
team_name="team_x")
+
+
 def test_create_workload_uses_supervisor_id_without_job(jobless_supervisor, 
mocker):
     """_create_workload() should fall back to self.id for the log filename 
when job is None."""
     trigger = mocker.Mock()
diff --git a/airflow-core/tests/unit/models/test_trigger.py 
b/airflow-core/tests/unit/models/test_trigger.py
index 5ee079f1f0f..04998f12e0a 100644
--- a/airflow-core/tests/unit/models/test_trigger.py
+++ b/airflow-core/tests/unit/models/test_trigger.py
@@ -1029,3 +1029,241 @@ def test_asset_trigger_ordering_and_capacity(session):
 
     # Only the three oldest should be returned, in order
     assert ids == [triggers[0].id, triggers[1].id, triggers[2].id]
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+    ("team_name_filter", "expect_team", "expect_global"),
+    [
+        pytest.param("testing", True, False, id="team_scoped"),
+        pytest.param(None, False, True, id="global_triggerer"),
+    ],
+)
+def test_get_sorted_triggers_multi_team_enabled(
+    session, create_task_instance, testing_team, team_name_filter, 
expect_team, expect_global
+):
+    """When multi_team=True, team_name filters to matching triggers only."""
+    time_now = timezone.utcnow()
+    trigger_team = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, 
team_name=testing_team.name
+    )
+    session.add(trigger_team)
+    ti_team = create_task_instance(task_id="team_ti", logical_date=time_now, 
run_id="team_run")
+    ti_team.trigger_id = trigger_team.id
+    session.add(ti_team)
+
+    trigger_global = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    session.add(trigger_global)
+    ti_global = create_task_instance(
+        task_id="global_ti", logical_date=time_now + 
datetime.timedelta(hours=1), run_id="global_run"
+    )
+    ti_global.trigger_id = trigger_global.id
+    session.add(ti_global)
+    session.commit()
+
+    result = Trigger.get_sorted_triggers(
+        capacity=100, alive_triggerer_ids=[], queues=None, session=session, 
team_name=team_name_filter
+    )
+    ids = {row[0] for row in result}
+    assert (trigger_team.id in ids) == expect_team
+    assert (trigger_global.id in ids) == expect_global
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "False"})
+def test_get_sorted_triggers_multi_team_disabled(session, 
create_task_instance, testing_team):
+    """When multi_team=False, all triggers are returned regardless of 
team_name."""
+    time_now = timezone.utcnow()
+    trigger_team = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, 
team_name=testing_team.name
+    )
+    session.add(trigger_team)
+    ti_team = create_task_instance(task_id="team_ti", logical_date=time_now, 
run_id="team_run")
+    ti_team.trigger_id = trigger_team.id
+    session.add(ti_team)
+
+    trigger_global = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    session.add(trigger_global)
+    ti_global = create_task_instance(
+        task_id="global_ti", logical_date=time_now + 
datetime.timedelta(hours=1), run_id="global_run"
+    )
+    ti_global.trigger_id = trigger_global.id
+    session.add(ti_global)
+    session.commit()
+
+    result = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], 
queues=None, session=session)
+    ids = {row[0] for row in result}
+    assert trigger_team.id in ids
+    assert trigger_global.id in ids
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+    ("team_name_filter", "expect_team", "expect_global"),
+    [
+        pytest.param("testing", True, False, id="team_scoped"),
+        pytest.param(None, False, True, id="global_triggerer"),
+    ],
+)
+def test_ids_for_triggerer_multi_team_enabled(
+    session, create_task_instance, testing_team, team_name_filter, 
expect_team, expect_global
+):
+    """When multi_team=True, ids_for_triggerer filters by team_name."""
+    time_now = timezone.utcnow()
+    job = Job(heartrate=10, state=State.RUNNING, latest_heartbeat=time_now)
+    session.add(job)
+    session.flush()
+
+    trigger_team = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, 
team_name=testing_team.name
+    )
+    trigger_team.triggerer_id = job.id
+    session.add(trigger_team)
+    ti_team = create_task_instance(task_id="team_ti", logical_date=time_now, 
run_id="team_run")
+    ti_team.trigger_id = trigger_team.id
+    session.add(ti_team)
+
+    trigger_global = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    trigger_global.triggerer_id = job.id
+    session.add(trigger_global)
+    ti_global = create_task_instance(
+        task_id="global_ti", logical_date=time_now + 
datetime.timedelta(hours=1), run_id="global_run"
+    )
+    ti_global.trigger_id = trigger_global.id
+    session.add(ti_global)
+    session.commit()
+
+    ids = set(Trigger.ids_for_triggerer(job.id, team_name=team_name_filter))
+    assert (trigger_team.id in ids) == expect_team
+    assert (trigger_global.id in ids) == expect_global
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "False"})
+def test_ids_for_triggerer_multi_team_disabled(session, create_task_instance, 
testing_team):
+    """When multi_team=False, ids_for_triggerer returns all triggers for the 
triggerer."""
+    time_now = timezone.utcnow()
+    job = Job(heartrate=10, state=State.RUNNING, latest_heartbeat=time_now)
+    session.add(job)
+    session.flush()
+
+    trigger_team = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, 
team_name=testing_team.name
+    )
+    trigger_team.triggerer_id = job.id
+    session.add(trigger_team)
+    ti_team = create_task_instance(task_id="team_ti", logical_date=time_now, 
run_id="team_run")
+    ti_team.trigger_id = trigger_team.id
+    session.add(ti_team)
+
+    trigger_global = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    trigger_global.triggerer_id = job.id
+    session.add(trigger_global)
+    ti_global = create_task_instance(
+        task_id="global_ti", logical_date=time_now + 
datetime.timedelta(hours=1), run_id="global_run"
+    )
+    ti_global.trigger_id = trigger_global.id
+    session.add(ti_global)
+    session.commit()
+
+    ids = Trigger.ids_for_triggerer(job.id)
+    assert set(ids) == {trigger_team.id, trigger_global.id}
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+    ("team_name_filter", "expect_team_assigned", "expect_global_assigned"),
+    [
+        pytest.param("testing", True, False, id="team_scoped"),
+        pytest.param(None, False, True, id="global_triggerer"),
+    ],
+)
+def test_assign_unassigned_multi_team(
+    session,
+    create_task_instance,
+    testing_team,
+    team_name_filter,
+    expect_team_assigned,
+    expect_global_assigned,
+):
+    """When multi_team=True, assign_unassigned only assigns triggers matching 
team_name."""
+    time_now = timezone.utcnow()
+    job = Job(heartrate=10, state=State.RUNNING, latest_heartbeat=time_now)
+    session.add(job)
+    session.flush()
+
+    trigger_team = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, 
team_name=testing_team.name
+    )
+    session.add(trigger_team)
+    ti_team = create_task_instance(task_id="team_ti", logical_date=time_now, 
run_id="team_run")
+    ti_team.trigger_id = trigger_team.id
+    session.add(ti_team)
+
+    trigger_global = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    session.add(trigger_global)
+    ti_global = create_task_instance(
+        task_id="global_ti", logical_date=time_now + 
datetime.timedelta(hours=1), run_id="global_run"
+    )
+    ti_global.trigger_id = trigger_global.id
+    session.add(ti_global)
+    session.commit()
+
+    Trigger.assign_unassigned(job.id, capacity=100, health_check_threshold=30, 
team_name=team_name_filter)
+    session.expire_all()
+
+    assert (session.get(Trigger, trigger_team.id).triggerer_id == job.id) == 
expect_team_assigned
+    assert (session.get(Trigger, trigger_global.id).triggerer_id == job.id) == 
expect_global_assigned
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
+def test_team_and_queue_combined(session, create_task_instance, testing_team):
+    """Team and queue filters combine as AND conditions."""
+    time_now = timezone.utcnow()
+
+    trigger_team_q1 = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger",
+        kwargs={},
+        team_name=testing_team.name,
+        queue="q1",
+    )
+    session.add(trigger_team_q1)
+    ti1 = create_task_instance(task_id="ti1", logical_date=time_now, 
run_id="run1")
+    ti1.trigger_id = trigger_team_q1.id
+    session.add(ti1)
+
+    trigger_team_q2 = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger",
+        kwargs={},
+        team_name=testing_team.name,
+        queue="q2",
+    )
+    session.add(trigger_team_q2)
+    ti2 = create_task_instance(
+        task_id="ti2", logical_date=time_now + datetime.timedelta(hours=1), 
run_id="run2"
+    )
+    ti2.trigger_id = trigger_team_q2.id
+    session.add(ti2)
+
+    trigger_global_q1 = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, 
queue="q1")
+    session.add(trigger_global_q1)
+    ti3 = create_task_instance(
+        task_id="ti3", logical_date=time_now + datetime.timedelta(hours=2), 
run_id="run3"
+    )
+    ti3.trigger_id = trigger_global_q1.id
+    session.add(ti3)
+    session.commit()
+
+    result = Trigger.get_sorted_triggers(
+        capacity=100,
+        alive_triggerer_ids=[],
+        queues={"q1"},
+        session=session,
+        team_name=testing_team.name,
+    )
+    ids = [row[0] for row in result]
+    assert ids == [trigger_team_q1.id]

Reply via email to