ashb commented on code in PR #62343: URL: https://github.com/apache/airflow/pull/62343#discussion_r3169346979
########## task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py: ########## @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervised execution of TestConnection workloads.""" + +from __future__ import annotations + +import signal +import uuid + +import structlog + +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import ConnectionTestState +from airflow.sdk.definitions.connection import Connection as SDKConnection +from airflow.sdk.execution_time.context import _preset_connections + +__all__ = ["supervise_connection_test"] + +log = structlog.get_logger(logger_name="connection_test_supervisor") + + +def supervise_connection_test( + *, + connection_test_id: uuid.UUID, + connection_id: str, + timeout: int, + token: str, + server: str, +) -> int: + """Execute a connection test on the worker and report the result via the Execution API.""" + from airflow.models.connection import Connection Review Comment: The supervisor imports `airflow.models.connection.Connection` solely to call `test_connection()` on it. That method dispatches to the provider hook via `ProvidersManager` (core), which adds a new ORM dependency to task-sdk and is the wrong direction — task-sdk should not be pulling in core ORM models. The right fix, which belongs in this PR since connection testing is its stated purpose: **Add `test_connection()` to `airflow.sdk.definitions.connection.Connection`**, dispatching via `ProvidersManagerTaskRuntime` (already in task-sdk for exactly this kind of hook dispatch): ```python def test_connection(self) -> tuple[bool, str]: from airflow.sdk.execution_time.context import _preset_connections from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime from airflow.sdk._shared.module_loading import import_string hook_info = ProvidersManagerTaskRuntime().hooks.get(self.conn_type) if not hook_info: return False, f"Unknown conn_type: {self.conn_type!r}" outer = _preset_connections.get() or {} token = _preset_connections.set({**outer, self.conn_id: self}) try: hook = import_string(hook_info.hook_class_name)(conn_id=self.conn_id) if not hasattr(hook, "test_connection"): return False, f"Hook {type(hook).__name__} doesn't implement test_connection" return hook.test_connection() finally: _preset_connections.reset(token) ``` `_preset_connections` is the right injection point (checked before all secrets backends, including user-configured custom ones), but it should be encapsulated inside `test_connection()` rather than wired manually in the supervisor. That reduces the supervisor to: ```python conn = SDKConnection(conn_id=r.conn_id, conn_type=r.conn_type, ...) with TimeoutPosix(seconds=timeout, error_message=f"Connection test timed out after {timeout}s"): success, message = conn.test_connection() ``` Which also fixes the second issue here: `signal.signal(signal.SIGALRM, ...)` + `signal.alarm()` should use `TimeoutPosix` from `airflow.sdk.execution_time.timeout` — it already exists, handles non-POSIX gracefully, and uses `signal.setitimer` (sub-second capable) rather than integer-only `signal.alarm`. ########## task-sdk/src/airflow/sdk/execution_time/context.py: ########## @@ -144,7 +145,16 @@ def _convert_variable_result_to_variable(var_result: VariableResult, deserialize return Variable(**var_result.model_dump(exclude={"type"})) +_preset_connections: ContextVar[dict[str, Connection]] = ContextVar("_preset_connections", default={}) Review Comment: `default={}` is a single shared mutable object — every context that hasn't called `.set()` gets back the exact same dict from `.get()`. The current code never mutates it, so it doesn't blow up today, but it's a trap for anyone who writes `_preset_connections.get()[conn_id] = conn` instead of using `.set()`. Change to `default=None` and guard at call sites: ```python _preset_connections: ContextVar[dict[str, Connection] | None] = ContextVar("_preset_connections", default=None) # in _get_connection: preset = _preset_connections.get() if preset and conn_id in preset: ... ``` Also: `.set({preset.conn_id: preset})` replaces the entire dict, shadowing any outer preset for the duration of the call. If `test_connection()` is ever called from within a task that has its own preset set (e.g. test tooling), the outer entries disappear until `.reset()`. Merge instead of replace: ```python outer = _preset_connections.get() or {} token = _preset_connections.set({**outer, self.conn_id: self}) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
