ashb commented on code in PR #62343: URL: https://github.com/apache/airflow/pull/62343#discussion_r3169346979
########## task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py: ########## @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervised execution of TestConnection workloads.""" + +from __future__ import annotations + +import signal +import uuid + +import structlog + +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import ConnectionTestState +from airflow.sdk.definitions.connection import Connection as SDKConnection +from airflow.sdk.execution_time.context import _preset_connections + +__all__ = ["supervise_connection_test"] + +log = structlog.get_logger(logger_name="connection_test_supervisor") + + +def supervise_connection_test( + *, + connection_test_id: uuid.UUID, + connection_id: str, + timeout: int, + token: str, + server: str, +) -> int: + """Execute a connection test on the worker and report the result via the Execution API.""" + from airflow.models.connection import Connection Review Comment: The supervisor imports `airflow.models.connection.Connection` solely to call `test_connection()` on it. That method dispatches to the provider hook via `ProvidersManager` (core), which adds a new ORM dependency to task-sdk and is the wrong direction — task-sdk should not be pulling in core ORM models. The right fix, which belongs in this PR since connection testing is its stated purpose: **Add `test_connection()` to `airflow.sdk.definitions.connection.Connection`**, dispatching via `ProvidersManagerTaskRuntime` (already in task-sdk for exactly this kind of hook dispatch): ```python def test_connection(self) -> tuple[bool, str]: from airflow.sdk.execution_time.context import _preset_connections from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime from airflow.sdk._shared.module_loading import import_string hook_info = ProvidersManagerTaskRuntime().hooks.get(self.conn_type) if not hook_info: return False, f"Unknown conn_type: {self.conn_type!r}" outer = _preset_connections.get() or {} token = _preset_connections.set({**outer, self.conn_id: self}) try: hook = import_string(hook_info.hook_class_name)(conn_id=self.conn_id) if not hasattr(hook, "test_connection"): return False, f"Hook {type(hook).__name__} doesn't implement test_connection" return hook.test_connection() finally: _preset_connections.reset(token) ``` `_preset_connections` is the right injection point (checked before all secrets backends, including user-configured custom ones), but it should be encapsulated inside `test_connection()` rather than wired manually in the supervisor. That reduces the supervisor to: ```python conn = SDKConnection(conn_id=r.conn_id, conn_type=r.conn_type, ...) with TimeoutPosix(seconds=timeout, error_message=f"Connection test timed out after {timeout}s"): success, message = conn.test_connection() ``` Which also fixes the second issue here: `signal.signal(signal.SIGALRM, ...)` + `signal.alarm()` should use `TimeoutPosix` from `airflow.sdk.execution_time.timeout` — it already exists, handles non-POSIX gracefully, and uses `signal.setitimer` (sub-second capable) rather than integer-only `signal.alarm`. --- Drafted-by: Claude Code (claude-sonnet-4-6); reviewed by @ashb before posting ########## airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py: ########## @@ -259,6 +284,87 @@ def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse: os.environ.pop(conn_env_var, None) +@connections_router.post( + "/test-async", + status_code=status.HTTP_202_ACCEPTED, + responses=create_openapi_http_exception_doc( + [ + status.HTTP_403_FORBIDDEN, + status.HTTP_409_CONFLICT, + status.HTTP_422_UNPROCESSABLE_ENTITY, + ] + ), + dependencies=[Depends(requires_access_connection(method="POST")), Depends(action_logging())], +) +def test_connection_async( + test_body: ConnectionTestRequestBody, + session: SessionDep, +) -> ConnectionTestQueuedResponse: + """ + Queue an async connection test to be executed on a worker. + + The connection data is stored in the test request table and the worker + reads from there. Returns a token to poll for the result via + GET /connections/test-async/{token}. + """ + _ensure_test_connection_enabled() + _ensure_executor_is_configured(test_body.executor) + + connection_test = ConnectionTestRequest( + connection_id=test_body.connection_id, + conn_type=test_body.conn_type, + host=test_body.host, + login=test_body.login, + password=test_body.password, + schema=test_body.schema_, + port=test_body.port, + extra=test_body.extra, + commit_on_success=test_body.commit_on_success, + executor=test_body.executor, + queue=test_body.queue, + ) + session.add(connection_test) Review Comment: `session.add(connection_test)` will raise an `IntegrityError` if there is already an active test for the same `connection_id` (the `uq_connection_test_request_active_conn` unique constraint on `active_connection_id`). The endpoint documents a 409 response but nothing catches the DB error — the caller gets a 500 instead. ########## airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py: ########## @@ -78,12 +79,49 @@ class ConnectionCollectionResponse(BaseModel): class ConnectionTestResponse(BaseModel): - """Connection Test serializer for responses.""" + """Connection Test serializer for synchronous test responses.""" status: bool message: str +class ConnectionTestRequestBody(StrictBaseModel): + """Request body for async connection test.""" + + connection_id: str + conn_type: str + host: str | None = None + login: str | None = None + schema_: str | None = Field(None, alias="schema") + port: int | None = None + password: str | None = None + extra: str | None = None + commit_on_success: bool = Field( Review Comment: `ConnectionTestRequestBody` duplicates almost all of `ConnectionBody` (same `connection_id`, `conn_type`, `host`, `login`, `schema_`, `port`, `password`, `extra` fields). Use `ConnectionBody` as a component and extend it: ```python class ConnectionTestRequestBody(ConnectionBody): commit_on_success: bool = Field(default=False, description="...") executor: str | None = None queue: str | None = None ``` `ConnectionBody` already has the `extra` JSON validator, the `connection_id` pattern constraint, and the standard field definitions. Duplicating those creates drift — e.g. `ConnectionBody.connection_id` has `pattern=r"^[\w.-]+$"` but `ConnectionTestRequestBody` drops it. We should also exclude `commit_on_success` from this schema as it makes no sense on the TaskSDK side. I think we've done something like this elsewhere, there's a pydantic setting/flag we can set in `Field()` ########## task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py: ########## @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervised execution of TestConnection workloads.""" + +from __future__ import annotations + +import signal +import uuid + +import structlog + +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import ConnectionTestState +from airflow.sdk.definitions.connection import Connection as SDKConnection +from airflow.sdk.execution_time.context import _preset_connections + +__all__ = ["supervise_connection_test"] + +log = structlog.get_logger(logger_name="connection_test_supervisor") + + +def supervise_connection_test( + *, + connection_test_id: uuid.UUID, + connection_id: str, + timeout: int, + token: str, + server: str, +) -> int: + """Execute a connection test on the worker and report the result via the Execution API.""" + from airflow.models.connection import Connection Review Comment: 🚨 We cannot introduce new dependencies in TaskSDK to core airflow. We will need to do something else, like remove `test_connection()` from models.connection and move it in to the SDK version. ########## task-sdk/src/airflow/sdk/api/client.py: ########## @@ -872,6 +874,25 @@ def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse: return HITLDetailResponse.model_validate_json(resp.read()) +class ConnectionTestOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get_connection(self, connection_test_id: uuid.UUID) -> ConnectionResponse: + """Fetch connection data for a test request from the API server.""" + resp = self.client.get(f"connection-tests/{connection_test_id}/connection") + return ConnectionResponse.model_validate_json(resp.read()) Review Comment: Return type is `ConnectionResponse` but the endpoint (`GET /connection-tests/{id}/connection`) returns `ConnectionTestConnectionResponse`. Both models happen to have the same fields today so it works at runtime, but the annotation is wrong and will silently break if the two models ever diverge. ```suggestion def get_connection(self, connection_test_id: uuid.UUID) -> ConnectionTestConnectionResponse: """Fetch connection data for a test request from the API server.""" resp = self.client.get(f"connection-tests/{connection_test_id}/connection") return ConnectionTestConnectionResponse.model_validate_json(resp.read()) ``` --- Drafted-by: Claude Code (claude-sonnet-4-6); reviewed by @ashb before posting ########## task-sdk/src/airflow/sdk/execution_time/context.py: ########## @@ -144,7 +145,16 @@ def _convert_variable_result_to_variable(var_result: VariableResult, deserialize return Variable(**var_result.model_dump(exclude={"type"})) +_preset_connections: ContextVar[dict[str, Connection]] = ContextVar("_preset_connections", default={}) Review Comment: `default={}` is a single shared mutable object — every context that hasn't called `.set()` gets back the exact same dict from `.get()`. The current code never mutates it, so it doesn't blow up today, but it's a trap for anyone who writes `_preset_connections.get()[conn_id] = conn` instead of using `.set()`. Change to `default=None` and guard at call sites: ```python _preset_connections: ContextVar[dict[str, Connection] | None] = ContextVar("_preset_connections", default=None) # in _get_connection: preset = _preset_connections.get() if preset and conn_id in preset: ... ``` Also: `.set({preset.conn_id: preset})` replaces the entire dict, shadowing any outer preset for the duration of the call. If `test_connection()` is ever called from within a task that has its own preset set (e.g. test tooling), the outer entries disappear until `.reset()`. Merge instead of replace: ```python outer = _preset_connections.get() or {} token = _preset_connections.set({**outer, self.conn_id: self}) ``` --- Drafted-by: Claude Code (claude-sonnet-4-6); reviewed by @ashb before posting ########## airflow-core/src/airflow/api_fastapi/execution_api/security.py: ########## @@ -190,6 +190,14 @@ async def require_auth( detail="Token subject does not match task instance ID", ) + if "ct:self" in security_scopes.scopes: Review Comment: ```suggestion elsif "ct:self" in security_scopes.scopes: ``` I think? ti:self. and ct:self can't co-exist? ########## airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml: ########## @@ -1844,6 +1844,115 @@ paths: security: - OAuth2PasswordBearer: [] - HTTPBearer: [] + /api/v2/connections/test-async: + post: + tags: + - Connection + summary: Test Connection Async + description: 'Queue an async connection test to be executed on a worker. Review Comment: Yes this async, but also no this isn't async 😁 It's not hitting async python code. I think lets make the URL be `/api/v2/connections/enqueue-test` and just remove "async" from the desc? ########## airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py: ########## @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from uuid import UUID + +from cadwyn import VersionedAPIRouter +from fastapi import HTTPException, Response, Security, status + +from airflow.api_fastapi.auth.tokens import JWTGenerator +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.connection_test import ( + ConnectionTestConnectionResponse, + ConnectionTestResultBody, +) +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import DepContainer +from airflow.api_fastapi.execution_api.security import CurrentTIToken, ExecutionAPIRoute, require_auth +from airflow.models.connection_test import ( + TERMINAL_STATES, + ConnectionTestRequest, + ConnectionTestState, +) + +router = VersionedAPIRouter() + +ct_id_router = VersionedAPIRouter( + route_class=ExecutionAPIRoute, + dependencies=[ + Security(require_auth, scopes=["ct:self"]), + ], +) Review Comment: Two routers seems needless? ```suggestion router = VersionedAPIRouter( route_class=ExecutionAPIRoute, dependencies=[ Security(require_auth, scopes=["ct:self"]), ], ) ``` etc. ########## airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py: ########## @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from uuid import UUID + +from cadwyn import VersionedAPIRouter +from fastapi import HTTPException, Response, Security, status + +from airflow.api_fastapi.auth.tokens import JWTGenerator +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.connection_test import ( + ConnectionTestConnectionResponse, + ConnectionTestResultBody, +) +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import DepContainer +from airflow.api_fastapi.execution_api.security import CurrentTIToken, ExecutionAPIRoute, require_auth +from airflow.models.connection_test import ( + TERMINAL_STATES, + ConnectionTestRequest, + ConnectionTestState, +) + +router = VersionedAPIRouter() + +ct_id_router = VersionedAPIRouter( + route_class=ExecutionAPIRoute, + dependencies=[ + Security(require_auth, scopes=["ct:self"]), + ], +) + + +@ct_id_router.get( + "/{connection_test_id}/connection", Review Comment: I'm a little bit worried about this endpoint exposing or leaking the password somehow. I wonder if we should allow each ct to only be fetched once? ########## airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py: ########## @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from uuid import UUID + +from cadwyn import VersionedAPIRouter +from fastapi import HTTPException, Response, Security, status + +from airflow.api_fastapi.auth.tokens import JWTGenerator +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.connection_test import ( + ConnectionTestConnectionResponse, + ConnectionTestResultBody, +) +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import DepContainer +from airflow.api_fastapi.execution_api.security import CurrentTIToken, ExecutionAPIRoute, require_auth +from airflow.models.connection_test import ( + TERMINAL_STATES, + ConnectionTestRequest, + ConnectionTestState, +) + +router = VersionedAPIRouter() + +ct_id_router = VersionedAPIRouter( + route_class=ExecutionAPIRoute, + dependencies=[ + Security(require_auth, scopes=["ct:self"]), + ], +) + + +@ct_id_router.get( + "/{connection_test_id}/connection", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"}, + }, +) +def get_connection_test_connection( + connection_test_id: UUID, + session: SessionDep, +) -> ConnectionTestConnectionResponse: + """Return the connection data stored in a test request (called by workers).""" + ct = session.get(ConnectionTestRequest, connection_test_id) + if ct is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Connection test {connection_test_id} not found", + }, + ) + + return ConnectionTestConnectionResponse( + conn_id=ct.connection_id, + conn_type=ct.conn_type, + host=ct.host, + login=ct.login, + password=ct.password, + schema=ct.schema, + port=ct.port, + extra=ct.extra, + ) + + +@ct_id_router.patch( + "/{connection_test_id}", + status_code=status.HTTP_204_NO_CONTENT, + dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])], + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"}, + status.HTTP_409_CONFLICT: {"description": "Connection test already in a terminal state"}, + }, +) +def patch_connection_test( + connection_test_id: UUID, + body: ConnectionTestResultBody, + response: Response, + session: SessionDep, + services=DepContainer, + token: TIToken = CurrentTIToken, +) -> None: + """Update the result of a connection test (called by workers).""" + ct = session.get(ConnectionTestRequest, connection_test_id, with_for_update=True) + if ct is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Connection test {connection_test_id} not found", + }, + ) + + if ct.state in TERMINAL_STATES: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "conflict", + "message": f"Connection test {connection_test_id} is already in terminal state: {ct.state}", + }, + ) + + ct.state = body.state + ct.result_message = body.result_message + + if body.state == ConnectionTestState.SUCCESS and ct.commit_on_success: + ct.commit_to_connection_table(session=session) + + # JWTReissueMiddleware also writes Refreshed-API-Token but skips workload tokens, so we set it here for the workload→execution swap. + if token.claims.scope == "workload": + generator: JWTGenerator = services.get(JWTGenerator) + execution_token = generator.generate(extras={"sub": str(connection_test_id), "scope": "execution"}) + response.headers["Refreshed-API-Token"] = execution_token Review Comment: Why do we need this? Isn't this in practice only ever used when the conn test is finishing, so the token doesn't matter anymore ########## airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py: ########## Review Comment: This version is already released. We can't ret-conn it. You'll need to add a new migration instead. ########## airflow-core/src/airflow/config_templates/config.yml: ########## @@ -2602,6 +2602,33 @@ scheduler: type: float example: ~ default: "120.0" + connection_test_timeout: + description: | + Maximum number of seconds an async connection test is allowed to run + before it is considered timed out. The scheduler reaper uses this value + plus a grace period to mark stale tests as failed. + version_added: 3.3.0 + type: integer + example: ~ + default: "60" + max_connection_test_concurrency: Review Comment: Ditto. ########## airflow-core/src/airflow/config_templates/config.yml: ########## @@ -2602,6 +2602,33 @@ scheduler: type: float example: ~ default: "120.0" + connection_test_timeout: Review Comment: This does not belong under the `scheduler` section. ########## airflow-core/src/airflow/executors/base_executor.py: ########## @@ -216,6 +218,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.team_name: str | None = team_name self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {} + self.queued_connection_tests: dict[ConnectionTestKey, workloads.TestConnection] = {} Review Comment: Hmmmmmmm I'm really not liking the growth of this pattern. I wonder if there should be a single `queue` dict? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
