kaxil commented on code in PR #68533:
URL: https://github.com/apache/airflow/pull/68533#discussion_r3409506360


##########
devel-common/src/tests_common/test_utils/in_process_taskrun.py:
##########
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""DB-free, xdist-safe execution of a task through a *real* supervisor socket.
+
+`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process 
and
+has **no real socket**, so operators that spawn a subprocess which re-connects 
to
+the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``,
+``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``.
+
+This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery
+(created explicitly for VirtualEnv operators) but serves every Execution-API 
call
+from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so 
the
+subprocess gets a working supervisor socket without touching the metadata DB. 
The
+result: such tests need no ``@pytest.mark.db_test`` and run under xdist.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+from unittest import mock
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from airflow.sdk.types import Operator
+
+
+class _StubXComs:
+    """Dict-backed stand-in for ``client.xcoms`` (the only resource that must 
round-trip)."""
+
+    def __init__(self) -> None:
+        self.store: dict[tuple, Any] = {}
+
+    def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs):
+        self.store[(dag_id, run_id, task_id, key, map_index)] = value
+
+    def get(self, dag_id, run_id, task_id, key, map_index, 
include_prior_dates=False):
+        from airflow.sdk.api.datamodels._generated import XComResponse
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if (dag_id, run_id, task_id, key, map_index) in self.store:
+            return XComResponse(key=key, value=self.store[(dag_id, run_id, 
task_id, key, map_index)])
+        return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND)
+
+    def delete(self, *args, **kwargs):
+        return None
+
+
+class _StubVariables:
+    def __init__(self, values: dict[str, Any] | None = None) -> None:
+        self.store = dict(values or {})
+
+    def get(self, key):
+        from airflow.sdk.api.datamodels._generated import VariableResponse
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if key in self.store:
+            return VariableResponse(key=key, value=self.store[key])
+        return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND)
+
+    def set(self, key, value, description=None):
+        self.store[key] = value
+
+    def delete(self, key):
+        self.store.pop(key, None)
+        return None
+
+
+class _StubConnections:
+    def __init__(self, conns: dict[str, Any] | None = None) -> None:
+        self.store = dict(conns or {})
+
+    def get(self, conn_id):
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if conn_id in self.store:
+            return self.store[conn_id]
+        return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
+
+
+class _InMemoryExecutionClient:

Review Comment:
   Most of this in-memory client already exists, just not collected in one 
place:
   
   - `Client(dry_run=True)` installs 
[`noop_handler`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/api/client.py#L1098),
 a no-network/no-DB client that already backs `InProcessTestSupervisor` (its 
`_api_client` builds the client with 
[`dry_run=True`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/api/client.py#L1168)).
 The only thing it lacks is remembering writes: `noop_handler` discards them 
("It doesn't make sense for returning connections etc.").
   - `run_task` already round-trips XCom in-memory by spying 
`XCom.set`/`XCom.get_one` 
([pytest_plugin.py](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/devel-common/src/tests_common/pytest_plugin.py#L2731)),
 which is what `_StubXComs` re-implements.
   
   So the genuinely new capability is "a dry-run client that remembers." 
Teaching `noop_handler`/dry-run to round-trip from an in-memory dict would let 
this client and the three `_Stub*` classes collapse into the existing one 
instead of a test-only parallel.
   
   Separately, the `__getattr__` MagicMock fallback (L113) returns a MagicMock 
for any unmodeled resource (assets, dag_runs, task_store), so a future venv 
test that round-trips one of those would pass against a mock instead of real 
data. Raising on unmodeled names (keeping the `__`-dunder guard) makes that 
fail loudly.



##########
devel-common/src/tests_common/test_utils/in_process_taskrun.py:
##########
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""DB-free, xdist-safe execution of a task through a *real* supervisor socket.
+
+`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process 
and
+has **no real socket**, so operators that spawn a subprocess which re-connects 
to
+the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``,
+``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``.
+
+This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery
+(created explicitly for VirtualEnv operators) but serves every Execution-API 
call
+from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so 
the
+subprocess gets a working supervisor socket without touching the metadata DB. 
The
+result: such tests need no ``@pytest.mark.db_test`` and run under xdist.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+from unittest import mock
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from airflow.sdk.types import Operator
+
+
+class _StubXComs:
+    """Dict-backed stand-in for ``client.xcoms`` (the only resource that must 
round-trip)."""
+
+    def __init__(self) -> None:
+        self.store: dict[tuple, Any] = {}
+
+    def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs):
+        self.store[(dag_id, run_id, task_id, key, map_index)] = value
+
+    def get(self, dag_id, run_id, task_id, key, map_index, 
include_prior_dates=False):
+        from airflow.sdk.api.datamodels._generated import XComResponse
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if (dag_id, run_id, task_id, key, map_index) in self.store:
+            return XComResponse(key=key, value=self.store[(dag_id, run_id, 
task_id, key, map_index)])
+        return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND)
+
+    def delete(self, *args, **kwargs):
+        return None
+
+
+class _StubVariables:
+    def __init__(self, values: dict[str, Any] | None = None) -> None:
+        self.store = dict(values or {})
+
+    def get(self, key):
+        from airflow.sdk.api.datamodels._generated import VariableResponse
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if key in self.store:
+            return VariableResponse(key=key, value=self.store[key])
+        return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND)
+
+    def set(self, key, value, description=None):
+        self.store[key] = value
+
+    def delete(self, key):
+        self.store.pop(key, None)
+        return None
+
+
+class _StubConnections:
+    def __init__(self, conns: dict[str, Any] | None = None) -> None:
+        self.store = dict(conns or {})
+
+    def get(self, conn_id):
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if conn_id in self.store:
+            return self.store[conn_id]
+        return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
+
+
+class _InMemoryExecutionClient:
+    """In-memory stand-in for the Task SDK execution-API ``Client`` (no 
metadata DB)."""
+
+    def __init__(self, ti_context, variables=None, connections=None) -> None:
+        self.task_instances = mock.MagicMock(name="stub.task_instances")
+        self.task_instances.start.return_value = ti_context
+        self.xcoms = _StubXComs()
+        self.variables = _StubVariables(variables)
+        self.connections = _StubConnections(connections)
+
+    def __getattr__(self, name):
+        # Resources we don't model (assets, dag_runs, hitl, task_store, ...) 
are
+        # absorbed — venv operator tests don't exercise them.
+        if name.startswith("__"):
+            raise AttributeError(name)
+        return mock.MagicMock(name=f"stub_client.{name}")
+
+
+class TaskRunResultNoDB:
+    """Result of :func:`run_task_no_db`, mirroring the ``run_task`` fixture 
surface."""
+
+    def __init__(self, result, client: _InMemoryExecutionClient, ti) -> None:
+        self._result = result
+        self.client = client
+        self._ti = ti
+
+    @property
+    def state(self):
+        return self._result.state
+
+    @property
+    def error(self):
+        return self._result.error
+
+    @property
+    def msg(self):
+        return self._result.msg
+
+    def xcom_get(
+        self,
+        key: str = "return_value",
+        task_id: str | None = None,
+        dag_id: str | None = None,
+        run_id: str | None = None,
+        map_index: int | None = None,
+    ) -> Any:
+        task_id = task_id or self._ti.task_id
+        dag_id = dag_id or self._ti.dag_id
+        run_id = run_id or self._ti.run_id
+        map_index = map_index if map_index is not None else self._ti.map_index
+        return self.client.xcoms.store.get((dag_id, run_id, task_id, key, 
map_index))
+
+
+def run_task_no_db(
+    task: Operator,
+    create_runtime_ti: Callable[..., Any],
+    *,
+    logical_date: Any | None = None,
+    variables: dict[str, Any] | None = None,
+    connections: dict[str, Any] | None = None,
+) -> TaskRunResultNoDB:
+    """Run *task* DB-free through the real-socket in-process supervisor."""
+    from uuid6 import uuid7
+
+    from airflow.sdk.api.datamodels._generated import TaskInstance as 
TaskInstanceDTO
+    from airflow.sdk.execution_time.supervisor import InProcessTestSupervisor
+
+    ti_kwargs = {} if logical_date is None else {"logical_date": logical_date}
+    rti = create_runtime_ti(task, **ti_kwargs)
+    ti_context = rti._ti_context_from_server
+
+    # `start()` model_dumps `what`; the plain DTO dumps cleanly, whereas the
+    # operator-laden RuntimeTaskInstance trips forward refs 
(RetryPolicy/WeightRuleParam).
+    what = TaskInstanceDTO(
+        id=rti.id,
+        task_id=rti.task_id,
+        dag_id=rti.dag_id,
+        run_id=rti.run_id,
+        try_number=rti.try_number,
+        map_index=rti.map_index,
+        dag_version_id=uuid7(),
+        queue="default",
+    )
+
+    client = _InMemoryExecutionClient(ti_context, variables=variables, 
connections=connections)
+
+    class _StubBackendSupervisor(InProcessTestSupervisor):

Review Comment:
   The "override `_api_client` to skip the DB" seam is already an established 
pattern 
([test_supervisor.py](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/tests/task_sdk/execution_time/test_supervisor.py#L3441)
 does exactly this with a fake client), and 
[`run_task_in_process`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/execution_time/supervisor.py#L2181)
 already wraps `.start()`. Rather than a subclass per call, consider an 
optional `client=` param on `InProcessTestSupervisor.start()` (it currently 
hardcodes `client=cls._api_client(task.dag)` at 
[#L2038](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/execution_time/supervisor.py#L2038)).
 Callers inject the in-memory client, the subclass goes away, and 
`run_task_no_db` becomes a thin wrapper over `run_task_in_process`.



##########
devel-common/src/tests_common/test_utils/in_process_taskrun.py:
##########
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""DB-free, xdist-safe execution of a task through a *real* supervisor socket.
+
+`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process 
and
+has **no real socket**, so operators that spawn a subprocess which re-connects 
to
+the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``,
+``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``.
+
+This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery
+(created explicitly for VirtualEnv operators) but serves every Execution-API 
call
+from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so 
the
+subprocess gets a working supervisor socket without touching the metadata DB. 
The
+result: such tests need no ``@pytest.mark.db_test`` and run under xdist.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+from unittest import mock
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from airflow.sdk.types import Operator
+
+
+class _StubXComs:
+    """Dict-backed stand-in for ``client.xcoms`` (the only resource that must 
round-trip)."""
+
+    def __init__(self) -> None:
+        self.store: dict[tuple, Any] = {}
+
+    def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs):
+        self.store[(dag_id, run_id, task_id, key, map_index)] = value
+
+    def get(self, dag_id, run_id, task_id, key, map_index, 
include_prior_dates=False):
+        from airflow.sdk.api.datamodels._generated import XComResponse
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if (dag_id, run_id, task_id, key, map_index) in self.store:
+            return XComResponse(key=key, value=self.store[(dag_id, run_id, 
task_id, key, map_index)])
+        return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND)
+
+    def delete(self, *args, **kwargs):
+        return None
+
+
+class _StubVariables:
+    def __init__(self, values: dict[str, Any] | None = None) -> None:
+        self.store = dict(values or {})
+
+    def get(self, key):
+        from airflow.sdk.api.datamodels._generated import VariableResponse
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if key in self.store:
+            return VariableResponse(key=key, value=self.store[key])
+        return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND)
+
+    def set(self, key, value, description=None):
+        self.store[key] = value
+
+    def delete(self, key):
+        self.store.pop(key, None)
+        return None
+
+
+class _StubConnections:
+    def __init__(self, conns: dict[str, Any] | None = None) -> None:
+        self.store = dict(conns or {})
+
+    def get(self, conn_id):
+        from airflow.sdk.exceptions import ErrorType
+        from airflow.sdk.execution_time.comms import ErrorResponse
+
+        if conn_id in self.store:
+            return self.store[conn_id]
+        return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
+
+
+class _InMemoryExecutionClient:
+    """In-memory stand-in for the Task SDK execution-API ``Client`` (no 
metadata DB)."""
+
+    def __init__(self, ti_context, variables=None, connections=None) -> None:
+        self.task_instances = mock.MagicMock(name="stub.task_instances")
+        self.task_instances.start.return_value = ti_context
+        self.xcoms = _StubXComs()
+        self.variables = _StubVariables(variables)
+        self.connections = _StubConnections(connections)
+
+    def __getattr__(self, name):
+        # Resources we don't model (assets, dag_runs, hitl, task_store, ...) 
are
+        # absorbed — venv operator tests don't exercise them.
+        if name.startswith("__"):
+            raise AttributeError(name)
+        return mock.MagicMock(name=f"stub_client.{name}")
+
+
+class TaskRunResultNoDB:

Review Comment:
   This duplicates the existing 
[`TaskRunResult`](https://github.com/apache/airflow/blob/06d4b1ea7d6e7795f1724662598e1f6885df0ac8/task-sdk/src/airflow/sdk/execution_time/supervisor.py#L1953),
 which exposes the same `state`/`msg`/`error` and is exactly what 
`_StubBackendSupervisor.start()` returns (captured as `result` at L189). Only 
`xcom_get` is new. Returning the existing `TaskRunResult` plus a small xcom 
accessor avoids a second result type that has to be kept in sync with the first.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to