kaxil commented on code in PR #68533: URL: https://github.com/apache/airflow/pull/68533#discussion_r3409506360
########## devel-common/src/tests_common/test_utils/in_process_taskrun.py: ########## @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DB-free, xdist-safe execution of a task through a *real* supervisor socket. + +`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process and +has **no real socket**, so operators that spawn a subprocess which re-connects to +the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``, +``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``. + +This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery +(created explicitly for VirtualEnv operators) but serves every Execution-API call +from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so the +subprocess gets a working supervisor socket without touching the metadata DB. The +result: such tests need no ``@pytest.mark.db_test`` and run under xdist. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest import mock + +if TYPE_CHECKING: + from collections.abc import Callable + + from airflow.sdk.types import Operator + + +class _StubXComs: + """Dict-backed stand-in for ``client.xcoms`` (the only resource that must round-trip).""" + + def __init__(self) -> None: + self.store: dict[tuple, Any] = {} + + def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs): + self.store[(dag_id, run_id, task_id, key, map_index)] = value + + def get(self, dag_id, run_id, task_id, key, map_index, include_prior_dates=False): + from airflow.sdk.api.datamodels._generated import XComResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if (dag_id, run_id, task_id, key, map_index) in self.store: + return XComResponse(key=key, value=self.store[(dag_id, run_id, task_id, key, map_index)]) + return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND) + + def delete(self, *args, **kwargs): + return None + + +class _StubVariables: + def __init__(self, values: dict[str, Any] | None = None) -> None: + self.store = dict(values or {}) + + def get(self, key): + from airflow.sdk.api.datamodels._generated import VariableResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if key in self.store: + return VariableResponse(key=key, value=self.store[key]) + return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND) + + def set(self, key, value, description=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + return None + + +class _StubConnections: + def __init__(self, conns: dict[str, Any] | None = None) -> None: + self.store = dict(conns or {}) + + def get(self, conn_id): + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if conn_id in self.store: + return self.store[conn_id] + return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + + +class _InMemoryExecutionClient: Review Comment: Most of this in-memory client already exists, just not collected in one place: - `Client(dry_run=True)` installs [`noop_handler`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/api/client.py#L1098), a no-network/no-DB client that already backs `InProcessTestSupervisor` (its `_api_client` builds the client with [`dry_run=True`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/api/client.py#L1168)). The only thing it lacks is remembering writes: `noop_handler` discards them ("It doesn't make sense for returning connections etc."). - `run_task` already round-trips XCom in-memory by spying `XCom.set`/`XCom.get_one` ([pytest_plugin.py](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/devel-common/src/tests_common/pytest_plugin.py#L2731)), which is what `_StubXComs` re-implements. So the genuinely new capability is "a dry-run client that remembers." Teaching `noop_handler`/dry-run to round-trip from an in-memory dict would let this client and the three `_Stub*` classes collapse into the existing one instead of a test-only parallel. Separately, the `__getattr__` MagicMock fallback (L113) returns a MagicMock for any unmodeled resource (assets, dag_runs, task_store), so a future venv test that round-trips one of those would pass against a mock instead of real data. Raising on unmodeled names (keeping the `__`-dunder guard) makes that fail loudly. ########## devel-common/src/tests_common/test_utils/in_process_taskrun.py: ########## @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DB-free, xdist-safe execution of a task through a *real* supervisor socket. + +`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process and +has **no real socket**, so operators that spawn a subprocess which re-connects to +the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``, +``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``. + +This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery +(created explicitly for VirtualEnv operators) but serves every Execution-API call +from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so the +subprocess gets a working supervisor socket without touching the metadata DB. The +result: such tests need no ``@pytest.mark.db_test`` and run under xdist. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest import mock + +if TYPE_CHECKING: + from collections.abc import Callable + + from airflow.sdk.types import Operator + + +class _StubXComs: + """Dict-backed stand-in for ``client.xcoms`` (the only resource that must round-trip).""" + + def __init__(self) -> None: + self.store: dict[tuple, Any] = {} + + def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs): + self.store[(dag_id, run_id, task_id, key, map_index)] = value + + def get(self, dag_id, run_id, task_id, key, map_index, include_prior_dates=False): + from airflow.sdk.api.datamodels._generated import XComResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if (dag_id, run_id, task_id, key, map_index) in self.store: + return XComResponse(key=key, value=self.store[(dag_id, run_id, task_id, key, map_index)]) + return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND) + + def delete(self, *args, **kwargs): + return None + + +class _StubVariables: + def __init__(self, values: dict[str, Any] | None = None) -> None: + self.store = dict(values or {}) + + def get(self, key): + from airflow.sdk.api.datamodels._generated import VariableResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if key in self.store: + return VariableResponse(key=key, value=self.store[key]) + return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND) + + def set(self, key, value, description=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + return None + + +class _StubConnections: + def __init__(self, conns: dict[str, Any] | None = None) -> None: + self.store = dict(conns or {}) + + def get(self, conn_id): + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if conn_id in self.store: + return self.store[conn_id] + return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + + +class _InMemoryExecutionClient: + """In-memory stand-in for the Task SDK execution-API ``Client`` (no metadata DB).""" + + def __init__(self, ti_context, variables=None, connections=None) -> None: + self.task_instances = mock.MagicMock(name="stub.task_instances") + self.task_instances.start.return_value = ti_context + self.xcoms = _StubXComs() + self.variables = _StubVariables(variables) + self.connections = _StubConnections(connections) + + def __getattr__(self, name): + # Resources we don't model (assets, dag_runs, hitl, task_store, ...) are + # absorbed — venv operator tests don't exercise them. + if name.startswith("__"): + raise AttributeError(name) + return mock.MagicMock(name=f"stub_client.{name}") + + +class TaskRunResultNoDB: + """Result of :func:`run_task_no_db`, mirroring the ``run_task`` fixture surface.""" + + def __init__(self, result, client: _InMemoryExecutionClient, ti) -> None: + self._result = result + self.client = client + self._ti = ti + + @property + def state(self): + return self._result.state + + @property + def error(self): + return self._result.error + + @property + def msg(self): + return self._result.msg + + def xcom_get( + self, + key: str = "return_value", + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> Any: + task_id = task_id or self._ti.task_id + dag_id = dag_id or self._ti.dag_id + run_id = run_id or self._ti.run_id + map_index = map_index if map_index is not None else self._ti.map_index + return self.client.xcoms.store.get((dag_id, run_id, task_id, key, map_index)) + + +def run_task_no_db( + task: Operator, + create_runtime_ti: Callable[..., Any], + *, + logical_date: Any | None = None, + variables: dict[str, Any] | None = None, + connections: dict[str, Any] | None = None, +) -> TaskRunResultNoDB: + """Run *task* DB-free through the real-socket in-process supervisor.""" + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceDTO + from airflow.sdk.execution_time.supervisor import InProcessTestSupervisor + + ti_kwargs = {} if logical_date is None else {"logical_date": logical_date} + rti = create_runtime_ti(task, **ti_kwargs) + ti_context = rti._ti_context_from_server + + # `start()` model_dumps `what`; the plain DTO dumps cleanly, whereas the + # operator-laden RuntimeTaskInstance trips forward refs (RetryPolicy/WeightRuleParam). + what = TaskInstanceDTO( + id=rti.id, + task_id=rti.task_id, + dag_id=rti.dag_id, + run_id=rti.run_id, + try_number=rti.try_number, + map_index=rti.map_index, + dag_version_id=uuid7(), + queue="default", + ) + + client = _InMemoryExecutionClient(ti_context, variables=variables, connections=connections) + + class _StubBackendSupervisor(InProcessTestSupervisor): Review Comment: The "override `_api_client` to skip the DB" seam is already an established pattern ([test_supervisor.py](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/tests/task_sdk/execution_time/test_supervisor.py#L3441) does exactly this with a fake client), and [`run_task_in_process`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/execution_time/supervisor.py#L2181) already wraps `.start()`. Rather than a subclass per call, consider an optional `client=` param on `InProcessTestSupervisor.start()` (it currently hardcodes `client=cls._api_client(task.dag)` at [#L2038](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/execution_time/supervisor.py#L2038)). Callers inject the in-memory client, the subclass goes away, and `run_task_no_db` becomes a thin wrapper over `run_task_in_process`. ########## devel-common/src/tests_common/test_utils/in_process_taskrun.py: ########## @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DB-free, xdist-safe execution of a task through a *real* supervisor socket. + +`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process and +has **no real socket**, so operators that spawn a subprocess which re-connects to +the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``, +``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``. + +This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery +(created explicitly for VirtualEnv operators) but serves every Execution-API call +from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so the +subprocess gets a working supervisor socket without touching the metadata DB. The +result: such tests need no ``@pytest.mark.db_test`` and run under xdist. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest import mock + +if TYPE_CHECKING: + from collections.abc import Callable + + from airflow.sdk.types import Operator + + +class _StubXComs: + """Dict-backed stand-in for ``client.xcoms`` (the only resource that must round-trip).""" + + def __init__(self) -> None: + self.store: dict[tuple, Any] = {} + + def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs): + self.store[(dag_id, run_id, task_id, key, map_index)] = value + + def get(self, dag_id, run_id, task_id, key, map_index, include_prior_dates=False): + from airflow.sdk.api.datamodels._generated import XComResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if (dag_id, run_id, task_id, key, map_index) in self.store: + return XComResponse(key=key, value=self.store[(dag_id, run_id, task_id, key, map_index)]) + return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND) + + def delete(self, *args, **kwargs): + return None + + +class _StubVariables: + def __init__(self, values: dict[str, Any] | None = None) -> None: + self.store = dict(values or {}) + + def get(self, key): + from airflow.sdk.api.datamodels._generated import VariableResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if key in self.store: + return VariableResponse(key=key, value=self.store[key]) + return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND) + + def set(self, key, value, description=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + return None + + +class _StubConnections: + def __init__(self, conns: dict[str, Any] | None = None) -> None: + self.store = dict(conns or {}) + + def get(self, conn_id): + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if conn_id in self.store: + return self.store[conn_id] + return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + + +class _InMemoryExecutionClient: + """In-memory stand-in for the Task SDK execution-API ``Client`` (no metadata DB).""" + + def __init__(self, ti_context, variables=None, connections=None) -> None: + self.task_instances = mock.MagicMock(name="stub.task_instances") + self.task_instances.start.return_value = ti_context + self.xcoms = _StubXComs() + self.variables = _StubVariables(variables) + self.connections = _StubConnections(connections) + + def __getattr__(self, name): + # Resources we don't model (assets, dag_runs, hitl, task_store, ...) are + # absorbed — venv operator tests don't exercise them. + if name.startswith("__"): + raise AttributeError(name) + return mock.MagicMock(name=f"stub_client.{name}") + + +class TaskRunResultNoDB: Review Comment: This duplicates the existing [`TaskRunResult`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/execution_time/supervisor.py#L1953), which exposes the same `state`/`msg`/`error` and is exactly what `_StubBackendSupervisor.start()` returns (captured as `result` at L189). Only `xcom_get` is new. Returning the existing `TaskRunResult` plus a small xcom accessor avoids a second result type that has to be kept in sync with the first. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
