This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ba8a36d7e7d feat: add callback support to aws batch executor (#62984)
ba8a36d7e7d is described below

commit ba8a36d7e7de4cd76513731be949a2ce06440de9
Author: Sebastian Daum <[email protected]>
AuthorDate: Wed May 13 19:07:15 2026 +0200

    feat: add callback support to aws batch executor (#62984)
---
 .../src/airflow/executors/base_executor.py         |   4 +-
 .../amazon/aws/executors/batch/batch_executor.py   | 128 +++++++++++++--------
 .../providers/amazon/aws/executors/batch/utils.py  |  33 ++++--
 .../aws/executors/batch/test_batch_executor.py     | 117 +++++++++++++++++--
 .../unit/amazon/aws/executors/batch/test_utils.py  |  20 ++--
 .../providers/celery/executors/celery_executor.py  |   4 +-
 6 files changed, 218 insertions(+), 88 deletions(-)

diff --git a/airflow-core/src/airflow/executors/base_executor.py 
b/airflow-core/src/airflow/executors/base_executor.py
index 9c9487c5377..2f1f5e8adb9 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -297,7 +297,7 @@ class BaseExecutor(LoggingMixin):
 
         return workloads_to_schedule
 
-    def _process_workloads(self, workloads: Sequence[ExecutorWorkload]) -> 
None:
+    def _process_workloads(self, workload_items: Sequence[ExecutorWorkload]) 
-> None:
         """
         Process the given workloads.
 
@@ -305,7 +305,7 @@ class BaseExecutor(LoggingMixin):
         the execution of workloads (e.g., queuing them to workers, submitting 
to
         external systems, etc.).
 
-        :param workloads: List of workloads to process
+        :param workload_items: List of workloads to process
         """
         raise NotImplementedError(f"{type(self).__name__} must implement 
_process_workloads()")
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
index 7058cbf5e60..835dfe2e1c4 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch 
Job."""
+"""AWS Batch Executor. Each Airflow workload gets delegated out to an AWS 
Batch Job."""
 
 from __future__ import annotations
 
@@ -33,7 +33,7 @@ from 
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
     exponential_backoff_retry,
 )
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
-from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_3_PLUS
 from airflow.providers.common.compat.sdk import AirflowException, Stats, 
timezone
 from airflow.utils.helpers import merge_dicts
 
@@ -42,6 +42,8 @@ if TYPE_CHECKING:
 
     from airflow.executors import workloads
     from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+
 from airflow.providers.amazon.aws.executors.batch.boto_schema import (
     BatchDescribeJobsResponseSchema,
     BatchSubmitJobResponseSchema,
@@ -88,6 +90,8 @@ class AwsBatchExecutor(BaseExecutor):
     """
 
     supports_multi_team: bool = True
+    if AIRFLOW_V_3_3_PLUS:
+        supports_callbacks: bool = True
 
     # AWS only allows a maximum number of JOBs in the describe_jobs function
     DESCRIBE_JOBS_BATCH_SIZE = 99
@@ -127,26 +131,44 @@ class AwsBatchExecutor(BaseExecutor):
     def queue_workload(self, workload: workloads.All, session: Session | None) 
-> None:
         from airflow.executors import workloads
 
-        if not isinstance(workload, workloads.ExecuteTask):
-            raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
-        ti = workload.ti
-        self.queued_tasks[ti.key] = workload
+        if isinstance(workload, workloads.ExecuteTask):
+            self.queued_tasks[workload.ti.key] = workload
+            return
+        if AIRFLOW_V_3_3_PLUS and isinstance(workload, 
workloads.ExecuteCallback):
+            self.queued_callbacks[workload.callback.key] = workload
+            return
+        raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
 
-    def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
-        from airflow.executors.workloads import ExecuteTask
+    def _process_workloads(self, workload_items: Sequence[workloads.All]) -> 
None:
+        from airflow.executors import workloads
 
-        # Airflow V3 version
-        for w in workloads:
-            if not isinstance(w, ExecuteTask):
+        for w in workload_items:
+            if isinstance(w, workloads.ExecuteTask):
+                task_command = [w]
+                task_key = w.ti.key
+                queue = w.ti.queue
+                executor_config = w.ti.executor_config or {}
+
+                del self.queued_tasks[task_key]
+                self.execute_async(
+                    key=task_key,
+                    command=task_command,  # type: ignore[arg-type]
+                    queue=queue,
+                    executor_config=executor_config,
+                )
+                self.running.add(task_key)
+            elif AIRFLOW_V_3_3_PLUS and isinstance(w, 
workloads.ExecuteCallback):
+                callback_command = [w]
+                callback_key = w.callback.key
+                queue = None
+                if isinstance(w.callback.data, dict) and "queue" in 
w.callback.data:
+                    queue = w.callback.data["queue"]
+
+                del self.queued_callbacks[callback_key]
+                self.execute_async(key=callback_key, command=callback_command, 
queue=queue)  # type: ignore[arg-type]
+                self.running.add(callback_key)
+            else:
                 raise RuntimeError(f"{type(self)} cannot handle workloads of 
type {type(w)}")
-            command = [w]
-            key = w.ti.key
-            queue = w.ti.queue
-            executor_config = w.ti.executor_config or {}
-
-            del self.queued_tasks[key]
-            self.execute_async(key=key, command=command, queue=queue, 
executor_config=executor_config)  # type: ignore[arg-type]
-            self.running.add(key)
 
     def check_health(self):
         """Make a test API call to check the health of the Batch Executor."""
@@ -235,7 +257,7 @@ class AwsBatchExecutor(BaseExecutor):
     def sync_running_jobs(self):
         all_job_ids = self.active_workers.get_all_jobs()
         if not all_job_ids:
-            self.log.debug("No active Airflow tasks, skipping sync")
+            self.log.debug("No active Airflow workloads, skipping sync")
             return
         describe_job_response = self._describe_jobs(all_job_ids)
 
@@ -245,8 +267,8 @@ class AwsBatchExecutor(BaseExecutor):
             if job.get_job_state() == State.FAILED:
                 self._handle_failed_job(job)
             elif job.get_job_state() == State.SUCCESS:
-                task_key = self.active_workers.pop_by_id(job.job_id)
-                self.success(task_key)
+                workload_key = self.active_workers.pop_by_id(job.job_id)
+                self.success(workload_key)
 
     def _handle_failed_job(self, job):
         """
@@ -263,15 +285,15 @@ class AwsBatchExecutor(BaseExecutor):
         # responsibility for ensuring the process started. Failures in the Dag 
will be caught by
         # Airflow, which will be handled separately.
         job_info = self.active_workers.id_to_job_info[job.job_id]
-        task_key = self.active_workers.id_to_key[job.job_id]
-        task_cmd = job_info.cmd
+        workload_key = self.active_workers.id_to_key[job.job_id]
+        workload_cmd = job_info.cmd
         queue = job_info.queue
         exec_info = job_info.config
         failure_count = 
self.active_workers.failure_count_by_id(job_id=job.job_id)
         if int(failure_count) < int(self.max_submit_job_attempts):
             self.log.warning(
-                "Airflow task %s failed due to %s. Failure %s out of %s 
occurred on %s. Rescheduling.",
-                task_key,
+                "Airflow workload %s failed due to %s. Failure %s out of %s 
occurred on %s. Rescheduling.",
+                workload_key,
                 job.status_reason,
                 failure_count,
                 self.max_submit_job_attempts,
@@ -281,8 +303,8 @@ class AwsBatchExecutor(BaseExecutor):
             self.active_workers.pop_by_id(job.job_id)
             self.pending_jobs.append(
                 BatchQueuedJob(
-                    task_key,
-                    task_cmd,
+                    workload_key,
+                    workload_cmd,
                     queue,
                     exec_info,
                     failure_count + 1,
@@ -291,12 +313,12 @@ class AwsBatchExecutor(BaseExecutor):
             )
         else:
             self.log.error(
-                "Airflow task %s has failed a maximum of %s times. Marking as 
failed",
-                task_key,
+                "Airflow workload %s has failed a maximum of %s times. Marking 
as failed",
+                workload_key,
                 failure_count,
             )
             self.active_workers.pop_by_id(job.job_id)
-            self.fail(task_key)
+            self.fail(workload_key)
 
     def attempt_submit_jobs(self):
         """
@@ -309,8 +331,8 @@ class AwsBatchExecutor(BaseExecutor):
         """
         for _ in range(len(self.pending_jobs)):
             batch_job = self.pending_jobs.popleft()
-            key = batch_job.key
-            cmd = batch_job.command
+            workload_key = batch_job.key
+            workload_cmd = batch_job.command
             queue = batch_job.queue
             exec_config = batch_job.executor_config
             attempt_number = batch_job.attempt_number
@@ -319,7 +341,7 @@ class AwsBatchExecutor(BaseExecutor):
                 self.pending_jobs.append(batch_job)
                 continue
             try:
-                submit_job_response = self._submit_job(key, cmd, queue, 
exec_config or {})
+                submit_job_response = self._submit_job(workload_key, 
workload_cmd, queue, exec_config or {})
             except NoCredentialsError:
                 self.pending_jobs.append(batch_job)
                 raise
@@ -337,7 +359,7 @@ class AwsBatchExecutor(BaseExecutor):
                     self.log.error(
                         (
                             "This job has been unsuccessfully attempted too 
many times (%s). "
-                            "Dropping the task. Reason: %s"
+                            "Dropping the workload. Reason: %s"
                         ),
                         attempt_number,
                         failure_reason,
@@ -345,10 +367,10 @@ class AwsBatchExecutor(BaseExecutor):
                     self.log_task_event(
                         event="batch job submit failure",
                         extra=f"This job has been unsuccessfully attempted too 
many times ({attempt_number}). "
-                        f"Dropping the task. Reason: {failure_reason}",
-                        ti_key=key,
+                        f"Dropping the workload. Reason: {failure_reason}",
+                        ti_key=workload_key,
                     )
-                    self.fail(key=key)
+                    self.fail(key=workload_key)
                 else:
                     batch_job.next_attempt_time = timezone.utcnow() + 
calculate_next_attempt_delay(
                         attempt_number
@@ -360,13 +382,13 @@ class AwsBatchExecutor(BaseExecutor):
                 job_id = submit_job_response["job_id"]
                 self.active_workers.add_job(
                     job_id=job_id,
-                    airflow_task_key=key,
-                    airflow_cmd=cmd,
+                    airflow_workload_key=workload_key,
+                    airflow_cmd=workload_cmd,
                     queue=queue,
                     exec_config=exec_config,
                     attempt_number=attempt_number,
                 )
-                self.running_state(key, job_id)
+                self.running_state(workload_key, job_id)
 
     def _describe_jobs(self, job_ids) -> list[BatchJob]:
         all_jobs = []
@@ -374,21 +396,25 @@ class AwsBatchExecutor(BaseExecutor):
             batched_job_ids = job_ids[i : i + 
self.__class__.DESCRIBE_JOBS_BATCH_SIZE]
             if not batched_job_ids:
                 continue
-            boto_describe_tasks = 
self.batch.describe_jobs(jobs=batched_job_ids)
+            boto_describe_workloads = 
self.batch.describe_jobs(jobs=batched_job_ids)
 
-            describe_tasks_response = 
BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
-            all_jobs.extend(describe_tasks_response["jobs"])
+            describe_workloads_response = 
BatchDescribeJobsResponseSchema().load(boto_describe_workloads)
+            all_jobs.extend(describe_workloads_response["jobs"])
         return all_jobs
 
-    def execute_async(self, key: TaskInstanceKey, command: CommandType, 
queue=None, executor_config=None):
-        """Save the task to be executed in the next sync using Boto3's RunTask 
API."""
+    def execute_async(
+        self, key: TaskInstanceKey | str, command: CommandType, queue=None, 
executor_config=None
+    ):
+        """Save the workload to be executed in the next sync using Boto3's 
RunTask API."""
         if executor_config and "command" in executor_config:
             raise ValueError('Executor Config should never override "command"')
 
         if len(command) == 1:
-            from airflow.executors.workloads import ExecuteTask
+            from airflow.executors import workloads
 
-            if isinstance(command[0], ExecuteTask):
+            if isinstance(command[0], workloads.ExecuteTask) or (
+                AIRFLOW_V_3_3_PLUS and isinstance(command[0], 
workloads.ExecuteCallback)
+            ):
                 workload = command[0]
                 ser_input = workload.model_dump_json()
                 command = [
@@ -433,7 +459,7 @@ class AwsBatchExecutor(BaseExecutor):
         self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: 
ExecutorConfigType
     ) -> dict:
         """
-        Override the Airflow command to update the container overrides so 
kwargs are specific to this task.
+        Override the Airflow command to update the container overrides so 
kwargs are specific to this workload.
 
         One last chance to modify Boto3's "submit_job" kwarg params before it 
gets passed into the Boto3
         client. For the latest kwarg parameters:
@@ -450,7 +476,7 @@ class AwsBatchExecutor(BaseExecutor):
         return submit_job_api
 
     def end(self, heartbeat_interval=10):
-        """Wait for all currently running tasks to end and prevent any new 
jobs from running."""
+        """Wait for all currently running workloads to end and prevent any new 
jobs from running."""
         try:
             while True:
                 self.sync()
@@ -500,7 +526,7 @@ class AwsBatchExecutor(BaseExecutor):
                     ti = next(ti for ti in tis if ti.external_executor_id == 
batch_job.job_id)
                     self.active_workers.add_job(
                         job_id=batch_job.job_id,
-                        airflow_task_key=ti.key,
+                        airflow_workload_key=ti.key,
                         airflow_cmd=ti.command_as_list(),
                         queue=ti.queue,
                         exec_config=ti.executor_config,
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py
index abcc16ec321..64685184902 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py
@@ -19,13 +19,22 @@ from __future__ import annotations
 import datetime
 from collections import defaultdict
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, TypeAlias
 
 from airflow.providers.amazon.aws.executors.utils.base_config_keys import 
BaseConfigKeys
 from airflow.utils.state import State
 
 if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstanceKey
+    from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS
+
+    if AIRFLOW_V_3_3_PLUS:
+        from airflow.executors.workloads.types import WorkloadKey
+
+        BatchJobWorkloadKey: TypeAlias = WorkloadKey
+    else:
+        BatchJobWorkloadKey: TypeAlias = TaskInstanceKey  # type: 
ignore[no-redef, misc]
+
 
 CommandType = list[str]
 ExecutorConfigType = dict[str, Any]
@@ -43,9 +52,9 @@ CONFIG_DEFAULTS = {
 class BatchQueuedJob:
     """Represents a Batch job that is queued. The job will be run in the next 
heartbeat."""
 
-    key: TaskInstanceKey
+    key: BatchJobWorkloadKey
     command: CommandType
-    queue: str
+    queue: str | None
     executor_config: ExecutorConfigType
     attempt_number: int
     next_attempt_time: datetime.datetime
@@ -91,33 +100,33 @@ class BatchJobCollection:
     """A collection to manage running Batch Jobs."""
 
     def __init__(self):
-        self.key_to_id: dict[TaskInstanceKey, str] = {}
-        self.id_to_key: dict[str, TaskInstanceKey] = {}
+        self.key_to_id: dict[BatchJobWorkloadKey, str] = {}
+        self.id_to_key: dict[str, BatchJobWorkloadKey] = {}
         self.id_to_failure_counts: dict[str, int] = defaultdict(int)
         self.id_to_job_info: dict[str, BatchJobInfo] = {}
 
     def add_job(
         self,
         job_id: str,
-        airflow_task_key: TaskInstanceKey,
+        airflow_workload_key: BatchJobWorkloadKey,
         airflow_cmd: CommandType,
         queue: str,
         exec_config: ExecutorConfigType,
         attempt_number: int,
     ):
         """Add a job to the collection."""
-        self.key_to_id[airflow_task_key] = job_id
-        self.id_to_key[job_id] = airflow_task_key
+        self.key_to_id[airflow_workload_key] = job_id
+        self.id_to_key[job_id] = airflow_workload_key
         self.id_to_failure_counts[job_id] = attempt_number
         self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, 
queue=queue, config=exec_config)
 
-    def pop_by_id(self, job_id: str) -> TaskInstanceKey:
+    def pop_by_id(self, job_id: str) -> BatchJobWorkloadKey:
         """Delete job from collection based off of Batch Job ID."""
-        task_key = self.id_to_key[job_id]
-        del self.key_to_id[task_key]
+        workload_key = self.id_to_key[job_id]
+        del self.key_to_id[workload_key]
         del self.id_to_key[job_id]
         del self.id_to_failure_counts[job_id]
-        return task_key
+        return workload_key
 
     def failure_count_by_id(self, job_id: str) -> int:
         """Get the number of times a job has failed given a Batch Job Id."""
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py 
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
index 0e723894597..00af3ca330d 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
@@ -49,7 +49,7 @@ from airflow.version import version as airflow_version_str
 
 from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
 from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS
 
 airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3]))
 ARN1 = "arn1"
@@ -111,7 +111,7 @@ class TestBatchJobCollection:
         self.first_airflow_key = mock.Mock(spec=tuple)
         self.collection.add_job(
             job_id=self.first_job_id,
-            airflow_task_key=self.first_airflow_key,
+            airflow_workload_key=self.first_airflow_key,
             airflow_cmd="command1",
             queue="queue1",
             exec_config={},
@@ -122,7 +122,7 @@ class TestBatchJobCollection:
         self.second_airflow_key = mock.Mock(spec=tuple)
         self.collection.add_job(
             job_id=self.second_job_id,
-            airflow_task_key=self.second_airflow_key,
+            airflow_workload_key=self.second_airflow_key,
             airflow_cmd="command2",
             queue="queue2",
             exec_config={},
@@ -210,6 +210,7 @@ class TestAwsBatchExecutor:
         workload = mock.Mock(spec=ExecuteTask)
         workload.ti = mock.Mock(spec=TaskInstance)
         workload.ti.key = mock_airflow_key()
+        workload.ti.queue = "some-job-queue"
         tags_exec_config = [{"key": "FOO", "value": "BAR"}]
         workload.ti.executor_config = {"tags": tags_exec_config}
         ser_workload = json.dumps({"test_key": "test_value"})
@@ -269,6 +270,98 @@ class TestAwsBatchExecutor:
         assert job_id == ARN1
         running_state_mock.assert_called_once_with(workload.ti.key, ARN1)
 
+    @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 
3.3+")
+    
@mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state")
+    def test_task_sdk_callback(self, running_state_mock, mock_airflow_key, 
mock_executor, mock_cmd):
+        """Test task sdk execution for callbacks from end-to-end."""
+        from airflow.executors.workloads import ExecuteCallback
+
+        workload = mock.Mock(spec=ExecuteCallback)
+        workload.callback = mock.Mock()
+        workload.callback.key = mock_airflow_key()
+        ser_workload = json.dumps({"test_key": "test_value"})
+        workload.model_dump_json.return_value = ser_workload
+
+        mock_executor.queue_workload(workload, mock.Mock())
+
+        mock_executor.batch.submit_job.return_value = {"jobId": ARN1, 
"jobName": "some-job-name"}
+
+        assert mock_executor.queued_callbacks[workload.callback.key] == 
workload
+        assert len(mock_executor.pending_jobs) == 0
+        assert len(mock_executor.running) == 0
+        mock_executor._process_workloads([workload])
+        assert len(mock_executor.queued_callbacks) == 0
+        assert len(mock_executor.running) == 1
+        assert workload.callback.key in mock_executor.running
+        assert len(mock_executor.pending_jobs) == 1
+        assert mock_executor.pending_jobs[0].command == [
+            "python",
+            "-m",
+            "airflow.sdk.execution_time.execute_workload",
+            "--json-string",
+            '{"test_key": "test_value"}',
+        ]
+
+        mock_executor.attempt_submit_jobs()
+        mock_executor.batch.submit_job.assert_called_once()
+        assert len(mock_executor.pending_jobs) == 0
+        mock_executor.batch.submit_job.assert_called_once_with(
+            jobDefinition="some-job-def",
+            jobName="some-job-name",
+            jobQueue="some-job-queue",
+            containerOverrides={
+                "command": [
+                    "python",
+                    "-m",
+                    "airflow.sdk.execution_time.execute_workload",
+                    "--json-string",
+                    ser_workload,
+                ],
+                "environment": [
+                    {
+                        "name": "AIRFLOW_IS_EXECUTOR_CONTAINER",
+                        "value": "true",
+                    },
+                ],
+            },
+        )
+
+        # Task is stored in active worker.
+        assert len(mock_executor.active_workers) == 1
+        # Get the job_id for this task key
+        job_id = next(
+            job_id
+            for job_id, key in mock_executor.active_workers.id_to_key.items()
+            if key == workload.callback.key
+        )
+        assert job_id == ARN1
+        running_state_mock.assert_called_once_with(workload.callback.key, ARN1)
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 
3.3+")
+    
@mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state")
+    def test_task_sdk_callback_with_queue(self, mock_airflow_key, 
mock_executor):
+        """Test task sdk execution for callbacks with queue from end-to-end."""
+        from airflow.executors.workloads import ExecuteCallback
+
+        workload = mock.Mock(spec=ExecuteCallback)
+        workload.callback = mock.Mock()
+        workload.callback.key = mock_airflow_key()
+        workload.callback.data = {"queue": "fast-queue"}
+
+        mock_executor.queue_workload(workload, mock.Mock())
+
+        mock_executor.batch.submit_job.return_value = {"jobId": ARN1, 
"jobName": "some-job-name"}
+
+        assert mock_executor.queued_callbacks[workload.callback.key] == 
workload
+        assert len(mock_executor.pending_jobs) == 0
+        assert len(mock_executor.running) == 0
+        mock_executor._process_workloads([workload])
+        assert len(mock_executor.queued_callbacks) == 0
+        assert len(mock_executor.running) == 1
+        assert workload.callback.key in mock_executor.running
+        assert len(mock_executor.pending_jobs) == 1
+        assert mock_executor.pending_jobs[0].queue == "fast-queue"
+
     @mock.patch.object(batch_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
     def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor):
         """
@@ -445,7 +538,7 @@ class TestAwsBatchExecutor:
         mock_executor.sync_running_jobs()
         for i in range(2):
             assert (
-                f"Airflow task {airflow_keys[i]} failed due to 
{jobs[i]['statusReason']}. Failure 1 out of 
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. 
Rescheduling."
+                f"Airflow workload {airflow_keys[i]} failed due to 
{jobs[i]['statusReason']}. Failure 1 out of 
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. 
Rescheduling."
                 in caplog.messages[i]
             )
 
@@ -454,7 +547,7 @@ class TestAwsBatchExecutor:
         mock_executor.sync_running_jobs()
         for i in range(2):
             assert (
-                f"Airflow task {airflow_keys[i]} failed due to 
{jobs[i]['statusReason']}. Failure 2 out of 
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. 
Rescheduling."
+                f"Airflow workload {airflow_keys[i]} failed due to 
{jobs[i]['statusReason']}. Failure 2 out of 
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. 
Rescheduling."
                 in caplog.messages[i]
             )
 
@@ -463,7 +556,7 @@ class TestAwsBatchExecutor:
         mock_executor.sync_running_jobs()
         for i in range(2):
             assert (
-                f"Airflow task {airflow_keys[i]} has failed a maximum of 
{mock_executor.max_submit_job_attempts} times. Marking as failed"
+                f"Airflow workload {airflow_keys[i]} has failed a maximum of 
{mock_executor.max_submit_job_attempts} times. Marking as failed"
                 in caplog.text
             )
 
@@ -479,7 +572,7 @@ class TestAwsBatchExecutor:
         caplog.set_level("DEBUG")
         assert len(mock_executor.active_workers.get_all_jobs()) == 0
         mock_executor.sync_running_jobs()
-        assert "No active Airflow tasks, skipping sync" in caplog.messages[0]
+        assert "No active Airflow workloads, skipping sync" in 
caplog.messages[0]
 
     def test_sync_client_error(self, mock_executor, caplog):
         mock_executor.execute_async("airflow_key", "airflow_cmd")
@@ -498,7 +591,7 @@ class TestAwsBatchExecutor:
     def test_sync_exception(self, mock_executor, caplog):
         mock_executor.active_workers.add_job(
             job_id="job_id",
-            airflow_task_key="airflow_key",
+            airflow_workload_key="airflow_key",
             airflow_cmd="command",
             queue="queue",
             exec_config={},
@@ -616,7 +709,7 @@ class TestAwsBatchExecutor:
     def test_terminate_failure(self, mock_executor, caplog):
         mock_executor.active_workers.add_job(
             job_id="job_id",
-            airflow_task_key="airflow_key",
+            airflow_workload_key="airflow_key",
             airflow_cmd="command",
             queue="queue",
             exec_config={},
@@ -662,7 +755,7 @@ class TestAwsBatchExecutor:
         """
         executor.active_workers.add_job(
             job_id=job_id,
-            airflow_task_key=airflow_key,
+            airflow_workload_key=airflow_key,
             airflow_cmd="airflow_cmd",
             queue="queue",
             exec_config={},
@@ -965,7 +1058,9 @@ class TestBatchExecutorConfig:
 
         executor = AwsBatchExecutor()
 
-        final_run_task_kwargs = executor._submit_job_kwargs(mock_ti_key, 
command, "queue", exec_config)
+        final_run_task_kwargs = executor._submit_job_kwargs(
+            mock_ti_key, command, expected_result["jobQueue"], exec_config
+        )
 
         assert final_run_task_kwargs == expected_result
 
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py 
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py
index c0b914bdb3b..3ee8dca2628 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py
@@ -152,7 +152,7 @@ class TestBatchJobCollection:
         """Test adding a job to the collection."""
         self.collection.add_job(
             job_id=self.job_id1,
-            airflow_task_key=self.key1,
+            airflow_workload_key=self.key1,
             airflow_cmd=self.cmd1,
             queue=self.queue1,
             exec_config=self.config1,
@@ -170,7 +170,7 @@ class TestBatchJobCollection:
         """Test adding multiple jobs to the collection."""
         self.collection.add_job(
             job_id=self.job_id1,
-            airflow_task_key=self.key1,
+            airflow_workload_key=self.key1,
             airflow_cmd=self.cmd1,
             queue=self.queue1,
             exec_config=self.config1,
@@ -178,7 +178,7 @@ class TestBatchJobCollection:
         )
         self.collection.add_job(
             job_id=self.job_id2,
-            airflow_task_key=self.key2,
+            airflow_workload_key=self.key2,
             airflow_cmd=self.cmd2,
             queue=self.queue2,
             exec_config=self.config2,
@@ -194,7 +194,7 @@ class TestBatchJobCollection:
         """Test removing a job from the collection by its ID."""
         self.collection.add_job(
             job_id=self.job_id1,
-            airflow_task_key=self.key1,
+            airflow_workload_key=self.key1,
             airflow_cmd=self.cmd1,
             queue=self.queue1,
             exec_config=self.config1,
@@ -222,7 +222,7 @@ class TestBatchJobCollection:
         attempt_number = 5
         self.collection.add_job(
             job_id=self.job_id1,
-            airflow_task_key=self.key1,
+            airflow_workload_key=self.key1,
             airflow_cmd=self.cmd1,
             queue=self.queue1,
             exec_config=self.config1,
@@ -240,7 +240,7 @@ class TestBatchJobCollection:
         initial_attempt = 1
         self.collection.add_job(
             job_id=self.job_id1,
-            airflow_task_key=self.key1,
+            airflow_workload_key=self.key1,
             airflow_cmd=self.cmd1,
             queue=self.queue1,
             exec_config=self.config1,
@@ -264,7 +264,7 @@ class TestBatchJobCollection:
         """Test getting all job IDs from a collection with jobs."""
         self.collection.add_job(
             job_id=self.job_id1,
-            airflow_task_key=self.key1,
+            airflow_workload_key=self.key1,
             airflow_cmd=self.cmd1,
             queue=self.queue1,
             exec_config=self.config1,
@@ -272,7 +272,7 @@ class TestBatchJobCollection:
         )
         self.collection.add_job(
             job_id=self.job_id2,
-            airflow_task_key=self.key2,
+            airflow_workload_key=self.key2,
             airflow_cmd=self.cmd2,
             queue=self.queue2,
             exec_config=self.config2,
@@ -288,7 +288,7 @@ class TestBatchJobCollection:
         assert len(self.collection) == 0
         self.collection.add_job(
             job_id=self.job_id1,
-            airflow_task_key=self.key1,
+            airflow_workload_key=self.key1,
             airflow_cmd=self.cmd1,
             queue=self.queue1,
             exec_config=self.config1,
@@ -297,7 +297,7 @@ class TestBatchJobCollection:
         assert len(self.collection) == 1
         self.collection.add_job(
             job_id=self.job_id2,
-            airflow_task_key=self.key2,
+            airflow_workload_key=self.key2,
             airflow_cmd=self.cmd2,
             queue=self.queue2,
             exec_config=self.config2,
diff --git 
a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py 
b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
index 16d634fbdef..7fc388c9608 100644
--- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
+++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
@@ -173,7 +173,7 @@ class CeleryExecutor(BaseExecutor):
 
         self._send_workloads(task_tuples_to_send)
 
-    def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
+    def _process_workloads(self, workload_items: Sequence[workloads.All]) -> 
None:
         # Airflow V3 version -- have to delay imports until we know we are on 
v3.
         from airflow.executors.workloads import ExecuteTask
 
@@ -181,7 +181,7 @@ class CeleryExecutor(BaseExecutor):
             from airflow.executors.workloads import ExecuteCallback
 
         workloads_to_be_sent: list[WorkloadInCelery] = []
-        for workload in workloads:
+        for workload in workload_items:
             if isinstance(workload, ExecuteTask):
                 workloads_to_be_sent.append((workload.ti.key, workload, 
workload.ti.queue, self.team_name))
             elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):


Reply via email to