This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 1a61eb3afd feat: OpenSearchQueryOperator using an endpoint with a self-signed certificate (#39788) 1a61eb3afd is described below commit 1a61eb3afdaf0496ce6c308b1a36357cfd5728b7 Author: Lukas1v <51420170+luka...@users.noreply.github.com> AuthorDate: Sat Jun 8 07:44:57 2024 +0200 feat: OpenSearchQueryOperator using an endpoint with a self-signed certificate (#39788) * feat: added connection options * feat: opensearch hook unit tests * feat: fallback to RequestsHttpConnection * fix: static checks * fix: static checks * fix: static checks * feat: opensearch static module loading --------- Co-authored-by: Lukas Verret <lukas.ver...@infrabel.be> --- airflow/providers/opensearch/hooks/opensearch.py | 21 ++++++++++++---- .../providers/opensearch/operators/opensearch.py | 12 +++++++++- .../providers/opensearch/hooks/test_opensearch.py | 28 ++++++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/airflow/providers/opensearch/hooks/opensearch.py b/airflow/providers/opensearch/hooks/opensearch.py index 2b4c254b4a..c5500be108 100644 --- a/airflow/providers/opensearch/hooks/opensearch.py +++ b/airflow/providers/opensearch/hooks/opensearch.py @@ -19,12 +19,16 @@ from __future__ import annotations import json from functools import cached_property -from typing import Any +from typing import TYPE_CHECKING, Any from opensearchpy import OpenSearch, RequestsHttpConnection +if TYPE_CHECKING: + from opensearchpy import Connection as OpenSearchConnectionClass + from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.utils.strings import to_boolean class OpenSearchHook(BaseHook): @@ -40,13 +44,20 @@ class OpenSearchHook(BaseHook): conn_type = "opensearch" hook_name = "OpenSearch Hook" - def __init__(self, open_search_conn_id: str, log_query: bool, **kwargs: Any): + def __init__( + self, + open_search_conn_id: str, + log_query: bool, + open_search_conn_class: type[OpenSearchConnectionClass] | None = RequestsHttpConnection, + **kwargs: Any, + ): super().__init__(**kwargs) self.conn_id = open_search_conn_id self.log_query = log_query - self.use_ssl = self.conn.extra_dejson.get("use_ssl", False) - self.verify_certs = self.conn.extra_dejson.get("verify_certs", False) + self.use_ssl = to_boolean(str(self.conn.extra_dejson.get("use_ssl", False))) + self.verify_certs = to_boolean(str(self.conn.extra_dejson.get("verify_certs", False))) + self.connection_class = open_search_conn_class self.__SERVICE = "es" @cached_property @@ -62,7 +73,7 @@ class OpenSearchHook(BaseHook): http_auth=auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, - connection_class=RequestsHttpConnection, + connection_class=self.connection_class, ) return client diff --git a/airflow/providers/opensearch/operators/opensearch.py b/airflow/providers/opensearch/operators/opensearch.py index cc12b6e8b0..6599e5bf94 100644 --- a/airflow/providers/opensearch/operators/opensearch.py +++ b/airflow/providers/opensearch/operators/opensearch.py @@ -20,6 +20,7 @@ from __future__ import annotations from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence +from opensearchpy import RequestsHttpConnection from opensearchpy.exceptions import OpenSearchException from airflow.exceptions import AirflowException @@ -27,6 +28,8 @@ from airflow.models import BaseOperator from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook if TYPE_CHECKING: + from opensearchpy import Connection as OpenSearchConnectionClass + from airflow.utils.context import Context @@ -42,6 +45,7 @@ class OpenSearchQueryOperator(BaseOperator): :param search_object: A Search object from opensearch-dsl. :param index_name: The name of the index to search for documents. :param opensearch_conn_id: opensearch connection to use + :param opensearch_conn_class: opensearch connection class to use :param log_query: Whether to log the query used. Defaults to True and logs query used. """ @@ -54,6 +58,7 @@ class OpenSearchQueryOperator(BaseOperator): search_object: Any | None = None, index_name: str | None = None, opensearch_conn_id: str = "opensearch_default", + opensearch_conn_class: type[OpenSearchConnectionClass] | None = RequestsHttpConnection, log_query: bool = True, **kwargs, ) -> None: @@ -61,13 +66,18 @@ class OpenSearchQueryOperator(BaseOperator): self.query = query self.index_name = index_name self.opensearch_conn_id = opensearch_conn_id + self.opensearch_conn_class = opensearch_conn_class self.log_query = log_query self.search_object = search_object @cached_property def hook(self) -> OpenSearchHook: """Get an instance of an OpenSearchHook.""" - return OpenSearchHook(open_search_conn_id=self.opensearch_conn_id, log_query=self.log_query) + return OpenSearchHook( + open_search_conn_id=self.opensearch_conn_id, + open_search_conn_class=self.opensearch_conn_class, + log_query=self.log_query, + ) def execute(self, context: Context) -> Any: """Execute a search against a given index or a Search object on an OpenSearch Cluster.""" diff --git a/tests/providers/opensearch/hooks/test_opensearch.py b/tests/providers/opensearch/hooks/test_opensearch.py index 92f57d276e..84360ae73f 100644 --- a/tests/providers/opensearch/hooks/test_opensearch.py +++ b/tests/providers/opensearch/hooks/test_opensearch.py @@ -16,15 +16,21 @@ # under the License. from __future__ import annotations +from unittest import mock + +import opensearchpy import pytest +from opensearchpy import Urllib3HttpConnection from airflow.exceptions import AirflowException +from airflow.models import Connection from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook pytestmark = pytest.mark.db_test MOCK_SEARCH_RETURN = {"status": "test"} +DEFAULT_CONN = opensearchpy.connection.http_requests.RequestsHttpConnection class TestOpenSearchHook: @@ -46,3 +52,25 @@ class TestOpenSearchHook: hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True) with pytest.raises(AirflowException, match="must include one of either a query or a document id"): hook.delete(index_name="test_index") + + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_hook_param_bool(self, mock_get_connection): + mock_conn = Connection( + conn_id="opensearch_default", extra={"use_ssl": "True", "verify_certs": "True"} + ) + mock_get_connection.return_value = mock_conn + hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True) + + assert isinstance(hook.use_ssl, bool) + assert isinstance(hook.verify_certs, bool) + + def test_load_conn_param(self, mock_hook): + hook_default = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True) + assert hook_default.connection_class == DEFAULT_CONN + + hook_Urllib3 = OpenSearchHook( + open_search_conn_id="opensearch_default", + log_query=True, + open_search_conn_class=Urllib3HttpConnection, + ) + assert hook_Urllib3.connection_class == Urllib3HttpConnection