This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 947b504d5b Add DefaultAzureCredential support to AzureBatchHook 
(#33469)
947b504d5b is described below

commit 947b504d5ba5882b1d7d36251e24185e9f47b9e7
Author: Wei Lee <[email protected]>
AuthorDate: Sat Aug 26 01:49:46 2023 +0800

    Add DefaultAzureCredential support to AzureBatchHook (#33469)
    
    * feat(providers/microsoft): add DefaultAzureCredential support to 
AzureBatchHook
    
    * fix(providers/microsfot): replace DefaultAzureCredential with 
AzureIdentityCredentialAdapter
    
    azure-batch does not directly supports DefaultAzureCredential
    
    * test(providers/microsoft): add test case 
test_fallback_to_azure_identity_credential_adppter_when_name_and_key_is_not_provided
---
 airflow/providers/microsoft/azure/hooks/batch.py     | 13 ++++++++++---
 .../microsoft/azure/hooks/test_azure_batch.py        | 20 ++++++++++++++++++++
 2 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/batch.py 
b/airflow/providers/microsoft/azure/hooks/batch.py
index a5494b94ca..deca28216d 100644
--- a/airflow/providers/microsoft/azure/hooks/batch.py
+++ b/airflow/providers/microsoft/azure/hooks/batch.py
@@ -27,7 +27,7 @@ from azure.batch.models import JobAddParameter, 
PoolAddParameter, TaskAddParamet
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.models import Connection
-from airflow.providers.microsoft.azure.utils import get_field
+from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter, get_field
 from airflow.utils import timezone
 
 
@@ -96,7 +96,15 @@ class AzureBatchHook(BaseHook):
         if not batch_account_url:
             raise AirflowException("Batch Account URL parameter is missing.")
 
-        credentials = batch_auth.SharedKeyCredentials(conn.login, 
conn.password)
+        credentials: batch_auth.SharedKeyCredentials | 
AzureIdentityCredentialAdapter
+        if all([conn.login, conn.password]):
+            credentials = batch_auth.SharedKeyCredentials(conn.login, 
conn.password)
+        else:
+            credentials = AzureIdentityCredentialAdapter(
+                None, resource_id="https://batch.core.windows.net/.default";
+            )
+            # credentials = AzureIdentityCredentialAdapter()
+
         batch_client = BatchServiceClient(credentials, 
batch_url=batch_account_url)
         return batch_client
 
@@ -344,7 +352,6 @@ class AzureBatchHook(BaseHook):
         :param task: The task to add
         """
         try:
-
             self.connection.task.add(job_id=job_id, task=task)
         except batch_models.BatchErrorException as err:
             if not err.error or err.error.code != "TaskExists":
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py 
b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index 05a8864f50..a3a421f5a0 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -27,6 +27,8 @@ from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
 from airflow.utils import db
 
+MODULE = "airflow.providers.microsoft.azure.hooks.batch"
+
 
 class TestAzureBatchHook:
     # set up the test environment
@@ -67,6 +69,24 @@ class TestAzureBatchHook:
         assert isinstance(hook._connection(), Connection)
         assert isinstance(hook.get_conn(), BatchServiceClient)
 
+    @mock.patch(f"{MODULE}.batch_auth.SharedKeyCredentials")
+    @mock.patch(f"{MODULE}.AzureIdentityCredentialAdapter")
+    def 
test_fallback_to_azure_identity_credential_adppter_when_name_and_key_is_not_provided(
+        self, mock_azure_identity_credential_adapter, 
mock_shared_key_credentials
+    ):
+        self.test_account_name = None
+        self.test_account_key = None
+
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
+        assert isinstance(hook.get_conn(), BatchServiceClient)
+        mock_azure_identity_credential_adapter.assert_called_with(
+            None, resource_id="https://batch.core.windows.net/.default";
+        )
+        assert not mock_shared_key_credentials.auth.called
+
+        self.test_account_name = "test_account_name"
+        self.test_account_key = "test_account_key"
+
     def test_configure_pool_with_vm_config(self):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
         pool = hook.configure_pool(

Reply via email to