ashb commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r3169346979


##########
task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py:
##########
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Supervised execution of TestConnection workloads."""
+
+from __future__ import annotations
+
+import signal
+import uuid
+
+import structlog
+
+from airflow.sdk.api.client import Client
+from airflow.sdk.api.datamodels._generated import ConnectionTestState
+from airflow.sdk.definitions.connection import Connection as SDKConnection
+from airflow.sdk.execution_time.context import _preset_connections
+
+__all__ = ["supervise_connection_test"]
+
+log = structlog.get_logger(logger_name="connection_test_supervisor")
+
+
+def supervise_connection_test(
+    *,
+    connection_test_id: uuid.UUID,
+    connection_id: str,
+    timeout: int,
+    token: str,
+    server: str,
+) -> int:
+    """Execute a connection test on the worker and report the result via the 
Execution API."""
+    from airflow.models.connection import Connection

Review Comment:
   The supervisor imports `airflow.models.connection.Connection` solely to call 
`test_connection()` on it. That method dispatches to the provider hook via 
`ProvidersManager` (core), which adds a new ORM dependency to task-sdk and is 
the wrong direction — task-sdk should not be pulling in core ORM models.
   
   The right fix, which belongs in this PR since connection testing is its 
stated purpose:
   
   **Add `test_connection()` to 
`airflow.sdk.definitions.connection.Connection`**, dispatching via 
`ProvidersManagerTaskRuntime` (already in task-sdk for exactly this kind of 
hook dispatch):
   
   ```python
   def test_connection(self) -> tuple[bool, str]:
       from airflow.sdk.execution_time.context import _preset_connections
       from airflow.sdk.providers_manager_runtime import 
ProvidersManagerTaskRuntime
       from airflow.sdk._shared.module_loading import import_string
   
       hook_info = ProvidersManagerTaskRuntime().hooks.get(self.conn_type)
       if not hook_info:
           return False, f"Unknown conn_type: {self.conn_type!r}"
       outer = _preset_connections.get() or {}
       token = _preset_connections.set({**outer, self.conn_id: self})
       try:
           hook = import_string(hook_info.hook_class_name)(conn_id=self.conn_id)
           if not hasattr(hook, "test_connection"):
               return False, f"Hook {type(hook).__name__} doesn't implement 
test_connection"
           return hook.test_connection()
       finally:
           _preset_connections.reset(token)
   ```
   
   `_preset_connections` is the right injection point (checked before all 
secrets backends, including user-configured custom ones), but it should be 
encapsulated inside `test_connection()` rather than wired manually in the 
supervisor.
   
   That reduces the supervisor to:
   
   ```python
   conn = SDKConnection(conn_id=r.conn_id, conn_type=r.conn_type, ...)
   with TimeoutPosix(seconds=timeout, error_message=f"Connection test timed out 
after {timeout}s"):
       success, message = conn.test_connection()
   ```
   
   Which also fixes the second issue here: `signal.signal(signal.SIGALRM, ...)` 
+ `signal.alarm()` should use `TimeoutPosix` from 
`airflow.sdk.execution_time.timeout` — it already exists, handles non-POSIX 
gracefully, and uses `signal.setitimer` (sub-second capable) rather than 
integer-only `signal.alarm`.
   
   



##########
task-sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -144,7 +145,16 @@ def _convert_variable_result_to_variable(var_result: 
VariableResult, deserialize
     return Variable(**var_result.model_dump(exclude={"type"}))
 
 
+_preset_connections: ContextVar[dict[str, Connection]] = 
ContextVar("_preset_connections", default={})

Review Comment:
   `default={}` is a single shared mutable object — every context that hasn't 
called `.set()` gets back the exact same dict from `.get()`. The current code 
never mutates it, so it doesn't blow up today, but it's a trap for anyone who 
writes `_preset_connections.get()[conn_id] = conn` instead of using `.set()`.
   
   Change to `default=None` and guard at call sites:
   
   ```python
   _preset_connections: ContextVar[dict[str, Connection] | None] = 
ContextVar("_preset_connections", default=None)
   
   # in _get_connection:
   preset = _preset_connections.get()
   if preset and conn_id in preset:
       ...
   ```
   
   Also: `.set({preset.conn_id: preset})` replaces the entire dict, shadowing 
any outer preset for the duration of the call. If `test_connection()` is ever 
called from within a task that has its own preset set (e.g. test tooling), the 
outer entries disappear until `.reset()`. Merge instead of replace:
   
   ```python
   outer = _preset_connections.get() or {}
   token = _preset_connections.set({**outer, self.conn_id: self})
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to