This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ba8a36d7e7d feat: add callback support to aws batch executor (#62984)
ba8a36d7e7d is described below
commit ba8a36d7e7de4cd76513731be949a2ce06440de9
Author: Sebastian Daum <[email protected]>
AuthorDate: Wed May 13 19:07:15 2026 +0200
feat: add callback support to aws batch executor (#62984)
---
.../src/airflow/executors/base_executor.py | 4 +-
.../amazon/aws/executors/batch/batch_executor.py | 128 +++++++++++++--------
.../providers/amazon/aws/executors/batch/utils.py | 33 ++++--
.../aws/executors/batch/test_batch_executor.py | 117 +++++++++++++++++--
.../unit/amazon/aws/executors/batch/test_utils.py | 20 ++--
.../providers/celery/executors/celery_executor.py | 4 +-
6 files changed, 218 insertions(+), 88 deletions(-)
diff --git a/airflow-core/src/airflow/executors/base_executor.py
b/airflow-core/src/airflow/executors/base_executor.py
index 9c9487c5377..2f1f5e8adb9 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -297,7 +297,7 @@ class BaseExecutor(LoggingMixin):
return workloads_to_schedule
- def _process_workloads(self, workloads: Sequence[ExecutorWorkload]) ->
None:
+ def _process_workloads(self, workload_items: Sequence[ExecutorWorkload])
-> None:
"""
Process the given workloads.
@@ -305,7 +305,7 @@ class BaseExecutor(LoggingMixin):
the execution of workloads (e.g., queuing them to workers, submitting
to
external systems, etc.).
- :param workloads: List of workloads to process
+ :param workload_items: List of workloads to process
"""
raise NotImplementedError(f"{type(self).__name__} must implement
_process_workloads()")
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
index 7058cbf5e60..835dfe2e1c4 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-"""AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch
Job."""
+"""AWS Batch Executor. Each Airflow workload gets delegated out to an AWS
Batch Job."""
from __future__ import annotations
@@ -33,7 +33,7 @@ from
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
-from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats,
timezone
from airflow.utils.helpers import merge_dicts
@@ -42,6 +42,8 @@ if TYPE_CHECKING:
from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+
from airflow.providers.amazon.aws.executors.batch.boto_schema import (
BatchDescribeJobsResponseSchema,
BatchSubmitJobResponseSchema,
@@ -88,6 +90,8 @@ class AwsBatchExecutor(BaseExecutor):
"""
supports_multi_team: bool = True
+ if AIRFLOW_V_3_3_PLUS:
+ supports_callbacks: bool = True
# AWS only allows a maximum number of JOBs in the describe_jobs function
DESCRIBE_JOBS_BATCH_SIZE = 99
@@ -127,26 +131,44 @@ class AwsBatchExecutor(BaseExecutor):
def queue_workload(self, workload: workloads.All, session: Session | None)
-> None:
from airflow.executors import workloads
- if not isinstance(workload, workloads.ExecuteTask):
- raise RuntimeError(f"{type(self)} cannot handle workloads of type
{type(workload)}")
- ti = workload.ti
- self.queued_tasks[ti.key] = workload
+ if isinstance(workload, workloads.ExecuteTask):
+ self.queued_tasks[workload.ti.key] = workload
+ return
+ if AIRFLOW_V_3_3_PLUS and isinstance(workload,
workloads.ExecuteCallback):
+ self.queued_callbacks[workload.callback.key] = workload
+ return
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type
{type(workload)}")
- def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
- from airflow.executors.workloads import ExecuteTask
+ def _process_workloads(self, workload_items: Sequence[workloads.All]) ->
None:
+ from airflow.executors import workloads
- # Airflow V3 version
- for w in workloads:
- if not isinstance(w, ExecuteTask):
+ for w in workload_items:
+ if isinstance(w, workloads.ExecuteTask):
+ task_command = [w]
+ task_key = w.ti.key
+ queue = w.ti.queue
+ executor_config = w.ti.executor_config or {}
+
+ del self.queued_tasks[task_key]
+ self.execute_async(
+ key=task_key,
+ command=task_command, # type: ignore[arg-type]
+ queue=queue,
+ executor_config=executor_config,
+ )
+ self.running.add(task_key)
+ elif AIRFLOW_V_3_3_PLUS and isinstance(w,
workloads.ExecuteCallback):
+ callback_command = [w]
+ callback_key = w.callback.key
+ queue = None
+ if isinstance(w.callback.data, dict) and "queue" in
w.callback.data:
+ queue = w.callback.data["queue"]
+
+ del self.queued_callbacks[callback_key]
+ self.execute_async(key=callback_key, command=callback_command,
queue=queue) # type: ignore[arg-type]
+ self.running.add(callback_key)
+ else:
raise RuntimeError(f"{type(self)} cannot handle workloads of
type {type(w)}")
- command = [w]
- key = w.ti.key
- queue = w.ti.queue
- executor_config = w.ti.executor_config or {}
-
- del self.queued_tasks[key]
- self.execute_async(key=key, command=command, queue=queue,
executor_config=executor_config) # type: ignore[arg-type]
- self.running.add(key)
def check_health(self):
"""Make a test API call to check the health of the Batch Executor."""
@@ -235,7 +257,7 @@ class AwsBatchExecutor(BaseExecutor):
def sync_running_jobs(self):
all_job_ids = self.active_workers.get_all_jobs()
if not all_job_ids:
- self.log.debug("No active Airflow tasks, skipping sync")
+ self.log.debug("No active Airflow workloads, skipping sync")
return
describe_job_response = self._describe_jobs(all_job_ids)
@@ -245,8 +267,8 @@ class AwsBatchExecutor(BaseExecutor):
if job.get_job_state() == State.FAILED:
self._handle_failed_job(job)
elif job.get_job_state() == State.SUCCESS:
- task_key = self.active_workers.pop_by_id(job.job_id)
- self.success(task_key)
+ workload_key = self.active_workers.pop_by_id(job.job_id)
+ self.success(workload_key)
def _handle_failed_job(self, job):
"""
@@ -263,15 +285,15 @@ class AwsBatchExecutor(BaseExecutor):
# responsibility for ensuring the process started. Failures in the Dag
will be caught by
# Airflow, which will be handled separately.
job_info = self.active_workers.id_to_job_info[job.job_id]
- task_key = self.active_workers.id_to_key[job.job_id]
- task_cmd = job_info.cmd
+ workload_key = self.active_workers.id_to_key[job.job_id]
+ workload_cmd = job_info.cmd
queue = job_info.queue
exec_info = job_info.config
failure_count =
self.active_workers.failure_count_by_id(job_id=job.job_id)
if int(failure_count) < int(self.max_submit_job_attempts):
self.log.warning(
- "Airflow task %s failed due to %s. Failure %s out of %s
occurred on %s. Rescheduling.",
- task_key,
+ "Airflow workload %s failed due to %s. Failure %s out of %s
occurred on %s. Rescheduling.",
+ workload_key,
job.status_reason,
failure_count,
self.max_submit_job_attempts,
@@ -281,8 +303,8 @@ class AwsBatchExecutor(BaseExecutor):
self.active_workers.pop_by_id(job.job_id)
self.pending_jobs.append(
BatchQueuedJob(
- task_key,
- task_cmd,
+ workload_key,
+ workload_cmd,
queue,
exec_info,
failure_count + 1,
@@ -291,12 +313,12 @@ class AwsBatchExecutor(BaseExecutor):
)
else:
self.log.error(
- "Airflow task %s has failed a maximum of %s times. Marking as
failed",
- task_key,
+ "Airflow workload %s has failed a maximum of %s times. Marking
as failed",
+ workload_key,
failure_count,
)
self.active_workers.pop_by_id(job.job_id)
- self.fail(task_key)
+ self.fail(workload_key)
def attempt_submit_jobs(self):
"""
@@ -309,8 +331,8 @@ class AwsBatchExecutor(BaseExecutor):
"""
for _ in range(len(self.pending_jobs)):
batch_job = self.pending_jobs.popleft()
- key = batch_job.key
- cmd = batch_job.command
+ workload_key = batch_job.key
+ workload_cmd = batch_job.command
queue = batch_job.queue
exec_config = batch_job.executor_config
attempt_number = batch_job.attempt_number
@@ -319,7 +341,7 @@ class AwsBatchExecutor(BaseExecutor):
self.pending_jobs.append(batch_job)
continue
try:
- submit_job_response = self._submit_job(key, cmd, queue,
exec_config or {})
+ submit_job_response = self._submit_job(workload_key,
workload_cmd, queue, exec_config or {})
except NoCredentialsError:
self.pending_jobs.append(batch_job)
raise
@@ -337,7 +359,7 @@ class AwsBatchExecutor(BaseExecutor):
self.log.error(
(
"This job has been unsuccessfully attempted too
many times (%s). "
- "Dropping the task. Reason: %s"
+ "Dropping the workload. Reason: %s"
),
attempt_number,
failure_reason,
@@ -345,10 +367,10 @@ class AwsBatchExecutor(BaseExecutor):
self.log_task_event(
event="batch job submit failure",
extra=f"This job has been unsuccessfully attempted too
many times ({attempt_number}). "
- f"Dropping the task. Reason: {failure_reason}",
- ti_key=key,
+ f"Dropping the workload. Reason: {failure_reason}",
+ ti_key=workload_key,
)
- self.fail(key=key)
+ self.fail(key=workload_key)
else:
batch_job.next_attempt_time = timezone.utcnow() +
calculate_next_attempt_delay(
attempt_number
@@ -360,13 +382,13 @@ class AwsBatchExecutor(BaseExecutor):
job_id = submit_job_response["job_id"]
self.active_workers.add_job(
job_id=job_id,
- airflow_task_key=key,
- airflow_cmd=cmd,
+ airflow_workload_key=workload_key,
+ airflow_cmd=workload_cmd,
queue=queue,
exec_config=exec_config,
attempt_number=attempt_number,
)
- self.running_state(key, job_id)
+ self.running_state(workload_key, job_id)
def _describe_jobs(self, job_ids) -> list[BatchJob]:
all_jobs = []
@@ -374,21 +396,25 @@ class AwsBatchExecutor(BaseExecutor):
batched_job_ids = job_ids[i : i +
self.__class__.DESCRIBE_JOBS_BATCH_SIZE]
if not batched_job_ids:
continue
- boto_describe_tasks =
self.batch.describe_jobs(jobs=batched_job_ids)
+ boto_describe_workloads =
self.batch.describe_jobs(jobs=batched_job_ids)
- describe_tasks_response =
BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
- all_jobs.extend(describe_tasks_response["jobs"])
+ describe_workloads_response =
BatchDescribeJobsResponseSchema().load(boto_describe_workloads)
+ all_jobs.extend(describe_workloads_response["jobs"])
return all_jobs
- def execute_async(self, key: TaskInstanceKey, command: CommandType,
queue=None, executor_config=None):
- """Save the task to be executed in the next sync using Boto3's RunTask
API."""
+ def execute_async(
+ self, key: TaskInstanceKey | str, command: CommandType, queue=None,
executor_config=None
+ ):
+ """Save the workload to be executed in the next sync using Boto3's
RunTask API."""
if executor_config and "command" in executor_config:
raise ValueError('Executor Config should never override "command"')
if len(command) == 1:
- from airflow.executors.workloads import ExecuteTask
+ from airflow.executors import workloads
- if isinstance(command[0], ExecuteTask):
+ if isinstance(command[0], workloads.ExecuteTask) or (
+ AIRFLOW_V_3_3_PLUS and isinstance(command[0],
workloads.ExecuteCallback)
+ ):
workload = command[0]
ser_input = workload.model_dump_json()
command = [
@@ -433,7 +459,7 @@ class AwsBatchExecutor(BaseExecutor):
self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config:
ExecutorConfigType
) -> dict:
"""
- Override the Airflow command to update the container overrides so
kwargs are specific to this task.
+ Override the Airflow command to update the container overrides so
kwargs are specific to this workload.
One last chance to modify Boto3's "submit_job" kwarg params before it
gets passed into the Boto3
client. For the latest kwarg parameters:
@@ -450,7 +476,7 @@ class AwsBatchExecutor(BaseExecutor):
return submit_job_api
def end(self, heartbeat_interval=10):
- """Wait for all currently running tasks to end and prevent any new
jobs from running."""
+ """Wait for all currently running workloads to end and prevent any new
jobs from running."""
try:
while True:
self.sync()
@@ -500,7 +526,7 @@ class AwsBatchExecutor(BaseExecutor):
ti = next(ti for ti in tis if ti.external_executor_id ==
batch_job.job_id)
self.active_workers.add_job(
job_id=batch_job.job_id,
- airflow_task_key=ti.key,
+ airflow_workload_key=ti.key,
airflow_cmd=ti.command_as_list(),
queue=ti.queue,
exec_config=ti.executor_config,
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py
index abcc16ec321..64685184902 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py
@@ -19,13 +19,22 @@ from __future__ import annotations
import datetime
from collections import defaultdict
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, TypeAlias
from airflow.providers.amazon.aws.executors.utils.base_config_keys import
BaseConfigKeys
from airflow.utils.state import State
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS
+
+ if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads.types import WorkloadKey
+
+ BatchJobWorkloadKey: TypeAlias = WorkloadKey
+ else:
+ BatchJobWorkloadKey: TypeAlias = TaskInstanceKey # type:
ignore[no-redef, misc]
+
CommandType = list[str]
ExecutorConfigType = dict[str, Any]
@@ -43,9 +52,9 @@ CONFIG_DEFAULTS = {
class BatchQueuedJob:
"""Represents a Batch job that is queued. The job will be run in the next
heartbeat."""
- key: TaskInstanceKey
+ key: BatchJobWorkloadKey
command: CommandType
- queue: str
+ queue: str | None
executor_config: ExecutorConfigType
attempt_number: int
next_attempt_time: datetime.datetime
@@ -91,33 +100,33 @@ class BatchJobCollection:
"""A collection to manage running Batch Jobs."""
def __init__(self):
- self.key_to_id: dict[TaskInstanceKey, str] = {}
- self.id_to_key: dict[str, TaskInstanceKey] = {}
+ self.key_to_id: dict[BatchJobWorkloadKey, str] = {}
+ self.id_to_key: dict[str, BatchJobWorkloadKey] = {}
self.id_to_failure_counts: dict[str, int] = defaultdict(int)
self.id_to_job_info: dict[str, BatchJobInfo] = {}
def add_job(
self,
job_id: str,
- airflow_task_key: TaskInstanceKey,
+ airflow_workload_key: BatchJobWorkloadKey,
airflow_cmd: CommandType,
queue: str,
exec_config: ExecutorConfigType,
attempt_number: int,
):
"""Add a job to the collection."""
- self.key_to_id[airflow_task_key] = job_id
- self.id_to_key[job_id] = airflow_task_key
+ self.key_to_id[airflow_workload_key] = job_id
+ self.id_to_key[job_id] = airflow_workload_key
self.id_to_failure_counts[job_id] = attempt_number
self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd,
queue=queue, config=exec_config)
- def pop_by_id(self, job_id: str) -> TaskInstanceKey:
+ def pop_by_id(self, job_id: str) -> BatchJobWorkloadKey:
"""Delete job from collection based off of Batch Job ID."""
- task_key = self.id_to_key[job_id]
- del self.key_to_id[task_key]
+ workload_key = self.id_to_key[job_id]
+ del self.key_to_id[workload_key]
del self.id_to_key[job_id]
del self.id_to_failure_counts[job_id]
- return task_key
+ return workload_key
def failure_count_by_id(self, job_id: str) -> int:
"""Get the number of times a job has failed given a Batch Job Id."""
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
index 0e723894597..00af3ca330d 100644
---
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
+++
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
@@ -49,7 +49,7 @@ from airflow.version import version as airflow_version_str
from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS
airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3]))
ARN1 = "arn1"
@@ -111,7 +111,7 @@ class TestBatchJobCollection:
self.first_airflow_key = mock.Mock(spec=tuple)
self.collection.add_job(
job_id=self.first_job_id,
- airflow_task_key=self.first_airflow_key,
+ airflow_workload_key=self.first_airflow_key,
airflow_cmd="command1",
queue="queue1",
exec_config={},
@@ -122,7 +122,7 @@ class TestBatchJobCollection:
self.second_airflow_key = mock.Mock(spec=tuple)
self.collection.add_job(
job_id=self.second_job_id,
- airflow_task_key=self.second_airflow_key,
+ airflow_workload_key=self.second_airflow_key,
airflow_cmd="command2",
queue="queue2",
exec_config={},
@@ -210,6 +210,7 @@ class TestAwsBatchExecutor:
workload = mock.Mock(spec=ExecuteTask)
workload.ti = mock.Mock(spec=TaskInstance)
workload.ti.key = mock_airflow_key()
+ workload.ti.queue = "some-job-queue"
tags_exec_config = [{"key": "FOO", "value": "BAR"}]
workload.ti.executor_config = {"tags": tags_exec_config}
ser_workload = json.dumps({"test_key": "test_value"})
@@ -269,6 +270,98 @@ class TestAwsBatchExecutor:
assert job_id == ARN1
running_state_mock.assert_called_once_with(workload.ti.key, ARN1)
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+
@mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state")
+ def test_task_sdk_callback(self, running_state_mock, mock_airflow_key,
mock_executor, mock_cmd):
+ """Test task sdk execution for callbacks from end-to-end."""
+ from airflow.executors.workloads import ExecuteCallback
+
+ workload = mock.Mock(spec=ExecuteCallback)
+ workload.callback = mock.Mock()
+ workload.callback.key = mock_airflow_key()
+ ser_workload = json.dumps({"test_key": "test_value"})
+ workload.model_dump_json.return_value = ser_workload
+
+ mock_executor.queue_workload(workload, mock.Mock())
+
+ mock_executor.batch.submit_job.return_value = {"jobId": ARN1,
"jobName": "some-job-name"}
+
+ assert mock_executor.queued_callbacks[workload.callback.key] ==
workload
+ assert len(mock_executor.pending_jobs) == 0
+ assert len(mock_executor.running) == 0
+ mock_executor._process_workloads([workload])
+ assert len(mock_executor.queued_callbacks) == 0
+ assert len(mock_executor.running) == 1
+ assert workload.callback.key in mock_executor.running
+ assert len(mock_executor.pending_jobs) == 1
+ assert mock_executor.pending_jobs[0].command == [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "--json-string",
+ '{"test_key": "test_value"}',
+ ]
+
+ mock_executor.attempt_submit_jobs()
+ mock_executor.batch.submit_job.assert_called_once()
+ assert len(mock_executor.pending_jobs) == 0
+ mock_executor.batch.submit_job.assert_called_once_with(
+ jobDefinition="some-job-def",
+ jobName="some-job-name",
+ jobQueue="some-job-queue",
+ containerOverrides={
+ "command": [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "--json-string",
+ ser_workload,
+ ],
+ "environment": [
+ {
+ "name": "AIRFLOW_IS_EXECUTOR_CONTAINER",
+ "value": "true",
+ },
+ ],
+ },
+ )
+
+ # Task is stored in active worker.
+ assert len(mock_executor.active_workers) == 1
+ # Get the job_id for this task key
+ job_id = next(
+ job_id
+ for job_id, key in mock_executor.active_workers.id_to_key.items()
+ if key == workload.callback.key
+ )
+ assert job_id == ARN1
+ running_state_mock.assert_called_once_with(workload.callback.key, ARN1)
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+
@mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state")
+ def test_task_sdk_callback_with_queue(self, mock_airflow_key,
mock_executor):
+ """Test task sdk execution for callbacks with queue from end-to-end."""
+ from airflow.executors.workloads import ExecuteCallback
+
+ workload = mock.Mock(spec=ExecuteCallback)
+ workload.callback = mock.Mock()
+ workload.callback.key = mock_airflow_key()
+ workload.callback.data = {"queue": "fast-queue"}
+
+ mock_executor.queue_workload(workload, mock.Mock())
+
+ mock_executor.batch.submit_job.return_value = {"jobId": ARN1,
"jobName": "some-job-name"}
+
+ assert mock_executor.queued_callbacks[workload.callback.key] ==
workload
+ assert len(mock_executor.pending_jobs) == 0
+ assert len(mock_executor.running) == 0
+ mock_executor._process_workloads([workload])
+ assert len(mock_executor.queued_callbacks) == 0
+ assert len(mock_executor.running) == 1
+ assert workload.callback.key in mock_executor.running
+ assert len(mock_executor.pending_jobs) == 1
+ assert mock_executor.pending_jobs[0].queue == "fast-queue"
+
@mock.patch.object(batch_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor):
"""
@@ -445,7 +538,7 @@ class TestAwsBatchExecutor:
mock_executor.sync_running_jobs()
for i in range(2):
assert (
- f"Airflow task {airflow_keys[i]} failed due to
{jobs[i]['statusReason']}. Failure 1 out of
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}.
Rescheduling."
+ f"Airflow workload {airflow_keys[i]} failed due to
{jobs[i]['statusReason']}. Failure 1 out of
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}.
Rescheduling."
in caplog.messages[i]
)
@@ -454,7 +547,7 @@ class TestAwsBatchExecutor:
mock_executor.sync_running_jobs()
for i in range(2):
assert (
- f"Airflow task {airflow_keys[i]} failed due to
{jobs[i]['statusReason']}. Failure 2 out of
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}.
Rescheduling."
+ f"Airflow workload {airflow_keys[i]} failed due to
{jobs[i]['statusReason']}. Failure 2 out of
{mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}.
Rescheduling."
in caplog.messages[i]
)
@@ -463,7 +556,7 @@ class TestAwsBatchExecutor:
mock_executor.sync_running_jobs()
for i in range(2):
assert (
- f"Airflow task {airflow_keys[i]} has failed a maximum of
{mock_executor.max_submit_job_attempts} times. Marking as failed"
+ f"Airflow workload {airflow_keys[i]} has failed a maximum of
{mock_executor.max_submit_job_attempts} times. Marking as failed"
in caplog.text
)
@@ -479,7 +572,7 @@ class TestAwsBatchExecutor:
caplog.set_level("DEBUG")
assert len(mock_executor.active_workers.get_all_jobs()) == 0
mock_executor.sync_running_jobs()
- assert "No active Airflow tasks, skipping sync" in caplog.messages[0]
+ assert "No active Airflow workloads, skipping sync" in
caplog.messages[0]
def test_sync_client_error(self, mock_executor, caplog):
mock_executor.execute_async("airflow_key", "airflow_cmd")
@@ -498,7 +591,7 @@ class TestAwsBatchExecutor:
def test_sync_exception(self, mock_executor, caplog):
mock_executor.active_workers.add_job(
job_id="job_id",
- airflow_task_key="airflow_key",
+ airflow_workload_key="airflow_key",
airflow_cmd="command",
queue="queue",
exec_config={},
@@ -616,7 +709,7 @@ class TestAwsBatchExecutor:
def test_terminate_failure(self, mock_executor, caplog):
mock_executor.active_workers.add_job(
job_id="job_id",
- airflow_task_key="airflow_key",
+ airflow_workload_key="airflow_key",
airflow_cmd="command",
queue="queue",
exec_config={},
@@ -662,7 +755,7 @@ class TestAwsBatchExecutor:
"""
executor.active_workers.add_job(
job_id=job_id,
- airflow_task_key=airflow_key,
+ airflow_workload_key=airflow_key,
airflow_cmd="airflow_cmd",
queue="queue",
exec_config={},
@@ -965,7 +1058,9 @@ class TestBatchExecutorConfig:
executor = AwsBatchExecutor()
- final_run_task_kwargs = executor._submit_job_kwargs(mock_ti_key,
command, "queue", exec_config)
+ final_run_task_kwargs = executor._submit_job_kwargs(
+ mock_ti_key, command, expected_result["jobQueue"], exec_config
+ )
assert final_run_task_kwargs == expected_result
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py
index c0b914bdb3b..3ee8dca2628 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py
@@ -152,7 +152,7 @@ class TestBatchJobCollection:
"""Test adding a job to the collection."""
self.collection.add_job(
job_id=self.job_id1,
- airflow_task_key=self.key1,
+ airflow_workload_key=self.key1,
airflow_cmd=self.cmd1,
queue=self.queue1,
exec_config=self.config1,
@@ -170,7 +170,7 @@ class TestBatchJobCollection:
"""Test adding multiple jobs to the collection."""
self.collection.add_job(
job_id=self.job_id1,
- airflow_task_key=self.key1,
+ airflow_workload_key=self.key1,
airflow_cmd=self.cmd1,
queue=self.queue1,
exec_config=self.config1,
@@ -178,7 +178,7 @@ class TestBatchJobCollection:
)
self.collection.add_job(
job_id=self.job_id2,
- airflow_task_key=self.key2,
+ airflow_workload_key=self.key2,
airflow_cmd=self.cmd2,
queue=self.queue2,
exec_config=self.config2,
@@ -194,7 +194,7 @@ class TestBatchJobCollection:
"""Test removing a job from the collection by its ID."""
self.collection.add_job(
job_id=self.job_id1,
- airflow_task_key=self.key1,
+ airflow_workload_key=self.key1,
airflow_cmd=self.cmd1,
queue=self.queue1,
exec_config=self.config1,
@@ -222,7 +222,7 @@ class TestBatchJobCollection:
attempt_number = 5
self.collection.add_job(
job_id=self.job_id1,
- airflow_task_key=self.key1,
+ airflow_workload_key=self.key1,
airflow_cmd=self.cmd1,
queue=self.queue1,
exec_config=self.config1,
@@ -240,7 +240,7 @@ class TestBatchJobCollection:
initial_attempt = 1
self.collection.add_job(
job_id=self.job_id1,
- airflow_task_key=self.key1,
+ airflow_workload_key=self.key1,
airflow_cmd=self.cmd1,
queue=self.queue1,
exec_config=self.config1,
@@ -264,7 +264,7 @@ class TestBatchJobCollection:
"""Test getting all job IDs from a collection with jobs."""
self.collection.add_job(
job_id=self.job_id1,
- airflow_task_key=self.key1,
+ airflow_workload_key=self.key1,
airflow_cmd=self.cmd1,
queue=self.queue1,
exec_config=self.config1,
@@ -272,7 +272,7 @@ class TestBatchJobCollection:
)
self.collection.add_job(
job_id=self.job_id2,
- airflow_task_key=self.key2,
+ airflow_workload_key=self.key2,
airflow_cmd=self.cmd2,
queue=self.queue2,
exec_config=self.config2,
@@ -288,7 +288,7 @@ class TestBatchJobCollection:
assert len(self.collection) == 0
self.collection.add_job(
job_id=self.job_id1,
- airflow_task_key=self.key1,
+ airflow_workload_key=self.key1,
airflow_cmd=self.cmd1,
queue=self.queue1,
exec_config=self.config1,
@@ -297,7 +297,7 @@ class TestBatchJobCollection:
assert len(self.collection) == 1
self.collection.add_job(
job_id=self.job_id2,
- airflow_task_key=self.key2,
+ airflow_workload_key=self.key2,
airflow_cmd=self.cmd2,
queue=self.queue2,
exec_config=self.config2,
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
index 16d634fbdef..7fc388c9608 100644
--- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
+++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
@@ -173,7 +173,7 @@ class CeleryExecutor(BaseExecutor):
self._send_workloads(task_tuples_to_send)
- def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
+ def _process_workloads(self, workload_items: Sequence[workloads.All]) ->
None:
# Airflow V3 version -- have to delay imports until we know we are on
v3.
from airflow.executors.workloads import ExecuteTask
@@ -181,7 +181,7 @@ class CeleryExecutor(BaseExecutor):
from airflow.executors.workloads import ExecuteCallback
workloads_to_be_sent: list[WorkloadInCelery] = []
- for workload in workloads:
+ for workload in workload_items:
if isinstance(workload, ExecuteTask):
workloads_to_be_sent.append((workload.ti.key, workload,
workload.ti.queue, self.team_name))
elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):