This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c61ca40d7 Add backward compatibility for elasticsearch<8 (#33281)
3c61ca40d7 is described below

commit 3c61ca40d7dfea4bb51d17704f9da88d7edd08c4
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Thu Aug 10 17:11:39 2023 +0530

    Add backward compatibility for elasticsearch<8 (#33281)
    
    * Add backward compatibility for elasticsearch<8
---
 .../providers/elasticsearch/log/es_task_handler.py | 29 +++++++++-
 .../elasticsearch/log/test_es_task_handler.py      | 63 ++++++++++++++++++++++
 2 files changed, 91 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py 
b/airflow/providers/elasticsearch/log/es_task_handler.py
index 03bfe247c5..0a85f178ba 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -25,7 +25,7 @@ from datetime import datetime
 from operator import attrgetter
 from time import time
 from typing import TYPE_CHECKING, Any, Callable, List, Tuple
-from urllib.parse import quote
+from urllib.parse import quote, urlparse
 
 # Using `from elasticsearch import *` would break elasticsearch mocking used 
in unit test.
 import elasticsearch
@@ -98,6 +98,12 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
         log_id_template: str | None = None,
     ):
         es_kwargs = es_kwargs or {}
+        # For elasticsearch>8,arguments like retry_timeout have changed for 
elasticsearch to retry_on_timeout
+        # in Elasticsearch() compared to previous versions.
+        # Read more at: 
https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#module-elasticsearch
+        if es_kwargs.get("retry_timeout"):
+            es_kwargs["retry_on_timeout"] = es_kwargs.pop("retry_timeout")
+        host = self.format_url(host)
         super().__init__(base_log_folder, filename_template)
         self.closed = False
 
@@ -126,6 +132,27 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
         self._doc_type_map: dict[Any, Any] = {}
         self._doc_type: list[Any] = []
 
+    @staticmethod
+    def format_url(host: str) -> str:
+        """
+        Formats the given host string to ensure it starts with 'http'.
+        Checks if the host string represents a valid URL.
+
+        :params host: The host string to format and check.
+        """
+        parsed_url = urlparse(host)
+
+        # Check if the scheme is either http or https
+        if not parsed_url.scheme:
+            host = "http://"; + host
+            parsed_url = urlparse(host)
+
+        # Basic validation for a valid URL
+        if not parsed_url.netloc:
+            raise ValueError(f"'{host}' is not a valid URL.")
+
+        return host
+
     def _render_log_id(self, ti: TaskInstance, try_number: int) -> str:
         with create_session() as session:
             dag_run = ti.get_dagrun(session=session)
diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py 
b/tests/providers/elasticsearch/log/test_es_task_handler.py
index 7ae894f22a..4ffa958819 100644
--- a/tests/providers/elasticsearch/log/test_es_task_handler.py
+++ b/tests/providers/elasticsearch/log/test_es_task_handler.py
@@ -125,6 +125,69 @@ class TestElasticsearchTaskHandler:
             "on 2023-07-09 07:47:32+00:00"
         )
 
+    @pytest.mark.parametrize(
+        "host, expected",
+        [
+            ("http://localhost:9200";, "http://localhost:9200";),
+            ("https://localhost:9200";, "https://localhost:9200";),
+            ("localhost:9200", "http://localhost:9200";),
+            ("someurl", "http://someurl";),
+            ("https://";, "ValueError"),
+        ],
+    )
+    def test_format_url(self, host, expected):
+        """
+        Test the format_url method of the ElasticsearchTaskHandler class.
+        """
+        if expected == "ValueError":
+            with pytest.raises(ValueError):
+                assert ElasticsearchTaskHandler.format_url(host) == expected
+        else:
+            assert ElasticsearchTaskHandler.format_url(host) == expected
+
+    def test_elasticsearch_constructor_retry_timeout_handling(self):
+        """
+        Test if the ElasticsearchTaskHandler constructor properly handles the 
retry_timeout argument.
+        """
+        # Mock the Elasticsearch client
+        with mock.patch(
+            
"airflow.providers.elasticsearch.log.es_task_handler.elasticsearch.Elasticsearch"
+        ) as mock_es:
+            # Test when 'retry_timeout' is present in es_kwargs
+            es_kwargs = {"retry_timeout": 10}
+            ElasticsearchTaskHandler(
+                base_log_folder="dummy_folder",
+                end_of_log_mark="end_of_log_mark",
+                write_stdout=False,
+                json_format=False,
+                json_fields="fields",
+                host_field="host",
+                offset_field="offset",
+                es_kwargs=es_kwargs,
+            )
+
+            # Check the arguments with which the Elasticsearch client is 
instantiated
+            mock_es.assert_called_once_with("http://localhost:9200";, 
retry_on_timeout=10)
+
+            # Reset the mock for the next test
+            mock_es.reset_mock()
+
+            # Test when 'retry_timeout' is not present in es_kwargs
+            es_kwargs = {}
+            ElasticsearchTaskHandler(
+                base_log_folder="dummy_folder",
+                end_of_log_mark="end_of_log_mark",
+                write_stdout=False,
+                json_format=False,
+                json_fields="fields",
+                host_field="host",
+                offset_field="offset",
+                es_kwargs=es_kwargs,
+            )
+
+            # Check that the Elasticsearch client is instantiated without the 
'retry_on_timeout' argument
+            mock_es.assert_called_once_with("http://localhost:9200";)
+
     def test_client(self):
         assert isinstance(self.es_task_handler.client, 
elasticsearch.Elasticsearch)
         assert self.es_task_handler.index_patterns == "_all"

Reply via email to