This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 04d3f4560f Order triggers by - TI priority_weight when assign 
unassigned triggers (#32318)
04d3f4560f is described below

commit 04d3f4560f3204eeb72a5ee14fb3316f5bdf35cb
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Aug 14 12:50:38 2023 +0200

    Order triggers by - TI priority_weight when assign unassigned triggers 
(#32318)
    
    * Order triggers by - TI priority_weight when assign unassigned triggers
    
    Signed-off-by: Hussein Awala <[email protected]>
    
    * Update airflow/models/trigger.py
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    * Replace outer join by inner join and use coalesce to handle None values
    
    * fix unit tests
    
    ---------
    
    Signed-off-by: Hussein Awala <[email protected]>
    Co-authored-by: Tzu-ping Chung <[email protected]>
    Co-authored-by: eladkal <[email protected]>
---
 airflow/models/trigger.py        |   4 +-
 tests/jobs/test_triggerer_job.py |  18 +++++-
 tests/models/test_trigger.py     | 124 +++++++++++++++++++++++++++++++++++----
 3 files changed, 130 insertions(+), 16 deletions(-)

diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index 6ba1d45852..55b49e45c5 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -22,6 +22,7 @@ from typing import Any, Iterable
 
 from sqlalchemy import Column, Integer, String, delete, func, or_, select, 
update
 from sqlalchemy.orm import Session, joinedload, relationship
+from sqlalchemy.sql.functions import coalesce
 
 from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.models.base import Base
@@ -244,8 +245,9 @@ class Trigger(Base):
     def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session):
         query = with_row_locks(
             select(cls.id)
+            .join(TaskInstance, cls.id == TaskInstance.trigger_id, 
isouter=False)
             .where(or_(cls.triggerer_id.is_(None), 
cls.triggerer_id.not_in(alive_triggerer_ids)))
-            .order_by(cls.created_date)
+            .order_by(coalesce(TaskInstance.priority_weight, 0).desc(), 
cls.created_date)
             .limit(capacity),
             session,
             skip_locked=True,
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 35ebe99b31..bd71c3c9ed 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -414,7 +414,7 @@ def test_trigger_create_race_condition_18392(session, 
tmp_path):
     assert len(instances) == 1
 
 
-def test_trigger_from_dead_triggerer(session):
+def test_trigger_from_dead_triggerer(session, create_task_instance):
     """
     Checks that the triggerer will correctly claim a Trigger that is assigned 
to a
     triggerer that does not exist.
@@ -425,6 +425,13 @@ def test_trigger_from_dead_triggerer(session):
     trigger_orm.id = 1
     trigger_orm.triggerer_id = 999  # Non-existent triggerer
     session.add(trigger_orm)
+    ti_orm = create_task_instance(
+        task_id="ti_orm",
+        execution_date=datetime.datetime.utcnow(),
+        run_id="orm_run_id",
+    )
+    ti_orm.trigger_id = trigger_orm.id
+    session.add(trigger_orm)
     session.commit()
     # Make a TriggererJobRunner and have it retrieve DB tasks
     job = Job()
@@ -434,7 +441,7 @@ def test_trigger_from_dead_triggerer(session):
     assert [x for x, y in job_runner.trigger_runner.to_create] == [1]
 
 
-def test_trigger_from_expired_triggerer(session):
+def test_trigger_from_expired_triggerer(session, create_task_instance):
     """
     Checks that the triggerer will correctly claim a Trigger that is assigned 
to a
     triggerer that has an expired heartbeat.
@@ -445,6 +452,13 @@ def test_trigger_from_expired_triggerer(session):
     trigger_orm.id = 1
     trigger_orm.triggerer_id = 42
     session.add(trigger_orm)
+    ti_orm = create_task_instance(
+        task_id="ti_orm",
+        execution_date=datetime.datetime.utcnow(),
+        run_id="orm_run_id",
+    )
+    ti_orm.trigger_id = trigger_orm.id
+    session.add(trigger_orm)
     # Use a TriggererJobRunner with an expired heartbeat
     triggerer_job_orm = Job(TriggererJobRunner.job_type)
     triggerer_job_orm.id = 42
diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py
index 98c570c935..3626c94670 100644
--- a/tests/models/test_trigger.py
+++ b/tests/models/test_trigger.py
@@ -171,19 +171,47 @@ def test_assign_unassigned(session, create_task_instance):
     trigger_on_healthy_triggerer = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
     trigger_on_healthy_triggerer.id = 1
     trigger_on_healthy_triggerer.triggerer_id = healthy_triggerer.id
+    session.add(trigger_on_healthy_triggerer)
+    ti_trigger_on_healthy_triggerer = create_task_instance(
+        task_id="ti_trigger_on_healthy_triggerer",
+        execution_date=time_now,
+        run_id="trigger_on_healthy_triggerer_run_id",
+    )
+    ti_trigger_on_healthy_triggerer.trigger_id = 
trigger_on_healthy_triggerer.id
+    session.add(ti_trigger_on_healthy_triggerer)
     trigger_on_unhealthy_triggerer = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
     trigger_on_unhealthy_triggerer.id = 2
     trigger_on_unhealthy_triggerer.triggerer_id = unhealthy_triggerer.id
+    session.add(trigger_on_unhealthy_triggerer)
+    ti_trigger_on_unhealthy_triggerer = create_task_instance(
+        task_id="ti_trigger_on_unhealthy_triggerer",
+        execution_date=time_now + datetime.timedelta(hours=1),
+        run_id="trigger_on_unhealthy_triggerer_run_id",
+    )
+    ti_trigger_on_unhealthy_triggerer.trigger_id = 
trigger_on_unhealthy_triggerer.id
+    session.add(ti_trigger_on_unhealthy_triggerer)
     trigger_on_killed_triggerer = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
     trigger_on_killed_triggerer.id = 3
     trigger_on_killed_triggerer.triggerer_id = finished_triggerer.id
+    session.add(trigger_on_killed_triggerer)
+    ti_trigger_on_killed_triggerer = create_task_instance(
+        task_id="ti_trigger_on_killed_triggerer",
+        execution_date=time_now + datetime.timedelta(hours=2),
+        run_id="trigger_on_killed_triggerer_run_id",
+    )
+    ti_trigger_on_killed_triggerer.trigger_id = trigger_on_killed_triggerer.id
+    session.add(ti_trigger_on_killed_triggerer)
     trigger_unassigned_to_triggerer = 
Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
     trigger_unassigned_to_triggerer.id = 4
-    assert trigger_unassigned_to_triggerer.triggerer_id is None
-    session.add(trigger_on_healthy_triggerer)
-    session.add(trigger_on_unhealthy_triggerer)
-    session.add(trigger_on_killed_triggerer)
     session.add(trigger_unassigned_to_triggerer)
+    ti_trigger_unassigned_to_triggerer = create_task_instance(
+        task_id="ti_trigger_unassigned_to_triggerer",
+        execution_date=time_now + datetime.timedelta(hours=3),
+        run_id="trigger_unassigned_to_triggerer_run_id",
+    )
+    ti_trigger_unassigned_to_triggerer.trigger_id = 
trigger_unassigned_to_triggerer.id
+    session.add(ti_trigger_unassigned_to_triggerer)
+    assert trigger_unassigned_to_triggerer.triggerer_id is None
     session.commit()
     assert session.query(Trigger).count() == 4
     Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30)
@@ -209,31 +237,101 @@ def test_assign_unassigned(session, 
create_task_instance):
     )
 
 
-def test_get_sorted_triggers(session, create_task_instance):
+def test_get_sorted_triggers_same_priority_weight(session, 
create_task_instance):
     """
-    Tests that triggers are sorted by the creation_date.
+    Tests that triggers are sorted by the creation_date if they have the same 
priority.
     """
+    old_execution_date = datetime.datetime(
+        2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+    )
     trigger_old = Trigger(
         classpath="airflow.triggers.testing.SuccessTrigger",
         kwargs={},
-        created_date=datetime.datetime(
-            2023, 5, 9, 12, 16, 14, 474415, 
tzinfo=pytz.timezone("Africa/Abidjan")
-        ),
+        created_date=old_execution_date + datetime.timedelta(seconds=30),
     )
     trigger_old.id = 1
+    session.add(trigger_old)
+    TI_old = create_task_instance(
+        task_id="old",
+        execution_date=old_execution_date,
+        run_id="old_run_id",
+    )
+    TI_old.priority_weight = 1
+    TI_old.trigger_id = trigger_old.id
+    session.add(TI_old)
+
+    new_execution_date = datetime.datetime(
+        2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+    )
     trigger_new = Trigger(
         classpath="airflow.triggers.testing.SuccessTrigger",
         kwargs={},
-        created_date=datetime.datetime(
-            2023, 5, 9, 12, 17, 14, 474415, 
tzinfo=pytz.timezone("Africa/Abidjan")
-        ),
+        created_date=new_execution_date + datetime.timedelta(seconds=30),
     )
     trigger_new.id = 2
-    session.add(trigger_old)
     session.add(trigger_new)
+    TI_new = create_task_instance(
+        task_id="new",
+        execution_date=new_execution_date,
+        run_id="new_run_id",
+    )
+    TI_new.priority_weight = 1
+    TI_new.trigger_id = trigger_new.id
+    session.add(TI_new)
+
     session.commit()
     assert session.query(Trigger).count() == 2
 
     trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, 
alive_triggerer_ids=[], session=session)
 
     assert trigger_ids_query == [(1,), (2,)]
+
+
+def test_get_sorted_triggers_different_priority_weights(session, 
create_task_instance):
+    """
+    Tests that triggers are sorted by the priority_weight.
+    """
+    old_execution_date = datetime.datetime(
+        2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+    )
+    trigger_old = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger",
+        kwargs={},
+        created_date=old_execution_date + datetime.timedelta(seconds=30),
+    )
+    trigger_old.id = 1
+    session.add(trigger_old)
+    TI_old = create_task_instance(
+        task_id="old",
+        execution_date=old_execution_date,
+        run_id="old_run_id",
+    )
+    TI_old.priority_weight = 1
+    TI_old.trigger_id = trigger_old.id
+    session.add(TI_old)
+
+    new_execution_date = datetime.datetime(
+        2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
+    )
+    trigger_new = Trigger(
+        classpath="airflow.triggers.testing.SuccessTrigger",
+        kwargs={},
+        created_date=new_execution_date + datetime.timedelta(seconds=30),
+    )
+    trigger_new.id = 2
+    session.add(trigger_new)
+    TI_new = create_task_instance(
+        task_id="new",
+        execution_date=new_execution_date,
+        run_id="new_run_id",
+    )
+    TI_new.priority_weight = 2
+    TI_new.trigger_id = trigger_new.id
+    session.add(TI_new)
+
+    session.commit()
+    assert session.query(Trigger).count() == 2
+
+    trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, 
alive_triggerer_ids=[], session=session)
+
+    assert trigger_ids_query == [(2,), (1,)]

Reply via email to