This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a10b3fccb09 Allow configuration of sqlalchemy query parameter for
JdbcHook and PostgresHook through extras (#44910)
a10b3fccb09 is described below
commit a10b3fccb09805397e607df4cd3ded6194d20170
Author: David Blain <[email protected]>
AuthorDate: Wed Dec 18 16:39:46 2024 +0100
Allow configuration of sqlalchemy query parameter for JdbcHook and
PostgresHook through extras (#44910)
---
providers/src/airflow/providers/jdbc/hooks/jdbc.py | 4 +++
.../airflow/providers/postgres/hooks/postgres.py | 28 +++++++++++------
providers/tests/jdbc/hooks/test_jdbc.py | 17 ++++++++++
providers/tests/postgres/hooks/test_postgres.py | 36 +++++++++++++++++++++-
4 files changed, 74 insertions(+), 11 deletions(-)
diff --git a/providers/src/airflow/providers/jdbc/hooks/jdbc.py
b/providers/src/airflow/providers/jdbc/hooks/jdbc.py
index 808b946bd97..07b5fc42d9a 100644
--- a/providers/src/airflow/providers/jdbc/hooks/jdbc.py
+++ b/providers/src/airflow/providers/jdbc/hooks/jdbc.py
@@ -152,6 +152,9 @@ class JdbcHook(DbApiHook):
@property
def sqlalchemy_url(self) -> URL:
conn = self.connection
+ sqlalchemy_query = conn.extra_dejson.get("sqlalchemy_query", {})
+ if not isinstance(sqlalchemy_query, dict):
+ raise AirflowException("The parameter 'sqlalchemy_query' must be
of type dict!")
sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme")
if sqlalchemy_scheme is None:
raise AirflowException(
@@ -164,6 +167,7 @@ class JdbcHook(DbApiHook):
host=conn.host,
port=conn.port,
database=conn.schema,
+ query=sqlalchemy_query,
)
def get_sqlalchemy_engine(self, engine_kwargs=None):
diff --git a/providers/src/airflow/providers/postgres/hooks/postgres.py
b/providers/src/airflow/providers/postgres/hooks/postgres.py
index f5dcfe2df49..9b657c14416 100644
--- a/providers/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/src/airflow/providers/postgres/hooks/postgres.py
@@ -29,6 +29,7 @@ import psycopg2.extras
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
from sqlalchemy.engine import URL
+from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
if TYPE_CHECKING:
@@ -85,6 +86,17 @@ class PostgresHook(DbApiHook):
hook_name = "Postgres"
supports_autocommit = True
supports_executemany = True
+ ignored_extra_options = {
+ "iam",
+ "redshift",
+ "redshift-serverless",
+ "cursor",
+ "cluster-identifier",
+ "workgroup-name",
+ "aws_conn_id",
+ "sqlalchemy_scheme",
+ "sqlalchemy_query",
+ }
def __init__(
self, *args, options: str | None = None, enable_log_db_messages: bool
= False, **kwargs
@@ -97,7 +109,10 @@ class PostgresHook(DbApiHook):
@property
def sqlalchemy_url(self) -> URL:
- conn = self.get_connection(self.get_conn_id())
+ conn = self.connection
+ query = conn.extra_dejson.get("sqlalchemy_query", {})
+ if not isinstance(query, dict):
+ raise AirflowException("The parameter 'sqlalchemy_query' must be
of type dict!")
return URL.create(
drivername="postgresql",
username=conn.login,
@@ -105,6 +120,7 @@ class PostgresHook(DbApiHook):
host=conn.host,
port=conn.port,
database=self.database or conn.schema,
+ query=query,
)
def _get_cursor(self, raw_cursor: str) -> CursorType:
@@ -143,15 +159,7 @@ class PostgresHook(DbApiHook):
conn_args["options"] = self.options
for arg_name, arg_val in conn.extra_dejson.items():
- if arg_name not in [
- "iam",
- "redshift",
- "redshift-serverless",
- "cursor",
- "cluster-identifier",
- "workgroup-name",
- "aws_conn_id",
- ]:
+ if arg_name not in self.ignored_extra_options:
conn_args[arg_name] = arg_val
self.conn = psycopg2.connect(**conn_args)
diff --git a/providers/tests/jdbc/hooks/test_jdbc.py
b/providers/tests/jdbc/hooks/test_jdbc.py
index 73015b5b522..ce4e5266234 100644
--- a/providers/tests/jdbc/hooks/test_jdbc.py
+++ b/providers/tests/jdbc/hooks/test_jdbc.py
@@ -219,6 +219,23 @@ class TestJdbcHook:
assert str(hook.sqlalchemy_url) ==
"mssql://login:password@host:1234/schema"
+ def test_sqlalchemy_url_with_sqlalchemy_scheme_and_query(self):
+ conn_params = dict(
+ extra=json.dumps(dict(sqlalchemy_scheme="mssql",
sqlalchemy_query={"servicename": "test"}))
+ )
+ hook_params = {"driver_path": "ParamDriverPath", "driver_class":
"ParamDriverClass"}
+ hook = get_hook(conn_params=conn_params, hook_params=hook_params)
+
+ assert str(hook.sqlalchemy_url) ==
"mssql://login:password@host:1234/schema?servicename=test"
+
+ def test_sqlalchemy_url_with_sqlalchemy_scheme_and_wrong_query_value(self):
+ conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql",
sqlalchemy_query="wrong type")))
+ hook_params = {"driver_path": "ParamDriverPath", "driver_class":
"ParamDriverClass"}
+ hook = get_hook(conn_params=conn_params, hook_params=hook_params)
+
+ with pytest.raises(AirflowException):
+ hook.sqlalchemy_url
+
def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
jdbc_hook = get_hook(
conn_params=dict(extra={"sqlalchemy_scheme": "sqlite"}),
diff --git a/providers/tests/postgres/hooks/test_postgres.py
b/providers/tests/postgres/hooks/test_postgres.py
index 7a720534d4b..76206d57958 100644
--- a/providers/tests/postgres/hooks/test_postgres.py
+++ b/providers/tests/postgres/hooks/test_postgres.py
@@ -25,6 +25,7 @@ from unittest import mock
import psycopg2.extras
import pytest
+from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.types import NOTSET
@@ -65,9 +66,42 @@ class TestPostgresHookConn:
assert mock_connect.call_count == 1
assert self.db_hook.get_uri() ==
"postgresql://login:password@host:5432/database"
+ def test_sqlalchemy_url(self):
+ conn = Connection(login="login-conn", password="password-conn",
host="host", schema="database")
+ hook = PostgresHook(connection=conn)
+ assert str(hook.sqlalchemy_url) ==
"postgresql://login-conn:password-conn@host/database"
+
+ def test_sqlalchemy_url_with_sqlalchemy_query(self):
+ conn = Connection(
+ login="login-conn",
+ password="password-conn",
+ host="host",
+ schema="database",
+ extra=dict(sqlalchemy_query={"gssencmode": "disable"}),
+ )
+ hook = PostgresHook(connection=conn)
+
+ assert (
+ str(hook.sqlalchemy_url)
+ ==
"postgresql://login-conn:password-conn@host/database?gssencmode=disable"
+ )
+
+ def test_sqlalchemy_url_with_wrong_sqlalchemy_query_value(self):
+ conn = Connection(
+ login="login-conn",
+ password="password-conn",
+ host="host",
+ schema="database",
+ extra=dict(sqlalchemy_query="wrong type"),
+ )
+ hook = PostgresHook(connection=conn)
+
+ with pytest.raises(AirflowException):
+ hook.sqlalchemy_url
+
@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn_cursor(self, mock_connect):
- self.connection.extra = '{"cursor": "dictcursor"}'
+ self.connection.extra = '{"cursor": "dictcursor", "sqlalchemy_query":
{"gssencmode": "disable"}}'
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
cursor_factory=psycopg2.extras.DictCursor,