[AIRFLOW-111] Include queued tasks in scheduler concurrency check

The concurrency argument in dags appears to not be
obeyed because the
scheduler does not check the concurrency properly
when checking tasks.
The tasks do not run, but this leads to a lot of
scheduler churn.

Closes #2214 from saguziel/aguziel-fix-concurrency

(cherry picked from commit 3ff5abee3f9d29e545e021c2c060e9c9f3045236)
Signed-off-by: Bolke de Bruin <bo...@xs4all.nl>


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/9070a827
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/9070a827
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/9070a827

Branch: refs/heads/v1-8-stable
Commit: 9070a82775691e08fb1b95c28fbc2cc5ee7b843d
Parents: 4db53f3
Author: Alex Guziel <alex.guz...@airbnb.com>
Authored: Wed Apr 5 09:59:53 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Wed Apr 5 10:00:06 2017 +0200

----------------------------------------------------------------------
 airflow/jobs.py   | 25 +++++++++++---------
 airflow/models.py | 48 ++++++++++++++++++++++----------------
 tests/jobs.py     | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++
 tests/models.py   | 38 +++++++++++++++++++++++++++++++
 4 files changed, 142 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9070a827/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 7db9b9c..ce45e05 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -43,7 +43,7 @@ from tabulate import tabulate
 from airflow import executors, models, settings
 from airflow import configuration as conf
 from airflow.exceptions import AirflowException
-from airflow.models import DagRun
+from airflow.models import DAG, DagRun
 from airflow.settings import Stats
 from airflow.task_runner import get_task_runner
 from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
@@ -1036,7 +1036,7 @@ class SchedulerJob(BaseJob):
                 task_instances, key=lambda ti: (-ti.priority_weight, 
ti.execution_date))
 
             # DAG IDs with running tasks that equal the concurrency limit of 
the dag
-            dag_id_to_running_task_count = {}
+            dag_id_to_possibly_running_task_count = {}
 
             for task_instance in priority_sorted_task_instances:
                 if open_slots <= 0:
@@ -1063,22 +1063,24 @@ class SchedulerJob(BaseJob):
                 # reached.
                 dag_id = task_instance.dag_id
 
-                if dag_id not in dag_id_to_running_task_count:
-                    dag_id_to_running_task_count[dag_id] = \
-                        DagRun.get_running_tasks(
-                            session,
+                if dag_id not in dag_id_to_possibly_running_task_count:
+                    dag_id_to_possibly_running_task_count[dag_id] = \
+                        DAG.get_num_task_instances(
                             dag_id,
-                            simple_dag_bag.get_dag(dag_id).task_ids)
+                            simple_dag_bag.get_dag(dag_id).task_ids,
+                            states=[State.RUNNING, State.QUEUED],
+                            session=session)
 
-                current_task_concurrency = dag_id_to_running_task_count[dag_id]
+                current_task_concurrency = 
dag_id_to_possibly_running_task_count[dag_id]
                 task_concurrency_limit = 
simple_dag_bag.get_dag(dag_id).concurrency
-                self.logger.info("DAG {} has {}/{} running tasks"
+                self.logger.info("DAG {} has {}/{} running and queued tasks"
                                  .format(dag_id,
                                          current_task_concurrency,
                                          task_concurrency_limit))
-                if current_task_concurrency > task_concurrency_limit:
+                if current_task_concurrency >= task_concurrency_limit:
                     self.logger.info("Not executing {} since the number "
-                                     "of tasks running from DAG {} is >= to 
the "
+                                     "of tasks running or queued from DAG {}"
+                                     " is >= to the "
                                      "DAG's task concurrency limit of {}"
                                      .format(task_instance,
                                              dag_id,
@@ -1137,6 +1139,7 @@ class SchedulerJob(BaseJob):
                     queue=queue)
 
                 open_slots -= 1
+                dag_id_to_possibly_running_task_count[dag_id] += 1
 
     def _process_dags(self, dagbag, dags, tis_out):
         """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9070a827/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 6828ab6..47413e0 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -3466,6 +3466,34 @@ class DAG(BaseDag, LoggingMixin):
             session.merge(dag)
             session.commit()
 
+    @staticmethod
+    @provide_session
+    def get_num_task_instances(dag_id, task_ids, states=None, session=None):
+        """
+        Returns the number of task instances in the given DAG.
+
+        :param session: ORM session
+        :param dag_id: ID of the DAG to get the task concurrency of
+        :type dag_id: unicode
+        :param task_ids: A list of valid task IDs for the given DAG
+        :type task_ids: list[unicode]
+        :param states: A list of states to filter by if supplied
+        :type states: list[state]
+        :return: The number of running tasks
+        :rtype: int
+        """
+        qry = session.query(func.count(TaskInstance.task_id)).filter(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.task_id.in_(task_ids))
+        if states is not None:
+            if None in states:
+                qry = qry.filter(or_(
+                    TaskInstance.state.in_(states),
+                    TaskInstance.state.is_(None)))
+            else:
+                qry = qry.filter(TaskInstance.state.in_(states))
+        return qry.scalar()
+
 
 class Chart(Base):
     __tablename__ = "chart"
@@ -4118,26 +4146,6 @@ class DagRun(Base):
         session.commit()
 
     @staticmethod
-    def get_running_tasks(session, dag_id, task_ids):
-        """
-        Returns the number of tasks running in the given DAG.
-
-        :param session: ORM session
-        :param dag_id: ID of the DAG to get the task concurrency of
-        :type dag_id: unicode
-        :param task_ids: A list of valid task IDs for the given DAG
-        :type task_ids: list[unicode]
-        :return: The number of running tasks
-        :rtype: int
-        """
-        qry = session.query(func.count(TaskInstance.task_id)).filter(
-            TaskInstance.dag_id == dag_id,
-            TaskInstance.task_id.in_(task_ids),
-            TaskInstance.state == State.RUNNING,
-        )
-        return qry.scalar()
-
-    @staticmethod
     def get_run(session, dag_id, execution_date):
         """
         :param dag_id: DAG ID

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9070a827/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index f9ede68..9b245ae 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -481,6 +481,68 @@ class SchedulerJobTest(unittest.TestCase):
         scheduler.heartrate = 0
         scheduler.run()
 
+    def test_concurrency(self):
+        dag_id = 'SchedulerJobTest.test_concurrency'
+        task_id_1 = 'dummy_task'
+        task_id_2 = 'dummy_task_nonexistent_queue'
+        # important that len(tasks) is less than concurrency
+        # because before scheduler._execute_task_instances would only
+        # check the num tasks once so if concurrency was 3,
+        # we could execute arbitrarily many tasks in the second run
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        task2 = DummyOperator(dag=dag, task_id=task_id_2)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        # create first dag run with 1 running and 1 queued
+        dr1 = scheduler.create_dag_run(dag)
+        ti1 = TI(task1, dr1.execution_date)
+        ti2 = TI(task2, dr1.execution_date)
+        ti1.refresh_from_db()
+        ti2.refresh_from_db()
+        ti1.state = State.RUNNING
+        ti2.state = State.QUEUED
+        session.merge(ti1)
+        session.merge(ti2)
+        session.commit()
+
+        self.assertEqual(State.RUNNING, dr1.state)
+        self.assertEqual(2, DAG.get_num_task_instances(dag_id, dag.task_ids,
+            states=[State.RUNNING, State.QUEUED], session=session))
+
+        # create second dag run
+        dr2 = scheduler.create_dag_run(dag)
+        ti3 = TI(task1, dr2.execution_date)
+        ti4 = TI(task2, dr2.execution_date)
+        ti3.refresh_from_db()
+        ti4.refresh_from_db()
+        # manually set to scheduled so we can pick them up
+        ti3.state = State.SCHEDULED
+        ti4.state = State.SCHEDULED
+        session.merge(ti3)
+        session.merge(ti4)
+        session.commit()
+
+        self.assertEqual(State.RUNNING, dr2.state)
+
+        scheduler._execute_task_instances(dagbag, [State.SCHEDULED])
+
+        # check that concurrency is respected
+        ti1.refresh_from_db()
+        ti2.refresh_from_db()
+        ti3.refresh_from_db()
+        ti4.refresh_from_db()
+        self.assertEqual(3, DAG.get_num_task_instances(dag_id, dag.task_ids,
+            states=[State.RUNNING, State.QUEUED], session=session))
+        self.assertEqual(State.RUNNING, ti1.state)
+        self.assertEqual(State.QUEUED, ti2.state)
+        six.assertCountEqual(self, [State.QUEUED, State.SCHEDULED], 
[ti3.state, ti4.state])
+
+        session.close()
+
     @provide_session
     def evaluate_dagrun(
             self,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9070a827/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 6673c04..83183f8 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -194,6 +194,44 @@ class DagTest(unittest.TestCase):
 
         self.assertEquals(tuple(), dag.topological_sort())
 
+    def test_get_num_task_instances(self):
+        test_dag_id = 'test_get_num_task_instances_dag'
+        test_task_id = 'task_1'
+
+        test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE)
+        test_task = DummyOperator(task_id=test_task_id, dag=test_dag)
+
+        ti1 = TI(task=test_task, execution_date=DEFAULT_DATE)
+        ti1.state = None
+        ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + 
datetime.timedelta(days=1))
+        ti2.state = State.RUNNING
+        ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + 
datetime.timedelta(days=2))
+        ti3.state = State.QUEUED
+        ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + 
datetime.timedelta(days=3))
+        ti4.state = State.RUNNING
+        session = settings.Session()
+        session.merge(ti1)
+        session.merge(ti2)
+        session.merge(ti3)
+        session.merge(ti4)
+        session.commit()
+
+        self.assertEqual(0, DAG.get_num_task_instances(test_dag_id, 
['fakename'],
+            session=session))
+        self.assertEqual(4, DAG.get_num_task_instances(test_dag_id, 
[test_task_id],
+            session=session))
+        self.assertEqual(4, DAG.get_num_task_instances(test_dag_id,
+            ['fakename', test_task_id], session=session))
+        self.assertEqual(1, DAG.get_num_task_instances(test_dag_id, 
[test_task_id],
+            states=[None], session=session))
+        self.assertEqual(2, DAG.get_num_task_instances(test_dag_id, 
[test_task_id],
+            states=[State.RUNNING], session=session))
+        self.assertEqual(3, DAG.get_num_task_instances(test_dag_id, 
[test_task_id],
+            states=[None, State.RUNNING], session=session))
+        self.assertEqual(4, DAG.get_num_task_instances(test_dag_id, 
[test_task_id],
+            states=[None, State.QUEUED, State.RUNNING], session=session))
+        session.close()
+
 
 class DagRunTest(unittest.TestCase):
 

Reply via email to