This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 85fd0e1102 Optimise and migrate to SA2-compatible syntax for 
TaskReschedule (#33720)
85fd0e1102 is described below

commit 85fd0e1102d664337d6bb08e590979867372f61d
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Tue Oct 17 01:18:37 2023 +0400

    Optimise and migrate to SA2-compatible syntax for TaskReschedule (#33720)
    
    Co-authored-by: Tzu-ping Chung <uranu...@gmail.com>
---
 airflow/models/taskinstance.py                     |  10 +-
 airflow/models/taskreschedule.py                   |  54 +++++-
 airflow/sensors/base.py                            |  16 +-
 airflow/ti_deps/deps/ready_to_reschedule.py        |  11 +-
 tests/models/test_taskinstance.py                  |  35 ++--
 tests/sensors/test_base.py                         |  32 ++--
 tests/ti_deps/deps/test_ready_to_reschedule_dep.py | 208 ++++++++++-----------
 7 files changed, 212 insertions(+), 154 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 1ab748c551..62c6533188 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2080,9 +2080,13 @@ class TaskInstance(Base, LoggingMixin):
             # If the task continues after being deferred (next_method is set), 
use the original start_date
             self.start_date = self.start_date if self.next_method else 
timezone.utcnow()
             if self.state == TaskInstanceState.UP_FOR_RESCHEDULE:
-                task_reschedule: TR = TR.query_for_task_instance(self, 
session=session).first()
-                if task_reschedule:
-                    self.start_date = task_reschedule.start_date
+                tr_start_date = session.scalar(
+                    TR.stmt_for_task_instance(self, descending=False)
+                    .with_only_columns(TR.start_date)
+                    .limit(1)
+                )
+                if tr_start_date:
+                    self.start_date = tr_start_date
 
             # Secondly we find non-runnable but requeueable tis. We reset its 
state.
             # This is because we might have hit concurrency limits,
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 53fd163ed5..1305932d3b 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -18,12 +18,14 @@
 """TaskReschedule tracks rescheduled task instances."""
 from __future__ import annotations
 
+import warnings
 from typing import TYPE_CHECKING
 
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, 
asc, desc, event, text
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, 
asc, desc, event, select, text
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import relationship
 
+from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
@@ -32,6 +34,7 @@ if TYPE_CHECKING:
     import datetime
 
     from sqlalchemy.orm import Query, Session
+    from sqlalchemy.sql import Select
 
     from airflow.models.operator import Operator
     from airflow.models.taskinstance import TaskInstance
@@ -97,6 +100,38 @@ class TaskReschedule(Base):
         self.reschedule_date = reschedule_date
         self.duration = (self.end_date - self.start_date).total_seconds()
 
+    @classmethod
+    def stmt_for_task_instance(
+        cls,
+        ti: TaskInstance,
+        *,
+        try_number: int | None = None,
+        descending: bool = False,
+    ) -> Select:
+        """
+        Statement for task reschedules for a given the task instance.
+
+        :param ti: the task instance to find task reschedules for
+        :param descending: If True then records are returned in descending 
order
+        :param try_number: Look for TaskReschedule of the given try_number. 
Default is None which
+            looks for the same try_number of the given task_instance.
+        :meta private:
+        """
+        if try_number is None:
+            try_number = ti.try_number
+
+        return (
+            select(cls)
+            .where(
+                cls.dag_id == ti.dag_id,
+                cls.task_id == ti.task_id,
+                cls.run_id == ti.run_id,
+                cls.map_index == ti.map_index,
+                cls.try_number == try_number,
+            )
+            .order_by(desc(cls.id) if descending else asc(cls.id))
+        )
+
     @staticmethod
     @provide_session
     def query_for_task_instance(
@@ -106,7 +141,7 @@ class TaskReschedule(Base):
         try_number: int | None = None,
     ) -> Query:
         """
-        Return query for task reschedules for a given the task instance.
+        Return query for task reschedules for a given the task instance 
(deprecated).
 
         :param session: the database session object
         :param task_instance: the task instance to find task reschedules for
@@ -114,6 +149,12 @@ class TaskReschedule(Base):
         :param try_number: Look for TaskReschedule of the given try_number. 
Default is None which
             looks for the same try_number of the given task_instance.
         """
+        warnings.warn(
+            "Using this method is no longer advised, and it is expected to be 
removed in the future.",
+            category=RemovedInAirflow3Warning,
+            stacklevel=2,
+        )
+
         if try_number is None:
             try_number = task_instance.try_number
 
@@ -145,8 +186,13 @@ class TaskReschedule(Base):
         :param try_number: Look for TaskReschedule of the given try_number. 
Default is None which
             looks for the same try_number of the given task_instance.
         """
-        return TaskReschedule.query_for_task_instance(
-            task_instance, session=session, try_number=try_number
+        warnings.warn(
+            "Using this method is no longer advised, and it is expected to be 
removed in the future.",
+            category=RemovedInAirflow3Warning,
+            stacklevel=2,
+        )
+        return session.scalars(
+            TaskReschedule.stmt_for_task_instance(ti=task_instance, 
try_number=try_number, descending=False)
         ).all()
 
 
diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py
index 1a505f5de6..8bf40c913b 100644
--- a/airflow/sensors/base.py
+++ b/airflow/sensors/base.py
@@ -48,6 +48,7 @@ from airflow.utils import timezone
 # Google Provider before 3.0.0 imported apply_defaults from here.
 # See  https://github.com/apache/airflow/issues/16035
 from airflow.utils.decorators import apply_defaults  # noqa: F401
+from airflow.utils.session import create_session
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -212,13 +213,16 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
             # If reschedule, use the start date of the first try (first try 
can be either the very
             # first execution of the task, or the first execution after the 
task was cleared.)
             first_try_number = context["ti"].max_tries - self.retries + 1
-            task_reschedules = TaskReschedule.find_for_task_instance(
-                context["ti"], try_number=first_try_number
-            )
-            if not task_reschedules:
+            with create_session() as session:
+                start_date = session.scalar(
+                    TaskReschedule.stmt_for_task_instance(
+                        context["ti"], try_number=first_try_number, 
descending=False
+                    )
+                    .with_only_columns(TaskReschedule.start_date)
+                    .limit(1)
+                )
+            if not start_date:
                 start_date = timezone.utcnow()
-            else:
-                start_date = task_reschedules[0].start_date
             started_at = start_date
 
             def run_duration() -> float:
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py 
b/airflow/ti_deps/deps/ready_to_reschedule.py
index 0eaa52c1eb..2e49b5bc48 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -71,12 +71,12 @@ class ReadyToRescheduleDep(BaseTIDep):
             )
             return
 
-        task_reschedule = (
-            TaskReschedule.query_for_task_instance(task_instance=ti, 
descending=True, session=session)
-            .with_entities(TaskReschedule.reschedule_date)
-            .first()
+        next_reschedule_date = session.scalar(
+            TaskReschedule.stmt_for_task_instance(ti, descending=True)
+            .with_only_columns(TaskReschedule.reschedule_date)
+            .limit(1)
         )
-        if not task_reschedule:
+        if not next_reschedule_date:
             # Because mapped sensors don't have the reschedule property, 
here's the last resort
             # and we need a slightly different passing reason
             if is_mapped:
@@ -86,7 +86,6 @@ class ReadyToRescheduleDep(BaseTIDep):
             return
 
         now = timezone.utcnow()
-        next_reschedule_date = task_reschedule.reschedule_date
         if now >= next_reschedule_date:
             yield self._passing_status(reason="Task instance id ready for 
reschedule.")
             return
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 0ad1da24b4..27ff39763f 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -108,6 +108,15 @@ def test_pool():
         session.rollback()
 
 
+@pytest.fixture
+def task_reschedules_for_ti():
+    def wrapper(ti):
+        with create_session() as session:
+            return 
session.scalars(TaskReschedule.stmt_for_task_instance(ti=ti, 
descending=False)).all()
+
+    return wrapper
+
+
 class CallbackWrapper:
     task_id: str | None = None
     dag_id: str | None = None
@@ -735,7 +744,7 @@ class TestTaskInstance:
         date = ti.next_retry_datetime()
         assert date == ti.end_date + datetime.timedelta(seconds=1)
 
-    def test_reschedule_handling(self, dag_maker):
+    def test_reschedule_handling(self, dag_maker, task_reschedules_for_ti):
         """
         Test that task reschedules are handled properly
         """
@@ -785,8 +794,7 @@ class TestTaskInstance:
             assert ti.start_date == expected_start_date
             assert ti.end_date == expected_end_date
             assert ti.duration == expected_duration
-            trs = TaskReschedule.find_for_task_instance(ti)
-            assert len(trs) == expected_task_reschedule_count
+            assert len(task_reschedules_for_ti(ti)) == 
expected_task_reschedule_count
 
         date1 = timezone.utcnow()
         date2 = date1 + datetime.timedelta(minutes=1)
@@ -833,7 +841,7 @@ class TestTaskInstance:
         done, fail = True, False
         run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
 
-    def test_mapped_reschedule_handling(self, dag_maker):
+    def test_mapped_reschedule_handling(self, dag_maker, 
task_reschedules_for_ti):
         """
         Test that mapped task reschedules are handled properly
         """
@@ -884,8 +892,7 @@ class TestTaskInstance:
             assert ti.start_date == expected_start_date
             assert ti.end_date == expected_end_date
             assert ti.duration == expected_duration
-            trs = TaskReschedule.find_for_task_instance(ti)
-            assert len(trs) == expected_task_reschedule_count
+            assert len(task_reschedules_for_ti(ti)) == 
expected_task_reschedule_count
 
         date1 = timezone.utcnow()
         date2 = date1 + datetime.timedelta(minutes=1)
@@ -933,7 +940,7 @@ class TestTaskInstance:
         run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
 
     @pytest.mark.usefixtures("test_pool")
-    def test_mapped_task_reschedule_handling_clear_reschedules(self, 
dag_maker):
+    def test_mapped_task_reschedule_handling_clear_reschedules(self, 
dag_maker, task_reschedules_for_ti):
         """
         Test that mapped task reschedules clearing are handled properly
         """
@@ -983,8 +990,7 @@ class TestTaskInstance:
             assert ti.start_date == expected_start_date
             assert ti.end_date == expected_end_date
             assert ti.duration == expected_duration
-            trs = TaskReschedule.find_for_task_instance(ti)
-            assert len(trs) == expected_task_reschedule_count
+            assert len(task_reschedules_for_ti(ti)) == 
expected_task_reschedule_count
 
         date1 = timezone.utcnow()
 
@@ -997,11 +1003,10 @@ class TestTaskInstance:
         assert ti.state == State.NONE
         assert ti._try_number == 0
         # Check that reschedules for ti have also been cleared.
-        trs = TaskReschedule.find_for_task_instance(ti)
-        assert not trs
+        assert not task_reschedules_for_ti(ti)
 
     @pytest.mark.usefixtures("test_pool")
-    def test_reschedule_handling_clear_reschedules(self, dag_maker):
+    def test_reschedule_handling_clear_reschedules(self, dag_maker, 
task_reschedules_for_ti):
         """
         Test that task reschedules clearing are handled properly
         """
@@ -1051,8 +1056,7 @@ class TestTaskInstance:
             assert ti.start_date == expected_start_date
             assert ti.end_date == expected_end_date
             assert ti.duration == expected_duration
-            trs = TaskReschedule.find_for_task_instance(ti)
-            assert len(trs) == expected_task_reschedule_count
+            assert len(task_reschedules_for_ti(ti)) == 
expected_task_reschedule_count
 
         date1 = timezone.utcnow()
 
@@ -1065,8 +1069,7 @@ class TestTaskInstance:
         assert ti.state == State.NONE
         assert ti._try_number == 0
         # Check that reschedules for ti have also been cleared.
-        trs = TaskReschedule.find_for_task_instance(ti)
-        assert not trs
+        assert not task_reschedules_for_ti(ti)
 
     def test_depends_on_past(self, dag_maker):
         with dag_maker(dag_id="test_depends_on_past"):
diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py
index 3be780b2e5..02e1f2ed5a 100644
--- a/tests/sensors/test_base.py
+++ b/tests/sensors/test_base.py
@@ -55,6 +55,7 @@ from 
airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor impor
 from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, 
poke_mode_only
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
 from airflow.utils import timezone
+from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from tests.test_utils import db
@@ -69,6 +70,15 @@ SENSOR_OP = "sensor_op"
 DEV_NULL = "dev/null"
 
 
+@pytest.fixture
+def task_reschedules_for_ti():
+    def wrapper(ti):
+        with create_session() as session:
+            return 
session.scalars(TaskReschedule.stmt_for_task_instance(ti=ti, 
descending=False)).all()
+
+    return wrapper
+
+
 class DummySensor(BaseSensorOperator):
     def __init__(self, return_value=False, **kwargs):
         super().__init__(**kwargs)
@@ -216,7 +226,7 @@ class TestBaseSensor:
             if ti.task_id == DUMMY_OP:
                 assert ti.state == State.NONE
 
-    def test_ok_with_reschedule(self, make_sensor, time_machine):
+    def test_ok_with_reschedule(self, make_sensor, time_machine, 
task_reschedules_for_ti):
         sensor, dr = make_sensor(return_value=None, poke_interval=10, 
timeout=25, mode="reschedule")
         sensor.poke = Mock(side_effect=[False, False, True])
 
@@ -233,7 +243,7 @@ class TestBaseSensor:
                 # verify task start date is the initial one
                 assert ti.start_date == date1
                 # verify one row in task_reschedule table
-                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                task_reschedules = task_reschedules_for_ti(ti)
                 assert len(task_reschedules) == 1
                 assert task_reschedules[0].start_date == date1
                 assert task_reschedules[0].reschedule_date == date1 + 
timedelta(seconds=sensor.poke_interval)
@@ -253,7 +263,7 @@ class TestBaseSensor:
                 # verify task start date is the initial one
                 assert ti.start_date == date1
                 # verify two rows in task_reschedule table
-                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                task_reschedules = task_reschedules_for_ti(ti)
                 assert len(task_reschedules) == 2
                 assert task_reschedules[1].start_date == date2
                 assert task_reschedules[1].reschedule_date == date2 + 
timedelta(seconds=sensor.poke_interval)
@@ -328,7 +338,7 @@ class TestBaseSensor:
             if ti.task_id == DUMMY_OP:
                 assert ti.state == State.NONE
 
-    def test_ok_with_reschedule_and_retry(self, make_sensor, time_machine):
+    def test_ok_with_reschedule_and_retry(self, make_sensor, time_machine, 
task_reschedules_for_ti):
         sensor, dr = make_sensor(
             return_value=None,
             poke_interval=10,
@@ -349,7 +359,7 @@ class TestBaseSensor:
             if ti.task_id == SENSOR_OP:
                 assert ti.state == State.UP_FOR_RESCHEDULE
                 # verify one row in task_reschedule table
-                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                task_reschedules = task_reschedules_for_ti(ti)
                 assert len(task_reschedules) == 1
                 assert task_reschedules[0].start_date == date1
                 assert task_reschedules[0].reschedule_date == date1 + 
timedelta(seconds=sensor.poke_interval)
@@ -382,7 +392,7 @@ class TestBaseSensor:
             if ti.task_id == SENSOR_OP:
                 assert ti.state == State.UP_FOR_RESCHEDULE
                 # verify one row in task_reschedule table
-                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                task_reschedules = task_reschedules_for_ti(ti)
                 assert len(task_reschedules) == 1
                 assert task_reschedules[0].start_date == date3
                 assert task_reschedules[0].reschedule_date == date3 + 
timedelta(seconds=sensor.poke_interval)
@@ -413,7 +423,7 @@ class TestBaseSensor:
         with pytest.raises(AirflowException):
             DummySensor(task_id="a", mode="foo")
 
-    def test_ok_with_custom_reschedule_exception(self, make_sensor):
+    def test_ok_with_custom_reschedule_exception(self, make_sensor, 
task_reschedules_for_ti):
         sensor, dr = make_sensor(return_value=None, mode="reschedule")
         date1 = timezone.utcnow()
         date2 = date1 + timedelta(seconds=60)
@@ -432,7 +442,7 @@ class TestBaseSensor:
                 # verify task is re-scheduled, i.e. state set to NONE
                 assert ti.state == State.UP_FOR_RESCHEDULE
                 # verify one row in task_reschedule table
-                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                task_reschedules = task_reschedules_for_ti(ti)
                 assert len(task_reschedules) == 1
                 assert task_reschedules[0].start_date == date1
                 assert task_reschedules[0].reschedule_date == date2
@@ -449,7 +459,7 @@ class TestBaseSensor:
                 # verify task is re-scheduled, i.e. state set to NONE
                 assert ti.state == State.UP_FOR_RESCHEDULE
                 # verify two rows in task_reschedule table
-                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                task_reschedules = task_reschedules_for_ti(ti)
                 assert len(task_reschedules) == 2
                 assert task_reschedules[1].start_date == date2
                 assert task_reschedules[1].reschedule_date == date3
@@ -467,7 +477,7 @@ class TestBaseSensor:
             if ti.task_id == DUMMY_OP:
                 assert ti.state == State.NONE
 
-    def test_reschedule_with_test_mode(self, make_sensor):
+    def test_reschedule_with_test_mode(self, make_sensor, 
task_reschedules_for_ti):
         sensor, dr = make_sensor(return_value=None, poke_interval=10, 
timeout=25, mode="reschedule")
         sensor.poke = Mock(side_effect=[False])
 
@@ -482,7 +492,7 @@ class TestBaseSensor:
                 # in test mode state is not modified
                 assert ti.state == State.NONE
                 # in test mode no reschedule request is recorded
-                task_reschedules = TaskReschedule.find_for_task_instance(ti)
+                task_reschedules = task_reschedules_for_ti(ti)
                 assert len(task_reschedules) == 0
             if ti.task_id == DUMMY_OP:
                 assert ti.state == State.NONE
diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py 
b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
index c9da328925..1d56da8ff3 100644
--- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
+++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -18,166 +18,158 @@
 from __future__ import annotations
 
 from datetime import timedelta
-from unittest.mock import Mock, patch
+from unittest.mock import patch
+
+import pytest
+import time_machine
+from slugify import slugify
 
-from airflow.models.dag import DAG
-from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
+from airflow.utils import timezone
+from airflow.utils.session import create_session
 from airflow.utils.state import State
-from airflow.utils.timezone import utcnow
+from tests.test_utils import db
 
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 
-class TestNotInReschedulePeriodDep:
-    def _get_task_instance(self, state):
-        dag = DAG("test_dag")
-        task = Mock(dag=dag, reschedule=True, is_mapped=False)
-        ti = TaskInstance(task=task, state=state, run_id=None)
-        return ti
 
-    def _get_task_reschedule(self, reschedule_date):
-        task = Mock(dag_id="test_dag", task_id="test_task", is_mapped=False)
-        reschedule = TaskReschedule(
-            task=task,
-            run_id=None,
-            try_number=None,
-            start_date=reschedule_date,
-            end_date=reschedule_date,
-            reschedule_date=reschedule_date,
-        )
-        return reschedule
+@pytest.fixture
+def not_expected_tr_db_call():
+    def side_effect(*args, **kwargs):
+        raise SystemError("Not expected DB call to `TaskReschedule` 
statement.")
+
+    with 
patch("airflow.models.taskreschedule.TaskReschedule.stmt_for_task_instance") as 
m:
+        m.side_effect = side_effect
+        yield m
 
-    def _get_mapped_task_instance(self, state):
-        dag = DAG("test_dag")
-        task = Mock(dag=dag, reschedule=True, is_mapped=True)
-        ti = TaskInstance(task=task, state=state, run_id=None)
-        return ti
 
-    def _get_mapped_task_reschedule(self, reschedule_date):
-        task = Mock(dag_id="test_dag", task_id="test_task", is_mapped=True)
-        reschedule = TaskReschedule(
-            task=task,
-            run_id=None,
-            try_number=None,
-            start_date=reschedule_date,
-            end_date=reschedule_date,
-            reschedule_date=reschedule_date,
+class TestNotInReschedulePeriodDep:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, request, create_task_instance):
+        db.clear_db_runs()
+        db.clear_rendered_ti_fields()
+
+        self.dag_id = f"dag_{slugify(request.cls.__name__)}"
+        self.task_id = f"task_{slugify(request.node.name, max_length=40)}"
+        self.run_id = f"run_{slugify(request.node.name, max_length=40)}"
+        self.ti_maker = create_task_instance
+
+        with time_machine.travel(DEFAULT_DATE, tick=False):
+            yield
+        db.clear_rendered_ti_fields()
+        db.clear_db_runs()
+
+    def _get_task_instance(self, state, *, map_index=-1):
+        """Helper which create fake task_instance"""
+        ti = self.ti_maker(
+            dag_id=self.dag_id,
+            task_id=self.task_id,
+            run_id=self.run_id,
+            execution_date=DEFAULT_DATE,
+            map_index=map_index,
+            state=state,
         )
-        return reschedule
+        ti.task.reschedule = True
+        return ti
 
-    def test_should_pass_if_ignore_in_reschedule_period_is_set(self):
+    def _create_task_reschedule(self, ti, minutes: int | list[int]):
+        """Helper which create fake task_reschedule(s) from task_instance."""
+        if isinstance(minutes, int):
+            minutes = [minutes]
+        trs = []
+        for minutes_timedelta in minutes:
+            dt = ti.execution_date + timedelta(minutes=minutes_timedelta)
+            trs.append(
+                TaskReschedule(
+                    task=ti.task,
+                    run_id=ti.run_id,
+                    try_number=ti.try_number,
+                    map_index=ti.map_index,
+                    start_date=dt,
+                    end_date=dt,
+                    reschedule_date=dt,
+                )
+            )
+        with create_session() as session:
+            session.add_all(trs)
+            session.commit()
+
+    def test_should_pass_if_ignore_in_reschedule_period_is_set(self, 
not_expected_tr_db_call):
         ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
         dep_context = DepContext(ignore_in_reschedule_period=True)
         assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context)
 
-    def test_should_pass_if_not_reschedule_mode(self):
+    def test_should_pass_if_not_reschedule_mode(self, not_expected_tr_db_call):
         ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
         del ti.task.reschedule
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    def test_should_pass_if_not_in_none_state(self):
+    def test_should_pass_if_not_in_none_state(self, not_expected_tr_db_call):
         ti = self._get_task_instance(State.UP_FOR_RETRY)
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_should_pass_if_no_reschedule_record_exists(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = []
+    def test_should_pass_if_no_reschedule_record_exists(self):
         ti = self._get_task_instance(State.NONE)
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_should_pass_after_reschedule_date_one(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = (
-            self._get_task_reschedule(utcnow() - timedelta(minutes=1))
-        )
+    def test_should_pass_after_reschedule_date_one(self):
         ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
+        self._create_task_reschedule(ti, -1)
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_should_pass_after_reschedule_date_multiple(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = [
-            self._get_task_reschedule(utcnow() - timedelta(minutes=21)),
-            self._get_task_reschedule(utcnow() - timedelta(minutes=11)),
-            self._get_task_reschedule(utcnow() - timedelta(minutes=1)),
-        ][-1]
+    def test_should_pass_after_reschedule_date_multiple(self):
         ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
+        self._create_task_reschedule(ti, [-21, -11, -1])
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_should_fail_before_reschedule_date_one(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = (
-            self._get_task_reschedule(utcnow() + timedelta(minutes=1))
-        )
-
+    def test_should_fail_before_reschedule_date_one(self):
         ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
+        self._create_task_reschedule(ti, 1)
         assert not ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_should_fail_before_reschedule_date_multiple(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = [
-            self._get_task_reschedule(utcnow() - timedelta(minutes=19)),
-            self._get_task_reschedule(utcnow() - timedelta(minutes=9)),
-            self._get_task_reschedule(utcnow() + timedelta(minutes=1)),
-        ][-1]
+    def test_should_fail_before_reschedule_date_multiple(self):
         ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
+        self._create_task_reschedule(ti, [-19, -9, 1])
+        # Last TaskReschedule doesn't meet requirements
         assert not ReadyToRescheduleDep().is_met(ti=ti)
 
-    def 
test_mapped_task_should_pass_if_ignore_in_reschedule_period_is_set(self):
-        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+    def 
test_mapped_task_should_pass_if_ignore_in_reschedule_period_is_set(self, 
not_expected_tr_db_call):
+        ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42)
         dep_context = DepContext(ignore_in_reschedule_period=True)
         assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_mapped_task_should_pass_if_not_reschedule_mode(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = []
-        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+    def test_mapped_task_should_pass_if_not_reschedule_mode(self, 
not_expected_tr_db_call):
+        ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42)
         del ti.task.reschedule
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    def test_mapped_task_should_pass_if_not_in_none_state(self):
-        ti = self._get_mapped_task_instance(State.UP_FOR_RETRY)
+    def test_mapped_task_should_pass_if_not_in_none_state(self, 
not_expected_tr_db_call):
+        ti = self._get_task_instance(State.UP_FOR_RETRY, map_index=42)
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_mapped_should_pass_if_no_reschedule_record_exists(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = []
-        ti = self._get_mapped_task_instance(State.NONE)
+    def test_mapped_should_pass_if_no_reschedule_record_exists(self):
+        ti = self._get_task_instance(State.NONE, map_index=42)
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_mapped_should_pass_after_reschedule_date_one(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = (
-            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1))
-        )
-        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+    def test_mapped_should_pass_after_reschedule_date_one(self):
+        ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42)
+        self._create_task_reschedule(ti, [-1])
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_mapped_task_should_pass_after_reschedule_date_multiple(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = [
-            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=21)),
-            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=11)),
-            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1)),
-        ][-1]
-        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+    def test_mapped_task_should_pass_after_reschedule_date_multiple(self):
+        ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42)
+        self._create_task_reschedule(ti, [-21, -11, -1])
         assert ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_mapped_task_should_fail_before_reschedule_date_one(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = (
-            self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1))
-        )
-
-        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+    def test_mapped_task_should_fail_before_reschedule_date_one(self):
+        ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42)
+        self._create_task_reschedule(ti, 1)
         assert not ReadyToRescheduleDep().is_met(ti=ti)
 
-    
@patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance")
-    def test_mapped_task_should_fail_before_reschedule_date_multiple(self, 
mock_query_for_task_instance):
-        
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value
 = [
-            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=19)),
-            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=9)),
-            self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1)),
-        ][-1]
-        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+    def test_mapped_task_should_fail_before_reschedule_date_multiple(self):
+        ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42)
+        self._create_task_reschedule(ti, [-19, -9, 1])
+        # Last TaskReschedule doesn't meet requirements
         assert not ReadyToRescheduleDep().is_met(ti=ti)

Reply via email to