This is an automated email from the ASF dual-hosted git repository.

jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 94405c425ac SnowflakeHook: extract OAuth token lifecycle management 
into dedicated helper (#68549)
94405c425ac is described below

commit 94405c425acbd484edea6f43aa2d1e57466adc0c
Author: SameerMesiah97 <[email protected]>
AuthorDate: Wed Jun 24 14:32:47 2026 +0100

    SnowflakeHook: extract OAuth token lifecycle management into dedicated 
helper (#68549)
    
    * Introduce _SnowflakeOAuthManager to own OAuth token lifecycle
    management.
    
    Move grant type validation, token requests, token caching and token
    expiration tracking out of SnowflakeHook while preserving the existing
    public API.
    
    SnowflakeHook continues to expose get_oauth_token(), but now delegates
    OAuth token management to the dedicated helper.
    
    * Adjust comment in the test file.
    
    ---------
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../airflow/providers/snowflake/hooks/snowflake.py | 232 +++++++++++----------
 .../tests/unit/snowflake/hooks/test_snowflake.py   |   6 +-
 2 files changed, 121 insertions(+), 117 deletions(-)

diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index e25b8605241..56bd9afd2f3 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -86,6 +86,119 @@ def _is_retryable_oauth_error(exception: BaseException) -> 
bool:
     return False
 
 
+class _SnowflakeOAuthManager:
+    """Encapsulates OAuth token lifecycle management for Snowflake 
authentication."""
+
+    def __init__(self):
+        self._oauth_token: str | None = None
+        self._oauth_token_expires_at: datetime | None = None
+
+    def validate_grant_type(self, grant_type: str | None) -> str:
+        """Validate OAuth grant_type."""
+        if not grant_type:
+            raise ValueError("Grant type must be provided for OAuth 
authentication.")
+
+        if grant_type not in SUPPORTED_GRANT_TYPES:
+            supported = ", ".join(sorted(SUPPORTED_GRANT_TYPES))
+
+            raise ValueError(f"Unsupported grant_type '{grant_type}'. 
Supported values: {supported}")
+
+        return grant_type
+
+    @tenacity.retry(
+        stop=tenacity.stop_after_attempt(3),
+        wait=tenacity.wait_exponential(multiplier=1, min=0, max=10),
+        retry=tenacity.retry_if_exception(_is_retryable_oauth_error),
+        reraise=True,
+    )
+    def request_oauth_token(
+        self,
+        *,
+        url: str,
+        data: dict[str, Any],
+        client_id: str,
+        client_secret: str,
+    ):
+        """
+        Execute an OAuth token request.
+
+        Retries automatically on transient failures (ConnectionError, Timeout, 
5xx)
+        via the @tenacity.retry decorator above.
+        """
+        response = requests.post(
+            url,
+            data=data,
+            headers={"Content-Type": "application/x-www-form-urlencoded"},
+            auth=HTTPBasicAuth(client_id, client_secret),
+            timeout=OAUTH_REQUEST_TIMEOUT,
+        )
+
+        # Raise HTTPError for non-success responses so retry logic can decide
+        # whether the failure is retryable.
+        response.raise_for_status()
+        return response
+
+    def get_valid_oauth_token(
+        self,
+        *,
+        conn_config: dict[str, Any],
+        token_endpoint: str | None,
+        grant_type: str,
+    ) -> str:
+        """
+        Return a valid OAuth access token.
+
+        This also updates the internal OAuth token cache and token expiry 
timestamp.
+        """
+        # Check validity using current timestamp.
+        now = timezone.utcnow()
+
+        if (
+            self._oauth_token is not None
+            and self._oauth_token_expires_at is not None
+            and now < self._oauth_token_expires_at
+        ):
+            return self._oauth_token
+
+        url = token_endpoint or 
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
+
+        data = {
+            "grant_type": grant_type,
+            "redirect_uri": conn_config.get("redirect_uri", 
"https://localhost.com";),
+        }
+
+        scope = conn_config.get("scope")
+
+        if scope:
+            data["scope"] = scope
+
+        grant_type = self.validate_grant_type(grant_type)
+
+        if grant_type == "refresh_token":
+            data |= {
+                "refresh_token": conn_config["refresh_token"],
+            }
+
+        response = self.request_oauth_token(
+            url=url,
+            data=data,
+            client_id=conn_config["client_id"],
+            client_secret=conn_config["client_secret"],
+        )
+
+        token = response.json()["access_token"]
+        expires_in = int(response.json()["expires_in"])
+
+        # Capture issue timestamp after access token is retrieved.
+        issued_at = timezone.utcnow()
+
+        # Persist retrieved access token and expiration timestamp.
+        self._oauth_token = token
+        self._oauth_token_expires_at = issued_at + 
timedelta(seconds=max(expires_in - OAUTH_EXPIRY_BUFFER, 0))
+
+        return token
+
+
 class SnowflakeHook(DbApiHook):
     """
     A client to interact with Snowflake.
@@ -225,10 +338,7 @@ class SnowflakeHook(DbApiHook):
         self.client_store_temporary_credential = 
kwargs.pop("client_store_temporary_credential", None)
         self.query_ids: list[str] = []
 
-        # Access token and expiration timestamp persisted
-        # to handle premature expiry.
-        self._oauth_token: str | None = None
-        self._oauth_token_expires_at: datetime | None = None
+        self._oauth = _SnowflakeOAuthManager()
 
     def _get_field(self, extra_dict, field_name):
         backcompat_prefix = "extra__snowflake__"
@@ -252,18 +362,6 @@ class SnowflakeHook(DbApiHook):
             return extra_dict[field_name] or None
         return extra_dict.get(backcompat_key) or None
 
-    def _validate_grant_type(self, grant_type: str | None) -> str:
-        """Validate OAuth grant_type."""
-        if not grant_type:
-            raise ValueError("Grant type must be provided for OAuth 
authentication.")
-
-        if grant_type not in SUPPORTED_GRANT_TYPES:
-            supported = ", ".join(sorted(SUPPORTED_GRANT_TYPES))
-
-            raise ValueError(f"Unsupported grant_type '{grant_type}'. 
Supported values: {supported}")
-
-        return grant_type
-
     @property
     def account_identifier(self) -> str:
         """Get snowflake account identifier."""
@@ -292,7 +390,7 @@ class SnowflakeHook(DbApiHook):
         if token_endpoint is None:
             token_endpoint = conn_config.get("token_endpoint")
 
-        return self._get_valid_oauth_token(
+        return self._oauth.get_valid_oauth_token(
             conn_config=conn_config, token_endpoint=token_endpoint, 
grant_type=grant_type
         )
 
@@ -342,9 +440,9 @@ class SnowflakeHook(DbApiHook):
             if azure_conn_id:
                 conn_config["token"] = 
self.get_azure_oauth_token(azure_conn_id)
             else:
-                grant_type = 
self._validate_grant_type(conn_config.get("grant_type"))
+                grant_type = 
self._oauth.validate_grant_type(conn_config.get("grant_type"))
 
-                conn_config["token"] = self._get_valid_oauth_token(
+                conn_config["token"] = self._oauth.get_valid_oauth_token(
                     conn_config=conn_config,
                     token_endpoint=conn_config.get("token_endpoint"),
                     grant_type=grant_type,
@@ -361,8 +459,7 @@ class SnowflakeHook(DbApiHook):
         Return static Snowflake connection parameters.
 
         These parameters are cached for the lifetime of the hook and exclude
-        time-sensitive values such as OAuth access tokens. This is used in
-        ``_get_valid_oauth_token()`` and ``get_conn_params()``.
+        time-sensitive values such as OAuth access tokens.
         """
         conn = self.get_connection(self.get_conn_id())
         extra_dict = conn.extra_dejson
@@ -491,99 +588,6 @@ class SnowflakeHook(DbApiHook):
 
         return conn_config
 
-    def _get_valid_oauth_token(
-        self,
-        *,
-        conn_config: dict[str, Any],
-        token_endpoint: str | None,
-        grant_type: str,
-    ) -> str:
-        """
-        Return a valid OAuth access token.
-
-        This also updates the internal OAuth token cache and token expiry 
timestamp.
-        """
-        # Check validity using current timestamp.
-        now = timezone.utcnow()
-
-        if (
-            self._oauth_token is not None
-            and self._oauth_token_expires_at is not None
-            and now < self._oauth_token_expires_at
-        ):
-            return self._oauth_token
-
-        url = token_endpoint or 
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
-
-        data = {
-            "grant_type": grant_type,
-            "redirect_uri": conn_config.get("redirect_uri", 
"https://localhost.com";),
-        }
-
-        scope = conn_config.get("scope")
-
-        if scope:
-            data["scope"] = scope
-
-        grant_type = self._validate_grant_type(grant_type)
-
-        if grant_type == "refresh_token":
-            data |= {
-                "refresh_token": conn_config["refresh_token"],
-            }
-
-        response = self._request_oauth_token(
-            url=url,
-            data=data,
-            client_id=conn_config["client_id"],
-            client_secret=conn_config["client_secret"],
-        )
-
-        token = response.json()["access_token"]
-        expires_in = int(response.json()["expires_in"])
-
-        # Capture issue timestamp after access token is retrieved.
-        issued_at = timezone.utcnow()
-
-        # Persist retrieved access token and expiration timestamp.
-        self._oauth_token = token
-        self._oauth_token_expires_at = issued_at + 
timedelta(seconds=max(expires_in - OAUTH_EXPIRY_BUFFER, 0))
-
-        return token
-
-    @tenacity.retry(
-        stop=tenacity.stop_after_attempt(3),
-        wait=tenacity.wait_exponential(multiplier=1, min=0, max=10),
-        retry=tenacity.retry_if_exception(_is_retryable_oauth_error),
-        reraise=True,
-    )
-    def _request_oauth_token(
-        self,
-        *,
-        url: str,
-        data: dict[str, Any],
-        client_id: str,
-        client_secret: str,
-    ):
-        """
-        Execute a single OAuth token request.
-
-        Performs one HTTP call and raises ``HTTPError`` for 4xx and 5xx 
responses.
-        Retry behavior is handled by the caller.
-        """
-        response = requests.post(
-            url,
-            data=data,
-            headers={"Content-Type": "application/x-www-form-urlencoded"},
-            auth=HTTPBasicAuth(client_id, client_secret),
-            timeout=OAUTH_REQUEST_TIMEOUT,
-        )
-
-        # Raise HTTPError for non-success responses so retry logic can decide
-        # whether the failure is retryable.
-        response.raise_for_status()
-        return response
-
     def get_private_key(self) -> PrivateKeyTypes | None:
         """Get the private key from snowflake connection."""
         conn = self.get_connection(self.get_conn_id())
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index 5b96484933e..d0fdb49b5b5 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -1314,9 +1314,9 @@ class TestPytestSnowflakeHook:
 
         if expected is ValueError:
             with pytest.raises(ValueError, match=match):
-                hook._validate_grant_type(grant_type)
+                hook._oauth.validate_grant_type(grant_type)
         else:
-            assert hook._validate_grant_type(grant_type) == expected
+            assert hook._oauth.validate_grant_type(grant_type) == expected
 
     @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
     @mock.patch("requests.post")
@@ -1599,7 +1599,7 @@ class TestPytestSnowflakeHook:
 
         t0 = datetime(2025, 1, 1, 12, 0, tzinfo=timezone.utc)
 
-        # _get_valid_oauth_token calls utcnow twice per refresh:
+        # get_valid_oauth_token from _SnowflakeOAuthManager calls utcnow twice 
per refresh:
         #   1) validity check
         #   2) issued_at
         mock_timezone_utcnow.side_effect = [

Reply via email to