This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 74b67125f5c6116dc827988b43e10f619803b770
Author: Tanel Kiis <tan...@users.noreply.github.com>
AuthorDate: Tue Mar 22 19:30:37 2022 +0200

    Fix Tasks getting stuck in scheduled state (#19747)
    
    The scheduler_job can get stuck in a state, where it is not able to queue 
new tasks. It will get out of this state on its own, but the time taken depends 
on the runtime of current tasks - this could be several hours or even days.
    
    If the scheduler can't queue any tasks because of different concurrency 
limits (per pool, dag or task), then on next iterations of the scheduler loop 
it will try to queue the same tasks. Meanwhile there could be some scheduled 
tasks with lower priority_weight that could be queued, but they will remain 
waiting.
    
    The proposed solution is to keep track of dag and task ids, that are 
concurrecy limited and then repeat the query with these dags and tasks filtered 
out.
    
    Co-authored-by: Tanel Kiis <tanel.k...@reach-u.com>
    (cherry picked from commit cd68540ef19b36180fdd1ebe38435637586747d4)
---
 airflow/jobs/scheduler_job.py    | 339 ++++++++++++++++++++++++---------------
 tests/jobs/test_scheduler_job.py |  94 ++++++++++-
 2 files changed, 300 insertions(+), 133 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 490d507..391697c 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -27,7 +27,7 @@ import time
 import warnings
 from collections import defaultdict
 from datetime import timedelta
-from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, 
Tuple
+from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, 
Set, Tuple
 
 from sqlalchemy import and_, func, not_, or_, tuple_
 from sqlalchemy.exc import OperationalError
@@ -259,54 +259,16 @@ class SchedulerJob(BaseJob):
 
         if pool_slots_free == 0:
             self.log.debug("All pools are full!")
-            return executable_tis
+            return []
 
         max_tis = min(max_tis, pool_slots_free)
 
-        # Get all task instances associated with scheduled
-        # DagRuns which are not backfilled, in the given states,
-        # and the dag is not paused
-        query = (
-            session.query(TI)
-            .join(TI.dag_run)
-            .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state != 
DagRunState.QUEUED)
-            .join(TI.dag_model)
-            .filter(not_(DM.is_paused))
-            .filter(TI.state == State.SCHEDULED)
-            .options(selectinload('dag_model'))
-            .order_by(-TI.priority_weight, DR.execution_date)
-        )
-        starved_pools = [pool_name for pool_name, stats in pools.items() if 
stats['open'] <= 0]
-        if starved_pools:
-            query = query.filter(not_(TI.pool.in_(starved_pools)))
-
-        query = query.limit(max_tis)
-
-        task_instances_to_examine: List[TI] = with_row_locks(
-            query,
-            of=TI,
-            session=session,
-            **skip_locked(session=session),
-        ).all()
-        # TODO[HA]: This was wrong before anyway, as it only looked at a 
sub-set of dags, not everything.
-        # Stats.gauge('scheduler.tasks.pending', 
len(task_instances_to_examine))
-
-        if len(task_instances_to_examine) == 0:
-            self.log.debug("No tasks to consider for execution.")
-            return executable_tis
-
-        # Put one task instance on each line
-        task_instance_str = "\n\t".join(repr(x) for x in 
task_instances_to_examine)
-        self.log.info("%s tasks up for execution:\n\t%s", 
len(task_instances_to_examine), task_instance_str)
-
-        pool_to_task_instances: DefaultDict[str, List[models.Pool]] = 
defaultdict(list)
-        for task_instance in task_instances_to_examine:
-            pool_to_task_instances[task_instance.pool].append(task_instance)
+        starved_pools = {pool_name for pool_name, stats in pools.items() if 
stats['open'] <= 0}
 
         # dag_id to # of running tasks and (dag_id, task_id) to # of running 
tasks.
-        dag_max_active_tasks_map: DefaultDict[str, int]
+        dag_active_tasks_map: DefaultDict[str, int]
         task_concurrency_map: DefaultDict[Tuple[str, str], int]
-        dag_max_active_tasks_map, task_concurrency_map = 
self.__get_concurrency_maps(
+        dag_active_tasks_map, task_concurrency_map = 
self.__get_concurrency_maps(
             states=list(EXECUTION_STATES), session=session
         )
 
@@ -314,124 +276,237 @@ class SchedulerJob(BaseJob):
         # Number of tasks that cannot be scheduled because of no open slot in 
pool
         num_starving_tasks_total = 0
 
-        # Go through each pool, and queue up a task for execution if there are
-        # any open slots in the pool.
+        # dag and task ids that can't be queued because of concurrency limits
+        starved_dags: Set[str] = set()
+        starved_tasks: Set[Tuple[str, str]] = set()
 
-        for pool, task_instances in pool_to_task_instances.items():
-            pool_name = pool
-            if pool not in pools:
-                self.log.warning("Tasks using non-existent pool '%s' will not 
be scheduled", pool)
-                continue
+        pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)
+
+        for loop_count in itertools.count(start=1):
 
-            open_slots = pools[pool]["open"]
+            num_starved_pools = len(starved_pools)
+            num_starved_dags = len(starved_dags)
+            num_starved_tasks = len(starved_tasks)
 
-            num_ready = len(task_instances)
-            self.log.info(
-                "Figuring out tasks to run in Pool(name=%s) with %s open slots 
"
-                "and %s task instances ready to be queued",
-                pool,
-                open_slots,
-                num_ready,
+            # Get task instances associated with scheduled
+            # DagRuns which are not backfilled, in the given states,
+            # and the dag is not paused
+            query = (
+                session.query(TI)
+                .join(TI.dag_run)
+                .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == 
DagRunState.RUNNING)
+                .join(TI.dag_model)
+                .filter(not_(DM.is_paused))
+                .filter(TI.state == TaskInstanceState.SCHEDULED)
+                .options(selectinload('dag_model'))
+                .order_by(-TI.priority_weight, DR.execution_date)
             )
 
-            priority_sorted_task_instances = sorted(
-                task_instances, key=lambda ti: (-ti.priority_weight, 
ti.execution_date)
+            if starved_pools:
+                query = query.filter(not_(TI.pool.in_(starved_pools)))
+
+            if starved_dags:
+                query = query.filter(not_(TI.dag_id.in_(starved_dags)))
+
+            if starved_tasks:
+                if settings.engine.dialect.name == 'mssql':
+                    task_filter = or_(
+                        and_(
+                            TaskInstance.dag_id == dag_id,
+                            TaskInstance.task_id == task_id,
+                        )
+                        for (dag_id, task_id) in starved_tasks
+                    )
+                else:
+                    task_filter = tuple_(TaskInstance.dag_id, 
TaskInstance.task_id).in_(starved_tasks)
+
+                query = query.filter(not_(task_filter))
+
+            query = query.limit(max_tis)
+
+            task_instances_to_examine: List[TI] = with_row_locks(
+                query,
+                of=TI,
+                session=session,
+                **skip_locked(session=session),
+            ).all()
+            # TODO[HA]: This was wrong before anyway, as it only looked at a 
sub-set of dags, not everything.
+            # Stats.gauge('scheduler.tasks.pending', 
len(task_instances_to_examine))
+
+            if len(task_instances_to_examine) == 0:
+                self.log.debug("No tasks to consider for execution.")
+                break
+
+            # Put one task instance on each line
+            task_instance_str = "\n\t".join(repr(x) for x in 
task_instances_to_examine)
+            self.log.info(
+                "%s tasks up for execution:\n\t%s", 
len(task_instances_to_examine), task_instance_str
             )
 
-            num_starving_tasks = 0
-            for current_index, task_instance in 
enumerate(priority_sorted_task_instances):
-                if open_slots <= 0:
-                    self.log.info("Not scheduling since there are %s open 
slots in pool %s", open_slots, pool)
-                    # Can't schedule any more since there are no more open 
slots.
-                    num_unhandled = len(priority_sorted_task_instances) - 
current_index
-                    num_starving_tasks += num_unhandled
-                    num_starving_tasks_total += num_unhandled
-                    break
-
-                # Check to make sure that the task max_active_tasks of the DAG 
hasn't been
-                # reached.
-                dag_id = task_instance.dag_id
-
-                current_max_active_tasks_per_dag = 
dag_max_active_tasks_map[dag_id]
-                max_active_tasks_per_dag_limit = 
task_instance.dag_model.max_active_tasks
+            pool_to_task_instances: DefaultDict[str, List[TI]] = 
defaultdict(list)
+            for task_instance in task_instances_to_examine:
+                
pool_to_task_instances[task_instance.pool].append(task_instance)
+
+            # Go through each pool, and queue up a task for execution if there 
are
+            # any open slots in the pool.
+
+            for pool, task_instances in pool_to_task_instances.items():
+                pool_name = pool
+                if pool not in pools:
+                    self.log.warning("Tasks using non-existent pool '%s' will 
not be scheduled", pool)
+                    starved_pools.add(pool_name)
+                    continue
+
+                pool_total = pools[pool]["total"]
+                open_slots = pools[pool]["open"]
+
+                num_ready = len(task_instances)
                 self.log.info(
-                    "DAG %s has %s/%s running and queued tasks",
-                    dag_id,
-                    current_max_active_tasks_per_dag,
-                    max_active_tasks_per_dag_limit,
+                    "Figuring out tasks to run in Pool(name=%s) with %s open 
slots "
+                    "and %s task instances ready to be queued",
+                    pool,
+                    open_slots,
+                    num_ready,
                 )
-                if current_max_active_tasks_per_dag >= 
max_active_tasks_per_dag_limit:
+
+                priority_sorted_task_instances = sorted(
+                    task_instances, key=lambda ti: (-ti.priority_weight, 
ti.execution_date)
+                )
+
+                for current_index, task_instance in 
enumerate(priority_sorted_task_instances):
+                    if open_slots <= 0:
+                        self.log.info(
+                            "Not scheduling since there are %s open slots in 
pool %s", open_slots, pool
+                        )
+                        # Can't schedule any more since there are no more open 
slots.
+                        num_unhandled = len(priority_sorted_task_instances) - 
current_index
+                        pool_num_starving_tasks[pool_name] += num_unhandled
+                        num_starving_tasks_total += num_unhandled
+                        starved_pools.add(pool_name)
+                        break
+
+                    if task_instance.pool_slots > pool_total:
+                        self.log.warning(
+                            "Not executing %s. Requested pool slots (%s) are 
greater than "
+                            "total pool slots: '%s' for pool: %s.",
+                            task_instance,
+                            task_instance.pool_slots,
+                            pool_total,
+                            pool,
+                        )
+
+                        starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
+                        continue
+
+                    if task_instance.pool_slots > open_slots:
+                        self.log.info(
+                            "Not executing %s since it requires %s slots "
+                            "but there are %s open slots in the pool %s.",
+                            task_instance,
+                            task_instance.pool_slots,
+                            open_slots,
+                            pool,
+                        )
+                        pool_num_starving_tasks[pool_name] += 1
+                        num_starving_tasks_total += 1
+                        starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
+                        # Though we can execute tasks with lower priority if 
there's enough room
+                        continue
+
+                    # Check to make sure that the task max_active_tasks of the 
DAG hasn't been
+                    # reached.
+                    dag_id = task_instance.dag_id
+
+                    current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
+                    max_active_tasks_per_dag_limit = 
task_instance.dag_model.max_active_tasks
                     self.log.info(
-                        "Not executing %s since the number of tasks running or 
queued "
-                        "from DAG %s is >= to the DAG's max_active_tasks limit 
of %s",
-                        task_instance,
+                        "DAG %s has %s/%s running and queued tasks",
                         dag_id,
+                        current_active_tasks_per_dag,
                         max_active_tasks_per_dag_limit,
                     )
-                    continue
-
-                task_concurrency_limit: Optional[int] = None
-                if task_instance.dag_model.has_task_concurrency_limits:
-                    # Many dags don't have a task_concurrency, so where we can 
avoid loading the full
-                    # serialized DAG the better.
-                    serialized_dag = self.dagbag.get_dag(dag_id, 
session=session)
-                    # If the dag is missing, fail the task and continue to the 
next task.
-                    if not serialized_dag:
-                        self.log.error(
-                            "DAG '%s' for task instance %s not found in 
serialized_dag table",
-                            dag_id,
+                    if current_active_tasks_per_dag >= 
max_active_tasks_per_dag_limit:
+                        self.log.info(
+                            "Not executing %s since the number of tasks 
running or queued "
+                            "from DAG %s is >= to the DAG's max_active_tasks 
limit of %s",
                             task_instance,
+                            dag_id,
+                            max_active_tasks_per_dag_limit,
                         )
-                        session.query(TI).filter(TI.dag_id == dag_id, TI.state 
== State.SCHEDULED).update(
-                            {TI.state: State.FAILED}, 
synchronize_session='fetch'
-                        )
+                        starved_dags.add(dag_id)
                         continue
-                    if serialized_dag.has_task(task_instance.task_id):
-                        task_concurrency_limit = serialized_dag.get_task(
-                            task_instance.task_id
-                        ).max_active_tis_per_dag
-
-                    if task_concurrency_limit is not None:
-                        current_task_concurrency = task_concurrency_map[
-                            (task_instance.dag_id, task_instance.task_id)
-                        ]
-
-                        if current_task_concurrency >= task_concurrency_limit:
-                            self.log.info(
-                                "Not executing %s since the task concurrency 
for"
-                                " this task has been reached.",
+
+                    if task_instance.dag_model.has_task_concurrency_limits:
+                        # Many dags don't have a task_concurrency, so where we 
can avoid loading the full
+                        # serialized DAG the better.
+                        serialized_dag = self.dagbag.get_dag(dag_id, 
session=session)
+                        # If the dag is missing, fail the task and continue to 
the next task.
+                        if not serialized_dag:
+                            self.log.error(
+                                "DAG '%s' for task instance %s not found in 
serialized_dag table",
+                                dag_id,
                                 task_instance,
                             )
+                            session.query(TI).filter(TI.dag_id == dag_id, 
TI.state == State.SCHEDULED).update(
+                                {TI.state: State.FAILED}, 
synchronize_session='fetch'
+                            )
                             continue
 
-                if task_instance.pool_slots > open_slots:
-                    self.log.info(
-                        "Not executing %s since it requires %s slots "
-                        "but there are %s open slots in the pool %s.",
-                        task_instance,
-                        task_instance.pool_slots,
-                        open_slots,
-                        pool,
-                    )
-                    num_starving_tasks += 1
-                    num_starving_tasks_total += 1
-                    # Though we can execute tasks with lower priority if 
there's enough room
-                    continue
+                        task_concurrency_limit: Optional[int] = None
+                        if serialized_dag.has_task(task_instance.task_id):
+                            task_concurrency_limit = serialized_dag.get_task(
+                                task_instance.task_id
+                            ).max_active_tis_per_dag
+
+                        if task_concurrency_limit is not None:
+                            current_task_concurrency = task_concurrency_map[
+                                (task_instance.dag_id, task_instance.task_id)
+                            ]
+
+                            if current_task_concurrency >= 
task_concurrency_limit:
+                                self.log.info(
+                                    "Not executing %s since the task 
concurrency for"
+                                    " this task has been reached.",
+                                    task_instance,
+                                )
+                                starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
+                                continue
+
+                    executable_tis.append(task_instance)
+                    open_slots -= task_instance.pool_slots
+                    dag_active_tasks_map[dag_id] += 1
+                    task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
+
+                pools[pool]["open"] = open_slots
+
+            is_done = executable_tis or len(task_instances_to_examine) < 
max_tis
+            # Check this to avoid accidental infinite loops
+            found_new_filters = (
+                len(starved_pools) > num_starved_pools
+                or len(starved_dags) > num_starved_dags
+                or len(starved_tasks) > num_starved_tasks
+            )
 
-                executable_tis.append(task_instance)
-                open_slots -= task_instance.pool_slots
-                dag_max_active_tasks_map[dag_id] += 1
-                task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
+            if is_done or not found_new_filters:
+                break
+
+            self.log.debug(
+                "Found no task instances to queue on the %s. iteration "
+                "but there could be more candidate task instances to check.",
+                loop_count,
+            )
 
+        for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
             Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)
 
         Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total)
         Stats.gauge('scheduler.tasks.running', num_tasks_in_executor)
         Stats.gauge('scheduler.tasks.executable', len(executable_tis))
 
-        task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
-        self.log.info("Setting the following tasks to queued state:\n\t%s", 
task_instance_str)
         if len(executable_tis) > 0:
+            task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
+            self.log.info("Setting the following tasks to queued 
state:\n\t%s", task_instance_str)
+
             # set TIs to queued state
             filter_for_tis = TI.filter_for_tis(executable_tis)
             session.query(TI).filter(filter_for_tis).update(
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 168d452..db39977 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -466,6 +466,7 @@ class TestSchedulerJob:
             dr2.get_task_instance(task_id_1, session=session),
             dr2.get_task_instance(task_id_2, session=session),
         ]
+        tis = sorted(tis, key=lambda ti: ti.key)
         for ti in tis:
             ti.state = State.SCHEDULED
             session.merge(ti)
@@ -482,7 +483,7 @@ class TestSchedulerJob:
         for ti in res:
             res_keys.append(ti.key)
         assert tis[0].key in res_keys
-        assert tis[1].key in res_keys
+        assert tis[2].key in res_keys
         assert tis[3].key in res_keys
         session.rollback()
 
@@ -899,6 +900,97 @@ class TestSchedulerJob:
 
         session.rollback()
 
+    def 
test_find_executable_task_instances_not_enough_pool_slots_for_first(self, 
dag_maker):
+        set_default_pool_slots(1)
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        session = settings.Session()
+
+        dag_id = 
'SchedulerJobTest.test_find_executable_task_instances_not_enough_pool_slots_for_first'
+        with dag_maker(dag_id=dag_id):
+            op1 = DummyOperator(task_id='dummy1', priority_weight=2, 
pool_slots=2)
+            op2 = DummyOperator(task_id='dummy2', priority_weight=1, 
pool_slots=1)
+
+        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+        ti1 = dr1.get_task_instance(op1.task_id, session)
+        ti2 = dr1.get_task_instance(op2.task_id, session)
+        ti1.state = State.SCHEDULED
+        ti2.state = State.SCHEDULED
+        session.flush()
+
+        # Schedule ti with lower priority,
+        # because the one with higher priority is limited by a concurrency 
limit
+        res = 
self.scheduler_job._executable_task_instances_to_queued(max_tis=32, 
session=session)
+        assert 1 == len(res)
+        assert res[0].key == ti2.key
+
+        session.rollback()
+
+    def 
test_find_executable_task_instances_not_enough_dag_concurrency_for_first(self, 
dag_maker):
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        session = settings.Session()
+
+        dag_id_1 = (
+            
'SchedulerJobTest.test_find_executable_task_instances_not_enough_dag_concurrency_for_first-a'
+        )
+        dag_id_2 = (
+            
'SchedulerJobTest.test_find_executable_task_instances_not_enough_dag_concurrency_for_first-b'
+        )
+
+        with dag_maker(dag_id=dag_id_1, max_active_tasks=1):
+            op1a = DummyOperator(task_id='dummy1-a', priority_weight=2)
+            op1b = DummyOperator(task_id='dummy1-b', priority_weight=2)
+        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+        with dag_maker(dag_id=dag_id_2):
+            op2 = DummyOperator(task_id='dummy2', priority_weight=1)
+        dr2 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+        ti1a = dr1.get_task_instance(op1a.task_id, session)
+        ti1b = dr1.get_task_instance(op1b.task_id, session)
+        ti2 = dr2.get_task_instance(op2.task_id, session)
+        ti1a.state = State.RUNNING
+        ti1b.state = State.SCHEDULED
+        ti2.state = State.SCHEDULED
+        session.flush()
+
+        # Schedule ti with lower priority,
+        # because the one with higher priority is limited by a concurrency 
limit
+        res = 
self.scheduler_job._executable_task_instances_to_queued(max_tis=1, 
session=session)
+        assert 1 == len(res)
+        assert res[0].key == ti2.key
+
+        session.rollback()
+
+    def 
test_find_executable_task_instances_not_enough_task_concurrency_for_first(self, 
dag_maker):
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        session = settings.Session()
+
+        dag_id = 
'SchedulerJobTest.test_find_executable_task_instances_not_enough_task_concurrency_for_first'
+
+        with dag_maker(dag_id=dag_id):
+            op1a = DummyOperator(task_id='dummy1-a', priority_weight=2, 
max_active_tis_per_dag=1)
+            op1b = DummyOperator(task_id='dummy1-b', priority_weight=1)
+        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
+
+        ti1a = dr1.get_task_instance(op1a.task_id, session)
+        ti1b = dr1.get_task_instance(op1b.task_id, session)
+        ti2a = dr2.get_task_instance(op1a.task_id, session)
+        ti1a.state = State.RUNNING
+        ti1b.state = State.SCHEDULED
+        ti2a.state = State.SCHEDULED
+        session.flush()
+
+        # Schedule ti with lower priority,
+        # because the one with higher priority is limited by a concurrency 
limit
+        res = 
self.scheduler_job._executable_task_instances_to_queued(max_tis=1, 
session=session)
+        assert 1 == len(res)
+        assert res[0].key == ti1b.key
+
+        session.rollback()
+
     def test_enqueue_task_instances_with_queued_state(self, dag_maker):
         dag_id = 
'SchedulerJobTest.test_enqueue_task_instances_with_queued_state'
         task_id_1 = 'dummy'

Reply via email to