[ https://issues.apache.org/jira/browse/AIRFLOW-2761?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16617096#comment-16617096 ]
ASF GitHub Bot commented on AIRFLOW-2761: ----------------------------------------- yrqls21 closed pull request #3910: [WIP][AIRFLOW-2761] Parallelize enqueue in celery executor URL: https://github.com/apache/incubator-airflow/pull/3910 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/UPDATING.md b/UPDATING.md index 4fd45b2c0e..124da34690 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -31,6 +31,11 @@ some bugs. The new `sync_parallelism` config option will control how many processes CeleryExecutor will use to fetch celery task state in parallel. Default value is max(1, number of cores - 1) +### New `log_processor_manager_location` config option + +The DAG parsing manager log now by default will be log into a file, where its location is +controlled by the new `log_processor_manager_location` config option in core section. + ## Airflow 1.10 Installation and upgrading requires setting `SLUGIFY_USES_TEXT_UNIDECODE=yes` in your environment or diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index 95150ab3bb..634b21747e 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -20,6 +20,7 @@ import os from airflow import configuration as conf +from airflow.utils.file import mkdirs # TODO: Logging format and level should be configured # in this file instead of from airflow.cfg. Currently @@ -38,7 +39,10 @@ PROCESSOR_LOG_FOLDER = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') +LOG_PROCESSOR_MANAGER_LOCATION = conf.get('core', 'LOG_PROCESSOR_MANAGER_LOCATION') + FILENAME_TEMPLATE = conf.get('core', 'LOG_FILENAME_TEMPLATE') + PROCESSOR_FILENAME_TEMPLATE = conf.get('core', 'LOG_PROCESSOR_FILENAME_TEMPLATE') # Storage bucket url for remote logging @@ -79,7 +83,7 @@ 'formatter': 'airflow', 'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER), 'filename_template': PROCESSOR_FILENAME_TEMPLATE, - }, + } }, 'loggers': { 'airflow.processor': { @@ -104,6 +108,26 @@ } } +DEFAULT_DAG_PARSING_LOGGING_CONFIG = { + 'handlers': { + 'processor_manager': { + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'airflow', + 'filename': LOG_PROCESSOR_MANAGER_LOCATION, + 'mode': 'a', + 'maxBytes': 104857600, # 100MB + 'backupCount': 5 + } + }, + 'loggers': { + 'airflow.processor_manager': { + 'handlers': ['processor_manager'], + 'level': LOG_LEVEL, + 'propagate': False, + } + } +} + REMOTE_HANDLERS = { 's3': { 'task': { @@ -172,6 +196,20 @@ REMOTE_LOGGING = conf.get('core', 'remote_logging') +if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True': + DEFAULT_LOGGING_CONFIG['handlers'] \ + .update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']) + DEFAULT_LOGGING_CONFIG['loggers'] \ + .update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers']) + + # Manually create log directory for processor_manager handler as RotatingFileHandler + # will only create file but not the directory. + processor_manager_handler_config = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][ + 'processor_manager'] + directory = os.path.dirname(processor_manager_handler_config['filename']) + if not os.path.exists(directory): + mkdirs(directory, 0o777) + if REMOTE_LOGGING and REMOTE_BASE_LOG_FOLDER.startswith('s3://'): DEFAULT_LOGGING_CONFIG['handlers'].update(REMOTE_HANDLERS['s3']) elif REMOTE_LOGGING and REMOTE_BASE_LOG_FOLDER.startswith('gs://'): diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 000dd67a13..367527b1c1 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -70,6 +70,7 @@ simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s # we need to escape the curly braces by adding an additional curly brace log_filename_template = {{{{ ti.dag_id }}}}/{{{{ ti.task_id }}}}/{{{{ ts }}}}/{{{{ try_number }}}}.log log_processor_filename_template = {{{{ filename }}}}.log +log_processor_manager_location = {AIRFLOW_HOME}/logs/dag_processor_manager/dag_processor_manager.log # Hostname by providing a path to a callable, which will resolve the hostname hostname_callable = socket:getfqdn diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index f9279cce54..7f8f350970 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -39,6 +39,7 @@ logging_level = INFO fab_logging_level = WARN log_filename_template = {{{{ ti.dag_id }}}}/{{{{ ti.task_id }}}}/{{{{ ts }}}}/{{{{ try_number }}}}.log log_processor_filename_template = {{{{ filename }}}}.log +log_processor_manager_location = {AIRFLOW_HOME}/logs/dag_processor_manager/dag_processor_manager.log executor = SequentialExecutor sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db load_examples = True diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index a989dc4408..3564c8c4ef 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -19,11 +19,12 @@ from builtins import range +# To avoid circular imports +import airflow.utils.dag_processing from airflow import configuration from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State - PARALLELISM = configuration.conf.getint('core', 'PARALLELISM') @@ -50,11 +51,11 @@ def start(self): # pragma: no cover """ pass - def queue_command(self, task_instance, command, priority=1, queue=None): - key = task_instance.key + def queue_command(self, simple_task_instance, command, priority=1, queue=None): + key = simple_task_instance.key if key not in self.queued_tasks and key not in self.running: self.log.info("Adding to queue: %s", command) - self.queued_tasks[key] = (command, priority, queue, task_instance) + self.queued_tasks[key] = (command, priority, queue, simple_task_instance) else: self.log.info("could not queue task {}".format(key)) @@ -86,7 +87,7 @@ def queue_task_instance( pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command( - task_instance, + airflow.utils.dag_processing.SimpleTaskInstance(task_instance), command, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue) @@ -124,26 +125,13 @@ def heartbeat(self): key=lambda x: x[1][1], reverse=True) for i in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, queue, ti) = sorted_queue.pop(0) - # TODO(jlowin) without a way to know what Job ran which tasks, - # there is a danger that another Job started running a task - # that was also queued to this executor. This is the last chance - # to check if that happened. The most probable way is that a - # Scheduler tried to run a task that was originally queued by a - # Backfill. This fix reduces the probability of a collision but - # does NOT eliminate it. + key, (command, _, queue, simple_ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) - ti.refresh_from_db() - if ti.state != State.RUNNING: - self.running[key] = command - self.execute_async(key=key, - command=command, - queue=queue, - executor_config=ti.executor_config) - else: - self.logger.info( - 'Task is already running, not sending to ' - 'executor: {}'.format(key)) + self.running[key] = command + self.execute_async(key=key, + command=command, + queue=queue, + executor_config=simple_ti.executor_config) # Calling child class sync method self.log.debug("Calling the %s sync method", self.__class__) @@ -151,7 +139,7 @@ def heartbeat(self): def change_state(self, key, state): self.log.debug("Changing state: {}".format(key)) - self.running.pop(key) + self.running.pop(key, None) self.event_buffer[key] = state def fail(self, key): diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 0de48b4d39..b3f9ba0939 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -33,10 +33,13 @@ from airflow.executors.base_executor import BaseExecutor from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.timeout import timeout # Make it constant for unit test. CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state' +CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task' + ''' To start the celery worker, run the command: airflow worker @@ -55,12 +58,12 @@ @app.task -def execute_command(command): +def execute_command(command_to_exec): log = LoggingMixin().log - log.info("Executing command in Celery: %s", command) + log.info("Executing command in Celery: %s", command_to_exec) env = os.environ.copy() try: - subprocess.check_call(command, stderr=subprocess.STDOUT, + subprocess.check_call(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env) except subprocess.CalledProcessError as e: log.exception('execute_command encountered a CalledProcessError') @@ -95,9 +98,10 @@ def fetch_celery_task_state(celery_task): """ try: - # Accessing state property of celery task will make actual network request - # to get the current state of the task. - res = (celery_task[0], celery_task[1].state) + with timeout(seconds=2): + # Accessing state property of celery task will make actual network request + # to get the current state of the task. + res = (celery_task[0], celery_task[1].state) except Exception as e: exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0], traceback.format_exc()) @@ -105,6 +109,19 @@ def fetch_celery_task_state(celery_task): return res +def send_task_to_executor(task_tuple): + key, simple_ti, command, queue, task = task_tuple + try: + with timeout(seconds=2): + result = task.apply_async(args=[command], queue=queue) + except Exception as e: + exception_traceback = "Celery Task ID: {}\n{}".format(key, + traceback.format_exc()) + result = ExceptionWithTraceback(e, exception_traceback) + + return key, command, result + + class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. It allows @@ -135,16 +152,16 @@ def start(self): 'Starting Celery Executor using {} processes for syncing'.format( self._sync_parallelism)) - def execute_async(self, key, command, - queue=DEFAULT_CELERY_CONFIG['task_default_queue'], - executor_config=None): - self.log.info("[celery] queuing {key} through celery, " - "queue={queue}".format(**locals())) - self.tasks[key] = execute_command.apply_async( - args=[command], queue=queue) - self.last_state[key] = celery_states.PENDING + def _num_tasks_per_send_process(self, to_send_count): + """ + How many Celery tasks should each worker process send. + :return: Number of tasks that should be sent per process + :rtype: int + """ + return max(1, + int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) - def _num_tasks_per_process(self): + def _num_tasks_per_fetch_process(self): """ How many Celery tasks should be sent to each worker process. :return: Number of tasks that should be used per process @@ -153,6 +170,82 @@ def _num_tasks_per_process(self): return max(1, int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) + def heartbeat(self): + # Triggering new jobs + if not self.parallelism: + open_slots = len(self.queued_tasks) + else: + open_slots = self.parallelism - len(self.running) + + self.logger.debug("{} running task instances".format(len(self.running))) + self.logger.debug("{} in queue".format(len(self.queued_tasks))) + self.logger.debug("{} open slots".format(open_slots)) + + sorted_queue = sorted( + [(k, v) for k, v in self.queued_tasks.items()], + key=lambda x: x[1][1], + reverse=True) + + task_tuples_to_send = [] + + for i in range(min((open_slots, len(self.queued_tasks)))): + key, (command, _, queue, simple_ti) = sorted_queue.pop(0) + task_tuples_to_send.append((key, simple_ti, command, queue, + execute_command)) + + cached_celery_backend = None + if task_tuples_to_send: + tasks = [t[4] for t in task_tuples_to_send] + if tasks[0].app is not None: + self.logger.debug( + 'before enqueue task app count: ' + str( + len(set([task.app for task in tasks]))) + ) + + self.logger.debug( + 'before enqueue task backend count: ' + str(len(set([task.backend + for task in + tasks])))) + + # Celery state queries will stuck if we do not use one same backend + # for all tasks. + cached_celery_backend = tasks[0].backend + + if task_tuples_to_send: + # Use chunking instead of a work queue to reduce context switching + # since tasks are roughly uniform in size + chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send)) + num_processes = min(len(task_tuples_to_send), self._sync_parallelism) + + send_pool = Pool(processes=num_processes) + self.logger.debug('send_pool created.') + key_and_async_results = send_pool.map( + send_task_to_executor, + task_tuples_to_send, + chunksize=chunksize) + + self.logger.debug('Sent all tasks.') + send_pool.close() + send_pool.join() + + for key, command, result in key_and_async_results: + if isinstance(result, ExceptionWithTraceback): + self.logger.error( + CELERY_SEND_ERR_MSG_HEADER + ":{}\n{}\n".format( + result.exception, result.traceback)) + elif result is not None: + # Only pops when enqueued successfully, otherwise keep it + # and expect scheduler loop to deal with it. + self.queued_tasks.pop(key) + result.backend = cached_celery_backend + self.running[key] = command + self.tasks[key] = result + self.last_state[key] = celery_states.PENDING + + # Calling child class sync method + self.logger.debug("Calling the {} sync method".format(self.__class__)) + self.sync() + def sync(self): num_processes = min(len(self.tasks), self._sync_parallelism) if num_processes == 0: @@ -167,7 +260,7 @@ def sync(self): # Use chunking instead of a work queue to reduce context switching since tasks are # roughly uniform in size - chunksize = self._num_tasks_per_process() + chunksize = self._num_tasks_per_fetch_process() self.log.debug("Waiting for inquiries to complete...") task_keys_to_states = self._sync_pool.map( diff --git a/airflow/jobs.py b/airflow/jobs.py index 916ec1f243..5f1d77c3a6 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -26,22 +26,19 @@ import logging import multiprocessing import os -import psutil import signal -import six import sys import threading import time -import datetime - from collections import defaultdict +from time import sleep + +import six from past.builtins import basestring from sqlalchemy import ( Column, Integer, String, func, Index, or_, and_, not_) from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient -from tabulate import tabulate -from time import sleep from airflow import configuration as conf from airflow import executors, models, settings @@ -53,16 +50,17 @@ from airflow.utils import asciiart, helpers, timezone from airflow.utils.configuration import tmp_configuration_copy from airflow.utils.dag_processing import (AbstractDagFileProcessor, - DagFileProcessorManager, + DagFileProcessorAgent, SimpleDag, SimpleDagBag, + SimpleTaskInstance, list_py_file_paths) from airflow.utils.db import create_session, provide_session from airflow.utils.email import send_email, get_email_address_list from airflow.utils.log.logging_mixin import LoggingMixin, set_context, StreamLogWriter from airflow.utils.net import get_hostname -from airflow.utils.state import State from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.state import State Base = models.Base ID_LEN = models.ID_LEN @@ -304,7 +302,7 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin): # Counter that increments everytime an instance of this class is created class_creation_counter = 0 - def __init__(self, file_path, pickle_dags, dag_id_white_list): + def __init__(self, file_path, pickle_dags, dag_id_white_list, zombies): """ :param file_path: a Python file containing Airflow DAG definitions :type file_path: unicode @@ -312,6 +310,8 @@ def __init__(self, file_path, pickle_dags, dag_id_white_list): :type pickle_dags: bool :param dag_id_whitelist: If specified, only look at these DAG ID's :type dag_id_whitelist: list[unicode] + :param zombies: zombie task instances to kill + :type zombies: SimpleTaskInstance """ self._file_path = file_path # Queue that's used to pass results from the child process. @@ -320,6 +320,7 @@ def __init__(self, file_path, pickle_dags, dag_id_white_list): self._process = None self._dag_id_white_list = dag_id_white_list self._pickle_dags = pickle_dags + self._zombies = zombies # The result of Scheduler.process_file(file_path). self._result = None # Whether the process is done running. @@ -340,7 +341,8 @@ def _launch_process(result_queue, file_path, pickle_dags, dag_id_white_list, - thread_name): + thread_name, + zombies): """ Launch a process to process the given file. @@ -358,6 +360,8 @@ def _launch_process(result_queue, :type thread_name: unicode :return: the process that was launched :rtype: multiprocessing.Process + :param zombies: zombie task instances to kill + :type zombies: SimpleTaskInstance """ def helper(): # This helper runs in the newly created process @@ -386,6 +390,7 @@ def helper(): os.getpid(), file_path) scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log) result = scheduler_job.process_file(file_path, + zombies, pickle_dags) result_queue.put(result) end_time = time.time() @@ -418,7 +423,8 @@ def start(self): self.file_path, self._pickle_dags, self._dag_id_white_list, - "DagFileProcessor{}".format(self._instance_id)) + "DagFileProcessor{}".format(self._instance_id), + self._zombies) self._start_time = timezone.utcnow() def terminate(self, sigkill=False): @@ -472,7 +478,8 @@ def done(self): if self._done: return True - if not self._result_queue.empty(): + # In case result queue is corrupted. + if self._result_queue and not self._result_queue.empty(): self._result = self._result_queue.get_nowait() self._done = True self.log.debug("Waiting for %s", self._process) @@ -480,7 +487,7 @@ def done(self): return True # Potential error case when process dies - if not self._process.is_alive(): + if self._result_queue and not self._process.is_alive(): self._done = True # Get the object from the queue or else join() can hang. if not self._result_queue.empty(): @@ -531,8 +538,6 @@ def __init__( dag_ids=None, subdir=settings.DAGS_FOLDER, num_runs=-1, - file_process_interval=conf.getint('scheduler', - 'min_file_process_interval'), processor_poll_interval=1.0, run_duration=None, do_pickle=False, @@ -579,26 +584,28 @@ def __init__( self.using_sqlite = False if 'sqlite' in conf.get('core', 'sql_alchemy_conn'): - if self.max_threads > 1: - self.log.error("Cannot use more than 1 thread when using sqlite. Setting max_threads to 1") - self.max_threads = 1 self.using_sqlite = True - # How often to scan the DAGs directory for new files. Default to 5 minutes. - self.dag_dir_list_interval = conf.getint('scheduler', - 'dag_dir_list_interval') - # How often to print out DAG file processing stats to the log. Default to - # 30 seconds. - self.print_stats_interval = conf.getint('scheduler', - 'print_stats_interval') - - self.file_process_interval = file_process_interval - self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query') if run_duration is None: self.run_duration = conf.getint('scheduler', 'run_duration') + self.processor_agent = None + self._last_loop = False + + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame): + """ + Helper method to clean up processor_agent to avoid leaving orphan processes. + """ + self.log.info("Exiting gracefully with signal {}".format(signum)) + if self.processor_agent: + self.processor_agent.end() + sys.exit(os.EX_OK) + @provide_session def manage_slas(self, dag, session=None): """ @@ -734,25 +741,6 @@ def manage_slas(self, dag, session=None): session.merge(sla) session.commit() - @staticmethod - @provide_session - def clear_nonexistent_import_errors(session, known_file_paths): - """ - Clears import errors for files that no longer exist. - - :param session: session for ORM operations - :type session: sqlalchemy.orm.session.Session - :param known_file_paths: The list of existing files that are parsed for DAGs - :type known_file_paths: list[unicode] - """ - query = session.query(models.ImportError) - if known_file_paths: - query = query.filter( - ~models.ImportError.filename.in_(known_file_paths) - ) - query.delete(synchronize_session='fetch') - session.commit() - @staticmethod def update_import_errors(session, dagbag): """ @@ -1107,7 +1095,9 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): # Put one task instance on each line task_instance_str = "\n\t".join( ["{}".format(x) for x in task_instances_to_examine]) - self.log.info("Tasks up for execution:\n\t%s", task_instance_str) + self.log.info("{} tasks up for execution:\n\t{}" + .format(len(task_instances_to_examine), + task_instance_str)) # Get the pool settings pools = {p.pool: p for p in session.query(models.Pool).all()} @@ -1233,13 +1223,13 @@ def _change_state_for_executable_task_instances(self, task_instances, acceptable_states, session=None): """ Changes the state of task instances in the list with one of the given states - to QUEUED atomically, and returns the TIs changed. + to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format. :param task_instances: TaskInstances to change the state of :type task_instances: List[TaskInstance] :param acceptable_states: Filters the TaskInstances updated to be in these states :type acceptable_states: Iterable[State] - :return: List[TaskInstance] + :return: List[SimpleTaskInstance] """ if len(task_instances) == 0: session.commit() @@ -1281,80 +1271,57 @@ def _change_state_for_executable_task_instances(self, task_instances, else task_instance.queued_dttm) session.merge(task_instance) - # save which TIs we set before session expires them - filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id, - TI.task_id == ti.task_id, - TI.execution_date == ti.execution_date) - for ti in tis_to_set_to_queued]) - session.commit() - - # requery in batches since above was expired by commit - - def query(result, items): - tis_to_be_queued = ( - session - .query(TI) - .filter(or_(*items)) - .all()) - task_instance_str = "\n\t".join( - ["{}".format(x) for x in tis_to_be_queued]) - self.log.info("Setting the follow tasks to queued state:\n\t%s", - task_instance_str) - return result + tis_to_be_queued + # Generate a list of SimpleTaskInstance for the use of queuing + # them in the executor. + simple_task_instances = [SimpleTaskInstance(ti) for ti in + tis_to_set_to_queued] - tis_to_be_queued = helpers.reduce_in_chunks(query, - filter_for_ti_enqueue, - [], - self.max_tis_per_query) + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_set_to_queued]) - return tis_to_be_queued + session.commit() + self.logger.info("Setting the follow {} tasks to queued state:\n\t{}" + .format(len(tis_to_set_to_queued), task_instance_str)) + return simple_task_instances - def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances): + def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, + simple_task_instances): """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. - :param task_instances: TaskInstances to enqueue - :type task_instances: List[TaskInstance] + :param simple_task_instances: TaskInstances to enqueue + :type simple_task_instances: List[SimpleTaskInstance] :param simple_dag_bag: Should contains all of the task_instances' dags :type simple_dag_bag: SimpleDagBag """ TI = models.TaskInstance # actually enqueue them - for task_instance in task_instances: - simple_dag = simple_dag_bag.get_dag(task_instance.dag_id) + for simple_task_instance in simple_task_instances: + simple_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id) command = TI.generate_command( - task_instance.dag_id, - task_instance.task_id, - task_instance.execution_date, + simple_task_instance.dag_id, + simple_task_instance.task_id, + simple_task_instance.execution_date, local=True, mark_success=False, ignore_all_deps=False, ignore_depends_on_past=False, ignore_task_deps=False, ignore_ti_state=False, - pool=task_instance.pool, + pool=simple_task_instance.pool, file_path=simple_dag.full_filepath, pickle_id=simple_dag.pickle_id) - priority = task_instance.priority_weight - queue = task_instance.queue + priority = simple_task_instance.priority_weight + queue = simple_task_instance.queue self.log.info( "Sending %s to executor with priority %s and queue %s", - task_instance.key, priority, queue + simple_task_instance.key, priority, queue ) - # save attributes so sqlalchemy doesnt expire them - copy_dag_id = task_instance.dag_id - copy_task_id = task_instance.task_id - copy_execution_date = task_instance.execution_date - make_transient(task_instance) - task_instance.dag_id = copy_dag_id - task_instance.task_id = copy_task_id - task_instance.execution_date = copy_execution_date - self.executor.queue_command( - task_instance, + simple_task_instance, command, priority=priority, queue=queue) @@ -1378,24 +1345,62 @@ def _execute_task_instances(self, :type simple_dag_bag: SimpleDagBag :param states: Execute TaskInstances in these states :type states: Tuple[State] - :return: None + :return: Number of task instance with state changed. """ executable_tis = self._find_executable_task_instances(simple_dag_bag, states, session=session) def query(result, items): - tis_with_state_changed = self._change_state_for_executable_task_instances( - items, - states, - session=session) + simple_tis_with_state_changed = \ + self._change_state_for_executable_task_instances(items, + states, + session=session) self._enqueue_task_instances_with_queued_state( simple_dag_bag, - tis_with_state_changed) + simple_tis_with_state_changed) session.commit() - return result + len(tis_with_state_changed) + return result + len(simple_tis_with_state_changed) return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query) + @provide_session + def _change_state_for_tasks_failed_to_execute(self, session): + """ + If there are tasks left over in the executor, + we set them back to SCHEDULED to avoid creating hanging tasks. + :param session: + :return: + """ + if self.executor.queued_tasks: + TI = models.TaskInstance + filter_for_ti_state_change = ( + [and_( + TI.dag_id == dag_id, + TI.task_id == task_id, + TI.execution_date == execution_date, + TI.state == State.QUEUED) + for dag_id, task_id, execution_date + in self.executor.queued_tasks.keys()]) + ti_query = (session.query(TI) + .filter(or_(*filter_for_ti_state_change))) + tis_to_set_to_scheduled = (ti_query + .with_for_update() + .all()) + if len(tis_to_set_to_scheduled) == 0: + session.commit() + return + + # set TIs to queued state + for task_instance in tis_to_set_to_scheduled: + task_instance.state = State.SCHEDULED + + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_set_to_scheduled]) + + session.commit() + self.logger.info("Set the follow tasks to scheduled state:\n\t{}" + .format(task_instance_str)) + def _process_dags(self, dagbag, dags, tis_out): """ Iterates over the dags and processes them. Processing includes: @@ -1476,72 +1481,6 @@ def _process_executor_events(self, simple_dag_bag, session=None): session.merge(ti) session.commit() - def _log_file_processing_stats(self, - known_file_paths, - processor_manager): - """ - Print out stats about how files are getting processed. - - :param known_file_paths: a list of file paths that may contain Airflow - DAG definitions - :type known_file_paths: list[unicode] - :param processor_manager: manager for the file processors - :type stats: DagFileProcessorManager - :return: None - """ - - # File Path: Path to the file containing the DAG definition - # PID: PID associated with the process that's processing the file. May - # be empty. - # Runtime: If the process is currently running, how long it's been - # running for in seconds. - # Last Runtime: If the process ran before, how long did it take to - # finish in seconds - # Last Run: When the file finished processing in the previous run. - headers = ["File Path", - "PID", - "Runtime", - "Last Runtime", - "Last Run"] - - rows = [] - for file_path in known_file_paths: - last_runtime = processor_manager.get_last_runtime(file_path) - processor_pid = processor_manager.get_pid(file_path) - processor_start_time = processor_manager.get_start_time(file_path) - runtime = ((timezone.utcnow() - processor_start_time).total_seconds() - if processor_start_time else None) - last_run = processor_manager.get_last_finish_time(file_path) - - rows.append((file_path, - processor_pid, - runtime, - last_runtime, - last_run)) - - # Sort by longest last runtime. (Can't sort None values in python3) - rows = sorted(rows, key=lambda x: x[3] or 0.0) - - formatted_rows = [] - for file_path, pid, runtime, last_runtime, last_run in rows: - formatted_rows.append((file_path, - pid, - "{:.2f}s".format(runtime) - if runtime else None, - "{:.2f}s".format(last_runtime) - if last_runtime else None, - last_run.strftime("%Y-%m-%dT%H:%M:%S") - if last_run else None)) - log_str = ("\n" + - "=" * 80 + - "\n" + - "DAG File Processing Stats\n\n" + - tabulate(formatted_rows, headers=headers) + - "\n" + - "=" * 80) - - self.log.info(log_str) - def _execute(self): self.log.info("Starting the scheduler") @@ -1551,84 +1490,40 @@ def _execute(self): (executors.LocalExecutor, executors.SequentialExecutor): pickle_dags = True - # Use multiple processes to parse and generate tasks for the - # DAGs in parallel. By processing them in separate processes, - # we can get parallelism and isolation from potentially harmful - # user code. - self.log.info( - "Processing files using up to %s processes at a time", - self.max_threads) self.log.info("Running execute loop for %s seconds", self.run_duration) self.log.info("Processing each file at most %s times", self.num_runs) - self.log.info( - "Process each file at most once every %s seconds", - self.file_process_interval) - self.log.info( - "Checking for new files in %s every %s seconds", - self.subdir, - self.dag_dir_list_interval) # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s", self.subdir) known_file_paths = list_py_file_paths(self.subdir) self.log.info("There are %s files in %s", len(known_file_paths), self.subdir) - def processor_factory(file_path): + def processor_factory(file_path, zombies): return DagFileProcessor(file_path, pickle_dags, - self.dag_ids) + self.dag_ids, + zombies) + + # When using sqlite, we do not use async_mode + # so the scheduler job and DAG parser don't access the DB at the same time. + async_mode = not self.using_sqlite - processor_manager = DagFileProcessorManager(self.subdir, - known_file_paths, - self.max_threads, - self.file_process_interval, - self.num_runs, - processor_factory) + self.processor_agent = DagFileProcessorAgent(self.subdir, + known_file_paths, + self.num_runs, + processor_factory, + async_mode) try: - self._execute_helper(processor_manager) + self._execute_helper() + except Exception: + self.log.exception("Exception when executing execute_helper") finally: + self.processor_agent.end() self.log.info("Exited execute loop") - # Kill all child processes on exit since we don't want to leave - # them as orphaned. - pids_to_kill = processor_manager.get_all_pids() - if len(pids_to_kill) > 0: - # First try SIGTERM - this_process = psutil.Process(os.getpid()) - # Only check child processes to ensure that we don't have a case - # where we kill the wrong process because a child process died - # but the PID got reused. - child_processes = [x for x in this_process.children(recursive=True) - if x.is_running() and x.pid in pids_to_kill] - for child in child_processes: - self.log.info("Terminating child PID: %s", child.pid) - child.terminate() - # TODO: Remove magic number - timeout = 5 - self.log.info( - "Waiting up to %s seconds for processes to exit...", timeout) - try: - psutil.wait_procs( - child_processes, timeout=timeout, - callback=lambda x: self.log.info('Terminated PID %s', x.pid)) - except psutil.TimeoutExpired: - self.log.debug("Ran out of time while waiting for processes to exit") - - # Then SIGKILL - child_processes = [x for x in this_process.children(recursive=True) - if x.is_running() and x.pid in pids_to_kill] - if len(child_processes) > 0: - self.log.info("SIGKILL processes that did not terminate gracefully") - for child in child_processes: - self.log.info("Killing child PID: %s", child.pid) - child.kill() - child.wait() - - def _execute_helper(self, processor_manager): - """ - :param processor_manager: manager to use - :type processor_manager: DagFileProcessorManager + def _execute_helper(self): + """ :return: None """ self.executor.start() @@ -1636,17 +1531,13 @@ def _execute_helper(self, processor_manager): self.log.info("Resetting orphaned tasks for active dag runs") self.reset_state_for_orphaned_tasks() + # Start after resetting orphaned tasks to avoid stressing out DB. + self.processor_agent.start() + execute_start_time = timezone.utcnow() - # Last time stats were printed - last_stat_print_time = datetime.datetime(2000, 1, 1, tzinfo=timezone.utc) # Last time that self.heartbeat() was called. last_self_heartbeat_time = timezone.utcnow() - # Last time that the DAG dir was traversed to look for files - last_dag_dir_refresh_time = timezone.utcnow() - - # Use this value initially - known_file_paths = processor_manager.file_paths # For the execute duration, parse and schedule DAGs while (timezone.utcnow() - execute_start_time).total_seconds() < \ @@ -1654,65 +1545,55 @@ def _execute_helper(self, processor_manager): self.log.debug("Starting Loop...") loop_start_time = time.time() - # Traverse the DAG directory for Python files containing DAGs - # periodically - elapsed_time_since_refresh = (timezone.utcnow() - - last_dag_dir_refresh_time).total_seconds() - - if elapsed_time_since_refresh > self.dag_dir_list_interval: - # Build up a list of Python files that could contain DAGs - self.log.info("Searching for files in %s", self.subdir) - known_file_paths = list_py_file_paths(self.subdir) - last_dag_dir_refresh_time = timezone.utcnow() - self.log.info( - "There are %s files in %s", len(known_file_paths), self.subdir) - - processor_manager.set_file_paths(known_file_paths) - - self.log.debug("Removing old import errors") - self.clear_nonexistent_import_errors(known_file_paths=known_file_paths) - - # Kick of new processes and collect results from finished ones - self.log.debug("Heartbeating the process manager") - simple_dags = processor_manager.heartbeat() - if self.using_sqlite: + self.processor_agent.heartbeat() # For the sqlite case w/ 1 thread, wait until the processor # is finished to avoid concurrent access to the DB. self.log.debug( "Waiting for processors to finish since we're using sqlite") + self.processor_agent.wait_until_finished() - processor_manager.wait_until_finished() + self.log.info("Harvesting DAG parsing results") + simple_dags = self.processor_agent.harvest_simple_dags() + self.log.debug("Harvested {} SimpleDAGs".format(len(simple_dags))) # Send tasks for execution if available simple_dag_bag = SimpleDagBag(simple_dags) if len(simple_dags) > 0: - - # Handle cases where a DAG run state is set (perhaps manually) to - # a non-running state. Handle task instances that belong to - # DAG runs in those states - - # If a task instance is up for retry but the corresponding DAG run - # isn't running, mark the task instance as FAILED so we don't try - # to re-run it. - self._change_state_for_tis_without_dagrun(simple_dag_bag, - [State.UP_FOR_RETRY], - State.FAILED) - # If a task instance is scheduled or queued, but the corresponding - # DAG run isn't running, set the state to NONE so we don't try to - # re-run it. - self._change_state_for_tis_without_dagrun(simple_dag_bag, - [State.QUEUED, - State.SCHEDULED], - State.NONE) - - self._execute_task_instances(simple_dag_bag, - (State.SCHEDULED,)) + try: + simple_dag_bag = SimpleDagBag(simple_dags) + + # Handle cases where a DAG run state is set (perhaps manually) to + # a non-running state. Handle task instances that belong to + # DAG runs in those states + + # If a task instance is up for retry but the corresponding DAG run + # isn't running, mark the task instance as FAILED so we don't try + # to re-run it. + self._change_state_for_tis_without_dagrun(simple_dag_bag, + [State.UP_FOR_RETRY], + State.FAILED) + # If a task instance is scheduled or queued, but the corresponding + # DAG run isn't running, set the state to NONE so we don't try to + # re-run it. + self._change_state_for_tis_without_dagrun(simple_dag_bag, + [State.QUEUED, + State.SCHEDULED], + State.NONE) + + self._execute_task_instances(simple_dag_bag, + (State.SCHEDULED,)) + except Exception as e: + self.log.error("Error queuing tasks") + self.log.exception(e) + continue # Call heartbeats self.log.debug("Heartbeating the executor") self.executor.heartbeat() + self._change_state_for_tasks_failed_to_execute() + # Process events from the executor self._process_executor_events(simple_dag_bag) @@ -1724,40 +1605,39 @@ def _execute_helper(self, processor_manager): self.heartbeat() last_self_heartbeat_time = timezone.utcnow() - # Occasionally print out stats about how fast the files are getting processed - if ((timezone.utcnow() - last_stat_print_time).total_seconds() > - self.print_stats_interval): - if len(known_file_paths) > 0: - self._log_file_processing_stats(known_file_paths, - processor_manager) - last_stat_print_time = timezone.utcnow() - loop_end_time = time.time() + loop_duration = loop_end_time - loop_start_time self.log.debug( "Ran scheduling loop in %.2f seconds", - loop_end_time - loop_start_time) + loop_duration) self.log.debug("Sleeping for %.2f seconds", self._processor_poll_interval) time.sleep(self._processor_poll_interval) - # Exit early for a test mode - if processor_manager.max_runs_reached(): - self.log.info( - "Exiting loop as all files have been processed %s times", - self.num_runs) + # Exit early for a test mode, run one additional scheduler loop + # to reduce the possibility that parsed DAG was put into the queue + # by the DAG manager but not yet received by DAG agent. + if self.processor_agent.done: + self._last_loop = True + + if self._last_loop: + self.log.info("Exiting scheduler loop as all files" + " have been processed {} times".format(self.num_runs)) break + if loop_duration < 1: + sleep_length = 1 - loop_duration + self.log.debug( + "Sleeping for {0:.2f} seconds to prevent excessive logging" + .format(sleep_length)) + sleep(sleep_length) + # Stop any processors - processor_manager.terminate() + self.processor_agent.terminate() # Verify that all files were processed, and if so, deactivate DAGs that # haven't been touched by the scheduler as they likely have been # deleted. - all_files_processed = True - for file_path in known_file_paths: - if processor_manager.get_last_finish_time(file_path) is None: - all_files_processed = False - break - if all_files_processed: + if self.processor_agent.all_files_processed: self.log.info( "Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat() @@ -1769,7 +1649,7 @@ def _execute_helper(self, processor_manager): settings.Session.remove() @provide_session - def process_file(self, file_path, pickle_dags=False, session=None): + def process_file(self, file_path, zombies, pickle_dags=False, session=None): """ Process a Python file containing Airflow DAGs. @@ -1788,6 +1668,8 @@ def process_file(self, file_path, pickle_dags=False, session=None): :param file_path: the path to the Python file that should be executed :type file_path: unicode + :param zombies: zombie task instances to kill. + :type zombies: SimpleTaskInstance :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :type pickle_dags: bool @@ -1882,7 +1764,7 @@ def process_file(self, file_path, pickle_dags=False, session=None): except Exception: self.log.exception("Error logging import errors!") try: - dagbag.kill_zombies() + dagbag.kill_zombies(zombies) except Exception: self.log.exception("Error killing zombies!") diff --git a/airflow/models.py b/airflow/models.py index d703810a77..aa624352bb 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -437,39 +437,28 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): return found_dags @provide_session - def kill_zombies(self, session=None): - """ - Fails tasks that haven't had a heartbeat in too long - """ - from airflow.jobs import LocalTaskJob as LJ - self.log.info("Finding 'running' jobs without a recent heartbeat") - TI = TaskInstance - secs = configuration.conf.getint('scheduler', 'scheduler_zombie_task_threshold') - limit_dttm = timezone.utcnow() - timedelta(seconds=secs) - self.log.info("Failing jobs without heartbeat after %s", limit_dttm) - - tis = ( - session.query(TI) - .join(LJ, TI.job_id == LJ.id) - .filter(TI.state == State.RUNNING) - .filter( - or_( - LJ.state != State.RUNNING, - LJ.latest_heartbeat < limit_dttm, - )) - .all() - ) - - for ti in tis: - if ti and ti.dag_id in self.dags: - dag = self.dags[ti.dag_id] - if ti.task_id in dag.task_ids: - task = dag.get_task(ti.task_id) - - # now set non db backed vars on ti - ti.task = task + def kill_zombies(self, zombies, session=None): + """ + Fail given zombie tasks, which are tasks that haven't + had a heartbeat for too long, in the current DagBag. + + :param zombies: zombie task instances to kill. + :type zombies: SimpleTaskInstance + :param session: DB session. + :type Session. + """ + for zombie in zombies: + if zombie.dag_id in self.dags: + dag = self.dags[zombie.dag_id] + if zombie.task_id in dag.task_ids: + task = dag.get_task(zombie.task_id) + ti = TaskInstance(task, zombie.execution_date) + # Get properties needed for failure handling from SimpleTaskInstance. + ti.start_date = zombie.start_date + ti.end_date = zombie.end_date + ti.try_number = zombie.try_number + ti.state = zombie.state ti.test_mode = configuration.getboolean('core', 'unit_test_mode') - ti.handle_failure("{} detected as zombie".format(ti), ti.test_mode, ti.get_template_context()) self.log.info( diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 89e5701cf1..acff607702 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -22,17 +22,39 @@ from __future__ import print_function from __future__ import unicode_literals +import logging +import multiprocessing import os import re +import signal +import sys import time import zipfile from abc import ABCMeta, abstractmethod from collections import defaultdict +from datetime import timedelta +import psutil +from sqlalchemy import or_ +from tabulate import tabulate + +# To avoid circular imports +import airflow.models +from airflow import configuration as conf from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.exceptions import AirflowException from airflow.utils import timezone +from airflow.utils.db import provide_session from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import State + +python_version_info = sys.version_info +if python_version_info.major > 2: + xrange = range + if python_version_info.minor < 4: + from imp import reload + else: + from importlib import reload class SimpleDag(BaseDag): @@ -121,6 +143,97 @@ def get_task_special_arg(self, task_id, special_arg_name): return None +class SimpleTaskInstance(object): + def __init__(self, ti): + self._dag_id = ti.dag_id + self._task_id = ti.task_id + self._execution_date = ti.execution_date + self._start_date = ti.start_date + self._end_date = ti.end_date + self._try_number = ti.try_number + self._state = ti.state + if hasattr(ti, 'run_as_user'): + self._run_as_user = ti.run_as_user + else: + self._run_as_user = None + if hasattr(ti, 'pool'): + self._pool = ti.pool + else: + self._pool = None + if hasattr(ti, 'priority_weight'): + self._priority_weight = ti.priority_weight + else: + self._priority_weight = None + self._queue = ti.queue + self._key = ti.key + + @property + def dag_id(self): + return self._dag_id + + @property + def task_id(self): + return self._task_id + + @property + def execution_date(self): + return self._execution_date + + @property + def start_date(self): + return self._start_date + + @property + def end_date(self): + return self._end_date + + @property + def try_number(self): + return self._try_number + + @property + def state(self): + return self._state + + @property + def pool(self): + return self._pool + + @property + def priority_weight(self): + return self._priority_weight + + @property + def queue(self): + return self._queue + + @property + def key(self): + return self._key + + @provide_session + def construct_task_instance(self, session=None, lock_for_update=False): + """ + Construct a TaskInstance from the database based on the primary key + :param session: DB session. + :param lock_for_update: if True, indicates that the database should + lock the TaskInstance (issuing a FOR UPDATE clause) until the + session is committed. + """ + TI = airflow.models.TaskInstance + + qry = session.query(TI).filter( + TI.dag_id == self._dag_id, + TI.task_id == self._task_id, + TI.execution_date == self._execution_date) + + if lock_for_update: + ti = qry.with_for_update().first() + else: + ti = qry.first() + return ti + + class SimpleDagBag(BaseDagBag): """ A collection of SimpleDag objects with some convenience methods. @@ -308,6 +421,254 @@ def file_path(self): raise NotImplementedError() +class DagParsingStat(object): + def __init__(self, + file_paths, + all_pids, + done, + all_files_processed, + result_count): + self.file_paths = file_paths + self.all_pids = all_pids + self.done = done + self.all_files_processed = all_files_processed + self.result_count = result_count + + +class DagParsingSignal(object): + AGENT_HEARTBEAT = "agent_heartbeat" + MANAGER_DONE = "manager_done" + TERMINATE_MANAGER = "terminate_manager" + END_MANAGER = "end_manager" + + +class DagFileProcessorAgent(LoggingMixin): + """ + Agent for DAG file processors. It is responsible for all DAG parsing + related jobs in scheduler process. Mainly it will collect DAG parsing + result from DAG file processor manager and communicate signal/DAG parsing + stat with DAG file processor manager. + """ + + def __init__(self, + dag_directory, + file_paths, + max_runs, + processor_factory, + async_mode): + """ + :param dag_directory: Directory where DAG definitions are kept. All + files in file_paths should be under this directory + :type dag_directory: unicode + :param file_paths: list of file paths that contain DAG definitions + :type file_paths: list[unicode] + :param max_runs: The number of times to parse and schedule each file. -1 + for unlimited. + :type max_runs: int + :param processor_factory: function that creates processors for DAG + definition files. Arguments are (dag_definition_path, log_file_path) + :type processor_factory: (unicode, unicode, list) -> (AbstractDagFileProcessor) + :param async_mode: Whether to start agent in async mode + :type async_mode: bool + """ + self._file_paths = file_paths + self._file_path_queue = [] + self._dag_directory = dag_directory + self._max_runs = max_runs + self._processor_factory = processor_factory + self._async_mode = async_mode + # Map from file path to the processor + self._processors = {} + # Map from file path to the last runtime + self._last_runtime = {} + # Map from file path to the last finish time + self._last_finish_time = {} + # Map from file path to the number of runs + self._run_count = defaultdict(int) + # Pids of DAG parse + self._all_pids = [] + # Pipe for communicating signals + self._parent_signal_conn, self._child_signal_conn = multiprocessing.Pipe() + # Pipe for communicating DagParsingStat + self._stat_queue = multiprocessing.Queue() + self._result_queue = multiprocessing.Queue() + self._process = None + self._done = False + # Initialized as true so we do not deactivate w/o any actual DAG parsing. + self._all_files_processed = True + self._result_count = 0 + + def start(self): + """ + Launch DagFileProcessorManager processor and start DAG parsing loop in manager. + """ + self._process = self._launch_process(self._dag_directory, + self._file_paths, + self._max_runs, + self._processor_factory, + self._child_signal_conn, + self._stat_queue, + self._result_queue, + self._async_mode) + self.log.info("Launched DagFileProcessorManager with pid: {}" + .format(self._process.pid)) + + def heartbeat(self): + """ + Should only be used when launched DAG file processor manager in sync mode. + Send agent heartbeat signal to the manager. + """ + self._parent_signal_conn.send(DagParsingSignal.AGENT_HEARTBEAT) + + def wait_until_finished(self): + """ + Should only be used when launched DAG file processor manager in sync mode. + Wait for done signal from the manager. + """ + while True: + if self._parent_signal_conn.poll() \ + and self._parent_signal_conn.recv() == DagParsingSignal.MANAGER_DONE: + break + time.sleep(0.1) + + @staticmethod + def _launch_process(dag_directory, + file_paths, + max_runs, + processor_factory, + signal_conn, + _stat_queue, + result_queue, + async_mode): + def helper(): + # Reload configurations and settings to avoid collision with parent process. + # Because this process may need custom configurations that cannot be shared, + # e.g. RotatingFileHandler. And it can cause connection corruption if we + # do not recreate the SQLA connection pool. + os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER'] = 'True' + reload(airflow.config_templates.airflow_local_settings) + reload(airflow.settings) + del os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER'] + processor_manager = DagFileProcessorManager(dag_directory, + file_paths, + max_runs, + processor_factory, + signal_conn, + _stat_queue, + result_queue, + async_mode) + + processor_manager.start() + + p = multiprocessing.Process(target=helper, + args=(), + name="DagFileProcessorManager") + p.start() + return p + + def harvest_simple_dags(self): + """ + Harvest DAG parsing results from result queue and sync metadata from stat queue. + :return: List of parsing result in SimpleDag format. + """ + # Metadata and results to be harvested can be inconsistent, + # but it should not be a big problem. + self._sync_metadata() + # Heartbeating after syncing metadata so we do not restart manager + # if it processed all files for max_run times and exit normally. + self._heartbeat_manager() + simple_dags = [] + # multiprocessing.Queue().qsize will not work on MacOS. + if sys.platform == "darwin": + qsize = self._result_count + else: + qsize = self._result_queue.qsize() + for _ in xrange(qsize): + simple_dags.append(self._result_queue.get()) + + self._result_count = 0 + + return simple_dags + + def _heartbeat_manager(self): + """ + Heartbeat DAG file processor and start it if it is not alive. + :return: + """ + if self._process and not self._process.is_alive() and not self.done: + self.start() + + def _sync_metadata(self): + """ + Sync metadata from stat queue and only keep the latest stat. + :return: + """ + while not self._stat_queue.empty(): + stat = self._stat_queue.get() + self._file_paths = stat.file_paths + self._all_pids = stat.all_pids + self._done = stat.done + self._all_files_processed = stat.all_files_processed + self._result_count += stat.result_count + + @property + def file_paths(self): + return self._file_paths + + @property + def done(self): + return self._done + + @property + def all_files_processed(self): + return self._all_files_processed + + def terminate(self): + """ + Send termination signal to DAG parsing processor manager + and expect it to terminate all DAG file processors. + """ + self.log.info("Sending termination signal to manager.") + self._child_signal_conn.send(DagParsingSignal.TERMINATE_MANAGER) + + def end(self): + """ + Terminate (and then kill) the manager process launched. + :return: + """ + if not self._process: + self.log.warn('Ending without manager process.') + return + this_process = psutil.Process(os.getpid()) + try: + manager_process = psutil.Process(self._process.pid) + except psutil.NoSuchProcess: + self.log.info("Manager process not running.") + return + + # First try SIGTERM + if manager_process.is_running() \ + and manager_process.pid in [x.pid for x in this_process.children()]: + self.log.info( + "Terminating manager process: {}".format(manager_process.pid)) + manager_process.terminate() + timeout = 5 + self.log.info("Waiting up to {}s for manager process to exit..." + .format(timeout)) + try: + psutil.wait_procs({manager_process}, timeout) + except psutil.TimeoutExpired: + self.log.debug("Ran out of time while waiting for " + "processes to exit") + + # Then SIGKILL + if manager_process.is_running() \ + and manager_process.pid in [x.pid for x in this_process.children()]: + self.log.info("Killing manager process: {}".format(manager_process.pid)) + manager_process.kill() + manager_process.wait() + + class DagFileProcessorManager(LoggingMixin): """ Given a list of DAG definition files, this kicks off several processors @@ -324,48 +685,322 @@ class DagFileProcessorManager(LoggingMixin): def __init__(self, dag_directory, file_paths, - parallelism, - process_file_interval, max_runs, - processor_factory): + processor_factory, + signal_conn, + stat_queue, + result_queue, + async_mode=True): """ :param dag_directory: Directory where DAG definitions are kept. All files in file_paths should be under this directory :type dag_directory: unicode :param file_paths: list of file paths that contain DAG definitions :type file_paths: list[unicode] - :param parallelism: maximum number of simultaneous process to run at once - :type parallelism: int - :param process_file_interval: process a file at most once every this - many seconds - :type process_file_interval: float :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :type max_runs: int - :type process_file_interval: float :param processor_factory: function that creates processors for DAG definition files. Arguments are (dag_definition_path) - :type processor_factory: (unicode, unicode) -> (AbstractDagFileProcessor) - + :type processor_factory: (unicode, unicode, list) -> (AbstractDagFileProcessor) + :param signal_conn: connection to communicate signal with processor agent. + :type signal_conn: Connection + :param stat_queue: the queue to use for passing back parsing stat to agent. + :type stat_queue: multiprocessing.Queue + :param result_queue: the queue to use for passing back the result to agent. + :type result_queue: multiprocessing.Queue + :param async_mode: whether to start the manager in async mode + :type async_mode: bool """ self._file_paths = file_paths self._file_path_queue = [] - self._parallelism = parallelism self._dag_directory = dag_directory self._max_runs = max_runs - self._process_file_interval = process_file_interval self._processor_factory = processor_factory + self._signal_conn = signal_conn + self._stat_queue = stat_queue + self._result_queue = result_queue + self._async_mode = async_mode + + self._parallelism = conf.getint('scheduler', 'max_threads') + if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1: + self.log.error("Cannot use more than 1 thread when using sqlite. " + "Setting parallelism to 1") + self._parallelism = 1 + + # Parse and schedule each file no faster than this interval. + self._file_process_interval = conf.getint('scheduler', + 'min_file_process_interval') + # How often to print out DAG file processing stats to the log. Default to + # 30 seconds. + self.print_stats_interval = conf.getint('scheduler', + 'print_stats_interval') + # How many seconds do we wait for tasks to heartbeat before mark them as zombies. + self._zombie_threshold_secs = ( + conf.getint('scheduler', 'scheduler_zombie_task_threshold')) # Map from file path to the processor self._processors = {} # Map from file path to the last runtime self._last_runtime = {} # Map from file path to the last finish time self._last_finish_time = {} + self._last_zombie_query_time = timezone.utcnow() + # Last time that the DAG dir was traversed to look for files + self.last_dag_dir_refresh_time = timezone.utcnow() + # Last time stats were printed + self.last_stat_print_time = timezone.datetime(2000, 1, 1) + self._zombie_query_interval = 10 # Map from file path to the number of runs self._run_count = defaultdict(int) - # Scheduler heartbeat key. + # Manager heartbeat key. self._heart_beat_key = 'heart-beat' + # How often to scan the DAGs directory for new files. Default to 5 minutes. + self.dag_dir_list_interval = conf.getint('scheduler', + 'dag_dir_list_interval') + + self._log = logging.getLogger('airflow.processor_manager') + + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame): + """ + Helper method to clean up DAG file processors to avoid leaving orphan processes. + """ + self.log.info("Exiting gracefully with signal {}".format(signum)) + self.terminate() + self.end() + self.log.debug("Finished terminating DAG processors.") + sys.exit(os.EX_OK) + + def start(self): + """ + Use multiple processes to parse and generate tasks for the + DAGs in parallel. By processing them in separate processes, + we can get parallelism and isolation from potentially harmful + user code. + :return: + """ + + self.log.info("Processing files using up to {} processes at a time " + .format(self._parallelism)) + self.log.info("Process each file at most once every {} seconds" + .format(self._file_process_interval)) + self.log.info("Checking for new files in {} every {} seconds" + .format(self._dag_directory, self.dag_dir_list_interval)) + + if self._async_mode: + self.log.debug("Starting DagFileProcessorManager in async mode") + self.start_in_async() + else: + self.log.debug("Starting DagFileProcessorManager in sync mode") + self.start_in_sync() + + def start_in_async(self): + """ + Parse DAG files repeatedly in a standalone loop. + """ + while True: + loop_start_time = time.time() + + if self._signal_conn.poll(): + agent_signal = self._signal_conn.recv() + if agent_signal == DagParsingSignal.TERMINATE_MANAGER: + self.terminate() + break + elif agent_signal == DagParsingSignal.END_MANAGER: + self.end() + sys.exit(os.EX_OK) + + self._refresh_dag_dir() + + simple_dags = self.heartbeat() + for simple_dag in simple_dags: + self._result_queue.put(simple_dag) + + self._print_stat() + + all_files_processed = all(self.get_last_finish_time(x) is not None + for x in self.file_paths) + max_runs_reached = self.max_runs_reached() + + dag_parsing_stat = DagParsingStat(self._file_paths, + self.get_all_pids(), + max_runs_reached, + all_files_processed, + len(simple_dags)) + self._stat_queue.put(dag_parsing_stat) + + if max_runs_reached: + self.log.info("Exiting dag parsing loop as all files " + "have been processed %s times", self._max_runs) + break + + loop_duration = time.time() - loop_start_time + if loop_duration < 1: + sleep_length = 1 - loop_duration + self.log.debug("Sleeping for {0:.2f} seconds " + "to prevent excessive logging".format(sleep_length)) + time.sleep(sleep_length) + + def start_in_sync(self): + """ + Parse DAG files in a loop controlled by DagParsingSignal. + Actual DAG parsing loop will run once everyone an agent heartbeats + the manager and will report done when finished the loop. + """ + while True: + if self._signal_conn.poll(): + agent_signal = self._signal_conn.recv() + if agent_signal == DagParsingSignal.TERMINATE_MANAGER: + self.terminate() + break + elif agent_signal == DagParsingSignal.END_MANAGER: + self.end() + sys.exit(os.EX_OK) + elif agent_signal == DagParsingSignal.AGENT_HEARTBEAT: + + self._refresh_dag_dir() + + simple_dags = self.heartbeat() + for simple_dag in simple_dags: + self._result_queue.put(simple_dag) + + self._print_stat() + + all_files_processed = all(self.get_last_finish_time(x) is not None + for x in self.file_paths) + max_runs_reached = self.max_runs_reached() + + dag_parsing_stat = DagParsingStat(self._file_paths, + self.get_all_pids(), + self.max_runs_reached(), + all_files_processed, + len(simple_dags)) + self._stat_queue.put(dag_parsing_stat) + + self.wait_until_finished() + self._signal_conn.send(DagParsingSignal.MANAGER_DONE) + + if max_runs_reached: + self.log.info("Exiting dag parsing loop as all files " + "have been processed %s times", self._max_runs) + self._signal_conn.send(DagParsingSignal.MANAGER_DONE) + break + else: + time.sleep(0.01) + + def _refresh_dag_dir(self): + """ + Refresh file paths from dag dir if we haven't done it for too long. + """ + elapsed_time_since_refresh = (timezone.utcnow() - + self.last_dag_dir_refresh_time).total_seconds() + if elapsed_time_since_refresh > self.dag_dir_list_interval: + # Build up a list of Python files that could contain DAGs + self.log.info( + "Searching for files in {}".format(self._dag_directory)) + self._file_paths = list_py_file_paths(self._dag_directory) + self.last_dag_dir_refresh_time = timezone.utcnow() + self.log.info("There are {} files in {}" + .format(len(self._file_paths), + self._dag_directory)) + self.set_file_paths(self._file_paths) + + try: + self.log.debug("Removing old import errors") + self.clear_nonexistent_import_errors() + except Exception: + self.log.exception("Error removing old import errors") + + def _print_stat(self): + """ + Occasionally print out stats about how fast the files are getting processed + :return: + """ + if ((timezone.utcnow() - self.last_stat_print_time).total_seconds() > + self.print_stats_interval): + if len(self._file_paths) > 0: + self._log_file_processing_stats(self._file_paths) + self.last_stat_print_time = timezone.utcnow() + + @provide_session + def clear_nonexistent_import_errors(self, session): + """ + Clears import errors for files that no longer exist. + :param session: session for ORM operations + :type session: sqlalchemy.orm.session.Session + """ + query = session.query(airflow.models.ImportError) + if self._file_paths: + query = query.filter( + ~airflow.models.ImportError.filename.in_(self._file_paths) + ) + query.delete(synchronize_session='fetch') + session.commit() + + def _log_file_processing_stats(self, known_file_paths): + """ + Print out stats about how files are getting processed. + :param known_file_paths: a list of file paths that may contain Airflow + DAG definitions + :type known_file_paths: list[unicode] + :return: None + """ + + # File Path: Path to the file containing the DAG definition + # PID: PID associated with the process that's processing the file. May + # be empty. + # Runtime: If the process is currently running, how long it's been + # running for in seconds. + # Last Runtime: If the process ran before, how long did it take to + # finish in seconds + # Last Run: When the file finished processing in the previous run. + headers = ["File Path", + "PID", + "Runtime", + "Last Runtime", + "Last Run"] + + rows = [] + for file_path in known_file_paths: + last_runtime = self.get_last_runtime(file_path) + processor_pid = self.get_pid(file_path) + processor_start_time = self.get_start_time(file_path) + runtime = ((timezone.utcnow() - processor_start_time).total_seconds() + if processor_start_time else None) + last_run = self.get_last_finish_time(file_path) + + rows.append((file_path, + processor_pid, + runtime, + last_runtime, + last_run)) + + # Sort by longest last runtime. (Can't sort None values in python3) + rows = sorted(rows, key=lambda x: x[3] or 0.0) + + formatted_rows = [] + for file_path, pid, runtime, last_runtime, last_run in rows: + formatted_rows.append((file_path, + pid, + "{:.2f}s".format(runtime) + if runtime else None, + "{:.2f}s".format(last_runtime) + if last_runtime else None, + last_run.strftime("%Y-%m-%dT%H:%M:%S") + if last_run else None)) + log_str = ("\n" + + "=" * 80 + + "\n" + + "DAG File Processing Stats\n\n" + + tabulate(formatted_rows, headers=headers) + + "\n" + + "=" * 80) + + self.log.info(log_str) + @property def file_paths(self): return self._file_paths @@ -472,7 +1107,7 @@ def wait_until_finished(self): def heartbeat(self): """ - This should be periodically called by the scheduler. This method will + This should be periodically called by the manager loop. This method will kick off new processes to process DAG definition files and read the results from the finished processors. @@ -498,7 +1133,7 @@ def heartbeat(self): running_processors[file_path] = processor self._processors = running_processors - self.log.debug("%s/%s scheduler processes running", + self.log.debug("%s/%s DAG parsing processes running", len(self._processors), self._parallelism) self.log.debug("%s file paths queued for processing", @@ -528,7 +1163,7 @@ def heartbeat(self): last_finish_time = self.get_last_finish_time(file_path) if (last_finish_time is not None and (now - last_finish_time).total_seconds() < - self._process_file_interval): + self._file_process_interval): file_paths_recently_processed.append(file_path) files_paths_at_run_limit = [file_path @@ -553,11 +1188,13 @@ def heartbeat(self): self._file_path_queue.extend(files_paths_to_queue) + zombies = self._find_zombies() + # Start more processors if we have enough slots and files to process while (self._parallelism - len(self._processors) > 0 and len(self._file_path_queue) > 0): file_path = self._file_path_queue.pop(0) - processor = self._processor_factory(file_path) + processor = self._processor_factory(file_path, zombies) processor.start() self.log.info( @@ -566,11 +1203,47 @@ def heartbeat(self): ) self._processors[file_path] = processor - # Update scheduler heartbeat count. + # Update heartbeat count. self._run_count[self._heart_beat_key] += 1 return simple_dags + @provide_session + def _find_zombies(self, session): + """ + Find zombie task instances, which are tasks haven't heartbeated for too long. + :return: Zombie task instances in SimpleTaskInstance format. + """ + now = timezone.utcnow() + zombies = [] + if (now - self._last_zombie_query_time).total_seconds() \ + > self._zombie_query_interval: + # to avoid circular imports + from airflow.jobs import LocalTaskJob as LJ + self.log.info("Finding 'running' jobs without a recent heartbeat") + TI = airflow.models.TaskInstance + limit_dttm = timezone.utcnow() - timedelta( + seconds=self._zombie_threshold_secs) + self.log.info( + "Failing jobs without heartbeat after {}".format(limit_dttm)) + + tis = ( + session.query(TI) + .join(LJ, TI.job_id == LJ.id) + .filter(TI.state == State.RUNNING) + .filter( + or_( + LJ.state != State.RUNNING, + LJ.latest_heartbeat < limit_dttm, + ) + ).all() + ) + self._last_zombie_query_time = timezone.utcnow() + for ti in tis: + zombies.append(SimpleTaskInstance(ti)) + + return zombies + def max_runs_reached(self): """ :return: whether all file paths have been processed max_runs times @@ -591,3 +1264,41 @@ def terminate(self): """ for processor in self._processors.values(): processor.terminate() + + def end(self): + """ + Kill all child processes on exit since we don't want to leave + them as orphaned. + """ + pids_to_kill = self.get_all_pids() + if len(pids_to_kill) > 0: + # First try SIGTERM + this_process = psutil.Process(os.getpid()) + # Only check child processes to ensure that we don't have a case + # where we kill the wrong process because a child process died + # but the PID got reused. + child_processes = [x for x in this_process.children(recursive=True) + if x.is_running() and x.pid in pids_to_kill] + for child in child_processes: + self.log.info("Terminating child PID: {}".format(child.pid)) + child.terminate() + # TODO: Remove magic number + timeout = 5 + self.log.info( + "Waiting up to %s seconds for processes to exit...", timeout) + try: + psutil.wait_procs( + child_processes, timeout=timeout, + callback=lambda x: self.log.info('Terminated PID %s', x.pid)) + except psutil.TimeoutExpired: + self.log.debug("Ran out of time while waiting for processes to exit") + + # Then SIGKILL + child_processes = [x for x in this_process.children(recursive=True) + if x.is_running() and x.pid in pids_to_kill] + if len(child_processes) > 0: + self.log.info("SIGKILL processes that did not terminate gracefully") + for child in child_processes: + self.log.info("Killing child PID: {}".format(child.pid)) + child.kill() + child.wait() diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py index a86b9d357b..f648005877 100644 --- a/airflow/utils/timeout.py +++ b/airflow/utils/timeout.py @@ -23,6 +23,7 @@ from __future__ import unicode_literals import signal +import os from airflow.exceptions import AirflowTaskTimeout from airflow.utils.log.logging_mixin import LoggingMixin @@ -35,10 +36,10 @@ class timeout(LoggingMixin): def __init__(self, seconds=1, error_message='Timeout'): self.seconds = seconds - self.error_message = error_message + self.error_message = error_message + ', PID: ' + str(os.getpid()) def handle_timeout(self, signum, frame): - self.log.error("Process timed out") + self.log.error("Process timed out, PID: " + str(os.getpid())) raise AirflowTaskTimeout(self.error_message) def __enter__(self): diff --git a/docs/scheduler.rst b/docs/scheduler.rst index 9c5cb5c90c..dc43c111fe 100644 --- a/docs/scheduler.rst +++ b/docs/scheduler.rst @@ -3,8 +3,9 @@ Scheduling & Triggers The Airflow scheduler monitors all tasks and all DAGs, and triggers the task instances whose dependencies have been met. Behind the scenes, -it monitors and stays in sync with a folder for all DAG objects it may contain, -and periodically (every minute or so) inspects active tasks to see whether +it spins up a subprocess, which monitors and stays in sync with a folder +for all DAG objects it may contain, and periodically (every minute or so) +collects DAG parsing results and inspects active tasks to see whether they can be triggered. The Airflow scheduler is designed to run as a persistent service in an diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 2ebcfd7b63..b6ce949854 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -21,6 +21,7 @@ import mock from celery.contrib.testing.worker import start_worker +from airflow.executors import celery_executor from airflow.executors.celery_executor import CeleryExecutor from airflow.executors.celery_executor import app from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER @@ -59,6 +60,21 @@ def test_celery_integration(self): self.assertNotIn('success', executor.last_state) self.assertNotIn('fail', executor.last_state) + def test_error_sending_task(self): + @app.task + def fake_execute_command(): + pass + + # fake_execute_command takes no arguments while execute_command takes 1, + # which will cause TypeError when calling task.apply_async() + celery_executor.execute_command = fake_execute_command + executor = CeleryExecutor() + value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti' + executor.queued_tasks['key'] = value_tuple + executor.heartbeat() + self.assertEquals(1, len(executor.queued_tasks)) + self.assertEquals(executor.queued_tasks['key'], value_tuple) + def test_exception_propagation(self): @app.task def fake_celery_task(): diff --git a/tests/jobs.py b/tests/jobs.py index dc3381e8e0..2d971131ff 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -28,37 +28,35 @@ import multiprocessing import os import shutil -import six import threading import time import unittest from tempfile import mkdtemp +import psutil +import six import sqlalchemy +from mock import Mock, patch, MagicMock, PropertyMock from airflow import AirflowException, settings, models +from airflow import configuration from airflow.bin import cli from airflow.executors import BaseExecutor, SequentialExecutor from airflow.jobs import BaseJob, BackfillJob, SchedulerJob, LocalTaskJob from airflow.models import DAG, DagModel, DagBag, DagRun, Pool, TaskInstance as TI -from airflow.operators.dummy_operator import DummyOperator from airflow.operators.bash_operator import BashOperator +from airflow.operators.dummy_operator import DummyOperator from airflow.task.task_runner.base_task_runner import BaseTaskRunner from airflow.utils import timezone - +from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths from airflow.utils.dates import days_ago from airflow.utils.db import provide_session +from airflow.utils.net import get_hostname from airflow.utils.state import State from airflow.utils.timeout import timeout -from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths -from airflow.utils.net import get_hostname - -from mock import Mock, patch, MagicMock, PropertyMock -from tests.executors.test_executor import TestExecutor - from tests.core import TEST_DAG_FOLDER +from tests.executors.test_executor import TestExecutor -from airflow import configuration configuration.load_test_config() logger = logging.getLogger(__name__) @@ -1211,6 +1209,7 @@ def setUp(self): session.query(models.DagRun).delete() session.query(models.ImportError).delete() session.commit() + session.close() @staticmethod def run_single_scheduler_loop_with_no_dags(dags_folder): @@ -1233,6 +1232,21 @@ def run_single_scheduler_loop_with_no_dags(dags_folder): def _make_simple_dag_bag(self, dags): return SimpleDagBag([SimpleDag(dag) for dag in dags]) + def test_no_orphan_process_will_be_left(self): + empty_dir = mkdtemp() + current_process = psutil.Process() + old_children = current_process.children(recursive=True) + scheduler = SchedulerJob(subdir=empty_dir, + num_runs=1) + scheduler.executor = TestExecutor() + scheduler.run() + shutil.rmtree(empty_dir) + + # Remove potential noise created by previous tests. + current_children = set(current_process.children(recursive=True)) - set( + old_children) + self.assertFalse(current_children) + def test_process_executor_events(self): dag_id = "test_process_executor_events" dag_id2 = "test_process_executor_events_2" @@ -1974,6 +1988,54 @@ def test_change_state_for_tis_without_dagrun(self): ti2.refresh_from_db(session=session) self.assertEqual(ti2.state, State.SCHEDULED) + def test_change_state_for_tasks_failed_to_execute(self): + dag = DAG( + dag_id='dag_id', + start_date=DEFAULT_DATE) + + task = DummyOperator( + task_id='task_id', + dag=dag, + owner='airflow') + + # If there's no left over task in executor.queued_tasks, nothing happens + session = settings.Session() + scheduler_job = SchedulerJob() + mock_logger = mock.MagicMock() + test_executor = TestExecutor() + scheduler_job.executor = test_executor + scheduler_job._logger = mock_logger + scheduler_job._change_state_for_tasks_failed_to_execute() + mock_logger.info.assert_not_called() + + # Tasks failed to execute with QUEUED state will be set to SCHEDULED state. + session.query(TI).delete() + session.commit() + key = 'dag_id', 'task_id', DEFAULT_DATE + test_executor.queued_tasks[key] = 'value' + ti = TI(task, DEFAULT_DATE) + ti.state = State.QUEUED + session.merge(ti) + session.commit() + + scheduler_job._change_state_for_tasks_failed_to_execute() + + ti.refresh_from_db() + self.assertEquals(State.SCHEDULED, ti.state) + + # Tasks failed to execute with RUNNING state will not be set to SCHEDULED state. + session.query(TI).delete() + session.commit() + ti.state = State.RUNNING + + session.merge(ti) + session.commit() + + scheduler_job._change_state_for_tasks_failed_to_execute() + + ti.refresh_from_db() + self.assertEquals(State.RUNNING, ti.state) + def test_execute_helper_reset_orphaned_tasks(self): session = settings.Session() dag = DAG( @@ -2002,13 +2064,13 @@ def test_execute_helper_reset_orphaned_tasks(self): session.commit() processor = mock.MagicMock() - processor.get_last_finish_time.return_value = None scheduler = SchedulerJob(num_runs=0, run_duration=0) executor = TestExecutor() scheduler.executor = executor + scheduler.processor_agent = processor - scheduler._execute_helper(processor_manager=processor) + scheduler._execute_helper() ti = dr.get_task_instance(task_id=op1.task_id, session=session) self.assertEqual(ti.state, State.NONE) @@ -2160,7 +2222,7 @@ def test_scheduler_start_date(self): """ Test that the scheduler respects start_dates, even when DAGS have run """ - + session = settings.Session() dag_id = 'test_start_date_scheduling' dag = self.dagbag.get_dag(dag_id) dag.clear() @@ -2171,9 +2233,9 @@ def test_scheduler_start_date(self): scheduler.run() # zero tasks ran - session = settings.Session() self.assertEqual( len(session.query(TI).filter(TI.dag_id == dag_id).all()), 0) + session.commit() # previously, running this backfill would kick off the Scheduler # because it would take the most recent run and start from there @@ -2186,18 +2248,19 @@ def test_scheduler_start_date(self): backfill.run() # one task ran - session = settings.Session() self.assertEqual( len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + session.commit() scheduler = SchedulerJob(dag_id, num_runs=2) scheduler.run() # still one task - session = settings.Session() self.assertEqual( len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + session.commit() + session.close() def test_scheduler_multiprocessing(self): """ @@ -2935,7 +2998,9 @@ def test_scheduler_run_duration(self): logging.info("Test ran in %.2fs, expected %.2fs", run_duration, expected_run_duration) - self.assertLess(run_duration - expected_run_duration, 5.0) + # 5s to wait for child process to exit and 1s dummy sleep + # in scheduler loop to prevent excessive logs. + self.assertLess(run_duration - expected_run_duration, 6.0) def test_dag_with_system_exit(self): """ @@ -2962,6 +3027,7 @@ def test_dag_with_system_exit(self): session = settings.Session() self.assertEqual( len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + session.close() def test_dag_get_active_runs(self): """ @@ -3095,10 +3161,9 @@ def setup_dag(dag_id, schedule_interval, start_date, catchup): dr = scheduler.create_dag_run(dag4) self.assertIsNotNone(dr) - def test_add_unparseable_file_before_sched_start_creates_import_error(self): + dags_folder = mkdtemp() try: - dags_folder = mkdtemp() unparseable_filename = os.path.join(dags_folder, TEMP_DAG_FILENAME) with open(unparseable_filename, 'w') as unparseable_file: unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) @@ -3108,6 +3173,7 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self): session = settings.Session() import_errors = session.query(models.ImportError).all() + session.close() self.assertEqual(len(import_errors), 1) import_error = import_errors[0] @@ -3117,8 +3183,8 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self): "invalid syntax ({}, line 1)".format(TEMP_DAG_FILENAME)) def test_add_unparseable_file_after_sched_start_creates_import_error(self): + dags_folder = mkdtemp() try: - dags_folder = mkdtemp() unparseable_filename = os.path.join(dags_folder, TEMP_DAG_FILENAME) self.run_single_scheduler_loop_with_no_dags(dags_folder) @@ -3130,6 +3196,7 @@ def test_add_unparseable_file_after_sched_start_creates_import_error(self): session = settings.Session() import_errors = session.query(models.ImportError).all() + session.close() self.assertEqual(len(import_errors), 1) import_error = import_errors[0] @@ -3151,6 +3218,7 @@ def test_no_import_errors_with_parseable_dag(self): session = settings.Session() import_errors = session.query(models.ImportError).all() + session.close() self.assertEqual(len(import_errors), 0) @@ -3224,6 +3292,7 @@ def test_remove_file_clears_import_error(self): session = settings.Session() import_errors = session.query(models.ImportError).all() + session.close() self.assertEqual(len(import_errors), 0) diff --git a/tests/models.py b/tests/models.py index a317d4c5e2..cdd14bbc0c 100644 --- a/tests/models.py +++ b/tests/models.py @@ -51,6 +51,7 @@ from airflow.operators.python_operator import ShortCircuitOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone +from airflow.utils.dag_processing import SimpleTaskInstance from airflow.utils.weight_rule import WeightRule from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule @@ -1578,7 +1579,7 @@ def test_process_file_with_none(self): self.assertEqual([], dagbag.process_file(None)) @patch.object(TI, 'handle_failure') - def test_kill_zombies(self, mock_ti): + def test_kill_zombies(self, mock_ti_handle_failure): """ Test that kill zombies call TIs failure handler with proper context """ @@ -1587,22 +1588,18 @@ def test_kill_zombies(self, mock_ti): dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') - ti = TI(task, datetime.datetime.now() - datetime.timedelta(1), 'running') - lj = LocalTaskJob(ti) - lj.state = State.SHUTDOWN - - session.add(lj) - session.commit() - - ti.job_id = lj.id + ti = TI(task, DEFAULT_DATE, State.RUNNING) session.add(ti) session.commit() - dagbag.kill_zombies() - mock_ti.assert_called_with(ANY, - configuration.getboolean('core', 'unit_test_mode'), - ANY) + zombies = [SimpleTaskInstance(ti)] + dagbag.kill_zombies(zombies) + mock_ti_handle_failure \ + .assert_called_with(ANY, + configuration.getboolean('core', + 'unit_test_mode'), + ANY) class TaskInstanceTest(unittest.TestCase): diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index f29e384b8c..8a6d4d2fd9 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,11 +17,26 @@ # specific language governing permissions and limitations # under the License. +import os import unittest +from datetime import timedelta from mock import MagicMock -from airflow.utils.dag_processing import DagFileProcessorManager +from airflow import configuration as conf +from airflow.jobs import DagFileProcessor +from airflow.jobs import LocalTaskJob as LJ +from airflow.models import DagBag, TaskInstance as TI +from airflow.settings import Session +from airflow.utils import timezone +from airflow.utils.dag_processing import (DagFileProcessorAgent, DagFileProcessorManager, + SimpleTaskInstance) +from airflow.utils.state import State + +TEST_DAG_FOLDER = os.path.join( + os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags') + +DEFAULT_DATE = timezone.datetime(2016, 1, 1) class TestDagFileProcessorManager(unittest.TestCase): @@ -29,10 +44,12 @@ def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): manager = DagFileProcessorManager( dag_directory='directory', file_paths=['abc.txt'], - parallelism=1, - process_file_interval=1, max_runs=1, - processor_factory=MagicMock().return_value) + processor_factory=MagicMock().return_value, + signal_conn=MagicMock(), + stat_queue=MagicMock(), + result_queue=MagicMock, + async_mode=True) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError( @@ -48,10 +65,12 @@ def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): manager = DagFileProcessorManager( dag_directory='directory', file_paths=['abc.txt'], - parallelism=1, - process_file_interval=1, max_runs=1, - processor_factory=MagicMock().return_value) + processor_factory=MagicMock().return_value, + signal_conn=MagicMock(), + stat_queue=MagicMock(), + result_queue=MagicMock, + async_mode=True) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError( @@ -62,3 +81,66 @@ def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): manager.set_file_paths(['abc.txt']) self.assertDictEqual(manager._processors, {'abc.txt': mock_processor}) + + def test_find_zombies(self): + manager = DagFileProcessorManager( + dag_directory='directory', + file_paths=['abc.txt'], + max_runs=1, + processor_factory=MagicMock().return_value, + signal_conn=MagicMock(), + stat_queue=MagicMock(), + result_queue=MagicMock, + async_mode=True) + + dagbag = DagBag(TEST_DAG_FOLDER) + session = Session() + dag = dagbag.get_dag('example_branch_operator') + task = dag.get_task(task_id='run_this_first') + + ti = TI(task, DEFAULT_DATE, State.RUNNING) + lj = LJ(ti) + lj.state = State.SHUTDOWN + lj.id = 1 + ti.job_id = lj.id + + session.add(lj) + session.add(ti) + session.commit() + session.close() + + manager._last_zombie_query_time = timezone.utcnow() - timedelta( + seconds=manager._zombie_threshold_secs + 1) + zombies = manager._find_zombies() + self.assertEquals(1, len(zombies)) + self.assertIsInstance(zombies[0], SimpleTaskInstance) + self.assertEquals(ti.dag_id, zombies[0].dag_id) + self.assertEquals(ti.task_id, zombies[0].task_id) + self.assertEquals(ti.execution_date, zombies[0].execution_date) + + +class TestDagFileProcessorAgent(unittest.TestCase): + def test_parse_once(self): + def processor_factory(file_path, zombies): + return DagFileProcessor(file_path, + False, + [], + zombies) + + test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py') + async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') + processor_agent = DagFileProcessorAgent(test_dag_path, + [test_dag_path], + 1, + processor_factory, + async_mode) + processor_agent.start() + parsing_result = [] + while not processor_agent.done: + if not async_mode: + processor_agent.heartbeat() + processor_agent.wait_until_finished() + parsing_result.extend(processor_agent.harvest_simple_dags()) + + dag_ids = [result.dag_id for result in parsing_result] + self.assertEqual(dag_ids.count('test_start_date_scheduling'), 1) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Parallelize Celery Executor enqueuing > ------------------------------------- > > Key: AIRFLOW-2761 > URL: https://issues.apache.org/jira/browse/AIRFLOW-2761 > Project: Apache Airflow > Issue Type: Improvement > Reporter: Kevin Yang > Priority: Major > > Currently celery executor enqueues in an async fashion but still doing that > in a single process loop. This can slows down scheduler loop and creates > scheduling delay if we have large # of task to schedule in a short time, e.g. > UTC midnight we need to schedule large # of sensors in a short period. -- This message was sent by Atlassian JIRA (v7.6.3#76005)