This is an automated email from the ASF dual-hosted git repository. onikolas pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 3e229d8df5 Task adoption for hybrid executors (#39531) 3e229d8df5 is described below commit 3e229d8df52748032e0c56503c9696f7f6d9eb62 Author: Niko Oliveira <oniko...@amazon.com> AuthorDate: Mon May 13 11:10:11 2024 -0700 Task adoption for hybrid executors (#39531) Sort the set of tasks that are up for adoption by the executor they're configured to run on (if any) and send them to the appropriate executor for adoption. --- airflow/executors/executor_loader.py | 8 ++-- airflow/jobs/scheduler_job_runner.py | 22 ++++++++++- tests/executors/test_executor_loader.py | 19 +++++++++ tests/jobs/test_scheduler_job.py | 69 +++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 5 deletions(-) diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index fb3ffce420..5fb9f90d4f 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -202,10 +202,10 @@ class ExecutorLoader: elif executor_name := _module_to_executors.get(executor_name_str): return executor_name else: - raise AirflowException(f"Unknown executor being loaded: {executor_name}") + raise AirflowException(f"Unknown executor being loaded: {executor_name_str}") @classmethod - def load_executor(cls, executor_name: ExecutorName | str) -> BaseExecutor: + def load_executor(cls, executor_name: ExecutorName | str | None) -> BaseExecutor: """ Load the executor. @@ -217,7 +217,9 @@ class ExecutorLoader: :return: an instance of executor class via executor_name """ - if isinstance(executor_name, str): + if not executor_name: + _executor_name = cls.get_default_executor_name() + elif isinstance(executor_name, str): _executor_name = cls.lookup_executor_name_by_str(executor_name) else: _executor_name = executor_name diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 49a065b5f5..f2333e8d5a 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -24,7 +24,7 @@ import signal import sys import time import warnings -from collections import Counter +from collections import Counter, defaultdict from dataclasses import dataclass from datetime import timedelta from functools import lru_cache, partial @@ -83,6 +83,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Query, Session from airflow.dag_processing.manager import DagFileProcessorAgent + from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.sqlalchemy import ( CommitProhibitorGuard, @@ -1651,7 +1652,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): # Lock these rows, so that another scheduler can't try and adopt these too tis_to_adopt_or_reset = with_row_locks(query, of=TI, session=session, skip_locked=True) tis_to_adopt_or_reset = session.scalars(tis_to_adopt_or_reset).all() - to_reset = self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset) + + to_reset: list[TaskInstance] = [] + exec_to_tis = self._executor_to_tis(tis_to_adopt_or_reset) + for executor, tis in exec_to_tis.items(): + to_reset.extend(executor.try_adopt_task_instances(tis)) reset_tis_message = [] for ti in to_reset: @@ -1831,3 +1836,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): updated_count = sum(self._set_orphaned(dataset) for dataset in orphaned_dataset_query) Stats.gauge("dataset.orphaned", updated_count) + + def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, list[TaskInstance]]: + """Organize TIs into lists per their respective executor.""" + _executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] = defaultdict(list) + executor: str | None + for ti in tis: + if ti.executor: + executor = str(ti.executor) + else: + executor = None + _executor_to_tis[ExecutorLoader.load_executor(executor)].append(ti) + + return _executor_to_tis diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py index 840e74a8fc..bb7da133b6 100644 --- a/tests/executors/test_executor_loader.py +++ b/tests/executors/test_executor_loader.py @@ -26,6 +26,7 @@ from airflow import plugins_manager from airflow.exceptions import AirflowConfigException from airflow.executors import executor_loader from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, ExecutorName +from airflow.executors.local_executor import LocalExecutor from airflow.providers.celery.executors.celery_executor import CeleryExecutor from tests.test_utils.config import conf_vars @@ -301,3 +302,21 @@ class TestExecutorLoader: monkeypatch.delenv("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK") with expectation: ExecutorLoader.validate_database_executor_compatibility(executor) + + def test_load_executor(self): + ExecutorLoader.block_use_of_hybrid_exec = mock.Mock() + with conf_vars({("core", "executor"): "LocalExecutor"}): + ExecutorLoader.init_executors() + assert isinstance(ExecutorLoader.load_executor("LocalExecutor"), LocalExecutor) + assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor) + assert isinstance(ExecutorLoader.load_executor(None), LocalExecutor) + + def test_load_executor_alias(self): + ExecutorLoader.block_use_of_hybrid_exec = mock.Mock() + with conf_vars({("core", "executor"): "local_exec:airflow.executors.local_executor.LocalExecutor"}): + ExecutorLoader.init_executors() + assert isinstance(ExecutorLoader.load_executor("local_exec"), LocalExecutor) + assert isinstance( + ExecutorLoader.load_executor("airflow.executors.local_executor.LocalExecutor"), LocalExecutor + ) + assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 491e345649..85399892ac 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -23,6 +23,7 @@ import logging import os from collections import deque from datetime import timedelta +from importlib import reload from typing import Generator from unittest import mock from unittest.mock import MagicMock, PropertyMock, patch @@ -165,6 +166,18 @@ class TestSchedulerJob: self.null_exec = None del self.dagbag + @pytest.fixture + def mock_executors(self): + default_executor = mock.MagicMock(slots_available=8, slots_occupied=0) + default_executor.name = MagicMock(alias="default_exec", module_path="default.exec.module.path") + second_executor = mock.MagicMock(slots_available=8, slots_occupied=0) + second_executor.name = MagicMock(alias="secondary_exec", module_path="secondary.exec.module.path") + with mock.patch("airflow.jobs.job.Job.executors", new_callable=PropertyMock) as executors_mock: + with mock.patch("airflow.jobs.job.Job.executor", new_callable=PropertyMock) as executor_mock: + executor_mock.return_value = default_executor + executors_mock.return_value = [default_executor, second_executor] + yield [default_executor, second_executor] + @pytest.mark.parametrize( "configs", [ @@ -1740,6 +1753,62 @@ class TestSchedulerJob: ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session) assert ti2.state == State.QUEUED, "Tasks run by Backfill Jobs should not be reset" + def test_adopt_or_reset_orphaned_tasks_multiple_executors(self, dag_maker, mock_executors): + """Test that with multiple executors configured tasks are sorted correctly and handed off to the + correct executor for adoption.""" + session = settings.Session() + with dag_maker("test_execute_helper_reset_orphaned_tasks_multiple_executors"): + op1 = EmptyOperator(task_id="op1") + op2 = EmptyOperator(task_id="op2", executor="default_exec") + op3 = EmptyOperator(task_id="op3", executor="secondary_exec") + + dr = dag_maker.create_dagrun() + scheduler_job = Job() + session.add(scheduler_job) + session.commit() + ti1 = dr.get_task_instance(task_id=op1.task_id, session=session) + ti2 = dr.get_task_instance(task_id=op2.task_id, session=session) + ti3 = dr.get_task_instance(task_id=op3.task_id, session=session) + tis = [ti1, ti2, ti3] + for ti in tis: + ti.state = State.QUEUED + ti.queued_by_job_id = scheduler_job.id + session.commit() + + with mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock: + # reload the scheduler_job_runner module so that it loads a fresh executor_loader module which + # contains the mocked load_executor method. + from airflow.jobs import scheduler_job_runner + + reload(scheduler_job_runner) + + processor = mock.MagicMock() + + new_scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=new_scheduler_job, num_runs=0) + self.job_runner.processor_agent = processor + # The executors are mocked, so cannot be loaded/imported. Mock load_executor and return the + # correct object for the given input executor name. + loader_mock.side_effect = lambda *x: { + ("default_exec",): mock_executors[0], + (None,): mock_executors[0], + ("secondary_exec",): mock_executors[1], + }[x] + + self.job_runner.adopt_or_reset_orphaned_tasks() + + # Default executor is called for ti1 (no explicit executor override uses default) and ti2 (where we + # explicitly marked that for execution by the default executor) + try: + mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti1, ti2]) + except AssertionError: + # The order of the TIs given to try_adopt_task_instances is not consistent, so check the other + # order first before allowing AssertionError to fail the test + mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti2, ti1]) + + # Second executor called for ti3 + mock_executors[1].try_adopt_task_instances.assert_called_once_with([ti3]) + def test_fail_stuck_queued_tasks(self, dag_maker, session): with dag_maker("test_fail_stuck_queued_tasks"): op1 = EmptyOperator(task_id="op1")