This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b358d0a6de Add `retry_args` parameter to `HttpOperator` (#40086)
b358d0a6de is described below

commit b358d0a6de3a7c2fabdabaccf8c5edab9e1d0ecf
Author: Bora Berke Sahin <67373739+borabe...@users.noreply.github.com>
AuthorDate: Thu Jun 6 15:16:30 2024 +0300

    Add `retry_args` parameter to `HttpOperator` (#40086)
    
    * Add `retry_args` parameter to `HttpOperator`
    
    * Add unit tests and fix bug in pagination
---
 airflow/providers/http/operators/http.py    | 18 +++++++-
 tests/providers/http/operators/test_http.py | 71 +++++++++++++++++++++++++++++
 2 files changed, 87 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/http/operators/http.py 
b/airflow/providers/http/operators/http.py
index 3c04a88573..6c086395e5 100644
--- a/airflow/providers/http/operators/http.py
+++ b/airflow/providers/http/operators/http.py
@@ -89,6 +89,8 @@ class HttpOperator(BaseOperator):
     :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter 
(corresponds to
         ``socket.TCP_KEEPINTVL``)
     :param deferrable: Run operator in the deferrable mode
+    :param retry_args: Arguments which define the retry behaviour.
+        See Tenacity documentation at https://github.com/jd/tenacity
     """
 
     conn_id_field = "http_conn_id"
@@ -120,6 +122,7 @@ class HttpOperator(BaseOperator):
         tcp_keep_alive_count: int = 20,
         tcp_keep_alive_interval: int = 30,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        retry_args: dict[str, Any] | None = None,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -139,6 +142,7 @@ class HttpOperator(BaseOperator):
         self.tcp_keep_alive_count = tcp_keep_alive_count
         self.tcp_keep_alive_interval = tcp_keep_alive_interval
         self.deferrable = deferrable
+        self.retry_args = retry_args
 
     @property
     def hook(self) -> HttpHook:
@@ -167,7 +171,12 @@ class HttpOperator(BaseOperator):
 
     def execute_sync(self, context: Context) -> Any:
         self.log.info("Calling HTTP method")
-        response = self.hook.run(self.endpoint, self.data, self.headers, 
self.extra_options)
+        if self.retry_args:
+            response = self.hook.run_with_advanced_retry(
+                self.retry_args, self.endpoint, self.data, self.headers, 
self.extra_options
+            )
+        else:
+            response = self.hook.run(self.endpoint, self.data, self.headers, 
self.extra_options)
         response = self.paginate_sync(response=response)
         return self.process_response(context=context, response=response)
 
@@ -180,7 +189,12 @@ class HttpOperator(BaseOperator):
             next_page_params = self.pagination_function(response)
             if not next_page_params:
                 break
-            response = 
self.hook.run(**self._merge_next_page_parameters(next_page_params))
+            if self.retry_args:
+                response = self.hook.run_with_advanced_retry(
+                    self.retry_args, 
**self._merge_next_page_parameters(next_page_params)
+                )
+            else:
+                response = 
self.hook.run(**self._merge_next_page_parameters(next_page_params))
             all_responses.append(response)
         return all_responses
 
diff --git a/tests/providers/http/operators/test_http.py 
b/tests/providers/http/operators/test_http.py
index dfd82a17ae..3c0d4eb438 100644
--- a/tests/providers/http/operators/test_http.py
+++ b/tests/providers/http/operators/test_http.py
@@ -19,14 +19,18 @@ from __future__ import annotations
 
 import base64
 import contextlib
+import json
 import pickle
 from unittest import mock
+from unittest.mock import call, patch
 
 import pytest
+import tenacity
 from requests import Response
 from requests.models import RequestEncodingMixin
 
 from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.providers.http.hooks.http import HttpHook
 from airflow.providers.http.operators.http import HttpOperator
 from airflow.providers.http.triggers.http import HttpTrigger
 
@@ -228,3 +232,70 @@ class TestHttpOperator:
                 **create_resume_response_parameters(), 
paginated_responses=[make_response_object()]
             )
             assert result == ['{"value": 5}', '{"value": 5}']
+
+    @patch.object(HttpHook, "run_with_advanced_retry")
+    def test_retry_args(self, mock_run_with_advanced_retry, requests_mock):
+        requests_mock.get("http://www.example.com";, exc=Exception("Example 
Exception"))
+        retry_args = dict(
+            wait=tenacity.wait_none(),
+            stop=tenacity.stop_after_attempt(5),
+            retry=tenacity.retry_if_exception_type(Exception),
+        )
+        operator = HttpOperator(
+            task_id="test_HTTP_op",
+            method="GET",
+            endpoint="/",
+            http_conn_id="HTTP_EXAMPLE",
+            retry_args=retry_args,
+        )
+        operator.execute({})
+        mock_run_with_advanced_retry.assert_called_with(retry_args, "/", {}, 
{}, {})
+        assert mock_run_with_advanced_retry.call_count == 1
+
+    @patch.object(HttpHook, "run_with_advanced_retry")
+    def test_pagination_retry_args(
+        self,
+        mock_run_with_advanced_retry,
+        requests_mock,
+    ):
+        is_second_call: bool = False
+
+        def pagination_function(response: Response) -> dict | None:
+            """Paginated function which returns None at the second call."""
+            nonlocal is_second_call
+            if not is_second_call:
+                is_second_call = True
+                return dict(
+                    endpoint=response.json()["endpoint"],
+                )
+            return None
+
+        retry_args = dict(
+            wait=tenacity.wait_none(),
+            stop=tenacity.stop_after_attempt(5),
+            retry=tenacity.retry_if_exception_type(Exception),
+        )
+        operator = HttpOperator(
+            task_id="test_HTTP_op",
+            method="GET",
+            endpoint="/",
+            http_conn_id="HTTP_EXAMPLE",
+            pagination_function=pagination_function,
+            retry_args=retry_args,
+        )
+
+        response = Response()
+        response.status_code = 200
+        response._content = json.dumps({"value": 5, "endpoint": 
"/"}).encode("utf-8")
+        response.headers["Content-Type"] = "application/json"
+
+        mock_run_with_advanced_retry.return_value = response
+        operator.execute({})
+        mock_run_with_advanced_retry.assert_has_calls(
+            [
+                call(retry_args, "/", {}, {}, {}),
+                call(retry_args, endpoint="/", data={}, headers={}, 
extra_options={}),
+            ]
+        )
+
+        assert mock_run_with_advanced_retry.call_count == 2

Reply via email to