This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit e02271003931c2d9df58a2b769c89cac8d4eb0a2 Author: Joffrey Bienvenu <joffrey.bienv...@infrabel.be> AuthorDate: Wed Nov 22 15:34:34 2023 +0100 Fix HttpOperator pagination with `str` data (#35782) * feat: Restrict `data` parameter typing Follows the hook's typing * feat: Implement `data` override when string * feat: Improve docstring about merging and overriding behavior * fix: Add correct typing for mypy * feat: add test * fix: remove unused imports * fix: Update SimpleHttpOperator docstring * feat: Correctly test parameters overriding --- airflow/providers/http/operators/http.py | 50 +++++++++++++++-------- airflow/providers/http/triggers/http.py | 2 +- tests/providers/http/operators/test_http.py | 63 ++++++++++++++++++++++------- 3 files changed, 84 insertions(+), 31 deletions(-) diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index 96415ed977..524de8c585 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -56,12 +56,17 @@ class HttpOperator(BaseOperator): :param pagination_function: A callable that generates the parameters used to call the API again, based on the previous response. Typically used when the API is paginated and returns for e.g a cursor, a 'next page id', or a 'next page URL'. When provided, the Operator will call the API - repeatedly until this callable returns None. Also, the result of the Operator will become by - default a list of Response.text objects (instead of a single response object). Same with the - other injected functions (like response_check, response_filter, ...) which will also receive a - list of Response object. This function receives a Response object form previous call, and should - return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`), which will be merged - and will override the one used in the initial API call. + repeatedly until this callable returns None. The result of the Operator will become by default a + list of Response.text objects (instead of a single response object). Same with the other injected + functions (like response_check, response_filter, ...) which will also receive a list of Response + objects. This function receives a Response object form previous call, and should return a nested + dictionary with the following optional keys: `endpoint`, `data`, `headers` and `extra_options. + Those keys will be merged and/or override the parameters provided into the HttpOperator declaration. + Parameters are merged when they are both a dictionary (e.g.: HttpOperator.headers will be merged + with the `headers` dict provided by this function). When merging, dict items returned by this + function will override initial ones (e.g: if both HttpOperator.headers and `headers` have a 'cookie' + item, the one provided by `headers` is kept). Parameters are simply overridden when any of them are + string (e.g.: HttpOperator.endpoint is overridden by `endpoint`). :param response_check: A check against the 'requests' response object. The callable takes the response object as the first positional argument and optionally any number of keyword arguments available in the context dictionary. @@ -101,7 +106,7 @@ class HttpOperator(BaseOperator): *, endpoint: str | None = None, method: str = "POST", - data: Any = None, + data: dict[str, Any] | str | None = None, headers: dict[str, str] | None = None, pagination_function: Callable[..., Any] | None = None, response_check: Callable[..., bool] | None = None, @@ -271,9 +276,16 @@ class HttpOperator(BaseOperator): :param next_page_params: A dictionary containing the parameters for the next page. :return: A dictionary containing the merged parameters. """ + data: str | dict | None = None # makes mypy happy + next_page_data_param = next_page_params.get("data") + if isinstance(self.data, dict) and isinstance(next_page_data_param, dict): + data = merge_dicts(self.data, next_page_data_param) + else: + data = next_page_data_param or self.data + return dict( endpoint=next_page_params.get("endpoint") or self.endpoint, - data=merge_dicts(self.data, next_page_params.get("data", {})), + data=data, headers=merge_dicts(self.headers, next_page_params.get("headers", {})), extra_options=merge_dicts(self.extra_options, next_page_params.get("extra_options", {})), ) @@ -294,14 +306,20 @@ class SimpleHttpOperator(HttpOperator): :param data: The data to pass. POST-data in POST/PUT and params in the URL for a GET request. (templated) :param headers: The HTTP headers to be added to the GET request - :param pagination_function: A callable that generates the parameters used to call the API again. - Typically used when the API is paginated and returns for e.g a cursor, a 'next page id', or - a 'next page URL'. When provided, the Operator will call the API repeatedly until this callable - returns None. Also, the result of the Operator will become by default a list of Response.text - objects (instead of a single response object). Same with the other injected functions (like - response_check, response_filter, ...) which will also receive a list of Response object. This - function should return a dict of parameters (`endpoint`, `data`, `headers`, `extra_options`), - which will be merged and override the one used in the initial API call. + :param pagination_function: A callable that generates the parameters used to call the API again, + based on the previous response. Typically used when the API is paginated and returns for e.g a + cursor, a 'next page id', or a 'next page URL'. When provided, the Operator will call the API + repeatedly until this callable returns None. The result of the Operator will become by default a + list of Response.text objects (instead of a single response object). Same with the other injected + functions (like response_check, response_filter, ...) which will also receive a list of Response + objects. This function receives a Response object form previous call, and should return a nested + dictionary with the following optional keys: `endpoint`, `data`, `headers` and `extra_options. + Those keys will be merged and/or override the parameters provided into the HttpOperator declaration. + Parameters are merged when they are both a dictionary (e.g.: HttpOperator.headers will be merged + with the `headers` dict provided by this function). When merging, dict items returned by this + function will override initial ones (e.g: if both HttpOperator.headers and `headers` have a 'cookie' + item, the one provided by `headers` is kept). Parameters are simply overridden when any of them are + string (e.g.: HttpOperator.endpoint is overridden by `endpoint`). :param response_check: A check against the 'requests' response object. The callable takes the response object as the first positional argument and optionally any number of keyword arguments available in the context dictionary. diff --git a/airflow/providers/http/triggers/http.py b/airflow/providers/http/triggers/http.py index 89aa9ca606..b4598984f3 100644 --- a/airflow/providers/http/triggers/http.py +++ b/airflow/providers/http/triggers/http.py @@ -56,7 +56,7 @@ class HttpTrigger(BaseTrigger): method: str = "POST", endpoint: str | None = None, headers: dict[str, str] | None = None, - data: Any = None, + data: dict[str, Any] | str | None = None, extra_options: dict[str, Any] | None = None, ): super().__init__() diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py index 451cd93d44..dfd82a17ae 100644 --- a/tests/providers/http/operators/test_http.py +++ b/tests/providers/http/operators/test_http.py @@ -24,6 +24,7 @@ from unittest import mock import pytest from requests import Response +from requests.models import RequestEncodingMixin from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.http.operators.http import HttpOperator @@ -112,41 +113,75 @@ class TestHttpOperator: ) assert result == "content" - def test_paginated_responses(self, requests_mock): + @pytest.mark.parametrize( + "data, headers, extra_options, pagination_data, pagination_headers, pagination_extra_options", + [ + ({"data": 1}, {"x-head": "1"}, {"verify": False}, {"data": 2}, {"x-head": "0"}, {"verify": True}), + ("data foo", {"x-head": "1"}, {"verify": False}, {"data": 2}, {"x-head": "0"}, {"verify": True}), + ("data foo", {"x-head": "1"}, {"verify": False}, "data bar", {"x-head": "0"}, {"verify": True}), + ({"data": 1}, {"x-head": "1"}, {"verify": False}, "data foo", {"x-head": "0"}, {"verify": True}), + ], + ) + def test_pagination( + self, + requests_mock, + data, + headers, + extra_options, + pagination_data, + pagination_headers, + pagination_extra_options, + ): """ Test that the HttpOperator calls repetitively the API when a pagination_function is provided, and as long as this function returns a dictionary that override previous' call parameters. """ - iterations: int = 0 + is_second_call: bool = False def pagination_function(response: Response) -> dict | None: """Paginated function which returns None at the second call.""" - nonlocal iterations - if iterations < 2: - iterations += 1 + nonlocal is_second_call + if not is_second_call: + is_second_call = True return dict( endpoint=response.json()["endpoint"], - data={}, - headers={}, - extra_options={}, + data=pagination_data, + headers=pagination_headers, + extra_options=pagination_extra_options, ) return None - requests_mock.get("http://www.example.com/foo", json={"value": 5, "endpoint": "bar"}) - requests_mock.get("http://www.example.com/bar", json={"value": 10, "endpoint": "foo"}) + first_endpoint = requests_mock.post("http://www.example.com/1", json={"value": 5, "endpoint": "2"}) + second_endpoint = requests_mock.post("http://www.example.com/2", json={"value": 10, "endpoint": "3"}) operator = HttpOperator( task_id="test_HTTP_op", - method="GET", - endpoint="/foo", + method="POST", + endpoint="/1", + data=data, + headers=headers, + extra_options=extra_options, http_conn_id="HTTP_EXAMPLE", pagination_function=pagination_function, response_filter=lambda resp: [entry.json()["value"] for entry in resp], ) result = operator.execute({}) - assert result == [5, 10, 5] - def test_async_paginated_responses(self, requests_mock): + # Ensure the initial call is made with parameters passed to the Operator + first_call = first_endpoint.request_history[0] + assert first_call.headers.items() >= headers.items() + assert first_call.body == RequestEncodingMixin._encode_params(data) + assert first_call.verify is extra_options["verify"] + + # Ensure the second - paginated - call is made with parameters merged from the pagination function + second_call = second_endpoint.request_history[0] + assert second_call.headers.items() >= pagination_headers.items() + assert second_call.body == RequestEncodingMixin._encode_params(pagination_data) + assert second_call.verify is pagination_extra_options["verify"] + + assert result == [5, 10] + + def test_async_pagination(self, requests_mock): """ Test that the HttpOperator calls asynchronously and repetitively the API when a pagination_function is provided, and as long as this function