This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1a61eb3afd feat: OpenSearchQueryOperator using an endpoint with a 
self-signed certificate (#39788)
1a61eb3afd is described below

commit 1a61eb3afdaf0496ce6c308b1a36357cfd5728b7
Author: Lukas1v <51420170+luka...@users.noreply.github.com>
AuthorDate: Sat Jun 8 07:44:57 2024 +0200

    feat: OpenSearchQueryOperator using an endpoint with a self-signed 
certificate (#39788)
    
    * feat: added connection options
    
    * feat: opensearch hook unit tests
    
    * feat: fallback to RequestsHttpConnection
    
    * fix: static checks
    
    * fix: static checks
    
    * fix: static checks
    
    * feat: opensearch static module loading
    
    ---------
    
    Co-authored-by: Lukas Verret <lukas.ver...@infrabel.be>
---
 airflow/providers/opensearch/hooks/opensearch.py   | 21 ++++++++++++----
 .../providers/opensearch/operators/opensearch.py   | 12 +++++++++-
 .../providers/opensearch/hooks/test_opensearch.py  | 28 ++++++++++++++++++++++
 3 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/opensearch/hooks/opensearch.py 
b/airflow/providers/opensearch/hooks/opensearch.py
index 2b4c254b4a..c5500be108 100644
--- a/airflow/providers/opensearch/hooks/opensearch.py
+++ b/airflow/providers/opensearch/hooks/opensearch.py
@@ -19,12 +19,16 @@ from __future__ import annotations
 
 import json
 from functools import cached_property
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 from opensearchpy import OpenSearch, RequestsHttpConnection
 
+if TYPE_CHECKING:
+    from opensearchpy import Connection as OpenSearchConnectionClass
+
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
+from airflow.utils.strings import to_boolean
 
 
 class OpenSearchHook(BaseHook):
@@ -40,13 +44,20 @@ class OpenSearchHook(BaseHook):
     conn_type = "opensearch"
     hook_name = "OpenSearch Hook"
 
-    def __init__(self, open_search_conn_id: str, log_query: bool, **kwargs: 
Any):
+    def __init__(
+        self,
+        open_search_conn_id: str,
+        log_query: bool,
+        open_search_conn_class: type[OpenSearchConnectionClass] | None = 
RequestsHttpConnection,
+        **kwargs: Any,
+    ):
         super().__init__(**kwargs)
         self.conn_id = open_search_conn_id
         self.log_query = log_query
 
-        self.use_ssl = self.conn.extra_dejson.get("use_ssl", False)
-        self.verify_certs = self.conn.extra_dejson.get("verify_certs", False)
+        self.use_ssl = to_boolean(str(self.conn.extra_dejson.get("use_ssl", 
False)))
+        self.verify_certs = 
to_boolean(str(self.conn.extra_dejson.get("verify_certs", False)))
+        self.connection_class = open_search_conn_class
         self.__SERVICE = "es"
 
     @cached_property
@@ -62,7 +73,7 @@ class OpenSearchHook(BaseHook):
             http_auth=auth,
             use_ssl=self.use_ssl,
             verify_certs=self.verify_certs,
-            connection_class=RequestsHttpConnection,
+            connection_class=self.connection_class,
         )
         return client
 
diff --git a/airflow/providers/opensearch/operators/opensearch.py 
b/airflow/providers/opensearch/operators/opensearch.py
index cc12b6e8b0..6599e5bf94 100644
--- a/airflow/providers/opensearch/operators/opensearch.py
+++ b/airflow/providers/opensearch/operators/opensearch.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
+from opensearchpy import RequestsHttpConnection
 from opensearchpy.exceptions import OpenSearchException
 
 from airflow.exceptions import AirflowException
@@ -27,6 +28,8 @@ from airflow.models import BaseOperator
 from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook
 
 if TYPE_CHECKING:
+    from opensearchpy import Connection as OpenSearchConnectionClass
+
     from airflow.utils.context import Context
 
 
@@ -42,6 +45,7 @@ class OpenSearchQueryOperator(BaseOperator):
     :param search_object: A Search object from opensearch-dsl.
     :param index_name: The name of the index to search for documents.
     :param opensearch_conn_id: opensearch connection to use
+    :param opensearch_conn_class: opensearch connection class to use
     :param log_query: Whether to log the query used. Defaults to True and logs 
query used.
     """
 
@@ -54,6 +58,7 @@ class OpenSearchQueryOperator(BaseOperator):
         search_object: Any | None = None,
         index_name: str | None = None,
         opensearch_conn_id: str = "opensearch_default",
+        opensearch_conn_class: type[OpenSearchConnectionClass] | None = 
RequestsHttpConnection,
         log_query: bool = True,
         **kwargs,
     ) -> None:
@@ -61,13 +66,18 @@ class OpenSearchQueryOperator(BaseOperator):
         self.query = query
         self.index_name = index_name
         self.opensearch_conn_id = opensearch_conn_id
+        self.opensearch_conn_class = opensearch_conn_class
         self.log_query = log_query
         self.search_object = search_object
 
     @cached_property
     def hook(self) -> OpenSearchHook:
         """Get an instance of an OpenSearchHook."""
-        return OpenSearchHook(open_search_conn_id=self.opensearch_conn_id, 
log_query=self.log_query)
+        return OpenSearchHook(
+            open_search_conn_id=self.opensearch_conn_id,
+            open_search_conn_class=self.opensearch_conn_class,
+            log_query=self.log_query,
+        )
 
     def execute(self, context: Context) -> Any:
         """Execute a search against a given index or a Search object on an 
OpenSearch Cluster."""
diff --git a/tests/providers/opensearch/hooks/test_opensearch.py 
b/tests/providers/opensearch/hooks/test_opensearch.py
index 92f57d276e..84360ae73f 100644
--- a/tests/providers/opensearch/hooks/test_opensearch.py
+++ b/tests/providers/opensearch/hooks/test_opensearch.py
@@ -16,15 +16,21 @@
 # under the License.
 from __future__ import annotations
 
+from unittest import mock
+
+import opensearchpy
 import pytest
+from opensearchpy import Urllib3HttpConnection
 
 from airflow.exceptions import AirflowException
+from airflow.models import Connection
 from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook
 
 pytestmark = pytest.mark.db_test
 
 
 MOCK_SEARCH_RETURN = {"status": "test"}
+DEFAULT_CONN = opensearchpy.connection.http_requests.RequestsHttpConnection
 
 
 class TestOpenSearchHook:
@@ -46,3 +52,25 @@ class TestOpenSearchHook:
         hook = OpenSearchHook(open_search_conn_id="opensearch_default", 
log_query=True)
         with pytest.raises(AirflowException, match="must include one of either 
a query or a document id"):
             hook.delete(index_name="test_index")
+
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    def test_hook_param_bool(self, mock_get_connection):
+        mock_conn = Connection(
+            conn_id="opensearch_default", extra={"use_ssl": "True", 
"verify_certs": "True"}
+        )
+        mock_get_connection.return_value = mock_conn
+        hook = OpenSearchHook(open_search_conn_id="opensearch_default", 
log_query=True)
+
+        assert isinstance(hook.use_ssl, bool)
+        assert isinstance(hook.verify_certs, bool)
+
+    def test_load_conn_param(self, mock_hook):
+        hook_default = 
OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
+        assert hook_default.connection_class == DEFAULT_CONN
+
+        hook_Urllib3 = OpenSearchHook(
+            open_search_conn_id="opensearch_default",
+            log_query=True,
+            open_search_conn_class=Urllib3HttpConnection,
+        )
+        assert hook_Urllib3.connection_class == Urllib3HttpConnection

Reply via email to