This is an automated email from the ASF dual-hosted git repository.
joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e4de653c33d Add sync and async helpers to resolve the dbt Cloud
account ID from the (#61757)
e4de653c33d is described below
commit e4de653c33dc201fa90e87e131de482dcf3e9009
Author: SameerMesiah97 <[email protected]>
AuthorDate: Fri Feb 27 17:31:06 2026 +0000
Add sync and async helpers to resolve the dbt Cloud account ID from the
(#61757)
configured Airflow connection and cache the value on the hook instance
to avoid repeated metadata DB lookups.
Introduce decorators to transparently fall back to the connection-based
account_id when not explicitly provided by the caller.
Add tests to verify caching behavior, including shared cache semantics
between sync and async resolution paths.
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../src/airflow/providers/dbt/cloud/hooks/dbt.py | 39 +++++++----
.../cloud/tests/unit/dbt/cloud/hooks/test_dbt.py | 77 ++++++++++++++++++++++
2 files changed, 105 insertions(+), 11 deletions(-)
diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
index 8910f9c8ebc..0a7cd33d7c7 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -63,12 +63,7 @@ def fallback_to_default_account(func: Callable) -> Callable:
# provided.
if bound_args.arguments.get("account_id") is None:
self = args[0]
- default_account_id = self.connection.login
- if not default_account_id:
- raise AirflowException("Could not determine the dbt Cloud
account.")
-
- bound_args.arguments["account_id"] = int(default_account_id)
-
+ bound_args.arguments["account_id"] = self._resolve_account_id()
return func(*bound_args.args, **bound_args.kwargs)
return wrapper
@@ -162,11 +157,7 @@ def provide_account_id(func: T) -> T:
if bound_args.arguments.get("account_id") is None:
self = args[0]
if self.dbt_cloud_conn_id:
- connection = await get_async_connection(self.dbt_cloud_conn_id)
- default_account_id = connection.login
- if not default_account_id:
- raise AirflowException("Could not determine the dbt Cloud
account.")
- bound_args.arguments["account_id"] = int(default_account_id)
+ bound_args.arguments["account_id"] = await
self._resolve_account_id_async()
return await func(*bound_args.args, **bound_args.kwargs)
@@ -434,6 +425,32 @@ class DbtCloudHook(HttpHook):
extra_options=extra_options or None,
)
+ def _resolve_account_id(self) -> int:
+ """Resolve and cache the dbt Cloud account ID (sync)."""
+ # Lazily initialized; absence means "not resolved yet".
+ if not hasattr(self, "_cached_account_id"):
+ conn = self.get_connection(self.dbt_cloud_conn_id)
+ if not conn.login:
+ raise AirflowException("Could not determine the dbt Cloud
account.")
+
+ # Cache is shared between sync and async resolution to avoid
duplicate
+ # metadata DB lookups on the same hook instance.
+ self._cached_account_id = int(conn.login)
+ return self._cached_account_id
+
+ async def _resolve_account_id_async(self) -> int:
+ """Resolve and cache the dbt Cloud account ID (async)."""
+ # Lazily initialized; absence means "not resolved yet".
+ if not hasattr(self, "_cached_account_id"):
+ conn = await get_async_connection(self.dbt_cloud_conn_id)
+ if not conn.login:
+ raise AirflowException("Could not determine the dbt Cloud
account.")
+
+ # Cache is shared between sync and async resolution to avoid
duplicate
+ # metadata DB lookups on the same hook instance.
+ self._cached_account_id = int(conn.login)
+ return self._cached_account_id
+
def list_accounts(self) -> list[Response]:
"""
Retrieve all of the dbt Cloud accounts the configured API token is
authorized to access.
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
index 6cddddd4bf1..3ba9653bb0b 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
@@ -279,6 +279,83 @@ class TestDbtCloudHook:
)
hook._paginate.assert_not_called()
+ def test_resolve_account_id_cached_sync(self):
+ hook = DbtCloudHook(ACCOUNT_ID_CONN)
+
+ with patch.object(DbtCloudHook, "get_connection") as
mock_get_connection:
+ mock_get_connection.return_value = Connection(
+ conn_id=ACCOUNT_ID_CONN,
+ conn_type=DbtCloudHook.conn_type,
+ login=str(DEFAULT_ACCOUNT_ID),
+ password=TOKEN,
+ )
+
+ first_call = hook._resolve_account_id()
+ second_call = hook._resolve_account_id()
+
+ assert first_call == DEFAULT_ACCOUNT_ID
+ assert second_call == DEFAULT_ACCOUNT_ID
+ assert mock_get_connection.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_resolve_account_id_cached_async(self):
+ hook = DbtCloudHook(ACCOUNT_ID_CONN)
+
+ with patch(
+ "airflow.providers.dbt.cloud.hooks.dbt.get_async_connection",
+ new=AsyncMock(
+ return_value=Connection(
+ conn_id=ACCOUNT_ID_CONN,
+ conn_type=DbtCloudHook.conn_type,
+ login=str(DEFAULT_ACCOUNT_ID),
+ password=TOKEN,
+ )
+ ),
+ ) as mock_get_async_connection:
+ first_call = await hook._resolve_account_id_async()
+ second_call = await hook._resolve_account_id_async()
+
+ assert first_call == DEFAULT_ACCOUNT_ID
+ assert second_call == DEFAULT_ACCOUNT_ID
+ assert mock_get_async_connection.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_account_id_cache_shared_between_sync_and_async(self):
+ hook = DbtCloudHook(ACCOUNT_ID_CONN)
+
+ with (
+ patch.object(
+ DbtCloudHook,
+ "get_connection",
+ return_value=Connection(
+ conn_id=ACCOUNT_ID_CONN,
+ conn_type=DbtCloudHook.conn_type,
+ login=str(DEFAULT_ACCOUNT_ID),
+ password=TOKEN,
+ ),
+ ) as mock_get_connection,
+ patch(
+ "airflow.providers.dbt.cloud.hooks.dbt.get_async_connection",
+ new=AsyncMock(
+ return_value=Connection(
+ conn_id=ACCOUNT_ID_CONN,
+ conn_type=DbtCloudHook.conn_type,
+ login=str(DEFAULT_ACCOUNT_ID),
+ password=TOKEN,
+ )
+ ),
+ ) as mock_get_async_connection,
+ ):
+ sync_account_id = hook._resolve_account_id()
+ async_account_id = await hook._resolve_account_id_async()
+
+ assert sync_account_id == DEFAULT_ACCOUNT_ID
+ assert async_account_id == DEFAULT_ACCOUNT_ID
+
+ # Only one metadata DB lookup total.
+ assert mock_get_connection.call_count == 1
+ assert mock_get_async_connection.call_count == 0
+
@pytest.mark.parametrize(
argnames=("conn_id", "account_id"),
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],