KevinYang21 closed pull request #4216: [WIP][AIRFLOW-2761] Parallelize enqueue in celery executor URL: https://github.com/apache/incubator-airflow/pull/4216
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 903c276339..934d97c19b 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -19,11 +19,12 @@ from builtins import range +# To avoid circular imports +import airflow.utils.dag_processing from airflow import configuration from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State - PARALLELISM = configuration.conf.getint('core', 'PARALLELISM') @@ -50,11 +51,11 @@ def start(self): # pragma: no cover """ pass - def queue_command(self, task_instance, command, priority=1, queue=None): - key = task_instance.key + def queue_command(self, simple_task_instance, command, priority=1, queue=None): + key = simple_task_instance.key if key not in self.queued_tasks and key not in self.running: self.log.info("Adding to queue: %s", command) - self.queued_tasks[key] = (command, priority, queue, task_instance) + self.queued_tasks[key] = (command, priority, queue, simple_task_instance) else: self.log.info("could not queue task {}".format(key)) @@ -86,7 +87,7 @@ def queue_task_instance( pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command( - task_instance, + airflow.utils.dag_processing.SimpleTaskInstance(task_instance), command, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue) @@ -124,26 +125,13 @@ def heartbeat(self): key=lambda x: x[1][1], reverse=True) for i in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, queue, ti) = sorted_queue.pop(0) - # TODO(jlowin) without a way to know what Job ran which tasks, - # there is a danger that another Job started running a task - # that was also queued to this executor. This is the last chance - # to check if that happened. The most probable way is that a - # Scheduler tried to run a task that was originally queued by a - # Backfill. This fix reduces the probability of a collision but - # does NOT eliminate it. + key, (command, _, queue, simple_ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) - ti.refresh_from_db() - if ti.state != State.RUNNING: - self.running[key] = command - self.execute_async(key=key, - command=command, - queue=queue, - executor_config=ti.executor_config) - else: - self.log.info( - 'Task is already running, not sending to ' - 'executor: {}'.format(key)) + self.running[key] = command + self.execute_async(key=key, + command=command, + queue=queue, + executor_config=simple_ti.executor_config) # Calling child class sync method self.log.debug("Calling the %s sync method", self.__class__) @@ -151,7 +139,7 @@ def heartbeat(self): def change_state(self, key, state): self.log.debug("Changing state: {}".format(key)) - self.running.pop(key) + self.running.pop(key, None) self.event_buffer[key] = state def fail(self, key): diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 0de48b4d39..0e71778ecc 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -33,10 +33,13 @@ from airflow.executors.base_executor import BaseExecutor from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.timeout import timeout # Make it constant for unit test. CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state' +CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task' + ''' To start the celery worker, run the command: airflow worker @@ -55,12 +58,12 @@ @app.task -def execute_command(command): +def execute_command(command_to_exec): log = LoggingMixin().log - log.info("Executing command in Celery: %s", command) + log.info("Executing command in Celery: %s", command_to_exec) env = os.environ.copy() try: - subprocess.check_call(command, stderr=subprocess.STDOUT, + subprocess.check_call(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env) except subprocess.CalledProcessError as e: log.exception('execute_command encountered a CalledProcessError') @@ -95,9 +98,10 @@ def fetch_celery_task_state(celery_task): """ try: - # Accessing state property of celery task will make actual network request - # to get the current state of the task. - res = (celery_task[0], celery_task[1].state) + with timeout(seconds=2): + # Accessing state property of celery task will make actual network request + # to get the current state of the task. + res = (celery_task[0], celery_task[1].state) except Exception as e: exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0], traceback.format_exc()) @@ -105,6 +109,19 @@ def fetch_celery_task_state(celery_task): return res +def send_task_to_executor(task_tuple): + key, simple_ti, command, queue, task = task_tuple + try: + with timeout(seconds=2): + result = task.apply_async(args=[command], queue=queue) + except Exception as e: + exception_traceback = "Celery Task ID: {}\n{}".format(key, + traceback.format_exc()) + result = ExceptionWithTraceback(e, exception_traceback) + + return key, command, result + + class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. It allows @@ -135,16 +152,16 @@ def start(self): 'Starting Celery Executor using {} processes for syncing'.format( self._sync_parallelism)) - def execute_async(self, key, command, - queue=DEFAULT_CELERY_CONFIG['task_default_queue'], - executor_config=None): - self.log.info("[celery] queuing {key} through celery, " - "queue={queue}".format(**locals())) - self.tasks[key] = execute_command.apply_async( - args=[command], queue=queue) - self.last_state[key] = celery_states.PENDING + def _num_tasks_per_send_process(self, to_send_count): + """ + How many Celery tasks should each worker process send. + :return: Number of tasks that should be sent per process + :rtype: int + """ + return max(1, + int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) - def _num_tasks_per_process(self): + def _num_tasks_per_fetch_process(self): """ How many Celery tasks should be sent to each worker process. :return: Number of tasks that should be used per process @@ -153,6 +170,71 @@ def _num_tasks_per_process(self): return max(1, int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) + def heartbeat(self): + # Triggering new jobs + if not self.parallelism: + open_slots = len(self.queued_tasks) + else: + open_slots = self.parallelism - len(self.running) + + self.log.debug("{} running task instances".format(len(self.running))) + self.log.debug("{} in queue".format(len(self.queued_tasks))) + self.log.debug("{} open slots".format(open_slots)) + + sorted_queue = sorted( + [(k, v) for k, v in self.queued_tasks.items()], + key=lambda x: x[1][1], + reverse=True) + + task_tuples_to_send = [] + + for i in range(min((open_slots, len(self.queued_tasks)))): + key, (command, _, queue, simple_ti) = sorted_queue.pop(0) + task_tuples_to_send.append((key, simple_ti, command, queue, + execute_command)) + + cached_celery_backend = None + if task_tuples_to_send: + tasks = [t[4] for t in task_tuples_to_send] + + # Celery state queries will stuck if we do not use one same backend + # for all tasks. + cached_celery_backend = tasks[0].backend + + if task_tuples_to_send: + # Use chunking instead of a work queue to reduce context switching + # since tasks are roughly uniform in size + chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send)) + num_processes = min(len(task_tuples_to_send), self._sync_parallelism) + + send_pool = Pool(processes=num_processes) + key_and_async_results = send_pool.map( + send_task_to_executor, + task_tuples_to_send, + chunksize=chunksize) + + send_pool.close() + send_pool.join() + self.log.debug('Sent all tasks.') + + for key, command, result in key_and_async_results: + if isinstance(result, ExceptionWithTraceback): + self.log.error( + CELERY_SEND_ERR_MSG_HEADER + ":{}\n{}\n".format( + result.exception, result.traceback)) + elif result is not None: + # Only pops when enqueued successfully, otherwise keep it + # and expect scheduler loop to deal with it. + self.queued_tasks.pop(key) + result.backend = cached_celery_backend + self.running[key] = command + self.tasks[key] = result + self.last_state[key] = celery_states.PENDING + + # Calling child class sync method + self.log.debug("Calling the {} sync method".format(self.__class__)) + self.sync() + def sync(self): num_processes = min(len(self.tasks), self._sync_parallelism) if num_processes == 0: @@ -167,7 +249,7 @@ def sync(self): # Use chunking instead of a work queue to reduce context switching since tasks are # roughly uniform in size - chunksize = self._num_tasks_per_process() + chunksize = self._num_tasks_per_fetch_process() self.log.debug("Waiting for inquiries to complete...") task_keys_to_states = self._sync_pool.map( diff --git a/airflow/jobs.py b/airflow/jobs.py index 9e68fad797..ca124bf1f1 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -52,6 +52,7 @@ DagFileProcessorAgent, SimpleDag, SimpleDagBag, + SimpleTaskInstance, list_py_file_paths) from airflow.utils.db import create_session, provide_session from airflow.utils.email import get_email_address_list, send_email @@ -598,6 +599,7 @@ def __init__( 'run_duration') self.processor_agent = None + self._last_loop = False signal.signal(signal.SIGINT, self._exit_gracefully) signal.signal(signal.SIGTERM, self._exit_gracefully) @@ -1228,13 +1230,13 @@ def _change_state_for_executable_task_instances(self, task_instances, acceptable_states, session=None): """ Changes the state of task instances in the list with one of the given states - to QUEUED atomically, and returns the TIs changed. + to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format. :param task_instances: TaskInstances to change the state of :type task_instances: List[TaskInstance] :param acceptable_states: Filters the TaskInstances updated to be in these states :type acceptable_states: Iterable[State] - :return: List[TaskInstance] + :return: List[SimpleTaskInstance] """ if len(task_instances) == 0: session.commit() @@ -1276,81 +1278,57 @@ def _change_state_for_executable_task_instances(self, task_instances, else task_instance.queued_dttm) session.merge(task_instance) - # save which TIs we set before session expires them - filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id, - TI.task_id == ti.task_id, - TI.execution_date == ti.execution_date) - for ti in tis_to_set_to_queued]) - session.commit() - - # requery in batches since above was expired by commit + # Generate a list of SimpleTaskInstance for the use of queuing + # them in the executor. + simple_task_instances = [SimpleTaskInstance(ti) for ti in + tis_to_set_to_queued] - def query(result, items): - tis_to_be_queued = ( - session - .query(TI) - .filter(or_(*items)) - .all()) - task_instance_str = "\n\t".join( - ["{}".format(x) for x in tis_to_be_queued]) - self.log.info("Setting the following {} tasks to queued state:\n\t{}" - .format(len(tis_to_be_queued), - task_instance_str)) - return result + tis_to_be_queued - - tis_to_be_queued = helpers.reduce_in_chunks(query, - filter_for_ti_enqueue, - [], - self.max_tis_per_query) + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_set_to_queued]) - return tis_to_be_queued + session.commit() + self.logger.info("Setting the following {} tasks to queued state:\n\t{}" + .format(len(tis_to_set_to_queued), task_instance_str)) + return simple_task_instances - def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances): + def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, + simple_task_instances): """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. - :param task_instances: TaskInstances to enqueue - :type task_instances: List[TaskInstance] + :param simple_task_instances: TaskInstances to enqueue + :type simple_task_instances: List[SimpleTaskInstance] :param simple_dag_bag: Should contains all of the task_instances' dags :type simple_dag_bag: SimpleDagBag """ TI = models.TaskInstance # actually enqueue them - for task_instance in task_instances: - simple_dag = simple_dag_bag.get_dag(task_instance.dag_id) + for simple_task_instance in simple_task_instances: + simple_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id) command = TI.generate_command( - task_instance.dag_id, - task_instance.task_id, - task_instance.execution_date, + simple_task_instance.dag_id, + simple_task_instance.task_id, + simple_task_instance.execution_date, local=True, mark_success=False, ignore_all_deps=False, ignore_depends_on_past=False, ignore_task_deps=False, ignore_ti_state=False, - pool=task_instance.pool, + pool=simple_task_instance.pool, file_path=simple_dag.full_filepath, pickle_id=simple_dag.pickle_id) - priority = task_instance.priority_weight - queue = task_instance.queue + priority = simple_task_instance.priority_weight + queue = simple_task_instance.queue self.log.info( "Sending %s to executor with priority %s and queue %s", - task_instance.key, priority, queue + simple_task_instance.key, priority, queue ) - # save attributes so sqlalchemy doesnt expire them - copy_dag_id = task_instance.dag_id - copy_task_id = task_instance.task_id - copy_execution_date = task_instance.execution_date - make_transient(task_instance) - task_instance.dag_id = copy_dag_id - task_instance.task_id = copy_task_id - task_instance.execution_date = copy_execution_date - self.executor.queue_command( - task_instance, + simple_task_instance, command, priority=priority, queue=queue) @@ -1374,24 +1352,65 @@ def _execute_task_instances(self, :type simple_dag_bag: SimpleDagBag :param states: Execute TaskInstances in these states :type states: Tuple[State] - :return: None + :return: Number of task instance with state changed. """ executable_tis = self._find_executable_task_instances(simple_dag_bag, states, session=session) def query(result, items): - tis_with_state_changed = self._change_state_for_executable_task_instances( - items, - states, - session=session) + simple_tis_with_state_changed = \ + self._change_state_for_executable_task_instances(items, + states, + session=session) self._enqueue_task_instances_with_queued_state( simple_dag_bag, - tis_with_state_changed) + simple_tis_with_state_changed) session.commit() - return result + len(tis_with_state_changed) + return result + len(simple_tis_with_state_changed) return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query) + @provide_session + def _change_state_for_tasks_failed_to_execute(self, session): + """ + If there are tasks left over in the executor, + we set them back to SCHEDULED to avoid creating hanging tasks. + :param session: + :return: + """ + if self.executor.queued_tasks: + TI = models.TaskInstance + filter_for_ti_state_change = ( + [and_( + TI.dag_id == dag_id, + TI.task_id == task_id, + TI.execution_date == execution_date, + # The TI.try_number will return raw try_number+1 since the + # ti is not running. And we need to -1 to match the DB record. + TI._try_number == try_number test_change_state_for_tasks_failed_to_execute- 1, + TI.state == State.QUEUED) + for dag_id, task_id, execution_date, try_number + in self.executor.queued_tasks.keys()]) + ti_query = (session.query(TI) + .filter(or_(*filter_for_ti_state_change))) + tis_to_set_to_scheduled = (ti_query + .with_for_update() + .all()) + if len(tis_to_set_to_scheduled) == 0: + session.commit() + return + + # set TIs to queued state + for task_instance in tis_to_set_to_scheduled: + task_instance.state = State.SCHEDULED + + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_set_to_scheduled]) + + session.commit() + self.logger.info("Set the follow tasks to scheduled state:\n\t{}" + .format(task_instance_str)) + def _process_dags(self, dagbag, dags, tis_out): """ Iterates over the dags and processes them. Processing includes: @@ -1507,6 +1526,8 @@ def processor_factory(file_path, zombies): try: self._execute_helper() + except Exception: + self.log.exception("Exception when executing execute_helper") finally: self.processor_agent.end() self.log.info("Exited execute loop") @@ -1557,6 +1578,7 @@ def _execute_helper(self): self.log.info("Harvesting DAG parsing results") simple_dags = self.processor_agent.harvest_simple_dags() + self.log.debug("Harvested {} SimpleDAGs".format(len(simple_dags))) # Send tasks for execution if available simple_dag_bag = SimpleDagBag(simple_dags) @@ -1593,6 +1615,8 @@ def _execute_helper(self): self.log.debug("Heartbeating the executor") self.executor.heartbeat() + self._change_state_for_tasks_failed_to_execute() + # Process events from the executor self._process_executor_events(simple_dag_bag) @@ -1612,8 +1636,13 @@ def _execute_helper(self): self.log.debug("Sleeping for %.2f seconds", self._processor_poll_interval) time.sleep(self._processor_poll_interval) - # Exit early for a test mode + # Exit early for a test mode, run one additional scheduler loop + # to reduce the possibility that parsed DAG was put into the queue + # by the DAG manager but not yet received by DAG agent. if self.processor_agent.done: + self._last_loop = True + + if self._last_loop: self.log.info("Exiting scheduler loop as all files" " have been processed {} times".format(self.num_runs)) break diff --git a/airflow/models.py b/airflow/models.py index 9ab2348cc2..8870a2921e 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -22,12 +22,12 @@ from __future__ import print_function from __future__ import unicode_literals +import copy +from collections import defaultdict, namedtuple +from builtins import ImportError as BuiltinImportError, bytes, object, str from future.standard_library import install_aliases -from builtins import str, object, bytes, ImportError as BuiltinImportError -import copy -from collections import namedtuple, defaultdict try: # Fix Python > 3.7 deprecation from collections.abc import Hashable diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 47f473e9aa..62c4c91968 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -146,6 +146,21 @@ def __init__(self, ti): self._end_date = ti.end_date self._try_number = ti.try_number self._state = ti.state + self._executor_config = ti.executor_config + if hasattr(ti, 'run_as_user'): + self._run_as_user = ti.run_as_user + else: + self._run_as_user = None + if hasattr(ti, 'pool'): + self._pool = ti.pool + else: + self._pool = None + if hasattr(ti, 'priority_weight'): + self._priority_weight = ti.priority_weight + else: + self._priority_weight = None + self._queue = ti.queue + self._key = ti.key @property def dag_id(self): @@ -175,6 +190,48 @@ def try_number(self): def state(self): return self._state + @property + def pool(self): + return self._pool + + @property + def priority_weight(self): + return self._priority_weight + + @property + def queue(self): + return self._queue + + @property + def key(self): + return self._key + + @property + def executor_config(self): + return self._executor_config + + @provide_session + def construct_task_instance(self, session=None, lock_for_update=False): + """ + Construct a TaskInstance from the database based on the primary key + :param session: DB session. + :param lock_for_update: if True, indicates that the database should + lock the TaskInstance (issuing a FOR UPDATE clause) until the + session is committed. + """ + TI = airflow.models.TaskInstance + + qry = session.query(TI).filter( + TI.dag_id == self._dag_id, + TI.task_id == self._task_id, + TI.execution_date == self._execution_date) + + if lock_for_update: + ti = qry.with_for_update().first() + else: + ti = qry.first() + return ti + class SimpleDagBag(BaseDagBag): """ @@ -566,11 +623,16 @@ def end(self): Terminate (and then kill) the manager process launched. :return: """ - if not self._process or not self._process.is_alive(): + if not self._process: self.log.warn('Ending without manager process.') return this_process = psutil.Process(os.getpid()) - manager_process = psutil.Process(self._process.pid) + try: + manager_process = psutil.Process(self._process.pid) + except psutil.NoSuchProcess: + self.log.info("Manager process not running.") + return + # First try SIGTERM if manager_process.is_running() \ and manager_process.pid in [x.pid for x in this_process.children()]: diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py index a86b9d357b..f648005877 100644 --- a/airflow/utils/timeout.py +++ b/airflow/utils/timeout.py @@ -23,6 +23,7 @@ from __future__ import unicode_literals import signal +import os from airflow.exceptions import AirflowTaskTimeout from airflow.utils.log.logging_mixin import LoggingMixin @@ -35,10 +36,10 @@ class timeout(LoggingMixin): def __init__(self, seconds=1, error_message='Timeout'): self.seconds = seconds - self.error_message = error_message + self.error_message = error_message + ', PID: ' + str(os.getpid()) def handle_timeout(self, signum, frame): - self.log.error("Process timed out") + self.log.error("Process timed out, PID: " + str(os.getpid())) raise AirflowTaskTimeout(self.error_message) def __enter__(self): diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 380201d30a..5a7d6e984c 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -18,12 +18,17 @@ # under the License. import sys import unittest +from multiprocessing import Pool + import mock from celery.contrib.testing.worker import start_worker -from airflow.executors.celery_executor import CeleryExecutor -from airflow.executors.celery_executor import app +from airflow.executors import celery_executor from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER +from airflow.executors.celery_executor import (CeleryExecutor, celery_configuration, + send_task_to_executor, execute_command) +from airflow.executors.celery_executor import app +from celery import states as celery_states from airflow.utils.state import State from airflow import configuration @@ -40,16 +45,37 @@ def test_celery_integration(self): executor = CeleryExecutor() executor.start() with start_worker(app=app, logfile=sys.stdout, loglevel='debug'): - success_command = ['true', 'some_parameter'] fail_command = ['false', 'some_parameter'] - executor.execute_async(key='success', command=success_command) - # errors are propagated for some reason - try: - executor.execute_async(key='fail', command=fail_command) - except Exception: - pass + cached_celery_backend = execute_command.backend + task_tuples_to_send = [('success', 'fake_simple_ti', success_command, + celery_configuration['task_default_queue'], + execute_command), + ('fail', 'fake_simple_ti', fail_command, + celery_configuration['task_default_queue'], + execute_command)] + + chunksize = executor._num_tasks_per_send_process(len(task_tuples_to_send)) + num_processes = min(len(task_tuples_to_send), executor._sync_parallelism) + + send_pool = Pool(processes=num_processes) + key_and_async_results = send_pool.map( + send_task_to_executor, + task_tuples_to_send, + chunksize=chunksize) + + send_pool.close() + send_pool.join() + + for key, command, result in key_and_async_results: + # Only pops when enqueued successfully, otherwise keep it + # and expect scheduler loop to deal with it. + result.backend = cached_celery_backend + executor.running[key] = command + executor.tasks[key] = result + executor.last_state[key] = celery_states.PENDING + executor.running['success'] = True executor.running['fail'] = True @@ -64,6 +90,21 @@ def test_celery_integration(self): self.assertNotIn('success', executor.last_state) self.assertNotIn('fail', executor.last_state) + def test_error_sending_task(self): + @app.task + def fake_execute_command(): + pass + + # fake_execute_command takes no arguments while execute_command takes 1, + # which will cause TypeError when calling task.apply_async() + celery_executor.execute_command = fake_execute_command + executor = CeleryExecutor() + value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti' + executor.queued_tasks['key'] = value_tuple + executor.heartbeat() + self.assertEquals(1, len(executor.queued_tasks)) + self.assertEquals(executor.queued_tasks['key'], value_tuple) + def test_exception_propagation(self): @app.task def fake_celery_task(): diff --git a/tests/executors/test_executor.py b/tests/executors/test_executor.py index aab66644b8..366ea8c967 100644 --- a/tests/executors/test_executor.py +++ b/tests/executors/test_executor.py @@ -46,7 +46,8 @@ def heartbeat(self): ti = self._running.pop() ti.set_state(State.SUCCESS, session) for key, val in list(self.queued_tasks.items()): - (command, priority, queue, ti) = val + (command, priority, queue, simple_ti) = val + ti = simple_ti.construct_task_instance() ti.set_state(State.RUNNING, session) self._running.append(ti) self.queued_tasks.pop(key) diff --git a/tests/test_jobs.py b/tests/test_jobs.py index af8ccc6c2e..fb161c1b69 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -2032,6 +2032,54 @@ def test_change_state_for_tis_without_dagrun(self): ti2.refresh_from_db(session=session) self.assertEqual(ti2.state, State.SCHEDULED) + def test_change_state_for_tasks_failed_to_execute(self): + dag = DAG( + dag_id='dag_id', + start_date=DEFAULT_DATE) + + task = DummyOperator( + task_id='task_id', + dag=dag, + owner='airflow') + + # If there's no left over task in executor.queued_tasks, nothing happens + session = settings.Session() + scheduler_job = SchedulerJob() + mock_logger = mock.MagicMock() + test_executor = TestExecutor() + scheduler_job.executor = test_executor + scheduler_job._logger = mock_logger + scheduler_job._change_state_for_tasks_failed_to_execute() + mock_logger.info.assert_not_called() + + # Tasks failed to execute with QUEUED state will be set to SCHEDULED state. + session.query(TI).delete() + session.commit() + key = 'dag_id', 'task_id', DEFAULT_DATE, 1 + test_executor.queued_tasks[key] = 'value' + ti = TI(task, DEFAULT_DATE) + ti.state = State.QUEUED + session.merge(ti) + session.commit() + + scheduler_job._change_state_for_tasks_failed_to_execute() + + ti.refresh_from_db() + self.assertEquals(State.SCHEDULED, ti.state) + + # Tasks failed to execute with RUNNING state will not be set to SCHEDULED state. + session.query(TI).delete() + session.commit() + ti.state = State.RUNNING + + session.merge(ti) + session.commit() + + scheduler_job._change_state_for_tasks_failed_to_execute() + + ti.refresh_from_db() + self.assertEquals(State.RUNNING, ti.state) + def test_execute_helper_reset_orphaned_tasks(self): session = settings.Session() dag = DAG( @@ -2949,7 +2997,8 @@ def run_with_error(task): pass ti_tuple = six.next(six.itervalues(executor.queued_tasks)) - (command, priority, queue, ti) = ti_tuple + (command, priority, queue, simple_ti) = ti_tuple + ti = simple_ti.construct_task_instance() ti.task = dag_task1 self.assertEqual(ti.try_number, 1) @@ -2970,15 +3019,21 @@ def run_with_error(task): # removing self.assertEqual(ti.state, State.SCHEDULED) # as scheduler will move state from SCHEDULED to QUEUED - # now the executor has cleared and it should be allowed the re-queue + # now the executor has cleared and it should be allowed the re-queue, + # but tasks stay in the executor.queued_tasks after executor.heartbeat() + # will be set back to SCHEDULED state executor.queued_tasks.clear() do_schedule() ti.refresh_from_db() - self.assertEqual(ti.state, State.QUEUED) - # calling below again in order to ensure with try_number 2, - # scheduler doesn't put task in queue + + self.assertEqual(ti.state, State.SCHEDULED) + + # To verify that task does get re-queued. + executor.queued_tasks.clear() + executor.do_update = True do_schedule() - self.assertEquals(1, len(executor.queued_tasks)) + ti.refresh_from_db() + self.assertEqual(ti.state, State.RUNNING) @unittest.skipUnless("INTEGRATION" in os.environ, "Can only run end to end") def test_retry_handling_job(self): @@ -3023,8 +3078,8 @@ def test_scheduler_run_duration(self): logging.info("Test ran in %.2fs, expected %.2fs", run_duration, expected_run_duration) - # 5s to wait for child process to exit and 1s dummy sleep - # in scheduler loop to prevent excessive logs. + # 5s to wait for child process to exit, 1s dummy sleep + # in scheduler loop to prevent excessive logs and 1s for last loop to finish. self.assertLess(run_duration - expected_run_duration, 6.0) def test_dag_with_system_exit(self): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services