ashb commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r3169346979


##########
task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py:
##########
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Supervised execution of TestConnection workloads."""
+
+from __future__ import annotations
+
+import signal
+import uuid
+
+import structlog
+
+from airflow.sdk.api.client import Client
+from airflow.sdk.api.datamodels._generated import ConnectionTestState
+from airflow.sdk.definitions.connection import Connection as SDKConnection
+from airflow.sdk.execution_time.context import _preset_connections
+
+__all__ = ["supervise_connection_test"]
+
+log = structlog.get_logger(logger_name="connection_test_supervisor")
+
+
+def supervise_connection_test(
+    *,
+    connection_test_id: uuid.UUID,
+    connection_id: str,
+    timeout: int,
+    token: str,
+    server: str,
+) -> int:
+    """Execute a connection test on the worker and report the result via the 
Execution API."""
+    from airflow.models.connection import Connection

Review Comment:
   The supervisor imports `airflow.models.connection.Connection` solely to call 
`test_connection()` on it. That method dispatches to the provider hook via 
`ProvidersManager` (core), which adds a new ORM dependency to task-sdk and is 
the wrong direction — task-sdk should not be pulling in core ORM models.
   
   The right fix, which belongs in this PR since connection testing is its 
stated purpose:
   
   **Add `test_connection()` to 
`airflow.sdk.definitions.connection.Connection`**, dispatching via 
`ProvidersManagerTaskRuntime` (already in task-sdk for exactly this kind of 
hook dispatch):
   
   ```python
   def test_connection(self) -> tuple[bool, str]:
       from airflow.sdk.execution_time.context import _preset_connections
       from airflow.sdk.providers_manager_runtime import 
ProvidersManagerTaskRuntime
       from airflow.sdk._shared.module_loading import import_string
   
       hook_info = ProvidersManagerTaskRuntime().hooks.get(self.conn_type)
       if not hook_info:
           return False, f"Unknown conn_type: {self.conn_type!r}"
       outer = _preset_connections.get() or {}
       token = _preset_connections.set({**outer, self.conn_id: self})
       try:
           hook = import_string(hook_info.hook_class_name)(conn_id=self.conn_id)
           if not hasattr(hook, "test_connection"):
               return False, f"Hook {type(hook).__name__} doesn't implement 
test_connection"
           return hook.test_connection()
       finally:
           _preset_connections.reset(token)
   ```
   
   `_preset_connections` is the right injection point (checked before all 
secrets backends, including user-configured custom ones), but it should be 
encapsulated inside `test_connection()` rather than wired manually in the 
supervisor.
   
   That reduces the supervisor to:
   
   ```python
   conn = SDKConnection(conn_id=r.conn_id, conn_type=r.conn_type, ...)
   with TimeoutPosix(seconds=timeout, error_message=f"Connection test timed out 
after {timeout}s"):
       success, message = conn.test_connection()
   ```
   
   Which also fixes the second issue here: `signal.signal(signal.SIGALRM, ...)` 
+ `signal.alarm()` should use `TimeoutPosix` from 
`airflow.sdk.execution_time.timeout` — it already exists, handles non-POSIX 
gracefully, and uses `signal.setitimer` (sub-second capable) rather than 
integer-only `signal.alarm`.
   
   ---
   Drafted-by: Claude Code (claude-sonnet-4-6); reviewed by @ashb before posting
   



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py:
##########
@@ -259,6 +284,87 @@ def test_connection(test_body: ConnectionBody) -> 
ConnectionTestResponse:
         os.environ.pop(conn_env_var, None)
 
 
+@connections_router.post(
+    "/test-async",
+    status_code=status.HTTP_202_ACCEPTED,
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_403_FORBIDDEN,
+            status.HTTP_409_CONFLICT,
+            status.HTTP_422_UNPROCESSABLE_ENTITY,
+        ]
+    ),
+    dependencies=[Depends(requires_access_connection(method="POST")), 
Depends(action_logging())],
+)
+def test_connection_async(
+    test_body: ConnectionTestRequestBody,
+    session: SessionDep,
+) -> ConnectionTestQueuedResponse:
+    """
+    Queue an async connection test to be executed on a worker.
+
+    The connection data is stored in the test request table and the worker
+    reads from there. Returns a token to poll for the result via
+    GET /connections/test-async/{token}.
+    """
+    _ensure_test_connection_enabled()
+    _ensure_executor_is_configured(test_body.executor)
+
+    connection_test = ConnectionTestRequest(
+        connection_id=test_body.connection_id,
+        conn_type=test_body.conn_type,
+        host=test_body.host,
+        login=test_body.login,
+        password=test_body.password,
+        schema=test_body.schema_,
+        port=test_body.port,
+        extra=test_body.extra,
+        commit_on_success=test_body.commit_on_success,
+        executor=test_body.executor,
+        queue=test_body.queue,
+    )
+    session.add(connection_test)

Review Comment:
   `session.add(connection_test)` will raise an `IntegrityError` if there is 
already an active test for the same `connection_id` (the 
`uq_connection_test_request_active_conn` unique constraint on 
`active_connection_id`). The endpoint documents a 409 response but nothing 
catches the DB error — the caller gets a 500 instead.



##########
airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py:
##########
@@ -78,12 +79,49 @@ class ConnectionCollectionResponse(BaseModel):
 
 
 class ConnectionTestResponse(BaseModel):
-    """Connection Test serializer for responses."""
+    """Connection Test serializer for synchronous test responses."""
 
     status: bool
     message: str
 
 
+class ConnectionTestRequestBody(StrictBaseModel):
+    """Request body for async connection test."""
+
+    connection_id: str
+    conn_type: str
+    host: str | None = None
+    login: str | None = None
+    schema_: str | None = Field(None, alias="schema")
+    port: int | None = None
+    password: str | None = None
+    extra: str | None = None
+    commit_on_success: bool = Field(

Review Comment:
   `ConnectionTestRequestBody` duplicates almost all of `ConnectionBody` (same 
`connection_id`, `conn_type`, `host`, `login`, `schema_`, `port`, `password`, 
`extra` fields). Use `ConnectionBody` as a component and extend it:
   
   ```python
   class ConnectionTestRequestBody(ConnectionBody):
       commit_on_success: bool = Field(default=False, description="...")
       executor: str | None = None
       queue: str | None = None
   ```
   
   `ConnectionBody` already has the `extra` JSON validator, the `connection_id` 
pattern constraint, and the standard field definitions. Duplicating those 
creates drift — e.g. `ConnectionBody.connection_id` has `pattern=r"^[\w.-]+$"` 
but `ConnectionTestRequestBody` drops it.
   
   We should also exclude `commit_on_success` from this schema as it makes no 
sense on the TaskSDK side. I think we've done something like this elsewhere, 
there's a pydantic setting/flag we can set in `Field()`
   



##########
task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py:
##########
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Supervised execution of TestConnection workloads."""
+
+from __future__ import annotations
+
+import signal
+import uuid
+
+import structlog
+
+from airflow.sdk.api.client import Client
+from airflow.sdk.api.datamodels._generated import ConnectionTestState
+from airflow.sdk.definitions.connection import Connection as SDKConnection
+from airflow.sdk.execution_time.context import _preset_connections
+
+__all__ = ["supervise_connection_test"]
+
+log = structlog.get_logger(logger_name="connection_test_supervisor")
+
+
+def supervise_connection_test(
+    *,
+    connection_test_id: uuid.UUID,
+    connection_id: str,
+    timeout: int,
+    token: str,
+    server: str,
+) -> int:
+    """Execute a connection test on the worker and report the result via the 
Execution API."""
+    from airflow.models.connection import Connection

Review Comment:
   🚨 We cannot introduce new dependencies in TaskSDK to core airflow. We will 
need to do something else, like remove `test_connection()` from 
models.connection and move it in to the SDK version.



##########
task-sdk/src/airflow/sdk/api/client.py:
##########
@@ -872,6 +874,25 @@ def get_detail_response(self, ti_id: uuid.UUID) -> 
HITLDetailResponse:
         return HITLDetailResponse.model_validate_json(resp.read())
 
 
+class ConnectionTestOperations:
+    __slots__ = ("client",)
+
+    def __init__(self, client: Client):
+        self.client = client
+
+    def get_connection(self, connection_test_id: uuid.UUID) -> 
ConnectionResponse:
+        """Fetch connection data for a test request from the API server."""
+        resp = 
self.client.get(f"connection-tests/{connection_test_id}/connection")
+        return ConnectionResponse.model_validate_json(resp.read())

Review Comment:
   Return type is `ConnectionResponse` but the endpoint (`GET 
/connection-tests/{id}/connection`) returns `ConnectionTestConnectionResponse`. 
Both models happen to have the same fields today so it works at runtime, but 
the annotation is wrong and will silently break if the two models ever diverge.
   
   ```suggestion
       def get_connection(self, connection_test_id: uuid.UUID) -> 
ConnectionTestConnectionResponse:
           """Fetch connection data for a test request from the API server."""
           resp = 
self.client.get(f"connection-tests/{connection_test_id}/connection")
           return 
ConnectionTestConnectionResponse.model_validate_json(resp.read())
   ```
   
   ---
   Drafted-by: Claude Code (claude-sonnet-4-6); reviewed by @ashb before posting
   



##########
task-sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -144,7 +145,16 @@ def _convert_variable_result_to_variable(var_result: 
VariableResult, deserialize
     return Variable(**var_result.model_dump(exclude={"type"}))
 
 
+_preset_connections: ContextVar[dict[str, Connection]] = 
ContextVar("_preset_connections", default={})

Review Comment:
   `default={}` is a single shared mutable object — every context that hasn't 
called `.set()` gets back the exact same dict from `.get()`. The current code 
never mutates it, so it doesn't blow up today, but it's a trap for anyone who 
writes `_preset_connections.get()[conn_id] = conn` instead of using `.set()`.
   
   Change to `default=None` and guard at call sites:
   
   ```python
   _preset_connections: ContextVar[dict[str, Connection] | None] = 
ContextVar("_preset_connections", default=None)
   
   # in _get_connection:
   preset = _preset_connections.get()
   if preset and conn_id in preset:
       ...
   ```
   
   Also: `.set({preset.conn_id: preset})` replaces the entire dict, shadowing 
any outer preset for the duration of the call. If `test_connection()` is ever 
called from within a task that has its own preset set (e.g. test tooling), the 
outer entries disappear until `.reset()`. Merge instead of replace:
   
   ```python
   outer = _preset_connections.get() or {}
   token = _preset_connections.set({**outer, self.conn_id: self})
   ```
   
   ---
   Drafted-by: Claude Code (claude-sonnet-4-6); reviewed by @ashb before posting
   



##########
airflow-core/src/airflow/api_fastapi/execution_api/security.py:
##########
@@ -190,6 +190,14 @@ async def require_auth(
                 detail="Token subject does not match task instance ID",
             )
 
+    if "ct:self" in security_scopes.scopes:

Review Comment:
   ```suggestion
       elsif "ct:self" in security_scopes.scopes:
   ```
   I think? ti:self. and ct:self can't co-exist?



##########
airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml:
##########
@@ -1844,6 +1844,115 @@ paths:
       security:
       - OAuth2PasswordBearer: []
       - HTTPBearer: []
+  /api/v2/connections/test-async:
+    post:
+      tags:
+      - Connection
+      summary: Test Connection Async
+      description: 'Queue an async connection test to be executed on a worker.

Review Comment:
   Yes this async, but also no this isn't async 😁 
   
   It's not hitting async python code.
   
   I think lets make the URL be `/api/v2/connections/enqueue-test` and just 
remove "async" from the desc?



##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py:
##########
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from uuid import UUID
+
+from cadwyn import VersionedAPIRouter
+from fastapi import HTTPException, Response, Security, status
+
+from airflow.api_fastapi.auth.tokens import JWTGenerator
+from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.execution_api.datamodels.connection_test import (
+    ConnectionTestConnectionResponse,
+    ConnectionTestResultBody,
+)
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.deps import DepContainer
+from airflow.api_fastapi.execution_api.security import CurrentTIToken, 
ExecutionAPIRoute, require_auth
+from airflow.models.connection_test import (
+    TERMINAL_STATES,
+    ConnectionTestRequest,
+    ConnectionTestState,
+)
+
+router = VersionedAPIRouter()
+
+ct_id_router = VersionedAPIRouter(
+    route_class=ExecutionAPIRoute,
+    dependencies=[
+        Security(require_auth, scopes=["ct:self"]),
+    ],
+)

Review Comment:
   Two routers seems needless?
   
   ```suggestion
   router = VersionedAPIRouter(
       route_class=ExecutionAPIRoute,
       dependencies=[
           Security(require_auth, scopes=["ct:self"]),
       ],
   )
   ```
   
   etc.



##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py:
##########
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from uuid import UUID
+
+from cadwyn import VersionedAPIRouter
+from fastapi import HTTPException, Response, Security, status
+
+from airflow.api_fastapi.auth.tokens import JWTGenerator
+from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.execution_api.datamodels.connection_test import (
+    ConnectionTestConnectionResponse,
+    ConnectionTestResultBody,
+)
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.deps import DepContainer
+from airflow.api_fastapi.execution_api.security import CurrentTIToken, 
ExecutionAPIRoute, require_auth
+from airflow.models.connection_test import (
+    TERMINAL_STATES,
+    ConnectionTestRequest,
+    ConnectionTestState,
+)
+
+router = VersionedAPIRouter()
+
+ct_id_router = VersionedAPIRouter(
+    route_class=ExecutionAPIRoute,
+    dependencies=[
+        Security(require_auth, scopes=["ct:self"]),
+    ],
+)
+
+
+@ct_id_router.get(
+    "/{connection_test_id}/connection",

Review Comment:
   I'm a little bit worried about this endpoint exposing or leaking the 
password somehow. I wonder if we should allow each ct to only be fetched once? 



##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py:
##########
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from uuid import UUID
+
+from cadwyn import VersionedAPIRouter
+from fastapi import HTTPException, Response, Security, status
+
+from airflow.api_fastapi.auth.tokens import JWTGenerator
+from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.execution_api.datamodels.connection_test import (
+    ConnectionTestConnectionResponse,
+    ConnectionTestResultBody,
+)
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.deps import DepContainer
+from airflow.api_fastapi.execution_api.security import CurrentTIToken, 
ExecutionAPIRoute, require_auth
+from airflow.models.connection_test import (
+    TERMINAL_STATES,
+    ConnectionTestRequest,
+    ConnectionTestState,
+)
+
+router = VersionedAPIRouter()
+
+ct_id_router = VersionedAPIRouter(
+    route_class=ExecutionAPIRoute,
+    dependencies=[
+        Security(require_auth, scopes=["ct:self"]),
+    ],
+)
+
+
+@ct_id_router.get(
+    "/{connection_test_id}/connection",
+    responses={
+        status.HTTP_404_NOT_FOUND: {"description": "Connection test not 
found"},
+    },
+)
+def get_connection_test_connection(
+    connection_test_id: UUID,
+    session: SessionDep,
+) -> ConnectionTestConnectionResponse:
+    """Return the connection data stored in a test request (called by 
workers)."""
+    ct = session.get(ConnectionTestRequest, connection_test_id)
+    if ct is None:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail={
+                "reason": "not_found",
+                "message": f"Connection test {connection_test_id} not found",
+            },
+        )
+
+    return ConnectionTestConnectionResponse(
+        conn_id=ct.connection_id,
+        conn_type=ct.conn_type,
+        host=ct.host,
+        login=ct.login,
+        password=ct.password,
+        schema=ct.schema,
+        port=ct.port,
+        extra=ct.extra,
+    )
+
+
+@ct_id_router.patch(
+    "/{connection_test_id}",
+    status_code=status.HTTP_204_NO_CONTENT,
+    dependencies=[Security(require_auth, scopes=["token:execution", 
"token:workload"])],
+    responses={
+        status.HTTP_404_NOT_FOUND: {"description": "Connection test not 
found"},
+        status.HTTP_409_CONFLICT: {"description": "Connection test already in 
a terminal state"},
+    },
+)
+def patch_connection_test(
+    connection_test_id: UUID,
+    body: ConnectionTestResultBody,
+    response: Response,
+    session: SessionDep,
+    services=DepContainer,
+    token: TIToken = CurrentTIToken,
+) -> None:
+    """Update the result of a connection test (called by workers)."""
+    ct = session.get(ConnectionTestRequest, connection_test_id, 
with_for_update=True)
+    if ct is None:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail={
+                "reason": "not_found",
+                "message": f"Connection test {connection_test_id} not found",
+            },
+        )
+
+    if ct.state in TERMINAL_STATES:
+        raise HTTPException(
+            status_code=status.HTTP_409_CONFLICT,
+            detail={
+                "reason": "conflict",
+                "message": f"Connection test {connection_test_id} is already 
in terminal state: {ct.state}",
+            },
+        )
+
+    ct.state = body.state
+    ct.result_message = body.result_message
+
+    if body.state == ConnectionTestState.SUCCESS and ct.commit_on_success:
+        ct.commit_to_connection_table(session=session)
+
+    # JWTReissueMiddleware also writes Refreshed-API-Token but skips workload 
tokens, so we set it here for the workload→execution swap.
+    if token.claims.scope == "workload":
+        generator: JWTGenerator = services.get(JWTGenerator)
+        execution_token = generator.generate(extras={"sub": 
str(connection_test_id), "scope": "execution"})
+        response.headers["Refreshed-API-Token"] = execution_token

Review Comment:
   Why do we need this? Isn't this in practice only ever used when the conn 
test is finishing, so the token doesn't matter anymore



##########
airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py:
##########


Review Comment:
   This version is already released. We can't ret-conn it. You'll need to add a 
new migration instead.



##########
airflow-core/src/airflow/config_templates/config.yml:
##########
@@ -2602,6 +2602,33 @@ scheduler:
       type: float
       example: ~
       default: "120.0"
+    connection_test_timeout:
+      description: |
+        Maximum number of seconds an async connection test is allowed to run
+        before it is considered timed out. The scheduler reaper uses this value
+        plus a grace period to mark stale tests as failed.
+      version_added: 3.3.0
+      type: integer
+      example: ~
+      default: "60"
+    max_connection_test_concurrency:

Review Comment:
   Ditto.



##########
airflow-core/src/airflow/config_templates/config.yml:
##########
@@ -2602,6 +2602,33 @@ scheduler:
       type: float
       example: ~
       default: "120.0"
+    connection_test_timeout:

Review Comment:
   This does not belong under the `scheduler` section.



##########
airflow-core/src/airflow/executors/base_executor.py:
##########
@@ -216,6 +218,7 @@ def __init__(self, parallelism: int = PARALLELISM, 
team_name: str | None = None)
         self.team_name: str | None = team_name
         self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
         self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}
+        self.queued_connection_tests: dict[ConnectionTestKey, 
workloads.TestConnection] = {}

Review Comment:
   Hmmmmmmm I'm really not liking the growth of this pattern. I wonder if there 
should be a single `queue` dict?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to