This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f3666e7236f Deferrable support for HttpOperator (#45228)
f3666e7236f is described below

commit f3666e7236f9e8ea31cd6752a4dc0f7a9a8001a7
Author: TakayukiTanabeSS <takayuki.tan...@sansan.com>
AuthorDate: Fri Jan 17 13:48:37 2025 +0900

    Deferrable support for HttpOperator (#45228)
    
    * Corrected the relationship between session and response appropriately.
    
    * made HttpMethodException
    
    * Update providers/src/airflow/providers/http/hooks/http.py
    
    Co-authored-by: Wei Lee <weilee...@gmail.com>
    
    * Update providers/src/airflow/providers/http/hooks/http.py
    
    Co-authored-by: Wei Lee <weilee...@gmail.com>
    
    * fix for review
    
    * fix for pre-commit
    
    ---------
    
    Co-authored-by: Wei Lee <weilee...@gmail.com>
---
 providers/src/airflow/providers/http/exceptions.py |  27 ++++
 providers/src/airflow/providers/http/hooks/http.py |  90 ++++++------
 .../src/airflow/providers/http/triggers/http.py    |  31 +++--
 providers/tests/http/hooks/test_http.py            | 153 +++++++++++++--------
 4 files changed, 185 insertions(+), 116 deletions(-)

diff --git a/providers/src/airflow/providers/http/exceptions.py 
b/providers/src/airflow/providers/http/exceptions.py
new file mode 100644
index 00000000000..7f0852ebf32
--- /dev/null
+++ b/providers/src/airflow/providers/http/exceptions.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.exceptions import AirflowException
+
+
+class HttpErrorException(AirflowException):
+    """Exception raised for HTTP error in Http hook."""
+
+
+class HttpMethodException(AirflowException):
+    """Exception raised for invalid HTTP methods in Http hook."""
diff --git a/providers/src/airflow/providers/http/hooks/http.py 
b/providers/src/airflow/providers/http/hooks/http.py
index a179739275e..b22a01f8283 100644
--- a/providers/src/airflow/providers/http/hooks/http.py
+++ b/providers/src/airflow/providers/http/hooks/http.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import asyncio
 from typing import TYPE_CHECKING, Any, Callable
 from urllib.parse import urlparse
 
@@ -32,6 +31,7 @@ from requests_toolbelt.adapters.socket_options import 
TCPKeepAliveAdapter
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
+from airflow.providers.http.exceptions import HttpErrorException, 
HttpMethodException
 
 if TYPE_CHECKING:
     from aiohttp.client_reqrep import ClientResponse
@@ -359,6 +359,7 @@ class HttpAsyncHook(BaseHook):
 
     async def run(
         self,
+        session: aiohttp.ClientSession,
         endpoint: str | None = None,
         data: dict[str, Any] | str | None = None,
         json: dict[str, Any] | str | None = None,
@@ -410,54 +411,51 @@ class HttpAsyncHook(BaseHook):
 
         url = _url_from_endpoint(self.base_url, endpoint)
 
-        async with aiohttp.ClientSession() as session:
-            if self.method == "GET":
-                request_func = session.get
-            elif self.method == "POST":
-                request_func = session.post
-            elif self.method == "PATCH":
-                request_func = session.patch
-            elif self.method == "HEAD":
-                request_func = session.head
-            elif self.method == "PUT":
-                request_func = session.put
-            elif self.method == "DELETE":
-                request_func = session.delete
-            elif self.method == "OPTIONS":
-                request_func = session.options
-            else:
-                raise AirflowException(f"Unexpected HTTP Method: 
{self.method}")
-
-            for attempt in range(1, 1 + self.retry_limit):
-                response = await request_func(
+        if self.method == "GET":
+            request_func = session.get
+        elif self.method == "POST":
+            request_func = session.post
+        elif self.method == "PATCH":
+            request_func = session.patch
+        elif self.method == "HEAD":
+            request_func = session.head
+        elif self.method == "PUT":
+            request_func = session.put
+        elif self.method == "DELETE":
+            request_func = session.delete
+        elif self.method == "OPTIONS":
+            request_func = session.options
+        else:
+            raise HttpMethodException(f"Unexpected HTTP Method: {self.method}")
+
+        for attempt in range(1, 1 + self.retry_limit):
+            response = await request_func(
+                url,
+                params=data if self.method == "GET" else None,
+                data=data if self.method in ("POST", "PUT", "PATCH") else None,
+                json=json,
+                headers=_headers,
+                auth=auth,
+                **extra_options,
+            )
+            try:
+                response.raise_for_status()
+            except ClientResponseError as e:
+                self.log.warning(
+                    "[Try %d of %d] Request to %s failed.",
+                    attempt,
+                    self.retry_limit,
                     url,
-                    params=data if self.method == "GET" else None,
-                    data=data if self.method in ("POST", "PUT", "PATCH") else 
None,
-                    json=json,
-                    headers=_headers,
-                    auth=auth,
-                    **extra_options,
                 )
-                try:
-                    response.raise_for_status()
-                except ClientResponseError as e:
-                    self.log.warning(
-                        "[Try %d of %d] Request to %s failed.",
-                        attempt,
-                        self.retry_limit,
-                        url,
-                    )
-                    if not self._retryable_error_async(e) or attempt == 
self.retry_limit:
-                        self.log.exception("HTTP error with status: %s", 
e.status)
-                        # In this case, the user probably made a mistake.
-                        # Don't retry.
-                        raise AirflowException(f"{e.status}:{e.message}")
-                    else:
-                        await asyncio.sleep(self.retry_delay)
-                else:
-                    return response
+                if not self._retryable_error_async(e) or attempt == 
self.retry_limit:
+                    self.log.exception("HTTP error with status: %s", e.status)
+                    # In this case, the user probably made a mistake.
+                    # Don't retry.
+                    raise HttpErrorException(f"{e.status}:{e.message}")
             else:
-                raise NotImplementedError  # should not reach this, but makes 
mypy happy
+                return response
+
+        raise NotImplementedError  # should not reach this, but makes mypy 
happy
 
     @classmethod
     def _process_extra_options_from_connection(cls, conn: Connection, 
extra_options: dict) -> dict:
diff --git a/providers/src/airflow/providers/http/triggers/http.py 
b/providers/src/airflow/providers/http/triggers/http.py
index c527d86ae54..d25d3a55cfb 100644
--- a/providers/src/airflow/providers/http/triggers/http.py
+++ b/providers/src/airflow/providers/http/triggers/http.py
@@ -22,6 +22,7 @@ import pickle
 from collections.abc import AsyncIterator
 from typing import TYPE_CHECKING, Any
 
+import aiohttp
 import requests
 from requests.cookies import RequestsCookieJar
 from requests.structures import CaseInsensitiveDict
@@ -94,13 +95,15 @@ class HttpTrigger(BaseTrigger):
             auth_type=self.auth_type,
         )
         try:
-            client_response = await hook.run(
-                endpoint=self.endpoint,
-                data=self.data,
-                headers=self.headers,
-                extra_options=self.extra_options,
-            )
-            response = await self._convert_response(client_response)
+            async with aiohttp.ClientSession() as session:
+                client_response = await hook.run(
+                    session=session,
+                    endpoint=self.endpoint,
+                    data=self.data,
+                    headers=self.headers,
+                    extra_options=self.extra_options,
+                )
+                response = await self._convert_response(client_response)
             yield TriggerEvent(
                 {
                     "status": "success",
@@ -181,12 +184,14 @@ class HttpSensorTrigger(BaseTrigger):
         hook = self._get_async_hook()
         while True:
             try:
-                await hook.run(
-                    endpoint=self.endpoint,
-                    data=self.data,
-                    headers=self.headers,
-                    extra_options=self.extra_options,
-                )
+                async with aiohttp.ClientSession() as session:
+                    await hook.run(
+                        session=session,
+                        endpoint=self.endpoint,
+                        data=self.data,
+                        headers=self.headers,
+                        extra_options=self.extra_options,
+                    )
                 yield TriggerEvent(True)
                 return
             except AirflowException as exc:
diff --git a/providers/tests/http/hooks/test_http.py 
b/providers/tests/http/hooks/test_http.py
index bd381a7155b..82a1ff97651 100644
--- a/providers/tests/http/hooks/test_http.py
+++ b/providers/tests/http/hooks/test_http.py
@@ -25,6 +25,7 @@ import os
 from http import HTTPStatus
 from unittest import mock
 
+import aiohttp
 import pytest
 import requests
 import tenacity
@@ -565,7 +566,8 @@ class TestHttpAsyncHook:
                 AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
             ),
         ):
-            await hook.run(endpoint="non_existent_endpoint")
+            async with aiohttp.ClientSession() as session:
+                await hook.run(session=session, 
endpoint="non_existent_endpoint")
 
     @pytest.mark.asyncio
     async def test_do_api_call_async_retryable_error(self, caplog, 
aioresponse):
@@ -581,7 +583,8 @@ class TestHttpAsyncHook:
                 AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
             ),
         ):
-            await hook.run(endpoint="non_existent_endpoint")
+            async with aiohttp.ClientSession() as session:
+                await hook.run(session=session, 
endpoint="non_existent_endpoint")
 
         assert "[Try 3 of 3] Request to 
http://httpbin.org/non_existent_endpoint failed" in caplog.text
 
@@ -593,61 +596,69 @@ class TestHttpAsyncHook:
         json = {"existing_cluster_id": "xxxx-xxxxxx-xxxxxx"}
 
         with pytest.raises(AirflowException, match="Unexpected HTTP Method: 
NOPE"):
-            await hook.run(endpoint="non_existent_endpoint", data=json)
+            async with aiohttp.ClientSession() as session:
+                await hook.run(session=session, 
endpoint="non_existent_endpoint", data=json)
 
     @pytest.mark.asyncio
-    async def test_async_post_request(self, aioresponse):
+    async def test_async_post_request(self):
         """Test api call asynchronously for POST request."""
         hook = HttpAsyncHook()
 
-        aioresponse.post(
-            "http://test:8080/v1/test";,
-            status=200,
-            payload='{"status":{"status": 200}}',
-            reason="OK",
-        )
+        with aioresponses() as m:
+            m.post(
+                "http://test:8080/v1/test";,
+                status=200,
+                payload='{"status":{"status": 200}}',
+                reason="OK",
+            )
 
-        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
-            resp = await hook.run("v1/test")
-            assert resp.status == 200
+            with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+                async with aiohttp.ClientSession() as session:
+                    resp = await hook.run(session=session, endpoint="v1/test")
+                    assert resp.status == 200
 
     @pytest.mark.asyncio
-    async def test_async_post_request_with_error_code(self, aioresponse):
+    async def test_async_post_request_with_error_code(self):
         """Test api call asynchronously for POST request with error."""
         hook = HttpAsyncHook()
 
-        aioresponse.post(
-            "http://test:8080/v1/test";,
-            status=418,
-            payload='{"status":{"status": 418}}',
-            reason="I am teapot",
-        )
+        with aioresponses() as m:
+            m.post(
+                "http://test:8080/v1/test";,
+                status=418,
+                payload='{"status":{"status": 418}}',
+                reason="I am teapot",
+            )
 
-        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
-            with pytest.raises(AirflowException):
-                await hook.run("v1/test")
+            with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+                async with aiohttp.ClientSession() as session:
+                    with pytest.raises(AirflowException):
+                        await hook.run(session=session, endpoint="v1/test")
 
     @pytest.mark.asyncio
-    async def test_async_request_uses_connection_extra(self, aioresponse):
+    async def test_async_request_uses_connection_extra(self):
         """Test api call asynchronously with a connection that has extra 
field."""
 
         connection_extra = {"bearer": "test"}
 
-        aioresponse.post(
-            "http://test:8080/v1/test";,
-            status=200,
-            payload='{"status":{"status": 200}}',
-            reason="OK",
-        )
+        with aioresponses() as m:
+            m.post(
+                "http://test:8080/v1/test";,
+                status=200,
+                payload='{"status":{"status": 200}}',
+                reason="OK",
+            )
 
-        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
-            hook = HttpAsyncHook()
-            with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
-                await hook.run("v1/test")
-                headers = mocked_function.call_args.kwargs.get("headers")
-                assert all(
-                    key in headers and headers[key] == value for key, value in 
connection_extra.items()
-                )
+            with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+                hook = HttpAsyncHook()
+                with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
+                    async with aiohttp.ClientSession() as session:
+                        await hook.run(session=session, endpoint="v1/test")
+                        headers = 
mocked_function.call_args.kwargs.get("headers")
+                        assert all(
+                            key in headers and headers[key] == value
+                            for key, value in connection_extra.items()
+                        )
 
     @pytest.mark.asyncio
     async def 
test_async_request_uses_connection_extra_with_requests_parameters(self):
@@ -670,18 +681,29 @@ class TestHttpAsyncHook:
 
         with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=airflow_connection):
             hook = HttpAsyncHook()
-            with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
-                await hook.run("v1/test")
-                headers = mocked_function.call_args.kwargs.get("headers")
-                assert all(
-                    key in headers and headers[key] == value for key, value in 
connection_extra.items()
+
+            with aioresponses() as m:
+                m.post(
+                    "http://test:8080/v1/test";,
+                    status=200,
+                    payload='{"status":{"status": 200}}',
+                    reason="OK",
                 )
-                assert mocked_function.call_args.kwargs.get("proxy") == proxy
-                assert mocked_function.call_args.kwargs.get("timeout") == 60
-                assert mocked_function.call_args.kwargs.get("verify_ssl") is 
False
-                assert mocked_function.call_args.kwargs.get("allow_redirects") 
is False
-                assert mocked_function.call_args.kwargs.get("max_redirects") 
== 3
-                assert mocked_function.call_args.kwargs.get("trust_env") is 
False
+
+                with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
+                    async with aiohttp.ClientSession() as session:
+                        await hook.run(session=session, endpoint="v1/test")
+                        headers = 
mocked_function.call_args.kwargs.get("headers")
+                        assert all(
+                            key in headers and headers[key] == value
+                            for key, value in connection_extra.items()
+                        )
+                        assert mocked_function.call_args.kwargs.get("proxy") 
== proxy
+                        assert mocked_function.call_args.kwargs.get("timeout") 
== 60
+                        assert 
mocked_function.call_args.kwargs.get("verify_ssl") is False
+                        assert 
mocked_function.call_args.kwargs.get("allow_redirects") is False
+                        assert 
mocked_function.call_args.kwargs.get("max_redirects") == 3
+                        assert 
mocked_function.call_args.kwargs.get("trust_env") is False
 
     def test_process_extra_options_from_connection(self):
         extra_options = {}
@@ -718,9 +740,19 @@ class TestHttpAsyncHook:
         schema = conn.schema or "http"  # default to http
         with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
             hook = HttpAsyncHook()
+
+            with aioresponses() as m:
+                m.post(
+                    f"{schema}://test:8080/v1/test",
+                    status=200,
+                    payload='{"status":{"status": 200}}',
+                    reason="OK",
+                )
+
             with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
-                await hook.run("v1/test")
-                assert mocked_function.call_args.args[0] == 
f"{schema}://{conn.host}v1/test"
+                async with aiohttp.ClientSession() as session:
+                    await hook.run(session=session, endpoint="v1/test")
+                    assert mocked_function.call_args.args[0] == 
f"{schema}://{conn.host}v1/test"
 
     @pytest.mark.asyncio
     async def test_build_request_url_from_endpoint_param(self):
@@ -728,9 +760,16 @@ class TestHttpAsyncHook:
             return Connection(conn_id=conn_id, conn_type="http")
 
         hook = HttpAsyncHook()
-        with (
-            mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_empty_conn),
-            mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function,
-        ):
-            await hook.run("test.com:8080/v1/test")
-            assert mocked_function.call_args.args[0] == 
"http://test.com:8080/v1/test";
+
+        with aioresponses() as m:
+            m.post(
+                "http://test.com:8080/v1/test";, status=200, 
payload='{"status":{"status": 200}}', reason="OK"
+            )
+
+            with (
+                mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_empty_conn),
+                mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function,
+            ):
+                async with aiohttp.ClientSession() as session:
+                    await hook.run(session=session, 
endpoint="test.com:8080/v1/test")
+                    assert mocked_function.call_args.args[0] == 
"http://test.com:8080/v1/test";

Reply via email to