This is an automated email from the ASF dual-hosted git repository.
bugraoz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6cf9d0b0269 Adding OAuth support for SnowflakeHook (#47191)
6cf9d0b0269 is described below
commit 6cf9d0b0269c0ca02c3453160ad02301c8fdd9fa
Author: Sharashchandra Desai <[email protected]>
AuthorDate: Sat Apr 19 03:13:52 2025 +0530
Adding OAuth support for SnowflakeHook (#47191)
* feat(snowflake_hook): Adding oauth support for SnowflakeHook
* refactor(snowflake_hook): adding the access_token directly to the
conn_config dict instead of via variable
* docs(snowflake_connection): Updating the docs to include changes for OAuth
---
providers/snowflake/docs/connections/snowflake.rst | 3 +-
.../airflow/providers/snowflake/hooks/snowflake.py | 29 +++++++++
.../tests/unit/snowflake/hooks/test_snowflake.py | 74 ++++++++++++++++++++++
3 files changed, 105 insertions(+), 1 deletion(-)
diff --git a/providers/snowflake/docs/connections/snowflake.rst
b/providers/snowflake/docs/connections/snowflake.rst
index 2d7076d120f..903bc17c5da 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -39,10 +39,11 @@ Configuring the Connection
--------------------------
Login
- Specify the snowflake username.
+ Specify the snowflake username. For OAuth, the OAuth Client ID.
Password
Specify the snowflake password. For public key authentication, the
passphrase for the private key.
+ For OAuth, the OAuth Client Secret.
Schema (optional)
Specify the snowflake schema to be used.
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 088c7177171..9cbb5542491 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -26,8 +26,10 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from urllib.parse import urlparse
+import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
+from requests.auth import HTTPBasicAuth
from snowflake import connector
from snowflake.connector import DictCursor, SnowflakeConnection, util_text
from snowflake.sqlalchemy import URL
@@ -185,6 +187,30 @@ class SnowflakeHook(DbApiHook):
return extra_dict[field_name] or None
return extra_dict.get(backcompat_key) or None
+ def get_oauth_token(self, conn_config: dict) -> str:
+ """Generate temporary OAuth access token using refresh token in
connection details."""
+ url =
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
+ data = {
+ "grant_type": "refresh_token",
+ "refresh_token": conn_config["refresh_token"],
+ "redirect_uri": conn_config.get("redirect_uri",
"https://localhost.com"),
+ }
+ response = requests.post(
+ url,
+ data=data,
+ headers={
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ auth=HTTPBasicAuth(conn_config["client_id"],
conn_config["client_secret"]), # type: ignore[arg-type]
+ )
+
+ try:
+ response.raise_for_status()
+ except requests.exceptions.HTTPError as e: # pragma: no cover
+ msg = f"Response: {e.response.content.decode()} Status Code:
{e.response.status_code}"
+ raise AirflowException(msg)
+ return response.json()["access_token"]
+
@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
"""
@@ -289,8 +315,11 @@ class SnowflakeHook(DbApiHook):
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config.pop("login", None)
+ conn_config.pop("user", None)
conn_config.pop("password", None)
+ conn_config["token"] =
self.get_oauth_token(conn_config=conn_config)
+
# configure custom target hostname and port, if specified
snowflake_host = extra_dict.get("host")
snowflake_port = extra_dict.get("port")
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index 2897921dc0c..b682e2e8511 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -22,6 +22,7 @@ import sys
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from unittest import mock
+from unittest.mock import Mock, PropertyMock
import pytest
from cryptography.hazmat.backends import default_backend
@@ -51,6 +52,21 @@ BASE_CONNECTION_KWARGS: dict = {
},
}
+CONN_PARAMS_OAUTH = {
+ "account": "airflow",
+ "application": "AIRFLOW",
+ "authenticator": "oauth",
+ "database": "db",
+ "client_id": "test_client_id",
+ "client_secret": "test_client_pw",
+ "refresh_token": "secrettoken",
+ "region": "af_region",
+ "role": "af_role",
+ "schema": "public",
+ "session_parameters": None,
+ "warehouse": "af_wh",
+}
+
@pytest.fixture
def non_encrypted_temporary_private_key(tmp_path: Path) -> Path:
@@ -483,6 +499,39 @@ class TestPytestSnowflakeHook:
):
SnowflakeHook(snowflake_conn_id="test_conn").get_conn()
+ @mock.patch("requests.post")
+ def test_get_conn_params_should_support_oauth(self, requests_post):
+ requests_post.return_value = Mock(
+ status_code=200,
+ json=lambda: {
+ "access_token": "supersecretaccesstoken",
+ "expires_in": 600,
+ "refresh_token": "secrettoken",
+ "token_type": "Bearer",
+ "username": "test_user",
+ },
+ )
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "login": "test_client_id",
+ "password": "test_client_secret",
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "refresh_token": "secrettoken",
+ },
+ }
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ assert "user" not in hook._get_conn_params
+ assert "password" not in hook._get_conn_params
+ assert "refresh_token" in hook._get_conn_params
+ assert "token" in hook._get_conn_params
+ assert hook._get_conn_params["authenticator"] == "oauth"
+
def test_should_add_partner_info(self):
with mock.patch.dict(
"os.environ",
@@ -816,3 +865,28 @@ class TestPytestSnowflakeHook:
"airflow_provider_version": provider_version,
}
)
+
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
+ @mock.patch("requests.post")
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+ new_callable=PropertyMock,
+ )
+ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
+ """Test get_oauth_token method makes the right http request"""
+ BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
+ mock_conn_param.return_value = CONN_PARAMS_OAUTH
+ requests_post.return_value.status_code = 200
+ mock_auth.return_value = BASIC_AUTH
+ hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
+ hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH)
+ requests_post.assert_called_once_with(
+
f"https://{CONN_PARAMS_OAUTH['account']}.snowflakecomputing.com/oauth/token-request",
+ data={
+ "grant_type": "refresh_token",
+ "refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
+ "redirect_uri": "https://localhost.com",
+ },
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ auth=BASIC_AUTH,
+ )