This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 824722219b4 Fix provider executor tests broken in main (#67268)
824722219b4 is described below
commit 824722219b471020dbc898da37f6e4d51bdad9b1
Author: Anish Giri <[email protected]>
AuthorDate: Thu May 21 11:31:39 2026 -0500
Fix provider executor tests broken in main (#67268)
---
.../aws/executors/aws_lambda/lambda_executor.py | 5 +++--
.../amazon/aws/executors/batch/batch_executor.py | 5 ++---
.../executors/aws_lambda/test_lambda_executor.py | 10 ++++++----
.../aws/executors/batch/test_batch_executor.py | 19 +++++++++++--------
.../amazon/aws/executors/ecs/test_ecs_executor.py | 22 ++++++++++++++++------
.../unit/celery/executors/test_celery_executor.py | 2 +-
.../executors/test_kubernetes_executor.py | 2 +-
7 files changed, 40 insertions(+), 25 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
index d4b3bc30a5f..72455165e83 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
@@ -345,8 +345,9 @@ class AwsLambdaExecutor(BaseExecutor):
try:
ser_workload_key = json.dumps(workload_key._asdict())
except AttributeError:
- # Callback workloads use string id.
- ser_workload_key = workload_key
+ # Callback workloads use CallbackKey (or legacy string id);
both have a
+ # str() representation that round-trips through JSON.
+ ser_workload_key = str(workload_key)
payload = {
"task_key": ser_workload_key,
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
index 835dfe2e1c4..e04464883c0 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+ from airflow.providers.amazon.aws.executors.batch.utils import
BatchJobWorkloadKey
from airflow.providers.amazon.aws.executors.batch.boto_schema import (
@@ -402,9 +403,7 @@ class AwsBatchExecutor(BaseExecutor):
all_jobs.extend(describe_workloads_response["jobs"])
return all_jobs
- def execute_async(
- self, key: TaskInstanceKey | str, command: CommandType, queue=None,
executor_config=None
- ):
+ def execute_async(self, key: BatchJobWorkloadKey, command: CommandType,
queue=None, executor_config=None):
"""Save the workload to be executed in the next sync using Boto3's
RunTask API."""
if executor_config and "command" in executor_config:
raise ValueError('Executor Config should never override "command"')
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
index d98ceaa5a89..284865285f1 100644
---
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
+++
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
@@ -64,7 +64,7 @@ def set_env_vars():
@pytest.fixture
def mock_airflow_key():
def _key():
- key_mock = mock.Mock()
+ key_mock = mock.Mock(spec=TaskInstanceKey)
# Use a "random" value (memory id of the mock obj) so each key
serializes uniquely
key_mock._asdict = mock.Mock(return_value={"mock_key": id(key_mock)})
return key_mock
@@ -180,10 +180,12 @@ class TestAwsLambdaExecutor:
def test_task_sdk_callback(self, mock_executor):
"""Test task sdk callback execution end-to-end."""
from airflow.executors.workloads import ExecuteCallback
+ from airflow.models.callback import CallbackKey
- callback_id = "callback_123"
+ callback_id = CallbackKey("callback_123")
workload = mock.Mock(spec=ExecuteCallback)
+ workload.key = callback_id
workload.callback = mock.Mock()
workload.callback.key = callback_id
workload.callback.data = {}
@@ -212,7 +214,7 @@ class TestAwsLambdaExecutor:
mock_executor.attempt_workload_runs()
mock_executor.lambda_client.invoke.assert_called_once()
payload =
json.loads(mock_executor.lambda_client.invoke.call_args.kwargs["Payload"])
- assert payload["task_key"] == callback_id
+ assert payload["task_key"] == str(callback_id)
assert payload["command"] == [
"python",
"-m",
@@ -223,7 +225,7 @@ class TestAwsLambdaExecutor:
# Callback is stored in running workloads.
assert len(mock_executor.running_workloads) == 1
- assert callback_id in mock_executor.running_workloads
+ assert str(callback_id) in mock_executor.running_workloads
@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
def test_task_sdk_callback_with_queue(self, mock_airflow_key,
mock_executor):
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
index 00c3d622dcb..4695ca1d47a 100644
---
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
+++
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
@@ -88,7 +88,7 @@ def mock_executor(set_env_vars) -> AwsBatchExecutor:
@pytest.fixture(autouse=True)
def mock_airflow_key():
- return mock.Mock(spec=list)
+ return mock.Mock(spec=TaskInstanceKey)
@pytest.fixture(autouse=True)
@@ -108,7 +108,7 @@ class TestBatchJobCollection:
self.collection = BatchJobCollection()
# Add first task
self.first_job_id = "001"
- self.first_airflow_key = mock.Mock(spec=tuple)
+ self.first_airflow_key = mock.Mock(spec=TaskInstanceKey)
self.collection.add_job(
job_id=self.first_job_id,
airflow_workload_key=self.first_airflow_key,
@@ -119,7 +119,7 @@ class TestBatchJobCollection:
)
# Add second task
self.second_job_id = "002"
- self.second_airflow_key = mock.Mock(spec=tuple)
+ self.second_airflow_key = mock.Mock(spec=TaskInstanceKey)
self.collection.add_job(
job_id=self.second_job_id,
airflow_workload_key=self.second_airflow_key,
@@ -190,7 +190,7 @@ class TestAwsBatchExecutor:
def test_execute(self, mock_executor):
"""Test execution from end-to-end"""
- airflow_key = mock.Mock(spec=tuple)
+ airflow_key = mock.Mock(spec=TaskInstanceKey)
airflow_cmd = ["1", "2"]
mock_executor.batch.submit_job.return_value = {"jobId": MOCK_JOB_ID,
"jobName": "some-job-name"}
@@ -480,7 +480,7 @@ class TestAwsBatchExecutor:
def test_attempt_submit_jobs_failure(self, mock_executor):
mock_executor.batch.submit_job.side_effect = NoCredentialsError()
- mock_executor.execute_async("airflow_key", "airflow_cmd")
+ mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey),
"airflow_cmd")
assert len(mock_executor.pending_jobs) == 1
with pytest.raises(NoCredentialsError, match="Unable to locate
credentials"):
mock_executor.attempt_submit_jobs()
@@ -501,7 +501,10 @@ class TestAwsBatchExecutor:
@mock.patch.object(batch_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
def test_task_retry_on_api_failure(self, _, mock_executor, caplog):
"""Test API failure retries"""
- airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
+ airflow_keys = [
+ TaskInstanceKey("dag", "task1", "run", 1, -1),
+ TaskInstanceKey("dag", "task2", "run", 1, -1),
+ ]
airflow_cmds = [["1", "2"], ["3", "4"]]
mock_executor.execute_async(airflow_keys[0], airflow_cmds[0])
@@ -575,7 +578,7 @@ class TestAwsBatchExecutor:
assert "No active Airflow workloads, skipping sync" in
caplog.messages[0]
def test_sync_client_error(self, mock_executor, caplog):
- mock_executor.execute_async("airflow_key", "airflow_cmd")
+ mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey),
"airflow_cmd")
assert len(mock_executor.pending_jobs) == 1
mock_resp = {
"Error": {
@@ -1053,7 +1056,7 @@ class TestBatchExecutorConfig:
)
os.environ[submit_job_kwargs_env_key] = json.dumps(submit_job_kwargs)
- mock_ti_key = mock.Mock(spec=tuple)
+ mock_ti_key = mock.Mock(spec=TaskInstanceKey)
command = ["command"]
executor = AwsBatchExecutor()
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
index f57cd13d28a..ca2d1e255b2 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -110,7 +110,7 @@ def mock_task(arn=ARN1, state=State.RUNNING):
@pytest.fixture(autouse=True)
def mock_airflow_key():
def _key():
- return mock.Mock(spec=tuple)
+ return mock.Mock(spec=TaskInstanceKey)
return _key
@@ -519,7 +519,7 @@ class TestAwsEcsExecutor:
"failures": [],
}
mock_executor.ecs.run_task.side_effect = [run_task_exception,
run_task_exception, run_task_success]
- mock_executor.execute_async(mock_airflow_key, mock_cmd)
+ mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), mock_cmd)
expected_retry_count = 2
# Fail 2 times
@@ -669,7 +669,7 @@ class TestAwsEcsExecutor:
mock_executor.ecs.run_task.call_args_list.clear()
# queue new task
- airflow_keys[1] = mock.Mock(spec=tuple)
+ airflow_keys[1] = mock.Mock(spec=TaskInstanceKey)
airflow_commands[1] = _generate_mock_cmd()
mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
@@ -710,7 +710,10 @@ class TestAwsEcsExecutor:
Test API failure retries.
"""
mock_executor.max_run_task_attempts = "2"
- airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
+ airflow_keys = [
+ TaskInstanceKey("dag", "task1", "run", 1, -1),
+ TaskInstanceKey("dag", "task2", "run", 1, -1),
+ ]
airflow_commands = [_generate_mock_cmd(), _generate_mock_cmd()]
mock_executor.execute_async(airflow_keys[0], airflow_commands[0])
@@ -955,7 +958,7 @@ class TestAwsEcsExecutor:
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
def test_failed_sync_api(self, _, success_mock, fail_mock, mock_executor,
mock_cmd):
"""Test what happens when ECS sync fails for certain tasks
repeatedly."""
- airflow_key = "test-key"
+ airflow_key = TaskInstanceKey("dag", "task", "run", 1, -1)
mock_executor.execute_async(airflow_key, mock_cmd)
assert len(mock_executor.pending_tasks) == 1
@@ -1148,7 +1151,14 @@ class TestAwsEcsExecutor:
@staticmethod
def _add_mock_task(executor: AwsEcsExecutor, arn: str,
state=TaskInstanceState.RUNNING):
task = mock_task(arn, state)
- executor.active_workers.add_task(task, mock.Mock(spec=tuple),
mock_queue, mock_cmd, mock_config, 1) # type:ignore[arg-type]
+ executor.active_workers.add_task(
+ task,
+ mock.Mock(spec=TaskInstanceKey),
+ mock_queue, # type:ignore[arg-type]
+ mock_cmd, # type:ignore[arg-type]
+ mock_config, # type:ignore[arg-type]
+ 1,
+ )
def _sync_mock_with_call_counts(self, sync_func: Callable):
"""Mock won't work here, because we actually want to call the 'sync'
func."""
diff --git
a/providers/celery/tests/unit/celery/executors/test_celery_executor.py
b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
index 3dcd32e84ef..c11ea80a5ba 100644
--- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py
+++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
@@ -919,7 +919,7 @@ def
test_process_workloads_routes_execute_callback(mock_send_workloads, callback
executor = celery_executor.CeleryExecutor()
executor._process_workloads([workload])
- mock_send_workloads.assert_called_once_with([(callback_id, workload,
expected_queue, None)])
+ mock_send_workloads.assert_called_once_with([(workload.callback.key,
workload, expected_queue, None)])
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="execute_workload is only
used for Airflow 3+")
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
index 9269c4debd0..2930eb7f2c4 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
@@ -808,7 +808,7 @@ class TestKubernetesExecutor:
try:
assert executor.event_buffer == {}
executor.execute_async(
- key=("dag", "task", timezone.utcnow(), 1),
+ key=TaskInstanceKey("dag", "task", "run_id", 1, -1),
queue=None,
command=["airflow", "tasks", "run", "true", "some_parameter"],
executor_config=k8s.V1Pod(