This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 74b67125f5c6116dc827988b43e10f619803b770 Author: Tanel Kiis <tan...@users.noreply.github.com> AuthorDate: Tue Mar 22 19:30:37 2022 +0200 Fix Tasks getting stuck in scheduled state (#19747) The scheduler_job can get stuck in a state, where it is not able to queue new tasks. It will get out of this state on its own, but the time taken depends on the runtime of current tasks - this could be several hours or even days. If the scheduler can't queue any tasks because of different concurrency limits (per pool, dag or task), then on next iterations of the scheduler loop it will try to queue the same tasks. Meanwhile there could be some scheduled tasks with lower priority_weight that could be queued, but they will remain waiting. The proposed solution is to keep track of dag and task ids, that are concurrecy limited and then repeat the query with these dags and tasks filtered out. Co-authored-by: Tanel Kiis <tanel.k...@reach-u.com> (cherry picked from commit cd68540ef19b36180fdd1ebe38435637586747d4) --- airflow/jobs/scheduler_job.py | 339 ++++++++++++++++++++++++--------------- tests/jobs/test_scheduler_job.py | 94 ++++++++++- 2 files changed, 300 insertions(+), 133 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 490d507..391697c 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -27,7 +27,7 @@ import time import warnings from collections import defaultdict from datetime import timedelta -from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Tuple +from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple from sqlalchemy import and_, func, not_, or_, tuple_ from sqlalchemy.exc import OperationalError @@ -259,54 +259,16 @@ class SchedulerJob(BaseJob): if pool_slots_free == 0: self.log.debug("All pools are full!") - return executable_tis + return [] max_tis = min(max_tis, pool_slots_free) - # Get all task instances associated with scheduled - # DagRuns which are not backfilled, in the given states, - # and the dag is not paused - query = ( - session.query(TI) - .join(TI.dag_run) - .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state != DagRunState.QUEUED) - .join(TI.dag_model) - .filter(not_(DM.is_paused)) - .filter(TI.state == State.SCHEDULED) - .options(selectinload('dag_model')) - .order_by(-TI.priority_weight, DR.execution_date) - ) - starved_pools = [pool_name for pool_name, stats in pools.items() if stats['open'] <= 0] - if starved_pools: - query = query.filter(not_(TI.pool.in_(starved_pools))) - - query = query.limit(max_tis) - - task_instances_to_examine: List[TI] = with_row_locks( - query, - of=TI, - session=session, - **skip_locked(session=session), - ).all() - # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. - # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) - - if len(task_instances_to_examine) == 0: - self.log.debug("No tasks to consider for execution.") - return executable_tis - - # Put one task instance on each line - task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) - self.log.info("%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str) - - pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list) - for task_instance in task_instances_to_examine: - pool_to_task_instances[task_instance.pool].append(task_instance) + starved_pools = {pool_name for pool_name, stats in pools.items() if stats['open'] <= 0} # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. - dag_max_active_tasks_map: DefaultDict[str, int] + dag_active_tasks_map: DefaultDict[str, int] task_concurrency_map: DefaultDict[Tuple[str, str], int] - dag_max_active_tasks_map, task_concurrency_map = self.__get_concurrency_maps( + dag_active_tasks_map, task_concurrency_map = self.__get_concurrency_maps( states=list(EXECUTION_STATES), session=session ) @@ -314,124 +276,237 @@ class SchedulerJob(BaseJob): # Number of tasks that cannot be scheduled because of no open slot in pool num_starving_tasks_total = 0 - # Go through each pool, and queue up a task for execution if there are - # any open slots in the pool. + # dag and task ids that can't be queued because of concurrency limits + starved_dags: Set[str] = set() + starved_tasks: Set[Tuple[str, str]] = set() - for pool, task_instances in pool_to_task_instances.items(): - pool_name = pool - if pool not in pools: - self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool) - continue + pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int) + + for loop_count in itertools.count(start=1): - open_slots = pools[pool]["open"] + num_starved_pools = len(starved_pools) + num_starved_dags = len(starved_dags) + num_starved_tasks = len(starved_tasks) - num_ready = len(task_instances) - self.log.info( - "Figuring out tasks to run in Pool(name=%s) with %s open slots " - "and %s task instances ready to be queued", - pool, - open_slots, - num_ready, + # Get task instances associated with scheduled + # DagRuns which are not backfilled, in the given states, + # and the dag is not paused + query = ( + session.query(TI) + .join(TI.dag_run) + .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING) + .join(TI.dag_model) + .filter(not_(DM.is_paused)) + .filter(TI.state == TaskInstanceState.SCHEDULED) + .options(selectinload('dag_model')) + .order_by(-TI.priority_weight, DR.execution_date) ) - priority_sorted_task_instances = sorted( - task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date) + if starved_pools: + query = query.filter(not_(TI.pool.in_(starved_pools))) + + if starved_dags: + query = query.filter(not_(TI.dag_id.in_(starved_dags))) + + if starved_tasks: + if settings.engine.dialect.name == 'mssql': + task_filter = or_( + and_( + TaskInstance.dag_id == dag_id, + TaskInstance.task_id == task_id, + ) + for (dag_id, task_id) in starved_tasks + ) + else: + task_filter = tuple_(TaskInstance.dag_id, TaskInstance.task_id).in_(starved_tasks) + + query = query.filter(not_(task_filter)) + + query = query.limit(max_tis) + + task_instances_to_examine: List[TI] = with_row_locks( + query, + of=TI, + session=session, + **skip_locked(session=session), + ).all() + # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. + # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) + + if len(task_instances_to_examine) == 0: + self.log.debug("No tasks to consider for execution.") + break + + # Put one task instance on each line + task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) + self.log.info( + "%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str ) - num_starving_tasks = 0 - for current_index, task_instance in enumerate(priority_sorted_task_instances): - if open_slots <= 0: - self.log.info("Not scheduling since there are %s open slots in pool %s", open_slots, pool) - # Can't schedule any more since there are no more open slots. - num_unhandled = len(priority_sorted_task_instances) - current_index - num_starving_tasks += num_unhandled - num_starving_tasks_total += num_unhandled - break - - # Check to make sure that the task max_active_tasks of the DAG hasn't been - # reached. - dag_id = task_instance.dag_id - - current_max_active_tasks_per_dag = dag_max_active_tasks_map[dag_id] - max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks + pool_to_task_instances: DefaultDict[str, List[TI]] = defaultdict(list) + for task_instance in task_instances_to_examine: + pool_to_task_instances[task_instance.pool].append(task_instance) + + # Go through each pool, and queue up a task for execution if there are + # any open slots in the pool. + + for pool, task_instances in pool_to_task_instances.items(): + pool_name = pool + if pool not in pools: + self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool) + starved_pools.add(pool_name) + continue + + pool_total = pools[pool]["total"] + open_slots = pools[pool]["open"] + + num_ready = len(task_instances) self.log.info( - "DAG %s has %s/%s running and queued tasks", - dag_id, - current_max_active_tasks_per_dag, - max_active_tasks_per_dag_limit, + "Figuring out tasks to run in Pool(name=%s) with %s open slots " + "and %s task instances ready to be queued", + pool, + open_slots, + num_ready, ) - if current_max_active_tasks_per_dag >= max_active_tasks_per_dag_limit: + + priority_sorted_task_instances = sorted( + task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date) + ) + + for current_index, task_instance in enumerate(priority_sorted_task_instances): + if open_slots <= 0: + self.log.info( + "Not scheduling since there are %s open slots in pool %s", open_slots, pool + ) + # Can't schedule any more since there are no more open slots. + num_unhandled = len(priority_sorted_task_instances) - current_index + pool_num_starving_tasks[pool_name] += num_unhandled + num_starving_tasks_total += num_unhandled + starved_pools.add(pool_name) + break + + if task_instance.pool_slots > pool_total: + self.log.warning( + "Not executing %s. Requested pool slots (%s) are greater than " + "total pool slots: '%s' for pool: %s.", + task_instance, + task_instance.pool_slots, + pool_total, + pool, + ) + + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) + continue + + if task_instance.pool_slots > open_slots: + self.log.info( + "Not executing %s since it requires %s slots " + "but there are %s open slots in the pool %s.", + task_instance, + task_instance.pool_slots, + open_slots, + pool, + ) + pool_num_starving_tasks[pool_name] += 1 + num_starving_tasks_total += 1 + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) + # Though we can execute tasks with lower priority if there's enough room + continue + + # Check to make sure that the task max_active_tasks of the DAG hasn't been + # reached. + dag_id = task_instance.dag_id + + current_active_tasks_per_dag = dag_active_tasks_map[dag_id] + max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks self.log.info( - "Not executing %s since the number of tasks running or queued " - "from DAG %s is >= to the DAG's max_active_tasks limit of %s", - task_instance, + "DAG %s has %s/%s running and queued tasks", dag_id, + current_active_tasks_per_dag, max_active_tasks_per_dag_limit, ) - continue - - task_concurrency_limit: Optional[int] = None - if task_instance.dag_model.has_task_concurrency_limits: - # Many dags don't have a task_concurrency, so where we can avoid loading the full - # serialized DAG the better. - serialized_dag = self.dagbag.get_dag(dag_id, session=session) - # If the dag is missing, fail the task and continue to the next task. - if not serialized_dag: - self.log.error( - "DAG '%s' for task instance %s not found in serialized_dag table", - dag_id, + if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: + self.log.info( + "Not executing %s since the number of tasks running or queued " + "from DAG %s is >= to the DAG's max_active_tasks limit of %s", task_instance, + dag_id, + max_active_tasks_per_dag_limit, ) - session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( - {TI.state: State.FAILED}, synchronize_session='fetch' - ) + starved_dags.add(dag_id) continue - if serialized_dag.has_task(task_instance.task_id): - task_concurrency_limit = serialized_dag.get_task( - task_instance.task_id - ).max_active_tis_per_dag - - if task_concurrency_limit is not None: - current_task_concurrency = task_concurrency_map[ - (task_instance.dag_id, task_instance.task_id) - ] - - if current_task_concurrency >= task_concurrency_limit: - self.log.info( - "Not executing %s since the task concurrency for" - " this task has been reached.", + + if task_instance.dag_model.has_task_concurrency_limits: + # Many dags don't have a task_concurrency, so where we can avoid loading the full + # serialized DAG the better. + serialized_dag = self.dagbag.get_dag(dag_id, session=session) + # If the dag is missing, fail the task and continue to the next task. + if not serialized_dag: + self.log.error( + "DAG '%s' for task instance %s not found in serialized_dag table", + dag_id, task_instance, ) + session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( + {TI.state: State.FAILED}, synchronize_session='fetch' + ) continue - if task_instance.pool_slots > open_slots: - self.log.info( - "Not executing %s since it requires %s slots " - "but there are %s open slots in the pool %s.", - task_instance, - task_instance.pool_slots, - open_slots, - pool, - ) - num_starving_tasks += 1 - num_starving_tasks_total += 1 - # Though we can execute tasks with lower priority if there's enough room - continue + task_concurrency_limit: Optional[int] = None + if serialized_dag.has_task(task_instance.task_id): + task_concurrency_limit = serialized_dag.get_task( + task_instance.task_id + ).max_active_tis_per_dag + + if task_concurrency_limit is not None: + current_task_concurrency = task_concurrency_map[ + (task_instance.dag_id, task_instance.task_id) + ] + + if current_task_concurrency >= task_concurrency_limit: + self.log.info( + "Not executing %s since the task concurrency for" + " this task has been reached.", + task_instance, + ) + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) + continue + + executable_tis.append(task_instance) + open_slots -= task_instance.pool_slots + dag_active_tasks_map[dag_id] += 1 + task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 + + pools[pool]["open"] = open_slots + + is_done = executable_tis or len(task_instances_to_examine) < max_tis + # Check this to avoid accidental infinite loops + found_new_filters = ( + len(starved_pools) > num_starved_pools + or len(starved_dags) > num_starved_dags + or len(starved_tasks) > num_starved_tasks + ) - executable_tis.append(task_instance) - open_slots -= task_instance.pool_slots - dag_max_active_tasks_map[dag_id] += 1 - task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 + if is_done or not found_new_filters: + break + + self.log.debug( + "Found no task instances to queue on the %s. iteration " + "but there could be more candidate task instances to check.", + loop_count, + ) + for pool_name, num_starving_tasks in pool_num_starving_tasks.items(): Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks) Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total) Stats.gauge('scheduler.tasks.running', num_tasks_in_executor) Stats.gauge('scheduler.tasks.executable', len(executable_tis)) - task_instance_str = "\n\t".join(repr(x) for x in executable_tis) - self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str) if len(executable_tis) > 0: + task_instance_str = "\n\t".join(repr(x) for x in executable_tis) + self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str) + # set TIs to queued state filter_for_tis = TI.filter_for_tis(executable_tis) session.query(TI).filter(filter_for_tis).update( diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 168d452..db39977 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -466,6 +466,7 @@ class TestSchedulerJob: dr2.get_task_instance(task_id_1, session=session), dr2.get_task_instance(task_id_2, session=session), ] + tis = sorted(tis, key=lambda ti: ti.key) for ti in tis: ti.state = State.SCHEDULED session.merge(ti) @@ -482,7 +483,7 @@ class TestSchedulerJob: for ti in res: res_keys.append(ti.key) assert tis[0].key in res_keys - assert tis[1].key in res_keys + assert tis[2].key in res_keys assert tis[3].key in res_keys session.rollback() @@ -899,6 +900,97 @@ class TestSchedulerJob: session.rollback() + def test_find_executable_task_instances_not_enough_pool_slots_for_first(self, dag_maker): + set_default_pool_slots(1) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + session = settings.Session() + + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_not_enough_pool_slots_for_first' + with dag_maker(dag_id=dag_id): + op1 = DummyOperator(task_id='dummy1', priority_weight=2, pool_slots=2) + op2 = DummyOperator(task_id='dummy2', priority_weight=1, pool_slots=1) + + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + + ti1 = dr1.get_task_instance(op1.task_id, session) + ti2 = dr1.get_task_instance(op2.task_id, session) + ti1.state = State.SCHEDULED + ti2.state = State.SCHEDULED + session.flush() + + # Schedule ti with lower priority, + # because the one with higher priority is limited by a concurrency limit + res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session) + assert 1 == len(res) + assert res[0].key == ti2.key + + session.rollback() + + def test_find_executable_task_instances_not_enough_dag_concurrency_for_first(self, dag_maker): + self.scheduler_job = SchedulerJob(subdir=os.devnull) + session = settings.Session() + + dag_id_1 = ( + 'SchedulerJobTest.test_find_executable_task_instances_not_enough_dag_concurrency_for_first-a' + ) + dag_id_2 = ( + 'SchedulerJobTest.test_find_executable_task_instances_not_enough_dag_concurrency_for_first-b' + ) + + with dag_maker(dag_id=dag_id_1, max_active_tasks=1): + op1a = DummyOperator(task_id='dummy1-a', priority_weight=2) + op1b = DummyOperator(task_id='dummy1-b', priority_weight=2) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + + with dag_maker(dag_id=dag_id_2): + op2 = DummyOperator(task_id='dummy2', priority_weight=1) + dr2 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + + ti1a = dr1.get_task_instance(op1a.task_id, session) + ti1b = dr1.get_task_instance(op1b.task_id, session) + ti2 = dr2.get_task_instance(op2.task_id, session) + ti1a.state = State.RUNNING + ti1b.state = State.SCHEDULED + ti2.state = State.SCHEDULED + session.flush() + + # Schedule ti with lower priority, + # because the one with higher priority is limited by a concurrency limit + res = self.scheduler_job._executable_task_instances_to_queued(max_tis=1, session=session) + assert 1 == len(res) + assert res[0].key == ti2.key + + session.rollback() + + def test_find_executable_task_instances_not_enough_task_concurrency_for_first(self, dag_maker): + self.scheduler_job = SchedulerJob(subdir=os.devnull) + session = settings.Session() + + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_not_enough_task_concurrency_for_first' + + with dag_maker(dag_id=dag_id): + op1a = DummyOperator(task_id='dummy1-a', priority_weight=2, max_active_tis_per_dag=1) + op1b = DummyOperator(task_id='dummy1-b', priority_weight=1) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) + + ti1a = dr1.get_task_instance(op1a.task_id, session) + ti1b = dr1.get_task_instance(op1b.task_id, session) + ti2a = dr2.get_task_instance(op1a.task_id, session) + ti1a.state = State.RUNNING + ti1b.state = State.SCHEDULED + ti2a.state = State.SCHEDULED + session.flush() + + # Schedule ti with lower priority, + # because the one with higher priority is limited by a concurrency limit + res = self.scheduler_job._executable_task_instances_to_queued(max_tis=1, session=session) + assert 1 == len(res) + assert res[0].key == ti1b.key + + session.rollback() + def test_enqueue_task_instances_with_queued_state(self, dag_maker): dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state' task_id_1 = 'dummy'