This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2933257b46f Allow SSM operators and sensors to run in deferrable mode 
(#55649)
2933257b46f is described below

commit 2933257b46fc45e0fd6a4be4f88ac3d56e2bdf3b
Author: Sean Ghaeli <[email protected]>
AuthorDate: Mon Sep 15 13:01:03 2025 -0700

    Allow SSM operators and sensors to run in deferrable mode (#55649)
    
    * Fix test and also update deprecated hook connection in ssm trigger
    
    * Remove the explicit forcing of the Ssm components to be in deferrable mode
    
    * Update async connection in ssm testing
---
 .../airflow/providers/amazon/aws/triggers/ssm.py   |  4 +--
 .../amazon/tests/system/amazon/aws/example_ssm.py  |  2 +-
 .../tests/unit/amazon/aws/triggers/test_ssm.py     | 30 ++++++++++++----------
 3 files changed, 19 insertions(+), 17 deletions(-)

diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
index f7efc916c72..94d0697584a 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
@@ -66,8 +66,8 @@ class SsmRunCommandTrigger(AwsBaseWaiterTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         hook = self.hook()
-        async with hook.async_conn as client:
-            response = 
client.list_command_invocations(CommandId=self.command_id)
+        async with await hook.get_async_conn() as client:
+            response = await 
client.list_command_invocations(CommandId=self.command_id)
             instance_ids = [invocation["InstanceId"] for invocation in 
response.get("CommandInvocations", [])]
             waiter = hook.get_waiter(self.waiter_name, deferrable=True, 
client=client)
 
diff --git a/providers/amazon/tests/system/amazon/aws/example_ssm.py 
b/providers/amazon/tests/system/amazon/aws/example_ssm.py
index f0d6371e006..3af62dcb63e 100644
--- a/providers/amazon/tests/system/amazon/aws/example_ssm.py
+++ b/providers/amazon/tests/system/amazon/aws/example_ssm.py
@@ -203,7 +203,7 @@ with DAG(
 
     # [START howto_sensor_run_command]
     await_run_command = SsmRunCommandCompletedSensor(
-        task_id="await_run_command", command_id=run_command.output
+        task_id="await_run_command", command_id="{{ 
ti.xcom_pull(task_ids='run_command') }}"
     )
     # [END howto_sensor_run_command]
 
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
index 994a989e8a8..d19acc0de1b 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
@@ -37,15 +37,17 @@ INSTANCE_ID_2 = "i-1234567890abcdef1"
 
 @pytest.fixture
 def mock_ssm_list_invocations():
-    def _setup(mock_async_conn):
+    def _setup(mock_get_async_conn):
         mock_client = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = mock_client
-        mock_client.list_command_invocations.return_value = {
-            "CommandInvocations": [
-                {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_1},
-                {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_2},
-            ]
-        }
+        mock_get_async_conn.return_value.__aenter__.return_value = mock_client
+        mock_client.list_command_invocations = mock.AsyncMock(
+            return_value={
+                "CommandInvocations": [
+                    {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_1},
+                    {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_2},
+                ]
+            }
+        )
         return mock_client
 
     return _setup
@@ -60,10 +62,10 @@ class TestSsmRunCommandTrigger:
         assert kwargs.get("command_id") == COMMAND_ID
 
     @pytest.mark.asyncio
-    @mock.patch.object(SsmHook, "async_conn")
+    @mock.patch.object(SsmHook, "get_async_conn")
     @mock.patch.object(SsmHook, "get_waiter")
-    async def test_run_success(self, mock_get_waiter, mock_async_conn, 
mock_ssm_list_invocations):
-        mock_client = mock_ssm_list_invocations(mock_async_conn)
+    async def test_run_success(self, mock_get_waiter, mock_get_async_conn, 
mock_ssm_list_invocations):
+        mock_client = mock_ssm_list_invocations(mock_get_async_conn)
         mock_get_waiter().wait = mock.AsyncMock(name="wait")
 
         trigger = SsmRunCommandTrigger(command_id=COMMAND_ID)
@@ -82,10 +84,10 @@ class TestSsmRunCommandTrigger:
         
mock_client.list_command_invocations.assert_called_once_with(CommandId=COMMAND_ID)
 
     @pytest.mark.asyncio
-    @mock.patch.object(SsmHook, "async_conn")
+    @mock.patch.object(SsmHook, "get_async_conn")
     @mock.patch.object(SsmHook, "get_waiter")
-    async def test_run_fails(self, mock_get_waiter, mock_async_conn, 
mock_ssm_list_invocations):
-        mock_ssm_list_invocations(mock_async_conn)
+    async def test_run_fails(self, mock_get_waiter, mock_get_async_conn, 
mock_ssm_list_invocations):
+        mock_ssm_list_invocations(mock_get_async_conn)
         mock_get_waiter().wait.side_effect = WaiterError(
             "name", "terminal failure", {"CommandInvocations": [{"CommandId": 
COMMAND_ID}]}
         )

Reply via email to