This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e9b9edd76a snowflake: pass through the ocsp_fail_open setting (#46476)
7e9b9edd76a is described below

commit 7e9b9edd76a9535fd6137ab33142f2c735eb4d4b
Author: Joe Crobak <joec...@gmail.com>
AuthorDate: Wed Feb 5 17:02:01 2025 -0500

    snowflake: pass through the ocsp_fail_open setting (#46476)
    
    For sqlalchemy, this needs to be passed in as part of connect_args.
---
 providers/snowflake/docs/connections/snowflake.rst |  1 +
 .../airflow/providers/snowflake/hooks/snowflake.py | 10 +++
 .../snowflake/hooks/test_snowflake.py              | 71 ++++++++++++++++++++++
 3 files changed, 82 insertions(+)

diff --git a/providers/snowflake/docs/connections/snowflake.rst 
b/providers/snowflake/docs/connections/snowflake.rst
index 741d73a62e3..2d7076d120f 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -64,6 +64,7 @@ Extra (optional)
     * ``insecure_mode``: Turn off OCSP certificate checks. For details, see: 
`How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake 
Community 
<https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`_.
     * ``host``: Target Snowflake hostname to connect to (e.g., for local 
testing with LocalStack).
     * ``port``: Target Snowflake port to connect to (e.g., for local testing 
with LocalStack).
+    * ``ocsp_fail_open``: Specify `ocsp_fail_open 
<https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#label-python-ocsp-choosing-fail-open-or-fail-close-mode>`_.
 
 URI format example
 ^^^^^^^^^^^^^^^^^^
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 45e12666b88..5777968b8d8 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -299,6 +299,12 @@ class SnowflakeHook(DbApiHook):
         if snowflake_port:
             conn_config["port"] = snowflake_port
 
+        # if a value for ocsp_fail_open is set, pass it along.
+        # Note the check is for `is not None` so that we can pass along 
`False` as a value.
+        ocsp_fail_open = extra_dict.get("ocsp_fail_open")
+        if ocsp_fail_open is not None:
+            conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open)
+
         return conn_config
 
     def get_uri(self) -> str:
@@ -320,6 +326,7 @@ class SnowflakeHook(DbApiHook):
                     "client_request_mfa_token",
                     "client_store_temporary_credential",
                     "json_result_force_utf8_decoding",
+                    "ocsp_fail_open",
                 ]
             }
         )
@@ -345,6 +352,9 @@ class SnowflakeHook(DbApiHook):
         if "json_result_force_utf8_decoding" in conn_params:
             engine_kwargs.setdefault("connect_args", {})
             engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] = 
True
+        if "ocsp_fail_open" in conn_params:
+            engine_kwargs.setdefault("connect_args", {})
+            engine_kwargs["connect_args"]["ocsp_fail_open"] = 
conn_params["ocsp_fail_open"]
         for key in ["session_parameters", "private_key"]:
             if conn_params.get(key):
                 engine_kwargs.setdefault("connect_args", {})
diff --git 
a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py 
b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py
index 775e9382729..b1a65b4293b 100644
--- a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py
@@ -277,6 +277,60 @@ class TestPytestSnowflakeHook:
                     "json_result_force_utf8_decoding": True,
                 },
             ),
+            (
+                {
+                    **BASE_CONNECTION_KWARGS,
+                    "extra": {
+                        **BASE_CONNECTION_KWARGS["extra"],
+                        "ocsp_fail_open": True,
+                    },
+                },
+                (
+                    "snowflake://user:pw@airflow.af_region/db/public?"
+                    
"application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh"
+                ),
+                {
+                    "account": "airflow",
+                    "application": "AIRFLOW",
+                    "authenticator": "snowflake",
+                    "database": "db",
+                    "password": "pw",
+                    "region": "af_region",
+                    "role": "af_role",
+                    "schema": "public",
+                    "session_parameters": None,
+                    "user": "user",
+                    "warehouse": "af_wh",
+                    "ocsp_fail_open": True,
+                },
+            ),
+            (
+                {
+                    **BASE_CONNECTION_KWARGS,
+                    "extra": {
+                        **BASE_CONNECTION_KWARGS["extra"],
+                        "ocsp_fail_open": False,
+                    },
+                },
+                (
+                    "snowflake://user:pw@airflow.af_region/db/public?"
+                    
"application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh"
+                ),
+                {
+                    "account": "airflow",
+                    "application": "AIRFLOW",
+                    "authenticator": "snowflake",
+                    "database": "db",
+                    "password": "pw",
+                    "region": "af_region",
+                    "role": "af_role",
+                    "schema": "public",
+                    "session_parameters": None,
+                    "user": "user",
+                    "warehouse": "af_wh",
+                    "ocsp_fail_open": False,
+                },
+            ),
         ],
     )
     def test_hook_should_support_prepare_basic_conn_params_and_uri(
@@ -530,6 +584,23 @@ class TestPytestSnowflakeHook:
             assert "private_key" in 
mock_create_engine.call_args.kwargs["connect_args"]
             assert mock_create_engine.return_value == conn
 
+    def test_get_sqlalchemy_engine_should_support_ocsp_fail_open(self):
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["ocsp_fail_open"] = "False"
+
+        with (
+            mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+            
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as 
mock_create_engine,
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn = hook.get_sqlalchemy_engine()
+            mock_create_engine.assert_called_once_with(
+                "snowflake://user:pw@airflow.af_region/db/public"
+                
"?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh",
+                connect_args={"ocsp_fail_open": False},
+            )
+            assert mock_create_engine.return_value == conn
+
     def test_hook_parameters_should_take_precedence(self):
         with mock.patch.dict(
             "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()

Reply via email to