This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 85fd0e1102 Optimise and migrate to SA2-compatible syntax for TaskReschedule (#33720) 85fd0e1102 is described below commit 85fd0e1102d664337d6bb08e590979867372f61d Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Tue Oct 17 01:18:37 2023 +0400 Optimise and migrate to SA2-compatible syntax for TaskReschedule (#33720) Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> --- airflow/models/taskinstance.py | 10 +- airflow/models/taskreschedule.py | 54 +++++- airflow/sensors/base.py | 16 +- airflow/ti_deps/deps/ready_to_reschedule.py | 11 +- tests/models/test_taskinstance.py | 35 ++-- tests/sensors/test_base.py | 32 ++-- tests/ti_deps/deps/test_ready_to_reschedule_dep.py | 208 ++++++++++----------- 7 files changed, 212 insertions(+), 154 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 1ab748c551..62c6533188 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2080,9 +2080,13 @@ class TaskInstance(Base, LoggingMixin): # If the task continues after being deferred (next_method is set), use the original start_date self.start_date = self.start_date if self.next_method else timezone.utcnow() if self.state == TaskInstanceState.UP_FOR_RESCHEDULE: - task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() - if task_reschedule: - self.start_date = task_reschedule.start_date + tr_start_date = session.scalar( + TR.stmt_for_task_instance(self, descending=False) + .with_only_columns(TR.start_date) + .limit(1) + ) + if tr_start_date: + self.start_date = tr_start_date # Secondly we find non-runnable but requeueable tis. We reset its state. # This is because we might have hit concurrency limits, diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index 53fd163ed5..1305932d3b 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -18,12 +18,14 @@ """TaskReschedule tracks rescheduled task instances.""" from __future__ import annotations +import warnings from typing import TYPE_CHECKING -from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, text +from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, select, text from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import relationship +from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.base import COLLATION_ARGS, ID_LEN, Base from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime @@ -32,6 +34,7 @@ if TYPE_CHECKING: import datetime from sqlalchemy.orm import Query, Session + from sqlalchemy.sql import Select from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance @@ -97,6 +100,38 @@ class TaskReschedule(Base): self.reschedule_date = reschedule_date self.duration = (self.end_date - self.start_date).total_seconds() + @classmethod + def stmt_for_task_instance( + cls, + ti: TaskInstance, + *, + try_number: int | None = None, + descending: bool = False, + ) -> Select: + """ + Statement for task reschedules for a given the task instance. + + :param ti: the task instance to find task reschedules for + :param descending: If True then records are returned in descending order + :param try_number: Look for TaskReschedule of the given try_number. Default is None which + looks for the same try_number of the given task_instance. + :meta private: + """ + if try_number is None: + try_number = ti.try_number + + return ( + select(cls) + .where( + cls.dag_id == ti.dag_id, + cls.task_id == ti.task_id, + cls.run_id == ti.run_id, + cls.map_index == ti.map_index, + cls.try_number == try_number, + ) + .order_by(desc(cls.id) if descending else asc(cls.id)) + ) + @staticmethod @provide_session def query_for_task_instance( @@ -106,7 +141,7 @@ class TaskReschedule(Base): try_number: int | None = None, ) -> Query: """ - Return query for task reschedules for a given the task instance. + Return query for task reschedules for a given the task instance (deprecated). :param session: the database session object :param task_instance: the task instance to find task reschedules for @@ -114,6 +149,12 @@ class TaskReschedule(Base): :param try_number: Look for TaskReschedule of the given try_number. Default is None which looks for the same try_number of the given task_instance. """ + warnings.warn( + "Using this method is no longer advised, and it is expected to be removed in the future.", + category=RemovedInAirflow3Warning, + stacklevel=2, + ) + if try_number is None: try_number = task_instance.try_number @@ -145,8 +186,13 @@ class TaskReschedule(Base): :param try_number: Look for TaskReschedule of the given try_number. Default is None which looks for the same try_number of the given task_instance. """ - return TaskReschedule.query_for_task_instance( - task_instance, session=session, try_number=try_number + warnings.warn( + "Using this method is no longer advised, and it is expected to be removed in the future.", + category=RemovedInAirflow3Warning, + stacklevel=2, + ) + return session.scalars( + TaskReschedule.stmt_for_task_instance(ti=task_instance, try_number=try_number, descending=False) ).all() diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 1a505f5de6..8bf40c913b 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -48,6 +48,7 @@ from airflow.utils import timezone # Google Provider before 3.0.0 imported apply_defaults from here. # See https://github.com/apache/airflow/issues/16035 from airflow.utils.decorators import apply_defaults # noqa: F401 +from airflow.utils.session import create_session if TYPE_CHECKING: from airflow.utils.context import Context @@ -212,13 +213,16 @@ class BaseSensorOperator(BaseOperator, SkipMixin): # If reschedule, use the start date of the first try (first try can be either the very # first execution of the task, or the first execution after the task was cleared.) first_try_number = context["ti"].max_tries - self.retries + 1 - task_reschedules = TaskReschedule.find_for_task_instance( - context["ti"], try_number=first_try_number - ) - if not task_reschedules: + with create_session() as session: + start_date = session.scalar( + TaskReschedule.stmt_for_task_instance( + context["ti"], try_number=first_try_number, descending=False + ) + .with_only_columns(TaskReschedule.start_date) + .limit(1) + ) + if not start_date: start_date = timezone.utcnow() - else: - start_date = task_reschedules[0].start_date started_at = start_date def run_duration() -> float: diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py index 0eaa52c1eb..2e49b5bc48 100644 --- a/airflow/ti_deps/deps/ready_to_reschedule.py +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -71,12 +71,12 @@ class ReadyToRescheduleDep(BaseTIDep): ) return - task_reschedule = ( - TaskReschedule.query_for_task_instance(task_instance=ti, descending=True, session=session) - .with_entities(TaskReschedule.reschedule_date) - .first() + next_reschedule_date = session.scalar( + TaskReschedule.stmt_for_task_instance(ti, descending=True) + .with_only_columns(TaskReschedule.reschedule_date) + .limit(1) ) - if not task_reschedule: + if not next_reschedule_date: # Because mapped sensors don't have the reschedule property, here's the last resort # and we need a slightly different passing reason if is_mapped: @@ -86,7 +86,6 @@ class ReadyToRescheduleDep(BaseTIDep): return now = timezone.utcnow() - next_reschedule_date = task_reschedule.reschedule_date if now >= next_reschedule_date: yield self._passing_status(reason="Task instance id ready for reschedule.") return diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 0ad1da24b4..27ff39763f 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -108,6 +108,15 @@ def test_pool(): session.rollback() +@pytest.fixture +def task_reschedules_for_ti(): + def wrapper(ti): + with create_session() as session: + return session.scalars(TaskReschedule.stmt_for_task_instance(ti=ti, descending=False)).all() + + return wrapper + + class CallbackWrapper: task_id: str | None = None dag_id: str | None = None @@ -735,7 +744,7 @@ class TestTaskInstance: date = ti.next_retry_datetime() assert date == ti.end_date + datetime.timedelta(seconds=1) - def test_reschedule_handling(self, dag_maker): + def test_reschedule_handling(self, dag_maker, task_reschedules_for_ti): """ Test that task reschedules are handled properly """ @@ -785,8 +794,7 @@ class TestTaskInstance: assert ti.start_date == expected_start_date assert ti.end_date == expected_end_date assert ti.duration == expected_duration - trs = TaskReschedule.find_for_task_instance(ti) - assert len(trs) == expected_task_reschedule_count + assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count date1 = timezone.utcnow() date2 = date1 + datetime.timedelta(minutes=1) @@ -833,7 +841,7 @@ class TestTaskInstance: done, fail = True, False run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) - def test_mapped_reschedule_handling(self, dag_maker): + def test_mapped_reschedule_handling(self, dag_maker, task_reschedules_for_ti): """ Test that mapped task reschedules are handled properly """ @@ -884,8 +892,7 @@ class TestTaskInstance: assert ti.start_date == expected_start_date assert ti.end_date == expected_end_date assert ti.duration == expected_duration - trs = TaskReschedule.find_for_task_instance(ti) - assert len(trs) == expected_task_reschedule_count + assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count date1 = timezone.utcnow() date2 = date1 + datetime.timedelta(minutes=1) @@ -933,7 +940,7 @@ class TestTaskInstance: run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) @pytest.mark.usefixtures("test_pool") - def test_mapped_task_reschedule_handling_clear_reschedules(self, dag_maker): + def test_mapped_task_reschedule_handling_clear_reschedules(self, dag_maker, task_reschedules_for_ti): """ Test that mapped task reschedules clearing are handled properly """ @@ -983,8 +990,7 @@ class TestTaskInstance: assert ti.start_date == expected_start_date assert ti.end_date == expected_end_date assert ti.duration == expected_duration - trs = TaskReschedule.find_for_task_instance(ti) - assert len(trs) == expected_task_reschedule_count + assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count date1 = timezone.utcnow() @@ -997,11 +1003,10 @@ class TestTaskInstance: assert ti.state == State.NONE assert ti._try_number == 0 # Check that reschedules for ti have also been cleared. - trs = TaskReschedule.find_for_task_instance(ti) - assert not trs + assert not task_reschedules_for_ti(ti) @pytest.mark.usefixtures("test_pool") - def test_reschedule_handling_clear_reschedules(self, dag_maker): + def test_reschedule_handling_clear_reschedules(self, dag_maker, task_reschedules_for_ti): """ Test that task reschedules clearing are handled properly """ @@ -1051,8 +1056,7 @@ class TestTaskInstance: assert ti.start_date == expected_start_date assert ti.end_date == expected_end_date assert ti.duration == expected_duration - trs = TaskReschedule.find_for_task_instance(ti) - assert len(trs) == expected_task_reschedule_count + assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count date1 = timezone.utcnow() @@ -1065,8 +1069,7 @@ class TestTaskInstance: assert ti.state == State.NONE assert ti._try_number == 0 # Check that reschedules for ti have also been cleared. - trs = TaskReschedule.find_for_task_instance(ti) - assert not trs + assert not task_reschedules_for_ti(ti) def test_depends_on_past(self, dag_maker): with dag_maker(dag_id="test_depends_on_past"): diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index 3be780b2e5..02e1f2ed5a 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -55,6 +55,7 @@ from airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor impor from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils import timezone +from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timezone import datetime from tests.test_utils import db @@ -69,6 +70,15 @@ SENSOR_OP = "sensor_op" DEV_NULL = "dev/null" +@pytest.fixture +def task_reschedules_for_ti(): + def wrapper(ti): + with create_session() as session: + return session.scalars(TaskReschedule.stmt_for_task_instance(ti=ti, descending=False)).all() + + return wrapper + + class DummySensor(BaseSensorOperator): def __init__(self, return_value=False, **kwargs): super().__init__(**kwargs) @@ -216,7 +226,7 @@ class TestBaseSensor: if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_ok_with_reschedule(self, make_sensor, time_machine): + def test_ok_with_reschedule(self, make_sensor, time_machine, task_reschedules_for_ti): sensor, dr = make_sensor(return_value=None, poke_interval=10, timeout=25, mode="reschedule") sensor.poke = Mock(side_effect=[False, False, True]) @@ -233,7 +243,7 @@ class TestBaseSensor: # verify task start date is the initial one assert ti.start_date == date1 # verify one row in task_reschedule table - task_reschedules = TaskReschedule.find_for_task_instance(ti) + task_reschedules = task_reschedules_for_ti(ti) assert len(task_reschedules) == 1 assert task_reschedules[0].start_date == date1 assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval) @@ -253,7 +263,7 @@ class TestBaseSensor: # verify task start date is the initial one assert ti.start_date == date1 # verify two rows in task_reschedule table - task_reschedules = TaskReschedule.find_for_task_instance(ti) + task_reschedules = task_reschedules_for_ti(ti) assert len(task_reschedules) == 2 assert task_reschedules[1].start_date == date2 assert task_reschedules[1].reschedule_date == date2 + timedelta(seconds=sensor.poke_interval) @@ -328,7 +338,7 @@ class TestBaseSensor: if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_ok_with_reschedule_and_retry(self, make_sensor, time_machine): + def test_ok_with_reschedule_and_retry(self, make_sensor, time_machine, task_reschedules_for_ti): sensor, dr = make_sensor( return_value=None, poke_interval=10, @@ -349,7 +359,7 @@ class TestBaseSensor: if ti.task_id == SENSOR_OP: assert ti.state == State.UP_FOR_RESCHEDULE # verify one row in task_reschedule table - task_reschedules = TaskReschedule.find_for_task_instance(ti) + task_reschedules = task_reschedules_for_ti(ti) assert len(task_reschedules) == 1 assert task_reschedules[0].start_date == date1 assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval) @@ -382,7 +392,7 @@ class TestBaseSensor: if ti.task_id == SENSOR_OP: assert ti.state == State.UP_FOR_RESCHEDULE # verify one row in task_reschedule table - task_reschedules = TaskReschedule.find_for_task_instance(ti) + task_reschedules = task_reschedules_for_ti(ti) assert len(task_reschedules) == 1 assert task_reschedules[0].start_date == date3 assert task_reschedules[0].reschedule_date == date3 + timedelta(seconds=sensor.poke_interval) @@ -413,7 +423,7 @@ class TestBaseSensor: with pytest.raises(AirflowException): DummySensor(task_id="a", mode="foo") - def test_ok_with_custom_reschedule_exception(self, make_sensor): + def test_ok_with_custom_reschedule_exception(self, make_sensor, task_reschedules_for_ti): sensor, dr = make_sensor(return_value=None, mode="reschedule") date1 = timezone.utcnow() date2 = date1 + timedelta(seconds=60) @@ -432,7 +442,7 @@ class TestBaseSensor: # verify task is re-scheduled, i.e. state set to NONE assert ti.state == State.UP_FOR_RESCHEDULE # verify one row in task_reschedule table - task_reschedules = TaskReschedule.find_for_task_instance(ti) + task_reschedules = task_reschedules_for_ti(ti) assert len(task_reschedules) == 1 assert task_reschedules[0].start_date == date1 assert task_reschedules[0].reschedule_date == date2 @@ -449,7 +459,7 @@ class TestBaseSensor: # verify task is re-scheduled, i.e. state set to NONE assert ti.state == State.UP_FOR_RESCHEDULE # verify two rows in task_reschedule table - task_reschedules = TaskReschedule.find_for_task_instance(ti) + task_reschedules = task_reschedules_for_ti(ti) assert len(task_reschedules) == 2 assert task_reschedules[1].start_date == date2 assert task_reschedules[1].reschedule_date == date3 @@ -467,7 +477,7 @@ class TestBaseSensor: if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_reschedule_with_test_mode(self, make_sensor): + def test_reschedule_with_test_mode(self, make_sensor, task_reschedules_for_ti): sensor, dr = make_sensor(return_value=None, poke_interval=10, timeout=25, mode="reschedule") sensor.poke = Mock(side_effect=[False]) @@ -482,7 +492,7 @@ class TestBaseSensor: # in test mode state is not modified assert ti.state == State.NONE # in test mode no reschedule request is recorded - task_reschedules = TaskReschedule.find_for_task_instance(ti) + task_reschedules = task_reschedules_for_ti(ti) assert len(task_reschedules) == 0 if ti.task_id == DUMMY_OP: assert ti.state == State.NONE diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index c9da328925..1d56da8ff3 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -18,166 +18,158 @@ from __future__ import annotations from datetime import timedelta -from unittest.mock import Mock, patch +from unittest.mock import patch + +import pytest +import time_machine +from slugify import slugify -from airflow.models.dag import DAG -from airflow.models.taskinstance import TaskInstance from airflow.models.taskreschedule import TaskReschedule from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep +from airflow.utils import timezone +from airflow.utils.session import create_session from airflow.utils.state import State -from airflow.utils.timezone import utcnow +from tests.test_utils import db +DEFAULT_DATE = timezone.datetime(2016, 1, 1) -class TestNotInReschedulePeriodDep: - def _get_task_instance(self, state): - dag = DAG("test_dag") - task = Mock(dag=dag, reschedule=True, is_mapped=False) - ti = TaskInstance(task=task, state=state, run_id=None) - return ti - def _get_task_reschedule(self, reschedule_date): - task = Mock(dag_id="test_dag", task_id="test_task", is_mapped=False) - reschedule = TaskReschedule( - task=task, - run_id=None, - try_number=None, - start_date=reschedule_date, - end_date=reschedule_date, - reschedule_date=reschedule_date, - ) - return reschedule +@pytest.fixture +def not_expected_tr_db_call(): + def side_effect(*args, **kwargs): + raise SystemError("Not expected DB call to `TaskReschedule` statement.") + + with patch("airflow.models.taskreschedule.TaskReschedule.stmt_for_task_instance") as m: + m.side_effect = side_effect + yield m - def _get_mapped_task_instance(self, state): - dag = DAG("test_dag") - task = Mock(dag=dag, reschedule=True, is_mapped=True) - ti = TaskInstance(task=task, state=state, run_id=None) - return ti - def _get_mapped_task_reschedule(self, reschedule_date): - task = Mock(dag_id="test_dag", task_id="test_task", is_mapped=True) - reschedule = TaskReschedule( - task=task, - run_id=None, - try_number=None, - start_date=reschedule_date, - end_date=reschedule_date, - reschedule_date=reschedule_date, +class TestNotInReschedulePeriodDep: + @pytest.fixture(autouse=True) + def setup_test_cases(self, request, create_task_instance): + db.clear_db_runs() + db.clear_rendered_ti_fields() + + self.dag_id = f"dag_{slugify(request.cls.__name__)}" + self.task_id = f"task_{slugify(request.node.name, max_length=40)}" + self.run_id = f"run_{slugify(request.node.name, max_length=40)}" + self.ti_maker = create_task_instance + + with time_machine.travel(DEFAULT_DATE, tick=False): + yield + db.clear_rendered_ti_fields() + db.clear_db_runs() + + def _get_task_instance(self, state, *, map_index=-1): + """Helper which create fake task_instance""" + ti = self.ti_maker( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + execution_date=DEFAULT_DATE, + map_index=map_index, + state=state, ) - return reschedule + ti.task.reschedule = True + return ti - def test_should_pass_if_ignore_in_reschedule_period_is_set(self): + def _create_task_reschedule(self, ti, minutes: int | list[int]): + """Helper which create fake task_reschedule(s) from task_instance.""" + if isinstance(minutes, int): + minutes = [minutes] + trs = [] + for minutes_timedelta in minutes: + dt = ti.execution_date + timedelta(minutes=minutes_timedelta) + trs.append( + TaskReschedule( + task=ti.task, + run_id=ti.run_id, + try_number=ti.try_number, + map_index=ti.map_index, + start_date=dt, + end_date=dt, + reschedule_date=dt, + ) + ) + with create_session() as session: + session.add_all(trs) + session.commit() + + def test_should_pass_if_ignore_in_reschedule_period_is_set(self, not_expected_tr_db_call): ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) dep_context = DepContext(ignore_in_reschedule_period=True) assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context) - def test_should_pass_if_not_reschedule_mode(self): + def test_should_pass_if_not_reschedule_mode(self, not_expected_tr_db_call): ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) del ti.task.reschedule assert ReadyToRescheduleDep().is_met(ti=ti) - def test_should_pass_if_not_in_none_state(self): + def test_should_pass_if_not_in_none_state(self, not_expected_tr_db_call): ti = self._get_task_instance(State.UP_FOR_RETRY) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [] + def test_should_pass_if_no_reschedule_record_exists(self): ti = self._get_task_instance(State.NONE) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( - self._get_task_reschedule(utcnow() - timedelta(minutes=1)) - ) + def test_should_pass_after_reschedule_date_one(self): ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) + self._create_task_reschedule(ti, -1) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [ - self._get_task_reschedule(utcnow() - timedelta(minutes=21)), - self._get_task_reschedule(utcnow() - timedelta(minutes=11)), - self._get_task_reschedule(utcnow() - timedelta(minutes=1)), - ][-1] + def test_should_pass_after_reschedule_date_multiple(self): ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) + self._create_task_reschedule(ti, [-21, -11, -1]) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( - self._get_task_reschedule(utcnow() + timedelta(minutes=1)) - ) - + def test_should_fail_before_reschedule_date_one(self): ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) + self._create_task_reschedule(ti, 1) assert not ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [ - self._get_task_reschedule(utcnow() - timedelta(minutes=19)), - self._get_task_reschedule(utcnow() - timedelta(minutes=9)), - self._get_task_reschedule(utcnow() + timedelta(minutes=1)), - ][-1] + def test_should_fail_before_reschedule_date_multiple(self): ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) + self._create_task_reschedule(ti, [-19, -9, 1]) + # Last TaskReschedule doesn't meet requirements assert not ReadyToRescheduleDep().is_met(ti=ti) - def test_mapped_task_should_pass_if_ignore_in_reschedule_period_is_set(self): - ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + def test_mapped_task_should_pass_if_ignore_in_reschedule_period_is_set(self, not_expected_tr_db_call): + ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42) dep_context = DepContext(ignore_in_reschedule_period=True) assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_mapped_task_should_pass_if_not_reschedule_mode(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [] - ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + def test_mapped_task_should_pass_if_not_reschedule_mode(self, not_expected_tr_db_call): + ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42) del ti.task.reschedule assert ReadyToRescheduleDep().is_met(ti=ti) - def test_mapped_task_should_pass_if_not_in_none_state(self): - ti = self._get_mapped_task_instance(State.UP_FOR_RETRY) + def test_mapped_task_should_pass_if_not_in_none_state(self, not_expected_tr_db_call): + ti = self._get_task_instance(State.UP_FOR_RETRY, map_index=42) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_mapped_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [] - ti = self._get_mapped_task_instance(State.NONE) + def test_mapped_should_pass_if_no_reschedule_record_exists(self): + ti = self._get_task_instance(State.NONE, map_index=42) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_mapped_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( - self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1)) - ) - ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + def test_mapped_should_pass_after_reschedule_date_one(self): + ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42) + self._create_task_reschedule(ti, [-1]) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_mapped_task_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [ - self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=21)), - self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=11)), - self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1)), - ][-1] - ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + def test_mapped_task_should_pass_after_reschedule_date_multiple(self): + ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42) + self._create_task_reschedule(ti, [-21, -11, -1]) assert ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_mapped_task_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( - self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1)) - ) - - ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + def test_mapped_task_should_fail_before_reschedule_date_one(self): + ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42) + self._create_task_reschedule(ti, 1) assert not ReadyToRescheduleDep().is_met(ti=ti) - @patch("airflow.models.taskreschedule.TaskReschedule.query_for_task_instance") - def test_mapped_task_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [ - self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=19)), - self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=9)), - self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1)), - ][-1] - ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + def test_mapped_task_should_fail_before_reschedule_date_multiple(self): + ti = self._get_task_instance(State.UP_FOR_RESCHEDULE, map_index=42) + self._create_task_reschedule(ti, [-19, -9, 1]) + # Last TaskReschedule doesn't meet requirements assert not ReadyToRescheduleDep().is_met(ti=ti)