This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 824722219b4 Fix provider executor tests broken in main (#67268)
824722219b4 is described below

commit 824722219b471020dbc898da37f6e4d51bdad9b1
Author: Anish Giri <[email protected]>
AuthorDate: Thu May 21 11:31:39 2026 -0500

    Fix provider executor tests broken in main (#67268)
---
 .../aws/executors/aws_lambda/lambda_executor.py    |  5 +++--
 .../amazon/aws/executors/batch/batch_executor.py   |  5 ++---
 .../executors/aws_lambda/test_lambda_executor.py   | 10 ++++++----
 .../aws/executors/batch/test_batch_executor.py     | 19 +++++++++++--------
 .../amazon/aws/executors/ecs/test_ecs_executor.py  | 22 ++++++++++++++++------
 .../unit/celery/executors/test_celery_executor.py  |  2 +-
 .../executors/test_kubernetes_executor.py          |  2 +-
 7 files changed, 40 insertions(+), 25 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
index d4b3bc30a5f..72455165e83 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
@@ -345,8 +345,9 @@ class AwsLambdaExecutor(BaseExecutor):
             try:
                 ser_workload_key = json.dumps(workload_key._asdict())
             except AttributeError:
-                # Callback workloads use string id.
-                ser_workload_key = workload_key
+                # Callback workloads use CallbackKey (or legacy string id); 
both have a
+                # str() representation that round-trips through JSON.
+                ser_workload_key = str(workload_key)
 
             payload = {
                 "task_key": ser_workload_key,
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
index 835dfe2e1c4..e04464883c0 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
 
     from airflow.executors import workloads
     from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+    from airflow.providers.amazon.aws.executors.batch.utils import 
BatchJobWorkloadKey
 
 
 from airflow.providers.amazon.aws.executors.batch.boto_schema import (
@@ -402,9 +403,7 @@ class AwsBatchExecutor(BaseExecutor):
             all_jobs.extend(describe_workloads_response["jobs"])
         return all_jobs
 
-    def execute_async(
-        self, key: TaskInstanceKey | str, command: CommandType, queue=None, 
executor_config=None
-    ):
+    def execute_async(self, key: BatchJobWorkloadKey, command: CommandType, 
queue=None, executor_config=None):
         """Save the workload to be executed in the next sync using Boto3's 
RunTask API."""
         if executor_config and "command" in executor_config:
             raise ValueError('Executor Config should never override "command"')
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
 
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
index d98ceaa5a89..284865285f1 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
@@ -64,7 +64,7 @@ def set_env_vars():
 @pytest.fixture
 def mock_airflow_key():
     def _key():
-        key_mock = mock.Mock()
+        key_mock = mock.Mock(spec=TaskInstanceKey)
         # Use a "random" value (memory id of the mock obj) so each key 
serializes uniquely
         key_mock._asdict = mock.Mock(return_value={"mock_key": id(key_mock)})
         return key_mock
@@ -180,10 +180,12 @@ class TestAwsLambdaExecutor:
     def test_task_sdk_callback(self, mock_executor):
         """Test task sdk callback execution end-to-end."""
         from airflow.executors.workloads import ExecuteCallback
+        from airflow.models.callback import CallbackKey
 
-        callback_id = "callback_123"
+        callback_id = CallbackKey("callback_123")
 
         workload = mock.Mock(spec=ExecuteCallback)
+        workload.key = callback_id
         workload.callback = mock.Mock()
         workload.callback.key = callback_id
         workload.callback.data = {}
@@ -212,7 +214,7 @@ class TestAwsLambdaExecutor:
         mock_executor.attempt_workload_runs()
         mock_executor.lambda_client.invoke.assert_called_once()
         payload = 
json.loads(mock_executor.lambda_client.invoke.call_args.kwargs["Payload"])
-        assert payload["task_key"] == callback_id
+        assert payload["task_key"] == str(callback_id)
         assert payload["command"] == [
             "python",
             "-m",
@@ -223,7 +225,7 @@ class TestAwsLambdaExecutor:
 
         # Callback is stored in running workloads.
         assert len(mock_executor.running_workloads) == 1
-        assert callback_id in mock_executor.running_workloads
+        assert str(callback_id) in mock_executor.running_workloads
 
     @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 
3.3+")
     def test_task_sdk_callback_with_queue(self, mock_airflow_key, 
mock_executor):
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py 
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
index 00c3d622dcb..4695ca1d47a 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py
@@ -88,7 +88,7 @@ def mock_executor(set_env_vars) -> AwsBatchExecutor:
 
 @pytest.fixture(autouse=True)
 def mock_airflow_key():
-    return mock.Mock(spec=list)
+    return mock.Mock(spec=TaskInstanceKey)
 
 
 @pytest.fixture(autouse=True)
@@ -108,7 +108,7 @@ class TestBatchJobCollection:
         self.collection = BatchJobCollection()
         # Add first task
         self.first_job_id = "001"
-        self.first_airflow_key = mock.Mock(spec=tuple)
+        self.first_airflow_key = mock.Mock(spec=TaskInstanceKey)
         self.collection.add_job(
             job_id=self.first_job_id,
             airflow_workload_key=self.first_airflow_key,
@@ -119,7 +119,7 @@ class TestBatchJobCollection:
         )
         # Add second task
         self.second_job_id = "002"
-        self.second_airflow_key = mock.Mock(spec=tuple)
+        self.second_airflow_key = mock.Mock(spec=TaskInstanceKey)
         self.collection.add_job(
             job_id=self.second_job_id,
             airflow_workload_key=self.second_airflow_key,
@@ -190,7 +190,7 @@ class TestAwsBatchExecutor:
 
     def test_execute(self, mock_executor):
         """Test execution from end-to-end"""
-        airflow_key = mock.Mock(spec=tuple)
+        airflow_key = mock.Mock(spec=TaskInstanceKey)
         airflow_cmd = ["1", "2"]
 
         mock_executor.batch.submit_job.return_value = {"jobId": MOCK_JOB_ID, 
"jobName": "some-job-name"}
@@ -480,7 +480,7 @@ class TestAwsBatchExecutor:
 
     def test_attempt_submit_jobs_failure(self, mock_executor):
         mock_executor.batch.submit_job.side_effect = NoCredentialsError()
-        mock_executor.execute_async("airflow_key", "airflow_cmd")
+        mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), 
"airflow_cmd")
         assert len(mock_executor.pending_jobs) == 1
         with pytest.raises(NoCredentialsError, match="Unable to locate 
credentials"):
             mock_executor.attempt_submit_jobs()
@@ -501,7 +501,10 @@ class TestAwsBatchExecutor:
     @mock.patch.object(batch_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
     def test_task_retry_on_api_failure(self, _, mock_executor, caplog):
         """Test API failure retries"""
-        airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
+        airflow_keys = [
+            TaskInstanceKey("dag", "task1", "run", 1, -1),
+            TaskInstanceKey("dag", "task2", "run", 1, -1),
+        ]
         airflow_cmds = [["1", "2"], ["3", "4"]]
 
         mock_executor.execute_async(airflow_keys[0], airflow_cmds[0])
@@ -575,7 +578,7 @@ class TestAwsBatchExecutor:
         assert "No active Airflow workloads, skipping sync" in 
caplog.messages[0]
 
     def test_sync_client_error(self, mock_executor, caplog):
-        mock_executor.execute_async("airflow_key", "airflow_cmd")
+        mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), 
"airflow_cmd")
         assert len(mock_executor.pending_jobs) == 1
         mock_resp = {
             "Error": {
@@ -1053,7 +1056,7 @@ class TestBatchExecutorConfig:
         )
         os.environ[submit_job_kwargs_env_key] = json.dumps(submit_job_kwargs)
 
-        mock_ti_key = mock.Mock(spec=tuple)
+        mock_ti_key = mock.Mock(spec=TaskInstanceKey)
         command = ["command"]
 
         executor = AwsBatchExecutor()
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py 
b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
index f57cd13d28a..ca2d1e255b2 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -110,7 +110,7 @@ def mock_task(arn=ARN1, state=State.RUNNING):
 @pytest.fixture(autouse=True)
 def mock_airflow_key():
     def _key():
-        return mock.Mock(spec=tuple)
+        return mock.Mock(spec=TaskInstanceKey)
 
     return _key
 
@@ -519,7 +519,7 @@ class TestAwsEcsExecutor:
             "failures": [],
         }
         mock_executor.ecs.run_task.side_effect = [run_task_exception, 
run_task_exception, run_task_success]
-        mock_executor.execute_async(mock_airflow_key, mock_cmd)
+        mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), mock_cmd)
         expected_retry_count = 2
 
         # Fail 2 times
@@ -669,7 +669,7 @@ class TestAwsEcsExecutor:
         mock_executor.ecs.run_task.call_args_list.clear()
 
         # queue new task
-        airflow_keys[1] = mock.Mock(spec=tuple)
+        airflow_keys[1] = mock.Mock(spec=TaskInstanceKey)
         airflow_commands[1] = _generate_mock_cmd()
         mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
 
@@ -710,7 +710,10 @@ class TestAwsEcsExecutor:
         Test API failure retries.
         """
         mock_executor.max_run_task_attempts = "2"
-        airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
+        airflow_keys = [
+            TaskInstanceKey("dag", "task1", "run", 1, -1),
+            TaskInstanceKey("dag", "task2", "run", 1, -1),
+        ]
         airflow_commands = [_generate_mock_cmd(), _generate_mock_cmd()]
 
         mock_executor.execute_async(airflow_keys[0], airflow_commands[0])
@@ -955,7 +958,7 @@ class TestAwsEcsExecutor:
     @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
     def test_failed_sync_api(self, _, success_mock, fail_mock, mock_executor, 
mock_cmd):
         """Test what happens when ECS sync fails for certain tasks 
repeatedly."""
-        airflow_key = "test-key"
+        airflow_key = TaskInstanceKey("dag", "task", "run", 1, -1)
         mock_executor.execute_async(airflow_key, mock_cmd)
         assert len(mock_executor.pending_tasks) == 1
 
@@ -1148,7 +1151,14 @@ class TestAwsEcsExecutor:
     @staticmethod
     def _add_mock_task(executor: AwsEcsExecutor, arn: str, 
state=TaskInstanceState.RUNNING):
         task = mock_task(arn, state)
-        executor.active_workers.add_task(task, mock.Mock(spec=tuple), 
mock_queue, mock_cmd, mock_config, 1)  # type:ignore[arg-type]
+        executor.active_workers.add_task(
+            task,
+            mock.Mock(spec=TaskInstanceKey),
+            mock_queue,  # type:ignore[arg-type]
+            mock_cmd,  # type:ignore[arg-type]
+            mock_config,  # type:ignore[arg-type]
+            1,
+        )
 
     def _sync_mock_with_call_counts(self, sync_func: Callable):
         """Mock won't work here, because we actually want to call the 'sync' 
func."""
diff --git 
a/providers/celery/tests/unit/celery/executors/test_celery_executor.py 
b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
index 3dcd32e84ef..c11ea80a5ba 100644
--- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py
+++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
@@ -919,7 +919,7 @@ def 
test_process_workloads_routes_execute_callback(mock_send_workloads, callback
     executor = celery_executor.CeleryExecutor()
     executor._process_workloads([workload])
 
-    mock_send_workloads.assert_called_once_with([(callback_id, workload, 
expected_queue, None)])
+    mock_send_workloads.assert_called_once_with([(workload.callback.key, 
workload, expected_queue, None)])
 
 
 @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="execute_workload is only 
used for Airflow 3+")
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
index 9269c4debd0..2930eb7f2c4 100644
--- 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
+++ 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py
@@ -808,7 +808,7 @@ class TestKubernetesExecutor:
         try:
             assert executor.event_buffer == {}
             executor.execute_async(
-                key=("dag", "task", timezone.utcnow(), 1),
+                key=TaskInstanceKey("dag", "task", "run_id", 1, -1),
                 queue=None,
                 command=["airflow", "tasks", "run", "true", "some_parameter"],
                 executor_config=k8s.V1Pod(

Reply via email to