KevinYang21 closed pull request #4216: [WIP][AIRFLOW-2761] Parallelize enqueue 
in celery executor
URL: https://github.com/apache/incubator-airflow/pull/4216
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index 903c276339..934d97c19b 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -19,11 +19,12 @@
 
 from builtins import range
 
+# To avoid circular imports
+import airflow.utils.dag_processing
 from airflow import configuration
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import State
 
-
 PARALLELISM = configuration.conf.getint('core', 'PARALLELISM')
 
 
@@ -50,11 +51,11 @@ def start(self):  # pragma: no cover
         """
         pass
 
-    def queue_command(self, task_instance, command, priority=1, queue=None):
-        key = task_instance.key
+    def queue_command(self, simple_task_instance, command, priority=1, 
queue=None):
+        key = simple_task_instance.key
         if key not in self.queued_tasks and key not in self.running:
             self.log.info("Adding to queue: %s", command)
-            self.queued_tasks[key] = (command, priority, queue, task_instance)
+            self.queued_tasks[key] = (command, priority, queue, 
simple_task_instance)
         else:
             self.log.info("could not queue task {}".format(key))
 
@@ -86,7 +87,7 @@ def queue_task_instance(
             pickle_id=pickle_id,
             cfg_path=cfg_path)
         self.queue_command(
-            task_instance,
+            airflow.utils.dag_processing.SimpleTaskInstance(task_instance),
             command,
             priority=task_instance.task.priority_weight_total,
             queue=task_instance.task.queue)
@@ -124,26 +125,13 @@ def heartbeat(self):
             key=lambda x: x[1][1],
             reverse=True)
         for i in range(min((open_slots, len(self.queued_tasks)))):
-            key, (command, _, queue, ti) = sorted_queue.pop(0)
-            # TODO(jlowin) without a way to know what Job ran which tasks,
-            # there is a danger that another Job started running a task
-            # that was also queued to this executor. This is the last chance
-            # to check if that happened. The most probable way is that a
-            # Scheduler tried to run a task that was originally queued by a
-            # Backfill. This fix reduces the probability of a collision but
-            # does NOT eliminate it.
+            key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
             self.queued_tasks.pop(key)
-            ti.refresh_from_db()
-            if ti.state != State.RUNNING:
-                self.running[key] = command
-                self.execute_async(key=key,
-                                   command=command,
-                                   queue=queue,
-                                   executor_config=ti.executor_config)
-            else:
-                self.log.info(
-                    'Task is already running, not sending to '
-                    'executor: {}'.format(key))
+            self.running[key] = command
+            self.execute_async(key=key,
+                               command=command,
+                               queue=queue,
+                               executor_config=simple_ti.executor_config)
 
         # Calling child class sync method
         self.log.debug("Calling the %s sync method", self.__class__)
@@ -151,7 +139,7 @@ def heartbeat(self):
 
     def change_state(self, key, state):
         self.log.debug("Changing state: {}".format(key))
-        self.running.pop(key)
+        self.running.pop(key, None)
         self.event_buffer[key] = state
 
     def fail(self, key):
diff --git a/airflow/executors/celery_executor.py 
b/airflow/executors/celery_executor.py
index 0de48b4d39..0e71778ecc 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -33,10 +33,13 @@
 from airflow.executors.base_executor import BaseExecutor
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.module_loading import import_string
+from airflow.utils.timeout import timeout
 
 # Make it constant for unit test.
 CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'
 
+CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'
+
 '''
 To start the celery worker, run the command:
 airflow worker
@@ -55,12 +58,12 @@
 
 
 @app.task
-def execute_command(command):
+def execute_command(command_to_exec):
     log = LoggingMixin().log
-    log.info("Executing command in Celery: %s", command)
+    log.info("Executing command in Celery: %s", command_to_exec)
     env = os.environ.copy()
     try:
-        subprocess.check_call(command, stderr=subprocess.STDOUT,
+        subprocess.check_call(command_to_exec, stderr=subprocess.STDOUT,
                               close_fds=True, env=env)
     except subprocess.CalledProcessError as e:
         log.exception('execute_command encountered a CalledProcessError')
@@ -95,9 +98,10 @@ def fetch_celery_task_state(celery_task):
     """
 
     try:
-        # Accessing state property of celery task will make actual network 
request
-        # to get the current state of the task.
-        res = (celery_task[0], celery_task[1].state)
+        with timeout(seconds=2):
+            # Accessing state property of celery task will make actual network 
request
+            # to get the current state of the task.
+            res = (celery_task[0], celery_task[1].state)
     except Exception as e:
         exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
                                                               
traceback.format_exc())
@@ -105,6 +109,19 @@ def fetch_celery_task_state(celery_task):
     return res
 
 
+def send_task_to_executor(task_tuple):
+    key, simple_ti, command, queue, task = task_tuple
+    try:
+        with timeout(seconds=2):
+            result = task.apply_async(args=[command], queue=queue)
+    except Exception as e:
+        exception_traceback = "Celery Task ID: {}\n{}".format(key,
+                                                              
traceback.format_exc())
+        result = ExceptionWithTraceback(e, exception_traceback)
+
+    return key, command, result
+
+
 class CeleryExecutor(BaseExecutor):
     """
     CeleryExecutor is recommended for production use of Airflow. It allows
@@ -135,16 +152,16 @@ def start(self):
             'Starting Celery Executor using {} processes for syncing'.format(
                 self._sync_parallelism))
 
-    def execute_async(self, key, command,
-                      queue=DEFAULT_CELERY_CONFIG['task_default_queue'],
-                      executor_config=None):
-        self.log.info("[celery] queuing {key} through celery, "
-                      "queue={queue}".format(**locals()))
-        self.tasks[key] = execute_command.apply_async(
-            args=[command], queue=queue)
-        self.last_state[key] = celery_states.PENDING
+    def _num_tasks_per_send_process(self, to_send_count):
+        """
+        How many Celery tasks should each worker process send.
+        :return: Number of tasks that should be sent per process
+        :rtype: int
+        """
+        return max(1,
+                   int(math.ceil(1.0 * to_send_count / 
self._sync_parallelism)))
 
-    def _num_tasks_per_process(self):
+    def _num_tasks_per_fetch_process(self):
         """
         How many Celery tasks should be sent to each worker process.
         :return: Number of tasks that should be used per process
@@ -153,6 +170,71 @@ def _num_tasks_per_process(self):
         return max(1,
                    int(math.ceil(1.0 * len(self.tasks) / 
self._sync_parallelism)))
 
+    def heartbeat(self):
+        # Triggering new jobs
+        if not self.parallelism:
+            open_slots = len(self.queued_tasks)
+        else:
+            open_slots = self.parallelism - len(self.running)
+
+        self.log.debug("{} running task instances".format(len(self.running)))
+        self.log.debug("{} in queue".format(len(self.queued_tasks)))
+        self.log.debug("{} open slots".format(open_slots))
+
+        sorted_queue = sorted(
+            [(k, v) for k, v in self.queued_tasks.items()],
+            key=lambda x: x[1][1],
+            reverse=True)
+
+        task_tuples_to_send = []
+
+        for i in range(min((open_slots, len(self.queued_tasks)))):
+            key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
+            task_tuples_to_send.append((key, simple_ti, command, queue,
+                                        execute_command))
+
+        cached_celery_backend = None
+        if task_tuples_to_send:
+            tasks = [t[4] for t in task_tuples_to_send]
+
+            # Celery state queries will stuck if we do not use one same backend
+            # for all tasks.
+            cached_celery_backend = tasks[0].backend
+
+        if task_tuples_to_send:
+            # Use chunking instead of a work queue to reduce context switching
+            # since tasks are roughly uniform in size
+            chunksize = 
self._num_tasks_per_send_process(len(task_tuples_to_send))
+            num_processes = min(len(task_tuples_to_send), 
self._sync_parallelism)
+
+            send_pool = Pool(processes=num_processes)
+            key_and_async_results = send_pool.map(
+                send_task_to_executor,
+                task_tuples_to_send,
+                chunksize=chunksize)
+
+            send_pool.close()
+            send_pool.join()
+            self.log.debug('Sent all tasks.')
+
+            for key, command, result in key_and_async_results:
+                if isinstance(result, ExceptionWithTraceback):
+                    self.log.error(
+                        CELERY_SEND_ERR_MSG_HEADER + ":{}\n{}\n".format(
+                            result.exception, result.traceback))
+                elif result is not None:
+                    # Only pops when enqueued successfully, otherwise keep it
+                    # and expect scheduler loop to deal with it.
+                    self.queued_tasks.pop(key)
+                    result.backend = cached_celery_backend
+                    self.running[key] = command
+                    self.tasks[key] = result
+                    self.last_state[key] = celery_states.PENDING
+
+        # Calling child class sync method
+        self.log.debug("Calling the {} sync method".format(self.__class__))
+        self.sync()
+
     def sync(self):
         num_processes = min(len(self.tasks), self._sync_parallelism)
         if num_processes == 0:
@@ -167,7 +249,7 @@ def sync(self):
 
         # Use chunking instead of a work queue to reduce context switching 
since tasks are
         # roughly uniform in size
-        chunksize = self._num_tasks_per_process()
+        chunksize = self._num_tasks_per_fetch_process()
 
         self.log.debug("Waiting for inquiries to complete...")
         task_keys_to_states = self._sync_pool.map(
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 9e68fad797..ca124bf1f1 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -52,6 +52,7 @@
                                           DagFileProcessorAgent,
                                           SimpleDag,
                                           SimpleDagBag,
+                                          SimpleTaskInstance,
                                           list_py_file_paths)
 from airflow.utils.db import create_session, provide_session
 from airflow.utils.email import get_email_address_list, send_email
@@ -598,6 +599,7 @@ def __init__(
                                             'run_duration')
 
         self.processor_agent = None
+        self._last_loop = False
 
         signal.signal(signal.SIGINT, self._exit_gracefully)
         signal.signal(signal.SIGTERM, self._exit_gracefully)
@@ -1228,13 +1230,13 @@ def _change_state_for_executable_task_instances(self, 
task_instances,
                                                     acceptable_states, 
session=None):
         """
         Changes the state of task instances in the list with one of the given 
states
-        to QUEUED atomically, and returns the TIs changed.
+        to QUEUED atomically, and returns the TIs changed in 
SimpleTaskInstance format.
 
         :param task_instances: TaskInstances to change the state of
         :type task_instances: List[TaskInstance]
         :param acceptable_states: Filters the TaskInstances updated to be in 
these states
         :type acceptable_states: Iterable[State]
-        :return: List[TaskInstance]
+        :return: List[SimpleTaskInstance]
         """
         if len(task_instances) == 0:
             session.commit()
@@ -1276,81 +1278,57 @@ def _change_state_for_executable_task_instances(self, 
task_instances,
                                          else task_instance.queued_dttm)
             session.merge(task_instance)
 
-        # save which TIs we set before session expires them
-        filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id,
-                                  TI.task_id == ti.task_id,
-                                  TI.execution_date == ti.execution_date)
-                                  for ti in tis_to_set_to_queued])
-        session.commit()
-
-        # requery in batches since above was expired by commit
+        # Generate a list of SimpleTaskInstance for the use of queuing
+        # them in the executor.
+        simple_task_instances = [SimpleTaskInstance(ti) for ti in
+                                 tis_to_set_to_queued]
 
-        def query(result, items):
-            tis_to_be_queued = (
-                session
-                .query(TI)
-                .filter(or_(*items))
-                .all())
-            task_instance_str = "\n\t".join(
-                ["{}".format(x) for x in tis_to_be_queued])
-            self.log.info("Setting the following {} tasks to queued 
state:\n\t{}"
-                          .format(len(tis_to_be_queued),
-                                  task_instance_str))
-            return result + tis_to_be_queued
-
-        tis_to_be_queued = helpers.reduce_in_chunks(query,
-                                                    filter_for_ti_enqueue,
-                                                    [],
-                                                    self.max_tis_per_query)
+        task_instance_str = "\n\t".join(
+            ["{}".format(x) for x in tis_to_set_to_queued])
 
-        return tis_to_be_queued
+        session.commit()
+        self.logger.info("Setting the following {} tasks to queued 
state:\n\t{}"
+                         .format(len(tis_to_set_to_queued), task_instance_str))
+        return simple_task_instances
 
-    def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, 
task_instances):
+    def _enqueue_task_instances_with_queued_state(self, simple_dag_bag,
+                                                  simple_task_instances):
         """
         Takes task_instances, which should have been set to queued, and 
enqueues them
         with the executor.
 
-        :param task_instances: TaskInstances to enqueue
-        :type task_instances: List[TaskInstance]
+        :param simple_task_instances: TaskInstances to enqueue
+        :type simple_task_instances: List[SimpleTaskInstance]
         :param simple_dag_bag: Should contains all of the task_instances' dags
         :type simple_dag_bag: SimpleDagBag
         """
         TI = models.TaskInstance
         # actually enqueue them
-        for task_instance in task_instances:
-            simple_dag = simple_dag_bag.get_dag(task_instance.dag_id)
+        for simple_task_instance in simple_task_instances:
+            simple_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id)
             command = TI.generate_command(
-                task_instance.dag_id,
-                task_instance.task_id,
-                task_instance.execution_date,
+                simple_task_instance.dag_id,
+                simple_task_instance.task_id,
+                simple_task_instance.execution_date,
                 local=True,
                 mark_success=False,
                 ignore_all_deps=False,
                 ignore_depends_on_past=False,
                 ignore_task_deps=False,
                 ignore_ti_state=False,
-                pool=task_instance.pool,
+                pool=simple_task_instance.pool,
                 file_path=simple_dag.full_filepath,
                 pickle_id=simple_dag.pickle_id)
 
-            priority = task_instance.priority_weight
-            queue = task_instance.queue
+            priority = simple_task_instance.priority_weight
+            queue = simple_task_instance.queue
             self.log.info(
                 "Sending %s to executor with priority %s and queue %s",
-                task_instance.key, priority, queue
+                simple_task_instance.key, priority, queue
             )
 
-            # save attributes so sqlalchemy doesnt expire them
-            copy_dag_id = task_instance.dag_id
-            copy_task_id = task_instance.task_id
-            copy_execution_date = task_instance.execution_date
-            make_transient(task_instance)
-            task_instance.dag_id = copy_dag_id
-            task_instance.task_id = copy_task_id
-            task_instance.execution_date = copy_execution_date
-
             self.executor.queue_command(
-                task_instance,
+                simple_task_instance,
                 command,
                 priority=priority,
                 queue=queue)
@@ -1374,24 +1352,65 @@ def _execute_task_instances(self,
         :type simple_dag_bag: SimpleDagBag
         :param states: Execute TaskInstances in these states
         :type states: Tuple[State]
-        :return: None
+        :return: Number of task instance with state changed.
         """
         executable_tis = self._find_executable_task_instances(simple_dag_bag, 
states,
                                                               session=session)
 
         def query(result, items):
-            tis_with_state_changed = 
self._change_state_for_executable_task_instances(
-                items,
-                states,
-                session=session)
+            simple_tis_with_state_changed = \
+                self._change_state_for_executable_task_instances(items,
+                                                                 states,
+                                                                 
session=session)
             self._enqueue_task_instances_with_queued_state(
                 simple_dag_bag,
-                tis_with_state_changed)
+                simple_tis_with_state_changed)
             session.commit()
-            return result + len(tis_with_state_changed)
+            return result + len(simple_tis_with_state_changed)
 
         return helpers.reduce_in_chunks(query, executable_tis, 0, 
self.max_tis_per_query)
 
+    @provide_session
+    def _change_state_for_tasks_failed_to_execute(self, session):
+        """
+        If there are tasks left over in the executor,
+        we set them back to SCHEDULED to avoid creating hanging tasks.
+        :param session:
+        :return:
+        """
+        if self.executor.queued_tasks:
+            TI = models.TaskInstance
+            filter_for_ti_state_change = (
+                [and_(
+                    TI.dag_id == dag_id,
+                    TI.task_id == task_id,
+                    TI.execution_date == execution_date,
+                    # The TI.try_number will return raw try_number+1 since the
+                    # ti is not running. And we need to -1 to match the DB 
record.
+                    TI._try_number == try_number 
test_change_state_for_tasks_failed_to_execute- 1,
+                    TI.state == State.QUEUED)
+                    for dag_id, task_id, execution_date, try_number
+                    in self.executor.queued_tasks.keys()])
+            ti_query = (session.query(TI)
+                        .filter(or_(*filter_for_ti_state_change)))
+            tis_to_set_to_scheduled = (ti_query
+                                       .with_for_update()
+                                       .all())
+            if len(tis_to_set_to_scheduled) == 0:
+                session.commit()
+                return
+
+            # set TIs to queued state
+            for task_instance in tis_to_set_to_scheduled:
+                task_instance.state = State.SCHEDULED
+
+            task_instance_str = "\n\t".join(
+                ["{}".format(x) for x in tis_to_set_to_scheduled])
+
+            session.commit()
+            self.logger.info("Set the follow tasks to scheduled state:\n\t{}"
+                             .format(task_instance_str))
+
     def _process_dags(self, dagbag, dags, tis_out):
         """
         Iterates over the dags and processes them. Processing includes:
@@ -1507,6 +1526,8 @@ def processor_factory(file_path, zombies):
 
         try:
             self._execute_helper()
+        except Exception:
+            self.log.exception("Exception when executing execute_helper")
         finally:
             self.processor_agent.end()
             self.log.info("Exited execute loop")
@@ -1557,6 +1578,7 @@ def _execute_helper(self):
 
             self.log.info("Harvesting DAG parsing results")
             simple_dags = self.processor_agent.harvest_simple_dags()
+            self.log.debug("Harvested {} SimpleDAGs".format(len(simple_dags)))
 
             # Send tasks for execution if available
             simple_dag_bag = SimpleDagBag(simple_dags)
@@ -1593,6 +1615,8 @@ def _execute_helper(self):
             self.log.debug("Heartbeating the executor")
             self.executor.heartbeat()
 
+            self._change_state_for_tasks_failed_to_execute()
+
             # Process events from the executor
             self._process_executor_events(simple_dag_bag)
 
@@ -1612,8 +1636,13 @@ def _execute_helper(self):
             self.log.debug("Sleeping for %.2f seconds", 
self._processor_poll_interval)
             time.sleep(self._processor_poll_interval)
 
-            # Exit early for a test mode
+            # Exit early for a test mode, run one additional scheduler loop
+            # to reduce the possibility that parsed DAG was put into the queue
+            # by the DAG manager but not yet received by DAG agent.
             if self.processor_agent.done:
+                self._last_loop = True
+
+            if self._last_loop:
                 self.log.info("Exiting scheduler loop as all files"
                               " have been processed {} 
times".format(self.num_runs))
                 break
diff --git a/airflow/models.py b/airflow/models.py
index 9ab2348cc2..8870a2921e 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -22,12 +22,12 @@
 from __future__ import print_function
 from __future__ import unicode_literals
 
+import copy
+from collections import defaultdict, namedtuple
 
+from builtins import ImportError as BuiltinImportError, bytes, object, str
 from future.standard_library import install_aliases
 
-from builtins import str, object, bytes, ImportError as BuiltinImportError
-import copy
-from collections import namedtuple, defaultdict
 try:
     # Fix Python > 3.7 deprecation
     from collections.abc import Hashable
diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py
index 47f473e9aa..62c4c91968 100644
--- a/airflow/utils/dag_processing.py
+++ b/airflow/utils/dag_processing.py
@@ -146,6 +146,21 @@ def __init__(self, ti):
         self._end_date = ti.end_date
         self._try_number = ti.try_number
         self._state = ti.state
+        self._executor_config = ti.executor_config
+        if hasattr(ti, 'run_as_user'):
+            self._run_as_user = ti.run_as_user
+        else:
+            self._run_as_user = None
+        if hasattr(ti, 'pool'):
+            self._pool = ti.pool
+        else:
+            self._pool = None
+        if hasattr(ti, 'priority_weight'):
+            self._priority_weight = ti.priority_weight
+        else:
+            self._priority_weight = None
+        self._queue = ti.queue
+        self._key = ti.key
 
     @property
     def dag_id(self):
@@ -175,6 +190,48 @@ def try_number(self):
     def state(self):
         return self._state
 
+    @property
+    def pool(self):
+        return self._pool
+
+    @property
+    def priority_weight(self):
+        return self._priority_weight
+
+    @property
+    def queue(self):
+        return self._queue
+
+    @property
+    def key(self):
+        return self._key
+
+    @property
+    def executor_config(self):
+        return self._executor_config
+
+    @provide_session
+    def construct_task_instance(self, session=None, lock_for_update=False):
+        """
+        Construct a TaskInstance from the database based on the primary key
+        :param session: DB session.
+        :param lock_for_update: if True, indicates that the database should
+            lock the TaskInstance (issuing a FOR UPDATE clause) until the
+            session is committed.
+        """
+        TI = airflow.models.TaskInstance
+
+        qry = session.query(TI).filter(
+            TI.dag_id == self._dag_id,
+            TI.task_id == self._task_id,
+            TI.execution_date == self._execution_date)
+
+        if lock_for_update:
+            ti = qry.with_for_update().first()
+        else:
+            ti = qry.first()
+        return ti
+
 
 class SimpleDagBag(BaseDagBag):
     """
@@ -566,11 +623,16 @@ def end(self):
         Terminate (and then kill) the manager process launched.
         :return:
         """
-        if not self._process or not self._process.is_alive():
+        if not self._process:
             self.log.warn('Ending without manager process.')
             return
         this_process = psutil.Process(os.getpid())
-        manager_process = psutil.Process(self._process.pid)
+        try:
+            manager_process = psutil.Process(self._process.pid)
+        except psutil.NoSuchProcess:
+            self.log.info("Manager process not running.")
+            return
+
         # First try SIGTERM
         if manager_process.is_running() \
                 and manager_process.pid in [x.pid for x in 
this_process.children()]:
diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py
index a86b9d357b..f648005877 100644
--- a/airflow/utils/timeout.py
+++ b/airflow/utils/timeout.py
@@ -23,6 +23,7 @@
 from __future__ import unicode_literals
 
 import signal
+import os
 
 from airflow.exceptions import AirflowTaskTimeout
 from airflow.utils.log.logging_mixin import LoggingMixin
@@ -35,10 +36,10 @@ class timeout(LoggingMixin):
 
     def __init__(self, seconds=1, error_message='Timeout'):
         self.seconds = seconds
-        self.error_message = error_message
+        self.error_message = error_message + ', PID: ' + str(os.getpid())
 
     def handle_timeout(self, signum, frame):
-        self.log.error("Process timed out")
+        self.log.error("Process timed out, PID: " + str(os.getpid()))
         raise AirflowTaskTimeout(self.error_message)
 
     def __enter__(self):
diff --git a/tests/executors/test_celery_executor.py 
b/tests/executors/test_celery_executor.py
index 380201d30a..5a7d6e984c 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -18,12 +18,17 @@
 # under the License.
 import sys
 import unittest
+from multiprocessing import Pool
+
 import mock
 from celery.contrib.testing.worker import start_worker
 
-from airflow.executors.celery_executor import CeleryExecutor
-from airflow.executors.celery_executor import app
+from airflow.executors import celery_executor
 from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER
+from airflow.executors.celery_executor import (CeleryExecutor, 
celery_configuration,
+                                               send_task_to_executor, 
execute_command)
+from airflow.executors.celery_executor import app
+from celery import states as celery_states
 from airflow.utils.state import State
 
 from airflow import configuration
@@ -40,16 +45,37 @@ def test_celery_integration(self):
         executor = CeleryExecutor()
         executor.start()
         with start_worker(app=app, logfile=sys.stdout, loglevel='debug'):
-
             success_command = ['true', 'some_parameter']
             fail_command = ['false', 'some_parameter']
 
-            executor.execute_async(key='success', command=success_command)
-            # errors are propagated for some reason
-            try:
-                executor.execute_async(key='fail', command=fail_command)
-            except Exception:
-                pass
+            cached_celery_backend = execute_command.backend
+            task_tuples_to_send = [('success', 'fake_simple_ti', 
success_command,
+                                    celery_configuration['task_default_queue'],
+                                    execute_command),
+                                   ('fail', 'fake_simple_ti', fail_command,
+                                    celery_configuration['task_default_queue'],
+                                    execute_command)]
+
+            chunksize = 
executor._num_tasks_per_send_process(len(task_tuples_to_send))
+            num_processes = min(len(task_tuples_to_send), 
executor._sync_parallelism)
+
+            send_pool = Pool(processes=num_processes)
+            key_and_async_results = send_pool.map(
+                send_task_to_executor,
+                task_tuples_to_send,
+                chunksize=chunksize)
+
+            send_pool.close()
+            send_pool.join()
+
+            for key, command, result in key_and_async_results:
+                # Only pops when enqueued successfully, otherwise keep it
+                # and expect scheduler loop to deal with it.
+                result.backend = cached_celery_backend
+                executor.running[key] = command
+                executor.tasks[key] = result
+                executor.last_state[key] = celery_states.PENDING
+
             executor.running['success'] = True
             executor.running['fail'] = True
 
@@ -64,6 +90,21 @@ def test_celery_integration(self):
         self.assertNotIn('success', executor.last_state)
         self.assertNotIn('fail', executor.last_state)
 
+    def test_error_sending_task(self):
+        @app.task
+        def fake_execute_command():
+            pass
+
+        # fake_execute_command takes no arguments while execute_command takes 
1,
+        # which will cause TypeError when calling task.apply_async()
+        celery_executor.execute_command = fake_execute_command
+        executor = CeleryExecutor()
+        value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti'
+        executor.queued_tasks['key'] = value_tuple
+        executor.heartbeat()
+        self.assertEquals(1, len(executor.queued_tasks))
+        self.assertEquals(executor.queued_tasks['key'], value_tuple)
+
     def test_exception_propagation(self):
         @app.task
         def fake_celery_task():
diff --git a/tests/executors/test_executor.py b/tests/executors/test_executor.py
index aab66644b8..366ea8c967 100644
--- a/tests/executors/test_executor.py
+++ b/tests/executors/test_executor.py
@@ -46,7 +46,8 @@ def heartbeat(self):
                 ti = self._running.pop()
                 ti.set_state(State.SUCCESS, session)
             for key, val in list(self.queued_tasks.items()):
-                (command, priority, queue, ti) = val
+                (command, priority, queue, simple_ti) = val
+                ti = simple_ti.construct_task_instance()
                 ti.set_state(State.RUNNING, session)
                 self._running.append(ti)
                 self.queued_tasks.pop(key)
diff --git a/tests/test_jobs.py b/tests/test_jobs.py
index af8ccc6c2e..fb161c1b69 100644
--- a/tests/test_jobs.py
+++ b/tests/test_jobs.py
@@ -2032,6 +2032,54 @@ def test_change_state_for_tis_without_dagrun(self):
         ti2.refresh_from_db(session=session)
         self.assertEqual(ti2.state, State.SCHEDULED)
 
+    def test_change_state_for_tasks_failed_to_execute(self):
+        dag = DAG(
+            dag_id='dag_id',
+            start_date=DEFAULT_DATE)
+
+        task = DummyOperator(
+            task_id='task_id',
+            dag=dag,
+            owner='airflow')
+
+        # If there's no left over task in executor.queued_tasks, nothing 
happens
+        session = settings.Session()
+        scheduler_job = SchedulerJob()
+        mock_logger = mock.MagicMock()
+        test_executor = TestExecutor()
+        scheduler_job.executor = test_executor
+        scheduler_job._logger = mock_logger
+        scheduler_job._change_state_for_tasks_failed_to_execute()
+        mock_logger.info.assert_not_called()
+
+        # Tasks failed to execute with QUEUED state will be set to SCHEDULED 
state.
+        session.query(TI).delete()
+        session.commit()
+        key = 'dag_id', 'task_id', DEFAULT_DATE, 1
+        test_executor.queued_tasks[key] = 'value'
+        ti = TI(task, DEFAULT_DATE)
+        ti.state = State.QUEUED
+        session.merge(ti)
+        session.commit()
+
+        scheduler_job._change_state_for_tasks_failed_to_execute()
+
+        ti.refresh_from_db()
+        self.assertEquals(State.SCHEDULED, ti.state)
+
+        # Tasks failed to execute with RUNNING state will not be set to 
SCHEDULED state.
+        session.query(TI).delete()
+        session.commit()
+        ti.state = State.RUNNING
+
+        session.merge(ti)
+        session.commit()
+
+        scheduler_job._change_state_for_tasks_failed_to_execute()
+
+        ti.refresh_from_db()
+        self.assertEquals(State.RUNNING, ti.state)
+
     def test_execute_helper_reset_orphaned_tasks(self):
         session = settings.Session()
         dag = DAG(
@@ -2949,7 +2997,8 @@ def run_with_error(task):
                 pass
 
         ti_tuple = six.next(six.itervalues(executor.queued_tasks))
-        (command, priority, queue, ti) = ti_tuple
+        (command, priority, queue, simple_ti) = ti_tuple
+        ti = simple_ti.construct_task_instance()
         ti.task = dag_task1
 
         self.assertEqual(ti.try_number, 1)
@@ -2970,15 +3019,21 @@ def run_with_error(task):
         # removing self.assertEqual(ti.state, State.SCHEDULED)
         # as scheduler will move state from SCHEDULED to QUEUED
 
-        # now the executor has cleared and it should be allowed the re-queue
+        # now the executor has cleared and it should be allowed the re-queue,
+        # but tasks stay in the executor.queued_tasks after 
executor.heartbeat()
+        # will be set back to SCHEDULED state
         executor.queued_tasks.clear()
         do_schedule()
         ti.refresh_from_db()
-        self.assertEqual(ti.state, State.QUEUED)
-        # calling below again in order to ensure with try_number 2,
-        # scheduler doesn't put task in queue
+
+        self.assertEqual(ti.state, State.SCHEDULED)
+
+        # To verify that task does get re-queued.
+        executor.queued_tasks.clear()
+        executor.do_update = True
         do_schedule()
-        self.assertEquals(1, len(executor.queued_tasks))
+        ti.refresh_from_db()
+        self.assertEqual(ti.state, State.RUNNING)
 
     @unittest.skipUnless("INTEGRATION" in os.environ, "Can only run end to 
end")
     def test_retry_handling_job(self):
@@ -3023,8 +3078,8 @@ def test_scheduler_run_duration(self):
         logging.info("Test ran in %.2fs, expected %.2fs",
                      run_duration,
                      expected_run_duration)
-        # 5s to wait for child process to exit and 1s dummy sleep
-        # in scheduler loop to prevent excessive logs.
+        # 5s to wait for child process to exit, 1s dummy sleep
+        # in scheduler loop to prevent excessive logs and 1s for last loop to 
finish.
         self.assertLess(run_duration - expected_run_duration, 6.0)
 
     def test_dag_with_system_exit(self):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to