This is an automated email from the ASF dual-hosted git repository.
joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a90a1ac6c62 Use common provider's get_async_connection in other
providers (#56791)
a90a1ac6c62 is described below
commit a90a1ac6c625a91a2c444b88620edd193706a560
Author: Ramit Kataria <[email protected]>
AuthorDate: Mon Jan 26 12:52:48 2026 -0800
Use common provider's get_async_connection in other providers (#56791)
* Use common provider's get_async_connection in other providers
* Fix sftp and livy unit tests
---
.../amazon/src/airflow/providers/amazon/aws/hooks/s3.py | 5 +++--
.../livy/src/airflow/providers/apache/livy/hooks/livy.py | 4 ++--
.../livy/tests/unit/apache/livy/hooks/test_livy.py | 12 ++++++------
.../providers/cncf/kubernetes/hooks/kubernetes.py | 4 ++--
.../cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py | 3 ++-
providers/http/src/airflow/providers/http/hooks/http.py | 4 ++--
.../providers/microsoft/azure/hooks/data_factory.py | 6 +++---
.../src/airflow/providers/microsoft/azure/hooks/wasb.py | 4 ++--
.../providers/pagerduty/hooks/pagerduty_events.py | 4 ++--
providers/sftp/src/airflow/providers/sftp/hooks/sftp.py | 4 ++--
providers/sftp/tests/unit/sftp/hooks/test_sftp.py | 16 ++++++++--------
providers/ssh/src/airflow/providers/ssh/hooks/ssh.py | 4 ++--
12 files changed, 36 insertions(+), 34 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
index 909423c06d3..db5ab1fc8d8 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
@@ -42,6 +42,8 @@ from typing import TYPE_CHECKING, Any
from urllib.parse import urlsplit
from uuid import uuid4
+from airflow.providers.common.compat.connection import get_async_connection
+
if TYPE_CHECKING:
from aiobotocore.client import AioBaseClient
from mypy_boto3_s3.service_resource import (
@@ -52,7 +54,6 @@ if TYPE_CHECKING:
from airflow.providers.amazon.version_compat import ArgNotSet
-from asgiref.sync import sync_to_async
from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError
@@ -90,7 +91,7 @@ def provide_bucket_name(func: Callable) -> Callable:
if not bound_args.arguments.get("bucket_name"):
self = args[0]
if self.aws_conn_id:
- connection = await
sync_to_async(self.get_connection)(self.aws_conn_id)
+ connection = await get_async_connection(self.aws_conn_id)
if connection.schema:
bound_args.arguments["bucket_name"] = connection.schema
return bound_args
diff --git
a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
index 5844995cb05..9de3582de11 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
@@ -26,8 +26,8 @@ from typing import TYPE_CHECKING, Any
import aiohttp
import requests
from aiohttp import ClientResponseError
-from asgiref.sync import sync_to_async
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
@@ -526,7 +526,7 @@ class LivyAsyncHook(HttpAsyncHook):
auth = None
if self.http_conn_id:
- conn = await sync_to_async(self.get_connection)(self.http_conn_id)
+ conn = await get_async_connection(self.http_conn_id)
self.base_url = self._generate_base_url(conn) # type:
ignore[arg-type]
if conn.login:
diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
index 8dac0bedf52..3353d796915 100644
--- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
+++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
@@ -611,7 +611,7 @@ class TestLivyAsyncHook:
@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_post_method_with_success(self,
mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for success response for POST
method."""
@@ -634,7 +634,7 @@ class TestLivyAsyncHook:
@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_get_method_with_success(self,
mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for GET method."""
@@ -659,7 +659,7 @@ class TestLivyAsyncHook:
@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_patch_method_with_success(self,
mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for PATCH method."""
@@ -684,7 +684,7 @@ class TestLivyAsyncHook:
@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_unexpected_method_error(self,
mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for unexpected method error"""
GET_RUN_ENDPOINT = "api/jobs/runs/get"
@@ -700,7 +700,7 @@ class TestLivyAsyncHook:
@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_with_type_error(self,
mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for TypeError."""
@@ -719,7 +719,7 @@ class TestLivyAsyncHook:
@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_with_client_response_error(self,
mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for Client Response Error."""
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index fa33fe44951..375ae9adb32 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -27,7 +27,6 @@ from typing import TYPE_CHECKING, Any, Protocol
import aiofiles
import requests
-from asgiref.sync import sync_to_async
from kubernetes import client, config, utils, watch
from kubernetes.client.models import V1Deployment
from kubernetes.config import ConfigException
@@ -46,6 +45,7 @@ from airflow.providers.cncf.kubernetes.utils.container import
(
container_is_completed,
container_is_running,
)
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException,
AirflowNotFoundException, BaseHook
from airflow.utils import yaml
@@ -899,7 +899,7 @@ class AsyncKubernetesHook(KubernetesHook):
async def get_conn_extras(self) -> dict:
if self._extras is None:
if self.conn_id:
- connection = await
sync_to_async(self.get_connection)(self.conn_id)
+ connection = await get_async_connection(self.conn_id)
self._extras = connection.extra_dejson
else:
self._extras = {}
diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
index c4dd0572f6e..ca20480abd2 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -34,6 +34,7 @@ from requests.auth import AuthBase
from requests.sessions import Session
from tenacity import AsyncRetrying, RetryCallState, retry_if_exception,
stop_after_attempt, wait_exponential
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.http.hooks.http import HttpHook
@@ -161,7 +162,7 @@ def provide_account_id(func: T) -> T:
if bound_args.arguments.get("account_id") is None:
self = args[0]
if self.dbt_cloud_conn_id:
- connection = await
sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
+ connection = await get_async_connection(self.dbt_cloud_conn_id)
default_account_id = connection.login
if not default_account_id:
raise AirflowException("Could not determine the dbt Cloud
account.")
diff --git a/providers/http/src/airflow/providers/http/hooks/http.py
b/providers/http/src/airflow/providers/http/hooks/http.py
index 815b7ffadc9..ed137a651c4 100644
--- a/providers/http/src/airflow/providers/http/hooks/http.py
+++ b/providers/http/src/airflow/providers/http/hooks/http.py
@@ -25,13 +25,13 @@ from urllib.parse import urlparse
import aiohttp
import tenacity
from aiohttp import ClientResponseError
-from asgiref.sync import sync_to_async
from requests import PreparedRequest, Request, Response, Session
from requests.auth import HTTPBasicAuth
from requests.exceptions import ConnectionError, HTTPError
from requests.models import DEFAULT_REDIRECT_LIMIT
from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.http.exceptions import HttpErrorException,
HttpMethodException
@@ -461,7 +461,7 @@ class HttpAsyncHook(BaseHook):
auth = None
if self.http_conn_id:
- conn = await sync_to_async(self.get_connection)(self.http_conn_id)
+ conn = await get_async_connection(self.http_conn_id)
if conn.host and "://" in conn.host:
self.base_url = conn.host
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py
index 93c28e33300..44248847794 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -39,7 +39,6 @@ from collections.abc import Callable
from functools import wraps
from typing import IO, TYPE_CHECKING, Any, TypeVar, cast
-from asgiref.sync import sync_to_async
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.identity.aio import (
ClientSecretCredential as AsyncClientSecretCredential,
@@ -48,6 +47,7 @@ from azure.identity.aio import (
from azure.mgmt.datafactory import DataFactoryManagementClient
from azure.mgmt.datafactory.aio import DataFactoryManagementClient as
AsyncDataFactoryManagementClient
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
@@ -1089,7 +1089,7 @@ def provide_targeted_factory_async(func: T) -> T:
# Check if arg was not included in the function signature or, if
it is, the value is not provided.
if arg not in bound_args.arguments or bound_args.arguments[arg] is
None:
self = args[0]
- conn = await sync_to_async(self.get_connection)(self.conn_id)
+ conn = await get_async_connection(self.conn_id)
extras = conn.extra_dejson
default_value = extras.get(default_key) or extras.get(
f"extra__azure_data_factory__{default_key}"
@@ -1126,7 +1126,7 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
if self._async_conn is not None:
return self._async_conn
- conn = await sync_to_async(self.get_connection)(self.conn_id)
+ conn = await get_async_connection(self.conn_id)
extras = conn.extra_dejson
tenant = get_field(extras, "tenantId")
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py
index ed92f48f1d3..dabd7280e3d 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -30,7 +30,6 @@ import logging
import os
from typing import TYPE_CHECKING, Any, cast
-from asgiref.sync import sync_to_async
from azure.core.exceptions import HttpResponseError, ResourceExistsError,
ResourceNotFoundError
from azure.identity import ClientSecretCredential
from azure.identity.aio import (
@@ -44,6 +43,7 @@ from azure.storage.blob.aio import (
ContainerClient as AsyncContainerClient,
)
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
@@ -620,7 +620,7 @@ class WasbAsyncHook(WasbHook):
self._blob_service_client = cast("AsyncBlobServiceClient",
self._blob_service_client)
return self._blob_service_client
- conn = await sync_to_async(self.get_connection)(self.conn_id)
+ conn = await get_async_connection(self.conn_id)
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})
diff --git
a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py
b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py
index 710189bde68..99a0efefe08 100644
---
a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py
+++
b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py
@@ -23,8 +23,8 @@ from typing import TYPE_CHECKING, Any
import aiohttp
import pagerduty
-from asgiref.sync import sync_to_async
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.http.hooks.http import HttpAsyncHook
@@ -285,7 +285,7 @@ class PagerdutyEventsAsyncHook(HttpAsyncHook):
return self.integration_key
if self.pagerduty_events_conn_id is not None:
- conn = await
sync_to_async(self.get_connection)(self.pagerduty_events_conn_id)
+ conn = await get_async_connection(self.pagerduty_events_conn_id)
self.integration_key = conn.password
if self.integration_key:
return self.integration_key
diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
index bef8d725c33..80ceb729082 100644
--- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
+++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
@@ -33,10 +33,10 @@ from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, cast
import asyncssh
-from asgiref.sync import sync_to_async
from paramiko.config import SSH_PORT
from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook,
Connection
from airflow.providers.sftp.exceptions import ConnectionNotOpenedException
from airflow.providers.ssh.hooks.ssh import SSHHook
@@ -756,7 +756,7 @@ class SFTPHookAsync(BaseHook):
- known_hosts
- passphrase
"""
- conn = await sync_to_async(self.get_connection)(self.sftp_conn_id)
+ conn = await get_async_connection(self.sftp_conn_id)
if conn.extra is not None:
self._parse_extras(conn) # type: ignore[arg-type]
diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
index dc3838f1a87..835bfcc0221 100644
--- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
+++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
@@ -734,7 +734,7 @@ class MockAirflowConnectionWithPrivate:
class TestSFTPHookAsync:
@patch("asyncssh.connect", new_callable=AsyncMock)
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def
test_extra_dejson_fields_for_connection_building_known_hosts_none(
self, mock_get_connection, mock_connect, caplog
@@ -775,7 +775,7 @@ class TestSFTPHookAsync:
)
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("asyncssh.import_private_key")
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_extra_dejson_fields_for_connection_with_host_key(
self,
@@ -799,7 +799,7 @@ class TestSFTPHookAsync:
assert hook.known_hosts == f"localhost {mock_host_key}".encode()
@patch("asyncssh.connect", new_callable=AsyncMock)
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_extra_dejson_fields_for_connection_raises_valuerror(
self, mock_get_connection, mock_connect
@@ -820,7 +820,7 @@ class TestSFTPHookAsync:
@patch("paramiko.SSHClient.connect")
@patch("asyncssh.import_private_key")
@patch("asyncssh.connect", new_callable=AsyncMock)
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_no_host_key_check_set_logs_warning(
self, mock_get_connection, mock_connect, mock_import_pkey,
mock_ssh_connect, caplog
@@ -833,7 +833,7 @@ class TestSFTPHookAsync:
assert "No Host Key Verification. This won't protect against
Man-In-The-Middle attacks" in caplog.text
@patch("asyncssh.connect", new_callable=AsyncMock)
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_extra_dejson_fields_for_connection_building(self,
mock_get_connection, mock_connect):
"""
@@ -861,7 +861,7 @@ class TestSFTPHookAsync:
@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("asyncssh.import_private_key")
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
async def test_connection_private(self, mock_get_connection,
mock_import_private_key, mock_connect):
"""
Assert that connection details with private key passed through the
extra field in the Airflow connection
@@ -888,7 +888,7 @@ class TestSFTPHookAsync:
@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
async def test_connection_port_default_to_22(self, mock_get_connection,
mock_connect):
from unittest.mock import Mock, call
@@ -917,7 +917,7 @@ class TestSFTPHookAsync:
@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
async def test_init_argument_not_ignored(self, mock_get_connection,
mock_connect):
from unittest.mock import Mock, call
diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
index 493f3f92369..4814569e4cf 100644
--- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
@@ -33,6 +33,7 @@ from paramiko.config import SSH_PORT
from sshtunnel import SSHTunnelForwarder
from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random
+from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.utils.platform import getuser
@@ -615,9 +616,8 @@ class SSHHookAsync(BaseHook):
Returns an asyncssh SSHClientConnection that can be used to run
commands.
"""
import asyncssh
- from asgiref.sync import sync_to_async
- conn = await sync_to_async(self.get_connection)(self.ssh_conn_id)
+ conn = await get_async_connection(self.ssh_conn_id)
if conn.extra is not None:
self._parse_extras(conn)