SameerMesiah97 commented on code in PR #62772:
URL: https://github.com/apache/airflow/pull/62772#discussion_r2893208264


##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py:
##########
@@ -170,3 +177,93 @@ def test_connection(self):
             return False, str(e)
 
         return True, "Successfully connected to Azure Container Instance."
+
+
+class AzureContainerInstanceAsyncHook(AzureContainerInstanceHook):
+    """
+    An async hook for communicating with Azure Container Instances.
+
+    :param azure_conn_id: :ref:`Azure connection id<howto/connection:azure>` of
+        a service principal which will be used to start the container instance.
+    """
+
+    def __init__(self, azure_conn_id: str = 
AzureContainerInstanceHook.default_conn_name) -> None:
+        self._async_conn: AsyncContainerInstanceManagementClient | None = None
+        super().__init__(azure_conn_id=azure_conn_id)
+
+    async def __aenter__(self) -> AzureContainerInstanceAsyncHook:
+        return self
+
+    async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> 
None:
+        await self.close()
+
+    async def close(self) -> None:
+        """Close the async connection."""
+        if self._async_conn is not None:
+            await self._async_conn.close()
+            self._async_conn = None

Review Comment:
   The async credential created in `get_async_conn()` isn't stored or closed in 
`close()`. Some Azure async credentials support `close()`. Would it make sense 
to keep a reference and close it here too?



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py:
##########
@@ -193,6 +201,8 @@ def __init__(
         diagnostics: ContainerGroupDiagnostics | None = None,
         priority: str | None = "Regular",
         identity: ContainerGroupIdentity | dict | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        polling_interval: float = 5.0,

Review Comment:
   Why 5 seconds? How long do you expect Azure Container Instances to run? It's 
seconds to minutes, this is fine. But if they can run for hours, I think 5 
seconds is a bit too aggressive. 10-30 seconds would be more reasonable.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py:
##########
@@ -170,3 +177,93 @@ def test_connection(self):
             return False, str(e)
 
         return True, "Successfully connected to Azure Container Instance."
+
+
+class AzureContainerInstanceAsyncHook(AzureContainerInstanceHook):
+    """
+    An async hook for communicating with Azure Container Instances.
+
+    :param azure_conn_id: :ref:`Azure connection id<howto/connection:azure>` of
+        a service principal which will be used to start the container instance.
+    """
+
+    def __init__(self, azure_conn_id: str = 
AzureContainerInstanceHook.default_conn_name) -> None:
+        self._async_conn: AsyncContainerInstanceManagementClient | None = None
+        super().__init__(azure_conn_id=azure_conn_id)
+
+    async def __aenter__(self) -> AzureContainerInstanceAsyncHook:
+        return self
+
+    async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> 
None:
+        await self.close()
+
+    async def close(self) -> None:
+        """Close the async connection."""
+        if self._async_conn is not None:
+            await self._async_conn.close()
+            self._async_conn = None
+
+    async def get_async_conn(self) -> AsyncContainerInstanceManagementClient:
+        """Return (or create) the async management client."""
+        if self._async_conn is not None:
+            return self._async_conn
+
+        conn = self.get_connection(self.conn_id)
+        tenant = conn.extra_dejson.get("tenantId")
+        subscription_id = cast("str", conn.extra_dejson.get("subscriptionId"))
+
+        if all([conn.login, conn.password, tenant]):
+            credential: Any = AsyncClientSecretCredential(
+                client_id=cast("str", conn.login),
+                client_secret=cast("str", conn.password),
+                tenant_id=cast("str", tenant),
+            )
+        else:
+            managed_identity_client_id = 
conn.extra_dejson.get("managed_identity_client_id")
+            workload_identity_tenant_id = 
conn.extra_dejson.get("workload_identity_tenant_id")
+            credential = get_async_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
+
+        self._async_conn = AsyncContainerInstanceManagementClient(
+            credential=credential,
+            subscription_id=subscription_id,
+        )
+        return self._async_conn
+
+    async def get_state(self, resource_group: str, name: str) -> 
ContainerGroup:  # type: ignore[override]
+        """
+        Get the state of a container group asynchronously.
+
+        :param resource_group: the name of the resource group
+        :param name: the name of the container group
+        :return: ContainerGroup
+        """
+        client = await self.get_async_conn()
+        return await client.container_groups.get(resource_group, name)
+
+    async def get_logs(self, resource_group: str, name: str, tail: int = 1000) 
-> list:  # type: ignore[override]
+        """
+        Get the tail from logs of a container group asynchronously.
+
+        :param resource_group: the name of the resource group
+        :param name: the name of the container group
+        :param tail: the size of the tail
+        :return: A list of log messages
+        """
+        client = await self.get_async_conn()
+        logs = await client.containers.list_logs(resource_group, name, name, 
tail=tail)
+        if logs.content is None:
+            return [None]
+        return logs.content.splitlines(True)
+
+    async def delete(self, resource_group: str, name: str) -> None:  # type: 
ignore[override]
+        """
+        Delete a container group asynchronously.
+
+        :param resource_group: the name of the resource group
+        :param name: the name of the container group
+        """
+        client = await self.get_async_conn()
+        await client.container_groups.begin_delete(resource_group, name)

Review Comment:
   Should the async hook mirror the existing sync behavior (fire-and-forget), 
or should it await the poller to ensure deletion completes and surface errors? 
Did you consider something like this:
   
   ```
   poller = await client.container_groups.begin_delete(
           resource_group,
           name
       )
   
   await poller.result()
   ```
   
   The key thing to understand here is that `begin_delete` does not return the 
result but an LRO poller object which waits for the result. I don't think it's 
necessarily wrong to keep the replicate the mechanics of the sync hook here but 
I was just wondering if you had thought of this approach. 



##########
providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_container_instance.py:
##########
@@ -160,3 +167,177 @@ def test_get_conn_fallback_to_default_azure_credential(
             credential=mock_credential,
             subscription_id="subscription_id",
         )
+
+
[email protected]
+def async_conn_with_credentials(create_mock_connection):
+    return create_mock_connection(
+        Connection(
+            conn_id="azure_aci_async_test",
+            conn_type="azure_container_instance",
+            login="client-id",
+            password="client-secret",
+            extra={
+                "tenantId": "tenant-id",
+                "subscriptionId": "subscription-id",
+            },
+        )
+    )
+
+
[email protected]
+def async_conn_without_credentials(create_mock_connection):
+    return create_mock_connection(
+        Connection(
+            conn_id="azure_aci_async_no_creds",
+            conn_type="azure_container_instance",
+            extra={"subscriptionId": "subscription-id"},
+        )
+    )
+
+
+class TestAzureContainerInstanceAsyncHook:
+    @patch(
+        
"airflow.providers.microsoft.azure.hooks.container_instance.AsyncContainerInstanceManagementClient"
+    )
+    
@patch("airflow.providers.microsoft.azure.hooks.container_instance.AsyncClientSecretCredential")
+    @pytest.mark.asyncio
+    async def test_get_async_conn_with_client_secret(
+        self,
+        mock_credential_cls,
+        mock_client_cls,
+        async_conn_with_credentials,
+    ):
+        mock_credential = MagicMock(spec=AsyncClientSecretCredential)
+        mock_credential_cls.return_value = mock_credential
+        mock_client_instance = 
MagicMock(spec=AsyncContainerInstanceManagementClient)
+        mock_client_cls.return_value = mock_client_instance
+
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        conn = await hook.get_async_conn()
+
+        mock_credential_cls.assert_called_once_with(
+            client_id="client-id",
+            client_secret="client-secret",
+            tenant_id="tenant-id",
+        )
+        mock_client_cls.assert_called_once_with(
+            credential=mock_credential,
+            subscription_id="subscription-id",
+        )
+        assert conn == mock_client_instance
+
+    @patch(
+        
"airflow.providers.microsoft.azure.hooks.container_instance.AsyncContainerInstanceManagementClient"
+    )
+    
@patch("airflow.providers.microsoft.azure.hooks.container_instance.get_async_default_azure_credential")
+    @pytest.mark.asyncio
+    async def test_get_async_conn_with_default_credential(
+        self,
+        mock_default_cred,
+        mock_client_cls,
+        async_conn_without_credentials,
+    ):
+        mock_credential = MagicMock(spec=AsyncClientSecretCredential)
+        mock_default_cred.return_value = mock_credential
+        mock_client_instance = 
MagicMock(spec=AsyncContainerInstanceManagementClient)
+        mock_client_cls.return_value = mock_client_instance
+
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_without_credentials.conn_id)
+        conn = await hook.get_async_conn()
+
+        mock_default_cred.assert_called_once_with(
+            managed_identity_client_id=None,
+            workload_identity_tenant_id=None,
+        )
+        assert conn == mock_client_instance
+
+    @pytest.mark.asyncio
+    async def test_get_async_conn_returns_cached_connection(self, 
async_conn_with_credentials):
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        mock_conn = MagicMock(spec=AsyncContainerInstanceManagementClient)
+        hook._async_conn = mock_conn
+
+        conn = await hook.get_async_conn()
+        assert conn is mock_conn
+
+    @pytest.mark.asyncio
+    async def test_get_state(self, async_conn_with_credentials):
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        mock_client = MagicMock()
+        mock_cg = MagicMock(spec=ContainerGroup)
+        mock_client.container_groups.get = AsyncMock(return_value=mock_cg)
+        hook._async_conn = mock_client
+
+        result = await hook.get_state("my-rg", "my-container")
+
+        mock_client.container_groups.get.assert_called_once_with("my-rg", 
"my-container")
+        assert result is mock_cg
+
+    @pytest.mark.asyncio
+    async def test_get_logs(self, async_conn_with_credentials):
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        mock_client = MagicMock()
+        mock_logs = MagicMock(spec=Logs)
+        mock_logs.content = "line1\nline2\n"
+        mock_client.containers.list_logs = AsyncMock(return_value=mock_logs)
+        hook._async_conn = mock_client
+
+        result = await hook.get_logs("my-rg", "my-container")
+
+        mock_client.containers.list_logs.assert_called_once_with(
+            "my-rg", "my-container", "my-container", tail=1000
+        )
+        assert result == ["line1\n", "line2\n"]
+
+    @pytest.mark.asyncio
+    async def test_get_logs_returns_none_sentinel_when_content_is_none(self, 
async_conn_with_credentials):
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        mock_client = MagicMock()
+        mock_logs = MagicMock(spec=Logs)
+        mock_logs.content = None
+        mock_client.containers.list_logs = AsyncMock(return_value=mock_logs)
+        hook._async_conn = mock_client
+
+        result = await hook.get_logs("my-rg", "my-container")
+        assert result == [None]
+
+    @pytest.mark.asyncio
+    async def test_delete(self, async_conn_with_credentials):
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        mock_client = MagicMock()
+        mock_client.container_groups.begin_delete = AsyncMock()

Review Comment:
   If you adjust the implementation for deleting containers, you will have to 
change this too.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py:
##########
@@ -407,8 +445,12 @@ def execute(self, context: Context) -> int:
             raise AirflowException("Could not start container group")
 
         finally:
-            if exit_code == 0 or self.remove_on_error:
-                self.on_kill()
+            if _cleanup:
+                if exit_code == 0:
+                    if self.remove_on_success:
+                        self.on_kill()
+                elif self.remove_on_error:
+                    self.on_kill()

Review Comment:
   This can be made clearer like this:
   
   ```
   if _cleanup:
       if exit_code == 0 and self.remove_on_success:
           self.on_kill()
       elif exit_code != 0 and self.remove_on_error:
           self.on_kill()
   ```



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/container_instance.py:
##########
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from airflow.providers.microsoft.azure.hooks.container_instance import 
AzureContainerInstanceAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+TERMINAL_STATES = frozenset({"Terminated", "Succeeded", "Failed", "Unhealthy"})
+SUCCESS_STATES = frozenset({"Terminated", "Succeeded"})
+
+
+class AzureContainerInstanceTrigger(BaseTrigger):
+    """
+    Poll an Azure Container Instance until it reaches a terminal state.
+
+    :param resource_group: the name of the resource group
+    :param name: the name of the container group
+    :param ci_conn_id: connection id of the Azure service principal
+    :param polling_interval: time in seconds between state polls
+    """
+
+    def __init__(
+        self,
+        resource_group: str,
+        name: str,
+        ci_conn_id: str,
+        polling_interval: float = 5.0,
+    ) -> None:
+        super().__init__()
+        self.resource_group = resource_group
+        self.name = name
+        self.ci_conn_id = ci_conn_id
+        self.polling_interval = polling_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize trigger arguments and classpath."""
+        return (
+            
"airflow.providers.microsoft.azure.triggers.container_instance.AzureContainerInstanceTrigger",
+            {
+                "resource_group": self.resource_group,
+                "name": self.name,
+                "ci_conn_id": self.ci_conn_id,
+                "polling_interval": self.polling_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Poll ACI until a terminal state is reached, then yield a 
TriggerEvent."""
+        try:
+            async with 
AzureContainerInstanceAsyncHook(azure_conn_id=self.ci_conn_id) as hook:
+                while True:
+                    cg_state = await hook.get_state(self.resource_group, 
self.name)
+                    instance_view = cg_state.containers[0].instance_view
+
+                    if instance_view is not None:
+                        c_state = instance_view.current_state
+                        state = c_state.state
+                        exit_code = c_state.exit_code
+                        detail_status = c_state.detail_status
+                    else:
+                        state = cg_state.provisioning_state
+                        exit_code = 0
+                        detail_status = "Provisioning"
+
+                    self.log.info("Container group %s/%s state: %s", 
self.resource_group, self.name, state)

Review Comment:
   This does not need to be logged in every iteration. I think you only log 
container state during transitions. With large numbers of concurrent deferrable 
tasks, this will result in extreme log pollution in the triggerer.



##########
providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_container_instance.py:
##########
@@ -160,3 +167,177 @@ def test_get_conn_fallback_to_default_azure_credential(
             credential=mock_credential,
             subscription_id="subscription_id",
         )
+
+
[email protected]
+def async_conn_with_credentials(create_mock_connection):
+    return create_mock_connection(
+        Connection(
+            conn_id="azure_aci_async_test",
+            conn_type="azure_container_instance",
+            login="client-id",
+            password="client-secret",
+            extra={
+                "tenantId": "tenant-id",
+                "subscriptionId": "subscription-id",
+            },
+        )
+    )
+
+
[email protected]
+def async_conn_without_credentials(create_mock_connection):
+    return create_mock_connection(
+        Connection(
+            conn_id="azure_aci_async_no_creds",
+            conn_type="azure_container_instance",
+            extra={"subscriptionId": "subscription-id"},
+        )
+    )
+
+
+class TestAzureContainerInstanceAsyncHook:
+    @patch(
+        
"airflow.providers.microsoft.azure.hooks.container_instance.AsyncContainerInstanceManagementClient"
+    )
+    
@patch("airflow.providers.microsoft.azure.hooks.container_instance.AsyncClientSecretCredential")
+    @pytest.mark.asyncio
+    async def test_get_async_conn_with_client_secret(
+        self,
+        mock_credential_cls,
+        mock_client_cls,
+        async_conn_with_credentials,
+    ):
+        mock_credential = MagicMock(spec=AsyncClientSecretCredential)
+        mock_credential_cls.return_value = mock_credential
+        mock_client_instance = 
MagicMock(spec=AsyncContainerInstanceManagementClient)
+        mock_client_cls.return_value = mock_client_instance
+
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        conn = await hook.get_async_conn()
+
+        mock_credential_cls.assert_called_once_with(
+            client_id="client-id",
+            client_secret="client-secret",
+            tenant_id="tenant-id",
+        )
+        mock_client_cls.assert_called_once_with(
+            credential=mock_credential,
+            subscription_id="subscription-id",
+        )
+        assert conn == mock_client_instance
+
+    @patch(
+        
"airflow.providers.microsoft.azure.hooks.container_instance.AsyncContainerInstanceManagementClient"
+    )
+    
@patch("airflow.providers.microsoft.azure.hooks.container_instance.get_async_default_azure_credential")
+    @pytest.mark.asyncio
+    async def test_get_async_conn_with_default_credential(
+        self,
+        mock_default_cred,
+        mock_client_cls,
+        async_conn_without_credentials,
+    ):
+        mock_credential = MagicMock(spec=AsyncClientSecretCredential)
+        mock_default_cred.return_value = mock_credential
+        mock_client_instance = 
MagicMock(spec=AsyncContainerInstanceManagementClient)
+        mock_client_cls.return_value = mock_client_instance
+
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_without_credentials.conn_id)
+        conn = await hook.get_async_conn()
+
+        mock_default_cred.assert_called_once_with(
+            managed_identity_client_id=None,
+            workload_identity_tenant_id=None,
+        )
+        assert conn == mock_client_instance
+
+    @pytest.mark.asyncio
+    async def test_get_async_conn_returns_cached_connection(self, 
async_conn_with_credentials):
+        hook = 
AzureContainerInstanceAsyncHook(azure_conn_id=async_conn_with_credentials.conn_id)
+        mock_conn = MagicMock(spec=AsyncContainerInstanceManagementClient)
+        hook._async_conn = mock_conn
+
+        conn = await hook.get_async_conn()
+        assert conn is mock_conn

Review Comment:
   It might be worth adding a test that calls `get_async_conn() `twice and 
verifies that the same client instance is returned, to confirm the caching 
behavior.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py:
##########
@@ -303,12 +315,14 @@ def _ensure_identity(identity: ContainerGroupIdentity | 
dict | None) -> Containe
             )
         return identity
 
+    @cached_property
+    def _ci_hook(self) -> AzureContainerInstanceHook:
+        return AzureContainerInstanceHook(azure_conn_id=self.ci_conn_id)
+

Review Comment:
   Can you explain why you turned this into a cached property?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to