This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 203c044a638 Added retry logic for Snowflake OAuth token requests
(#61796)
203c044a638 is described below
commit 203c044a6380bf23b0fdca35f39b8f6b06e8652b
Author: SameerMesiah97 <[email protected]>
AuthorDate: Sun Feb 15 23:45:05 2026 +0000
Added retry logic for Snowflake OAuth token requests (#61796)
Introduced retry handling for OAuth token acquisition in SnowflakeHook
using tenacity. Extracted the HTTP call into _request_oauth_token and added
retry classification via _is_retryable_oauth_error. Retries apply only to
connection errors and HTTP 5xx responses, while HTTP 4xx errors fail fast.
Updated unit tests to cover retry behavior, non-retryable errors, and retry
exhaustion. Updated the get_oauth_token docstring to reflect retry semantics.
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../airflow/providers/snowflake/hooks/snowflake.py | 74 +++++++++++---
.../tests/unit/snowflake/hooks/test_snowflake.py | 106 +++++++++++++++++++++
2 files changed, 166 insertions(+), 14 deletions(-)
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 96725472493..d077b072330 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -29,9 +29,11 @@ from typing import TYPE_CHECKING, Any, TypeVar, overload
from urllib.parse import urlparse
import requests
+import tenacity
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from requests.auth import HTTPBasicAuth
+from requests.exceptions import ConnectionError, HTTPError, Timeout
from snowflake import connector
from snowflake.connector import DictCursor, SnowflakeConnection, util_text
from snowflake.sqlalchemy import URL
@@ -65,6 +67,22 @@ def _try_to_boolean(value: Any):
return value
+def _is_retryable_oauth_error(exception: BaseException) -> bool:
+ """Return True if exception is retryable for OAuth token request."""
+ if isinstance(exception, (ConnectionError, Timeout)):
+ return True
+
+ # Retry only on server-side HTTP errors (5xx).
+ # Client-side errors (4xx) indicate misconfiguration or invalid credentials
+ # and should fail fast without retrying.
+ if isinstance(exception, HTTPError):
+ response = exception.response
+ if response is not None and 500 <= response.status_code < 600:
+ return True
+
+ return False
+
+
class SnowflakeHook(DbApiHook):
"""
A client to interact with Snowflake.
@@ -239,7 +257,11 @@ class SnowflakeHook(DbApiHook):
token_endpoint: str | None = None,
grant_type: str = "refresh_token",
) -> str:
- """Generate temporary OAuth access token using refresh token in
connection details."""
+ """
+ Generate temporary OAuth access token using refresh token in
connection details.
+
+ Transient network and server-side errors are retried automatically.
+ """
if conn_config is None:
conn_config = self._get_static_conn_params
@@ -503,22 +525,13 @@ class SnowflakeHook(DbApiHook):
else:
raise ValueError(f"Unknown grant_type: {grant_type}")
- response = requests.post(
- url,
+ response = self._request_oauth_token(
+ url=url,
data=data,
- headers={
- "Content-Type": "application/x-www-form-urlencoded",
- },
- auth=HTTPBasicAuth(conn_config["client_id"],
conn_config["client_secret"]), # type: ignore[arg-type]
- timeout=OAUTH_REQUEST_TIMEOUT,
+ client_id=conn_config["client_id"],
+ client_secret=conn_config["client_secret"],
)
- try:
- response.raise_for_status()
- except requests.exceptions.HTTPError as e: # pragma: no cover
- msg = f"Response: {e.response.content.decode()} Status Code:
{e.response.status_code}"
- raise AirflowException(msg)
-
token = response.json()["access_token"]
expires_in = int(response.json()["expires_in"])
@@ -531,6 +544,39 @@ class SnowflakeHook(DbApiHook):
return token
+ @tenacity.retry(
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_exponential(multiplier=1, min=0, max=10),
+ retry=tenacity.retry_if_exception(_is_retryable_oauth_error),
+ reraise=True,
+ )
+ def _request_oauth_token(
+ self,
+ *,
+ url: str,
+ data: dict[str, Any],
+ client_id: str,
+ client_secret: str,
+ ):
+ """
+ Execute a single OAuth token request.
+
+ Performs one HTTP call and raises ``HTTPError`` for 4xx and 5xx
responses.
+ Retry behavior is handled by the caller.
+ """
+ response = requests.post(
+ url,
+ data=data,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ auth=HTTPBasicAuth(client_id, client_secret),
+ timeout=OAUTH_REQUEST_TIMEOUT,
+ )
+
+ # Raise HTTPError for non-success responses so retry logic can decide
+ # whether the failure is retryable.
+ response.raise_for_status()
+ return response
+
def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
conn_params = self._get_conn_params()
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index c28fd895d12..da10a1efd72 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -30,6 +30,7 @@ import pytest
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
+from requests.exceptions import ConnectionError, HTTPError
from airflow.models import Connection
from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
@@ -1131,6 +1132,111 @@ class TestPytestSnowflakeHook:
timeout=30,
)
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake.timezone.utcnow")
+ @mock.patch("requests.post")
+ def test_get_oauth_token_retries_and_succeeds(self, requests_post,
mock_timezone_utcnow):
+
+ # Freeze time to prevent access token expiration.
+ t0 = datetime(2025, 1, 1, 12, 0, tzinfo=timezone.utc)
+ mock_timezone_utcnow.side_effect = [t0, t0]
+
+ requests_post.side_effect = [
+ ConnectionError("temporary network error"),
+ Mock(
+ status_code=200,
+ json=lambda: {"access_token": "retry_token", "expires_in":
600},
+ raise_for_status=lambda: None,
+ ),
+ ]
+
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "login": "client_id",
+ "password": "client_secret",
+ "extra": {
+ "account": "airflow",
+ "authenticator": "oauth",
+ "grant_type": "refresh_token",
+ "refresh_token": "secret_token",
+ },
+ }
+
+ with mock.patch.dict(
+ "os.environ",
+ {"AIRFLOW_CONN_TEST_CONN":
Connection(**connection_kwargs).get_uri()},
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ token = hook.get_oauth_token()
+
+ # Should retry.
+ assert token == "retry_token"
+ assert requests_post.call_count == 2
+
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake.requests.post")
+ def test_get_oauth_token_does_not_retry_on_client_error(self,
requests_post):
+
+ response = Mock(status_code=401)
+ http_error = HTTPError(response=response)
+
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = http_error
+
+ requests_post.return_value = mock_response
+
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "login": "client_id",
+ "password": "client_secret",
+ "extra": {
+ "account": "airflow",
+ "authenticator": "oauth",
+ "grant_type": "refresh_token",
+ "refresh_token": "secret_token",
+ },
+ }
+
+ with mock.patch.dict(
+ "os.environ",
+ {"AIRFLOW_CONN_TEST_CONN":
Connection(**connection_kwargs).get_uri()},
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+
+ with pytest.raises(HTTPError):
+ hook.get_oauth_token()
+
+ # Should not retry.
+ assert requests_post.call_count == 1
+
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake.requests.post")
+ def test_get_oauth_token_fails_after_max_retries(self, requests_post):
+
+ # Always fail with retryable error.
+ requests_post.side_effect = ConnectionError("persistent network
failure")
+
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "login": "client_id",
+ "password": "client_secret",
+ "extra": {
+ "account": "airflow",
+ "authenticator": "oauth",
+ "grant_type": "refresh_token",
+ "refresh_token": "secret_token",
+ },
+ }
+
+ with mock.patch.dict(
+ "os.environ",
+ {"AIRFLOW_CONN_TEST_CONN":
Connection(**connection_kwargs).get_uri()},
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+
+ with pytest.raises(ConnectionError):
+ hook.get_oauth_token()
+
+ # Stop after the third attempt.
+ assert requests_post.call_count == 3
+
def test_get_azure_oauth_token(self, mocker):
"""Test get_azure_oauth_token method gets token from provided
connection id"""
azure_conn_id = "azure_test_conn"