This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 7e9b9edd76a snowflake: pass through the ocsp_fail_open setting (#46476) 7e9b9edd76a is described below commit 7e9b9edd76a9535fd6137ab33142f2c735eb4d4b Author: Joe Crobak <joec...@gmail.com> AuthorDate: Wed Feb 5 17:02:01 2025 -0500 snowflake: pass through the ocsp_fail_open setting (#46476) For sqlalchemy, this needs to be passed in as part of connect_args. --- providers/snowflake/docs/connections/snowflake.rst | 1 + .../airflow/providers/snowflake/hooks/snowflake.py | 10 +++ .../snowflake/hooks/test_snowflake.py | 71 ++++++++++++++++++++++ 3 files changed, 82 insertions(+) diff --git a/providers/snowflake/docs/connections/snowflake.rst b/providers/snowflake/docs/connections/snowflake.rst index 741d73a62e3..2d7076d120f 100644 --- a/providers/snowflake/docs/connections/snowflake.rst +++ b/providers/snowflake/docs/connections/snowflake.rst @@ -64,6 +64,7 @@ Extra (optional) * ``insecure_mode``: Turn off OCSP certificate checks. For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community <https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`_. * ``host``: Target Snowflake hostname to connect to (e.g., for local testing with LocalStack). * ``port``: Target Snowflake port to connect to (e.g., for local testing with LocalStack). + * ``ocsp_fail_open``: Specify `ocsp_fail_open <https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#label-python-ocsp-choosing-fail-open-or-fail-close-mode>`_. URI format example ^^^^^^^^^^^^^^^^^^ diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 45e12666b88..5777968b8d8 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -299,6 +299,12 @@ class SnowflakeHook(DbApiHook): if snowflake_port: conn_config["port"] = snowflake_port + # if a value for ocsp_fail_open is set, pass it along. + # Note the check is for `is not None` so that we can pass along `False` as a value. + ocsp_fail_open = extra_dict.get("ocsp_fail_open") + if ocsp_fail_open is not None: + conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open) + return conn_config def get_uri(self) -> str: @@ -320,6 +326,7 @@ class SnowflakeHook(DbApiHook): "client_request_mfa_token", "client_store_temporary_credential", "json_result_force_utf8_decoding", + "ocsp_fail_open", ] } ) @@ -345,6 +352,9 @@ class SnowflakeHook(DbApiHook): if "json_result_force_utf8_decoding" in conn_params: engine_kwargs.setdefault("connect_args", {}) engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] = True + if "ocsp_fail_open" in conn_params: + engine_kwargs.setdefault("connect_args", {}) + engine_kwargs["connect_args"]["ocsp_fail_open"] = conn_params["ocsp_fail_open"] for key in ["session_parameters", "private_key"]: if conn_params.get(key): engine_kwargs.setdefault("connect_args", {}) diff --git a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py index 775e9382729..b1a65b4293b 100644 --- a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py @@ -277,6 +277,60 @@ class TestPytestSnowflakeHook: "json_result_force_utf8_decoding": True, }, ), + ( + { + **BASE_CONNECTION_KWARGS, + "extra": { + **BASE_CONNECTION_KWARGS["extra"], + "ocsp_fail_open": True, + }, + }, + ( + "snowflake://user:pw@airflow.af_region/db/public?" + "application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh" + ), + { + "account": "airflow", + "application": "AIRFLOW", + "authenticator": "snowflake", + "database": "db", + "password": "pw", + "region": "af_region", + "role": "af_role", + "schema": "public", + "session_parameters": None, + "user": "user", + "warehouse": "af_wh", + "ocsp_fail_open": True, + }, + ), + ( + { + **BASE_CONNECTION_KWARGS, + "extra": { + **BASE_CONNECTION_KWARGS["extra"], + "ocsp_fail_open": False, + }, + }, + ( + "snowflake://user:pw@airflow.af_region/db/public?" + "application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh" + ), + { + "account": "airflow", + "application": "AIRFLOW", + "authenticator": "snowflake", + "database": "db", + "password": "pw", + "region": "af_region", + "role": "af_role", + "schema": "public", + "session_parameters": None, + "user": "user", + "warehouse": "af_wh", + "ocsp_fail_open": False, + }, + ), ], ) def test_hook_should_support_prepare_basic_conn_params_and_uri( @@ -530,6 +584,23 @@ class TestPytestSnowflakeHook: assert "private_key" in mock_create_engine.call_args.kwargs["connect_args"] assert mock_create_engine.return_value == conn + def test_get_sqlalchemy_engine_should_support_ocsp_fail_open(self): + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs["extra"]["ocsp_fail_open"] = "False" + + with ( + mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), + mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine, + ): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + conn = hook.get_sqlalchemy_engine() + mock_create_engine.assert_called_once_with( + "snowflake://user:pw@airflow.af_region/db/public" + "?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh", + connect_args={"ocsp_fail_open": False}, + ) + assert mock_create_engine.return_value == conn + def test_hook_parameters_should_take_precedence(self): with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()