This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f3666e7236f Deferrable support for HttpOperator (#45228)
f3666e7236f is described below
commit f3666e7236f9e8ea31cd6752a4dc0f7a9a8001a7
Author: TakayukiTanabeSS <[email protected]>
AuthorDate: Fri Jan 17 13:48:37 2025 +0900
Deferrable support for HttpOperator (#45228)
* Corrected the relationship between session and response appropriately.
* made HttpMethodException
* Update providers/src/airflow/providers/http/hooks/http.py
Co-authored-by: Wei Lee <[email protected]>
* Update providers/src/airflow/providers/http/hooks/http.py
Co-authored-by: Wei Lee <[email protected]>
* fix for review
* fix for pre-commit
---------
Co-authored-by: Wei Lee <[email protected]>
---
providers/src/airflow/providers/http/exceptions.py | 27 ++++
providers/src/airflow/providers/http/hooks/http.py | 90 ++++++------
.../src/airflow/providers/http/triggers/http.py | 31 +++--
providers/tests/http/hooks/test_http.py | 153 +++++++++++++--------
4 files changed, 185 insertions(+), 116 deletions(-)
diff --git a/providers/src/airflow/providers/http/exceptions.py
b/providers/src/airflow/providers/http/exceptions.py
new file mode 100644
index 00000000000..7f0852ebf32
--- /dev/null
+++ b/providers/src/airflow/providers/http/exceptions.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.exceptions import AirflowException
+
+
+class HttpErrorException(AirflowException):
+ """Exception raised for HTTP error in Http hook."""
+
+
+class HttpMethodException(AirflowException):
+ """Exception raised for invalid HTTP methods in Http hook."""
diff --git a/providers/src/airflow/providers/http/hooks/http.py
b/providers/src/airflow/providers/http/hooks/http.py
index a179739275e..b22a01f8283 100644
--- a/providers/src/airflow/providers/http/hooks/http.py
+++ b/providers/src/airflow/providers/http/hooks/http.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import asyncio
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlparse
@@ -32,6 +31,7 @@ from requests_toolbelt.adapters.socket_options import
TCPKeepAliveAdapter
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
+from airflow.providers.http.exceptions import HttpErrorException,
HttpMethodException
if TYPE_CHECKING:
from aiohttp.client_reqrep import ClientResponse
@@ -359,6 +359,7 @@ class HttpAsyncHook(BaseHook):
async def run(
self,
+ session: aiohttp.ClientSession,
endpoint: str | None = None,
data: dict[str, Any] | str | None = None,
json: dict[str, Any] | str | None = None,
@@ -410,54 +411,51 @@ class HttpAsyncHook(BaseHook):
url = _url_from_endpoint(self.base_url, endpoint)
- async with aiohttp.ClientSession() as session:
- if self.method == "GET":
- request_func = session.get
- elif self.method == "POST":
- request_func = session.post
- elif self.method == "PATCH":
- request_func = session.patch
- elif self.method == "HEAD":
- request_func = session.head
- elif self.method == "PUT":
- request_func = session.put
- elif self.method == "DELETE":
- request_func = session.delete
- elif self.method == "OPTIONS":
- request_func = session.options
- else:
- raise AirflowException(f"Unexpected HTTP Method:
{self.method}")
-
- for attempt in range(1, 1 + self.retry_limit):
- response = await request_func(
+ if self.method == "GET":
+ request_func = session.get
+ elif self.method == "POST":
+ request_func = session.post
+ elif self.method == "PATCH":
+ request_func = session.patch
+ elif self.method == "HEAD":
+ request_func = session.head
+ elif self.method == "PUT":
+ request_func = session.put
+ elif self.method == "DELETE":
+ request_func = session.delete
+ elif self.method == "OPTIONS":
+ request_func = session.options
+ else:
+ raise HttpMethodException(f"Unexpected HTTP Method: {self.method}")
+
+ for attempt in range(1, 1 + self.retry_limit):
+ response = await request_func(
+ url,
+ params=data if self.method == "GET" else None,
+ data=data if self.method in ("POST", "PUT", "PATCH") else None,
+ json=json,
+ headers=_headers,
+ auth=auth,
+ **extra_options,
+ )
+ try:
+ response.raise_for_status()
+ except ClientResponseError as e:
+ self.log.warning(
+ "[Try %d of %d] Request to %s failed.",
+ attempt,
+ self.retry_limit,
url,
- params=data if self.method == "GET" else None,
- data=data if self.method in ("POST", "PUT", "PATCH") else
None,
- json=json,
- headers=_headers,
- auth=auth,
- **extra_options,
)
- try:
- response.raise_for_status()
- except ClientResponseError as e:
- self.log.warning(
- "[Try %d of %d] Request to %s failed.",
- attempt,
- self.retry_limit,
- url,
- )
- if not self._retryable_error_async(e) or attempt ==
self.retry_limit:
- self.log.exception("HTTP error with status: %s",
e.status)
- # In this case, the user probably made a mistake.
- # Don't retry.
- raise AirflowException(f"{e.status}:{e.message}")
- else:
- await asyncio.sleep(self.retry_delay)
- else:
- return response
+ if not self._retryable_error_async(e) or attempt ==
self.retry_limit:
+ self.log.exception("HTTP error with status: %s", e.status)
+ # In this case, the user probably made a mistake.
+ # Don't retry.
+ raise HttpErrorException(f"{e.status}:{e.message}")
else:
- raise NotImplementedError # should not reach this, but makes
mypy happy
+ return response
+
+ raise NotImplementedError # should not reach this, but makes mypy
happy
@classmethod
def _process_extra_options_from_connection(cls, conn: Connection,
extra_options: dict) -> dict:
diff --git a/providers/src/airflow/providers/http/triggers/http.py
b/providers/src/airflow/providers/http/triggers/http.py
index c527d86ae54..d25d3a55cfb 100644
--- a/providers/src/airflow/providers/http/triggers/http.py
+++ b/providers/src/airflow/providers/http/triggers/http.py
@@ -22,6 +22,7 @@ import pickle
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any
+import aiohttp
import requests
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict
@@ -94,13 +95,15 @@ class HttpTrigger(BaseTrigger):
auth_type=self.auth_type,
)
try:
- client_response = await hook.run(
- endpoint=self.endpoint,
- data=self.data,
- headers=self.headers,
- extra_options=self.extra_options,
- )
- response = await self._convert_response(client_response)
+ async with aiohttp.ClientSession() as session:
+ client_response = await hook.run(
+ session=session,
+ endpoint=self.endpoint,
+ data=self.data,
+ headers=self.headers,
+ extra_options=self.extra_options,
+ )
+ response = await self._convert_response(client_response)
yield TriggerEvent(
{
"status": "success",
@@ -181,12 +184,14 @@ class HttpSensorTrigger(BaseTrigger):
hook = self._get_async_hook()
while True:
try:
- await hook.run(
- endpoint=self.endpoint,
- data=self.data,
- headers=self.headers,
- extra_options=self.extra_options,
- )
+ async with aiohttp.ClientSession() as session:
+ await hook.run(
+ session=session,
+ endpoint=self.endpoint,
+ data=self.data,
+ headers=self.headers,
+ extra_options=self.extra_options,
+ )
yield TriggerEvent(True)
return
except AirflowException as exc:
diff --git a/providers/tests/http/hooks/test_http.py
b/providers/tests/http/hooks/test_http.py
index bd381a7155b..82a1ff97651 100644
--- a/providers/tests/http/hooks/test_http.py
+++ b/providers/tests/http/hooks/test_http.py
@@ -25,6 +25,7 @@ import os
from http import HTTPStatus
from unittest import mock
+import aiohttp
import pytest
import requests
import tenacity
@@ -565,7 +566,8 @@ class TestHttpAsyncHook:
AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/",
),
):
- await hook.run(endpoint="non_existent_endpoint")
+ async with aiohttp.ClientSession() as session:
+ await hook.run(session=session,
endpoint="non_existent_endpoint")
@pytest.mark.asyncio
async def test_do_api_call_async_retryable_error(self, caplog,
aioresponse):
@@ -581,7 +583,8 @@ class TestHttpAsyncHook:
AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/",
),
):
- await hook.run(endpoint="non_existent_endpoint")
+ async with aiohttp.ClientSession() as session:
+ await hook.run(session=session,
endpoint="non_existent_endpoint")
assert "[Try 3 of 3] Request to
http://httpbin.org/non_existent_endpoint failed" in caplog.text
@@ -593,61 +596,69 @@ class TestHttpAsyncHook:
json = {"existing_cluster_id": "xxxx-xxxxxx-xxxxxx"}
with pytest.raises(AirflowException, match="Unexpected HTTP Method:
NOPE"):
- await hook.run(endpoint="non_existent_endpoint", data=json)
+ async with aiohttp.ClientSession() as session:
+ await hook.run(session=session,
endpoint="non_existent_endpoint", data=json)
@pytest.mark.asyncio
- async def test_async_post_request(self, aioresponse):
+ async def test_async_post_request(self):
"""Test api call asynchronously for POST request."""
hook = HttpAsyncHook()
- aioresponse.post(
- "http://test:8080/v1/test",
- status=200,
- payload='{"status":{"status": 200}}',
- reason="OK",
- )
+ with aioresponses() as m:
+ m.post(
+ "http://test:8080/v1/test",
+ status=200,
+ payload='{"status":{"status": 200}}',
+ reason="OK",
+ )
- with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
- resp = await hook.run("v1/test")
- assert resp.status == 200
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
+ async with aiohttp.ClientSession() as session:
+ resp = await hook.run(session=session, endpoint="v1/test")
+ assert resp.status == 200
@pytest.mark.asyncio
- async def test_async_post_request_with_error_code(self, aioresponse):
+ async def test_async_post_request_with_error_code(self):
"""Test api call asynchronously for POST request with error."""
hook = HttpAsyncHook()
- aioresponse.post(
- "http://test:8080/v1/test",
- status=418,
- payload='{"status":{"status": 418}}',
- reason="I am teapot",
- )
+ with aioresponses() as m:
+ m.post(
+ "http://test:8080/v1/test",
+ status=418,
+ payload='{"status":{"status": 418}}',
+ reason="I am teapot",
+ )
- with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
- with pytest.raises(AirflowException):
- await hook.run("v1/test")
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
+ async with aiohttp.ClientSession() as session:
+ with pytest.raises(AirflowException):
+ await hook.run(session=session, endpoint="v1/test")
@pytest.mark.asyncio
- async def test_async_request_uses_connection_extra(self, aioresponse):
+ async def test_async_request_uses_connection_extra(self):
"""Test api call asynchronously with a connection that has extra
field."""
connection_extra = {"bearer": "test"}
- aioresponse.post(
- "http://test:8080/v1/test",
- status=200,
- payload='{"status":{"status": 200}}',
- reason="OK",
- )
+ with aioresponses() as m:
+ m.post(
+ "http://test:8080/v1/test",
+ status=200,
+ payload='{"status":{"status": 200}}',
+ reason="OK",
+ )
- with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
- hook = HttpAsyncHook()
- with mock.patch("aiohttp.ClientSession.post",
new_callable=mock.AsyncMock) as mocked_function:
- await hook.run("v1/test")
- headers = mocked_function.call_args.kwargs.get("headers")
- assert all(
- key in headers and headers[key] == value for key, value in
connection_extra.items()
- )
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
+ hook = HttpAsyncHook()
+ with mock.patch("aiohttp.ClientSession.post",
new_callable=mock.AsyncMock) as mocked_function:
+ async with aiohttp.ClientSession() as session:
+ await hook.run(session=session, endpoint="v1/test")
+ headers =
mocked_function.call_args.kwargs.get("headers")
+ assert all(
+ key in headers and headers[key] == value
+ for key, value in connection_extra.items()
+ )
@pytest.mark.asyncio
async def
test_async_request_uses_connection_extra_with_requests_parameters(self):
@@ -670,18 +681,29 @@ class TestHttpAsyncHook:
with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=airflow_connection):
hook = HttpAsyncHook()
- with mock.patch("aiohttp.ClientSession.post",
new_callable=mock.AsyncMock) as mocked_function:
- await hook.run("v1/test")
- headers = mocked_function.call_args.kwargs.get("headers")
- assert all(
- key in headers and headers[key] == value for key, value in
connection_extra.items()
+
+ with aioresponses() as m:
+ m.post(
+ "http://test:8080/v1/test",
+ status=200,
+ payload='{"status":{"status": 200}}',
+ reason="OK",
)
- assert mocked_function.call_args.kwargs.get("proxy") == proxy
- assert mocked_function.call_args.kwargs.get("timeout") == 60
- assert mocked_function.call_args.kwargs.get("verify_ssl") is
False
- assert mocked_function.call_args.kwargs.get("allow_redirects")
is False
- assert mocked_function.call_args.kwargs.get("max_redirects")
== 3
- assert mocked_function.call_args.kwargs.get("trust_env") is
False
+
+ with mock.patch("aiohttp.ClientSession.post",
new_callable=mock.AsyncMock) as mocked_function:
+ async with aiohttp.ClientSession() as session:
+ await hook.run(session=session, endpoint="v1/test")
+ headers =
mocked_function.call_args.kwargs.get("headers")
+ assert all(
+ key in headers and headers[key] == value
+ for key, value in connection_extra.items()
+ )
+ assert mocked_function.call_args.kwargs.get("proxy")
== proxy
+ assert mocked_function.call_args.kwargs.get("timeout")
== 60
+ assert
mocked_function.call_args.kwargs.get("verify_ssl") is False
+ assert
mocked_function.call_args.kwargs.get("allow_redirects") is False
+ assert
mocked_function.call_args.kwargs.get("max_redirects") == 3
+ assert
mocked_function.call_args.kwargs.get("trust_env") is False
def test_process_extra_options_from_connection(self):
extra_options = {}
@@ -718,9 +740,19 @@ class TestHttpAsyncHook:
schema = conn.schema or "http" # default to http
with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
hook = HttpAsyncHook()
+
+ with aioresponses() as m:
+ m.post(
+ f"{schema}://test:8080/v1/test",
+ status=200,
+ payload='{"status":{"status": 200}}',
+ reason="OK",
+ )
+
with mock.patch("aiohttp.ClientSession.post",
new_callable=mock.AsyncMock) as mocked_function:
- await hook.run("v1/test")
- assert mocked_function.call_args.args[0] ==
f"{schema}://{conn.host}v1/test"
+ async with aiohttp.ClientSession() as session:
+ await hook.run(session=session, endpoint="v1/test")
+ assert mocked_function.call_args.args[0] ==
f"{schema}://{conn.host}v1/test"
@pytest.mark.asyncio
async def test_build_request_url_from_endpoint_param(self):
@@ -728,9 +760,16 @@ class TestHttpAsyncHook:
return Connection(conn_id=conn_id, conn_type="http")
hook = HttpAsyncHook()
- with (
- mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_empty_conn),
- mock.patch("aiohttp.ClientSession.post",
new_callable=mock.AsyncMock) as mocked_function,
- ):
- await hook.run("test.com:8080/v1/test")
- assert mocked_function.call_args.args[0] ==
"http://test.com:8080/v1/test"
+
+ with aioresponses() as m:
+ m.post(
+ "http://test.com:8080/v1/test", status=200,
payload='{"status":{"status": 200}}', reason="OK"
+ )
+
+ with (
+ mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_empty_conn),
+ mock.patch("aiohttp.ClientSession.post",
new_callable=mock.AsyncMock) as mocked_function,
+ ):
+ async with aiohttp.ClientSession() as session:
+ await hook.run(session=session,
endpoint="test.com:8080/v1/test")
+ assert mocked_function.call_args.args[0] ==
"http://test.com:8080/v1/test"