This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2933257b46f Allow SSM operators and sensors to run in deferrable mode
(#55649)
2933257b46f is described below
commit 2933257b46fc45e0fd6a4be4f88ac3d56e2bdf3b
Author: Sean Ghaeli <[email protected]>
AuthorDate: Mon Sep 15 13:01:03 2025 -0700
Allow SSM operators and sensors to run in deferrable mode (#55649)
* Fix test and also update deprecated hook connection in ssm trigger
* Remove the explicit forcing of the Ssm components to be in deferrable mode
* Update async connection in ssm testing
---
.../airflow/providers/amazon/aws/triggers/ssm.py | 4 +--
.../amazon/tests/system/amazon/aws/example_ssm.py | 2 +-
.../tests/unit/amazon/aws/triggers/test_ssm.py | 30 ++++++++++++----------
3 files changed, 19 insertions(+), 17 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
index f7efc916c72..94d0697584a 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
@@ -66,8 +66,8 @@ class SsmRunCommandTrigger(AwsBaseWaiterTrigger):
async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self.hook()
- async with hook.async_conn as client:
- response =
client.list_command_invocations(CommandId=self.command_id)
+ async with await hook.get_async_conn() as client:
+ response = await
client.list_command_invocations(CommandId=self.command_id)
instance_ids = [invocation["InstanceId"] for invocation in
response.get("CommandInvocations", [])]
waiter = hook.get_waiter(self.waiter_name, deferrable=True,
client=client)
diff --git a/providers/amazon/tests/system/amazon/aws/example_ssm.py
b/providers/amazon/tests/system/amazon/aws/example_ssm.py
index f0d6371e006..3af62dcb63e 100644
--- a/providers/amazon/tests/system/amazon/aws/example_ssm.py
+++ b/providers/amazon/tests/system/amazon/aws/example_ssm.py
@@ -203,7 +203,7 @@ with DAG(
# [START howto_sensor_run_command]
await_run_command = SsmRunCommandCompletedSensor(
- task_id="await_run_command", command_id=run_command.output
+ task_id="await_run_command", command_id="{{
ti.xcom_pull(task_ids='run_command') }}"
)
# [END howto_sensor_run_command]
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
index 994a989e8a8..d19acc0de1b 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
@@ -37,15 +37,17 @@ INSTANCE_ID_2 = "i-1234567890abcdef1"
@pytest.fixture
def mock_ssm_list_invocations():
- def _setup(mock_async_conn):
+ def _setup(mock_get_async_conn):
mock_client = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = mock_client
- mock_client.list_command_invocations.return_value = {
- "CommandInvocations": [
- {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_1},
- {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_2},
- ]
- }
+ mock_get_async_conn.return_value.__aenter__.return_value = mock_client
+ mock_client.list_command_invocations = mock.AsyncMock(
+ return_value={
+ "CommandInvocations": [
+ {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_1},
+ {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_2},
+ ]
+ }
+ )
return mock_client
return _setup
@@ -60,10 +62,10 @@ class TestSsmRunCommandTrigger:
assert kwargs.get("command_id") == COMMAND_ID
@pytest.mark.asyncio
- @mock.patch.object(SsmHook, "async_conn")
+ @mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
- async def test_run_success(self, mock_get_waiter, mock_async_conn,
mock_ssm_list_invocations):
- mock_client = mock_ssm_list_invocations(mock_async_conn)
+ async def test_run_success(self, mock_get_waiter, mock_get_async_conn,
mock_ssm_list_invocations):
+ mock_client = mock_ssm_list_invocations(mock_get_async_conn)
mock_get_waiter().wait = mock.AsyncMock(name="wait")
trigger = SsmRunCommandTrigger(command_id=COMMAND_ID)
@@ -82,10 +84,10 @@ class TestSsmRunCommandTrigger:
mock_client.list_command_invocations.assert_called_once_with(CommandId=COMMAND_ID)
@pytest.mark.asyncio
- @mock.patch.object(SsmHook, "async_conn")
+ @mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
- async def test_run_fails(self, mock_get_waiter, mock_async_conn,
mock_ssm_list_invocations):
- mock_ssm_list_invocations(mock_async_conn)
+ async def test_run_fails(self, mock_get_waiter, mock_get_async_conn,
mock_ssm_list_invocations):
+ mock_ssm_list_invocations(mock_get_async_conn)
mock_get_waiter().wait.side_effect = WaiterError(
"name", "terminal failure", {"CommandInvocations": [{"CommandId":
COMMAND_ID}]}
)