This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dfdcf02ba6a Add multi-team query filtering to triggerer trigger
assignment (#67517)
dfdcf02ba6a is described below
commit dfdcf02ba6a5e714f57b2db57f9d1e1044a277f5
Author: Ramit Kataria <[email protected]>
AuthorDate: Tue May 26 12:12:58 2026 -0700
Add multi-team query filtering to triggerer trigger assignment (#67517)
When core.multi_team is enabled, the triggerer's polling queries now
filter triggers by team_name. A team-scoped triggerer (--team-name X)
only picks up triggers with team_name='X', while a global triggerer
(no --team-name) only picks up triggers with team_name IS NULL.
When core.multi_team is disabled, no filtering is applied and queries
remain unchanged from the non-multi-team path. This handles the edge
case where multi-team is disabled after triggers were already created
with a team assignment, so those triggers are still picked up rather
than orphaned.
Team and queue filters combine as AND conditions.
---
.../src/airflow/jobs/triggerer_job_runner.py | 3 +-
airflow-core/src/airflow/models/trigger.py | 26 ++-
airflow-core/tests/unit/jobs/test_triggerer_job.py | 23 ++
airflow-core/tests/unit/models/test_trigger.py | 238 +++++++++++++++++++++
4 files changed, 288 insertions(+), 2 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 14c522d9c29..36102aaa0d3 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -657,8 +657,9 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
self.capacity,
self.health_check_threshold,
queues=self.queues,
+ team_name=self.team_name,
)
- ids = Trigger.ids_for_triggerer(self.job.id, queues=self.queues)
+ ids = Trigger.ids_for_triggerer(self.job.id, queues=self.queues,
team_name=self.team_name)
self.update_triggers(set(ids))
def handle_events(self):
diff --git a/airflow-core/src/airflow/models/trigger.py
b/airflow-core/src/airflow/models/trigger.py
index 0d15d5ac497..09c5a1054e5 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -337,7 +337,11 @@ class Trigger(Base):
@classmethod
@provide_session
def ids_for_triggerer(
- cls, triggerer_id, queues: set[str] | None = None, session: Session =
NEW_SESSION
+ cls,
+ triggerer_id,
+ queues: set[str] | None = None,
+ team_name: str | None = None,
+ session: Session = NEW_SESSION,
) -> list[int]:
"""Retrieve a list of trigger ids."""
query = select(cls.id).where(cls.triggerer_id == triggerer_id)
@@ -349,6 +353,14 @@ class Trigger(Base):
else:
query = query.filter(cls.queue.is_(None))
+ # Check config instead of team_name: if multi-team is disabled after
triggers were
+ # created with a team, those triggers must still be picked up instead
of being orphaned.
+ if conf.getboolean("core", "multi_team"):
+ if team_name:
+ query = query.filter(cls.team_name == team_name)
+ else:
+ query = query.filter(cls.team_name.is_(None))
+
return list(session.scalars(query).all())
@classmethod
@@ -359,6 +371,7 @@ class Trigger(Base):
capacity,
health_check_threshold,
queues: set[str] | None = None,
+ team_name: str | None = None,
session: Session = NEW_SESSION,
) -> None:
"""
@@ -393,6 +406,7 @@ class Trigger(Base):
capacity=capacity,
alive_triggerer_ids=alive_triggerer_ids,
queues=queues,
+ team_name=team_name,
session=session,
)
if trigger_ids_query:
@@ -412,6 +426,7 @@ class Trigger(Base):
alive_triggerer_ids: list[int] | Select,
queues: set[str] | None,
session: Session,
+ team_name: str | None = None,
):
"""
Get sorted triggers based on capacity and alive triggerer ids.
@@ -420,6 +435,7 @@ class Trigger(Base):
:param alive_triggerer_ids: The alive triggerer ids as a list or a
select query.
:param queues: The optional set of trigger queues to filter triggers
by.
:param session: The database session.
+ :param team_name: The team to filter triggers for (None = global
triggerer).
"""
from airflow.models.callback import Callback # to avoid circular
import: Callback -> Trigger
@@ -465,6 +481,14 @@ class Trigger(Base):
else:
filtered_query = query.filter(cls.queue.is_(None))
+ # Check config instead of team_name: if multi-team is disabled
after triggers were
+ # created with a team, those triggers must still be picked up
instead of being orphaned.
+ if conf.getboolean("core", "multi_team"):
+ if team_name:
+ filtered_query = filtered_query.filter(cls.team_name ==
team_name)
+ else:
+ filtered_query =
filtered_query.filter(cls.team_name.is_(None))
+
locked_query =
with_row_locks(filtered_query.limit(remaining_capacity), session,
skip_locked=True)
result.extend(session.execute(locked_query).all())
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 3d3a399c121..f42c84a2bf5 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -501,6 +501,29 @@ def
test_load_triggers_raises_without_job(jobless_supervisor, mocker):
update_triggers.assert_not_called()
+def test_load_triggers_passes_team_name(supervisor_builder, mocker):
+ """load_triggers passes team_name to assign_unassigned and
ids_for_triggerer."""
+ proc = supervisor_builder()
+ proc.team_name = "team_x"
+
+ assign_unassigned =
mocker.patch("airflow.jobs.triggerer_job_runner.Trigger.assign_unassigned")
+ ids_for_triggerer = mocker.patch(
+ "airflow.jobs.triggerer_job_runner.Trigger.ids_for_triggerer",
return_value=[1, 2]
+ )
+ mocker.patch.object(TriggerRunnerSupervisor, "update_triggers")
+
+ proc.load_triggers()
+
+ assign_unassigned.assert_called_once_with(
+ proc.job.id,
+ proc.capacity,
+ proc.health_check_threshold,
+ queues=proc.queues,
+ team_name="team_x",
+ )
+ ids_for_triggerer.assert_called_once_with(proc.job.id, queues=proc.queues,
team_name="team_x")
+
+
def test_create_workload_uses_supervisor_id_without_job(jobless_supervisor,
mocker):
"""_create_workload() should fall back to self.id for the log filename
when job is None."""
trigger = mocker.Mock()
diff --git a/airflow-core/tests/unit/models/test_trigger.py
b/airflow-core/tests/unit/models/test_trigger.py
index 5ee079f1f0f..04998f12e0a 100644
--- a/airflow-core/tests/unit/models/test_trigger.py
+++ b/airflow-core/tests/unit/models/test_trigger.py
@@ -1029,3 +1029,241 @@ def test_asset_trigger_ordering_and_capacity(session):
# Only the three oldest should be returned, in order
assert ids == [triggers[0].id, triggers[1].id, triggers[2].id]
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+ ("team_name_filter", "expect_team", "expect_global"),
+ [
+ pytest.param("testing", True, False, id="team_scoped"),
+ pytest.param(None, False, True, id="global_triggerer"),
+ ],
+)
+def test_get_sorted_triggers_multi_team_enabled(
+ session, create_task_instance, testing_team, team_name_filter,
expect_team, expect_global
+):
+ """When multi_team=True, team_name filters to matching triggers only."""
+ time_now = timezone.utcnow()
+ trigger_team = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger", kwargs={},
team_name=testing_team.name
+ )
+ session.add(trigger_team)
+ ti_team = create_task_instance(task_id="team_ti", logical_date=time_now,
run_id="team_run")
+ ti_team.trigger_id = trigger_team.id
+ session.add(ti_team)
+
+ trigger_global =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+ session.add(trigger_global)
+ ti_global = create_task_instance(
+ task_id="global_ti", logical_date=time_now +
datetime.timedelta(hours=1), run_id="global_run"
+ )
+ ti_global.trigger_id = trigger_global.id
+ session.add(ti_global)
+ session.commit()
+
+ result = Trigger.get_sorted_triggers(
+ capacity=100, alive_triggerer_ids=[], queues=None, session=session,
team_name=team_name_filter
+ )
+ ids = {row[0] for row in result}
+ assert (trigger_team.id in ids) == expect_team
+ assert (trigger_global.id in ids) == expect_global
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "False"})
+def test_get_sorted_triggers_multi_team_disabled(session,
create_task_instance, testing_team):
+ """When multi_team=False, all triggers are returned regardless of
team_name."""
+ time_now = timezone.utcnow()
+ trigger_team = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger", kwargs={},
team_name=testing_team.name
+ )
+ session.add(trigger_team)
+ ti_team = create_task_instance(task_id="team_ti", logical_date=time_now,
run_id="team_run")
+ ti_team.trigger_id = trigger_team.id
+ session.add(ti_team)
+
+ trigger_global =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+ session.add(trigger_global)
+ ti_global = create_task_instance(
+ task_id="global_ti", logical_date=time_now +
datetime.timedelta(hours=1), run_id="global_run"
+ )
+ ti_global.trigger_id = trigger_global.id
+ session.add(ti_global)
+ session.commit()
+
+ result = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[],
queues=None, session=session)
+ ids = {row[0] for row in result}
+ assert trigger_team.id in ids
+ assert trigger_global.id in ids
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+ ("team_name_filter", "expect_team", "expect_global"),
+ [
+ pytest.param("testing", True, False, id="team_scoped"),
+ pytest.param(None, False, True, id="global_triggerer"),
+ ],
+)
+def test_ids_for_triggerer_multi_team_enabled(
+ session, create_task_instance, testing_team, team_name_filter,
expect_team, expect_global
+):
+ """When multi_team=True, ids_for_triggerer filters by team_name."""
+ time_now = timezone.utcnow()
+ job = Job(heartrate=10, state=State.RUNNING, latest_heartbeat=time_now)
+ session.add(job)
+ session.flush()
+
+ trigger_team = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger", kwargs={},
team_name=testing_team.name
+ )
+ trigger_team.triggerer_id = job.id
+ session.add(trigger_team)
+ ti_team = create_task_instance(task_id="team_ti", logical_date=time_now,
run_id="team_run")
+ ti_team.trigger_id = trigger_team.id
+ session.add(ti_team)
+
+ trigger_global =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+ trigger_global.triggerer_id = job.id
+ session.add(trigger_global)
+ ti_global = create_task_instance(
+ task_id="global_ti", logical_date=time_now +
datetime.timedelta(hours=1), run_id="global_run"
+ )
+ ti_global.trigger_id = trigger_global.id
+ session.add(ti_global)
+ session.commit()
+
+ ids = set(Trigger.ids_for_triggerer(job.id, team_name=team_name_filter))
+ assert (trigger_team.id in ids) == expect_team
+ assert (trigger_global.id in ids) == expect_global
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "False"})
+def test_ids_for_triggerer_multi_team_disabled(session, create_task_instance,
testing_team):
+ """When multi_team=False, ids_for_triggerer returns all triggers for the
triggerer."""
+ time_now = timezone.utcnow()
+ job = Job(heartrate=10, state=State.RUNNING, latest_heartbeat=time_now)
+ session.add(job)
+ session.flush()
+
+ trigger_team = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger", kwargs={},
team_name=testing_team.name
+ )
+ trigger_team.triggerer_id = job.id
+ session.add(trigger_team)
+ ti_team = create_task_instance(task_id="team_ti", logical_date=time_now,
run_id="team_run")
+ ti_team.trigger_id = trigger_team.id
+ session.add(ti_team)
+
+ trigger_global =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+ trigger_global.triggerer_id = job.id
+ session.add(trigger_global)
+ ti_global = create_task_instance(
+ task_id="global_ti", logical_date=time_now +
datetime.timedelta(hours=1), run_id="global_run"
+ )
+ ti_global.trigger_id = trigger_global.id
+ session.add(ti_global)
+ session.commit()
+
+ ids = Trigger.ids_for_triggerer(job.id)
+ assert set(ids) == {trigger_team.id, trigger_global.id}
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+ ("team_name_filter", "expect_team_assigned", "expect_global_assigned"),
+ [
+ pytest.param("testing", True, False, id="team_scoped"),
+ pytest.param(None, False, True, id="global_triggerer"),
+ ],
+)
+def test_assign_unassigned_multi_team(
+ session,
+ create_task_instance,
+ testing_team,
+ team_name_filter,
+ expect_team_assigned,
+ expect_global_assigned,
+):
+ """When multi_team=True, assign_unassigned only assigns triggers matching
team_name."""
+ time_now = timezone.utcnow()
+ job = Job(heartrate=10, state=State.RUNNING, latest_heartbeat=time_now)
+ session.add(job)
+ session.flush()
+
+ trigger_team = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger", kwargs={},
team_name=testing_team.name
+ )
+ session.add(trigger_team)
+ ti_team = create_task_instance(task_id="team_ti", logical_date=time_now,
run_id="team_run")
+ ti_team.trigger_id = trigger_team.id
+ session.add(ti_team)
+
+ trigger_global =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+ session.add(trigger_global)
+ ti_global = create_task_instance(
+ task_id="global_ti", logical_date=time_now +
datetime.timedelta(hours=1), run_id="global_run"
+ )
+ ti_global.trigger_id = trigger_global.id
+ session.add(ti_global)
+ session.commit()
+
+ Trigger.assign_unassigned(job.id, capacity=100, health_check_threshold=30,
team_name=team_name_filter)
+ session.expire_all()
+
+ assert (session.get(Trigger, trigger_team.id).triggerer_id == job.id) ==
expect_team_assigned
+ assert (session.get(Trigger, trigger_global.id).triggerer_id == job.id) ==
expect_global_assigned
+
+
[email protected]_serialized_dag
+@conf_vars({("core", "multi_team"): "True"})
+def test_team_and_queue_combined(session, create_task_instance, testing_team):
+ """Team and queue filters combine as AND conditions."""
+ time_now = timezone.utcnow()
+
+ trigger_team_q1 = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger",
+ kwargs={},
+ team_name=testing_team.name,
+ queue="q1",
+ )
+ session.add(trigger_team_q1)
+ ti1 = create_task_instance(task_id="ti1", logical_date=time_now,
run_id="run1")
+ ti1.trigger_id = trigger_team_q1.id
+ session.add(ti1)
+
+ trigger_team_q2 = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger",
+ kwargs={},
+ team_name=testing_team.name,
+ queue="q2",
+ )
+ session.add(trigger_team_q2)
+ ti2 = create_task_instance(
+ task_id="ti2", logical_date=time_now + datetime.timedelta(hours=1),
run_id="run2"
+ )
+ ti2.trigger_id = trigger_team_q2.id
+ session.add(ti2)
+
+ trigger_global_q1 =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={},
queue="q1")
+ session.add(trigger_global_q1)
+ ti3 = create_task_instance(
+ task_id="ti3", logical_date=time_now + datetime.timedelta(hours=2),
run_id="run3"
+ )
+ ti3.trigger_id = trigger_global_q1.id
+ session.add(ti3)
+ session.commit()
+
+ result = Trigger.get_sorted_triggers(
+ capacity=100,
+ alive_triggerer_ids=[],
+ queues={"q1"},
+ session=session,
+ team_name=testing_team.name,
+ )
+ ids = [row[0] for row in result]
+ assert ids == [trigger_team_q1.id]