This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e3a25eccfb AIP-72: Handle External update TI state in Supervisor 
(#44406)
6e3a25eccfb is described below

commit 6e3a25eccfba136dc57385718f08f5f291a7db76
Author: Kaxil Naik <kaxiln...@apache.org>
AuthorDate: Thu Nov 28 23:58:34 2024 +0000

    AIP-72: Handle External update TI state in Supervisor (#44406)
    
    - Updated logic to handle externally updated TI state in Supervisor. This 
states could have been externally changed via UI, CLI, API etc
    - Replaced `FASTEST_HEARTBEAT_INTERVAL` and `SLOWEST_HEARTBEAT_INTERVAL` 
with `MIN_HEARTBEAT_INTERVAL` and `HEARTBEAT_THRESHOLD` for better clarity
    
    This is part of my efforts to port LocalTaskJob tests to Supervisor: 
https://github.com/apache/airflow/issues/44356.
    
    This ports over `TestLocalTaskJob.test_mark_{success,failure}_no_kill`.
    
    This PR also allows retrying heartbeats:
    
    - Added `_last_successful_heartbeat` and `_last_heartbeat_attempt` for 
better separation of tracking successful heartbeats and retries.
    - `MIN_HEARTBEAT_INTERVAL` is now respected between heartbeat attempts, 
even after failures.
    -  The num of retries is configurable via `MAX_FAILED_HEARTBEATS`
---
 .../src/airflow/sdk/execution_time/supervisor.py   | 123 +++++++++++++----
 task_sdk/tests/conftest.py                         |   9 +-
 task_sdk/tests/execution_time/test_supervisor.py   | 152 +++++++++++++++++++--
 3 files changed, 242 insertions(+), 42 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 4058e46513a..1fdcf309dff 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -31,6 +31,7 @@ import weakref
 from collections.abc import Generator
 from contextlib import suppress
 from datetime import datetime, timezone
+from http import HTTPStatus
 from socket import socket, socketpair
 from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, 
NoReturn, TextIO, cast, overload
 from uuid import UUID
@@ -42,7 +43,7 @@ import psutil
 import structlog
 from pydantic import TypeAdapter
 
-from airflow.sdk.api.client import Client
+from airflow.sdk.api.client import Client, ServerResponseError
 from airflow.sdk.api.datamodels._generated import IntermediateTIState, 
TaskInstance, TerminalTIState
 from airflow.sdk.execution_time.comms import (
     DeferTask,
@@ -54,6 +55,8 @@ from airflow.sdk.execution_time.comms import (
 )
 
 if TYPE_CHECKING:
+    from selectors import SelectorKey
+
     from structlog.typing import FilteringBoundLogger, WrappedLogger
 
 
@@ -62,9 +65,12 @@ __all__ = ["WatchedSubprocess", "supervise"]
 log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")
 
 # TODO: Pull this from config
-SLOWEST_HEARTBEAT_INTERVAL: int = 30
+#  (previously `[scheduler] local_task_job_heartbeat_sec` with the following 
as fallback if it is 0:
+#  `[scheduler] scheduler_zombie_task_threshold`)
+HEARTBEAT_THRESHOLD: int = 30
 # Don't heartbeat more often than this
-FASTEST_HEARTBEAT_INTERVAL: int = 5
+MIN_HEARTBEAT_INTERVAL: int = 5
+MAX_FAILED_HEARTBEATS: int = 3
 
 
 @overload
@@ -265,7 +271,13 @@ class WatchedSubprocess:
     _terminal_state: str | None = None
     _final_state: str | None = None
 
-    _last_heartbeat: float = 0
+    _last_successful_heartbeat: float = attrs.field(default=0, init=False)
+    _last_heartbeat_attempt: float = attrs.field(default=0, init=False)
+
+    # After the failure of a heartbeat, we'll increment this counter. If it 
reaches `MAX_FAILED_HEARTBEATS`, we
+    # will kill the process. This is to handle temporary network issues etc. 
ensuring that the process
+    # does not hang around forever.
+    failed_heartbeats: int = attrs.field(default=0, init=False)
 
     selector: selectors.BaseSelector = 
attrs.field(factory=selectors.DefaultSelector)
 
@@ -320,7 +332,7 @@ class WatchedSubprocess:
         # reason)
         try:
             client.task_instances.start(ti.id, pid, 
datetime.now(tz=timezone.utc))
-            proc._last_heartbeat = time.monotonic()
+            proc._last_successful_heartbeat = time.monotonic()
         except Exception:
             # On any error kill that subprocess!
             proc.kill(signal.SIGKILL)
@@ -423,38 +435,55 @@ class WatchedSubprocess:
 
         This function:
 
-        - Polls the subprocess for output
-        - Sends heartbeats to the client to keep the task alive
-        - Checks if the subprocess has exited
+        - Waits for activity on file objects (e.g., subprocess stdout, stderr, 
logs, requests) using the selector.
+        - Processes events triggered on the monitored file objects, such as 
data availability or EOF.
+        - Sends heartbeats to ensure the process is alive and checks if the 
subprocess has exited.
         """
-        # Until we have a selector for the process, don't poll for more than 
10s, just in case it exists but
-        # doesn't produce any output
-        max_poll_interval = 10
-
         while self._exit_code is None or len(self.selector.get_map()):
-            last_heartbeat_ago = time.monotonic() - self._last_heartbeat
+            last_heartbeat_ago = time.monotonic() - 
self._last_successful_heartbeat
             # Monitor the task to see if it's done. Wait in a syscall 
(`select`) for as long as possible
             # so we notice the subprocess finishing as quick as we can.
             max_wait_time = max(
                 0,  # Make sure this value is never negative,
                 min(
                     # Ensure we heartbeat _at most_ 75% through time the 
zombie threshold time
-                    SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
-                    max_poll_interval,
+                    HEARTBEAT_THRESHOLD - last_heartbeat_ago * 0.75,
+                    MIN_HEARTBEAT_INTERVAL,
                 ),
             )
+            # Block until events are ready or the timeout is reached
+            # This listens for activity (e.g., subprocess output) on 
registered file objects
             events = self.selector.select(timeout=max_wait_time)
-            for key, _ in events:
-                socket_handler = key.data
-                need_more = socket_handler(key.fileobj)
-
-                if not need_more:
-                    self.selector.unregister(key.fileobj)
-                    key.fileobj.close()  # type: ignore[union-attr]
+            self._process_file_object_events(events)
 
             self._check_subprocess_exit()
             self._send_heartbeat_if_needed()
 
+    def _process_file_object_events(self, events: list[tuple[SelectorKey, 
int]]):
+        """
+        Process selector events by invoking handlers for each file object.
+
+        For each file object event, this method retrieves the associated 
handler and processes
+        the event. If the handler indicates that the file object no longer 
needs
+        monitoring (e.g., EOF or closed), the file object is unregistered and 
closed.
+        """
+        for key, _ in events:
+            # Retrieve the handler responsible for processing this file object 
(e.g., stdout, stderr)
+            socket_handler = key.data
+
+            # Example of handler behavior:
+            # If the subprocess writes "Hello, World!" to stdout:
+            # - `socket_handler` reads and processes the message.
+            # - If EOF is reached, the handler returns False to signal no more 
reads are expected.
+            need_more = socket_handler(key.fileobj)
+
+            # If the handler signals that the file object is no longer needed 
(EOF, closed, etc.)
+            # unregister it from the selector to stop monitoring; `wait()` 
blocks until all selectors
+            # are removed.
+            if not need_more:
+                self.selector.unregister(key.fileobj)
+                key.fileobj.close()  # type: ignore[union-attr]
+
     def _check_subprocess_exit(self):
         """Check if the subprocess has exited."""
         if self._exit_code is None:
@@ -466,14 +495,48 @@ class WatchedSubprocess:
 
     def _send_heartbeat_if_needed(self):
         """Send a heartbeat to the client if heartbeat interval has passed."""
-        if time.monotonic() - self._last_heartbeat >= 
FASTEST_HEARTBEAT_INTERVAL:
-            try:
-                self.client.task_instances.heartbeat(self.ti_id, 
pid=self._process.pid)
-                self._last_heartbeat = time.monotonic()
-            except Exception:
-                log.warning("Failed to send heartbeat", exc_info=True)
-                # TODO: If we couldn't heartbeat for X times the interval, 
kill ourselves
-                pass
+        # Respect the minimum interval between heartbeat attempts
+        if (time.monotonic() - self._last_heartbeat_attempt) < 
MIN_HEARTBEAT_INTERVAL:
+            return
+
+        self._last_heartbeat_attempt = time.monotonic()
+        try:
+            self.client.task_instances.heartbeat(self.ti_id, 
pid=self._process.pid)
+            # Update the last heartbeat time on success
+            self._last_successful_heartbeat = time.monotonic()
+
+            # Reset the counter on success
+            self.failed_heartbeats = 0
+        except ServerResponseError as e:
+            if e.response.status_code in {HTTPStatus.NOT_FOUND, 
HTTPStatus.CONFLICT}:
+                log.error(
+                    "Server indicated the task shouldn't be running anymore",
+                    detail=e.detail,
+                    status_code=e.response.status_code,
+                )
+                self.kill(signal.SIGTERM)
+            else:
+                # If we get any other error, we'll just log it and try again 
next time
+                self._handle_heartbeat_failures()
+        except Exception:
+            self._handle_heartbeat_failures()
+
+    def _handle_heartbeat_failures(self):
+        """Increment the failed heartbeats counter and kill the process if too 
many failures."""
+        self.failed_heartbeats += 1
+        log.warning(
+            "Failed to send heartbeat. Will be retried",
+            failed_heartbeats=self.failed_heartbeats,
+            ti_id=self.ti_id,
+            max_retries=MAX_FAILED_HEARTBEATS,
+            exc_info=True,
+        )
+        # If we've failed to heartbeat too many times, kill the process
+        if self.failed_heartbeats >= MAX_FAILED_HEARTBEATS:
+            log.error(
+                "Too many failed heartbeats; terminating process", 
failed_heartbeats=self.failed_heartbeats
+            )
+            self.kill(signal.SIGTERM)
 
     @property
     def final_state(self):
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index 9e03cb07963..f0864474092 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 import os
 from pathlib import Path
 from typing import TYPE_CHECKING, NoReturn
@@ -72,7 +73,7 @@ def test_dags_dir():
 
 
 @pytest.fixture
-def captured_logs():
+def captured_logs(request):
     import structlog
 
     from airflow.sdk.log import configure_logging, reset_logging
@@ -81,6 +82,12 @@ def captured_logs():
     reset_logging()
     configure_logging(enable_pretty_log=False)
 
+    # Get log level from test parameter, defaulting to INFO if not provided
+    log_level = getattr(request, "param", logging.INFO)
+
+    # We want to capture all logs, but we don't want to see them in the test 
output
+    
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level))
+
     # But we need to replace remove the last processor (the one that turns 
JSON into text, as we want the
     # event dict for tests)
     cur_processors = structlog.get_config()["processors"]
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 3a50a73b3cf..53adf8df2e3 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -30,7 +30,6 @@ from unittest.mock import MagicMock
 
 import httpx
 import pytest
-import structlog
 from uuid6 import uuid7
 
 from airflow.sdk.api import client as sdk_client
@@ -64,9 +63,6 @@ def lineno():
 @pytest.mark.usefixtures("disable_capturing")
 class TestWatchedSubprocess:
     def test_reading_from_pipes(self, captured_logs, time_machine):
-        # Ignore anything lower than INFO for this test. Captured_logs resets 
things for us afterwards
-        
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
-
         def subprocess_main():
             # This is run in the subprocess!
 
@@ -177,9 +173,6 @@ class TestWatchedSubprocess:
         assert rc == -9
 
     def test_last_chance_exception_handling(self, capfd):
-        # Ignore anything lower than INFO for this test. Captured_logs resets 
things for us afterwards
-        
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
-
         def subprocess_main():
             # The real main() in task_runner catches exceptions! This is what 
would happen if we had a syntax
             # or import error for instance - a very early exception
@@ -210,7 +203,7 @@ class TestWatchedSubprocess:
         """Test that the WatchedSubprocess class regularly sends heartbeat 
requests, up to a certain frequency"""
         import airflow.sdk.execution_time.supervisor
 
-        monkeypatch.setattr(airflow.sdk.execution_time.supervisor, 
"FASTEST_HEARTBEAT_INTERVAL", 0.1)
+        monkeypatch.setattr(airflow.sdk.execution_time.supervisor, 
"MIN_HEARTBEAT_INTERVAL", 0.1)
 
         def subprocess_main():
             sys.stdin.readline()
@@ -241,9 +234,6 @@ class TestWatchedSubprocess:
     def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
         """Test running a simple DAG in a subprocess and capturing the 
output."""
 
-        # Ignore anything lower than INFO for this test.
-        
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
-
         instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901)
         time_machine.move_to(instant, tick=False)
 
@@ -299,6 +289,146 @@ class TestWatchedSubprocess:
             "previous_state": "running",
         }
 
+    @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, 
ids=["log_level=error"])
+    def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, 
mocker):
+        """
+        Test that ensures that the Supervisor does not cause the task to fail 
if the Task Instance is no longer
+        in the running state. Instead, it logs the error and terminates the 
task process if it
+        might be running in a different state or has already completed -- or 
running on a different worker.
+        """
+        import airflow.sdk.execution_time.supervisor
+
+        monkeypatch.setattr(airflow.sdk.execution_time.supervisor, 
"MIN_HEARTBEAT_INTERVAL", 0.1)
+
+        def subprocess_main():
+            sys.stdin.readline()
+            sleep(5)
+
+        ti_id = uuid7()
+
+        # Track the number of requests to simulate mixed responses
+        request_count = {"count": 0}
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/heartbeat":
+                request_count["count"] += 1
+                if request_count["count"] == 1:
+                    # First request succeeds
+                    return httpx.Response(status_code=204)
+                else:
+                    # Second request returns a conflict status code
+                    return httpx.Response(
+                        409,
+                        json={
+                            "reason": "not_running",
+                            "message": "TI is no longer in the running state 
and task should terminate",
+                            "current_state": "success",
+                        },
+                    )
+            # Return a 204 for all other requests like the initial call to 
mark the task as running
+            return httpx.Response(status_code=204)
+
+        proc = WatchedSubprocess.start(
+            path=os.devnull,
+            ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
+            client=make_client(transport=httpx.MockTransport(handle_request)),
+            target=subprocess_main,
+        )
+
+        # Wait for the subprocess to finish -- it should have been terminated
+        assert proc.wait() == -signal.SIGTERM
+
+        # Verify the number of requests made
+        assert request_count["count"] == 2
+        assert captured_logs == [
+            {
+                "detail": {
+                    "current_state": "success",
+                    "message": "TI is no longer in the running state and task 
should terminate",
+                    "reason": "not_running",
+                },
+                "event": "Server indicated the task shouldn't be running 
anymore",
+                "level": "error",
+                "status_code": 409,
+                "logger": "supervisor",
+                "timestamp": mocker.ANY,
+            }
+        ]
+
+    @pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True)
+    def test_heartbeat_failures_handling(self, monkeypatch, mocker, 
captured_logs, time_machine):
+        """
+        Test that ensures the WatchedSubprocess kills the process after
+        MAX_FAILED_HEARTBEATS are exceeded.
+        """
+        max_failed_heartbeats = 3
+        min_heartbeat_interval = 5
+        monkeypatch.setattr(
+            "airflow.sdk.execution_time.supervisor.MAX_FAILED_HEARTBEATS", 
max_failed_heartbeats
+        )
+        monkeypatch.setattr(
+            "airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL", 
min_heartbeat_interval
+        )
+
+        mock_process = mocker.Mock()
+        mock_process.pid = 12345
+
+        # Mock the client heartbeat method to raise an exception
+        mock_client_heartbeat = mocker.Mock(side_effect=Exception("Simulated 
heartbeat failure"))
+        client = mocker.Mock()
+        client.task_instances.heartbeat = mock_client_heartbeat
+
+        # Patch the kill method at the class level so we can assert it was 
called with the correct signal
+        mock_kill = 
mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")
+
+        proc = WatchedSubprocess(
+            ti_id=TI_ID,
+            pid=mock_process.pid,
+            stdin=mocker.MagicMock(),
+            client=client,
+            process=mock_process,
+        )
+
+        time_now = tz.datetime(2024, 11, 28, 12, 0, 0)
+        time_machine.move_to(time_now, tick=False)
+
+        # Simulate sending heartbeats and ensure the process gets killed after 
max retries
+        for i in range(1, max_failed_heartbeats):
+            proc._send_heartbeat_if_needed()
+            assert proc.failed_heartbeats == i  # Increment happens after 
failure
+            mock_client_heartbeat.assert_called_with(TI_ID, 
pid=mock_process.pid)
+
+            # Ensure the retry log is present
+            expected_log = {
+                "event": "Failed to send heartbeat. Will be retried",
+                "failed_heartbeats": i,
+                "ti_id": TI_ID,
+                "max_retries": max_failed_heartbeats,
+                "level": "warning",
+                "logger": "supervisor",
+                "timestamp": mocker.ANY,
+                "exception": mocker.ANY,
+            }
+
+            assert expected_log in captured_logs
+
+            # Advance time by `min_heartbeat_interval` to allow the next 
heartbeat
+            time_machine.shift(min_heartbeat_interval)
+
+        # On the final failure, the process should be killed
+        proc._send_heartbeat_if_needed()
+
+        assert proc.failed_heartbeats == max_failed_heartbeats
+        mock_kill.assert_called_once_with(signal.SIGTERM)
+        mock_client_heartbeat.assert_called_with(TI_ID, pid=mock_process.pid)
+        assert {
+            "event": "Too many failed heartbeats; terminating process",
+            "level": "error",
+            "failed_heartbeats": max_failed_heartbeats,
+            "logger": "supervisor",
+            "timestamp": mocker.ANY,
+        } in captured_logs
+
 
 class TestHandleRequest:
     @pytest.fixture

Reply via email to