This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f6c7388cfa Create SQLAlchemy engine from connection in DB Hook and 
added autocommit param to insert_rows method (#40669)
f6c7388cfa is described below

commit f6c7388cfa70874d84f312a5859a4f510fef0084
Author: David Blain <i...@dabla.be>
AuthorDate: Fri Jul 26 21:33:27 2024 +0200

    Create SQLAlchemy engine from connection in DB Hook and added autocommit 
param to insert_rows method (#40669)
    
    * refactor: Refactored get_sqlalchemy_engine method of DbApiHook to use the 
get_conn result to build the sqlalchemy engine
    
    * refactor: Added autocommit parameter to insert_rows just like with the 
run method as this parameter will also be needed once whe have the 
SQLInsertRowsOperator
    
    * refactor: Updated the docstring of the insert_rows method
    
    * refactor: Updated sql.pyi
    
    * refactor: Try to fix AttributeError: type object 'SkipDBTestsSession' has 
no attribute 'get_bind'
    
    * refactor: Implemented the sqlalchemy_url property for JdbcHook
    
    * refactor: Refactored get_sqlalchemy_engine in DbApiHook, if Hook 
implements the sqlalchemy_url property then use it, otherwise fallback to 
original implementation with get_uri
    
    * refactor: Added SQLAlchemy Inspector property in DbApiHook
    
    * refactor: Reformated test_sqlalchemy_url_with_sqlalchemy_scheme in 
TestJdbcHook
    
    * refactor: Fixed static checks in DbApiHook
    
    * refactor: Fixed some static checks
    
    * docs: Updated docstring of JdbcHook and mentioned importance of 
sqlalchemy_scheme parameter
    
    ---------
    
    Co-authored-by: David Blain <david.bl...@infrabel.be>
---
 airflow/providers/common/sql/hooks/sql.py  | 23 ++++++++++++++++++++---
 airflow/providers/common/sql/hooks/sql.pyi |  6 ++++--
 airflow/providers/jdbc/hooks/jdbc.py       | 26 +++++++++++++++++++++++++-
 airflow/settings.py                        | 10 ++++++++++
 tests/providers/jdbc/hooks/test_jdbc.py    | 15 +++++++++++++++
 5 files changed, 74 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/common/sql/hooks/sql.py 
b/airflow/providers/common/sql/hooks/sql.py
index 4dba2d843a..dc66cc40bd 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -40,6 +40,7 @@ from urllib.parse import urlparse
 import sqlparse
 from more_itertools import chunked
 from sqlalchemy import create_engine
+from sqlalchemy.engine import Inspector
 
 from airflow.exceptions import (
     AirflowException,
@@ -242,7 +243,20 @@ class DbApiHook(BaseHook):
         """
         if engine_kwargs is None:
             engine_kwargs = {}
-        return create_engine(self.get_uri(), **engine_kwargs)
+        engine_kwargs["creator"] = self.get_conn
+
+        try:
+            url = self.sqlalchemy_url
+        except NotImplementedError:
+            url = self.get_uri()
+
+        self.log.debug("url: %s", url)
+        self.log.debug("engine_kwargs: %s", engine_kwargs)
+        return create_engine(url=url, **engine_kwargs)
+
+    @property
+    def inspector(self) -> Inspector:
+        return Inspector.from_engine(self.get_sqlalchemy_engine())
 
     def get_pandas_df(
         self,
@@ -571,6 +585,7 @@ class DbApiHook(BaseHook):
         replace=False,
         *,
         executemany=False,
+        autocommit=False,
         **kwargs,
     ):
         """
@@ -585,12 +600,14 @@ class DbApiHook(BaseHook):
         :param commit_every: The maximum number of rows to insert in one
             transaction. Set to 0 to insert all rows in one transaction.
         :param replace: Whether to replace instead of insert
-        :param executemany: (Deprecated) If True, all rows are inserted at 
once in
+        :param executemany: If True, all rows are inserted at once in
             chunks defined by the commit_every parameter. This only works if 
all rows
             have same number of column names, but leads to better performance.
+        :param autocommit: What to set the connection's autocommit setting to
+            before executing the query.
         """
         nb_rows = 0
-        with self._create_autocommit_connection() as conn:
+        with self._create_autocommit_connection(autocommit) as conn:
             conn.commit()
             with closing(conn.cursor()) as cur:
                 if self.supports_executemany or executemany:
diff --git a/airflow/providers/common/sql/hooks/sql.pyi 
b/airflow/providers/common/sql/hooks/sql.pyi
index 27142aeaf2..625ec1d320 100644
--- a/airflow/providers/common/sql/hooks/sql.pyi
+++ b/airflow/providers/common/sql/hooks/sql.pyi
@@ -42,7 +42,7 @@ from airflow.providers.openlineage.extractors import 
OperatorLineage as Operator
 from airflow.providers.openlineage.sqlparser import DatabaseInfo as 
DatabaseInfo
 from functools import cached_property as cached_property
 from pandas import DataFrame as DataFrame
-from sqlalchemy.engine import URL as URL
+from sqlalchemy.engine import Inspector, URL as URL
 from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, 
Sequence, TypeVar, overload
 
 T = TypeVar("T")
@@ -64,7 +64,6 @@ class DbApiHook(BaseHook):
     log_sql: Incomplete
     descriptions: Incomplete
     def __init__(self, *args, schema: str | None = None, log_sql: bool = True, 
**kwargs) -> None: ...
-
     def get_conn_id(self) -> str: ...
     @cached_property
     def placeholder(self): ...
@@ -73,6 +72,8 @@ class DbApiHook(BaseHook):
     @property
     def sqlalchemy_url(self) -> URL: ...
     def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = None): 
...
+    @property
+    def inspector(self) -> Inspector: ...
     def get_pandas_df(
         self, sql, parameters: list | tuple | Mapping[str, Any] | None = None, 
**kwargs
     ) -> DataFrame: ...
@@ -123,6 +124,7 @@ class DbApiHook(BaseHook):
         replace: bool = False,
         *,
         executemany: bool = False,
+        autocommit: bool = False,
         **kwargs,
     ): ...
     def bulk_dump(self, table, tmp_file) -> None: ...
diff --git a/airflow/providers/jdbc/hooks/jdbc.py 
b/airflow/providers/jdbc/hooks/jdbc.py
index cf5d2dd47d..81c63cbe3d 100644
--- a/airflow/providers/jdbc/hooks/jdbc.py
+++ b/airflow/providers/jdbc/hooks/jdbc.py
@@ -23,7 +23,9 @@ from contextlib import contextmanager
 from typing import TYPE_CHECKING, Any
 
 import jaydebeapi
+from sqlalchemy.engine import URL
 
+from airflow.exceptions import AirflowException
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 
 if TYPE_CHECKING:
@@ -60,7 +62,12 @@ class JdbcHook(DbApiHook):
            "providers.jdbc" section of the Airflow configuration. If you're 
enabling these options in Airflow
            configuration, you should make sure that you trust the users who 
can edit connections in the UI
            to not use it maliciously.
-        4. Patch the ``JdbcHook.default_driver_path`` and/or 
``JdbcHook.default_driver_class`` values in the
+        4. Define the "sqlalchemy_scheme" property in the extra of the 
connection if you want to use the
+           SQLAlchemy engine from the JdbcHook.  When using the JdbcHook, the 
"sqlalchemy_scheme" will by
+           default have the "jdbc" value, which is a protocol, not a database 
scheme or dialect.  So in order
+           to be able to use SQLAlchemy with the JdbcHook, you need to define 
the "sqlalchemy_scheme"
+           property in the extra of the connection.
+        5. Patch the ``JdbcHook.default_driver_path`` and/or 
``JdbcHook.default_driver_class`` values in the
            ``local_settings.py`` file.
 
     See :doc:`/connections/jdbc` for full documentation.
@@ -149,6 +156,23 @@ class JdbcHook(DbApiHook):
             self._driver_class = self.default_driver_class
         return self._driver_class
 
+    @property
+    def sqlalchemy_url(self) -> URL:
+        conn = self.get_connection(getattr(self, self.conn_name_attr))
+        sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme")
+        if sqlalchemy_scheme is None:
+            raise AirflowException(
+                "The parameter 'sqlalchemy_scheme' must be defined in extra 
for JDBC connections!"
+            )
+        return URL.create(
+            drivername=sqlalchemy_scheme,
+            username=conn.login,
+            password=conn.password,
+            host=conn.host,
+            port=conn.port,
+            database=conn.schema,
+        )
+
     def get_conn(self) -> jaydebeapi.Connection:
         conn: Connection = self.get_connection(self.get_conn_id())
         host: str = conn.host
diff --git a/airflow/settings.py b/airflow/settings.py
index 6dc9880271..eb4053f50e 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -256,6 +256,16 @@ class SkipDBTestsSession:
     def remove(*args, **kwargs):
         pass
 
+    def get_bind(
+        self,
+        mapper=None,
+        clause=None,
+        bind=None,
+        _sa_skip_events=None,
+        _sa_skip_for_implicit_returning=False,
+    ):
+        pass
+
 
 class TracebackSession:
     """
diff --git a/tests/providers/jdbc/hooks/test_jdbc.py 
b/tests/providers/jdbc/hooks/test_jdbc.py
index 6e4387ee1a..cb38ce40ae 100644
--- a/tests/providers/jdbc/hooks/test_jdbc.py
+++ b/tests/providers/jdbc/hooks/test_jdbc.py
@@ -25,6 +25,7 @@ from unittest.mock import Mock, patch
 import jaydebeapi
 import pytest
 
+from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.jdbc.hooks.jdbc import JdbcHook, suppress_and_warn
 from airflow.utils import db
@@ -186,3 +187,17 @@ class TestJdbcHook:
         with pytest.raises(RuntimeError, match="Spam Egg"):
             with suppress_and_warn(KeyError):
                 raise RuntimeError("Spam Egg")
+
+    def test_sqlalchemy_url_without_sqlalchemy_scheme(self):
+        hook_params = {"driver_path": "ParamDriverPath", "driver_class": 
"ParamDriverClass"}
+        hook = get_hook(hook_params=hook_params)
+
+        with pytest.raises(AirflowException):
+            hook.sqlalchemy_url
+
+    def test_sqlalchemy_url_with_sqlalchemy_scheme(self):
+        conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql")))
+        hook_params = {"driver_path": "ParamDriverPath", "driver_class": 
"ParamDriverClass"}
+        hook = get_hook(conn_params=conn_params, hook_params=hook_params)
+
+        assert str(hook.sqlalchemy_url) == 
"mssql://login:password@host:1234/schema"

Reply via email to