This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0edbe91368 Restrict direct usage of driver params via extras for JDBC 
connection (#31849)
0edbe91368 is described below

commit 0edbe913685e6a21905bc4bb52c6a084bdcdf953
Author: Pankaj Koti <[email protected]>
AuthorDate: Mon Jun 12 23:59:17 2023 +0530

    Restrict direct usage of driver params via extras for JDBC connection 
(#31849)
---
 airflow/providers/jdbc/CHANGELOG.rst               |  14 +++
 airflow/providers/jdbc/hooks/jdbc.py               | 111 ++++++++++++-----
 airflow/providers/jdbc/provider.yaml               |   1 +
 .../connections/jdbc.rst                           |  19 ++-
 tests/providers/jdbc/hooks/test_jdbc.py            | 132 ++++++++++++++-------
 5 files changed, 203 insertions(+), 74 deletions(-)

diff --git a/airflow/providers/jdbc/CHANGELOG.rst 
b/airflow/providers/jdbc/CHANGELOG.rst
index 0bbabfb0e6..9e2a1534ac 100644
--- a/airflow/providers/jdbc/CHANGELOG.rst
+++ b/airflow/providers/jdbc/CHANGELOG.rst
@@ -24,6 +24,20 @@
 Changelog
 ---------
 
+4.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+To configure driver parameters (driver path and driver class), you can use the 
following methods:
+
+1. Supply them as constructor arguments when instantiating the hook.
+2. Set the "driver_path" and/or "driver_class" parameters in the "hook_params" 
dictionary when creating the hook using SQL operators.
+3. Set the "driver_path" and/or "driver_class" extra in the connection and 
correspondingly enable the "allow_driver_path_in_extra" and/or 
"allow_driver_class_in_extra" options in the "providers.jdbc" section of the 
Airflow configuration.
+4. Patch the "JdbcHook.default_driver_path" and/or 
"JdbcHook.default_driver_class" values in the "local_settings.py" file.
+
+
 3.4.0
 .....
 
diff --git a/airflow/providers/jdbc/hooks/jdbc.py 
b/airflow/providers/jdbc/hooks/jdbc.py
index 1d9a3425d6..0a1656abd4 100644
--- a/airflow/providers/jdbc/hooks/jdbc.py
+++ b/airflow/providers/jdbc/hooks/jdbc.py
@@ -31,6 +31,25 @@ class JdbcHook(DbApiHook):
     JDBC URL, username and password will be taken from the predefined 
connection.
     Note that the whole JDBC URL must be specified in the "host" field in the 
DB.
     Raises an airflow error if the given connection id doesn't exist.
+
+    To configure driver parameters, you can use the following methods:
+        1. Supply them as constructor arguments when instantiating the hook.
+        2. Set the "driver_path" and/or "driver_class" parameters in the 
"hook_params" dictionary when
+           creating the hook using SQL operators.
+        3. Set the "driver_path" and/or "driver_class" extra in the connection 
and correspondingly enable
+           the "allow_driver_path_in_extra" and/or 
"allow_driver_class_in_extra" options in the
+           "providers.jdbc" section of the Airflow configuration. If you're 
enabling these options in Airflow
+           configuration, you should make sure that you trust the users who 
can edit connections in the UI
+           to not use it maliciously.
+        4. Patch the ``JdbcHook.default_driver_path`` and/or 
``JdbcHook.default_driver_class`` values in the
+           "local_settings.py" file.
+
+    See :doc:`/connections/jdbc` for full documentation.
+
+    :param args: passed to DbApiHook
+    :param driver_path: path to the JDBC driver jar file. See above for more 
info
+    :param driver_class: name of the JDBC driver class. See above for more info
+    :param kwargs: passed to DbApiHook
     """
 
     conn_name_attr = "jdbc_conn_id"
@@ -39,57 +58,89 @@ class JdbcHook(DbApiHook):
     hook_name = "JDBC Connection"
     supports_autocommit = True
 
-    @staticmethod
-    def get_connection_form_widgets() -> dict[str, Any]:
-        """Get connection widgets to add to connection form."""
-        from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
-        from flask_babel import lazy_gettext
-        from wtforms import StringField
+    default_driver_path: str | None = None
+    default_driver_class: str | None = None
 
-        return {
-            "drv_path": StringField(lazy_gettext("Driver Path"), 
widget=BS3TextFieldWidget()),
-            "drv_clsname": StringField(lazy_gettext("Driver Class"), 
widget=BS3TextFieldWidget()),
-        }
+    def __init__(
+        self,
+        *args,
+        driver_path: str | None = None,
+        driver_class: str | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self._driver_path = driver_path
+        self._driver_class = driver_class
 
     @staticmethod
     def get_ui_field_behaviour() -> dict[str, Any]:
         """Get custom field behaviour."""
         return {
-            "hidden_fields": ["port", "schema", "extra"],
+            "hidden_fields": ["port", "schema"],
             "relabeling": {"host": "Connection URL"},
         }
 
-    def _get_field(self, extras: dict, field_name: str):
-        """Get field from extra.
+    @property
+    def connection_extra_lower(self) -> dict:
+        """
+        ``connection.extra_dejson`` but where keys are converted to lower case.
 
-        This first checks the short name, then check for prefixed name for
-        backward compatibility.
+        This is used internally for case-insensitive access of jdbc params.
         """
-        backcompat_prefix = "extra__jdbc__"
-        if field_name.startswith("extra__"):
-            raise ValueError(
-                f"Got prefixed name {field_name}; please remove the 
'{backcompat_prefix}' prefix "
-                "when using this method."
-            )
-        if field_name in extras:
-            return extras[field_name] or None
-        prefixed_name = f"{backcompat_prefix}{field_name}"
-        return extras.get(prefixed_name) or None
+        conn = self.get_connection(getattr(self, self.conn_name_attr))
+        return {k.lower(): v for k, v in conn.extra_dejson.items()}
+
+    @property
+    def driver_path(self) -> str | None:
+        from airflow.configuration import conf
+
+        extra_driver_path = self.connection_extra_lower.get("driver_path")
+        if extra_driver_path:
+            if conf.getboolean("providers.jdbc", "allow_driver_path_in_extra", 
fallback=False):
+                self._driver_path = extra_driver_path
+            else:
+                self.log.warning(
+                    "You have supplied 'driver_path' via connection extra but 
it will not be used. In order "
+                    "to use 'driver_path' from extra you must set airflow 
config setting "
+                    "`allow_driver_path_in_extra = True` in section 
`providers.jdbc`. Alternatively you may "
+                    "specify it via 'driver_path' parameter of the hook 
constructor or via 'hook_params' "
+                    "dictionary with key 'driver_path' if using SQL operators."
+                )
+        if not self._driver_path:
+            self._driver_path = self.default_driver_path
+        return self._driver_path
+
+    @property
+    def driver_class(self) -> str | None:
+        from airflow.configuration import conf
+
+        extra_driver_class = self.connection_extra_lower.get("driver_class")
+        if extra_driver_class:
+            if conf.getboolean("providers.jdbc", 
"allow_driver_class_in_extra", fallback=False):
+                self._driver_class = extra_driver_class
+            else:
+                self.log.warning(
+                    "You have supplied 'driver_class' via connection extra but 
it will not be used. In order "
+                    "to use 'driver_class' from extra you must set airflow 
config setting "
+                    "`allow_driver_class_in_extra = True` in section 
`providers.jdbc`. Alternatively you may "
+                    "specify it via 'driver_class' parameter of the hook 
constructor or via 'hook_params' "
+                    "dictionary with key 'driver_class' if using SQL 
operators."
+                )
+        if not self._driver_class:
+            self._driver_class = self.default_driver_class
+        return self._driver_class
 
     def get_conn(self) -> jaydebeapi.Connection:
         conn: Connection = self.get_connection(getattr(self, 
self.conn_name_attr))
-        extras = conn.extra_dejson
         host: str = conn.host
         login: str = conn.login
         psw: str = conn.password
-        jdbc_driver_loc: str | None = self._get_field(extras, "drv_path")
-        jdbc_driver_name: str | None = self._get_field(extras, "drv_clsname")
 
         conn = jaydebeapi.connect(
-            jclassname=jdbc_driver_name,
+            jclassname=self.driver_class,
             url=str(host),
             driver_args=[str(login), str(psw)],
-            jars=jdbc_driver_loc.split(",") if jdbc_driver_loc else None,
+            jars=self.driver_path.split(",") if self.driver_path else None,
         )
         return conn
 
diff --git a/airflow/providers/jdbc/provider.yaml 
b/airflow/providers/jdbc/provider.yaml
index f4fc9f88ab..d180e3ac00 100644
--- a/airflow/providers/jdbc/provider.yaml
+++ b/airflow/providers/jdbc/provider.yaml
@@ -23,6 +23,7 @@ description: |
 
 suspended: false
 versions:
+  - 4.0.0
   - 3.4.0
   - 3.3.0
   - 3.2.1
diff --git a/docs/apache-airflow-providers-jdbc/connections/jdbc.rst 
b/docs/apache-airflow-providers-jdbc/connections/jdbc.rst
index a646188212..f00b5f4e18 100644
--- a/docs/apache-airflow-providers-jdbc/connections/jdbc.rst
+++ b/docs/apache-airflow-providers-jdbc/connections/jdbc.rst
@@ -43,6 +43,19 @@ Port (optional)
 Extra (optional)
     Specify the extra parameters (as json dictionary) that can be used in JDBC 
connection. The following parameters out of the standard python parameters are 
supported:
 
-    * ``conn_prefix`` - Used to build the connection url in ``JdbcOperator``, 
added in front of host (``conn_prefix`` ``host`` [: ``port`` ] / ``schema``)
-    * ``drv_clsname`` - Full qualified Java class name of the JDBC driver. For 
``JdbcOperator``.
-    * ``drv_path`` - Jar filename or sequence of filenames for the JDBC driver 
libs. For ``JdbcOperator``.
+    - ``driver_class``
+        * Full qualified Java class name of the JDBC driver. For 
``JdbcOperator``.
+          Note that this is only considered if ``allow_driver_class_in_extra`` 
is set to True in airflow config section
+          ``providers.jdbc`` (by default it is not considered).  Note: if 
setting this config from env vars, use
+          ``AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_CLASS_IN_EXTRA=true``.
+
+    - ``driver_path``
+        * Jar filename or sequence of filenames for the JDBC driver libs. For 
``JdbcOperator``.
+          Note that this is only considered if ``allow_driver_path_in_extra`` 
is set to True in airflow config section
+          ``providers.jdbc`` (by default it is not considered).  Note: if 
setting this config from env vars, use
+          ``AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_PATH_IN_EXTRA=true``.
+
+    .. note::
+        Setting ``allow_driver_path_in_extra`` or 
``allow_driver_class_in_extra`` to True allows users to set the driver
+        via the Airflow Connection's ``extra`` field.  By default this is not 
allowed.  If enabling this functionality,
+        you should make sure that you trust the users who can edit connections 
in the UI to not use it maliciously.
diff --git a/tests/providers/jdbc/hooks/test_jdbc.py 
b/tests/providers/jdbc/hooks/test_jdbc.py
index 50913b0237..4374a0cad6 100644
--- a/tests/providers/jdbc/hooks/test_jdbc.py
+++ b/tests/providers/jdbc/hooks/test_jdbc.py
@@ -18,12 +18,10 @@
 from __future__ import annotations
 
 import json
-import os
+import logging
+from unittest import mock
 from unittest.mock import Mock, patch
 
-import pytest
-from pytest import param
-
 from airflow.models import Connection
 from airflow.providers.jdbc.hooks.jdbc import JdbcHook
 from airflow.utils import db
@@ -31,6 +29,22 @@ from airflow.utils import db
 jdbc_conn_mock = Mock(name="jdbc_conn")
 
 
+def get_hook(hook_params=None, conn_params=None):
+    hook_params = hook_params or {}
+    conn_params = conn_params or {}
+    connection = Connection(
+        **{
+            **dict(login="login", password="password", host="host", 
schema="schema", port=1234),
+            **conn_params,
+        }
+    )
+
+    hook = JdbcHook(**hook_params)
+    hook.get_connection = Mock()
+    hook.get_connection.return_value = connection
+    return hook
+
+
 class TestJdbcHook:
     def setup_method(self):
         db.merge_conn(
@@ -41,8 +55,8 @@ class TestJdbcHook:
                 port=443,
                 extra=json.dumps(
                     {
-                        "extra__jdbc__drv_path": 
"/path1/test.jar,/path2/t.jar2",
-                        "extra__jdbc__drv_clsname": "com.driver.main",
+                        "driver_path": "/path1/test.jar,/path2/t.jar2",
+                        "driver_class": "com.driver.main",
                     }
                 ),
             )
@@ -70,39 +84,75 @@ class TestJdbcHook:
         jdbc_hook.get_autocommit(jdbc_conn)
         jdbc_conn.jconn.getAutoCommit.assert_called_once_with()
 
-    @pytest.mark.parametrize(
-        "uri",
-        [
-            param(
-                "a://?extra__jdbc__drv_path=abc&extra__jdbc__drv_clsname=abc",
-                id="prefix",
-            ),
-            param("a://?drv_path=abc&drv_clsname=abc", id="no-prefix"),
-        ],
-    )
-    @patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
-    def test_backcompat_prefix_works(self, mock_connect, uri):
-        with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
-            hook = JdbcHook("my_conn")
-            hook.get_conn()
-            mock_connect.assert_called_with(
-                jclassname="abc",
-                url="",
-                driver_args=["None", "None"],
-                jars="abc".split(","),
-            )
+    def test_driver_hook_params(self):
+        hook = get_hook(hook_params=dict(driver_path="Blah driver path", 
driver_class="Blah driver class"))
+        assert hook.driver_path == "Blah driver path"
+        assert hook.driver_class == "Blah driver class"
 
-    @patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
-    def test_backcompat_prefix_both_prefers_short(self, mock_connect):
-        with patch.dict(
-            os.environ,
-            {"AIRFLOW_CONN_MY_CONN": 
"a://?drv_path=non-prefixed&extra__jdbc__drv_path=prefixed"},
-        ):
-            hook = JdbcHook("my_conn")
-            hook.get_conn()
-            mock_connect.assert_called_with(
-                jclassname=None,
-                url="",
-                driver_args=["None", "None"],
-                jars="non-prefixed".split(","),
-            )
+    def test_driver_in_extra_not_used(self):
+        conn_params = dict(
+            extra=json.dumps(dict(driver_path="ExtraDriverPath", 
driver_class="ExtraDriverClass"))
+        )
+        hook_params = {"driver_path": "ParamDriverPath", "driver_class": 
"ParamDriverClass"}
+        hook = get_hook(conn_params=conn_params, hook_params=hook_params)
+        assert hook.driver_path == "ParamDriverPath"
+        assert hook.driver_class == "ParamDriverClass"
+
+    def test_driver_extra_raises_warning_by_default(self, caplog):
+        with caplog.at_level(logging.WARNING, 
logger="airflow.providers.jdbc.hooks.test_jdbc"):
+            driver_path = get_hook(conn_params=dict(extra='{"driver_path": 
"Blah driver path"}')).driver_path
+            assert (
+                "You have supplied 'driver_path' via connection extra but it 
will not be used"
+            ) in caplog.text
+            assert driver_path is None
+
+            driver_class = get_hook(
+                conn_params=dict(extra='{"driver_class": "Blah driver class"}')
+            ).driver_class
+            assert (
+                "You have supplied 'driver_class' via connection extra but it 
will not be used"
+            ) in caplog.text
+            assert driver_class is None
+
+    @mock.patch.dict("os.environ", 
{"AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_PATH_IN_EXTRA": "TRUE"})
+    @mock.patch.dict("os.environ", 
{"AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_CLASS_IN_EXTRA": "TRUE"})
+    def test_driver_extra_works_when_allow_driver_extra(self):
+        hook = get_hook(
+            conn_params=dict(extra='{"driver_path": "Blah driver path", 
"driver_class": "Blah driver class"}')
+        )
+        assert hook.driver_path == "Blah driver path"
+        assert hook.driver_class == "Blah driver class"
+
+    def test_default_driver_set(self):
+        with patch.object(JdbcHook, "default_driver_path", "Blah driver path") 
as _, patch.object(
+            JdbcHook, "default_driver_class", "Blah driver class"
+        ) as _:
+            hook = get_hook()
+            assert hook.driver_path == "Blah driver path"
+            assert hook.driver_class == "Blah driver class"
+
+    def test_driver_none_by_default(self):
+        hook = get_hook()
+        assert hook.driver_path is None
+        assert hook.driver_class is None
+
+    def 
test_driver_extra_raises_warning_and_returns_default_driver_by_default(self, 
caplog):
+        with patch.object(JdbcHook, "default_driver_path", "Blah driver path"):
+            with caplog.at_level(logging.WARNING, 
logger="airflow.providers.jdbc.hooks.test_jdbc"):
+                driver_path = get_hook(
+                    conn_params=dict(extra='{"driver_path": "Blah driver 
path2"}')
+                ).driver_path
+                assert (
+                    "have supplied 'driver_path' via connection extra but it 
will not be used"
+                ) in caplog.text
+                assert driver_path == "Blah driver path"
+
+        with patch.object(JdbcHook, "default_driver_class", "Blah driver 
class"):
+            with caplog.at_level(logging.WARNING, 
logger="airflow.providers.jdbc.hooks.test_jdbc"):
+                driver_class = get_hook(
+                    conn_params=dict(extra='{"driver_class": "Blah driver 
class2"}')
+                ).driver_class
+                assert (
+                    "have supplied 'driver_class' via connection extra but it 
will not be used"
+                ) in caplog.text
+                assert driver_class == "Blah driver class"

Reply via email to