This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d5f81a4e2d Switch AzureDataLakeStorageV2Hook to use 
DefaultAzureCredential for managed identity/workload auth (#38497)
d5f81a4e2d is described below

commit d5f81a4e2de0d4236cffcf2e2d3c682b4c6ec355
Author: Tamara Janina Fingerlin <90063506+tja...@users.noreply.github.com>
AuthorDate: Mon May 27 02:28:39 2024 +0200

    Switch AzureDataLakeStorageV2Hook to use DefaultAzureCredential for managed 
identity/workload auth (#38497)
---
 .../providers/microsoft/azure/hooks/data_lake.py   |  7 +++---
 .../microsoft/azure/hooks/test_data_factory.py     | 29 +++++++++++++++++++++-
 2 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py 
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 054eda087e..b2d9c5aafa 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -22,7 +22,7 @@ from typing import Any, Union
 
 from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
 from azure.datalake.store import core, lib, multithread
-from azure.identity import ClientSecretCredential
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
 from azure.storage.filedatalake import (
     DataLakeDirectoryClient,
     DataLakeFileClient,
@@ -38,9 +38,10 @@ from airflow.providers.microsoft.azure.utils import (
     AzureIdentityCredentialAdapter,
     add_managed_identity_connection_widgets,
     get_field,
+    get_sync_default_azure_credential,
 )
 
-Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter]
+Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter, 
DefaultAzureCredential]
 
 
 class AzureDataLakeHook(BaseHook):
@@ -358,7 +359,7 @@ class AzureDataLakeStorageV2Hook(BaseHook):
         else:
             managed_identity_client_id = self._get_field(extra, 
"managed_identity_client_id")
             workload_identity_tenant_id = self._get_field(extra, 
"workload_identity_tenant_id")
-            credential = AzureIdentityCredentialAdapter(
+            credential = get_sync_default_azure_credential(
                 managed_identity_client_id=managed_identity_client_id,
                 workload_identity_tenant_id=workload_identity_tenant_id,
             )
diff --git a/tests/providers/microsoft/azure/hooks/test_data_factory.py 
b/tests/providers/microsoft/azure/hooks/test_data_factory.py
index 1ee77ad3af..a7d8786fd8 100644
--- a/tests/providers/microsoft/azure/hooks/test_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_data_factory.py
@@ -86,8 +86,8 @@ def setup_connections(create_mock_connections):
                 "factory_name": DEFAULT_FACTORY,
             },
         ),
+        # connection_missing_subscription_id
         Connection(
-            # connection_missing_subscription_id
             conn_id="azure_data_factory_missing_subscription_id",
             conn_type="azure_data_factory",
             login="clientId",
@@ -110,6 +110,18 @@ def setup_connections(create_mock_connections):
                 "factory_name": DEFAULT_FACTORY,
             },
         ),
+        # connection_workload_identity
+        Connection(
+            conn_id="azure_data_factory_workload_identity",
+            conn_type="azure_data_factory",
+            extra={
+                "subscriptionId": "subscriptionId",
+                "resource_group_name": DEFAULT_RESOURCE_GROUP,
+                "factory_name": DEFAULT_FACTORY,
+                "workload_identity_tenant_id": "workload_tenant_id",
+                "managed_identity_client_id": "workload_client_id",
+            },
+        ),
     )
 
 
@@ -198,6 +210,21 @@ def 
test_get_conn_by_default_azure_credential(mock_credential):
         mock_create_client.assert_called_with(mock_credential(), 
"subscriptionId")
 
 
+@mock.patch(f"{MODULE}.get_sync_default_azure_credential")
+def test_get_conn_with_workload_identity(mock_credential):
+    hook = AzureDataFactoryHook("azure_data_factory_workload_identity")
+    with patch.object(hook, "_create_client") as mock_create_client:
+        mock_create_client.return_value = MagicMock()
+
+        connection = hook.get_conn()
+        assert connection is not None
+        mock_credential.assert_called_once_with(
+            managed_identity_client_id="workload_client_id",
+            workload_identity_tenant_id="workload_tenant_id",
+        )
+        mock_create_client.assert_called_with(mock_credential(), 
"subscriptionId")
+
+
 def test_get_factory(hook: AzureDataFactoryHook):
     hook.get_factory(RESOURCE_GROUP, FACTORY)
 

Reply via email to