yrqls21 closed pull request #3836: [WIP][Airflow-2760] Decouple DAG parsing loop from scheduler loop URL: https://github.com/apache/incubator-airflow/pull/3836
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/jobs.py b/airflow/jobs.py index 90dd293f1e..b04cd2fb5b 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -26,22 +26,19 @@ import logging import multiprocessing import os -import psutil import signal -import six import sys import threading import time -import datetime - from collections import defaultdict +from time import sleep + +import six from past.builtins import basestring from sqlalchemy import ( Column, Integer, String, func, Index, or_, and_, not_) from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient -from tabulate import tabulate -from time import sleep from airflow import configuration as conf from airflow import executors, models, settings @@ -53,7 +50,7 @@ from airflow.utils import asciiart, helpers, timezone from airflow.utils.configuration import tmp_configuration_copy from airflow.utils.dag_processing import (AbstractDagFileProcessor, - DagFileProcessorManager, + DagFileProcessorAgent, SimpleDag, SimpleDagBag, list_py_file_paths) @@ -61,8 +58,8 @@ from airflow.utils.email import send_email from airflow.utils.log.logging_mixin import LoggingMixin, set_context, StreamLogWriter from airflow.utils.net import get_hostname -from airflow.utils.state import State from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.state import State Base = models.Base ID_LEN = models.ID_LEN @@ -160,12 +157,14 @@ def heartbeat(self): heart rate. If you go over 60 seconds before calling it, it won't sleep at all. """ - with create_session() as session: - job = session.query(BaseJob).filter_by(id=self.id).one() - make_transient(job) - session.commit() + session = settings.Session() + # with create_session() as session: + job = session.query(BaseJob).filter_by(id=self.id).one() + make_transient(job) + session.commit() if job.state == State.SHUTDOWN: + session.close() self.kill() # Figure out how long to sleep for @@ -179,14 +178,15 @@ def heartbeat(self): sleep(sleep_for) # Update last heartbeat time - with create_session() as session: - job = session.query(BaseJob).filter(BaseJob.id == self.id).first() - job.latest_heartbeat = timezone.utcnow() - session.merge(job) - session.commit() + # with create_session() as session: + job = session.query(BaseJob).filter(BaseJob.id == self.id).first() + job.latest_heartbeat = timezone.utcnow() + session.merge(job) + session.commit() - self.heartbeat_callback(session=session) - self.log.debug('[heartbeat]') + self.heartbeat_callback(session=session) + session.close() + self.log.debug('[heartbeat]') def run(self): Stats.incr(self.__class__.__name__.lower() + '_start', 1, 1) @@ -304,7 +304,7 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin): # Counter that increments everytime an instance of this class is created class_creation_counter = 0 - def __init__(self, file_path, pickle_dags, dag_id_white_list): + def __init__(self, file_path, pickle_dags, dag_id_white_list, zombies): """ :param file_path: a Python file containing Airflow DAG definitions :type file_path: unicode @@ -312,6 +312,8 @@ def __init__(self, file_path, pickle_dags, dag_id_white_list): :type pickle_dags: bool :param dag_id_whitelist: If specified, only look at these DAG ID's :type dag_id_whitelist: list[unicode] + :param zombies: zombie task instances to kill + :type zombies: SimpleTaskInstance """ self._file_path = file_path # Queue that's used to pass results from the child process. @@ -320,6 +322,7 @@ def __init__(self, file_path, pickle_dags, dag_id_white_list): self._process = None self._dag_id_white_list = dag_id_white_list self._pickle_dags = pickle_dags + self._zombies = zombies # The result of Scheduler.process_file(file_path). self._result = None # Whether the process is done running. @@ -340,7 +343,8 @@ def _launch_process(result_queue, file_path, pickle_dags, dag_id_white_list, - thread_name): + thread_name, + zombies): """ Launch a process to process the given file. @@ -358,6 +362,8 @@ def _launch_process(result_queue, :type thread_name: unicode :return: the process that was launched :rtype: multiprocessing.Process + :param zombies: zombie task instances to kill + :type zombies: SimpleTaskInstance """ def helper(): # This helper runs in the newly created process @@ -386,6 +392,7 @@ def helper(): os.getpid(), file_path) scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log) result = scheduler_job.process_file(file_path, + zombies, pickle_dags) result_queue.put(result) end_time = time.time() @@ -418,7 +425,8 @@ def start(self): self.file_path, self._pickle_dags, self._dag_id_white_list, - "DagFileProcessor{}".format(self._instance_id)) + "DagFileProcessor{}".format(self._instance_id), + self._zombies) self._start_time = timezone.utcnow() def terminate(self, sigkill=False): @@ -472,7 +480,8 @@ def done(self): if self._done: return True - if not self._result_queue.empty(): + # In case result queue is corrupted. + if self._result_queue and not self._result_queue.empty(): self._result = self._result_queue.get_nowait() self._done = True self.log.debug("Waiting for %s", self._process) @@ -480,7 +489,7 @@ def done(self): return True # Potential error case when process dies - if not self._process.is_alive(): + if self._result_queue and not self._process.is_alive(): self._done = True # Get the object from the queue or else join() can hang. if not self._result_queue.empty(): @@ -531,8 +540,6 @@ def __init__( dag_ids=None, subdir=settings.DAGS_FOLDER, num_runs=-1, - file_process_interval=conf.getint('scheduler', - 'min_file_process_interval'), processor_poll_interval=1.0, run_duration=None, do_pickle=False, @@ -579,26 +586,27 @@ def __init__( self.using_sqlite = False if 'sqlite' in conf.get('core', 'sql_alchemy_conn'): - if self.max_threads > 1: - self.log.error("Cannot use more than 1 thread when using sqlite. Setting max_threads to 1") - self.max_threads = 1 self.using_sqlite = True - # How often to scan the DAGs directory for new files. Default to 5 minutes. - self.dag_dir_list_interval = conf.getint('scheduler', - 'dag_dir_list_interval') - # How often to print out DAG file processing stats to the log. Default to - # 30 seconds. - self.print_stats_interval = conf.getint('scheduler', - 'print_stats_interval') - - self.file_process_interval = file_process_interval - self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query') if run_duration is None: self.run_duration = conf.getint('scheduler', 'run_duration') + self.processor_agent = None + + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame): + """ + Helper method to clean up processor_agent to avoid leaving orphan processes. + """ + self.logger.info("Exiting gracefully with signal {}".format(signum)) + if self.processor_agent: + self.processor_agent.end() + sys.exit(os.EX_OK) + @provide_session def manage_slas(self, dag, session=None): """ @@ -734,25 +742,6 @@ def manage_slas(self, dag, session=None): session.merge(sla) session.commit() - @staticmethod - @provide_session - def clear_nonexistent_import_errors(session, known_file_paths): - """ - Clears import errors for files that no longer exist. - - :param session: session for ORM operations - :type session: sqlalchemy.orm.session.Session - :param known_file_paths: The list of existing files that are parsed for DAGs - :type known_file_paths: list[unicode] - """ - query = session.query(models.ImportError) - if known_file_paths: - query = query.filter( - ~models.ImportError.filename.in_(known_file_paths) - ) - query.delete(synchronize_session='fetch') - session.commit() - @staticmethod def update_import_errors(session, dagbag): """ @@ -1107,7 +1096,9 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): # Put one task instance on each line task_instance_str = "\n\t".join( ["{}".format(x) for x in task_instances_to_examine]) - self.log.info("Tasks up for execution:\n\t%s", task_instance_str) + self.log.info("{} tasks up for execution:\n\t{}" + .format(len(task_instances_to_examine), + task_instance_str)) # Get the pool settings pools = {p.pool: p for p in session.query(models.Pool).all()} @@ -1298,8 +1289,9 @@ def query(result, items): .all()) task_instance_str = "\n\t".join( ["{}".format(x) for x in tis_to_be_queued]) - self.log.info("Setting the follow tasks to queued state:\n\t%s", - task_instance_str) + self.log.info("Setting the follow {} tasks to queued state:\n\t{}" + .format(len(tis_to_be_queued), + task_instance_str)) return result + tis_to_be_queued tis_to_be_queued = helpers.reduce_in_chunks(query, @@ -1476,72 +1468,6 @@ def _process_executor_events(self, simple_dag_bag, session=None): session.merge(ti) session.commit() - def _log_file_processing_stats(self, - known_file_paths, - processor_manager): - """ - Print out stats about how files are getting processed. - - :param known_file_paths: a list of file paths that may contain Airflow - DAG definitions - :type known_file_paths: list[unicode] - :param processor_manager: manager for the file processors - :type stats: DagFileProcessorManager - :return: None - """ - - # File Path: Path to the file containing the DAG definition - # PID: PID associated with the process that's processing the file. May - # be empty. - # Runtime: If the process is currently running, how long it's been - # running for in seconds. - # Last Runtime: If the process ran before, how long did it take to - # finish in seconds - # Last Run: When the file finished processing in the previous run. - headers = ["File Path", - "PID", - "Runtime", - "Last Runtime", - "Last Run"] - - rows = [] - for file_path in known_file_paths: - last_runtime = processor_manager.get_last_runtime(file_path) - processor_pid = processor_manager.get_pid(file_path) - processor_start_time = processor_manager.get_start_time(file_path) - runtime = ((timezone.utcnow() - processor_start_time).total_seconds() - if processor_start_time else None) - last_run = processor_manager.get_last_finish_time(file_path) - - rows.append((file_path, - processor_pid, - runtime, - last_runtime, - last_run)) - - # Sort by longest last runtime. (Can't sort None values in python3) - rows = sorted(rows, key=lambda x: x[3] or 0.0) - - formatted_rows = [] - for file_path, pid, runtime, last_runtime, last_run in rows: - formatted_rows.append((file_path, - pid, - "{:.2f}s".format(runtime) - if runtime else None, - "{:.2f}s".format(last_runtime) - if last_runtime else None, - last_run.strftime("%Y-%m-%dT%H:%M:%S") - if last_run else None)) - log_str = ("\n" + - "=" * 80 + - "\n" + - "DAG File Processing Stats\n\n" + - tabulate(formatted_rows, headers=headers) + - "\n" + - "=" * 80) - - self.log.info(log_str) - def _execute(self): self.log.info("Starting the scheduler") @@ -1551,84 +1477,38 @@ def _execute(self): (executors.LocalExecutor, executors.SequentialExecutor): pickle_dags = True - # Use multiple processes to parse and generate tasks for the - # DAGs in parallel. By processing them in separate processes, - # we can get parallelism and isolation from potentially harmful - # user code. - self.log.info( - "Processing files using up to %s processes at a time", - self.max_threads) self.log.info("Running execute loop for %s seconds", self.run_duration) self.log.info("Processing each file at most %s times", self.num_runs) - self.log.info( - "Process each file at most once every %s seconds", - self.file_process_interval) - self.log.info( - "Checking for new files in %s every %s seconds", - self.subdir, - self.dag_dir_list_interval) # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s", self.subdir) known_file_paths = list_py_file_paths(self.subdir) self.log.info("There are %s files in %s", len(known_file_paths), self.subdir) - def processor_factory(file_path): + def processor_factory(file_path, zombies): return DagFileProcessor(file_path, pickle_dags, - self.dag_ids) + self.dag_ids, + zombies) + + # When using sqlite, we do not use async_mode + # so the scheduler job and DAG parser don't access the DB at the same time. + async_mode = not self.using_sqlite - processor_manager = DagFileProcessorManager(self.subdir, - known_file_paths, - self.max_threads, - self.file_process_interval, - self.num_runs, - processor_factory) + self.processor_agent = DagFileProcessorAgent(self.subdir, + known_file_paths, + self.num_runs, + processor_factory, + async_mode) try: - self._execute_helper(processor_manager) + self._execute_helper() finally: - self.log.info("Exited execute loop") - - # Kill all child processes on exit since we don't want to leave - # them as orphaned. - pids_to_kill = processor_manager.get_all_pids() - if len(pids_to_kill) > 0: - # First try SIGTERM - this_process = psutil.Process(os.getpid()) - # Only check child processes to ensure that we don't have a case - # where we kill the wrong process because a child process died - # but the PID got reused. - child_processes = [x for x in this_process.children(recursive=True) - if x.is_running() and x.pid in pids_to_kill] - for child in child_processes: - self.log.info("Terminating child PID: %s", child.pid) - child.terminate() - # TODO: Remove magic number - timeout = 5 - self.log.info( - "Waiting up to %s seconds for processes to exit...", timeout) - try: - psutil.wait_procs( - child_processes, timeout=timeout, - callback=lambda x: self.log.info('Terminated PID %s', x.pid)) - except psutil.TimeoutExpired: - self.log.debug("Ran out of time while waiting for processes to exit") - - # Then SIGKILL - child_processes = [x for x in this_process.children(recursive=True) - if x.is_running() and x.pid in pids_to_kill] - if len(child_processes) > 0: - self.log.info("SIGKILL processes that did not terminate gracefully") - for child in child_processes: - self.log.info("Killing child PID: %s", child.pid) - child.kill() - child.wait() - - def _execute_helper(self, processor_manager): - """ - :param processor_manager: manager to use - :type processor_manager: DagFileProcessorManager + self.processor_agent.end() + self.logger.info("Exited execute loop") + + def _execute_helper(self): + """ :return: None """ self.executor.start() @@ -1636,17 +1516,13 @@ def _execute_helper(self, processor_manager): self.log.info("Resetting orphaned tasks for active dag runs") self.reset_state_for_orphaned_tasks() + # Start after resetting orphaned tasks to avoid stressing out DB. + self.processor_agent.start() + execute_start_time = timezone.utcnow() - # Last time stats were printed - last_stat_print_time = datetime.datetime(2000, 1, 1, tzinfo=timezone.utc) # Last time that self.heartbeat() was called. last_self_heartbeat_time = timezone.utcnow() - # Last time that the DAG dir was traversed to look for files - last_dag_dir_refresh_time = timezone.utcnow() - - # Use this value initially - known_file_paths = processor_manager.file_paths # For the execute duration, parse and schedule DAGs while (timezone.utcnow() - execute_start_time).total_seconds() < \ @@ -1654,60 +1530,47 @@ def _execute_helper(self, processor_manager): self.log.debug("Starting Loop...") loop_start_time = time.time() - # Traverse the DAG directory for Python files containing DAGs - # periodically - elapsed_time_since_refresh = (timezone.utcnow() - - last_dag_dir_refresh_time).total_seconds() - - if elapsed_time_since_refresh > self.dag_dir_list_interval: - # Build up a list of Python files that could contain DAGs - self.log.info("Searching for files in %s", self.subdir) - known_file_paths = list_py_file_paths(self.subdir) - last_dag_dir_refresh_time = timezone.utcnow() - self.log.info( - "There are %s files in %s", len(known_file_paths), self.subdir) - - processor_manager.set_file_paths(known_file_paths) - - self.log.debug("Removing old import errors") - self.clear_nonexistent_import_errors(known_file_paths=known_file_paths) - - # Kick of new processes and collect results from finished ones - self.log.debug("Heartbeating the process manager") - simple_dags = processor_manager.heartbeat() - if self.using_sqlite: + self.processor_agent.heartbeat() # For the sqlite case w/ 1 thread, wait until the processor # is finished to avoid concurrent access to the DB. self.log.debug( "Waiting for processors to finish since we're using sqlite") + self.processor_agent.wait_until_finished() - processor_manager.wait_until_finished() + self.logger.info("Harvesting DAG parsing results") + simple_dags = self.processor_agent.harvest_simple_dags() # Send tasks for execution if available simple_dag_bag = SimpleDagBag(simple_dags) if len(simple_dags) > 0: - - # Handle cases where a DAG run state is set (perhaps manually) to - # a non-running state. Handle task instances that belong to - # DAG runs in those states - - # If a task instance is up for retry but the corresponding DAG run - # isn't running, mark the task instance as FAILED so we don't try - # to re-run it. - self._change_state_for_tis_without_dagrun(simple_dag_bag, - [State.UP_FOR_RETRY], - State.FAILED) - # If a task instance is scheduled or queued, but the corresponding - # DAG run isn't running, set the state to NONE so we don't try to - # re-run it. - self._change_state_for_tis_without_dagrun(simple_dag_bag, - [State.QUEUED, - State.SCHEDULED], - State.NONE) - - self._execute_task_instances(simple_dag_bag, - (State.SCHEDULED,)) + try: + simple_dag_bag = SimpleDagBag(simple_dags) + + # Handle cases where a DAG run state is set (perhaps manually) to + # a non-running state. Handle task instances that belong to + # DAG runs in those states + + # If a task instance is up for retry but the corresponding DAG run + # isn't running, mark the task instance as FAILED so we don't try + # to re-run it. + self._change_state_for_tis_without_dagrun(simple_dag_bag, + [State.UP_FOR_RETRY], + State.FAILED) + # If a task instance is scheduled or queued, but the corresponding + # DAG run isn't running, set the state to NONE so we don't try to + # re-run it. + self._change_state_for_tis_without_dagrun(simple_dag_bag, + [State.QUEUED, + State.SCHEDULED], + State.NONE) + + self._execute_task_instances(simple_dag_bag, + (State.SCHEDULED,)) + except Exception as e: + self.logger.error("Error queuing tasks") + self.logger.exception(e) + continue # Call heartbeats self.log.debug("Heartbeating the executor") @@ -1724,40 +1587,34 @@ def _execute_helper(self, processor_manager): self.heartbeat() last_self_heartbeat_time = timezone.utcnow() - # Occasionally print out stats about how fast the files are getting processed - if ((timezone.utcnow() - last_stat_print_time).total_seconds() > - self.print_stats_interval): - if len(known_file_paths) > 0: - self._log_file_processing_stats(known_file_paths, - processor_manager) - last_stat_print_time = timezone.utcnow() - loop_end_time = time.time() + loop_duration = loop_end_time - loop_start_time self.log.debug( "Ran scheduling loop in %.2f seconds", - loop_end_time - loop_start_time) + loop_duration) self.log.debug("Sleeping for %.2f seconds", self._processor_poll_interval) time.sleep(self._processor_poll_interval) # Exit early for a test mode - if processor_manager.max_runs_reached(): - self.log.info( - "Exiting loop as all files have been processed %s times", - self.num_runs) + if self.processor_agent.done: + self.logger.info("Exiting scheduler loop as all files" + " have been processed {} times".format(self.num_runs)) break + if loop_duration < 1: + sleep_length = 1 - loop_duration + self.logger.debug( + "Sleeping for {0:.2f} seconds to prevent excessive logging" + .format(sleep_length)) + sleep(sleep_length) + # Stop any processors - processor_manager.terminate() + self.processor_agent.terminate() # Verify that all files were processed, and if so, deactivate DAGs that # haven't been touched by the scheduler as they likely have been # deleted. - all_files_processed = True - for file_path in known_file_paths: - if processor_manager.get_last_finish_time(file_path) is None: - all_files_processed = False - break - if all_files_processed: + if self.processor_agent.all_files_processed: self.log.info( "Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat() @@ -1769,7 +1626,7 @@ def _execute_helper(self, processor_manager): settings.Session.remove() @provide_session - def process_file(self, file_path, pickle_dags=False, session=None): + def process_file(self, file_path, zombies, pickle_dags=False, session=None): """ Process a Python file containing Airflow DAGs. @@ -1788,6 +1645,8 @@ def process_file(self, file_path, pickle_dags=False, session=None): :param file_path: the path to the Python file that should be executed :type file_path: unicode + :param zombies: zombie task instances to kill. + :type zombies: SimpleTaskInstance :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :type pickle_dags: bool @@ -1882,7 +1741,7 @@ def process_file(self, file_path, pickle_dags=False, session=None): except Exception: self.log.exception("Error logging import errors!") try: - dagbag.kill_zombies() + dagbag.kill_zombies(zombies) except Exception: self.log.exception("Error killing zombies!") diff --git a/airflow/models.py b/airflow/models.py index 3f8f6c6736..c4d9b1cd75 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -437,43 +437,31 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): return found_dags @provide_session - def kill_zombies(self, session=None): - """ - Fails tasks that haven't had a heartbeat in too long - """ - from airflow.jobs import LocalTaskJob as LJ - self.log.info("Finding 'running' jobs without a recent heartbeat") - TI = TaskInstance - secs = configuration.conf.getint('scheduler', 'scheduler_zombie_task_threshold') - limit_dttm = timezone.utcnow() - timedelta(seconds=secs) - self.log.info("Failing jobs without heartbeat after %s", limit_dttm) - - tis = ( - session.query(TI) - .join(LJ, TI.job_id == LJ.id) - .filter(TI.state == State.RUNNING) - .filter( - or_( - LJ.state != State.RUNNING, - LJ.latest_heartbeat < limit_dttm, - )) - .all() - ) - - for ti in tis: - if ti and ti.dag_id in self.dags: - dag = self.dags[ti.dag_id] - if ti.task_id in dag.task_ids: - task = dag.get_task(ti.task_id) - - # now set non db backed vars on ti - ti.task = task + def kill_zombies(self, zombies, session=None): + """ + Fail given zombie tasks, which are tasks that haven't + had a heartbeat for too long, in the current DagBag. + + :param zombies: zombie task instances to kill. + :type zombies: SimpleTaskInstance + :param session: DB session. + :type Session. + """ + for zombie in zombies: + if zombie.dag_id in self.dags: + dag = self.dags[zombie.dag_id] + if zombie.task_id in dag.task_ids: + task = dag.get_task(zombie.task_id) + ti = TaskInstance(task, zombie.execution_date) + ti.start_date = zombie.start_date + ti.end_date = zombie.end_date + ti.try_number = zombie.try_number + ti.state = zombie.state ti.test_mode = configuration.getboolean('core', 'unit_test_mode') - - ti.handle_failure("{} detected as zombie".format(ti), + ti.handle_failure("{} killed as zombie".format(ti), ti.test_mode, ti.get_template_context()) - self.log.info( - 'Marked zombie job %s as %s', ti, ti.state) + self.logger.info( + 'Marked zombie job {} as failed'.format(ti)) Stats.incr('zombies_killed') session.commit() diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 89e5701cf1..b651bb0c11 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -22,17 +22,34 @@ from __future__ import print_function from __future__ import unicode_literals +import multiprocessing import os import re +import signal +import sys import time import zipfile from abc import ABCMeta, abstractmethod from collections import defaultdict +from datetime import timedelta +import psutil +from sqlalchemy import or_ +from tabulate import tabulate + +# To avoid circular imports +import airflow.models +from airflow import configuration as conf +from airflow import settings from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.exceptions import AirflowException from airflow.utils import timezone +from airflow.utils.db import provide_session from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import State + +if sys.version_info.major > 2: + xrange = range class SimpleDag(BaseDag): @@ -121,6 +138,49 @@ def get_task_special_arg(self, task_id, special_arg_name): return None +class SimpleTaskInstance(object): + def __init__(self, ti): + self._dag_id = ti.dag_id + self._task_id = ti.task_id + self._execution_date = ti.execution_date + self._start_date = ti.start_date + self._end_date = ti.end_date + self._try_number = ti.try_number + self._state = ti.state + + @property + def dag_id(self): + return self._dag_id + + @property + def task_id(self): + return self._task_id + + @property + def execution_date(self): + return self._execution_date + + @property + def start_date(self): + return self._start_date + + @property + def end_date(self): + return self._end_date + + @property + def try_number(self): + return self._try_number + + @property + def state(self): + return self._state + + @property + def run_as_user(self): + return self._run_as_user + + class SimpleDagBag(BaseDagBag): """ A collection of SimpleDag objects with some convenience methods. @@ -308,6 +368,243 @@ def file_path(self): raise NotImplementedError() +class DagParsingStat(object): + def __init__(self, + file_paths, + all_pids, + done, + all_files_processed, + result_count): + self.file_paths = file_paths + self.all_pids = all_pids + self.done = done + self.all_files_processed = all_files_processed + self.result_count = result_count + + +class DagParsingSignal(object): + AGENT_HEARTBEAT = "agent_heartbeat" + MANAGER_DONE = "manager_done" + TERMINATE_MANAGER = "terminate_manager" + END_MANAGER = "end_manager" + + +class DagFileProcessorAgent(LoggingMixin): + """ + Agent for DAG file processors. It is responsible for all DAG parsing + related jobs in scheduler process. Mainly it will collect DAG parsing + result from DAG file processor manager and communicate signal/DAG parsing + stat with DAG file processor manager. + """ + + def __init__(self, + dag_directory, + file_paths, + max_runs, + processor_factory, + async_mode): + """ + :param dag_directory: Directory where DAG definitions are kept. All + files in file_paths should be under this directory + :type dag_directory: unicode + :param file_paths: list of file paths that contain DAG definitions + :type file_paths: list[unicode] + :param max_runs: The number of times to parse and schedule each file. -1 + for unlimited. + :type max_runs: int + :param processor_factory: function that creates processors for DAG + definition files. Arguments are (dag_definition_path, log_file_path) + :type processor_factory: (unicode, unicode, list) -> (AbstractDagFileProcessor) + :param async_mode: Whether to start agent in async mode + :type async_mode: bool + """ + self._file_paths = file_paths + self._file_path_queue = [] + self._dag_directory = dag_directory + self._max_runs = max_runs + self._processor_factory = processor_factory + self._async_mode = async_mode + # Map from file path to the processor + self._processors = {} + # Map from file path to the last runtime + self._last_runtime = {} + # Map from file path to the last finish time + self._last_finish_time = {} + # Map from file path to the number of runs + self._run_count = defaultdict(int) + # Scheduler heartbeat key. + self._heart_beat_key = 'heart-beat' + # Pids of DAG parse + self._all_pids = [] + # Pipe for communicating signals + self._parent_signal_conn, self._child_signal_conn = multiprocessing.Pipe() + # Pipe for communicating DagParsingStat + self._stat_queue = multiprocessing.Queue() + self._result_queue = multiprocessing.Queue() + self._process = None + self._done = False + # Initialized as true so we do not deactivate w/o any actual DAG parsing. + self._all_files_processed = True + self._result_count = 0 + + def start(self): + """ + Launch DagFileProcessorManager processor and start DAG parsing loop in manager. + """ + self._process = self._launch_process(self._dag_directory, + self._file_paths, + self._max_runs, + self._processor_factory, + self._child_signal_conn, + self._stat_queue, + self._result_queue, + self._async_mode) + self.log.info("Launched DagFileProcessorManager with pid: {}" + .format(self._process.pid)) + + def heartbeat(self): + """ + Should only be used when launched DAG file processor manager in sync mode. + Send agent heartbeat signal to the manager. + """ + self._parent_signal_conn.send(DagParsingSignal.AGENT_HEARTBEAT) + + def wait_until_finished(self): + """ + Should only be used when launched DAG file processor manager in sync mode. + Wait for done signal from the manager. + """ + while True: + if self._parent_signal_conn.poll() \ + and self._parent_signal_conn.recv() == DagParsingSignal.MANAGER_DONE: + break + time.sleep(0.1) + + @staticmethod + def _launch_process(dag_directory, + file_paths, + max_runs, + processor_factory, + signal_conn, + _stat_queue, + result_queue, + async_mode): + def helper(): + processor_manager = DagFileProcessorManager(dag_directory, + file_paths, + max_runs, + processor_factory, + signal_conn, + _stat_queue, + result_queue, + async_mode) + + processor_manager.start() + + p = multiprocessing.Process(target=helper, + args=(), + name="DagFileProcessorManager") + p.start() + return p + + def harvest_simple_dags(self): + """ + Harvest DAG parsing results from result queue and sync metadata from stat queue. + :return: List of parsing result in SimpleDag format. + """ + # Metadata and results to be harvested can be inconsistent, + # but it should not be a big problem. + self._sync_metadata() + # Heartbeating after syncing metadata so we do not restart manager + # if it processed all files for max_run times and exit normally. + self._heartbeat_manager() + simple_dags = [] + # multiprocessing.Queue().qsize will not work on MacOS. + if sys.platform == "darwin": + qsize = self._result_count + else: + qsize = self._result_queue.qsize() + for _ in xrange(qsize): + simple_dags.append(self._result_queue.get()) + + self._result_count = 0 + + return simple_dags + + def _heartbeat_manager(self): + """ + Heartbeat DAG file processor and start it if it is not alive. + :return: + """ + if self._process and not self._process.is_alive() and not self.done: + self.start() + + def _sync_metadata(self): + """ + Sync metadata from stat queue and only keep the latest stat. + :return: + """ + while not self._stat_queue.empty(): + stat = self._stat_queue.get() + self._file_paths = stat.file_paths + self._all_pids = stat.all_pids + self._done = stat.done + self._all_files_processed = stat.all_files_processed + self._result_count += stat.result_count + + @property + def file_paths(self): + return self._file_paths + + @property + def done(self): + return self._done + + @property + def all_files_processed(self): + return self._all_files_processed + + def terminate(self): + """ + Send termination signal to DAG parsing processor manager + and expect it to terminate all DAG file processors. + """ + self.log.info("Sending termination signal to manager.") + self._child_signal_conn.send(DagParsingSignal.TERMINATE_MANAGER) + + def end(self): + """ + Terminate (and then kill) the manager process launched. + :return: + """ + if not self._process or not self._process.is_alive(): + self.log.warn('Ending without manager process.') + return + this_process = psutil.Process(os.getpid()) + manager_process = psutil.Process(self._process.pid) + # First try SIGTERM + if manager_process.is_running() \ + and manager_process.pid in [x.pid for x in this_process.children()]: + self.log.info( + "Terminating manager process: {}".format(manager_process.pid)) + manager_process.terminate() + timeout = 5 + self.log.info("Waiting up to {}s for manager process to exit..." + .format(timeout)) + try: + psutil.wait_procs({manager_process}, timeout) + except psutil.TimeoutExpired: + self.log.debug("Ran out of time while waiting for " + "processes to exit") + + # Then SIGKILL + if manager_process.is_running() \ + and manager_process.pid in [x.pid for x in this_process.children()]: + self.log.info("Killing manager process: {}".format(manager_process.pid)) + manager_process.kill() + manager_process.wait() + + class DagFileProcessorManager(LoggingMixin): """ Given a list of DAG definition files, this kicks off several processors @@ -324,48 +621,318 @@ class DagFileProcessorManager(LoggingMixin): def __init__(self, dag_directory, file_paths, - parallelism, - process_file_interval, max_runs, - processor_factory): + processor_factory, + signal_conn, + stat_queue, + result_queue, + async_mode=True): """ :param dag_directory: Directory where DAG definitions are kept. All files in file_paths should be under this directory :type dag_directory: unicode :param file_paths: list of file paths that contain DAG definitions :type file_paths: list[unicode] - :param parallelism: maximum number of simultaneous process to run at once - :type parallelism: int - :param process_file_interval: process a file at most once every this - many seconds - :type process_file_interval: float :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :type max_runs: int - :type process_file_interval: float :param processor_factory: function that creates processors for DAG definition files. Arguments are (dag_definition_path) - :type processor_factory: (unicode, unicode) -> (AbstractDagFileProcessor) - + :type processor_factory: (unicode, unicode, list) -> (AbstractDagFileProcessor) + :param signal_conn: connection to communicate signal with processor agent. + :type signal_conn: Connection + :param stat_queue: the queue to use for passing back parsing stat to agent. + :type stat_queue: multiprocessing.Queue + :param result_queue: the queue to use for passing back the result to agent. + :type result_queue: multiprocessing.Queue + :param async_mode: whether to start the manager in async mode + :type async_mode: bool """ self._file_paths = file_paths self._file_path_queue = [] - self._parallelism = parallelism self._dag_directory = dag_directory self._max_runs = max_runs - self._process_file_interval = process_file_interval self._processor_factory = processor_factory + self._signal_conn = signal_conn + self._stat_queue = stat_queue + self._result_queue = result_queue + self._async_mode = async_mode + + self._parallelism = conf.getint('scheduler', 'max_threads') + if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1: + self.log.error("Cannot use more than 1 thread when using sqlite. " + "Setting parallelism to 1") + self._parallelism = 1 + + # Parse and schedule each file no faster than this interval. + self._file_process_interval = conf.getint('scheduler', + 'min_file_process_interval') + # How often to print out DAG file processing stats to the log. Default to + # 30 seconds. + self.print_stats_interval = conf.getint('scheduler', + 'print_stats_interval') + # How many seconds do we wait for tasks to heartbeat before mark them as zombies. + self._zombie_threshold_secs = ( + conf.getint('scheduler', 'scheduler_zombie_task_threshold')) # Map from file path to the processor self._processors = {} # Map from file path to the last runtime self._last_runtime = {} # Map from file path to the last finish time self._last_finish_time = {} + self._last_zombie_query_time = timezone.utcnow() + # Last time that the DAG dir was traversed to look for files + self.last_dag_dir_refresh_time = timezone.utcnow() + # Last time stats were printed + self.last_stat_print_time = timezone.datetime(2000, 1, 1) + self._zombie_query_interval = 10 # Map from file path to the number of runs self._run_count = defaultdict(int) # Scheduler heartbeat key. self._heart_beat_key = 'heart-beat' + # How often to scan the DAGs directory for new files. Default to 5 minutes. + self.dag_dir_list_interval = conf.getint('scheduler', + 'dag_dir_list_interval') + + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame): + """ + Helper method to clean up DAG file processors to avoid leaving orphan processes. + """ + self.log.info("Exiting gracefully with signal {}".format(signum)) + self.terminate() + self.end() + self.log.debug("Finished terminating DAG processors.") + sys.exit(os.EX_OK) + + def start(self): + """ + Use multiple processes to parse and generate tasks for the + DAGs in parallel. By processing them in separate processes, + we can get parallelism and isolation from potentially harmful + user code. + :return: + """ + + self.log.info("Processing files using up to {} processes at a time " + .format(self._parallelism)) + self.log.info("Process each file at most once every {} seconds" + .format(self._file_process_interval)) + self.log.info("Checking for new files in {} every {} seconds" + .format(self._dag_directory, self.dag_dir_list_interval)) + if self._async_mode: + self.log.debug("Starting DagFileProcessorManager in async mode") + self.start_in_async() + else: + self.log.debug("Starting DagFileProcessorManager in sync mode") + self.start_in_sync() + + def start_in_async(self): + """ + Parse DAG files repeatedly in a standalone loop. + """ + while True: + loop_start_time = time.time() + + if self._signal_conn.poll(): + agent_signal = self._signal_conn.recv() + if agent_signal == DagParsingSignal.TERMINATE_MANAGER: + self.terminate() + break + elif agent_signal == DagParsingSignal.END_MANAGER: + self.end() + sys.exit(os.EX_OK) + + self._refresh_dag_dir() + + simple_dags = self.heartbeat() + for simple_dag in simple_dags: + self._result_queue.put(simple_dag) + + self._print_stat() + + all_files_processed = all(self.get_last_finish_time(x) is not None + for x in self.file_paths) + max_runs_reached = self.max_runs_reached() + + dag_parsing_stat = DagParsingStat(self._file_paths, + self.get_all_pids(), + max_runs_reached, + all_files_processed, + len(simple_dags)) + self._stat_queue.put(dag_parsing_stat) + + if max_runs_reached: + self.log.info("Exiting dag parsing loop as all files " + "have been processed %s times", self._max_runs) + break + + loop_duration = time.time() - loop_start_time + if loop_duration < 1: + sleep_length = 1 - loop_duration + self.log.debug("Sleeping for {0:.2f} seconds " + "to prevent excessive logging".format(sleep_length)) + time.sleep(sleep_length) + + def start_in_sync(self): + """ + Parse DAG files in a loop controlled by DagParsingSignal. + Actual DAG parsing loop will run once everyone an agent heartbeats + the manager and will report done when finished the loop. + """ + while True: + if self._signal_conn.poll(): + agent_signal = self._signal_conn.recv() + if agent_signal == DagParsingSignal.TERMINATE_MANAGER: + self.terminate() + break + elif agent_signal == DagParsingSignal.END_MANAGER: + self.end() + sys.exit(os.EX_OK) + elif agent_signal == DagParsingSignal.AGENT_HEARTBEAT: + + self._refresh_dag_dir() + + simple_dags = self.heartbeat() + for simple_dag in simple_dags: + self._result_queue.put(simple_dag) + + self._print_stat() + + all_files_processed = all(self.get_last_finish_time(x) is not None + for x in self.file_paths) + max_runs_reached = self.max_runs_reached() + + dag_parsing_stat = DagParsingStat(self._file_paths, + self.get_all_pids(), + self.max_runs_reached(), + all_files_processed, + len(simple_dags)) + self._stat_queue.put(dag_parsing_stat) + + self.wait_until_finished() + self._signal_conn.send(DagParsingSignal.MANAGER_DONE) + + if max_runs_reached: + self.log.info("Exiting dag parsing loop as all files " + "have been processed %s times", self._max_runs) + self._signal_conn.send(DagParsingSignal.MANAGER_DONE) + break + else: + time.sleep(0.01) + + def _refresh_dag_dir(self): + """ + Refresh file paths from dag dir if we haven't done it for too long. + """ + elapsed_time_since_refresh = (timezone.utcnow() - + self.last_dag_dir_refresh_time).total_seconds() + if elapsed_time_since_refresh > self.dag_dir_list_interval: + # Build up a list of Python files that could contain DAGs + self.log.info( + "Searching for files in {}".format(self._dag_directory)) + self._file_paths = list_py_file_paths(self._dag_directory) + self.last_dag_dir_refresh_time = timezone.utcnow() + self.log.info("There are {} files in {}" + .format(len(self._file_paths), + self._dag_directory)) + self.set_file_paths(self._file_paths) + + self.log.debug("Removing old import errors") + + try: + self.clear_nonexistent_import_errors() + except Exception as e: + self.log.error("Error removing old import errors") + self.log.exception(e) + + def _print_stat(self): + """ + Occasionally print out stats about how fast the files are getting processed + :return: + """ + if ((timezone.utcnow() - self.last_stat_print_time).total_seconds() > + self.print_stats_interval): + if len(self._file_paths) > 0: + self._log_file_processing_stats(self._file_paths) + self.last_stat_print_time = timezone.utcnow() + + @provide_session + def clear_nonexistent_import_errors(self, session): + """ + Clears import errors for files that no longer exist. + :param session: session for ORM operations + :type session: sqlalchemy.orm.session.Session + """ + session.query(airflow.models.ImportError).filter( + ~airflow.models.ImportError.filename.in_(self._file_paths) + ).delete(synchronize_session='fetch') + session.commit() + + def _log_file_processing_stats(self, known_file_paths): + """ + Print out stats about how files are getting processed. + :param known_file_paths: a list of file paths that may contain Airflow + DAG definitions + :type known_file_paths: list[unicode] + :return: None + """ + + # File Path: Path to the file containing the DAG definition + # PID: PID associated with the process that's processing the file. May + # be empty. + # Runtime: If the process is currently running, how long it's been + # running for in seconds. + # Last Runtime: If the process ran before, how long did it take to + # finish in seconds + # Last Run: When the file finished processing in the previous run. + headers = ["File Path", + "PID", + "Runtime", + "Last Runtime", + "Last Run"] + + rows = [] + for file_path in known_file_paths: + last_runtime = self.get_last_runtime(file_path) + processor_pid = self.get_pid(file_path) + processor_start_time = self.get_start_time(file_path) + runtime = ((timezone.utcnow() - processor_start_time).total_seconds() + if processor_start_time else None) + last_run = self.get_last_finish_time(file_path) + + rows.append((file_path, + processor_pid, + runtime, + last_runtime, + last_run)) + + # Sort by longest last runtime. (Can't sort None values in python3) + rows = sorted(rows, key=lambda x: x[3] or 0.0) + + formatted_rows = [] + for file_path, pid, runtime, last_runtime, last_run in rows: + formatted_rows.append((file_path, + pid, + "{:.2f}s".format(runtime) + if runtime else None, + "{:.2f}s".format(last_runtime) + if last_runtime else None, + last_run.strftime("%Y-%m-%dT%H:%M:%S") + if last_run else None)) + log_str = ("\n" + + "=" * 80 + + "\n" + + "DAG File Processing Stats\n\n" + + tabulate(formatted_rows, headers=headers) + + "\n" + + "=" * 80) + + self.log.info(log_str) + @property def file_paths(self): return self._file_paths @@ -472,7 +1039,7 @@ def wait_until_finished(self): def heartbeat(self): """ - This should be periodically called by the scheduler. This method will + This should be periodically called by the manager loop. This method will kick off new processes to process DAG definition files and read the results from the finished processors. @@ -528,7 +1095,7 @@ def heartbeat(self): last_finish_time = self.get_last_finish_time(file_path) if (last_finish_time is not None and (now - last_finish_time).total_seconds() < - self._process_file_interval): + self._file_process_interval): file_paths_recently_processed.append(file_path) files_paths_at_run_limit = [file_path @@ -553,11 +1120,13 @@ def heartbeat(self): self._file_path_queue.extend(files_paths_to_queue) + zombies = self._find_zombies() + # Start more processors if we have enough slots and files to process while (self._parallelism - len(self._processors) > 0 and len(self._file_path_queue) > 0): file_path = self._file_path_queue.pop(0) - processor = self._processor_factory(file_path) + processor = self._processor_factory(file_path, zombies) processor.start() self.log.info( @@ -566,11 +1135,48 @@ def heartbeat(self): ) self._processors[file_path] = processor - # Update scheduler heartbeat count. + # Update heartbeat count. self._run_count[self._heart_beat_key] += 1 return simple_dags + def _find_zombies(self): + """ + Find zombie task instances, which are tasks haven't heartbeated for too long. + :return: Zombie task instances in SimpleTaskInstance format. + """ + now = timezone.utcnow() + zombies = [] + if (now - self._last_zombie_query_time).total_seconds() \ + > self._zombie_query_interval: + # to avoid circular imports + from airflow.jobs import LocalTaskJob as LJ + session = settings.Session + self.log.info("Finding 'running' jobs without a recent heartbeat") + TI = airflow.models.TaskInstance + limit_dttm = timezone.utcnow() - timedelta( + seconds=self._zombie_threshold_secs) + self.log.info( + "Failing jobs without heartbeat after {}".format(limit_dttm)) + + tis = ( + session.query(TI) + .join(LJ, TI.job_id == LJ.id) + .filter(TI.state == State.RUNNING) + .filter( + or_( + LJ.state != State.RUNNING, + LJ.latest_heartbeat < limit_dttm, + ) + ).all() + ) + for ti in tis: + zombies.append(SimpleTaskInstance(ti)) + + session.close() + + return zombies + def max_runs_reached(self): """ :return: whether all file paths have been processed max_runs times @@ -591,3 +1197,38 @@ def terminate(self): """ for processor in self._processors.values(): processor.terminate() + + def end(self): + """ + Kill all child processes on exit since we don't want to leave + them as orphaned. + """ + pids_to_kill = self.get_all_pids() + if len(pids_to_kill) > 0: + # First try SIGTERM + this_process = psutil.Process(os.getpid()) + # Only check child processes to ensure that we don't have a case + # where we kill the wrong process because a child process died + # but the PID got reused. + child_processes = [x for x in this_process.children(recursive=True) + if x.is_running() and x.pid in pids_to_kill] + for child in child_processes: + self.log.info("Terminating child PID: {}".format(child.pid)) + child.terminate() + timeout = 5 + self.log.info("Waiting up to {}s for processes to exit..." + .format(timeout)) + try: + psutil.wait_procs(child_processes, timeout) + except psutil.TimeoutExpired: + self.log.debug("Ran out of time while waiting for " + "processes to exit") + + # Then SIGKILL + child_processes = [x for x in this_process.children(recursive=True) + if x.is_running() and x.pid in pids_to_kill] + if len(child_processes) > 0: + for child in child_processes: + self.log.info("Killing child PID: {}".format(child.pid)) + child.kill() + child.wait() diff --git a/tests/jobs.py b/tests/jobs.py index f9c07b96c9..ba4ac98063 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -28,37 +28,35 @@ import multiprocessing import os import shutil -import six import threading import time import unittest from tempfile import mkdtemp +import psutil +import six import sqlalchemy +from mock import Mock, patch, MagicMock, PropertyMock from airflow import AirflowException, settings, models +from airflow import configuration from airflow.bin import cli from airflow.executors import BaseExecutor, SequentialExecutor from airflow.jobs import BaseJob, BackfillJob, SchedulerJob, LocalTaskJob from airflow.models import DAG, DagModel, DagBag, DagRun, Pool, TaskInstance as TI -from airflow.operators.dummy_operator import DummyOperator from airflow.operators.bash_operator import BashOperator +from airflow.operators.dummy_operator import DummyOperator from airflow.task.task_runner.base_task_runner import BaseTaskRunner from airflow.utils import timezone - +from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths from airflow.utils.dates import days_ago from airflow.utils.db import provide_session +from airflow.utils.net import get_hostname from airflow.utils.state import State from airflow.utils.timeout import timeout -from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths -from airflow.utils.net import get_hostname - -from mock import Mock, patch, MagicMock, PropertyMock -from tests.executors.test_executor import TestExecutor - from tests.core import TEST_DAG_FOLDER +from tests.executors.test_executor import TestExecutor -from airflow import configuration configuration.load_test_config() logger = logging.getLogger(__name__) @@ -1213,6 +1211,7 @@ def setUp(self): session.query(models.DagRun).delete() session.query(models.ImportError).delete() session.commit() + session.close() @staticmethod def run_single_scheduler_loop_with_no_dags(dags_folder): @@ -1235,6 +1234,21 @@ def run_single_scheduler_loop_with_no_dags(dags_folder): def _make_simple_dag_bag(self, dags): return SimpleDagBag([SimpleDag(dag) for dag in dags]) + def test_no_orphan_process_will_be_left(self): + empty_dir = mkdtemp() + current_process = psutil.Process() + old_children = current_process.children(recursive=True) + scheduler = SchedulerJob(subdir=empty_dir, + num_runs=1) + scheduler.executor = TestExecutor() + scheduler.run() + shutil.rmtree(empty_dir) + + # Remove potential noise created by previous tests. + current_children = set(current_process.children(recursive=True)) - set( + old_children) + self.assertFalse(current_children) + def test_process_executor_events(self): dag_id = "test_process_executor_events" dag_id2 = "test_process_executor_events_2" @@ -2004,13 +2018,14 @@ def test_execute_helper_reset_orphaned_tasks(self): session.commit() processor = mock.MagicMock() - processor.get_last_finish_time.return_value = None + # processor.get_last_finish_time.return_value = None scheduler = SchedulerJob(num_runs=0, run_duration=0) executor = TestExecutor() scheduler.executor = executor + scheduler.processor_agent = processor - scheduler._execute_helper(processor_manager=processor) + scheduler._execute_helper() ti = dr.get_task_instance(task_id=op1.task_id, session=session) self.assertEqual(ti.state, State.NONE) @@ -2937,7 +2952,9 @@ def test_scheduler_run_duration(self): logging.info("Test ran in %.2fs, expected %.2fs", run_duration, expected_run_duration) - self.assertLess(run_duration - expected_run_duration, 5.0) + # 5s to wait for child process to exit and 1s dummy sleep + # in scheduler loop to prevent excessive logs. + self.assertLess(run_duration - expected_run_duration, 6.0) def test_dag_with_system_exit(self): """ @@ -2964,6 +2981,7 @@ def test_dag_with_system_exit(self): session = settings.Session() self.assertEqual( len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + session.close() def test_dag_get_active_runs(self): """ @@ -3097,10 +3115,9 @@ def setup_dag(dag_id, schedule_interval, start_date, catchup): dr = scheduler.create_dag_run(dag4) self.assertIsNotNone(dr) - def test_add_unparseable_file_before_sched_start_creates_import_error(self): + dags_folder = mkdtemp() try: - dags_folder = mkdtemp() unparseable_filename = os.path.join(dags_folder, TEMP_DAG_FILENAME) with open(unparseable_filename, 'w') as unparseable_file: unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) @@ -3110,6 +3127,7 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self): session = settings.Session() import_errors = session.query(models.ImportError).all() + session.close() self.assertEqual(len(import_errors), 1) import_error = import_errors[0] @@ -3119,8 +3137,9 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self): "invalid syntax ({}, line 1)".format(TEMP_DAG_FILENAME)) def test_add_unparseable_file_after_sched_start_creates_import_error(self): + dags_folder = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'dags') try: - dags_folder = mkdtemp() unparseable_filename = os.path.join(dags_folder, TEMP_DAG_FILENAME) self.run_single_scheduler_loop_with_no_dags(dags_folder) @@ -3128,10 +3147,12 @@ def test_add_unparseable_file_after_sched_start_creates_import_error(self): unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) self.run_single_scheduler_loop_with_no_dags(dags_folder) finally: - shutil.rmtree(dags_folder) + # shutil.rmtree(dags_folder) + pass session = settings.Session() import_errors = session.query(models.ImportError).all() + session.close() self.assertEqual(len(import_errors), 1) import_error = import_errors[0] @@ -3153,6 +3174,7 @@ def test_no_import_errors_with_parseable_dag(self): session = settings.Session() import_errors = session.query(models.ImportError).all() + session.close() self.assertEqual(len(import_errors), 0) diff --git a/tests/models.py b/tests/models.py index a9c43dfd8e..36cafbb2e4 100644 --- a/tests/models.py +++ b/tests/models.py @@ -51,6 +51,7 @@ from airflow.operators.python_operator import ShortCircuitOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone +from airflow.utils.dag_processing import SimpleTaskInstance from airflow.utils.weight_rule import WeightRule from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule @@ -1589,7 +1590,8 @@ def test_kill_zombies(self, mock_ti): session.add(ti) session.commit() - dagbag.kill_zombies() + zombies = [SimpleTaskInstance(ti)] + dagbag.kill_zombies(zombies) mock_ti.assert_called_with(ANY, configuration.getboolean('core', 'unit_test_mode'), ANY) diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index f29e384b8c..cd19dab237 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -17,22 +17,31 @@ # specific language governing permissions and limitations # under the License. +import os import unittest from mock import MagicMock +from airflow import configuration as conf +from airflow.jobs import DagFileProcessor +from airflow.utils.dag_processing import DagFileProcessorAgent from airflow.utils.dag_processing import DagFileProcessorManager +TEST_DAG_FOLDER = os.path.join( + os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags') + class TestDagFileProcessorManager(unittest.TestCase): def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): manager = DagFileProcessorManager( dag_directory='directory', file_paths=['abc.txt'], - parallelism=1, - process_file_interval=1, max_runs=1, - processor_factory=MagicMock().return_value) + processor_factory=MagicMock().return_value, + signal_conn=MagicMock(), + stat_queue=MagicMock(), + result_queue=MagicMock, + async_mode=True) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError( @@ -48,10 +57,12 @@ def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): manager = DagFileProcessorManager( dag_directory='directory', file_paths=['abc.txt'], - parallelism=1, - process_file_interval=1, max_runs=1, - processor_factory=MagicMock().return_value) + processor_factory=MagicMock().return_value, + signal_conn=MagicMock(), + stat_queue=MagicMock(), + result_queue=MagicMock, + async_mode=True) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError( @@ -62,3 +73,29 @@ def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): manager.set_file_paths(['abc.txt']) self.assertDictEqual(manager._processors, {'abc.txt': mock_processor}) + + def test_parse_once(self): + def processor_factory(file_path, zombies): + return DagFileProcessor(file_path, + False, + [], + zombies) + + test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py') + async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') + processor_agent = DagFileProcessorAgent(test_dag_path, + [test_dag_path], + 1, + processor_factory, + async_mode) + processor_agent.start() + parsing_result = [] + while not processor_agent.done: + print(processor_agent.done) + if not async_mode: + processor_agent.heartbeat() + processor_agent.wait_until_finished() + parsing_result.extend(processor_agent.harvest_simple_dags()) + + dag_ids = [result.dag_id for result in parsing_result] + self.assertEqual(dag_ids.count('test_start_date_scheduling'), 1) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services