This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 73106932771ba3a1e8a08de3905b569d5462c416
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Tue Nov 21 11:12:56 2023 +0400

    Remove backcompat inheritance for DbApiHook (#35754)
    
    * Remove backcompat inheritance for DbApiHook
    
    * jwt_file > jwt__file
    
    * simplify trino test
---
 airflow/providers/apache/impala/hooks/impala.py    |  3 +-
 airflow/providers/common/sql/hooks/sql.py          | 18 +---------
 airflow/providers/common/sql/hooks/sql.pyi         |  4 +--
 .../providers/elasticsearch/hooks/elasticsearch.py |  4 +--
 airflow/providers/trino/hooks/trino.py             | 19 ++++++++---
 tests/providers/odbc/hooks/test_odbc.py            |  3 +-
 tests/providers/trino/hooks/test_trino.py          | 38 +++++++++++++++-------
 7 files changed, 49 insertions(+), 40 deletions(-)

diff --git a/airflow/providers/apache/impala/hooks/impala.py 
b/airflow/providers/apache/impala/hooks/impala.py
index ab19865a9e..b8c79b4e25 100644
--- a/airflow/providers/apache/impala/hooks/impala.py
+++ b/airflow/providers/apache/impala/hooks/impala.py
@@ -35,7 +35,8 @@ class ImpalaHook(DbApiHook):
     hook_name = "Impala"
 
     def get_conn(self) -> Connection:
-        connection = self.get_connection(self.impala_conn_id)  # pylint: 
disable=no-member
+        conn_id: str = getattr(self, self.conn_name_attr)
+        connection = self.get_connection(conn_id)
         return connect(
             host=connection.host,
             port=connection.port,
diff --git a/airflow/providers/common/sql/hooks/sql.py 
b/airflow/providers/common/sql/hooks/sql.py
index ab4eda5d8e..bb85dedc1c 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -34,12 +34,10 @@ from typing import (
 from urllib.parse import urlparse
 
 import sqlparse
-from packaging.version import Version
 from sqlalchemy import create_engine
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.version import version
 
 if TYPE_CHECKING:
     from pandas import DataFrame
@@ -120,21 +118,7 @@ class ConnectorProtocol(Protocol):
         """
 
 
-# In case we are running it on Airflow 2.4+, we should use BaseHook, but on 
Airflow 2.3 and below
-# We want the DbApiHook to derive from the original DbApiHook from airflow, 
because otherwise
-# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" 
will refuse to
-# accept the new Hooks as not derived from the original DbApiHook
-if Version(version) < Version("2.4"):
-    try:
-        from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook
-    except ImportError:
-        # just in case we have a problem with circular import
-        BaseForDbApiHook: type[BaseHook] = BaseHook  # type: ignore[no-redef]
-else:
-    BaseForDbApiHook: type[BaseHook] = BaseHook  # type: ignore[no-redef]
-
-
-class DbApiHook(BaseForDbApiHook):
+class DbApiHook(BaseHook):
     """
     Abstract base class for sql hooks.
 
diff --git a/airflow/providers/common/sql/hooks/sql.pyi 
b/airflow/providers/common/sql/hooks/sql.pyi
index dedac037df..41bd6ebf47 100644
--- a/airflow/providers/common/sql/hooks/sql.pyi
+++ b/airflow/providers/common/sql/hooks/sql.pyi
@@ -32,8 +32,8 @@ Definition of the public interface for 
airflow.providers.common.sql.hooks.sql
 isort:skip_file
 """
 from _typeshed import Incomplete
-from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook
-from typing import Any, Callable, Iterable, Mapping, Sequence
+from airflow.hooks.base import BaseHook as BaseForDbApiHook
+from typing import Any, Callable, Iterable, Mapping, Sequence, Union
 from typing_extensions import Protocol
 
 def return_single_query_results(
diff --git a/airflow/providers/elasticsearch/hooks/elasticsearch.py 
b/airflow/providers/elasticsearch/hooks/elasticsearch.py
index 6c93586892..2d9fca4a97 100644
--- a/airflow/providers/elasticsearch/hooks/elasticsearch.py
+++ b/airflow/providers/elasticsearch/hooks/elasticsearch.py
@@ -108,9 +108,7 @@ class ElasticsearchSQLHook(DbApiHook):
         if conn.extra_dejson.get("timeout", False):
             conn_args["timeout"] = conn.extra_dejson["timeout"]
 
-        conn = connect(**conn_args)
-
-        return conn
+        return connect(**conn_args)
 
     def get_uri(self) -> str:
         conn_id = getattr(self, self.conn_name_attr)
diff --git a/airflow/providers/trino/hooks/trino.py 
b/airflow/providers/trino/hooks/trino.py
index 03195fe452..798109dc3f 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import json
 import os
+from pathlib import Path
 from typing import TYPE_CHECKING, Any, Iterable, Mapping, TypeVar
 
 import trino
@@ -28,6 +29,7 @@ from trino.transaction import IsolationLevel
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.utils.helpers import exactly_one
 from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING, 
DEFAULT_FORMAT_PREFIX
 
 if TYPE_CHECKING:
@@ -99,11 +101,20 @@ class TrinoHook(DbApiHook):
         elif db.password:
             auth = trino.auth.BasicAuthentication(db.login, db.password)  # 
type: ignore[attr-defined]
         elif extra.get("auth") == "jwt":
-            if "jwt__file" in extra:
-                with open(extra.get("jwt__file")) as jwt_file:
-                    token = jwt_file.read()
+            if not exactly_one(jwt_file := "jwt__file" in extra, jwt_token := 
"jwt__token" in extra):
+                msg = (
+                    "When auth set to 'jwt' then expected exactly one 
parameter 'jwt__file' or 'jwt__token'"
+                    " in connection extra, but "
+                )
+                if jwt_file and jwt_token:
+                    msg += "provided both."
+                else:
+                    msg += "none of them provided."
+                raise ValueError(msg)
+            elif jwt_file:
+                token = Path(extra["jwt__file"]).read_text()
             else:
-                token = extra.get("jwt__token")
+                token = extra["jwt__token"]
             auth = trino.auth.JWTAuthentication(token=token)
         elif extra.get("auth") == "certs":
             auth = trino.auth.CertificateAuthentication(
diff --git a/tests/providers/odbc/hooks/test_odbc.py 
b/tests/providers/odbc/hooks/test_odbc.py
index ad763b934b..03e09a8adf 100644
--- a/tests/providers/odbc/hooks/test_odbc.py
+++ b/tests/providers/odbc/hooks/test_odbc.py
@@ -79,7 +79,8 @@ class TestOdbcHook:
         class UnitTestOdbcHook(OdbcHook):
             conn_name_attr = "test_conn_id"
 
-            def get_connection(self, conn_id: str):
+            @classmethod
+            def get_connection(cls, conn_id: str):
                 return connection
 
             def get_conn(self):
diff --git a/tests/providers/trino/hooks/test_trino.py 
b/tests/providers/trino/hooks/test_trino.py
index 8aeb4cbe08..5d61c056bc 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -18,9 +18,7 @@
 from __future__ import annotations
 
 import json
-import os
 import re
-from tempfile import TemporaryDirectory
 from unittest import mock
 from unittest.mock import patch
 
@@ -40,16 +38,10 @@ CERT_AUTHENTICATION = 
"airflow.providers.trino.hooks.trino.trino.auth.Certificat
 
 
 @pytest.fixture()
-def jwt_token_file():
-    # Couldn't get this working with TemporaryFile, using TemporaryDirectory 
instead
-    # Save a phony jwt to a temporary file for the trino hook to read from
-    with TemporaryDirectory() as tmp_dir:
-        tmp_jwt_file = os.path.join(tmp_dir, "jwt.json")
-
-        with open(tmp_jwt_file, "w") as tmp_file:
-            tmp_file.write('{"phony":"jwt"}')
-
-        yield tmp_jwt_file
+def jwt_token_file(tmp_path):
+    jwt_file = tmp_path / "jwt.json"
+    jwt_file.write_text('{"phony":"jwt"}')
+    yield jwt_file.__fspath__()
 
 
 class TestTrinoHookConn:
@@ -140,6 +132,28 @@ class TestTrinoHookConn:
         TrinoHook().get_conn()
         self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)
 
+    @pytest.mark.parametrize(
+        "jwt_file, jwt_token, error_suffix",
+        [
+            pytest.param(True, True, "provided both", 
id="provided-both-params"),
+            pytest.param(False, False, "none of them provided", 
id="no-jwt-provided"),
+        ],
+    )
+    @patch(HOOK_GET_CONNECTION)
+    def test_exactly_one_jwt_token(
+        self, mock_get_connection, jwt_file, jwt_token, error_suffix, 
jwt_token_file
+    ):
+        error_match = f"When auth set to 'jwt'.*{error_suffix}"
+        extras = {"auth": "jwt"}
+        if jwt_file:
+            extras["jwt__file"] = jwt_token_file
+        if jwt_token:
+            extras["jwt__token"] = "TEST_JWT_TOKEN"
+
+        self.set_get_connection_return_value(mock_get_connection, 
extra=json.dumps(extras))
+        with pytest.raises(ValueError, match=error_match):
+            TrinoHook().get_conn()
+
     @patch(CERT_AUTHENTICATION)
     @patch(TRINO_DBAPI_CONNECT)
     @patch(HOOK_GET_CONNECTION)

Reply via email to