This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0d2011a3a52 Support user-assigned managed identity for Azure VM auth
(#66072)
0d2011a3a52 is described below
commit 0d2011a3a5267d34ff06448533d2c4718227237c
Author: Shen YuDong <[email protected]>
AuthorDate: Wed May 6 17:33:38 2026 +0800
Support user-assigned managed identity for Azure VM auth (#66072)
* Databricks: support user-assigned managed identity for Azure VM auth
This fixes an issue where Databricks connections could not specify
a user-assigned managed identity when Airflow runs on Azure VM.
Ref: #65588
* fix static check
---
.../databricks/docs/connections/databricks.rst | 1 +
.../providers/databricks/hooks/databricks_base.py | 13 +++++++--
.../unit/databricks/hooks/test_databricks_base.py | 33 ++++++++++++++++++++++
3 files changed, 45 insertions(+), 2 deletions(-)
diff --git a/providers/databricks/docs/connections/databricks.rst
b/providers/databricks/docs/connections/databricks.rst
index 634f61f2d4d..630526266be 100644
--- a/providers/databricks/docs/connections/databricks.rst
+++ b/providers/databricks/docs/connections/databricks.rst
@@ -99,6 +99,7 @@ Extra (optional)
* ``use_default_azure_credential``: required boolean flag to specify if
the `DefaultAzureCredential` class should be used to retrieve a AAD token. For
example, this can be used when authenticating with workload identity within an
Azure Kubernetes Service cluster. Note that this option can't be set together
with the `use_azure_managed_identity` parameter.
* ``azure_resource_id``: optional Resource ID of the Azure Databricks
workspace (required if managed identity isn't
a user inside workspace)
+ * ``azure_managed_identity_client_id``: optional client ID of the
user-assigned managed identity. This parameter is only required if you're using
a user-assigned managed identity. If not specified, the hook will attempt to
authenticate using a system-assigned managed identity.
The following parameters are necessary if using authentication with
Kubernetes OIDC token federation:
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
index fe9d2245705..7ab51fc7dfa 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -111,6 +111,7 @@ class BaseDatabricksHook(BaseHook):
"host",
"use_azure_managed_identity",
DEFAULT_AZURE_CREDENTIAL_SETTING_KEY,
+ "azure_managed_identity_client_id",
"azure_ad_endpoint",
"azure_resource_id",
"azure_tenant_id",
@@ -340,7 +341,12 @@ class BaseDatabricksHook(BaseHook):
for attempt in self._get_retry_object():
with attempt:
if
self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
- token =
ManagedIdentityCredential().get_token(f"{resource}/.default")
+ client_id = self.databricks_conn.extra_dejson.get(
+ "azure_managed_identity_client_id", None
+ )
+ token =
ManagedIdentityCredential(client_id=client_id).get_token(
+ f"{resource}/.default"
+ )
else:
credential = ClientSecretCredential(
client_id=self._get_connection_attr("login"),
@@ -387,7 +393,10 @@ class BaseDatabricksHook(BaseHook):
async for attempt in self._a_get_retry_object():
with attempt:
if
self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
- async with AsyncManagedIdentityCredential() as
credential:
+ client_id = self.databricks_conn.extra_dejson.get(
+ "azure_managed_identity_client_id", None
+ )
+ async with
AsyncManagedIdentityCredential(client_id=client_id) as credential:
token = await
credential.get_token(f"{resource}/.default")
else:
async with AsyncClientSecretCredential(
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
index 8c8794ebda9..088c3fc27bc 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+import json
from datetime import datetime, timedelta
from unittest import mock
@@ -466,6 +467,38 @@ class TestBaseDatabricksHook:
mock_get_aad_token.assert_called_once_with(DEFAULT_DATABRICKS_SCOPE)
mock_log_debug.assert_called_once_with("Using AAD Token for
managed identity.")
+ @mock.patch("azure.identity.ManagedIdentityCredential")
+ def test_get_aad_token_with_managed_identity_client_id(
+ self,
+ mock_credential,
+ ):
+ conn = Connection(
+ host="example.databricks.com",
+ extra=json.dumps(
+ {
+ "use_azure_managed_identity": True,
+ "azure_managed_identity_client_id": "cli-id-abc",
+ }
+ ),
+ )
+
+ token_mock = mock.Mock()
+ token_mock.token = "the-token"
+ token_mock.expires_on = 8888888888
+ mock_credential.return_value.get_token.return_value = token_mock
+
+ hook = BaseDatabricksHook()
+ hook.databricks_conn = conn
+ hook.oauth_tokens = {}
+ hook._get_retry_object = lambda: [
+ mock.Mock(__enter__=lambda s: None, __exit__=lambda s, a, b, c:
None)
+ ]
+
+ token = hook._get_aad_token("https://databricks.azure.com")
+
+ assert token == "the-token"
+ mock_credential.assert_called_once_with(client_id="cli-id-abc")
+
@mock.patch(
"airflow.providers.databricks.hooks.databricks_base.BaseDatabricksHook.databricks_conn",
new_callable=mock.PropertyMock,