This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e4de653c33d Add sync and async helpers to resolve the dbt Cloud 
account ID from the (#61757)
e4de653c33d is described below

commit e4de653c33dc201fa90e87e131de482dcf3e9009
Author: SameerMesiah97 <[email protected]>
AuthorDate: Fri Feb 27 17:31:06 2026 +0000

    Add sync and async helpers to resolve the dbt Cloud account ID from the 
(#61757)
    
    configured Airflow connection and cache the value on the hook instance
    to avoid repeated metadata DB lookups.
    
    Introduce decorators to transparently fall back to the connection-based
    account_id when not explicitly provided by the caller.
    
    Add tests to verify caching behavior, including shared cache semantics
    between sync and async resolution paths.
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../src/airflow/providers/dbt/cloud/hooks/dbt.py   | 39 +++++++----
 .../cloud/tests/unit/dbt/cloud/hooks/test_dbt.py   | 77 ++++++++++++++++++++++
 2 files changed, 105 insertions(+), 11 deletions(-)

diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py 
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
index 8910f9c8ebc..0a7cd33d7c7 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -63,12 +63,7 @@ def fallback_to_default_account(func: Callable) -> Callable:
         # provided.
         if bound_args.arguments.get("account_id") is None:
             self = args[0]
-            default_account_id = self.connection.login
-            if not default_account_id:
-                raise AirflowException("Could not determine the dbt Cloud 
account.")
-
-            bound_args.arguments["account_id"] = int(default_account_id)
-
+            bound_args.arguments["account_id"] = self._resolve_account_id()
         return func(*bound_args.args, **bound_args.kwargs)
 
     return wrapper
@@ -162,11 +157,7 @@ def provide_account_id(func: T) -> T:
         if bound_args.arguments.get("account_id") is None:
             self = args[0]
             if self.dbt_cloud_conn_id:
-                connection = await get_async_connection(self.dbt_cloud_conn_id)
-                default_account_id = connection.login
-                if not default_account_id:
-                    raise AirflowException("Could not determine the dbt Cloud 
account.")
-                bound_args.arguments["account_id"] = int(default_account_id)
+                bound_args.arguments["account_id"] = await 
self._resolve_account_id_async()
 
         return await func(*bound_args.args, **bound_args.kwargs)
 
@@ -434,6 +425,32 @@ class DbtCloudHook(HttpHook):
             extra_options=extra_options or None,
         )
 
+    def _resolve_account_id(self) -> int:
+        """Resolve and cache the dbt Cloud account ID (sync)."""
+        # Lazily initialized; absence means "not resolved yet".
+        if not hasattr(self, "_cached_account_id"):
+            conn = self.get_connection(self.dbt_cloud_conn_id)
+            if not conn.login:
+                raise AirflowException("Could not determine the dbt Cloud 
account.")
+
+            # Cache is shared between sync and async resolution to avoid 
duplicate
+            # metadata DB lookups on the same hook instance.
+            self._cached_account_id = int(conn.login)
+        return self._cached_account_id
+
+    async def _resolve_account_id_async(self) -> int:
+        """Resolve and cache the dbt Cloud account ID (async)."""
+        # Lazily initialized; absence means "not resolved yet".
+        if not hasattr(self, "_cached_account_id"):
+            conn = await get_async_connection(self.dbt_cloud_conn_id)
+            if not conn.login:
+                raise AirflowException("Could not determine the dbt Cloud 
account.")
+
+            # Cache is shared between sync and async resolution to avoid 
duplicate
+            # metadata DB lookups on the same hook instance.
+            self._cached_account_id = int(conn.login)
+        return self._cached_account_id
+
     def list_accounts(self) -> list[Response]:
         """
         Retrieve all of the dbt Cloud accounts the configured API token is 
authorized to access.
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py 
b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
index 6cddddd4bf1..3ba9653bb0b 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
@@ -279,6 +279,83 @@ class TestDbtCloudHook:
         )
         hook._paginate.assert_not_called()
 
+    def test_resolve_account_id_cached_sync(self):
+        hook = DbtCloudHook(ACCOUNT_ID_CONN)
+
+        with patch.object(DbtCloudHook, "get_connection") as 
mock_get_connection:
+            mock_get_connection.return_value = Connection(
+                conn_id=ACCOUNT_ID_CONN,
+                conn_type=DbtCloudHook.conn_type,
+                login=str(DEFAULT_ACCOUNT_ID),
+                password=TOKEN,
+            )
+
+            first_call = hook._resolve_account_id()
+            second_call = hook._resolve_account_id()
+
+            assert first_call == DEFAULT_ACCOUNT_ID
+            assert second_call == DEFAULT_ACCOUNT_ID
+            assert mock_get_connection.call_count == 1
+
+    @pytest.mark.asyncio
+    async def test_resolve_account_id_cached_async(self):
+        hook = DbtCloudHook(ACCOUNT_ID_CONN)
+
+        with patch(
+            "airflow.providers.dbt.cloud.hooks.dbt.get_async_connection",
+            new=AsyncMock(
+                return_value=Connection(
+                    conn_id=ACCOUNT_ID_CONN,
+                    conn_type=DbtCloudHook.conn_type,
+                    login=str(DEFAULT_ACCOUNT_ID),
+                    password=TOKEN,
+                )
+            ),
+        ) as mock_get_async_connection:
+            first_call = await hook._resolve_account_id_async()
+            second_call = await hook._resolve_account_id_async()
+
+            assert first_call == DEFAULT_ACCOUNT_ID
+            assert second_call == DEFAULT_ACCOUNT_ID
+            assert mock_get_async_connection.call_count == 1
+
+    @pytest.mark.asyncio
+    async def test_account_id_cache_shared_between_sync_and_async(self):
+        hook = DbtCloudHook(ACCOUNT_ID_CONN)
+
+        with (
+            patch.object(
+                DbtCloudHook,
+                "get_connection",
+                return_value=Connection(
+                    conn_id=ACCOUNT_ID_CONN,
+                    conn_type=DbtCloudHook.conn_type,
+                    login=str(DEFAULT_ACCOUNT_ID),
+                    password=TOKEN,
+                ),
+            ) as mock_get_connection,
+            patch(
+                "airflow.providers.dbt.cloud.hooks.dbt.get_async_connection",
+                new=AsyncMock(
+                    return_value=Connection(
+                        conn_id=ACCOUNT_ID_CONN,
+                        conn_type=DbtCloudHook.conn_type,
+                        login=str(DEFAULT_ACCOUNT_ID),
+                        password=TOKEN,
+                    )
+                ),
+            ) as mock_get_async_connection,
+        ):
+            sync_account_id = hook._resolve_account_id()
+            async_account_id = await hook._resolve_account_id_async()
+
+            assert sync_account_id == DEFAULT_ACCOUNT_ID
+            assert async_account_id == DEFAULT_ACCOUNT_ID
+
+            # Only one metadata DB lookup total.
+            assert mock_get_connection.call_count == 1
+            assert mock_get_async_connection.call_count == 0
+
     @pytest.mark.parametrize(
         argnames=("conn_id", "account_id"),
         argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],

Reply via email to