This is an automated email from the ASF dual-hosted git repository. jhtimmins pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit f7bece3da14879712a895e897dad339944170ce3 Author: Ephraim Anierobi <splendidzig...@gmail.com> AuthorDate: Tue Jul 6 15:03:27 2021 +0100 Add 'queued' state to DagRun (#16401) This change adds queued state to DagRun. Newly created DagRuns start in the queued state, are then moved to the running state after satisfying the DAG's max_active_runs. If the Dag doesn't have max_active_runs, the DagRuns are moved to running state immediately Clearing a DagRun sets the state to queued state Closes: #9975, #16366 (cherry picked from commit 6611ffd399dce0474d8329720de7e83f568df598) --- airflow/api_connexion/openapi/v1.yaml | 4 +- airflow/jobs/scheduler_job.py | 149 ++-- ...93827b8_add_queued_at_column_to_dagrun_table.py | 49 ++ airflow/models/dag.py | 4 +- airflow/models/dagrun.py | 17 +- airflow/models/taskinstance.py | 6 +- airflow/www/static/js/tree.js | 4 +- airflow/www/views.py | 5 +- docs/apache-airflow/migrations-ref.rst | 4 +- tests/api/common/experimental/test_mark_tasks.py | 4 +- .../endpoints/test_dag_run_endpoint.py | 14 +- tests/api_connexion/schemas/test_dag_run_schema.py | 3 + tests/dag_processing/test_processor.py | 746 +++++++++++++++++++++ tests/jobs/test_scheduler_job.py | 206 +++--- tests/models/test_cleartasks.py | 37 + tests/models/test_dagrun.py | 25 +- tests/sensors/test_external_task_sensor.py | 8 +- tests/utils/test_dag_processing.py | 11 +- 18 files changed, 1037 insertions(+), 259 deletions(-) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 182f356..47fa3dc 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -1853,6 +1853,7 @@ components: description: | The start time. The time when DAG run was actually created. readOnly: true + nullable: true end_date: type: string format: date-time @@ -3025,8 +3026,9 @@ components: description: DAG State. type: string enum: - - success + - queued - running + - success - failed TriggerRule: diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index fe8e0b0..b7506b5 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -804,7 +804,7 @@ class SchedulerJob(BaseJob): """ For all DAG IDs in the DagBag, look for task instances in the old_states and set them to new_state if the corresponding DagRun - does not exist or exists but is not in the running state. This + does not exist or exists but is not in the running or queued state. This normally should not happen, but it can if the state of DagRuns are changed manually. @@ -821,7 +821,7 @@ class SchedulerJob(BaseJob): .filter(models.TaskInstance.state.in_(old_states)) .filter( or_( - models.DagRun.state != State.RUNNING, + models.DagRun.state.notin_([State.RUNNING, State.QUEUED]), models.DagRun.state.is_(None), ) ) @@ -1489,39 +1489,12 @@ class SchedulerJob(BaseJob): if settings.USE_JOB_SCHEDULE: self._create_dagruns_for_dags(guard, session) - dag_runs = self._get_next_dagruns_to_examine(session) + self._start_queued_dagruns(session) + guard.commit() + dag_runs = self._get_next_dagruns_to_examine(State.RUNNING, session) # Bulk fetch the currently active dag runs for the dags we are # examining, rather than making one query per DagRun - # TODO: This query is probably horribly inefficient (though there is an - # index on (dag_id,state)). It is to deal with the case when a user - # clears more than max_active_runs older tasks -- we don't want the - # scheduler to suddenly go and start running tasks from all of the - # runs. (AIRFLOW-137/GH #1442) - # - # The longer term fix would be to have `clear` do this, and put DagRuns - # in to the queued state, then take DRs out of queued before creating - # any new ones - - # Build up a set of execution_dates that are "active" for a given - # dag_id -- only tasks from those runs will be scheduled. - active_runs_by_dag_id = defaultdict(set) - - query = ( - session.query( - TI.dag_id, - TI.execution_date, - ) - .filter( - TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})), - TI.state.notin_(list(State.finished) + [State.REMOVED]), - ) - .group_by(TI.dag_id, TI.execution_date) - ) - - for dag_id, execution_date in query: - active_runs_by_dag_id[dag_id].add(execution_date) - for dag_run in dag_runs: # Use try_except to not stop the Scheduler when a Serialized DAG is not found # This takes care of Dynamic DAGs especially @@ -1530,7 +1503,7 @@ class SchedulerJob(BaseJob): # But this would take care of the scenario when the Scheduler is restarted after DagRun is # created and the DAG is deleted / renamed try: - self._schedule_dag_run(dag_run, active_runs_by_dag_id.get(dag_run.dag_id, set()), session) + self._schedule_dag_run(dag_run, session) except SerializedDagNotFound: self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id) continue @@ -1570,9 +1543,9 @@ class SchedulerJob(BaseJob): return num_queued_tis @retry_db_transaction - def _get_next_dagruns_to_examine(self, session): + def _get_next_dagruns_to_examine(self, state, session): """Get Next DagRuns to Examine with retries""" - return DagRun.next_dagruns_to_examine(session) + return DagRun.next_dagruns_to_examine(state, session) @retry_db_transaction def _create_dagruns_for_dags(self, guard, session): @@ -1593,7 +1566,7 @@ class SchedulerJob(BaseJob): # as DagModel.dag_id and DagModel.next_dagrun # This list is used to verify if the DagRun already exist so that we don't attempt to create # duplicate dag runs - active_dagruns = ( + existing_dagruns = ( session.query(DagRun.dag_id, DagRun.execution_date) .filter( tuple_(DagRun.dag_id, DagRun.execution_date).in_( @@ -1616,89 +1589,83 @@ class SchedulerJob(BaseJob): # are not updated. # We opted to check DagRun existence instead # of catching an Integrity error and rolling back the session i.e - # we need to run self._update_dag_next_dagruns if the Dag Run already exists or if we + # we need to set dag.next_dagrun_info if the Dag Run already exists or if we # create a new one. This is so that in the next Scheduling loop we try to create new runs # instead of falling in a loop of Integrity Error. - if (dag.dag_id, dag_model.next_dagrun) not in active_dagruns: - run = dag.create_dagrun( + if (dag.dag_id, dag_model.next_dagrun) not in existing_dagruns: + dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag_model.next_dagrun, - start_date=timezone.utcnow(), - state=State.RUNNING, + state=State.QUEUED, external_trigger=False, session=session, dag_hash=dag_hash, creating_job_id=self.id, ) - - expected_start_date = dag.following_schedule(run.execution_date) - if expected_start_date: - schedule_delay = run.start_date - expected_start_date - Stats.timing( - f'dagrun.schedule_delay.{dag.dag_id}', - schedule_delay, - ) - - self._update_dag_next_dagruns(dag_models, session) + dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info( + dag_model.next_dagrun + ) # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in # memory for larger dags? or expunge_all() - def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Session) -> None: - """ - Bulk update the next_dagrun and next_dagrun_create_after for all the dags. + def _start_queued_dagruns( + self, + session: Session, + ) -> int: + """Find DagRuns in queued state and decide moving them to running state""" + dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session) - We batch the select queries to get info about all the dags at once - """ - # Check max_active_runs, to see if we are _now_ at the limit for any of - # these dag? (we've just created a DagRun for them after all) - active_runs_of_dags = dict( + active_runs_of_dags = defaultdict( + lambda: 0, session.query(DagRun.dag_id, func.count('*')) - .filter( - DagRun.dag_id.in_([o.dag_id for o in dag_models]), + .filter( # We use `list` here because SQLA doesn't accept a set + # We use set to avoid duplicate dag_ids + DagRun.dag_id.in_(list({dr.dag_id for dr in dag_runs})), DagRun.state == State.RUNNING, - DagRun.external_trigger.is_(False), ) .group_by(DagRun.dag_id) - .all() + .all(), ) - for dag_model in dag_models: - # Get the DAG in a try_except to not stop the Scheduler when a Serialized DAG is not found - # This takes care of Dynamic DAGs especially + def _update_state(dag_run): + dag_run.state = State.RUNNING + dag_run.start_date = timezone.utcnow() + expected_start_date = dag.following_schedule(dag_run.execution_date) + if expected_start_date: + schedule_delay = dag_run.start_date - expected_start_date + Stats.timing( + f'dagrun.schedule_delay.{dag.dag_id}', + schedule_delay, + ) + + for dag_run in dag_runs: try: - dag = self.dagbag.get_dag(dag_model.dag_id, session=session) + dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) except SerializedDagNotFound: - self.log.exception("DAG '%s' not found in serialized_dag table", dag_model.dag_id) + self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id) continue - active_runs_of_dag = active_runs_of_dags.get(dag.dag_id, 0) - if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: - self.log.info( - "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", + active_runs = active_runs_of_dags[dag_run.dag_id] + if dag.max_active_runs and active_runs >= dag.max_active_runs: + self.log.debug( + "DAG %s already has %d active runs, not moving any more runs to RUNNING state %s", dag.dag_id, - active_runs_of_dag, - dag.max_active_runs, + active_runs, + dag_run.execution_date, ) - dag_model.next_dagrun_create_after = None else: - dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info( - dag_model.next_dagrun - ) + active_runs_of_dags[dag_run.dag_id] += 1 + _update_state(dag_run) def _schedule_dag_run( self, dag_run: DagRun, - currently_active_runs: Set[datetime.datetime], session: Session, ) -> int: """ Make scheduling decisions about an individual dag run - ``currently_active_runs`` is passed in so that a batch query can be - used to ask this for all dag runs in the batch, to avoid an n+1 query. - :param dag_run: The DagRun to schedule - :param currently_active_runs: Number of currently active runs of this DAG :return: Number of tasks scheduled """ dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) @@ -1725,9 +1692,6 @@ class SchedulerJob(BaseJob): session.flush() self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id) - # Work out if we should allow creating a new DagRun now? - self._update_dag_next_dagruns([session.query(DagModel).get(dag_run.dag_id)], session) - callback_to_execute = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=dag.dag_id, @@ -1745,19 +1709,6 @@ class SchedulerJob(BaseJob): self.log.error("Execution date is in future: %s", dag_run.execution_date) return 0 - if dag.max_active_runs: - if ( - len(currently_active_runs) >= dag.max_active_runs - and dag_run.execution_date not in currently_active_runs - ): - self.log.info( - "DAG %s already has %d active runs, not queuing any tasks for run %s", - dag.dag_id, - len(currently_active_runs), - dag_run.execution_date, - ) - return 0 - self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False) diff --git a/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py new file mode 100644 index 0000000..03caebc --- /dev/null +++ b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add queued_at column to dagrun table + +Revision ID: 97cdd93827b8 +Revises: a13f7613ad25 +Create Date: 2021-06-29 21:53:48.059438 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import mssql + +# revision identifiers, used by Alembic. +revision = '97cdd93827b8' +down_revision = 'a13f7613ad25' +branch_labels = None +depends_on = None + + +def upgrade(): + """Apply Add queued_at column to dagrun table""" + conn = op.get_bind() + if conn.dialect.name == "mssql": + op.add_column('dag_run', sa.Column('queued_at', mssql.DATETIME2(precision=6), nullable=True)) + else: + op.add_column('dag_run', sa.Column('queued_at', sa.DateTime(), nullable=True)) + + +def downgrade(): + """Unapply Add queued_at column to dagrun table""" + op.drop_column('dag_run', "queued_at") diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 13d69c2..a3d06db 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1153,7 +1153,7 @@ class DAG(LoggingMixin): confirm_prompt=False, include_subdags=True, include_parentdag=True, - dag_run_state: str = State.RUNNING, + dag_run_state: str = State.QUEUED, dry_run=False, session=None, get_tis=False, @@ -1369,7 +1369,7 @@ class DAG(LoggingMixin): confirm_prompt=False, include_subdags=True, include_parentdag=False, - dag_run_state=State.RUNNING, + dag_run_state=State.QUEUED, dry_run=False, ): all_tis = [] diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 5b8ac0c..c503ac4 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -61,12 +61,15 @@ class DagRun(Base, LoggingMixin): __tablename__ = "dag_run" + __NO_VALUE = object() + id = Column(Integer, primary_key=True) dag_id = Column(String(ID_LEN)) + queued_at = Column(UtcDateTime) execution_date = Column(UtcDateTime, default=timezone.utcnow) - start_date = Column(UtcDateTime, default=timezone.utcnow) + start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) - _state = Column('state', String(50), default=State.RUNNING) + _state = Column('state', String(50), default=State.QUEUED) run_id = Column(String(ID_LEN)) creating_job_id = Column(Integer) external_trigger = Column(Boolean, default=True) @@ -102,6 +105,7 @@ class DagRun(Base, LoggingMixin): self, dag_id: Optional[str] = None, run_id: Optional[str] = None, + queued_at: Optional[datetime] = __NO_VALUE, execution_date: Optional[datetime] = None, start_date: Optional[datetime] = None, external_trigger: Optional[bool] = None, @@ -118,6 +122,10 @@ class DagRun(Base, LoggingMixin): self.external_trigger = external_trigger self.conf = conf or {} self.state = state + if queued_at is self.__NO_VALUE: + self.queued_at = timezone.utcnow() if state == State.QUEUED else None + else: + self.queued_at = queued_at self.run_type = run_type self.dag_hash = dag_hash self.creating_job_id = creating_job_id @@ -140,6 +148,8 @@ class DagRun(Base, LoggingMixin): if self._state != state: self._state = state self.end_date = timezone.utcnow() if self._state in State.finished else None + if state == State.QUEUED: + self.queued_at = timezone.utcnow() @declared_attr def state(self): @@ -160,6 +170,7 @@ class DagRun(Base, LoggingMixin): @classmethod def next_dagruns_to_examine( cls, + state: str, session: Session, max_number: Optional[int] = None, ): @@ -180,7 +191,7 @@ class DagRun(Base, LoggingMixin): # TODO: Bake this query, it is run _A lot_ query = ( session.query(cls) - .filter(cls.state == State.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB) + .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB) .join( DagModel, DagModel.dag_id == cls.dag_id, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b99fa34..8d2578f 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -138,7 +138,7 @@ def clear_task_instances( session, activate_dag_runs=None, dag=None, - dag_run_state: Union[str, Literal[False]] = State.RUNNING, + dag_run_state: Union[str, Literal[False]] = State.QUEUED, ): """ Clears a set of task instances, but makes sure the running ones @@ -240,7 +240,9 @@ def clear_task_instances( ) for dr in drs: dr.state = dag_run_state - dr.start_date = timezone.utcnow() + dr.start_date = None + if dag_run_state == State.QUEUED: + dr.last_scheduling_decision = None class TaskInstanceKey(NamedTuple): diff --git a/airflow/www/static/js/tree.js b/airflow/www/static/js/tree.js index 4bf366a..d45c880 100644 --- a/airflow/www/static/js/tree.js +++ b/airflow/www/static/js/tree.js @@ -58,7 +58,9 @@ document.addEventListener('DOMContentLoaded', () => { const tree = d3.layout.tree().nodeSize([0, 25]); let nodes = tree.nodes(data); const nodeobj = {}; - const getActiveRuns = () => data.instances.filter((run) => run.state === 'running').length > 0; + const runActiveStates = ['queued', 'running']; + const getActiveRuns = () => data.instances + .filter((run) => runActiveStates.includes(run.state)).length > 0; const now = Date.now() / 1000; const devicePixelRatio = window.devicePixelRatio || 1; diff --git a/airflow/www/views.py b/airflow/www/views.py index 09d27e0..9af5e1b 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1526,7 +1526,7 @@ class Airflow(AirflowBaseView): dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=execution_date, - state=State.RUNNING, + state=State.QUEUED, conf=run_conf, external_trigger=True, dag_hash=current_app.dag_bag.dags_hash.get(dag_id), @@ -3451,6 +3451,7 @@ class DagRunModelView(AirflowModelView): 'execution_date', 'run_id', 'run_type', + 'queued_at', 'start_date', 'end_date', 'external_trigger', @@ -3786,7 +3787,7 @@ class TaskInstanceModelView(AirflowModelView): lazy_gettext('Clear'), lazy_gettext( 'Are you sure you want to clear the state of the selected task' - ' instance(s) and set their dagruns to the running state?' + ' instance(s) and set their dagruns to the QUEUED state?' ), single=False, ) diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 0c663da..0af143f 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -23,9 +23,7 @@ Here's the list of all the Database Migrations that are executed via when you ru +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ -| ``e9304a3141f0`` (head) | ``83f031fd9f1c`` | | Make XCom primary key columns non-nullable | -+--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ -| ``83f031fd9f1c`` | ``a13f7613ad25`` | | Improve MSSQL compatibility | +| ``97cdd93827b8`` (head) | ``a13f7613ad25`` | ``2.1.3`` | Add ``queued_at`` column in ``dag_run`` table | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ | ``a13f7613ad25`` | ``e165e7455d70`` | ``2.1.0`` | Resource based permissions for default ``Flask-AppBuilder`` views | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py index 4dab57e..49008d3 100644 --- a/tests/api/common/experimental/test_mark_tasks.py +++ b/tests/api/common/experimental/test_mark_tasks.py @@ -414,7 +414,9 @@ class TestMarkDAGRun(unittest.TestCase): assert ti.state == state def _create_test_dag_run(self, state, date): - return self.dag1.create_dagrun(run_type=DagRunType.MANUAL, state=state, execution_date=date) + return self.dag1.create_dagrun( + run_type=DagRunType.MANUAL, state=state, start_date=date, execution_date=date + ) def _verify_dag_run_state(self, dag, date, state): drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date) diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index e51eca8..0aa13b2 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -90,7 +90,7 @@ class TestDagRunEndpoint: def teardown_method(self) -> None: clear_db_runs() - # clear_db_dags() + clear_db_dags() def _create_dag(self, dag_id): dag_instance = DagModel(dag_id=dag_id) @@ -118,6 +118,7 @@ class TestDagRunEndpoint: execution_date=timezone.parse(self.default_time_2), start_date=timezone.parse(self.default_time), external_trigger=True, + state=state, ) dag_runs.append(dagrun_model_2) if extra_dag: @@ -131,6 +132,7 @@ class TestDagRunEndpoint: execution_date=timezone.parse(self.default_time_2), start_date=timezone.parse(self.default_time), external_trigger=True, + state=state, ) ) if commit: @@ -193,6 +195,7 @@ class TestGetDagRun(TestDagRunEndpoint): execution_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), external_trigger=True, + state='running', ) session.add(dagrun_model) session.commit() @@ -532,7 +535,7 @@ class TestGetDagRunsEndDateFilters(TestDagRunEndpoint): ( f"api/v1/dags/TEST_DAG_ID/dagRuns?end_date_lte=" f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}", - ["TEST_DAG_RUN_ID_1"], + ["TEST_DAG_RUN_ID_1", "TEST_DAG_RUN_ID_2"], ), ] ) @@ -750,6 +753,7 @@ class TestGetDagRunBatchPagination(TestDagRunEndpoint): DagRun( dag_id="TEST_DAG_ID", run_id="TEST_DAG_RUN_ID" + str(i), + state='running', run_type=DagRunType.MANUAL, execution_date=timezone.parse(self.default_time) + timedelta(minutes=i), start_date=timezone.parse(self.default_time), @@ -884,7 +888,7 @@ class TestGetDagRunBatchDateFilters(TestDagRunEndpoint): ), ( {"end_date_lte": f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}"}, - ["TEST_DAG_RUN_ID_1"], + ["TEST_DAG_RUN_ID_1", "TEST_DAG_RUN_ID_2"], ), ] ) @@ -927,8 +931,8 @@ class TestPostDagRun(TestDagRunEndpoint): "end_date": None, "execution_date": response.json["execution_date"], "external_trigger": True, - "start_date": response.json["start_date"], - "state": "running", + "start_date": None, + "state": "queued", } == response.json @parameterized.expand( diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py index 3e6bf2e..9e4a9e8 100644 --- a/tests/api_connexion/schemas/test_dag_run_schema.py +++ b/tests/api_connexion/schemas/test_dag_run_schema.py @@ -49,6 +49,7 @@ class TestDAGRunSchema(TestDAGRunBase): def test_serialize(self, session): dagrun_model = DagRun( run_id="my-dag-run", + state='running', run_type=DagRunType.MANUAL.value, execution_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), @@ -124,6 +125,7 @@ class TestDagRunCollection(TestDAGRunBase): def test_serialize(self, session): dagrun_model_1 = DagRun( run_id="my-dag-run", + state='running', execution_date=timezone.parse(self.default_time), run_type=DagRunType.MANUAL.value, start_date=timezone.parse(self.default_time), @@ -131,6 +133,7 @@ class TestDagRunCollection(TestDAGRunBase): ) dagrun_model_2 = DagRun( run_id="my-dag-run-2", + state='running', execution_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), run_type=DagRunType.MANUAL.value, diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py new file mode 100644 index 0000000..9425bbb --- /dev/null +++ b/tests/dag_processing/test_processor.py @@ -0,0 +1,746 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import datetime +import os +import unittest +from datetime import timedelta +from tempfile import NamedTemporaryFile +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from parameterized import parameterized + +from airflow import settings +from airflow.configuration import conf +from airflow.dag_processing.processor import DagFileProcessor +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance +from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.taskinstance import SimpleTaskInstance +from airflow.operators.bash import BashOperator +from airflow.operators.dummy import DummyOperator +from airflow.serialization.serialized_objects import SerializedDAG +from airflow.utils import timezone +from airflow.utils.callback_requests import TaskCallbackRequest +from airflow.utils.dates import days_ago +from airflow.utils.session import create_session +from airflow.utils.state import State +from airflow.utils.types import DagRunType +from tests.test_utils.config import conf_vars, env_vars +from tests.test_utils.db import ( + clear_db_dags, + clear_db_import_errors, + clear_db_jobs, + clear_db_pools, + clear_db_runs, + clear_db_serialized_dags, + clear_db_sla_miss, +) +from tests.test_utils.mock_executor import MockExecutor + +DEFAULT_DATE = timezone.datetime(2016, 1, 1) + + +@pytest.fixture(scope="class") +def disable_load_example(): + with conf_vars({('core', 'load_examples'): 'false'}): + with env_vars({('core', 'load_examples'): 'false'}): + yield + + +@pytest.mark.usefixtures("disable_load_example") +class TestDagFileProcessor(unittest.TestCase): + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_pools() + clear_db_dags() + clear_db_sla_miss() + clear_db_import_errors() + clear_db_jobs() + clear_db_serialized_dags() + + def setUp(self): + self.clean_db() + + # Speed up some tests by not running the tasks, just look at what we + # enqueue! + self.null_exec = MockExecutor() + self.scheduler_job = None + + def tearDown(self) -> None: + if self.scheduler_job and self.scheduler_job.processor_agent: + self.scheduler_job.processor_agent.end() + self.scheduler_job = None + self.clean_db() + + def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs): + dag = DAG( + dag_id='test_scheduler_reschedule', + start_date=start_date, + # Make sure it only creates a single DAG Run + end_date=end_date, + ) + dag.clear() + dag.is_subdag = False + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False) + session.merge(orm_dag) + session.commit() + return dag + + @classmethod + def setUpClass(cls): + # Ensure the DAGs we are looking at from the DB are up-to-date + non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False) + non_serialized_dagbag.sync_to_db() + cls.dagbag = DagBag(read_dags_from_db=True) + + def test_dag_file_processor_sla_miss_callback(self): + """ + Test that the dag file processor calls the sla miss callback + """ + session = settings.Session() + + sla_callback = MagicMock() + + # Create dag with a start of 1 day ago, but an sla of 0 + # so we'll already have an sla_miss on the books. + test_start_date = days_ago(1) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta()}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) + + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor.manage_slas(dag=dag, session=session) + + assert sla_callback.called + + def test_dag_file_processor_sla_miss_callback_invalid_sla(self): + """ + Test that the dag file processor does not call the sla miss callback when + given an invalid sla + """ + session = settings.Session() + + sla_callback = MagicMock() + + # Create dag with a start of 1 day ago, but an sla of 0 + # so we'll already have an sla_miss on the books. + # Pass anything besides a timedelta object to the sla argument. + test_start_date = days_ago(1) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': None}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) + + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor.manage_slas(dag=dag, session=session) + sla_callback.assert_not_called() + + def test_dag_file_processor_sla_miss_callback_sent_notification(self): + """ + Test that the dag file processor does not call the sla_miss_callback when a + notification has already been sent + """ + session = settings.Session() + + # Mock the callback function so we can verify that it was not called + sla_callback = MagicMock() + + # Create dag with a start of 2 days ago, but an sla of 1 day + # ago so we'll already have an sla_miss on the books + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + # Create a TaskInstance for two days ago + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge( + SlaMiss( + task_id='dummy', + dag_id='test_sla_miss', + execution_date=test_start_date, + email_sent=False, + notification_sent=True, + ) + ) + + # Now call manage_slas and see if the sla_miss callback gets called + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor.manage_slas(dag=dag, session=session) + + sla_callback.assert_not_called() + + def test_dag_file_processor_sla_miss_callback_exception(self): + """ + Test that the dag file processor gracefully logs an exception if there is a problem + calling the sla_miss_callback + """ + session = settings.Session() + + sla_callback = MagicMock(side_effect=RuntimeError('Could not call function')) + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1)) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + # Now call manage_slas and see if the sla_miss callback gets called + mock_log = mock.MagicMock() + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + dag_file_processor.manage_slas(dag=dag, session=session) + assert sla_callback.called + mock_log.exception.assert_called_once_with( + 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss' + ) + + @mock.patch('airflow.dag_processing.processor.send_email') + def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email): + session = settings.Session() + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + email1 = 'te...@test.com' + task = DummyOperator( + task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1) + ) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + email2 = 'te...@test.com' + DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2) + + session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date)) + + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + + dag_file_processor.manage_slas(dag=dag, session=session) + + assert len(mock_send_email.call_args_list) == 1 + + send_email_to = mock_send_email.call_args_list[0][0][0] + assert email1 in send_email_to + assert email2 not in send_email_to + + @mock.patch('airflow.dag_processing.processor.Stats.incr') + @mock.patch("airflow.utils.email.send_email") + def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr): + """ + Test that the dag file processor gracefully logs an exception if there is a problem + sending an email + """ + session = settings.Session() + + # Mock the callback function so we can verify that it was not called + mock_send_email.side_effect = RuntimeError('Could not send an email') + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + task = DummyOperator( + task_id='dummy', dag=dag, owner='airflow', email='t...@test.com', sla=datetime.timedelta(hours=1) + ) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + mock_log = mock.MagicMock() + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + + dag_file_processor.manage_slas(dag=dag, session=session) + mock_log.exception.assert_called_once_with( + 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss' + ) + mock_stats_incr.assert_called_once_with('sla_email_notification_failure') + + def test_dag_file_processor_sla_miss_deleted_task(self): + """ + Test that the dag file processor will not crash when trying to send + sla miss notification for a deleted task + """ + session = settings.Session() + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + task = DummyOperator( + task_id='dummy', dag=dag, owner='airflow', email='t...@test.com', sla=datetime.timedelta(hours=1) + ) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge( + SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date) + ) + + mock_log = mock.MagicMock() + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + dag_file_processor.manage_slas(dag=dag, session=session) + + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) + def test_dag_file_processor_process_task_instances(self, state, start_date, end_date): + """ + Test if _process_task_instances puts the right task instances into the + mock_list. + """ + dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') + + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + dag.clear() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None + + with create_session() as session: + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date + + count = self.scheduler_job._schedule_dag_run(dr, session) + assert count == 1 + + session.refresh(ti) + assert ti.state == State.SCHEDULED + + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) + def test_dag_file_processor_process_task_instances_with_task_concurrency( + self, + state, + start_date, + end_date, + ): + """ + Test if _process_task_instances puts the right task instances into the + mock_list. + """ + dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi') + + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + dag.clear() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None + + with create_session() as session: + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date + + count = self.scheduler_job._schedule_dag_run(dr, session) + assert count == 1 + + session.refresh(ti) + assert ti.state == State.SCHEDULED + + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) + def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date): + """ + Test if _process_task_instances puts the right task instances into the + mock_list. + """ + dag = DAG( + dag_id='test_scheduler_process_execute_task_depends_on_past', + start_date=DEFAULT_DATE, + default_args={ + 'depends_on_past': True, + }, + ) + BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi') + BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi') + + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + dag.clear() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None + + with create_session() as session: + tis = dr.get_task_instances(session=session) + for ti in tis: + ti.state = state + ti.start_date = start_date + ti.end_date = end_date + + count = self.scheduler_job._schedule_dag_run(dr, session) + assert count == 2 + + session.refresh(tis[0]) + session.refresh(tis[1]) + assert tis[0].state == State.SCHEDULED + assert tis[1].state == State.SCHEDULED + + def test_scheduler_job_add_new_task(self): + """ + Test if a task instance will be added if the dag is updated + """ + dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test') + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + + # Since we don't want to store the code for the DAG defined in this file + with mock.patch.object(settings, "STORE_DAG_CODE", False): + self.scheduler_job.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + if self.scheduler_job.processor_agent: + self.scheduler_job.processor_agent.end() + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session) + self.scheduler_job._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + tis = dr.get_task_instances() + assert len(tis) == 1 + + BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test') + SerializedDagModel.write_dag(dag=dag) + + scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session) + session.flush() + assert scheduled_tis == 2 + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + tis = dr.get_task_instances() + assert len(tis) == 2 + + def test_runs_respected_after_clear(self): + """ + Test dag after dag.clear, max_active_runs is respected + """ + dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE) + dag.max_active_runs = 1 + + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi') + + session = settings.Session() + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + session.commit() + session.close() + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + # Write Dag to DB + dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False) + dagbag.bag_dag(dag, root_dag=dag) + dagbag.sync_to_db() + + dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag.dag_id) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + + date = DEFAULT_DATE + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.QUEUED, + ) + date = dag.following_schedule(date) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.QUEUED, + ) + date = dag.following_schedule(date) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.QUEUED, + ) + dag.clear() + + assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 3 + + session = settings.Session() + self.scheduler_job._start_queued_dagruns(session) + session.commit() + # Assert that only 1 dagrun is active + assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1 + # Assert that the other two are queued + assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 2 + + @patch.object(TaskInstance, 'handle_failure_with_callback') + def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): + dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + with create_session() as session: + session.query(TaskInstance).delete() + dag = dagbag.get_dag('example_branch_operator') + task = dag.get_task(task_id='run_this_first') + + ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) + + session.add(ti) + session.commit() + + requests = [ + TaskCallbackRequest( + full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" + ) + ] + dag_file_processor.execute_callbacks(dagbag, requests) + mock_ti_handle_failure.assert_called_once_with( + error="Message", + test_mode=conf.getboolean('core', 'unit_test_mode'), + ) + + def test_process_file_should_failure_callback(self): + dag_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py' + ) + dagbag = DagBag(dag_folder=dag_file, include_examples=False) + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + with create_session() as session, NamedTemporaryFile(delete=False) as callback_file: + session.query(TaskInstance).delete() + dag = dagbag.get_dag('test_om_failure_callback_dag') + task = dag.get_task(task_id='test_om_failure_callback_task') + + ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) + + session.add(ti) + session.commit() + + requests = [ + TaskCallbackRequest( + full_filepath=dag.full_filepath, + simple_task_instance=SimpleTaskInstance(ti), + msg="Message", + ) + ] + callback_file.close() + + with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}): + dag_file_processor.process_file(dag_file, requests) + with open(callback_file.name) as callback_file2: + content = callback_file2.read() + assert "Callback fired" == content + os.remove(callback_file.name) + + def test_should_mark_dummy_task_as_success(self): + dag_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py' + ) + + # Write DAGs to dag and serialized_dag table + dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False) + dagbag.sync_to_db() + + self.scheduler_job_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job_job.processor_agent = mock.MagicMock() + dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks") + + # Create DagRun + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + self.scheduler_job_job._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Schedule TaskInstances + self.scheduler_job_job._schedule_dag_run(dr, session) + with create_session() as session: + tis = session.query(TaskInstance).all() + + dags = self.scheduler_job_job.dagbag.dags.values() + assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags] + assert 5 == len(tis) + assert { + ('test_task_a', 'success'), + ('test_task_b', None), + ('test_task_c', 'success'), + ('test_task_on_execute', 'scheduled'), + ('test_task_on_success', 'scheduled'), + } == {(ti.task_id, ti.state) for ti in tis} + for state, start_date, end_date, duration in [ + (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis + ]: + if state == 'success': + assert start_date is not None + assert end_date is not None + assert 0.0 == duration + else: + assert start_date is None + assert end_date is None + assert duration is None + + self.scheduler_job_job._schedule_dag_run(dr, session) + with create_session() as session: + tis = session.query(TaskInstance).all() + + assert 5 == len(tis) + assert { + ('test_task_a', 'success'), + ('test_task_b', 'success'), + ('test_task_c', 'success'), + ('test_task_on_execute', 'scheduled'), + ('test_task_on_success', 'scheduled'), + } == {(ti.task_id, ti.state) for ti in tis} + for state, start_date, end_date, duration in [ + (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis + ]: + if state == 'success': + assert start_date is not None + assert end_date is not None + assert 0.0 == duration + else: + assert start_date is None + assert end_date is None + assert duration is None diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fe0b257..33c82d7 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -23,7 +23,6 @@ import shutil import unittest from datetime import timedelta from tempfile import NamedTemporaryFile, mkdtemp -from time import sleep from unittest import mock from unittest.mock import MagicMock, patch from zipfile import ZipFile @@ -1114,7 +1113,6 @@ class TestSchedulerJob(unittest.TestCase): session.flush() res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session) - assert 2 == len(res) res_keys = map(lambda x: x.key, res) assert ti_no_dagrun.key in res_keys @@ -2259,15 +2257,16 @@ class TestSchedulerJob(unittest.TestCase): dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job._create_dag_runs([orm_dag], session) + self.scheduler_job._start_queued_dagruns(session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 dr = drs[0] - # Should not be able to create a new dag run, as we are at max active runs - assert orm_dag.next_dagrun_create_after is None + # This should have a value since we control max_active_runs + # by DagRun State. + assert orm_dag.next_dagrun_create_after # But we should record the date of _what run_ it would be assert isinstance(orm_dag.next_dagrun, datetime.datetime) @@ -2279,7 +2278,7 @@ class TestSchedulerJob(unittest.TestCase): self.scheduler_job.processor_agent = mock.Mock() self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock() - self.scheduler_job._schedule_dag_run(dr, {}, session) + self.scheduler_job._schedule_dag_run(dr, session) session.flush() session.refresh(dr) @@ -2336,7 +2335,7 @@ class TestSchedulerJob(unittest.TestCase): self.scheduler_job.processor_agent = mock.Mock() self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock() - self.scheduler_job._schedule_dag_run(dr, {}, session) + self.scheduler_job._schedule_dag_run(dr, session) session.flush() session.refresh(dr) @@ -2395,7 +2394,7 @@ class TestSchedulerJob(unittest.TestCase): ti = dr.get_task_instance('dummy') ti.set_state(state, session) - self.scheduler_job._schedule_dag_run(dr, {}, session) + self.scheduler_job._schedule_dag_run(dr, session) expected_callback = DagCallbackRequest( full_filepath=dr.dag.fileloc, @@ -2450,7 +2449,7 @@ class TestSchedulerJob(unittest.TestCase): ti = dr.get_task_instance('test_task') ti.set_state(state, session) - self.scheduler_job._schedule_dag_run(dr, set(), session) + self.scheduler_job._schedule_dag_run(dr, session) # Verify Callback is not set (i.e is None) when no callbacks are set on DAG self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once_with(dr, None) @@ -2830,13 +2829,13 @@ class TestSchedulerJob(unittest.TestCase): execution_date=DEFAULT_DATE, state=State.RUNNING, ) - self.scheduler_job._schedule_dag_run(dr, {}, session) + self.scheduler_job._schedule_dag_run(dr, session) dr = dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag.following_schedule(dr.execution_date), state=State.RUNNING, ) - self.scheduler_job._schedule_dag_run(dr, {}, session) + self.scheduler_job._schedule_dag_run(dr, session) task_instances_list = self.scheduler_job._executable_task_instances_to_queued( max_tis=32, session=session ) @@ -2887,7 +2886,7 @@ class TestSchedulerJob(unittest.TestCase): execution_date=date, state=State.RUNNING, ) - self.scheduler_job._schedule_dag_run(dr, {}, session) + self.scheduler_job._schedule_dag_run(dr, session) date = dag.following_schedule(date) task_instances_list = self.scheduler_job._executable_task_instances_to_queued( @@ -2950,7 +2949,7 @@ class TestSchedulerJob(unittest.TestCase): execution_date=date, state=State.RUNNING, ) - scheduler._schedule_dag_run(dr, {}, session) + scheduler._schedule_dag_run(dr, session) date = dag_d1.following_schedule(date) date = DEFAULT_DATE @@ -2960,7 +2959,7 @@ class TestSchedulerJob(unittest.TestCase): execution_date=date, state=State.RUNNING, ) - scheduler._schedule_dag_run(dr, {}, session) + scheduler._schedule_dag_run(dr, session) date = dag_d2.following_schedule(date) scheduler._executable_task_instances_to_queued(max_tis=2, session=session) @@ -3037,7 +3036,7 @@ class TestSchedulerJob(unittest.TestCase): execution_date=DEFAULT_DATE, state=State.RUNNING, ) - self.scheduler_job._schedule_dag_run(dr, {}, session) + self.scheduler_job._schedule_dag_run(dr, session) task_instances_list = self.scheduler_job._executable_task_instances_to_queued( max_tis=32, session=session @@ -3096,7 +3095,7 @@ class TestSchedulerJob(unittest.TestCase): # Verify that DagRun.verify_integrity is not called with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity: - scheduled_tis = self.scheduler_job._schedule_dag_run(dr, {}, session) + scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session) mock_verify_integrity.assert_not_called() session.flush() @@ -3159,7 +3158,7 @@ class TestSchedulerJob(unittest.TestCase): dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) assert dag_version_2 != dag_version_1 - scheduled_tis = self.scheduler_job._schedule_dag_run(dr, {}, session) + scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session) session.flush() assert scheduled_tis == 2 @@ -3871,14 +3870,13 @@ class TestSchedulerJob(unittest.TestCase): full_filepath=dag.fileloc, dag_id=dag_id ) - @freeze_time(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9)) - @mock.patch('airflow.jobs.scheduler_job.Stats.timing') - def test_create_dag_runs(self, stats_timing): + def test_create_dag_runs(self): """ Test various invariants of _create_dag_runs. - That the run created has the creating_job_id set - - That we emit the right DagRun metrics + - That the run created is on QUEUED State + - That dag_model has next_dagrun """ dag = DAG(dag_id='test_create_dag_runs', start_date=DEFAULT_DATE) @@ -3902,8 +3900,51 @@ class TestSchedulerJob(unittest.TestCase): with create_session() as session: self.scheduler_job._create_dag_runs([dag_model], session) + dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first() + # Assert dr state is queued + assert dr.state == State.QUEUED + assert dr.start_date is None + + assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id + + @freeze_time(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9)) + @mock.patch('airflow.jobs.scheduler_job.Stats.timing') + def test_start_dagruns(self, stats_timing): + """ + Test that _start_dagrun: + + - moves runs to RUNNING State + - emit the right DagRun metrics + """ + dag = DAG(dag_id='test_start_dag_runs', start_date=DEFAULT_DATE) + + DummyOperator( + task_id='dummy', + dag=dag, + ) + + dagbag = DagBag( + dag_folder=os.devnull, + include_examples=False, + read_dags_from_db=True, + ) + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() + dag_model = DagModel.get_dagmodel(dag.dag_id) + + self.scheduler_job = SchedulerJob(executor=self.null_exec) + self.scheduler_job.processor_agent = mock.MagicMock() + + with create_session() as session: + self.scheduler_job._create_dag_runs([dag_model], session) + self.scheduler_job._start_queued_dagruns(session) + + dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first() + # Assert dr state is running + assert dr.state == State.RUNNING + stats_timing.assert_called_once_with( - "dagrun.schedule_delay.test_create_dag_runs", datetime.timedelta(seconds=9) + "dagrun.schedule_delay.test_start_dag_runs", datetime.timedelta(seconds=9) ) assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id @@ -4102,61 +4143,7 @@ class TestSchedulerJob(unittest.TestCase): assert dag_model.next_dagrun == DEFAULT_DATE + timedelta(days=1) session.rollback() - def test_do_schedule_max_active_runs_upstream_failed(self): - """ - Test that tasks in upstream failed don't count as actively running. - - This test can be removed when adding a queued state to DagRuns. - """ - - with DAG( - dag_id='test_max_active_run_with_upstream_failed', - start_date=DEFAULT_DATE, - schedule_interval='@once', - max_active_runs=1, - ) as dag: - # Can't use DummyOperator as that goes straight to success - task1 = BashOperator(task_id='dummy1', bash_command='true') - - session = settings.Session() - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=True, - ) - - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db(session=session) - - run1 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=DEFAULT_DATE, - state=State.RUNNING, - session=session, - ) - - ti = run1.get_task_instance(task1.task_id, session) - ti.state = State.UPSTREAM_FAILED - - run2 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=DEFAULT_DATE + timedelta(hours=1), - state=State.RUNNING, - session=session, - ) - - dag.sync_to_db(session=session) # Update the date fields - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) - - num_queued = self.scheduler_job._do_scheduling(session) - - assert num_queued == 1 - ti = run2.get_task_instance(task1.task_id, session) - assert ti.state == State.QUEUED - + @conf_vars({('scheduler', 'use_job_schedule'): "false"}) def test_do_schedule_max_active_runs_dag_timed_out(self): """Test that tasks are set to a finished state when their DAG times out""" @@ -4189,33 +4176,36 @@ class TestSchedulerJob(unittest.TestCase): run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, state=State.RUNNING, + start_date=timezone.utcnow() - timedelta(seconds=2), session=session, ) + run1_ti = run1.get_task_instance(task1.task_id, session) run1_ti.state = State.RUNNING - sleep(1) - run2 = dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE + timedelta(seconds=10), - state=State.RUNNING, + state=State.QUEUED, session=session, ) dag.sync_to_db(session=session) - self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.executor = MockExecutor() self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) - _ = self.scheduler_job._do_scheduling(session) - + self.scheduler_job._do_scheduling(session) + session.add(run1) + session.refresh(run1) assert run1.state == State.FAILED assert run1_ti.state == State.SKIPPED - assert run2.state == State.RUNNING - _ = self.scheduler_job._do_scheduling(session) + # Run scheduling again to assert run2 has started + self.scheduler_job._do_scheduling(session) + session.add(run2) + session.refresh(run2) + assert run2.state == State.RUNNING run2_ti = run2.get_task_instance(task1.task_id, session) assert run2_ti.state == State.QUEUED @@ -4265,8 +4255,8 @@ class TestSchedulerJob(unittest.TestCase): def test_do_schedule_max_active_runs_and_manual_trigger(self): """ - Make sure that when a DAG is already at max_active_runs, that manually triggering a run doesn't cause - the dag to "stall". + Make sure that when a DAG is already at max_active_runs, that manually triggered + dagruns don't start running. """ with DAG( @@ -4281,7 +4271,7 @@ class TestSchedulerJob(unittest.TestCase): task1 >> task2 - task3 = BashOperator(task_id='dummy3', bash_command='true') + BashOperator(task_id='dummy3', bash_command='true') session = settings.Session() dagbag = DagBag( @@ -4296,7 +4286,7 @@ class TestSchedulerJob(unittest.TestCase): dag_run = dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, - state=State.RUNNING, + state=State.QUEUED, session=session, ) @@ -4314,47 +4304,23 @@ class TestSchedulerJob(unittest.TestCase): assert num_queued == 2 assert dag_run.state == State.RUNNING - ti1 = dag_run.get_task_instance(task1.task_id, session) - assert ti1.state == State.QUEUED - - # Set task1 to success (so task2 can run) but keep task3 as "running" - ti1.state = State.SUCCESS - - ti3 = dag_run.get_task_instance(task3.task_id, session) - ti3.state = State.RUNNING - - session.flush() - - # At this point, ti2 and ti3 of the scheduled dag run should be running - num_queued = self.scheduler_job._do_scheduling(session) - - assert num_queued == 1 - # Should have queued task2 - ti2 = dag_run.get_task_instance(task2.task_id, session) - assert ti2.state == State.QUEUED - - ti2.state = None - session.flush() # Now that this one is running, manually trigger a dag. - manual_run = dag.create_dagrun( + dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE + timedelta(hours=1), - state=State.RUNNING, + state=State.QUEUED, session=session, ) session.flush() - num_queued = self.scheduler_job._do_scheduling(session) + self.scheduler_job._do_scheduling(session) - assert num_queued == 1 - # Should have queued task2 again. - ti2 = dag_run.get_task_instance(task2.task_id, session) - assert ti2.state == State.QUEUED - # Manual run shouldn't have been started, because we're at max_active_runs with DR1 - ti1 = manual_run.get_task_instance(task1.task_id, session) - assert ti1.state is None + # Assert that only 1 dagrun is active + assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1 + # Assert that the other one is queued + assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 1 @pytest.mark.xfail(reason="Work out where this goes") diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index 9b8fbd0..4f64347 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -19,6 +19,8 @@ import datetime import unittest +from parameterized import parameterized + from airflow import settings from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances from airflow.operators.dummy import DummyOperator @@ -92,6 +94,41 @@ class TestClearTasks(unittest.TestCase): assert ti0.state is None assert ti0.external_executor_id is None + @parameterized.expand([(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)]) + def test_clear_task_instances_dr_state(self, state, last_scheduling): + """Test that DR state is set to None after clear. + And that DR.last_scheduling_decision is handled OK. + start_date is also set to None + """ + dag = DAG( + 'test_clear_task_instances', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) + task0 = DummyOperator(task_id='0', owner='test', dag=dag) + task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2) + ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + session = settings.Session() + dr = dag.create_dagrun( + execution_date=ti0.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + dr.last_scheduling_decision = DEFAULT_DATE + session.add(dr) + session.commit() + + ti0.run() + ti1.run() + qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + clear_task_instances(qry, session, dag_run_state=state, dag=dag) + + dr = ti0.get_dagrun() + assert dr.state == state + assert dr.start_date is None + assert dr.last_scheduling_decision == last_scheduling + def test_clear_task_instances_without_task(self): dag = DAG( 'test_clear_task_instances_without_task', diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 7899199..fac38bf 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -39,7 +39,7 @@ from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE -from tests.test_utils.db import clear_db_jobs, clear_db_pools, clear_db_runs +from tests.test_utils.db import clear_db_dags, clear_db_jobs, clear_db_pools, clear_db_runs class TestDagRun(unittest.TestCase): @@ -50,6 +50,12 @@ class TestDagRun(unittest.TestCase): def setUp(self): clear_db_runs() clear_db_pools() + clear_db_dags() + + def tearDown(self) -> None: + clear_db_runs() + clear_db_pools() + clear_db_dags() def create_dag_run( self, @@ -102,7 +108,7 @@ class TestDagRun(unittest.TestCase): session.commit() ti0.refresh_from_db() dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == now).first() - assert dr0.state == State.RUNNING + assert dr0.state == State.QUEUED def test_dagrun_find(self): session = settings.Session() @@ -692,9 +698,11 @@ class TestDagRun(unittest.TestCase): ti.run() assert (ti.state == State.SUCCESS) == is_ti_success - def test_next_dagruns_to_examine_only_unpaused(self): + @parameterized.expand([(State.QUEUED,), (State.RUNNING,)]) + def test_next_dagruns_to_examine_only_unpaused(self, state): """ Check that "next_dagruns_to_examine" ignores runs from paused/inactive DAGs + and gets running/queued dagruns """ dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) @@ -712,25 +720,22 @@ class TestDagRun(unittest.TestCase): session.flush() dr = dag.create_dagrun( run_type=DagRunType.SCHEDULED, - state=State.RUNNING, + state=state, execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, + start_date=DEFAULT_DATE if state == State.RUNNING else None, session=session, ) - runs = DagRun.next_dagruns_to_examine(session).all() + runs = DagRun.next_dagruns_to_examine(state, session).all() assert runs == [dr] orm_dag.is_paused = True session.flush() - runs = DagRun.next_dagruns_to_examine(session).all() + runs = DagRun.next_dagruns_to_examine(state, session).all() assert runs == [] - session.rollback() - session.close() - @mock.patch.object(Stats, 'timing') def test_no_scheduling_delay_for_nonscheduled_runs(self, stats_mock): """ diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 187fdb2..274766c 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -545,16 +545,16 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child): task_0 = dag_0.get_task("task_0") clear_tasks(dag_bag, dag_0, task_0, start_date=day_1, end_date=day_2) - # Assert that dagruns of all the affected dags are set to RUNNING after tasks are cleared. + # Assert that dagruns of all the affected dags are set to QUEUED after tasks are cleared. # Unaffected dagruns should be left as SUCCESS. dagrun_0_1 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_1) dagrun_0_2 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_2) dagrun_1_1 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_1) dagrun_1_2 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_2) - assert dagrun_0_1.state == State.RUNNING - assert dagrun_0_2.state == State.RUNNING - assert dagrun_1_1.state == State.RUNNING + assert dagrun_0_1.state == State.QUEUED + assert dagrun_0_2.state == State.QUEUED + assert dagrun_1_1.state == State.QUEUED assert dagrun_1_2.state == State.SUCCESS diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 58ad010..e38c184 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -35,7 +35,7 @@ from freezegun import freeze_time from airflow.configuration import conf from airflow.jobs.local_task_job import LocalTaskJob as LJ -from airflow.jobs.scheduler_job import DagFileProcessorProcess +from airflow.jobs.scheduler_job import DagFileProcessorProcess, SchedulerJob from airflow.models import DagBag, DagModel, TaskInstance as TI from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance @@ -508,8 +508,8 @@ class TestDagFileProcessorManager(unittest.TestCase): child_pipe.close() parent_pipe.close() - @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.pid", new_callable=PropertyMock) - @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.kill") + @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock) + @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.kill") def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid): mock_pid.return_value = 1234 manager = DagFileProcessorManager( @@ -529,8 +529,8 @@ class TestDagFileProcessorManager(unittest.TestCase): manager._kill_timed_out_processors() mock_kill.assert_called_once_with() - @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.pid", new_callable=PropertyMock) - @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess") + @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock) + @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess") def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_pid): mock_pid.return_value = 1234 manager = DagFileProcessorManager( @@ -560,7 +560,6 @@ class TestDagFileProcessorManager(unittest.TestCase): # We need to _actually_ parse the files here to test the behaviour. # Right now the parsing code lives in SchedulerJob, even though it's # called via utils.dag_processing. - from airflow.jobs.scheduler_job import SchedulerJob dag_id = 'exit_test_dag' dag_directory = TEST_DAG_FOLDER.parent / 'dags_with_system_exit'