This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 947b504d5b Add DefaultAzureCredential support to AzureBatchHook
(#33469)
947b504d5b is described below
commit 947b504d5ba5882b1d7d36251e24185e9f47b9e7
Author: Wei Lee <[email protected]>
AuthorDate: Sat Aug 26 01:49:46 2023 +0800
Add DefaultAzureCredential support to AzureBatchHook (#33469)
* feat(providers/microsoft): add DefaultAzureCredential support to
AzureBatchHook
* fix(providers/microsfot): replace DefaultAzureCredential with
AzureIdentityCredentialAdapter
azure-batch does not directly supports DefaultAzureCredential
* test(providers/microsoft): add test case
test_fallback_to_azure_identity_credential_adppter_when_name_and_key_is_not_provided
---
airflow/providers/microsoft/azure/hooks/batch.py | 13 ++++++++++---
.../microsoft/azure/hooks/test_azure_batch.py | 20 ++++++++++++++++++++
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/batch.py
b/airflow/providers/microsoft/azure/hooks/batch.py
index a5494b94ca..deca28216d 100644
--- a/airflow/providers/microsoft/azure/hooks/batch.py
+++ b/airflow/providers/microsoft/azure/hooks/batch.py
@@ -27,7 +27,7 @@ from azure.batch.models import JobAddParameter,
PoolAddParameter, TaskAddParamet
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
-from airflow.providers.microsoft.azure.utils import get_field
+from airflow.providers.microsoft.azure.utils import
AzureIdentityCredentialAdapter, get_field
from airflow.utils import timezone
@@ -96,7 +96,15 @@ class AzureBatchHook(BaseHook):
if not batch_account_url:
raise AirflowException("Batch Account URL parameter is missing.")
- credentials = batch_auth.SharedKeyCredentials(conn.login,
conn.password)
+ credentials: batch_auth.SharedKeyCredentials |
AzureIdentityCredentialAdapter
+ if all([conn.login, conn.password]):
+ credentials = batch_auth.SharedKeyCredentials(conn.login,
conn.password)
+ else:
+ credentials = AzureIdentityCredentialAdapter(
+ None, resource_id="https://batch.core.windows.net/.default"
+ )
+ # credentials = AzureIdentityCredentialAdapter()
+
batch_client = BatchServiceClient(credentials,
batch_url=batch_account_url)
return batch_client
@@ -344,7 +352,6 @@ class AzureBatchHook(BaseHook):
:param task: The task to add
"""
try:
-
self.connection.task.add(job_id=job_id, task=task)
except batch_models.BatchErrorException as err:
if not err.error or err.error.code != "TaskExists":
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index 05a8864f50..a3a421f5a0 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -27,6 +27,8 @@ from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
from airflow.utils import db
+MODULE = "airflow.providers.microsoft.azure.hooks.batch"
+
class TestAzureBatchHook:
# set up the test environment
@@ -67,6 +69,24 @@ class TestAzureBatchHook:
assert isinstance(hook._connection(), Connection)
assert isinstance(hook.get_conn(), BatchServiceClient)
+ @mock.patch(f"{MODULE}.batch_auth.SharedKeyCredentials")
+ @mock.patch(f"{MODULE}.AzureIdentityCredentialAdapter")
+ def
test_fallback_to_azure_identity_credential_adppter_when_name_and_key_is_not_provided(
+ self, mock_azure_identity_credential_adapter,
mock_shared_key_credentials
+ ):
+ self.test_account_name = None
+ self.test_account_key = None
+
+ hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
+ assert isinstance(hook.get_conn(), BatchServiceClient)
+ mock_azure_identity_credential_adapter.assert_called_with(
+ None, resource_id="https://batch.core.windows.net/.default"
+ )
+ assert not mock_shared_key_credentials.auth.called
+
+ self.test_account_name = "test_account_name"
+ self.test_account_key = "test_account_key"
+
def test_configure_pool_with_vm_config(self):
hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
pool = hook.configure_pool(