This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 6e3a25eccfb AIP-72: Handle External update TI state in Supervisor (#44406) 6e3a25eccfb is described below commit 6e3a25eccfba136dc57385718f08f5f291a7db76 Author: Kaxil Naik <kaxiln...@apache.org> AuthorDate: Thu Nov 28 23:58:34 2024 +0000 AIP-72: Handle External update TI state in Supervisor (#44406) - Updated logic to handle externally updated TI state in Supervisor. This states could have been externally changed via UI, CLI, API etc - Replaced `FASTEST_HEARTBEAT_INTERVAL` and `SLOWEST_HEARTBEAT_INTERVAL` with `MIN_HEARTBEAT_INTERVAL` and `HEARTBEAT_THRESHOLD` for better clarity This is part of my efforts to port LocalTaskJob tests to Supervisor: https://github.com/apache/airflow/issues/44356. This ports over `TestLocalTaskJob.test_mark_{success,failure}_no_kill`. This PR also allows retrying heartbeats: - Added `_last_successful_heartbeat` and `_last_heartbeat_attempt` for better separation of tracking successful heartbeats and retries. - `MIN_HEARTBEAT_INTERVAL` is now respected between heartbeat attempts, even after failures. - The num of retries is configurable via `MAX_FAILED_HEARTBEATS` --- .../src/airflow/sdk/execution_time/supervisor.py | 123 +++++++++++++---- task_sdk/tests/conftest.py | 9 +- task_sdk/tests/execution_time/test_supervisor.py | 152 +++++++++++++++++++-- 3 files changed, 242 insertions(+), 42 deletions(-) diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 4058e46513a..1fdcf309dff 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -31,6 +31,7 @@ import weakref from collections.abc import Generator from contextlib import suppress from datetime import datetime, timezone +from http import HTTPStatus from socket import socket, socketpair from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, TextIO, cast, overload from uuid import UUID @@ -42,7 +43,7 @@ import psutil import structlog from pydantic import TypeAdapter -from airflow.sdk.api.client import Client +from airflow.sdk.api.client import Client, ServerResponseError from airflow.sdk.api.datamodels._generated import IntermediateTIState, TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( DeferTask, @@ -54,6 +55,8 @@ from airflow.sdk.execution_time.comms import ( ) if TYPE_CHECKING: + from selectors import SelectorKey + from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -62,9 +65,12 @@ __all__ = ["WatchedSubprocess", "supervise"] log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor") # TODO: Pull this from config -SLOWEST_HEARTBEAT_INTERVAL: int = 30 +# (previously `[scheduler] local_task_job_heartbeat_sec` with the following as fallback if it is 0: +# `[scheduler] scheduler_zombie_task_threshold`) +HEARTBEAT_THRESHOLD: int = 30 # Don't heartbeat more often than this -FASTEST_HEARTBEAT_INTERVAL: int = 5 +MIN_HEARTBEAT_INTERVAL: int = 5 +MAX_FAILED_HEARTBEATS: int = 3 @overload @@ -265,7 +271,13 @@ class WatchedSubprocess: _terminal_state: str | None = None _final_state: str | None = None - _last_heartbeat: float = 0 + _last_successful_heartbeat: float = attrs.field(default=0, init=False) + _last_heartbeat_attempt: float = attrs.field(default=0, init=False) + + # After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we + # will kill the process. This is to handle temporary network issues etc. ensuring that the process + # does not hang around forever. + failed_heartbeats: int = attrs.field(default=0, init=False) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) @@ -320,7 +332,7 @@ class WatchedSubprocess: # reason) try: client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc)) - proc._last_heartbeat = time.monotonic() + proc._last_successful_heartbeat = time.monotonic() except Exception: # On any error kill that subprocess! proc.kill(signal.SIGKILL) @@ -423,38 +435,55 @@ class WatchedSubprocess: This function: - - Polls the subprocess for output - - Sends heartbeats to the client to keep the task alive - - Checks if the subprocess has exited + - Waits for activity on file objects (e.g., subprocess stdout, stderr, logs, requests) using the selector. + - Processes events triggered on the monitored file objects, such as data availability or EOF. + - Sends heartbeats to ensure the process is alive and checks if the subprocess has exited. """ - # Until we have a selector for the process, don't poll for more than 10s, just in case it exists but - # doesn't produce any output - max_poll_interval = 10 - while self._exit_code is None or len(self.selector.get_map()): - last_heartbeat_ago = time.monotonic() - self._last_heartbeat + last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible # so we notice the subprocess finishing as quick as we can. max_wait_time = max( 0, # Make sure this value is never negative, min( # Ensure we heartbeat _at most_ 75% through time the zombie threshold time - SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75, - max_poll_interval, + HEARTBEAT_THRESHOLD - last_heartbeat_ago * 0.75, + MIN_HEARTBEAT_INTERVAL, ), ) + # Block until events are ready or the timeout is reached + # This listens for activity (e.g., subprocess output) on registered file objects events = self.selector.select(timeout=max_wait_time) - for key, _ in events: - socket_handler = key.data - need_more = socket_handler(key.fileobj) - - if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + self._process_file_object_events(events) self._check_subprocess_exit() self._send_heartbeat_if_needed() + def _process_file_object_events(self, events: list[tuple[SelectorKey, int]]): + """ + Process selector events by invoking handlers for each file object. + + For each file object event, this method retrieves the associated handler and processes + the event. If the handler indicates that the file object no longer needs + monitoring (e.g., EOF or closed), the file object is unregistered and closed. + """ + for key, _ in events: + # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) + socket_handler = key.data + + # Example of handler behavior: + # If the subprocess writes "Hello, World!" to stdout: + # - `socket_handler` reads and processes the message. + # - If EOF is reached, the handler returns False to signal no more reads are expected. + need_more = socket_handler(key.fileobj) + + # If the handler signals that the file object is no longer needed (EOF, closed, etc.) + # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors + # are removed. + if not need_more: + self.selector.unregister(key.fileobj) + key.fileobj.close() # type: ignore[union-attr] + def _check_subprocess_exit(self): """Check if the subprocess has exited.""" if self._exit_code is None: @@ -466,14 +495,48 @@ class WatchedSubprocess: def _send_heartbeat_if_needed(self): """Send a heartbeat to the client if heartbeat interval has passed.""" - if time.monotonic() - self._last_heartbeat >= FASTEST_HEARTBEAT_INTERVAL: - try: - self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid) - self._last_heartbeat = time.monotonic() - except Exception: - log.warning("Failed to send heartbeat", exc_info=True) - # TODO: If we couldn't heartbeat for X times the interval, kill ourselves - pass + # Respect the minimum interval between heartbeat attempts + if (time.monotonic() - self._last_heartbeat_attempt) < MIN_HEARTBEAT_INTERVAL: + return + + self._last_heartbeat_attempt = time.monotonic() + try: + self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid) + # Update the last heartbeat time on success + self._last_successful_heartbeat = time.monotonic() + + # Reset the counter on success + self.failed_heartbeats = 0 + except ServerResponseError as e: + if e.response.status_code in {HTTPStatus.NOT_FOUND, HTTPStatus.CONFLICT}: + log.error( + "Server indicated the task shouldn't be running anymore", + detail=e.detail, + status_code=e.response.status_code, + ) + self.kill(signal.SIGTERM) + else: + # If we get any other error, we'll just log it and try again next time + self._handle_heartbeat_failures() + except Exception: + self._handle_heartbeat_failures() + + def _handle_heartbeat_failures(self): + """Increment the failed heartbeats counter and kill the process if too many failures.""" + self.failed_heartbeats += 1 + log.warning( + "Failed to send heartbeat. Will be retried", + failed_heartbeats=self.failed_heartbeats, + ti_id=self.ti_id, + max_retries=MAX_FAILED_HEARTBEATS, + exc_info=True, + ) + # If we've failed to heartbeat too many times, kill the process + if self.failed_heartbeats >= MAX_FAILED_HEARTBEATS: + log.error( + "Too many failed heartbeats; terminating process", failed_heartbeats=self.failed_heartbeats + ) + self.kill(signal.SIGTERM) @property def final_state(self): diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 9e03cb07963..f0864474092 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import logging import os from pathlib import Path from typing import TYPE_CHECKING, NoReturn @@ -72,7 +73,7 @@ def test_dags_dir(): @pytest.fixture -def captured_logs(): +def captured_logs(request): import structlog from airflow.sdk.log import configure_logging, reset_logging @@ -81,6 +82,12 @@ def captured_logs(): reset_logging() configure_logging(enable_pretty_log=False) + # Get log level from test parameter, defaulting to INFO if not provided + log_level = getattr(request, "param", logging.INFO) + + # We want to capture all logs, but we don't want to see them in the test output + structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level)) + # But we need to replace remove the last processor (the one that turns JSON into text, as we want the # event dict for tests) cur_processors = structlog.get_config()["processors"] diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 3a50a73b3cf..53adf8df2e3 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -30,7 +30,6 @@ from unittest.mock import MagicMock import httpx import pytest -import structlog from uuid6 import uuid7 from airflow.sdk.api import client as sdk_client @@ -64,9 +63,6 @@ def lineno(): @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: def test_reading_from_pipes(self, captured_logs, time_machine): - # Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards - structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO)) - def subprocess_main(): # This is run in the subprocess! @@ -177,9 +173,6 @@ class TestWatchedSubprocess: assert rc == -9 def test_last_chance_exception_handling(self, capfd): - # Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards - structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO)) - def subprocess_main(): # The real main() in task_runner catches exceptions! This is what would happen if we had a syntax # or import error for instance - a very early exception @@ -210,7 +203,7 @@ class TestWatchedSubprocess: """Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency""" import airflow.sdk.execution_time.supervisor - monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "FASTEST_HEARTBEAT_INTERVAL", 0.1) + monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): sys.stdin.readline() @@ -241,9 +234,6 @@ class TestWatchedSubprocess: def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): """Test running a simple DAG in a subprocess and capturing the output.""" - # Ignore anything lower than INFO for this test. - structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO)) - instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) time_machine.move_to(instant, tick=False) @@ -299,6 +289,146 @@ class TestWatchedSubprocess: "previous_state": "running", } + @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"]) + def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker): + """ + Test that ensures that the Supervisor does not cause the task to fail if the Task Instance is no longer + in the running state. Instead, it logs the error and terminates the task process if it + might be running in a different state or has already completed -- or running on a different worker. + """ + import airflow.sdk.execution_time.supervisor + + monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) + + def subprocess_main(): + sys.stdin.readline() + sleep(5) + + ti_id = uuid7() + + # Track the number of requests to simulate mixed responses + request_count = {"count": 0} + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti_id}/heartbeat": + request_count["count"] += 1 + if request_count["count"] == 1: + # First request succeeds + return httpx.Response(status_code=204) + else: + # Second request returns a conflict status code + return httpx.Response( + 409, + json={ + "reason": "not_running", + "message": "TI is no longer in the running state and task should terminate", + "current_state": "success", + }, + ) + # Return a 204 for all other requests like the initial call to mark the task as running + return httpx.Response(status_code=204) + + proc = WatchedSubprocess.start( + path=os.devnull, + ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + client=make_client(transport=httpx.MockTransport(handle_request)), + target=subprocess_main, + ) + + # Wait for the subprocess to finish -- it should have been terminated + assert proc.wait() == -signal.SIGTERM + + # Verify the number of requests made + assert request_count["count"] == 2 + assert captured_logs == [ + { + "detail": { + "current_state": "success", + "message": "TI is no longer in the running state and task should terminate", + "reason": "not_running", + }, + "event": "Server indicated the task shouldn't be running anymore", + "level": "error", + "status_code": 409, + "logger": "supervisor", + "timestamp": mocker.ANY, + } + ] + + @pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True) + def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, time_machine): + """ + Test that ensures the WatchedSubprocess kills the process after + MAX_FAILED_HEARTBEATS are exceeded. + """ + max_failed_heartbeats = 3 + min_heartbeat_interval = 5 + monkeypatch.setattr( + "airflow.sdk.execution_time.supervisor.MAX_FAILED_HEARTBEATS", max_failed_heartbeats + ) + monkeypatch.setattr( + "airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL", min_heartbeat_interval + ) + + mock_process = mocker.Mock() + mock_process.pid = 12345 + + # Mock the client heartbeat method to raise an exception + mock_client_heartbeat = mocker.Mock(side_effect=Exception("Simulated heartbeat failure")) + client = mocker.Mock() + client.task_instances.heartbeat = mock_client_heartbeat + + # Patch the kill method at the class level so we can assert it was called with the correct signal + mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill") + + proc = WatchedSubprocess( + ti_id=TI_ID, + pid=mock_process.pid, + stdin=mocker.MagicMock(), + client=client, + process=mock_process, + ) + + time_now = tz.datetime(2024, 11, 28, 12, 0, 0) + time_machine.move_to(time_now, tick=False) + + # Simulate sending heartbeats and ensure the process gets killed after max retries + for i in range(1, max_failed_heartbeats): + proc._send_heartbeat_if_needed() + assert proc.failed_heartbeats == i # Increment happens after failure + mock_client_heartbeat.assert_called_with(TI_ID, pid=mock_process.pid) + + # Ensure the retry log is present + expected_log = { + "event": "Failed to send heartbeat. Will be retried", + "failed_heartbeats": i, + "ti_id": TI_ID, + "max_retries": max_failed_heartbeats, + "level": "warning", + "logger": "supervisor", + "timestamp": mocker.ANY, + "exception": mocker.ANY, + } + + assert expected_log in captured_logs + + # Advance time by `min_heartbeat_interval` to allow the next heartbeat + time_machine.shift(min_heartbeat_interval) + + # On the final failure, the process should be killed + proc._send_heartbeat_if_needed() + + assert proc.failed_heartbeats == max_failed_heartbeats + mock_kill.assert_called_once_with(signal.SIGTERM) + mock_client_heartbeat.assert_called_with(TI_ID, pid=mock_process.pid) + assert { + "event": "Too many failed heartbeats; terminating process", + "level": "error", + "failed_heartbeats": max_failed_heartbeats, + "logger": "supervisor", + "timestamp": mocker.ANY, + } in captured_logs + class TestHandleRequest: @pytest.fixture