This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 73106932771ba3a1e8a08de3905b569d5462c416 Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Tue Nov 21 11:12:56 2023 +0400 Remove backcompat inheritance for DbApiHook (#35754) * Remove backcompat inheritance for DbApiHook * jwt_file > jwt__file * simplify trino test --- airflow/providers/apache/impala/hooks/impala.py | 3 +- airflow/providers/common/sql/hooks/sql.py | 18 +--------- airflow/providers/common/sql/hooks/sql.pyi | 4 +-- .../providers/elasticsearch/hooks/elasticsearch.py | 4 +-- airflow/providers/trino/hooks/trino.py | 19 ++++++++--- tests/providers/odbc/hooks/test_odbc.py | 3 +- tests/providers/trino/hooks/test_trino.py | 38 +++++++++++++++------- 7 files changed, 49 insertions(+), 40 deletions(-) diff --git a/airflow/providers/apache/impala/hooks/impala.py b/airflow/providers/apache/impala/hooks/impala.py index ab19865a9e..b8c79b4e25 100644 --- a/airflow/providers/apache/impala/hooks/impala.py +++ b/airflow/providers/apache/impala/hooks/impala.py @@ -35,7 +35,8 @@ class ImpalaHook(DbApiHook): hook_name = "Impala" def get_conn(self) -> Connection: - connection = self.get_connection(self.impala_conn_id) # pylint: disable=no-member + conn_id: str = getattr(self, self.conn_name_attr) + connection = self.get_connection(conn_id) return connect( host=connection.host, port=connection.port, diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index ab4eda5d8e..bb85dedc1c 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -34,12 +34,10 @@ from typing import ( from urllib.parse import urlparse import sqlparse -from packaging.version import Version from sqlalchemy import create_engine from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.version import version if TYPE_CHECKING: from pandas import DataFrame @@ -120,21 +118,7 @@ class ConnectorProtocol(Protocol): """ -# In case we are running it on Airflow 2.4+, we should use BaseHook, but on Airflow 2.3 and below -# We want the DbApiHook to derive from the original DbApiHook from airflow, because otherwise -# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" will refuse to -# accept the new Hooks as not derived from the original DbApiHook -if Version(version) < Version("2.4"): - try: - from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook - except ImportError: - # just in case we have a problem with circular import - BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] -else: - BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] - - -class DbApiHook(BaseForDbApiHook): +class DbApiHook(BaseHook): """ Abstract base class for sql hooks. diff --git a/airflow/providers/common/sql/hooks/sql.pyi b/airflow/providers/common/sql/hooks/sql.pyi index dedac037df..41bd6ebf47 100644 --- a/airflow/providers/common/sql/hooks/sql.pyi +++ b/airflow/providers/common/sql/hooks/sql.pyi @@ -32,8 +32,8 @@ Definition of the public interface for airflow.providers.common.sql.hooks.sql isort:skip_file """ from _typeshed import Incomplete -from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook -from typing import Any, Callable, Iterable, Mapping, Sequence +from airflow.hooks.base import BaseHook as BaseForDbApiHook +from typing import Any, Callable, Iterable, Mapping, Sequence, Union from typing_extensions import Protocol def return_single_query_results( diff --git a/airflow/providers/elasticsearch/hooks/elasticsearch.py b/airflow/providers/elasticsearch/hooks/elasticsearch.py index 6c93586892..2d9fca4a97 100644 --- a/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -108,9 +108,7 @@ class ElasticsearchSQLHook(DbApiHook): if conn.extra_dejson.get("timeout", False): conn_args["timeout"] = conn.extra_dejson["timeout"] - conn = connect(**conn_args) - - return conn + return connect(**conn_args) def get_uri(self) -> str: conn_id = getattr(self, self.conn_name_attr) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 03195fe452..798109dc3f 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -19,6 +19,7 @@ from __future__ import annotations import json import os +from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable, Mapping, TypeVar import trino @@ -28,6 +29,7 @@ from trino.transaction import IsolationLevel from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.utils.helpers import exactly_one from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING, DEFAULT_FORMAT_PREFIX if TYPE_CHECKING: @@ -99,11 +101,20 @@ class TrinoHook(DbApiHook): elif db.password: auth = trino.auth.BasicAuthentication(db.login, db.password) # type: ignore[attr-defined] elif extra.get("auth") == "jwt": - if "jwt__file" in extra: - with open(extra.get("jwt__file")) as jwt_file: - token = jwt_file.read() + if not exactly_one(jwt_file := "jwt__file" in extra, jwt_token := "jwt__token" in extra): + msg = ( + "When auth set to 'jwt' then expected exactly one parameter 'jwt__file' or 'jwt__token'" + " in connection extra, but " + ) + if jwt_file and jwt_token: + msg += "provided both." + else: + msg += "none of them provided." + raise ValueError(msg) + elif jwt_file: + token = Path(extra["jwt__file"]).read_text() else: - token = extra.get("jwt__token") + token = extra["jwt__token"] auth = trino.auth.JWTAuthentication(token=token) elif extra.get("auth") == "certs": auth = trino.auth.CertificateAuthentication( diff --git a/tests/providers/odbc/hooks/test_odbc.py b/tests/providers/odbc/hooks/test_odbc.py index ad763b934b..03e09a8adf 100644 --- a/tests/providers/odbc/hooks/test_odbc.py +++ b/tests/providers/odbc/hooks/test_odbc.py @@ -79,7 +79,8 @@ class TestOdbcHook: class UnitTestOdbcHook(OdbcHook): conn_name_attr = "test_conn_id" - def get_connection(self, conn_id: str): + @classmethod + def get_connection(cls, conn_id: str): return connection def get_conn(self): diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py index 8aeb4cbe08..5d61c056bc 100644 --- a/tests/providers/trino/hooks/test_trino.py +++ b/tests/providers/trino/hooks/test_trino.py @@ -18,9 +18,7 @@ from __future__ import annotations import json -import os import re -from tempfile import TemporaryDirectory from unittest import mock from unittest.mock import patch @@ -40,16 +38,10 @@ CERT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.Certificat @pytest.fixture() -def jwt_token_file(): - # Couldn't get this working with TemporaryFile, using TemporaryDirectory instead - # Save a phony jwt to a temporary file for the trino hook to read from - with TemporaryDirectory() as tmp_dir: - tmp_jwt_file = os.path.join(tmp_dir, "jwt.json") - - with open(tmp_jwt_file, "w") as tmp_file: - tmp_file.write('{"phony":"jwt"}') - - yield tmp_jwt_file +def jwt_token_file(tmp_path): + jwt_file = tmp_path / "jwt.json" + jwt_file.write_text('{"phony":"jwt"}') + yield jwt_file.__fspath__() class TestTrinoHookConn: @@ -140,6 +132,28 @@ class TestTrinoHookConn: TrinoHook().get_conn() self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth) + @pytest.mark.parametrize( + "jwt_file, jwt_token, error_suffix", + [ + pytest.param(True, True, "provided both", id="provided-both-params"), + pytest.param(False, False, "none of them provided", id="no-jwt-provided"), + ], + ) + @patch(HOOK_GET_CONNECTION) + def test_exactly_one_jwt_token( + self, mock_get_connection, jwt_file, jwt_token, error_suffix, jwt_token_file + ): + error_match = f"When auth set to 'jwt'.*{error_suffix}" + extras = {"auth": "jwt"} + if jwt_file: + extras["jwt__file"] = jwt_token_file + if jwt_token: + extras["jwt__token"] = "TEST_JWT_TOKEN" + + self.set_get_connection_return_value(mock_get_connection, extra=json.dumps(extras)) + with pytest.raises(ValueError, match=error_match): + TrinoHook().get_conn() + @patch(CERT_AUTHENTICATION) @patch(TRINO_DBAPI_CONNECT) @patch(HOOK_GET_CONNECTION)