This is an automated email from the ASF dual-hosted git repository.

bugraoz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6cf9d0b0269 Adding OAuth support for SnowflakeHook  (#47191)
6cf9d0b0269 is described below

commit 6cf9d0b0269c0ca02c3453160ad02301c8fdd9fa
Author: Sharashchandra Desai <[email protected]>
AuthorDate: Sat Apr 19 03:13:52 2025 +0530

    Adding OAuth support for SnowflakeHook  (#47191)
    
    * feat(snowflake_hook): Adding oauth support for SnowflakeHook
    
    * refactor(snowflake_hook): adding the access_token directly to the 
conn_config dict instead of via variable
    
    * docs(snowflake_connection): Updating the docs to include changes for OAuth
---
 providers/snowflake/docs/connections/snowflake.rst |  3 +-
 .../airflow/providers/snowflake/hooks/snowflake.py | 29 +++++++++
 .../tests/unit/snowflake/hooks/test_snowflake.py   | 74 ++++++++++++++++++++++
 3 files changed, 105 insertions(+), 1 deletion(-)

diff --git a/providers/snowflake/docs/connections/snowflake.rst 
b/providers/snowflake/docs/connections/snowflake.rst
index 2d7076d120f..903bc17c5da 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -39,10 +39,11 @@ Configuring the Connection
 --------------------------
 
 Login
-    Specify the snowflake username.
+    Specify the snowflake username. For OAuth, the OAuth Client ID.
 
 Password
     Specify the snowflake password. For public key authentication, the 
passphrase for the private key.
+    For OAuth, the OAuth Client Secret.
 
 Schema (optional)
     Specify the snowflake schema to be used.
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 088c7177171..9cbb5542491 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -26,8 +26,10 @@ from pathlib import Path
 from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
 from urllib.parse import urlparse
 
+import requests
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
+from requests.auth import HTTPBasicAuth
 from snowflake import connector
 from snowflake.connector import DictCursor, SnowflakeConnection, util_text
 from snowflake.sqlalchemy import URL
@@ -185,6 +187,30 @@ class SnowflakeHook(DbApiHook):
             return extra_dict[field_name] or None
         return extra_dict.get(backcompat_key) or None
 
+    def get_oauth_token(self, conn_config: dict) -> str:
+        """Generate temporary OAuth access token using refresh token in 
connection details."""
+        url = 
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
+        data = {
+            "grant_type": "refresh_token",
+            "refresh_token": conn_config["refresh_token"],
+            "redirect_uri": conn_config.get("redirect_uri", 
"https://localhost.com";),
+        }
+        response = requests.post(
+            url,
+            data=data,
+            headers={
+                "Content-Type": "application/x-www-form-urlencoded",
+            },
+            auth=HTTPBasicAuth(conn_config["client_id"], 
conn_config["client_secret"]),  # type: ignore[arg-type]
+        )
+
+        try:
+            response.raise_for_status()
+        except requests.exceptions.HTTPError as e:  # pragma: no cover
+            msg = f"Response: {e.response.content.decode()} Status Code: 
{e.response.status_code}"
+            raise AirflowException(msg)
+        return response.json()["access_token"]
+
     @cached_property
     def _get_conn_params(self) -> dict[str, str | None]:
         """
@@ -289,8 +315,11 @@ class SnowflakeHook(DbApiHook):
             conn_config["client_id"] = conn.login
             conn_config["client_secret"] = conn.password
             conn_config.pop("login", None)
+            conn_config.pop("user", None)
             conn_config.pop("password", None)
 
+            conn_config["token"] = 
self.get_oauth_token(conn_config=conn_config)
+
         # configure custom target hostname and port, if specified
         snowflake_host = extra_dict.get("host")
         snowflake_port = extra_dict.get("port")
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index 2897921dc0c..b682e2e8511 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -22,6 +22,7 @@ import sys
 from copy import deepcopy
 from typing import TYPE_CHECKING, Any
 from unittest import mock
+from unittest.mock import Mock, PropertyMock
 
 import pytest
 from cryptography.hazmat.backends import default_backend
@@ -51,6 +52,21 @@ BASE_CONNECTION_KWARGS: dict = {
     },
 }
 
+CONN_PARAMS_OAUTH = {
+    "account": "airflow",
+    "application": "AIRFLOW",
+    "authenticator": "oauth",
+    "database": "db",
+    "client_id": "test_client_id",
+    "client_secret": "test_client_pw",
+    "refresh_token": "secrettoken",
+    "region": "af_region",
+    "role": "af_role",
+    "schema": "public",
+    "session_parameters": None,
+    "warehouse": "af_wh",
+}
+
 
 @pytest.fixture
 def non_encrypted_temporary_private_key(tmp_path: Path) -> Path:
@@ -483,6 +499,39 @@ class TestPytestSnowflakeHook:
         ):
             SnowflakeHook(snowflake_conn_id="test_conn").get_conn()
 
+    @mock.patch("requests.post")
+    def test_get_conn_params_should_support_oauth(self, requests_post):
+        requests_post.return_value = Mock(
+            status_code=200,
+            json=lambda: {
+                "access_token": "supersecretaccesstoken",
+                "expires_in": 600,
+                "refresh_token": "secrettoken",
+                "token_type": "Bearer",
+                "username": "test_user",
+            },
+        )
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "test_client_id",
+            "password": "test_client_secret",
+            "extra": {
+                "database": "db",
+                "account": "airflow",
+                "warehouse": "af_wh",
+                "region": "af_region",
+                "role": "af_role",
+                "refresh_token": "secrettoken",
+            },
+        }
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            assert "user" not in hook._get_conn_params
+            assert "password" not in hook._get_conn_params
+            assert "refresh_token" in hook._get_conn_params
+            assert "token" in hook._get_conn_params
+            assert hook._get_conn_params["authenticator"] == "oauth"
+
     def test_should_add_partner_info(self):
         with mock.patch.dict(
             "os.environ",
@@ -816,3 +865,28 @@ class TestPytestSnowflakeHook:
                     "airflow_provider_version": provider_version,
                 }
             )
+
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
+    @mock.patch("requests.post")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
+    def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
+        """Test get_oauth_token method makes the right http request"""
+        BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
+        mock_conn_param.return_value = CONN_PARAMS_OAUTH
+        requests_post.return_value.status_code = 200
+        mock_auth.return_value = BASIC_AUTH
+        hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
+        hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH)
+        requests_post.assert_called_once_with(
+            
f"https://{CONN_PARAMS_OAUTH['account']}.snowflakecomputing.com/oauth/token-request",
+            data={
+                "grant_type": "refresh_token",
+                "refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
+                "redirect_uri": "https://localhost.com";,
+            },
+            headers={"Content-Type": "application/x-www-form-urlencoded"},
+            auth=BASIC_AUTH,
+        )

Reply via email to