This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6e3a25eccfb AIP-72: Handle External update TI state in Supervisor
(#44406)
6e3a25eccfb is described below
commit 6e3a25eccfba136dc57385718f08f5f291a7db76
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Nov 28 23:58:34 2024 +0000
AIP-72: Handle External update TI state in Supervisor (#44406)
- Updated logic to handle externally updated TI state in Supervisor. This
states could have been externally changed via UI, CLI, API etc
- Replaced `FASTEST_HEARTBEAT_INTERVAL` and `SLOWEST_HEARTBEAT_INTERVAL`
with `MIN_HEARTBEAT_INTERVAL` and `HEARTBEAT_THRESHOLD` for better clarity
This is part of my efforts to port LocalTaskJob tests to Supervisor:
https://github.com/apache/airflow/issues/44356.
This ports over `TestLocalTaskJob.test_mark_{success,failure}_no_kill`.
This PR also allows retrying heartbeats:
- Added `_last_successful_heartbeat` and `_last_heartbeat_attempt` for
better separation of tracking successful heartbeats and retries.
- `MIN_HEARTBEAT_INTERVAL` is now respected between heartbeat attempts,
even after failures.
- The num of retries is configurable via `MAX_FAILED_HEARTBEATS`
---
.../src/airflow/sdk/execution_time/supervisor.py | 123 +++++++++++++----
task_sdk/tests/conftest.py | 9 +-
task_sdk/tests/execution_time/test_supervisor.py | 152 +++++++++++++++++++--
3 files changed, 242 insertions(+), 42 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 4058e46513a..1fdcf309dff 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -31,6 +31,7 @@ import weakref
from collections.abc import Generator
from contextlib import suppress
from datetime import datetime, timezone
+from http import HTTPStatus
from socket import socket, socketpair
from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal,
NoReturn, TextIO, cast, overload
from uuid import UUID
@@ -42,7 +43,7 @@ import psutil
import structlog
from pydantic import TypeAdapter
-from airflow.sdk.api.client import Client
+from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import IntermediateTIState,
TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
DeferTask,
@@ -54,6 +55,8 @@ from airflow.sdk.execution_time.comms import (
)
if TYPE_CHECKING:
+ from selectors import SelectorKey
+
from structlog.typing import FilteringBoundLogger, WrappedLogger
@@ -62,9 +65,12 @@ __all__ = ["WatchedSubprocess", "supervise"]
log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")
# TODO: Pull this from config
-SLOWEST_HEARTBEAT_INTERVAL: int = 30
+# (previously `[scheduler] local_task_job_heartbeat_sec` with the following
as fallback if it is 0:
+# `[scheduler] scheduler_zombie_task_threshold`)
+HEARTBEAT_THRESHOLD: int = 30
# Don't heartbeat more often than this
-FASTEST_HEARTBEAT_INTERVAL: int = 5
+MIN_HEARTBEAT_INTERVAL: int = 5
+MAX_FAILED_HEARTBEATS: int = 3
@overload
@@ -265,7 +271,13 @@ class WatchedSubprocess:
_terminal_state: str | None = None
_final_state: str | None = None
- _last_heartbeat: float = 0
+ _last_successful_heartbeat: float = attrs.field(default=0, init=False)
+ _last_heartbeat_attempt: float = attrs.field(default=0, init=False)
+
+ # After the failure of a heartbeat, we'll increment this counter. If it
reaches `MAX_FAILED_HEARTBEATS`, we
+ # will kill the process. This is to handle temporary network issues etc.
ensuring that the process
+ # does not hang around forever.
+ failed_heartbeats: int = attrs.field(default=0, init=False)
selector: selectors.BaseSelector =
attrs.field(factory=selectors.DefaultSelector)
@@ -320,7 +332,7 @@ class WatchedSubprocess:
# reason)
try:
client.task_instances.start(ti.id, pid,
datetime.now(tz=timezone.utc))
- proc._last_heartbeat = time.monotonic()
+ proc._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
proc.kill(signal.SIGKILL)
@@ -423,38 +435,55 @@ class WatchedSubprocess:
This function:
- - Polls the subprocess for output
- - Sends heartbeats to the client to keep the task alive
- - Checks if the subprocess has exited
+ - Waits for activity on file objects (e.g., subprocess stdout, stderr,
logs, requests) using the selector.
+ - Processes events triggered on the monitored file objects, such as
data availability or EOF.
+ - Sends heartbeats to ensure the process is alive and checks if the
subprocess has exited.
"""
- # Until we have a selector for the process, don't poll for more than
10s, just in case it exists but
- # doesn't produce any output
- max_poll_interval = 10
-
while self._exit_code is None or len(self.selector.get_map()):
- last_heartbeat_ago = time.monotonic() - self._last_heartbeat
+ last_heartbeat_ago = time.monotonic() -
self._last_successful_heartbeat
# Monitor the task to see if it's done. Wait in a syscall
(`select`) for as long as possible
# so we notice the subprocess finishing as quick as we can.
max_wait_time = max(
0, # Make sure this value is never negative,
min(
# Ensure we heartbeat _at most_ 75% through time the
zombie threshold time
- SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
- max_poll_interval,
+ HEARTBEAT_THRESHOLD - last_heartbeat_ago * 0.75,
+ MIN_HEARTBEAT_INTERVAL,
),
)
+ # Block until events are ready or the timeout is reached
+ # This listens for activity (e.g., subprocess output) on
registered file objects
events = self.selector.select(timeout=max_wait_time)
- for key, _ in events:
- socket_handler = key.data
- need_more = socket_handler(key.fileobj)
-
- if not need_more:
- self.selector.unregister(key.fileobj)
- key.fileobj.close() # type: ignore[union-attr]
+ self._process_file_object_events(events)
self._check_subprocess_exit()
self._send_heartbeat_if_needed()
+ def _process_file_object_events(self, events: list[tuple[SelectorKey,
int]]):
+ """
+ Process selector events by invoking handlers for each file object.
+
+ For each file object event, this method retrieves the associated
handler and processes
+ the event. If the handler indicates that the file object no longer
needs
+ monitoring (e.g., EOF or closed), the file object is unregistered and
closed.
+ """
+ for key, _ in events:
+ # Retrieve the handler responsible for processing this file object
(e.g., stdout, stderr)
+ socket_handler = key.data
+
+ # Example of handler behavior:
+ # If the subprocess writes "Hello, World!" to stdout:
+ # - `socket_handler` reads and processes the message.
+ # - If EOF is reached, the handler returns False to signal no more
reads are expected.
+ need_more = socket_handler(key.fileobj)
+
+ # If the handler signals that the file object is no longer needed
(EOF, closed, etc.)
+ # unregister it from the selector to stop monitoring; `wait()`
blocks until all selectors
+ # are removed.
+ if not need_more:
+ self.selector.unregister(key.fileobj)
+ key.fileobj.close() # type: ignore[union-attr]
+
def _check_subprocess_exit(self):
"""Check if the subprocess has exited."""
if self._exit_code is None:
@@ -466,14 +495,48 @@ class WatchedSubprocess:
def _send_heartbeat_if_needed(self):
"""Send a heartbeat to the client if heartbeat interval has passed."""
- if time.monotonic() - self._last_heartbeat >=
FASTEST_HEARTBEAT_INTERVAL:
- try:
- self.client.task_instances.heartbeat(self.ti_id,
pid=self._process.pid)
- self._last_heartbeat = time.monotonic()
- except Exception:
- log.warning("Failed to send heartbeat", exc_info=True)
- # TODO: If we couldn't heartbeat for X times the interval,
kill ourselves
- pass
+ # Respect the minimum interval between heartbeat attempts
+ if (time.monotonic() - self._last_heartbeat_attempt) <
MIN_HEARTBEAT_INTERVAL:
+ return
+
+ self._last_heartbeat_attempt = time.monotonic()
+ try:
+ self.client.task_instances.heartbeat(self.ti_id,
pid=self._process.pid)
+ # Update the last heartbeat time on success
+ self._last_successful_heartbeat = time.monotonic()
+
+ # Reset the counter on success
+ self.failed_heartbeats = 0
+ except ServerResponseError as e:
+ if e.response.status_code in {HTTPStatus.NOT_FOUND,
HTTPStatus.CONFLICT}:
+ log.error(
+ "Server indicated the task shouldn't be running anymore",
+ detail=e.detail,
+ status_code=e.response.status_code,
+ )
+ self.kill(signal.SIGTERM)
+ else:
+ # If we get any other error, we'll just log it and try again
next time
+ self._handle_heartbeat_failures()
+ except Exception:
+ self._handle_heartbeat_failures()
+
+ def _handle_heartbeat_failures(self):
+ """Increment the failed heartbeats counter and kill the process if too
many failures."""
+ self.failed_heartbeats += 1
+ log.warning(
+ "Failed to send heartbeat. Will be retried",
+ failed_heartbeats=self.failed_heartbeats,
+ ti_id=self.ti_id,
+ max_retries=MAX_FAILED_HEARTBEATS,
+ exc_info=True,
+ )
+ # If we've failed to heartbeat too many times, kill the process
+ if self.failed_heartbeats >= MAX_FAILED_HEARTBEATS:
+ log.error(
+ "Too many failed heartbeats; terminating process",
failed_heartbeats=self.failed_heartbeats
+ )
+ self.kill(signal.SIGTERM)
@property
def final_state(self):
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index 9e03cb07963..f0864474092 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, NoReturn
@@ -72,7 +73,7 @@ def test_dags_dir():
@pytest.fixture
-def captured_logs():
+def captured_logs(request):
import structlog
from airflow.sdk.log import configure_logging, reset_logging
@@ -81,6 +82,12 @@ def captured_logs():
reset_logging()
configure_logging(enable_pretty_log=False)
+ # Get log level from test parameter, defaulting to INFO if not provided
+ log_level = getattr(request, "param", logging.INFO)
+
+ # We want to capture all logs, but we don't want to see them in the test
output
+
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level))
+
# But we need to replace remove the last processor (the one that turns
JSON into text, as we want the
# event dict for tests)
cur_processors = structlog.get_config()["processors"]
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 3a50a73b3cf..53adf8df2e3 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -30,7 +30,6 @@ from unittest.mock import MagicMock
import httpx
import pytest
-import structlog
from uuid6 import uuid7
from airflow.sdk.api import client as sdk_client
@@ -64,9 +63,6 @@ def lineno():
@pytest.mark.usefixtures("disable_capturing")
class TestWatchedSubprocess:
def test_reading_from_pipes(self, captured_logs, time_machine):
- # Ignore anything lower than INFO for this test. Captured_logs resets
things for us afterwards
-
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
-
def subprocess_main():
# This is run in the subprocess!
@@ -177,9 +173,6 @@ class TestWatchedSubprocess:
assert rc == -9
def test_last_chance_exception_handling(self, capfd):
- # Ignore anything lower than INFO for this test. Captured_logs resets
things for us afterwards
-
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
-
def subprocess_main():
# The real main() in task_runner catches exceptions! This is what
would happen if we had a syntax
# or import error for instance - a very early exception
@@ -210,7 +203,7 @@ class TestWatchedSubprocess:
"""Test that the WatchedSubprocess class regularly sends heartbeat
requests, up to a certain frequency"""
import airflow.sdk.execution_time.supervisor
- monkeypatch.setattr(airflow.sdk.execution_time.supervisor,
"FASTEST_HEARTBEAT_INTERVAL", 0.1)
+ monkeypatch.setattr(airflow.sdk.execution_time.supervisor,
"MIN_HEARTBEAT_INTERVAL", 0.1)
def subprocess_main():
sys.stdin.readline()
@@ -241,9 +234,6 @@ class TestWatchedSubprocess:
def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
"""Test running a simple DAG in a subprocess and capturing the
output."""
- # Ignore anything lower than INFO for this test.
-
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
-
instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901)
time_machine.move_to(instant, tick=False)
@@ -299,6 +289,146 @@ class TestWatchedSubprocess:
"previous_state": "running",
}
+ @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True,
ids=["log_level=error"])
+ def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch,
mocker):
+ """
+ Test that ensures that the Supervisor does not cause the task to fail
if the Task Instance is no longer
+ in the running state. Instead, it logs the error and terminates the
task process if it
+ might be running in a different state or has already completed -- or
running on a different worker.
+ """
+ import airflow.sdk.execution_time.supervisor
+
+ monkeypatch.setattr(airflow.sdk.execution_time.supervisor,
"MIN_HEARTBEAT_INTERVAL", 0.1)
+
+ def subprocess_main():
+ sys.stdin.readline()
+ sleep(5)
+
+ ti_id = uuid7()
+
+ # Track the number of requests to simulate mixed responses
+ request_count = {"count": 0}
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/heartbeat":
+ request_count["count"] += 1
+ if request_count["count"] == 1:
+ # First request succeeds
+ return httpx.Response(status_code=204)
+ else:
+ # Second request returns a conflict status code
+ return httpx.Response(
+ 409,
+ json={
+ "reason": "not_running",
+ "message": "TI is no longer in the running state
and task should terminate",
+ "current_state": "success",
+ },
+ )
+ # Return a 204 for all other requests like the initial call to
mark the task as running
+ return httpx.Response(status_code=204)
+
+ proc = WatchedSubprocess.start(
+ path=os.devnull,
+ ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d",
try_number=1),
+ client=make_client(transport=httpx.MockTransport(handle_request)),
+ target=subprocess_main,
+ )
+
+ # Wait for the subprocess to finish -- it should have been terminated
+ assert proc.wait() == -signal.SIGTERM
+
+ # Verify the number of requests made
+ assert request_count["count"] == 2
+ assert captured_logs == [
+ {
+ "detail": {
+ "current_state": "success",
+ "message": "TI is no longer in the running state and task
should terminate",
+ "reason": "not_running",
+ },
+ "event": "Server indicated the task shouldn't be running
anymore",
+ "level": "error",
+ "status_code": 409,
+ "logger": "supervisor",
+ "timestamp": mocker.ANY,
+ }
+ ]
+
+ @pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True)
+ def test_heartbeat_failures_handling(self, monkeypatch, mocker,
captured_logs, time_machine):
+ """
+ Test that ensures the WatchedSubprocess kills the process after
+ MAX_FAILED_HEARTBEATS are exceeded.
+ """
+ max_failed_heartbeats = 3
+ min_heartbeat_interval = 5
+ monkeypatch.setattr(
+ "airflow.sdk.execution_time.supervisor.MAX_FAILED_HEARTBEATS",
max_failed_heartbeats
+ )
+ monkeypatch.setattr(
+ "airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL",
min_heartbeat_interval
+ )
+
+ mock_process = mocker.Mock()
+ mock_process.pid = 12345
+
+ # Mock the client heartbeat method to raise an exception
+ mock_client_heartbeat = mocker.Mock(side_effect=Exception("Simulated
heartbeat failure"))
+ client = mocker.Mock()
+ client.task_instances.heartbeat = mock_client_heartbeat
+
+ # Patch the kill method at the class level so we can assert it was
called with the correct signal
+ mock_kill =
mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")
+
+ proc = WatchedSubprocess(
+ ti_id=TI_ID,
+ pid=mock_process.pid,
+ stdin=mocker.MagicMock(),
+ client=client,
+ process=mock_process,
+ )
+
+ time_now = tz.datetime(2024, 11, 28, 12, 0, 0)
+ time_machine.move_to(time_now, tick=False)
+
+ # Simulate sending heartbeats and ensure the process gets killed after
max retries
+ for i in range(1, max_failed_heartbeats):
+ proc._send_heartbeat_if_needed()
+ assert proc.failed_heartbeats == i # Increment happens after
failure
+ mock_client_heartbeat.assert_called_with(TI_ID,
pid=mock_process.pid)
+
+ # Ensure the retry log is present
+ expected_log = {
+ "event": "Failed to send heartbeat. Will be retried",
+ "failed_heartbeats": i,
+ "ti_id": TI_ID,
+ "max_retries": max_failed_heartbeats,
+ "level": "warning",
+ "logger": "supervisor",
+ "timestamp": mocker.ANY,
+ "exception": mocker.ANY,
+ }
+
+ assert expected_log in captured_logs
+
+ # Advance time by `min_heartbeat_interval` to allow the next
heartbeat
+ time_machine.shift(min_heartbeat_interval)
+
+ # On the final failure, the process should be killed
+ proc._send_heartbeat_if_needed()
+
+ assert proc.failed_heartbeats == max_failed_heartbeats
+ mock_kill.assert_called_once_with(signal.SIGTERM)
+ mock_client_heartbeat.assert_called_with(TI_ID, pid=mock_process.pid)
+ assert {
+ "event": "Too many failed heartbeats; terminating process",
+ "level": "error",
+ "failed_heartbeats": max_failed_heartbeats,
+ "logger": "supervisor",
+ "timestamp": mocker.ANY,
+ } in captured_logs
+
class TestHandleRequest:
@pytest.fixture