This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit e02271003931c2d9df58a2b769c89cac8d4eb0a2
Author: Joffrey Bienvenu <joffrey.bienv...@infrabel.be>
AuthorDate: Wed Nov 22 15:34:34 2023 +0100

    Fix HttpOperator pagination with `str` data (#35782)
    
    * feat: Restrict `data` parameter typing
    
    Follows the hook's typing
    
    * feat: Implement `data` override when string
    
    * feat: Improve docstring about merging and overriding behavior
    
    * fix: Add correct typing for mypy
    
    * feat: add test
    
    * fix: remove unused imports
    
    * fix: Update SimpleHttpOperator docstring
    
    * feat: Correctly test parameters overriding
---
 airflow/providers/http/operators/http.py    | 50 +++++++++++++++--------
 airflow/providers/http/triggers/http.py     |  2 +-
 tests/providers/http/operators/test_http.py | 63 ++++++++++++++++++++++-------
 3 files changed, 84 insertions(+), 31 deletions(-)

diff --git a/airflow/providers/http/operators/http.py 
b/airflow/providers/http/operators/http.py
index 96415ed977..524de8c585 100644
--- a/airflow/providers/http/operators/http.py
+++ b/airflow/providers/http/operators/http.py
@@ -56,12 +56,17 @@ class HttpOperator(BaseOperator):
     :param pagination_function: A callable that generates the parameters used 
to call the API again,
         based on the previous response. Typically used when the API is 
paginated and returns for e.g a
         cursor, a 'next page id', or a 'next page URL'. When provided, the 
Operator will call the API
-        repeatedly until this callable returns None. Also, the result of the 
Operator will become by
-        default a list of Response.text objects (instead of a single response 
object). Same with the
-        other injected functions (like response_check, response_filter, ...) 
which will also receive a
-        list of Response object. This function receives a Response object form 
previous call, and should
-        return a dict of parameters (`endpoint`, `data`, `headers`, 
`extra_options`), which will be merged
-        and will override the one used in the initial API call.
+        repeatedly until this callable returns None. The result of the 
Operator will become by default a
+        list of Response.text objects (instead of a single response object). 
Same with the other injected
+        functions (like response_check, response_filter, ...) which will also 
receive a list of Response
+        objects. This function receives a Response object form previous call, 
and should return a nested
+        dictionary with the following optional keys: `endpoint`, `data`, 
`headers` and `extra_options.
+        Those keys will be merged and/or override the parameters provided into 
the HttpOperator declaration.
+        Parameters are merged when they are both a dictionary (e.g.: 
HttpOperator.headers will be merged
+        with the `headers` dict provided by this function). When merging, dict 
items returned by this
+        function will override initial ones (e.g: if both HttpOperator.headers 
and `headers` have a 'cookie'
+        item, the one provided by `headers` is kept). Parameters are simply 
overridden when any of them are
+        string (e.g.: HttpOperator.endpoint is overridden by `endpoint`).
     :param response_check: A check against the 'requests' response object.
         The callable takes the response object as the first positional argument
         and optionally any number of keyword arguments available in the 
context dictionary.
@@ -101,7 +106,7 @@ class HttpOperator(BaseOperator):
         *,
         endpoint: str | None = None,
         method: str = "POST",
-        data: Any = None,
+        data: dict[str, Any] | str | None = None,
         headers: dict[str, str] | None = None,
         pagination_function: Callable[..., Any] | None = None,
         response_check: Callable[..., bool] | None = None,
@@ -271,9 +276,16 @@ class HttpOperator(BaseOperator):
         :param next_page_params: A dictionary containing the parameters for 
the next page.
         :return: A dictionary containing the merged parameters.
         """
+        data: str | dict | None = None  # makes mypy happy
+        next_page_data_param = next_page_params.get("data")
+        if isinstance(self.data, dict) and isinstance(next_page_data_param, 
dict):
+            data = merge_dicts(self.data, next_page_data_param)
+        else:
+            data = next_page_data_param or self.data
+
         return dict(
             endpoint=next_page_params.get("endpoint") or self.endpoint,
-            data=merge_dicts(self.data, next_page_params.get("data", {})),
+            data=data,
             headers=merge_dicts(self.headers, next_page_params.get("headers", 
{})),
             extra_options=merge_dicts(self.extra_options, 
next_page_params.get("extra_options", {})),
         )
@@ -294,14 +306,20 @@ class SimpleHttpOperator(HttpOperator):
     :param data: The data to pass. POST-data in POST/PUT and params
         in the URL for a GET request. (templated)
     :param headers: The HTTP headers to be added to the GET request
-    :param pagination_function: A callable that generates the parameters used 
to call the API again.
-        Typically used when the API is paginated and returns for e.g a cursor, 
a 'next page id', or
-        a 'next page URL'. When provided, the Operator will call the API 
repeatedly until this callable
-        returns None. Also, the result of the Operator will become by default 
a list of Response.text
-        objects (instead of a single response object). Same with the other 
injected functions (like
-        response_check, response_filter, ...) which will also receive a list 
of Response object. This
-        function should return a dict of parameters (`endpoint`, `data`, 
`headers`, `extra_options`),
-        which will be merged and override the one used in the initial API call.
+    :param pagination_function: A callable that generates the parameters used 
to call the API again,
+        based on the previous response. Typically used when the API is 
paginated and returns for e.g a
+        cursor, a 'next page id', or a 'next page URL'. When provided, the 
Operator will call the API
+        repeatedly until this callable returns None. The result of the 
Operator will become by default a
+        list of Response.text objects (instead of a single response object). 
Same with the other injected
+        functions (like response_check, response_filter, ...) which will also 
receive a list of Response
+        objects. This function receives a Response object form previous call, 
and should return a nested
+        dictionary with the following optional keys: `endpoint`, `data`, 
`headers` and `extra_options.
+        Those keys will be merged and/or override the parameters provided into 
the HttpOperator declaration.
+        Parameters are merged when they are both a dictionary (e.g.: 
HttpOperator.headers will be merged
+        with the `headers` dict provided by this function). When merging, dict 
items returned by this
+        function will override initial ones (e.g: if both HttpOperator.headers 
and `headers` have a 'cookie'
+        item, the one provided by `headers` is kept). Parameters are simply 
overridden when any of them are
+        string (e.g.: HttpOperator.endpoint is overridden by `endpoint`).
     :param response_check: A check against the 'requests' response object.
         The callable takes the response object as the first positional argument
         and optionally any number of keyword arguments available in the 
context dictionary.
diff --git a/airflow/providers/http/triggers/http.py 
b/airflow/providers/http/triggers/http.py
index 89aa9ca606..b4598984f3 100644
--- a/airflow/providers/http/triggers/http.py
+++ b/airflow/providers/http/triggers/http.py
@@ -56,7 +56,7 @@ class HttpTrigger(BaseTrigger):
         method: str = "POST",
         endpoint: str | None = None,
         headers: dict[str, str] | None = None,
-        data: Any = None,
+        data: dict[str, Any] | str | None = None,
         extra_options: dict[str, Any] | None = None,
     ):
         super().__init__()
diff --git a/tests/providers/http/operators/test_http.py 
b/tests/providers/http/operators/test_http.py
index 451cd93d44..dfd82a17ae 100644
--- a/tests/providers/http/operators/test_http.py
+++ b/tests/providers/http/operators/test_http.py
@@ -24,6 +24,7 @@ from unittest import mock
 
 import pytest
 from requests import Response
+from requests.models import RequestEncodingMixin
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.http.operators.http import HttpOperator
@@ -112,41 +113,75 @@ class TestHttpOperator:
         )
         assert result == "content"
 
-    def test_paginated_responses(self, requests_mock):
+    @pytest.mark.parametrize(
+        "data, headers, extra_options, pagination_data, pagination_headers, 
pagination_extra_options",
+        [
+            ({"data": 1}, {"x-head": "1"}, {"verify": False}, {"data": 2}, 
{"x-head": "0"}, {"verify": True}),
+            ("data foo", {"x-head": "1"}, {"verify": False}, {"data": 2}, 
{"x-head": "0"}, {"verify": True}),
+            ("data foo", {"x-head": "1"}, {"verify": False}, "data bar", 
{"x-head": "0"}, {"verify": True}),
+            ({"data": 1}, {"x-head": "1"}, {"verify": False}, "data foo", 
{"x-head": "0"}, {"verify": True}),
+        ],
+    )
+    def test_pagination(
+        self,
+        requests_mock,
+        data,
+        headers,
+        extra_options,
+        pagination_data,
+        pagination_headers,
+        pagination_extra_options,
+    ):
         """
         Test that the HttpOperator calls repetitively the API when a
         pagination_function is provided, and as long as this function returns
         a dictionary that override previous' call parameters.
         """
-        iterations: int = 0
+        is_second_call: bool = False
 
         def pagination_function(response: Response) -> dict | None:
             """Paginated function which returns None at the second call."""
-            nonlocal iterations
-            if iterations < 2:
-                iterations += 1
+            nonlocal is_second_call
+            if not is_second_call:
+                is_second_call = True
                 return dict(
                     endpoint=response.json()["endpoint"],
-                    data={},
-                    headers={},
-                    extra_options={},
+                    data=pagination_data,
+                    headers=pagination_headers,
+                    extra_options=pagination_extra_options,
                 )
             return None
 
-        requests_mock.get("http://www.example.com/foo";, json={"value": 5, 
"endpoint": "bar"})
-        requests_mock.get("http://www.example.com/bar";, json={"value": 10, 
"endpoint": "foo"})
+        first_endpoint = requests_mock.post("http://www.example.com/1";, 
json={"value": 5, "endpoint": "2"})
+        second_endpoint = requests_mock.post("http://www.example.com/2";, 
json={"value": 10, "endpoint": "3"})
         operator = HttpOperator(
             task_id="test_HTTP_op",
-            method="GET",
-            endpoint="/foo",
+            method="POST",
+            endpoint="/1",
+            data=data,
+            headers=headers,
+            extra_options=extra_options,
             http_conn_id="HTTP_EXAMPLE",
             pagination_function=pagination_function,
             response_filter=lambda resp: [entry.json()["value"] for entry in 
resp],
         )
         result = operator.execute({})
-        assert result == [5, 10, 5]
 
-    def test_async_paginated_responses(self, requests_mock):
+        # Ensure the initial call is made with parameters passed to the 
Operator
+        first_call = first_endpoint.request_history[0]
+        assert first_call.headers.items() >= headers.items()
+        assert first_call.body == RequestEncodingMixin._encode_params(data)
+        assert first_call.verify is extra_options["verify"]
+
+        # Ensure the second - paginated - call is made with parameters merged 
from the pagination function
+        second_call = second_endpoint.request_history[0]
+        assert second_call.headers.items() >= pagination_headers.items()
+        assert second_call.body == 
RequestEncodingMixin._encode_params(pagination_data)
+        assert second_call.verify is pagination_extra_options["verify"]
+
+        assert result == [5, 10]
+
+    def test_async_pagination(self, requests_mock):
         """
         Test that the HttpOperator calls asynchronously and repetitively
         the API when a pagination_function is provided, and as long as this 
function

Reply via email to