This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 04d3f4560f Order triggers by - TI priority_weight when assign
unassigned triggers (#32318)
04d3f4560f is described below
commit 04d3f4560f3204eeb72a5ee14fb3316f5bdf35cb
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Aug 14 12:50:38 2023 +0200
Order triggers by - TI priority_weight when assign unassigned triggers
(#32318)
* Order triggers by - TI priority_weight when assign unassigned triggers
Signed-off-by: Hussein Awala <[email protected]>
* Update airflow/models/trigger.py
Co-authored-by: Tzu-ping Chung <[email protected]>
* Replace outer join by inner join and use coalesce to handle None values
* fix unit tests
---------
Signed-off-by: Hussein Awala <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: eladkal <[email protected]>
---
airflow/models/trigger.py | 4 +-
tests/jobs/test_triggerer_job.py | 18 +++++-
tests/models/test_trigger.py | 124 +++++++++++++++++++++++++++++++++++----
3 files changed, 130 insertions(+), 16 deletions(-)
diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index 6ba1d45852..55b49e45c5 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -22,6 +22,7 @@ from typing import Any, Iterable
from sqlalchemy import Column, Integer, String, delete, func, or_, select,
update
from sqlalchemy.orm import Session, joinedload, relationship
+from sqlalchemy.sql.functions import coalesce
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.models.base import Base
@@ -244,8 +245,9 @@ class Trigger(Base):
def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session):
query = with_row_locks(
select(cls.id)
+ .join(TaskInstance, cls.id == TaskInstance.trigger_id,
isouter=False)
.where(or_(cls.triggerer_id.is_(None),
cls.triggerer_id.not_in(alive_triggerer_ids)))
- .order_by(cls.created_date)
+ .order_by(coalesce(TaskInstance.priority_weight, 0).desc(),
cls.created_date)
.limit(capacity),
session,
skip_locked=True,
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 35ebe99b31..bd71c3c9ed 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -414,7 +414,7 @@ def test_trigger_create_race_condition_18392(session,
tmp_path):
assert len(instances) == 1
-def test_trigger_from_dead_triggerer(session):
+def test_trigger_from_dead_triggerer(session, create_task_instance):
"""
Checks that the triggerer will correctly claim a Trigger that is assigned
to a
triggerer that does not exist.
@@ -425,6 +425,13 @@ def test_trigger_from_dead_triggerer(session):
trigger_orm.id = 1
trigger_orm.triggerer_id = 999 # Non-existent triggerer
session.add(trigger_orm)
+ ti_orm = create_task_instance(
+ task_id="ti_orm",
+ execution_date=datetime.datetime.utcnow(),
+ run_id="orm_run_id",
+ )
+ ti_orm.trigger_id = trigger_orm.id
+ session.add(trigger_orm)
session.commit()
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
@@ -434,7 +441,7 @@ def test_trigger_from_dead_triggerer(session):
assert [x for x, y in job_runner.trigger_runner.to_create] == [1]
-def test_trigger_from_expired_triggerer(session):
+def test_trigger_from_expired_triggerer(session, create_task_instance):
"""
Checks that the triggerer will correctly claim a Trigger that is assigned
to a
triggerer that has an expired heartbeat.
@@ -445,6 +452,13 @@ def test_trigger_from_expired_triggerer(session):
trigger_orm.id = 1
trigger_orm.triggerer_id = 42
session.add(trigger_orm)
+ ti_orm = create_task_instance(
+ task_id="ti_orm",
+ execution_date=datetime.datetime.utcnow(),
+ run_id="orm_run_id",
+ )
+ ti_orm.trigger_id = trigger_orm.id
+ session.add(trigger_orm)
# Use a TriggererJobRunner with an expired heartbeat
triggerer_job_orm = Job(TriggererJobRunner.job_type)
triggerer_job_orm.id = 42
diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py
index 98c570c935..3626c94670 100644
--- a/tests/models/test_trigger.py
+++ b/tests/models/test_trigger.py
@@ -171,19 +171,47 @@ def test_assign_unassigned(session, create_task_instance):
trigger_on_healthy_triggerer =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_on_healthy_triggerer.id = 1
trigger_on_healthy_triggerer.triggerer_id = healthy_triggerer.id
+ session.add(trigger_on_healthy_triggerer)
+ ti_trigger_on_healthy_triggerer = create_task_instance(
+ task_id="ti_trigger_on_healthy_triggerer",
+ execution_date=time_now,
+ run_id="trigger_on_healthy_triggerer_run_id",
+ )
+ ti_trigger_on_healthy_triggerer.trigger_id =
trigger_on_healthy_triggerer.id
+ session.add(ti_trigger_on_healthy_triggerer)
trigger_on_unhealthy_triggerer =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_on_unhealthy_triggerer.id = 2
trigger_on_unhealthy_triggerer.triggerer_id = unhealthy_triggerer.id
+ session.add(trigger_on_unhealthy_triggerer)
+ ti_trigger_on_unhealthy_triggerer = create_task_instance(
+ task_id="ti_trigger_on_unhealthy_triggerer",
+ execution_date=time_now + datetime.timedelta(hours=1),
+ run_id="trigger_on_unhealthy_triggerer_run_id",
+ )
+ ti_trigger_on_unhealthy_triggerer.trigger_id =
trigger_on_unhealthy_triggerer.id
+ session.add(ti_trigger_on_unhealthy_triggerer)
trigger_on_killed_triggerer =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_on_killed_triggerer.id = 3
trigger_on_killed_triggerer.triggerer_id = finished_triggerer.id
+ session.add(trigger_on_killed_triggerer)
+ ti_trigger_on_killed_triggerer = create_task_instance(
+ task_id="ti_trigger_on_killed_triggerer",
+ execution_date=time_now + datetime.timedelta(hours=2),
+ run_id="trigger_on_killed_triggerer_run_id",
+ )
+ ti_trigger_on_killed_triggerer.trigger_id = trigger_on_killed_triggerer.id
+ session.add(ti_trigger_on_killed_triggerer)
trigger_unassigned_to_triggerer =
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_unassigned_to_triggerer.id = 4
- assert trigger_unassigned_to_triggerer.triggerer_id is None
- session.add(trigger_on_healthy_triggerer)
- session.add(trigger_on_unhealthy_triggerer)
- session.add(trigger_on_killed_triggerer)
session.add(trigger_unassigned_to_triggerer)
+ ti_trigger_unassigned_to_triggerer = create_task_instance(
+ task_id="ti_trigger_unassigned_to_triggerer",
+ execution_date=time_now + datetime.timedelta(hours=3),
+ run_id="trigger_unassigned_to_triggerer_run_id",
+ )
+ ti_trigger_unassigned_to_triggerer.trigger_id =
trigger_unassigned_to_triggerer.id
+ session.add(ti_trigger_unassigned_to_triggerer)
+ assert trigger_unassigned_to_triggerer.triggerer_id is None
session.commit()
assert session.query(Trigger).count() == 4
Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30)
@@ -209,31 +237,101 @@ def test_assign_unassigned(session,
create_task_instance):
)
-def test_get_sorted_triggers(session, create_task_instance):
+def test_get_sorted_triggers_same_priority_weight(session,
create_task_instance):
"""
- Tests that triggers are sorted by the creation_date.
+ Tests that triggers are sorted by the creation_date if they have the same
priority.
"""
+ old_execution_date = datetime.datetime(
+ 2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+ )
trigger_old = Trigger(
classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={},
- created_date=datetime.datetime(
- 2023, 5, 9, 12, 16, 14, 474415,
tzinfo=pytz.timezone("Africa/Abidjan")
- ),
+ created_date=old_execution_date + datetime.timedelta(seconds=30),
)
trigger_old.id = 1
+ session.add(trigger_old)
+ TI_old = create_task_instance(
+ task_id="old",
+ execution_date=old_execution_date,
+ run_id="old_run_id",
+ )
+ TI_old.priority_weight = 1
+ TI_old.trigger_id = trigger_old.id
+ session.add(TI_old)
+
+ new_execution_date = datetime.datetime(
+ 2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+ )
trigger_new = Trigger(
classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={},
- created_date=datetime.datetime(
- 2023, 5, 9, 12, 17, 14, 474415,
tzinfo=pytz.timezone("Africa/Abidjan")
- ),
+ created_date=new_execution_date + datetime.timedelta(seconds=30),
)
trigger_new.id = 2
- session.add(trigger_old)
session.add(trigger_new)
+ TI_new = create_task_instance(
+ task_id="new",
+ execution_date=new_execution_date,
+ run_id="new_run_id",
+ )
+ TI_new.priority_weight = 1
+ TI_new.trigger_id = trigger_new.id
+ session.add(TI_new)
+
session.commit()
assert session.query(Trigger).count() == 2
trigger_ids_query = Trigger.get_sorted_triggers(capacity=100,
alive_triggerer_ids=[], session=session)
assert trigger_ids_query == [(1,), (2,)]
+
+
+def test_get_sorted_triggers_different_priority_weights(session,
create_task_instance):
+ """
+ Tests that triggers are sorted by the priority_weight.
+ """
+ old_execution_date = datetime.datetime(
+ 2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+ )
+ trigger_old = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger",
+ kwargs={},
+ created_date=old_execution_date + datetime.timedelta(seconds=30),
+ )
+ trigger_old.id = 1
+ session.add(trigger_old)
+ TI_old = create_task_instance(
+ task_id="old",
+ execution_date=old_execution_date,
+ run_id="old_run_id",
+ )
+ TI_old.priority_weight = 1
+ TI_old.trigger_id = trigger_old.id
+ session.add(TI_old)
+
+ new_execution_date = datetime.datetime(
+ 2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+ )
+ trigger_new = Trigger(
+ classpath="airflow.triggers.testing.SuccessTrigger",
+ kwargs={},
+ created_date=new_execution_date + datetime.timedelta(seconds=30),
+ )
+ trigger_new.id = 2
+ session.add(trigger_new)
+ TI_new = create_task_instance(
+ task_id="new",
+ execution_date=new_execution_date,
+ run_id="new_run_id",
+ )
+ TI_new.priority_weight = 2
+ TI_new.trigger_id = trigger_new.id
+ session.add(TI_new)
+
+ session.commit()
+ assert session.query(Trigger).count() == 2
+
+ trigger_ids_query = Trigger.get_sorted_triggers(capacity=100,
alive_triggerer_ids=[], session=session)
+
+ assert trigger_ids_query == [(2,), (1,)]