ashb commented on code in PR #44465:
URL: https://github.com/apache/airflow/pull/44465#discussion_r1863167664


##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -391,12 +391,55 @@ def _send_startup_message(self, ti: TaskInstance, path: 
str | os.PathLike[str],
         self.stdin.write(msg.model_dump_json().encode())
         self.stdin.write(b"\n")
 
-    def kill(self, signal: signal.Signals = signal.SIGINT):
+    def kill(
+        self,
+        signal_to_send: signal.Signals = signal.SIGINT,
+        escalation_delay: float = 5.0,
+        bypass_escalation: bool = False,
+    ):
+        """
+        Attempt to terminate the subprocess with a given signal.
+
+        If the process does not exit within `escalation_delay` seconds, 
escalate to SIGTERM and eventually SIGKILL if necessary.
+
+        :param signal_to_send: The signal to send initially (default is 
SIGINT).
+        :param escalation_delay: Time in seconds to wait before escalating to 
a stronger signal.
+        :param bypass_escalation: If True, send the signal directly to the 
process without escalation.
+        """
         if self._exit_code is not None:
             return
 
-        with suppress(ProcessLookupError):
-            os.kill(self.pid, signal)
+        if bypass_escalation:
+            with suppress(ProcessLookupError):
+                os.kill(self.pid, signal_to_send)
+            return
+
+        # Escalation sequence: SIGINT -> SIGTERM -> SIGKILL
+        escalation_path = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL]
+        if signal_to_send in escalation_path:
+            # Start from `initial_signal`
+            escalation_path = 
escalation_path[escalation_path.index(signal_to_send) :]
+
+        for sig in escalation_path:
+            try:
+                if sig == signal.SIGKILL:
+                    self._process.kill()
+                elif sig == signal.SIGTERM:
+                    self._process.terminate()
+                else:
+                    os.kill(self.pid, sig)
+
+                self._exit_code = self._process.wait(timeout=escalation_delay)

Review Comment:
   No not really. The flow could be something like this:
   
   1. We send SIGINT to the process.
   2. It catches it, and send a log message
   3. The first call to `self.selector.select` would catch that, then enter wait
   4. The subprocess now sends a request to, for example, update an XCom value, 
and blocks waiting to read the response
   5. The supervisor is blocking in wait
   6. Supervisor wait times out, we then send sigkill.
   
   Remember, the socket being closed (which also happens when the process is 
hard killed with kill -9) counts as an event in the selector.
   
   We want to call this again.
   
   ```python
               events = self.selector.select(timeout=escalation_delay)
               self._process_file_object_events(events)
   
               self._check_subprocess_exit()
   ```
   
   and I suspect all three of those calls should be moved into 
`_process_file_object_events` and it renamed, so that the call is something 
like:
   
   ```python
       def _service_subprocess(max_time: float):
               events = self.selector.select(timeout=max_time)
               for key, _ in events:
                   socket_handler = key.data
                   ....
   
               self._check_subprocess_exit()
   ```
   
   That way `_monitor_subprocess` and `kill` can both simply call 
`_service_subprocess`. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to