This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new f6c7388cfa Create SQLAlchemy engine from connection in DB Hook and added autocommit param to insert_rows method (#40669) f6c7388cfa is described below commit f6c7388cfa70874d84f312a5859a4f510fef0084 Author: David Blain <i...@dabla.be> AuthorDate: Fri Jul 26 21:33:27 2024 +0200 Create SQLAlchemy engine from connection in DB Hook and added autocommit param to insert_rows method (#40669) * refactor: Refactored get_sqlalchemy_engine method of DbApiHook to use the get_conn result to build the sqlalchemy engine * refactor: Added autocommit parameter to insert_rows just like with the run method as this parameter will also be needed once whe have the SQLInsertRowsOperator * refactor: Updated the docstring of the insert_rows method * refactor: Updated sql.pyi * refactor: Try to fix AttributeError: type object 'SkipDBTestsSession' has no attribute 'get_bind' * refactor: Implemented the sqlalchemy_url property for JdbcHook * refactor: Refactored get_sqlalchemy_engine in DbApiHook, if Hook implements the sqlalchemy_url property then use it, otherwise fallback to original implementation with get_uri * refactor: Added SQLAlchemy Inspector property in DbApiHook * refactor: Reformated test_sqlalchemy_url_with_sqlalchemy_scheme in TestJdbcHook * refactor: Fixed static checks in DbApiHook * refactor: Fixed some static checks * docs: Updated docstring of JdbcHook and mentioned importance of sqlalchemy_scheme parameter --------- Co-authored-by: David Blain <david.bl...@infrabel.be> --- airflow/providers/common/sql/hooks/sql.py | 23 ++++++++++++++++++++--- airflow/providers/common/sql/hooks/sql.pyi | 6 ++++-- airflow/providers/jdbc/hooks/jdbc.py | 26 +++++++++++++++++++++++++- airflow/settings.py | 10 ++++++++++ tests/providers/jdbc/hooks/test_jdbc.py | 15 +++++++++++++++ 5 files changed, 74 insertions(+), 6 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 4dba2d843a..dc66cc40bd 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -40,6 +40,7 @@ from urllib.parse import urlparse import sqlparse from more_itertools import chunked from sqlalchemy import create_engine +from sqlalchemy.engine import Inspector from airflow.exceptions import ( AirflowException, @@ -242,7 +243,20 @@ class DbApiHook(BaseHook): """ if engine_kwargs is None: engine_kwargs = {} - return create_engine(self.get_uri(), **engine_kwargs) + engine_kwargs["creator"] = self.get_conn + + try: + url = self.sqlalchemy_url + except NotImplementedError: + url = self.get_uri() + + self.log.debug("url: %s", url) + self.log.debug("engine_kwargs: %s", engine_kwargs) + return create_engine(url=url, **engine_kwargs) + + @property + def inspector(self) -> Inspector: + return Inspector.from_engine(self.get_sqlalchemy_engine()) def get_pandas_df( self, @@ -571,6 +585,7 @@ class DbApiHook(BaseHook): replace=False, *, executemany=False, + autocommit=False, **kwargs, ): """ @@ -585,12 +600,14 @@ class DbApiHook(BaseHook): :param commit_every: The maximum number of rows to insert in one transaction. Set to 0 to insert all rows in one transaction. :param replace: Whether to replace instead of insert - :param executemany: (Deprecated) If True, all rows are inserted at once in + :param executemany: If True, all rows are inserted at once in chunks defined by the commit_every parameter. This only works if all rows have same number of column names, but leads to better performance. + :param autocommit: What to set the connection's autocommit setting to + before executing the query. """ nb_rows = 0 - with self._create_autocommit_connection() as conn: + with self._create_autocommit_connection(autocommit) as conn: conn.commit() with closing(conn.cursor()) as cur: if self.supports_executemany or executemany: diff --git a/airflow/providers/common/sql/hooks/sql.pyi b/airflow/providers/common/sql/hooks/sql.pyi index 27142aeaf2..625ec1d320 100644 --- a/airflow/providers/common/sql/hooks/sql.pyi +++ b/airflow/providers/common/sql/hooks/sql.pyi @@ -42,7 +42,7 @@ from airflow.providers.openlineage.extractors import OperatorLineage as Operator from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo from functools import cached_property as cached_property from pandas import DataFrame as DataFrame -from sqlalchemy.engine import URL as URL +from sqlalchemy.engine import Inspector, URL as URL from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, Sequence, TypeVar, overload T = TypeVar("T") @@ -64,7 +64,6 @@ class DbApiHook(BaseHook): log_sql: Incomplete descriptions: Incomplete def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs) -> None: ... - def get_conn_id(self) -> str: ... @cached_property def placeholder(self): ... @@ -73,6 +72,8 @@ class DbApiHook(BaseHook): @property def sqlalchemy_url(self) -> URL: ... def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = None): ... + @property + def inspector(self) -> Inspector: ... def get_pandas_df( self, sql, parameters: list | tuple | Mapping[str, Any] | None = None, **kwargs ) -> DataFrame: ... @@ -123,6 +124,7 @@ class DbApiHook(BaseHook): replace: bool = False, *, executemany: bool = False, + autocommit: bool = False, **kwargs, ): ... def bulk_dump(self, table, tmp_file) -> None: ... diff --git a/airflow/providers/jdbc/hooks/jdbc.py b/airflow/providers/jdbc/hooks/jdbc.py index cf5d2dd47d..81c63cbe3d 100644 --- a/airflow/providers/jdbc/hooks/jdbc.py +++ b/airflow/providers/jdbc/hooks/jdbc.py @@ -23,7 +23,9 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any import jaydebeapi +from sqlalchemy.engine import URL +from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: @@ -60,7 +62,12 @@ class JdbcHook(DbApiHook): "providers.jdbc" section of the Airflow configuration. If you're enabling these options in Airflow configuration, you should make sure that you trust the users who can edit connections in the UI to not use it maliciously. - 4. Patch the ``JdbcHook.default_driver_path`` and/or ``JdbcHook.default_driver_class`` values in the + 4. Define the "sqlalchemy_scheme" property in the extra of the connection if you want to use the + SQLAlchemy engine from the JdbcHook. When using the JdbcHook, the "sqlalchemy_scheme" will by + default have the "jdbc" value, which is a protocol, not a database scheme or dialect. So in order + to be able to use SQLAlchemy with the JdbcHook, you need to define the "sqlalchemy_scheme" + property in the extra of the connection. + 5. Patch the ``JdbcHook.default_driver_path`` and/or ``JdbcHook.default_driver_class`` values in the ``local_settings.py`` file. See :doc:`/connections/jdbc` for full documentation. @@ -149,6 +156,23 @@ class JdbcHook(DbApiHook): self._driver_class = self.default_driver_class return self._driver_class + @property + def sqlalchemy_url(self) -> URL: + conn = self.get_connection(getattr(self, self.conn_name_attr)) + sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme") + if sqlalchemy_scheme is None: + raise AirflowException( + "The parameter 'sqlalchemy_scheme' must be defined in extra for JDBC connections!" + ) + return URL.create( + drivername=sqlalchemy_scheme, + username=conn.login, + password=conn.password, + host=conn.host, + port=conn.port, + database=conn.schema, + ) + def get_conn(self) -> jaydebeapi.Connection: conn: Connection = self.get_connection(self.get_conn_id()) host: str = conn.host diff --git a/airflow/settings.py b/airflow/settings.py index 6dc9880271..eb4053f50e 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -256,6 +256,16 @@ class SkipDBTestsSession: def remove(*args, **kwargs): pass + def get_bind( + self, + mapper=None, + clause=None, + bind=None, + _sa_skip_events=None, + _sa_skip_for_implicit_returning=False, + ): + pass + class TracebackSession: """ diff --git a/tests/providers/jdbc/hooks/test_jdbc.py b/tests/providers/jdbc/hooks/test_jdbc.py index 6e4387ee1a..cb38ce40ae 100644 --- a/tests/providers/jdbc/hooks/test_jdbc.py +++ b/tests/providers/jdbc/hooks/test_jdbc.py @@ -25,6 +25,7 @@ from unittest.mock import Mock, patch import jaydebeapi import pytest +from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.jdbc.hooks.jdbc import JdbcHook, suppress_and_warn from airflow.utils import db @@ -186,3 +187,17 @@ class TestJdbcHook: with pytest.raises(RuntimeError, match="Spam Egg"): with suppress_and_warn(KeyError): raise RuntimeError("Spam Egg") + + def test_sqlalchemy_url_without_sqlalchemy_scheme(self): + hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"} + hook = get_hook(hook_params=hook_params) + + with pytest.raises(AirflowException): + hook.sqlalchemy_url + + def test_sqlalchemy_url_with_sqlalchemy_scheme(self): + conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql"))) + hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"} + hook = get_hook(conn_params=conn_params, hook_params=hook_params) + + assert str(hook.sqlalchemy_url) == "mssql://login:password@host:1234/schema"