This is an automated email from the ASF dual-hosted git repository.
dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9928bb3292e Use async versions of CertificateCredential and
ClientSecretCredential in KiotaRequestAdapterHook (#68375)
9928bb3292e is described below
commit 9928bb3292e9cda8e54f0cf624061b0798f521d4
Author: David Blain <[email protected]>
AuthorDate: Thu Jun 18 08:54:36 2026 +0200
Use async versions of CertificateCredential and ClientSecretCredential in
KiotaRequestAdapterHook (#68375)
* refactor: Use async versions of CertificateCredential and
ClientSecretCredential to avoid blocking the event loop, especially when used
concurrently
* refactor: Fix stale cached request adapter causing "HTTP transport has
already been closed" errors
* refactor: Invalidate cached request adapters which have closed session in
KiotaRequestAdapterHook
---
.../providers/microsoft/azure/hooks/msgraph.py | 52 +++++--
.../unit/microsoft/azure/hooks/test_msgraph.py | 160 +++++++++++++++++----
2 files changed, 166 insertions(+), 46 deletions(-)
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 32e21f61ebe..2f3bf3a4030 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, Any, cast
from urllib.parse import quote, urljoin, urlparse
import httpx
-from azure.identity import CertificateCredential, ClientSecretCredential
+from azure.identity.aio import CertificateCredential, ClientSecretCredential
from httpx import AsyncHTTPTransport, Response, Timeout
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
@@ -50,18 +50,16 @@ from msgraph_core._enums import NationalClouds
from airflow.exceptions import AirflowBadRequest, AirflowConfigException,
AirflowProviderDeprecationWarning
from airflow.providers.common.compat.connection import get_async_connection
-from airflow.providers.common.compat.sdk import AirflowException,
AirflowNotFoundException, BaseHook
+from airflow.providers.common.compat.sdk import AirflowException,
AirflowNotFoundException, BaseHook, redact
if TYPE_CHECKING:
- from azure.identity._internal.client_credential_base import
ClientCredentialBase
+ from azure.core.credentials_async import AsyncTokenCredential
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory
from airflow.providers.common.compat.sdk import Connection
-from airflow.providers.common.compat.sdk import redact
-
PaginationCallable = Callable[..., tuple[str, dict[str, Any] | None]]
@@ -366,7 +364,6 @@ class KiotaRequestAdapterHook(BaseHook):
http_client=http_client,
base_url=base_url,
)
- self.cached_request_adapters[self.conn_id] = (api_version,
request_adapter)
return api_version, request_adapter
def get_conn(self) -> RequestAdapter:
@@ -374,7 +371,7 @@ class KiotaRequestAdapterHook(BaseHook):
Initiate a new RequestAdapter connection.
.. warning::
- This method is deprecated.
+ This method is deprecated. Use :meth:`get_async_conn` instead.
"""
if not self.conn_id:
raise AirflowException("Failed to create the
KiotaRequestAdapterHook. No conn_id provided!")
@@ -390,9 +387,15 @@ class KiotaRequestAdapterHook(BaseHook):
if not request_adapter:
connection = self.get_connection(conn_id=self.conn_id)
api_version, request_adapter =
self._build_request_adapter(connection)
+ self.cached_request_adapters[self.conn_id] = (api_version,
request_adapter)
self.api_version = api_version
return request_adapter
+ @staticmethod
+ def _is_http_client_closed(request_adapter: RequestAdapter) -> bool:
+ """Return True when the underlying httpx AsyncClient has been
closed."""
+ return cast("HttpxRequestAdapter",
request_adapter)._http_client.is_closed
+
async def get_async_conn(self) -> RequestAdapter:
"""Initiate a new RequestAdapter connection asynchronously."""
if not self.conn_id:
@@ -400,9 +403,19 @@ class KiotaRequestAdapterHook(BaseHook):
api_version, request_adapter =
self.cached_request_adapters.get(self.conn_id, (None, None))
+ if request_adapter and self._is_http_client_closed(request_adapter):
+ self.log.warning(
+ "Cached request adapter for conn_id '%s' has a closed HTTP
client. Rebuilding.",
+ self.conn_id,
+ )
+ self.cached_request_adapters.pop(self.conn_id, None)
+ request_adapter = None
+
if not request_adapter:
connection = await get_async_connection(conn_id=self.conn_id)
api_version, request_adapter =
self._build_request_adapter(connection)
+ self.cached_request_adapters[self.conn_id] = (api_version,
request_adapter)
+
self.api_version = api_version
return request_adapter
@@ -433,7 +446,7 @@ class KiotaRequestAdapterHook(BaseHook):
authority: str | None,
verify: bool,
proxies: dict | None,
- ) -> ClientCredentialBase:
+ ) -> AsyncTokenCredential:
tenant_id = config.get("tenant_id") or config.get("tenantId")
certificate_path = config.get("certificate_path")
certificate_data = config.get("certificate_data")
@@ -582,16 +595,25 @@ class KiotaRequestAdapterHook(BaseHook):
async def send_request(self, request_info: RequestInformation,
response_type: str | None = None):
conn = await self.get_async_conn()
- if response_type:
- return await conn.send_primitive_async(
+ try:
+ if response_type:
+ return await conn.send_primitive_async(
+ request_info=request_info,
+ response_type=response_type,
+ error_map=self.error_mapping(),
+ )
+ return await conn.send_no_response_content_async(
request_info=request_info,
- response_type=response_type,
error_map=self.error_mapping(),
)
- return await conn.send_no_response_content_async(
- request_info=request_info,
- error_map=self.error_mapping(),
- )
+ except Exception as e:
+ self.log.warning(
+ "Request failed for conn_id '%s': %s. Invalidating cached
request adapter.",
+ self.conn_id,
+ e,
+ )
+ self.cached_request_adapters.pop(self.conn_id, None)
+ raise
def request_information(
self,
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
index 96443b5a67d..dab96656d43 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
@@ -18,10 +18,11 @@ from __future__ import annotations
import asyncio
import inspect
+from contextlib import AbstractAsyncContextManager
from json import JSONDecodeError
from os.path import dirname
-from typing import TYPE_CHECKING, cast
-from unittest.mock import Mock, patch
+from typing import cast
+from unittest.mock import AsyncMock, Mock, patch
import pytest
from httpx import Response
@@ -52,31 +53,8 @@ from unit.microsoft.azure.test_utils import (
patch_hook_and_request_adapter,
)
-if TYPE_CHECKING:
- from azure.identity._internal.msal_credentials import MsalCredential
- from kiota_abstractions.authentication import
BaseBearerTokenAuthenticationProvider
- from kiota_abstractions.request_adapter import RequestAdapter
- from kiota_authentication_azure.azure_identity_access_token_provider
import (
- AzureIdentityAccessTokenProvider,
- )
-
class TestKiotaRequestAdapterHook:
- @staticmethod
- def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id:
str):
- adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter",
request_adapter)
- auth_provider: BaseBearerTokenAuthenticationProvider = cast(
- "BaseBearerTokenAuthenticationProvider",
- adapter._authentication_provider,
- )
- access_token_provider: AzureIdentityAccessTokenProvider = cast(
- "AzureIdentityAccessTokenProvider",
- auth_provider.access_token_provider,
- )
- credentials: MsalCredential = cast("MsalCredential",
access_token_provider._credentials)
- tenant_id = credentials._tenant_id
- assert tenant_id == expected_tenant_id
-
def test_get_conn(self):
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
@@ -276,10 +254,15 @@ class TestKiotaRequestAdapterHook:
@pytest.mark.asyncio
async def test_tenant_id(self):
with patch_hook():
- hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
- actual = await hook.get_async_conn()
+ with patch(
+
"airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential",
+ autospec=True,
+ ) as mock_credential_cls:
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ await hook.get_async_conn()
- self.assert_tenant_id(actual, "tenant-id")
+ mock_credential_cls.assert_called_once()
+ assert mock_credential_cls.call_args.kwargs.get("tenant_id")
== "tenant-id"
@pytest.mark.asyncio
async def test_azure_tenant_id(self):
@@ -289,10 +272,15 @@ class TestKiotaRequestAdapterHook:
azure_tenant_id="azure-tenant-id",
)
):
- hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
- actual = await hook.get_async_conn()
+ with patch(
+
"airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential",
+ autospec=True,
+ ) as mock_credential_cls:
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ await hook.get_async_conn()
- self.assert_tenant_id(actual, "azure-tenant-id")
+ mock_credential_cls.assert_called_once()
+ assert mock_credential_cls.call_args.kwargs.get("tenant_id")
== "azure-tenant-id"
@pytest.mark.asyncio
async def test_proxies(self):
@@ -472,6 +460,116 @@ class TestKiotaRequestAdapterHook:
assert result == proxies
+ def test_get_credentials_returns_async_client_secret_credential(self):
+ """get_credentials must return an async context manager
(azure.identity.aio credential)."""
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ config = {"tenant_id": "tenant-id"}
+
+ credentials = hook.get_credentials(
+ login="client_id",
+ password="client_secret",
+ config=config,
+ authority=None,
+ verify=True,
+ proxies=None,
+ )
+
+ assert isinstance(credentials, AbstractAsyncContextManager)
+
+ def test_get_credentials_returns_async_certificate_credential(self):
+ """get_credentials must return an async context manager when
certificate_data is set."""
+ import datetime
+
+ from cryptography import x509
+ from cryptography.hazmat.primitives import hashes, serialization
+ from cryptography.hazmat.primitives.asymmetric import rsa
+ from cryptography.x509.oid import NameOID
+
+ private_key = rsa.generate_private_key(public_exponent=65537,
key_size=2048)
+ name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
+ cert = (
+ x509.CertificateBuilder()
+ .subject_name(name)
+ .issuer_name(name)
+ .public_key(private_key.public_key())
+ .serial_number(x509.random_serial_number())
+ .not_valid_before(datetime.datetime.now(datetime.timezone.utc))
+ .not_valid_after(datetime.datetime.now(datetime.timezone.utc) +
datetime.timedelta(days=1))
+ .sign(private_key, hashes.SHA256())
+ )
+ pem = private_key.private_bytes(
+ serialization.Encoding.PEM,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ serialization.NoEncryption(),
+ ) + cert.public_bytes(serialization.Encoding.PEM)
+
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ config = {
+ "tenant_id": "tenant-id",
+ "certificate_data": pem.decode(),
+ }
+
+ credentials = hook.get_credentials(
+ login="client_id",
+ password=None,
+ config=config,
+ authority=None,
+ verify=True,
+ proxies=None,
+ )
+
+ assert isinstance(credentials, AbstractAsyncContextManager)
+
+ @pytest.mark.asyncio
+ async def test_get_async_conn_uses_async_credentials(self):
+ """get_async_conn must build a request adapter backed by async
credentials."""
+ with patch_hook():
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ request_adapter = await hook.get_async_conn()
+
+ adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter",
request_adapter)
+ # Reach into the auth provider chain to retrieve the underlying
credential object.
+ access_token_provider =
adapter._authentication_provider.access_token_provider
+ credentials = access_token_provider._credentials
+
+ assert isinstance(credentials, AbstractAsyncContextManager)
+
+ @pytest.mark.asyncio
+ async def
test_get_async_conn_rebuilds_adapter_when_http_client_is_closed(self):
+ """get_async_conn evicts and rebuilds the adapter when the cached HTTP
client is already closed."""
+ with patch_hook():
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+ stale_adapter = Mock(spec=HttpxRequestAdapter)
+ stale_adapter._http_client = Mock(is_closed=True)
+ hook.cached_request_adapters[hook.conn_id] = (hook.api_version,
stale_adapter)
+
+ fresh_adapter = Mock(spec=HttpxRequestAdapter)
+ fresh_adapter._http_client = Mock(is_closed=False)
+
+ with patch.object(hook, "_build_request_adapter",
return_value=("v1.0", fresh_adapter)):
+ result = await hook.get_async_conn()
+
+ assert result is fresh_adapter
+ assert hook.cached_request_adapters[hook.conn_id] == ("v1.0",
fresh_adapter)
+
+ @pytest.mark.asyncio
+ async def
test_send_request_invalidates_cache_and_raises_on_any_error(self):
+ """send_request evicts the cached adapter and re-raises on any request
error."""
+ with patch_hook():
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+ adapter = Mock(spec=HttpxRequestAdapter)
+ adapter._http_client = Mock(is_closed=False)
+ adapter.send_no_response_content_async =
AsyncMock(side_effect=RuntimeError("some error"))
+ hook.cached_request_adapters[hook.conn_id] = (hook.api_version,
adapter)
+
+ with pytest.raises(RuntimeError, match="some error"):
+ await hook.run(url="users")
+
+ adapter.send_no_response_content_async.assert_called_once()
+ assert hook.conn_id not in hook.cached_request_adapters
+
class TestKiotaRequestAdapterHookProtocol:
"""Test protocol handling in KiotaRequestAdapterHook."""