This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 61412b3af83 Add HTTP retry handling into task SDK api.client (#45121)
61412b3af83 is described below
commit 61412b3af83610cf21588ef3ddaff300dce50f37
Author: Jens Scheffler <[email protected]>
AuthorDate: Fri Dec 27 23:51:46 2024 +0100
Add HTTP retry handling into task SDK api.client (#45121)
* Add HTTP retry handling into task SDK api.client
* Add logging of call failures
* Prevent task sdk tests with LocalExecutor fail with retries
* Review feedback
* Review feedback, Adjust wording
* Correct time parameters to float
* Review Feedback
---
task_sdk/pyproject.toml | 1 +
task_sdk/src/airflow/sdk/api/client.py | 28 ++++
task_sdk/tests/api/test_client.py | 157 ++++++++++++++++-----
.../commands/remote_commands/test_task_command.py | 2 +
4 files changed, 155 insertions(+), 33 deletions(-)
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index aa0271c85fc..a27f4cb7c91 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -30,6 +30,7 @@ dependencies = [
"msgspec>=0.18.6",
"psutil>=6.1.0",
"structlog>=24.4.0",
+ "retryhttp>=1.2.0",
]
classifiers = [
"Framework :: Apache Airflow",
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index 7488ef3e88a..ee4144c7f54 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -17,6 +17,8 @@
from __future__ import annotations
+import logging
+import os
import sys
import uuid
from http import HTTPStatus
@@ -26,6 +28,8 @@ import httpx
import msgspec
import structlog
from pydantic import BaseModel
+from retryhttp import retry, wait_retry_after
+from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7
from airflow.sdk import __version__
@@ -268,6 +272,15 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"text": "Hello, world!"})
+# Config options for SDK how retries on HTTP requests should be handled
+# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07,
3:37 and fails after 5:07min
+# So far there is no other config facility in SDK we use ENV for the moment
+# TODO: Consider these env variables while handling airflow confs in task sdk
+API_RETRIES = int(os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10))
+API_RETRY_WAIT_MIN = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN",
1.0))
+API_RETRY_WAIT_MAX = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX",
90.0))
+
+
class Client(httpx.Client):
def __init__(self, *, base_url: str | None, dry_run: bool = False, token:
str, **kwargs: Any):
if (not base_url) ^ dry_run:
@@ -289,6 +302,21 @@ class Client(httpx.Client):
**kwargs,
)
+ _default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN,
max=API_RETRY_WAIT_MAX)
+
+ @retry(
+ reraise=True,
+ max_attempt_number=API_RETRIES,
+ wait_server_errors=_default_wait,
+ wait_network_errors=_default_wait,
+ wait_timeouts=_default_wait,
+ wait_rate_limited=wait_retry_after(fallback=_default_wait), # No
infinite timeout on HTTP 429
+ before_sleep=before_log(log, logging.WARNING),
+ )
+ def request(self, *args, **kwargs):
+ """Implement a convenience for httpx.Client.request with a retry
layer."""
+ return super().request(*args, **kwargs)
+
# We "group" or "namespace" operations by what they operate on, rather
than a flat namespace with all
# methods on one object prefixed with the object type
(`.task_instances.update` rather than
# `task_instance_update` etc.)
diff --git a/task_sdk/tests/api/test_client.py
b/task_sdk/tests/api/test_client.py
index 279502793ee..c52feb96766 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import json
+from unittest import mock
import httpx
import pytest
@@ -30,18 +31,28 @@ from airflow.utils import timezone
from airflow.utils.state import TerminalTIState
-class TestClient:
- def test_error_parsing(self):
- def handle_request(request: httpx.Request) -> httpx.Response:
- """
- A transport handle that always returns errors
- """
+def make_client(transport: httpx.MockTransport) -> Client:
+ """Get a client with a custom transport"""
+ return Client(base_url="test://server", token="", transport=transport)
- return httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg":
"err", "type": "required"}]})
- client = Client(
- base_url=None, dry_run=True, token="", mounts={"'http://":
httpx.MockTransport(handle_request)}
- )
+def make_client_w_responses(responses: list[httpx.Response]) -> Client:
+ """Helper fixture to create a mock client with custom responses."""
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ return responses.pop(0)
+
+ return Client(
+ base_url=None, dry_run=True, token="", mounts={"'http://":
httpx.MockTransport(handle_request)}
+ )
+
+
+class TestClient:
+ def test_error_parsing(self):
+ responses = [
+ httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err",
"type": "required"}]})
+ ]
+ client = make_client_w_responses(responses)
with pytest.raises(ServerResponseError) as err:
client.get("http://error")
@@ -53,39 +64,92 @@ class TestClient:
]
def test_error_parsing_plain_text(self):
- def handle_request(request: httpx.Request) -> httpx.Response:
- """
- A transport handle that always returns errors
- """
-
- return httpx.Response(422, content=b"Internal Server Error")
-
- client = Client(
- base_url=None, dry_run=True, token="", mounts={"'http://":
httpx.MockTransport(handle_request)}
- )
+ responses = [httpx.Response(422, content=b"Internal Server Error")]
+ client = make_client_w_responses(responses)
with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)
def test_error_parsing_other_json(self):
- def handle_request(request: httpx.Request) -> httpx.Response:
- # Some other json than an error body.
- return httpx.Response(404, json={"detail": "Not found"})
-
- client = Client(
- base_url=None, dry_run=True, token="", mounts={"'http://":
httpx.MockTransport(handle_request)}
- )
+ responses = [httpx.Response(404, json={"detail": "Not found"})]
+ client = make_client_w_responses(responses)
with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert err.value.args == ("Not found",)
assert err.value.detail is None
+ @mock.patch("time.sleep", return_value=None)
+ def test_retry_handling_unrecoverable_error(self, mock_sleep):
+ responses: list[httpx.Response] = [
+ *[httpx.Response(500, text="Internal Server Error")] * 11,
+ httpx.Response(200, json={"detail": "Recovered from error - but
will fail before"}),
+ httpx.Response(400, json={"detail": "Should not get here"}),
+ ]
+ client = make_client_w_responses(responses)
-def make_client(transport: httpx.MockTransport) -> Client:
- """Get a client with a custom transport"""
- return Client(base_url="test://server", token="", transport=transport)
+ with pytest.raises(httpx.HTTPStatusError) as err:
+ client.get("http://error")
+ assert not isinstance(err.value, ServerResponseError)
+ assert len(responses) == 3
+ assert mock_sleep.call_count == 9
+
+ @mock.patch("time.sleep", return_value=None)
+ def test_retry_handling_recovered(self, mock_sleep):
+ responses: list[httpx.Response] = [
+ *[httpx.Response(500, text="Internal Server Error")] * 3,
+ httpx.Response(200, json={"detail": "Recovered from error"}),
+ httpx.Response(400, json={"detail": "Should not get here"}),
+ ]
+ client = make_client_w_responses(responses)
+
+ response = client.get("http://error")
+ assert response.status_code == 200
+ assert len(responses) == 1
+ assert mock_sleep.call_count == 3
+
+ @mock.patch("time.sleep", return_value=None)
+ def test_retry_handling_overload(self, mock_sleep):
+ responses: list[httpx.Response] = [
+ httpx.Response(429, text="I am really busy atm, please back-off",
headers={"Retry-After": "37"}),
+ httpx.Response(200, json={"detail": "Recovered from error"}),
+ httpx.Response(400, json={"detail": "Should not get here"}),
+ ]
+ client = make_client_w_responses(responses)
+
+ response = client.get("http://error")
+ assert response.status_code == 200
+ assert len(responses) == 1
+ assert mock_sleep.call_count == 1
+ assert mock_sleep.call_args[0][0] == 37
+
+ @mock.patch("time.sleep", return_value=None)
+ def test_retry_handling_non_retry_error(self, mock_sleep):
+ responses: list[httpx.Response] = [
+ httpx.Response(422, json={"detail": "Somehow this is a bad
request"}),
+ httpx.Response(400, json={"detail": "Should not get here"}),
+ ]
+ client = make_client_w_responses(responses)
+
+ with pytest.raises(ServerResponseError) as err:
+ client.get("http://error")
+ assert len(responses) == 1
+ assert mock_sleep.call_count == 0
+ assert err.value.args == ("Somehow this is a bad request",)
+
+ @mock.patch("time.sleep", return_value=None)
+ def test_retry_handling_ok(self, mock_sleep):
+ responses: list[httpx.Response] = [
+ httpx.Response(200, json={"detail": "Recovered from error"}),
+ httpx.Response(400, json={"detail": "Should not get here"}),
+ ]
+ client = make_client_w_responses(responses)
+
+ response = client.get("http://error")
+ assert response.status_code == 200
+ assert len(responses) == 1
+ assert mock_sleep.call_count == 0
class TestTaskInstanceOperations:
@@ -95,7 +159,8 @@ class TestTaskInstanceOperations:
response parsing.
"""
- def test_task_instance_start(self, make_ti_context):
+ @mock.patch("time.sleep", return_value=None) # To have retries not
slowing down tests
+ def test_task_instance_start(self, mock_sleep, make_ti_context):
# Simulate a successful response from the server that starts a task
ti_id = uuid6.uuid7()
start_date = "2024-10-31T12:00:00Z"
@@ -105,7 +170,14 @@ class TestTaskInstanceOperations:
run_type="manual",
)
+ # ...including a validation that retry really works
+ call_count = 0
+
def handle_request(request: httpx.Request) -> httpx.Response:
+ nonlocal call_count
+ call_count += 1
+ if call_count < 4:
+ return httpx.Response(status_code=500, json={"detail":
"Internal Server Error"})
if request.url.path == f"/task-instances/{ti_id}/run":
actual_body = json.loads(request.read())
assert actual_body["pid"] == 100
@@ -120,6 +192,7 @@ class TestTaskInstanceOperations:
client = make_client(transport=httpx.MockTransport(handle_request))
resp = client.task_instances.start(ti_id, 100, start_date)
assert resp == ti_context
+ assert call_count == 4
@pytest.mark.parametrize("state", [state for state in TerminalTIState])
def test_task_instance_finish(self, state):
@@ -245,9 +318,17 @@ class TestVariableOperations:
response parsing.
"""
- def test_variable_get_success(self):
+ @mock.patch("time.sleep", return_value=None) # To have retries not
slowing down tests
+ def test_variable_get_success(self, mock_sleep):
# Simulate a successful response from the server with a variable
+ # ...including a validation that retry really works
+ call_count = 0
+
def handle_request(request: httpx.Request) -> httpx.Response:
+ nonlocal call_count
+ call_count += 1
+ if call_count < 2:
+ return httpx.Response(status_code=500, json={"detail":
"Internal Server Error"})
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=200,
@@ -261,6 +342,7 @@ class TestVariableOperations:
assert isinstance(result, VariableResponse)
assert result.key == "test_key"
assert result.value == "test_value"
+ assert call_count == 2
def test_variable_not_found(self):
# Simulate a 404 response from the server
@@ -323,9 +405,17 @@ class TestXCOMOperations:
pytest.param({"key": "test_key", "value": {"key2": "value2"}},
id="nested-dict-value"),
],
)
- def test_xcom_get_success(self, value):
+ @mock.patch("time.sleep", return_value=None) # To have retries not
slowing down tests
+ def test_xcom_get_success(self, mock_sleep, value):
# Simulate a successful response from the server when getting an xcom
+ # ...including a validation that retry really works
+ call_count = 0
+
def handle_request(request: httpx.Request) -> httpx.Response:
+ nonlocal call_count
+ call_count += 1
+ if call_count < 3:
+ return httpx.Response(status_code=500, json={"detail":
"Internal Server Error"})
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
return httpx.Response(
status_code=201,
@@ -343,6 +433,7 @@ class TestXCOMOperations:
assert isinstance(result, XComResponse)
assert result.key == "test_key"
assert result.value == value
+ assert call_count == 3
def test_xcom_get_success_with_map_index(self):
# Simulate a successful response from the server when getting an xcom
with map_index passed
diff --git a/tests/cli/commands/remote_commands/test_task_command.py
b/tests/cli/commands/remote_commands/test_task_command.py
index 66177c2d84e..843d6817cdc 100644
--- a/tests/cli/commands/remote_commands/test_task_command.py
+++ b/tests/cli/commands/remote_commands/test_task_command.py
@@ -496,6 +496,8 @@ class TestCliTasks:
mock.patch(
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor"
) as get_default_mock,
+ mock.patch("airflow.executors.local_executor.SimpleQueue"), #
Prevent a task being queued
+
mock.patch("airflow.executors.local_executor.LocalExecutor.end"),
):
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2", executor="foo_executor_alias")