This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit f44d950365d0248a74a2c034dc3ae98757a8ff4b Author: Malthe Borch <mbo...@gmail.com> AuthorDate: Tue Feb 15 13:12:51 2022 +0000 Fix race condition between triggerer and scheduler (#21316) (cherry picked from commit 2a6792d94d153c6f2dd116843a43ee63cd296c8d) --- airflow/executors/base_executor.py | 36 ++++++++++++++++++--- tests/executors/test_base_executor.py | 60 ++++++++++++++++++++++++++++++++--- 2 files changed, 87 insertions(+), 9 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index f7ad45a..1d993bb 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -17,7 +17,7 @@ """Base executor - this is the base class for all the implemented executors.""" import sys from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Counter, Dict, List, Optional, Set, Tuple from airflow.configuration import conf from airflow.models.taskinstance import TaskInstance, TaskInstanceKey @@ -29,6 +29,8 @@ PARALLELISM: int = conf.getint('core', 'PARALLELISM') NOT_STARTED_MESSAGE = "The executor should be started first!" +QUEUEING_ATTEMPTS = 5 + # Command to execute - list of strings # the first element is always "airflow". # It should be result of TaskInstance.generate_command method.q @@ -63,6 +65,7 @@ class BaseExecutor(LoggingMixin): self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict() self.running: Set[TaskInstanceKey] = set() self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {} + self.attempts: Counter[TaskInstanceKey] = Counter() def __repr__(self): return f"{self.__class__.__name__}(parallelism={self.parallelism})" @@ -78,7 +81,7 @@ class BaseExecutor(LoggingMixin): queue: Optional[str] = None, ): """Queues command to task""" - if task_instance.key not in self.queued_tasks and task_instance.key not in self.running: + if task_instance.key not in self.queued_tasks: self.log.info("Adding to queue: %s", command) self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance) else: @@ -183,9 +186,32 @@ class BaseExecutor(LoggingMixin): for _ in range(min((open_slots, len(self.queued_tasks)))): key, (command, _, queue, ti) = sorted_queue.pop(0) - self.queued_tasks.pop(key) - self.running.add(key) - self.execute_async(key=key, command=command, queue=queue, executor_config=ti.executor_config) + + # If a task makes it here but is still understood by the executor + # to be running, it generally means that the task has been killed + # externally and not yet been marked as failed. + # + # However, when a task is deferred, there is also a possibility of + # a race condition where a task might be scheduled again during + # trigger processing, even before we are able to register that the + # deferred task has completed. In this case and for this reason, + # we make a small number of attempts to see if the task has been + # removed from the running set in the meantime. + if key in self.running: + attempt = self.attempts[key] + if attempt < QUEUEING_ATTEMPTS - 1: + self.attempts[key] = attempt + 1 + self.log.info("task %s is still running", key) + continue + + # We give up and remove the task from the queue. + self.log.error("could not queue task %s (still running after %d attempts)", key, attempt) + del self.attempts[key] + del self.queued_tasks[key] + else: + del self.queued_tasks[key] + self.running.add(key) + self.execute_async(key=key, command=command, queue=queue, executor_config=ti.executor_config) def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: """ diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index 49d6c01..40bf8eb 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -18,7 +18,9 @@ from datetime import timedelta from unittest import mock -from airflow.executors.base_executor import BaseExecutor +from pytest import mark + +from airflow.executors.base_executor import QUEUEING_ATTEMPTS, BaseExecutor from airflow.models.baseoperator import BaseOperator from airflow.models.taskinstance import TaskInstanceKey from airflow.utils import timezone @@ -57,7 +59,7 @@ def test_gauge_executor_metrics(mock_stats_gauge, mock_trigger_tasks, mock_sync) mock_stats_gauge.assert_has_calls(calls) -def test_try_adopt_task_instances(dag_maker): +def setup_dagrun(dag_maker): date = timezone.utcnow() start_date = date - timedelta(days=2) @@ -66,8 +68,58 @@ def test_try_adopt_task_instances(dag_maker): BaseOperator(task_id="task_2", start_date=start_date) BaseOperator(task_id="task_3", start_date=start_date) - dagrun = dag_maker.create_dagrun(execution_date=date) - tis = dagrun.task_instances + return dag_maker.create_dagrun(execution_date=date) + +def test_try_adopt_task_instances(dag_maker): + dagrun = setup_dagrun(dag_maker) + tis = dagrun.task_instances assert {ti.task_id for ti in tis} == {"task_1", "task_2", "task_3"} assert BaseExecutor().try_adopt_task_instances(tis) == tis + + +def enqueue_tasks(executor, dagrun): + for task_instance in dagrun.task_instances: + executor.queue_command(task_instance, ["airflow"]) + + +def setup_trigger_tasks(dag_maker): + dagrun = setup_dagrun(dag_maker) + executor = BaseExecutor() + executor.execute_async = mock.Mock() + enqueue_tasks(executor, dagrun) + return executor, dagrun + + +@mark.parametrize("open_slots", [1, 2, 3]) +def test_trigger_queued_tasks(dag_maker, open_slots): + executor, _ = setup_trigger_tasks(dag_maker) + executor.trigger_tasks(open_slots) + assert len(executor.execute_async.mock_calls) == open_slots + + +@mark.parametrize("change_state_attempt", range(QUEUEING_ATTEMPTS + 2)) +def test_trigger_running_tasks(dag_maker, change_state_attempt): + executor, dagrun = setup_trigger_tasks(dag_maker) + open_slots = 100 + executor.trigger_tasks(open_slots) + expected_calls = len(dagrun.task_instances) # initially `execute_async` called for each task + assert len(executor.execute_async.mock_calls) == expected_calls + + # All the tasks are now "running", so while we enqueue them again here, + # they won't be executed again until the executor has been notified of a state change. + enqueue_tasks(executor, dagrun) + + for attempt in range(QUEUEING_ATTEMPTS + 2): + # On the configured attempt, we notify the executor that the task has succeeded. + if attempt == change_state_attempt: + executor.change_state(dagrun.task_instances[0].key, State.SUCCESS) + # If we have not exceeded QUEUEING_ATTEMPTS, we should expect an additional "execute" call + if attempt < QUEUEING_ATTEMPTS: + expected_calls += 1 + executor.trigger_tasks(open_slots) + assert len(executor.execute_async.mock_calls) == expected_calls + if change_state_attempt < QUEUEING_ATTEMPTS: + assert len(executor.execute_async.mock_calls) == len(dagrun.task_instances) + 1 + else: + assert len(executor.execute_async.mock_calls) == len(dagrun.task_instances)