This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e229d8df5 Task adoption for hybrid executors (#39531)
3e229d8df5 is described below

commit 3e229d8df52748032e0c56503c9696f7f6d9eb62
Author: Niko Oliveira <oniko...@amazon.com>
AuthorDate: Mon May 13 11:10:11 2024 -0700

    Task adoption for hybrid executors (#39531)
    
    Sort the set of tasks that are up for adoption by the executor they're
    configured to run on (if any) and send them to the appropriate executor
    for adoption.
---
 airflow/executors/executor_loader.py    |  8 ++--
 airflow/jobs/scheduler_job_runner.py    | 22 ++++++++++-
 tests/executors/test_executor_loader.py | 19 +++++++++
 tests/jobs/test_scheduler_job.py        | 69 +++++++++++++++++++++++++++++++++
 4 files changed, 113 insertions(+), 5 deletions(-)

diff --git a/airflow/executors/executor_loader.py 
b/airflow/executors/executor_loader.py
index fb3ffce420..5fb9f90d4f 100644
--- a/airflow/executors/executor_loader.py
+++ b/airflow/executors/executor_loader.py
@@ -202,10 +202,10 @@ class ExecutorLoader:
         elif executor_name := _module_to_executors.get(executor_name_str):
             return executor_name
         else:
-            raise AirflowException(f"Unknown executor being loaded: 
{executor_name}")
+            raise AirflowException(f"Unknown executor being loaded: 
{executor_name_str}")
 
     @classmethod
-    def load_executor(cls, executor_name: ExecutorName | str) -> BaseExecutor:
+    def load_executor(cls, executor_name: ExecutorName | str | None) -> 
BaseExecutor:
         """
         Load the executor.
 
@@ -217,7 +217,9 @@ class ExecutorLoader:
 
         :return: an instance of executor class via executor_name
         """
-        if isinstance(executor_name, str):
+        if not executor_name:
+            _executor_name = cls.get_default_executor_name()
+        elif isinstance(executor_name, str):
             _executor_name = cls.lookup_executor_name_by_str(executor_name)
         else:
             _executor_name = executor_name
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 49a065b5f5..f2333e8d5a 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -24,7 +24,7 @@ import signal
 import sys
 import time
 import warnings
-from collections import Counter
+from collections import Counter, defaultdict
 from dataclasses import dataclass
 from datetime import timedelta
 from functools import lru_cache, partial
@@ -83,6 +83,7 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Query, Session
 
     from airflow.dag_processing.manager import DagFileProcessorAgent
+    from airflow.executors.base_executor import BaseExecutor
     from airflow.models.taskinstance import TaskInstanceKey
     from airflow.utils.sqlalchemy import (
         CommitProhibitorGuard,
@@ -1651,7 +1652,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     # Lock these rows, so that another scheduler can't try and 
adopt these too
                     tis_to_adopt_or_reset = with_row_locks(query, of=TI, 
session=session, skip_locked=True)
                     tis_to_adopt_or_reset = 
session.scalars(tis_to_adopt_or_reset).all()
-                    to_reset = 
self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset)
+
+                    to_reset: list[TaskInstance] = []
+                    exec_to_tis = self._executor_to_tis(tis_to_adopt_or_reset)
+                    for executor, tis in exec_to_tis.items():
+                        to_reset.extend(executor.try_adopt_task_instances(tis))
 
                     reset_tis_message = []
                     for ti in to_reset:
@@ -1831,3 +1836,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         updated_count = sum(self._set_orphaned(dataset) for dataset in 
orphaned_dataset_query)
         Stats.gauge("dataset.orphaned", updated_count)
+
+    def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, 
list[TaskInstance]]:
+        """Organize TIs into lists per their respective executor."""
+        _executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] = 
defaultdict(list)
+        executor: str | None
+        for ti in tis:
+            if ti.executor:
+                executor = str(ti.executor)
+            else:
+                executor = None
+            _executor_to_tis[ExecutorLoader.load_executor(executor)].append(ti)
+
+        return _executor_to_tis
diff --git a/tests/executors/test_executor_loader.py 
b/tests/executors/test_executor_loader.py
index 840e74a8fc..bb7da133b6 100644
--- a/tests/executors/test_executor_loader.py
+++ b/tests/executors/test_executor_loader.py
@@ -26,6 +26,7 @@ from airflow import plugins_manager
 from airflow.exceptions import AirflowConfigException
 from airflow.executors import executor_loader
 from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, 
ExecutorName
+from airflow.executors.local_executor import LocalExecutor
 from airflow.providers.celery.executors.celery_executor import CeleryExecutor
 from tests.test_utils.config import conf_vars
 
@@ -301,3 +302,21 @@ class TestExecutorLoader:
         
monkeypatch.delenv("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK")
         with expectation:
             ExecutorLoader.validate_database_executor_compatibility(executor)
+
+    def test_load_executor(self):
+        ExecutorLoader.block_use_of_hybrid_exec = mock.Mock()
+        with conf_vars({("core", "executor"): "LocalExecutor"}):
+            ExecutorLoader.init_executors()
+            assert isinstance(ExecutorLoader.load_executor("LocalExecutor"), 
LocalExecutor)
+            assert 
isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), 
LocalExecutor)
+            assert isinstance(ExecutorLoader.load_executor(None), 
LocalExecutor)
+
+    def test_load_executor_alias(self):
+        ExecutorLoader.block_use_of_hybrid_exec = mock.Mock()
+        with conf_vars({("core", "executor"): 
"local_exec:airflow.executors.local_executor.LocalExecutor"}):
+            ExecutorLoader.init_executors()
+            assert isinstance(ExecutorLoader.load_executor("local_exec"), 
LocalExecutor)
+            assert isinstance(
+                
ExecutorLoader.load_executor("airflow.executors.local_executor.LocalExecutor"), 
LocalExecutor
+            )
+            assert 
isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), 
LocalExecutor)
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 491e345649..85399892ac 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -23,6 +23,7 @@ import logging
 import os
 from collections import deque
 from datetime import timedelta
+from importlib import reload
 from typing import Generator
 from unittest import mock
 from unittest.mock import MagicMock, PropertyMock, patch
@@ -165,6 +166,18 @@ class TestSchedulerJob:
         self.null_exec = None
         del self.dagbag
 
+    @pytest.fixture
+    def mock_executors(self):
+        default_executor = mock.MagicMock(slots_available=8, slots_occupied=0)
+        default_executor.name = MagicMock(alias="default_exec", 
module_path="default.exec.module.path")
+        second_executor = mock.MagicMock(slots_available=8, slots_occupied=0)
+        second_executor.name = MagicMock(alias="secondary_exec", 
module_path="secondary.exec.module.path")
+        with mock.patch("airflow.jobs.job.Job.executors", 
new_callable=PropertyMock) as executors_mock:
+            with mock.patch("airflow.jobs.job.Job.executor", 
new_callable=PropertyMock) as executor_mock:
+                executor_mock.return_value = default_executor
+                executors_mock.return_value = [default_executor, 
second_executor]
+                yield [default_executor, second_executor]
+
     @pytest.mark.parametrize(
         "configs",
         [
@@ -1740,6 +1753,62 @@ class TestSchedulerJob:
         ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session)
         assert ti2.state == State.QUEUED, "Tasks run by Backfill Jobs should 
not be reset"
 
+    def test_adopt_or_reset_orphaned_tasks_multiple_executors(self, dag_maker, 
mock_executors):
+        """Test that with multiple executors configured tasks are sorted 
correctly and handed off to the
+        correct executor for adoption."""
+        session = settings.Session()
+        with 
dag_maker("test_execute_helper_reset_orphaned_tasks_multiple_executors"):
+            op1 = EmptyOperator(task_id="op1")
+            op2 = EmptyOperator(task_id="op2", executor="default_exec")
+            op3 = EmptyOperator(task_id="op3", executor="secondary_exec")
+
+        dr = dag_maker.create_dagrun()
+        scheduler_job = Job()
+        session.add(scheduler_job)
+        session.commit()
+        ti1 = dr.get_task_instance(task_id=op1.task_id, session=session)
+        ti2 = dr.get_task_instance(task_id=op2.task_id, session=session)
+        ti3 = dr.get_task_instance(task_id=op3.task_id, session=session)
+        tis = [ti1, ti2, ti3]
+        for ti in tis:
+            ti.state = State.QUEUED
+            ti.queued_by_job_id = scheduler_job.id
+        session.commit()
+
+        with 
mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as 
loader_mock:
+            # reload the scheduler_job_runner module so that it loads a fresh 
executor_loader module which
+            # contains the mocked load_executor method.
+            from airflow.jobs import scheduler_job_runner
+
+            reload(scheduler_job_runner)
+
+            processor = mock.MagicMock()
+
+            new_scheduler_job = Job()
+            self.job_runner = SchedulerJobRunner(job=new_scheduler_job, 
num_runs=0)
+            self.job_runner.processor_agent = processor
+            # The executors are mocked, so cannot be loaded/imported. Mock 
load_executor and return the
+            # correct object for the given input executor name.
+            loader_mock.side_effect = lambda *x: {
+                ("default_exec",): mock_executors[0],
+                (None,): mock_executors[0],
+                ("secondary_exec",): mock_executors[1],
+            }[x]
+
+            self.job_runner.adopt_or_reset_orphaned_tasks()
+
+        # Default executor is called for ti1 (no explicit executor override 
uses default) and ti2 (where we
+        # explicitly marked that for execution by the default executor)
+        try:
+            
mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti1, ti2])
+        except AssertionError:
+            # The order of the TIs given to try_adopt_task_instances is not 
consistent, so check the other
+            # order first before allowing AssertionError to fail the test
+            
mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti2, ti1])
+
+        # Second executor called for ti3
+        
mock_executors[1].try_adopt_task_instances.assert_called_once_with([ti3])
+
     def test_fail_stuck_queued_tasks(self, dag_maker, session):
         with dag_maker("test_fail_stuck_queued_tasks"):
             op1 = EmptyOperator(task_id="op1")

Reply via email to