This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 23a3c33c189 Add SubprocessCoordinator (#67635)
23a3c33c189 is described below
commit 23a3c33c18958eff7d626595afc6d3a21db40bd6
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Wed Jun 3 10:38:09 2026 +0800
Add SubprocessCoordinator (#67635)
* Add SocketCoordinator
* Make JavaCoordinator inherit SocketCoordinator
* Name the coordinator base after its subprocess role, not its socket
transport
- Rename SocketCoordinator -> SubprocessCoordinator and
_SocketActivitySubprocess
-> _PopenActivitySubprocess: the TCP transport is an implementation
detail that
subclasses don't care about, while the subprocess they launch is the
point.
- Flatten the coordinators/socket/ package into a single private
coordinators/_subprocess.py (also drops the stray subpackage __version__).
- Spec the mock_client test fixture with Client / TaskInstanceOperations so
the
client call surface is checked.
---
.../{java/coordinator.py => _subprocess.py} | 269 +++-------
.../airflow/sdk/coordinators/java/coordinator.py | 253 +---------
.../task_sdk/coordinators/java/test_coordinator.py | 434 +---------------
.../test_coordinator.py => test_subprocess.py} | 544 ++++++---------------
4 files changed, 241 insertions(+), 1259 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
similarity index 51%
copy from task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
copy to task-sdk/src/airflow/sdk/coordinators/_subprocess.py
index e91a26006c5..da3ba5243a5 100644
--- a/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
+++ b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
@@ -15,32 +15,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Java runtime coordinator that launches a JVM subprocess for Dag file
processing and task execution."""
+"""
+Common subprocess coordinator scaffolding.
+
+Coordinators that launch a subprocess and communicate with it over two TCP
+sockets (``--comm`` and ``--logs``) — Java, native executables, and any
+future runtime that follows the same wire convention — can subclass
+:class:`SubprocessCoordinator` and reuse the resource-tracking, accept, and
+draining machinery in this module rather than re-implementing it.
+"""
from __future__ import annotations
-import email
import itertools
import os
-import pathlib
import selectors
import signal
import socket
-import stat
import subprocess
import time
-import zipfile
from typing import TYPE_CHECKING, TypeVar, cast
import attrs
import structlog
from airflow.sdk.execution_time.coordinator import BaseCoordinator
-from airflow.sdk.execution_time.schema import get_schema_version_migrator
from airflow.sdk.execution_time.supervisor import ActivitySubprocess,
NeverRaised, ProcessTracker
if TYPE_CHECKING:
- from collections.abc import Iterable, Iterator, Sequence
+ from collections.abc import Sequence
from structlog.typing import FilteringBoundLogger
from typing_extensions import Self
@@ -51,7 +54,7 @@ if TYPE_CHECKING:
Tracked = TypeVar("Tracked", socket.socket, subprocess.Popen)
-log: FilteringBoundLogger =
structlog.get_logger(logger_name="coordinators.java")
+log: FilteringBoundLogger =
structlog.get_logger(logger_name="coordinators.subprocess")
def _start_server() -> socket.socket:
@@ -62,119 +65,6 @@ def _start_server() -> socket.socket:
return server
-def _find_jars(items: Iterable[pathlib.Path]) -> Iterator[pathlib.Path]:
- """
- Yield JAR files under *items*, descending into directories.
-
- A symlink loop or a directory that hardlinks into one of its ancestors
- would otherwise recurse until the interpreter stack is exhausted, so
- directories are deduplicated by ``(st_dev, st_ino)`` for the duration
- of a single scan.
- """
- seen_dirs: set[tuple[int, int]] = set()
- yield from _walk_jars(items, seen_dirs)
-
-
-def _walk_jars(items: Iterable[pathlib.Path], seen_dirs: set[tuple[int, int]])
-> Iterator[pathlib.Path]:
- for item in items:
- try:
- st = item.stat()
- except OSError:
- continue
- if stat.S_ISDIR(st.st_mode):
- key = (st.st_dev, st.st_ino)
- if key in seen_dirs:
- log.debug("Skipping already-visited directory", path=item)
- continue
- seen_dirs.add(key)
- yield from _walk_jars(_iter_dir(item), seen_dirs)
- elif stat.S_ISREG(st.st_mode) and item.suffix == ".jar":
- yield item
-
-
-def _iter_dir(directory: pathlib.Path) -> Iterator[pathlib.Path]:
- # iterdir() is lazy, so an unreadable directory raises only once iteration
- # starts; swallow it here so a single bad directory does not abort the
scan.
- try:
- yield from directory.iterdir()
- except OSError:
- return
-
-
-def _calculate_classpath(jars_root: Sequence[pathlib.Path]) -> str:
- jars = (p.as_posix() for p in _find_jars(jars_root))
- return os.pathsep.join(sorted(jars)) # Keep output deterministic.
-
-
[email protected]
-class _JarMetadata:
- main_class: str
- schema_version: str
-
- @classmethod
- def from_jar(cls, path: pathlib.Path) -> Self | None:
- try:
- with zipfile.ZipFile(path) as zf:
- try:
- manifest_info = zf.getinfo("META-INF/MANIFEST.MF")
- except KeyError:
- log.debug("JAR does not contain META-INF/MANIFEST.MF;
ignored", path=path)
- return None
- with zf.open(manifest_info) as f:
- manifest = email.message_from_binary_file(f)
- return cls(manifest["Main-Class"],
manifest["Airflow-Supervisor-Schema-Version"])
- except zipfile.BadZipFile:
- log.exception("Cannot read JAR; ignored", path=path)
- return None
-
-
-def _validate_schema_version(instance, _, value) -> str:
- return get_schema_version_migrator().resolve_version(str(value))
-
-
[email protected]
-class _JarInfo:
- main_class: str
- schema_version: str = attrs.field(validator=_validate_schema_version)
-
- @attrs.define
- class _Progress:
- main_class: str | None = attrs.field(init=False, default=None)
- schema_version: str | None = attrs.field(init=False, default=None)
-
- def collect(self) -> _JarInfo | None:
- if self.main_class is None or self.schema_version is None:
- return None
- return _JarInfo(self.main_class, self.schema_version)
-
- @classmethod
- def find(cls, roots: Sequence[pathlib.Path], main_class: str) -> _JarInfo:
- log.debug("Finding JARs recursively", roots=roots)
- progress = cls._Progress()
- for p in _find_jars(roots):
- if (metadata := _JarMetadata.from_jar(p)) is None:
- continue
- if metadata.main_class and ((main_class == metadata.main_class) or
not main_class):
- log.debug("JAR located with Main-Class metadata", path=p,
main_class=metadata.main_class)
- progress.main_class = metadata.main_class
- if metadata.schema_version:
- log.debug(
- "JAR located with Airflow-Supervisor-Schema-Version
metadata",
- path=p,
- schema_version=metadata.schema_version,
- )
- progress.schema_version = metadata.schema_version
- if (result := progress.collect()) is not None:
- return result
- if progress.main_class is not None:
- tp = "cannot find a JAR with Airflow-Supervisor-Schema-Version
metadata in {1}"
- elif main_class:
- tp = "cannot find a JAR with Main-Class matching {0!r} in {1}"
- else:
- tp = "cannot find a JAR with Main-Class metadata in {1}"
- raise FileNotFoundError(tp.format(main_class,
os.pathsep.join(os.fspath(p.resolve()) for p in roots)))
-
-
def _accept_connections(
servers: dict[str, socket.socket],
drains: dict[str, socket.socket],
@@ -183,7 +73,7 @@ def _accept_connections(
max_wait: float = 10.0,
drain_size: int = 4096,
) -> tuple[dict[socket.socket, socket.socket], dict[socket.socket, bytes]]:
- """Block until the Java process connects to servers."""
+ """Block until the subprocess connects to servers, draining stdout/stderr
along the way."""
accepted: dict[socket.socket, socket.socket] = {}
drained: dict[socket.socket, bytes] = {s: b"" for s in drains.values()}
with selectors.DefaultSelector() as sel:
@@ -243,6 +133,16 @@ class PopenTracker(ProcessTracker):
@attrs.define(kw_only=True)
class _ResourceTracker:
+ """
+ Context manager that auto-closes tracked sockets and terminates tracked
Popen objects.
+
+ A subprocess startup is built up incrementally: bind sockets, spawn the
+ child, accept its connections. If any step fails, the half-set-up state
+ must be released. Calling :meth:`track` after each successful step records
+ what to release; :meth:`untrack` removes ownership once another component
+ (e.g. the activity subprocess instance) has taken over.
+ """
+
timeout: float
tracked: dict[int, socket.socket | subprocess.Popen] =
attrs.field(init=False, factory=dict)
@@ -272,8 +172,17 @@ class _ResourceTracker:
@attrs.define(kw_only=True)
-class _JavaActivitySubprocess(ActivitySubprocess):
- """Java task runner process."""
+class _PopenActivitySubprocess(ActivitySubprocess):
+ """
+ Activity subprocess that talks to the parent over two TCP sockets.
+
+ The subclass-supplied *command* is launched with ``--comm=<host:port>``
+ and ``--logs=<host:port>`` appended; the subprocess MUST connect back to
+ both ports before *startup_timeout* elapses. Anything the subprocess
+ writes to stdout/stderr before connecting is drained and forwarded to
+ :meth:`_register_pipe_readers` via the ``data=`` kwarg so log lines are
+ not lost.
+ """
_comm_server: socket.socket
_logs_server: socket.socket
@@ -287,14 +196,11 @@ class _JavaActivitySubprocess(ActivitySubprocess):
bundle_info,
logger: FilteringBoundLogger | None = None,
sentry_integration: str = "",
- java_executable: str,
- jvm_args: list[str],
- jars_root: Sequence[pathlib.Path],
- main_class: str,
+ command: Sequence[str],
+ subprocess_schema_version: str | None = None,
startup_timeout: float = 10.0,
**kwargs,
) -> Self:
- jar = _JarInfo.find(jars_root, main_class)
with _ResourceTracker(timeout=startup_timeout) as tracker:
comm_server, logs_server = tracker.track(_start_server(),
_start_server())
stdout_r, stdout_w = tracker.track(*socket.socketpair())
@@ -302,12 +208,7 @@ class _JavaActivitySubprocess(ActivitySubprocess):
proc = subprocess.Popen(
[
- java_executable,
- "-classpath",
- _calculate_classpath(jars_root),
- *jvm_args,
- jar.main_class,
- # Arguments to MainClass...
+ *command,
"--comm={0[0]}:{0[1]}".format(comm_server.getsockname()),
"--logs={0[0]}:{0[1]}".format(logs_server.getsockname()),
],
@@ -334,7 +235,7 @@ class _JavaActivitySubprocess(ActivitySubprocess):
process_log=logger or
structlog.get_logger(logger_name="task").bind(),
start_time=time.monotonic(),
stdin=socks[comm_server],
- subprocess_schema_version=jar.schema_version,
+ subprocess_schema_version=subprocess_schema_version,
comm_server=comm_server,
logs_server=logs_server,
**kwargs,
@@ -350,8 +251,8 @@ class _JavaActivitySubprocess(ActivitySubprocess):
sentry_integration=sentry_integration,
)
- # Untrack everything left. 'self' keeps track of these and close
the
- # servers when the subprocess exits in 'wait'.
+ # Untrack everything left. 'self' keeps track of these and closes
+ # the servers when the subprocess exits in 'wait'.
tracker.untrack(comm_server, logs_server, proc)
return self
@@ -362,70 +263,37 @@ class _JavaActivitySubprocess(ActivitySubprocess):
return code
-def _convert_jars_root(
- value: None | os.PathLike[str] | pathlib.Path | list[os.PathLike[str] |
pathlib.Path],
-) -> list[pathlib.Path]:
- if value is None:
- return []
- if isinstance(value, (str, os.PathLike, pathlib.Path)):
- return [pathlib.Path(value).expanduser()]
- return [pathlib.Path(v).expanduser() for v in value]
-
-
@attrs.define(kw_only=True)
-class JavaCoordinator(BaseCoordinator):
+class SubprocessCoordinator(BaseCoordinator):
"""
- Coordinator that launches a JVM subprocess for DAG parsing and task
execution.
-
- Configuration is taken from the ``[sdk] coordinators`` entry that
constructs
- this instance::
-
- {
- "name": "jdk-17",
- "classpath": "airflow.sdk.coordinators.java.JavaCoordinator",
- "kwargs": {
- "java_executable": "/usr/lib/jvm/java-17-openjdk/bin/java",
- "jvm_args": ["-Xmx1024m"],
- "jars_root": ["~/airflow/jars"],
- },
- }
-
- :param java_executable: Path to the ``java`` command (defaults to
- ``"java"``, which relies on ``$PATH``).
- :param jvm_args: Extra arguments passed to the JVM (e.g. ``["-Xmx512m"]``).
- :param jars_root: A list of directories scanned for JAR bundles.
- :param main_class: Explicit entry point to execute with *java_executable*.
- :param task_startup_timeout: Maximum time the coordinator waits for a task
- process to start, in seconds. The default is 10 seconds.
-
- If *main_class* is not explicitly set, JavaCoordinator scans *jars_root* to
- find an executable JAR (one with Main-Class set in its metadata). If more
- than one executable JAR is found, it may be nondeterministic which one ends
- up being executed.
-
- A JAR containing metadata *Airflow-Supervisor-Schema-Version* should also
be
- available to specify the wire schema version. The JAR containing the Java
- SDK automatically sets this, so you don't generally need to do anything if
- dependency JARs are deployed as-is. If you repackage the dependencies,
- however, you must also reproduce the metadata entry in one of the JARs.
-
- The default *task_startup_timeout* should plenty long enough since a task-
- containing JAR is not supposed to consume significant time to perform setup
- (it should happen in individual tasks instead). However, if the launch time
- has to be so slow, you can increase the timeout to give the JAR more time.
- Note that decreasing the value is generally not meaningful since the
- coordinator does not need to wait for the full period.
+ Abstract base for coordinators that launch a subprocess and IPC over TCP
sockets.
+
+ Subclasses provide the per-task subprocess command and the supervisor
+ wire-schema version via :meth:`_build_execute_task_command`. The rest of
+ the socket lifecycle — listening, spawning the child, accepting
+ connections, draining startup output, and tearing everything down on
+ failure — is handled here.
+
+ :param task_startup_timeout: Maximum time the coordinator waits for the
+ subprocess to connect to both servers, in seconds. The default is 10
+ seconds.
"""
- java_executable: str = "java"
- jvm_args: list[str] = attrs.field(factory=list)
- jars_root: list[pathlib.Path] = attrs.field(
- converter=_convert_jars_root,
- validator=attrs.validators.min_len(1),
- )
- main_class: str = ""
task_startup_timeout: float = 10.0
+ def _build_execute_task_command(self, *, what: TaskInstanceDTO) ->
tuple[list[str], str | None]:
+ """
+ Build the subprocess command and resolve its supervisor wire-schema
version for *what*.
+
+ Returns a ``(command, subprocess_schema_version)`` pair. *command*
+ MUST NOT include the ``--comm`` / ``--logs`` flags — those are
+ appended by :class:`_PopenActivitySubprocess` once the listening
+ sockets have been bound. A ``None`` schema version disables schema
+ migration; messages are then exchanged at the runtime's native wire
+ format.
+ """
+ raise NotImplementedError
+
def execute_task(
self,
*,
@@ -438,7 +306,8 @@ class JavaCoordinator(BaseCoordinator):
subprocess_logs_to_stdout: bool,
**kwargs,
) -> BaseCoordinator.ExecutionResult:
- process = _JavaActivitySubprocess.start(
+ command, subprocess_schema_version =
self._build_execute_task_command(what=what)
+ process = _PopenActivitySubprocess.start(
what=what,
dag_rel_path=dag_rel_path,
bundle_info=bundle_info,
@@ -446,10 +315,8 @@ class JavaCoordinator(BaseCoordinator):
logger=logger,
subprocess_logs_to_stdout=subprocess_logs_to_stdout,
sentry_integration=sentry_integration,
- java_executable=self.java_executable,
- jvm_args=self.jvm_args,
- jars_root=self.jars_root,
- main_class=self.main_class,
+ command=command,
+ subprocess_schema_version=subprocess_schema_version,
startup_timeout=self.task_startup_timeout,
)
exit_code = process.wait()
diff --git a/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
b/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
index e91a26006c5..d6aebe707f5 100644
--- a/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
+++ b/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
@@ -20,24 +20,17 @@
from __future__ import annotations
import email
-import itertools
import os
import pathlib
-import selectors
-import signal
-import socket
import stat
-import subprocess
-import time
import zipfile
-from typing import TYPE_CHECKING, TypeVar, cast
+from typing import TYPE_CHECKING
import attrs
import structlog
-from airflow.sdk.execution_time.coordinator import BaseCoordinator
+from airflow.sdk.coordinators._subprocess import SubprocessCoordinator
from airflow.sdk.execution_time.schema import get_schema_version_migrator
-from airflow.sdk.execution_time.supervisor import ActivitySubprocess,
NeverRaised, ProcessTracker
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
@@ -45,23 +38,11 @@ if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger
from typing_extensions import Self
- from airflow.sdk.api.client import Client
- from airflow.sdk.api.datamodels._generated import BundleInfo
from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
- Tracked = TypeVar("Tracked", socket.socket, subprocess.Popen)
-
log: FilteringBoundLogger =
structlog.get_logger(logger_name="coordinators.java")
-def _start_server() -> socket.socket:
- server = socket.socket()
- server.bind(("127.0.0.1", 0))
- server.setblocking(True)
- server.listen(1) # Just need to listen to the child process.
- return server
-
-
def _find_jars(items: Iterable[pathlib.Path]) -> Iterator[pathlib.Path]:
"""
Yield JAR files under *items*, descending into directories.
@@ -175,193 +156,6 @@ class _JarInfo:
raise FileNotFoundError(tp.format(main_class,
os.pathsep.join(os.fspath(p.resolve()) for p in roots)))
-def _accept_connections(
- servers: dict[str, socket.socket],
- drains: dict[str, socket.socket],
- proc: subprocess.Popen,
- *,
- max_wait: float = 10.0,
- drain_size: int = 4096,
-) -> tuple[dict[socket.socket, socket.socket], dict[socket.socket, bytes]]:
- """Block until the Java process connects to servers."""
- accepted: dict[socket.socket, socket.socket] = {}
- drained: dict[socket.socket, bytes] = {s: b"" for s in drains.values()}
- with selectors.DefaultSelector() as sel:
- for key, soc in itertools.chain(servers.items(), drains.items()):
- sel.register(soc, selectors.EVENT_READ, data=key)
- deadline = time.monotonic() + max_wait
- while len(accepted) < len(servers):
- remaining = deadline - time.monotonic()
- if remaining <= 0:
- for s in accepted.values():
- s.close()
- raise TimeoutError("process did not connect within timeout")
- if proc.poll() is not None:
- for s in accepted.values():
- s.close()
- raise RuntimeError(f"process exited with {proc.returncode}
before connecting")
- for event, _ in sel.select(timeout=min(remaining, 1.0)):
- soc = cast("socket.socket", event.fileobj)
- if soc in drained:
- if incoming := soc.recv(drain_size):
- log.debug("Draining child process stream",
key=event.data)
- drained[soc] += incoming
- else:
- log.warning("Child stream closed before ready!",
key=event.data)
- sel.unregister(soc)
- else:
- log.debug("Accepting child process connection",
key=event.data)
- conn, _ = soc.accept()
- sel.unregister(soc)
- accepted[soc] = conn
- return accepted, drained
-
-
-class PopenTracker(ProcessTracker):
- """
- Process tracker backed by :class:`subprocess.Popen`.
-
- :meta private:
- """
-
- ProcessNotFound = NeverRaised
- TimeoutExpired = subprocess.TimeoutExpired
-
- def __init__(self, impl: subprocess.Popen) -> None:
- self._impl = impl
-
- @property
- def pid(self) -> int:
- return self._impl.pid
-
- def send_signal(self, s: signal.Signals) -> None:
- self._impl.send_signal(s)
-
- def wait(self, timeout: float | None) -> int:
- return self._impl.wait(timeout)
-
-
[email protected](kw_only=True)
-class _ResourceTracker:
- timeout: float
- tracked: dict[int, socket.socket | subprocess.Popen] =
attrs.field(init=False, factory=dict)
-
- def __enter__(self):
- return self
-
- def __exit__(self, *exc_info):
- for o in self.tracked.values():
- match o:
- case socket.socket():
- o.close()
- case subprocess.Popen():
- o.terminate()
- try:
- o.wait(self.timeout)
- except subprocess.TimeoutExpired:
- o.kill()
-
- def track(self, *objects: Tracked) -> tuple[Tracked, ...]:
- self.tracked.update((id(o), o) for o in objects)
- return objects
-
- def untrack(self, *objects: Tracked) -> tuple[Tracked, ...]:
- for o in objects:
- self.tracked.pop(id(o), None)
- return objects
-
-
[email protected](kw_only=True)
-class _JavaActivitySubprocess(ActivitySubprocess):
- """Java task runner process."""
-
- _comm_server: socket.socket
- _logs_server: socket.socket
-
- @classmethod
- def start( # type: ignore[override]
- cls,
- *,
- what: TaskInstanceDTO,
- dag_rel_path: str | os.PathLike[str],
- bundle_info,
- logger: FilteringBoundLogger | None = None,
- sentry_integration: str = "",
- java_executable: str,
- jvm_args: list[str],
- jars_root: Sequence[pathlib.Path],
- main_class: str,
- startup_timeout: float = 10.0,
- **kwargs,
- ) -> Self:
- jar = _JarInfo.find(jars_root, main_class)
- with _ResourceTracker(timeout=startup_timeout) as tracker:
- comm_server, logs_server = tracker.track(_start_server(),
_start_server())
- stdout_r, stdout_w = tracker.track(*socket.socketpair())
- stderr_r, stderr_w = tracker.track(*socket.socketpair())
-
- proc = subprocess.Popen(
- [
- java_executable,
- "-classpath",
- _calculate_classpath(jars_root),
- *jvm_args,
- jar.main_class,
- # Arguments to MainClass...
- "--comm={0[0]}:{0[1]}".format(comm_server.getsockname()),
- "--logs={0[0]}:{0[1]}".format(logs_server.getsockname()),
- ],
- stdout=stdout_w.fileno(),
- stderr=stderr_w.fileno(),
- )
- tracker.track(proc)
- for soc in tracker.untrack(stdout_w, stderr_w):
- soc.close()
- log.info("Starting subprocess", pid=proc.pid)
-
- socks, drained = _accept_connections(
- {"comm": comm_server, "logs": logs_server},
- {"stdout": stdout_r, "stderr": stderr_r},
- proc,
- max_wait=startup_timeout,
- )
- tracker.track(*socks.values())
-
- self = cls(
- id=what.id,
- pid=proc.pid,
- process=PopenTracker(proc),
- process_log=logger or
structlog.get_logger(logger_name="task").bind(),
- start_time=time.monotonic(),
- stdin=socks[comm_server],
- subprocess_schema_version=jar.schema_version,
- comm_server=comm_server,
- logs_server=logs_server,
- **kwargs,
- )
- self._register_pipe_readers(
- *tracker.untrack(stdout_r, stderr_r, socks[comm_server],
socks[logs_server]),
- data=drained,
- )
- self._on_child_started(
- ti=what,
- dag_rel_path=dag_rel_path,
- bundle_info=bundle_info,
- sentry_integration=sentry_integration,
- )
-
- # Untrack everything left. 'self' keeps track of these and close
the
- # servers when the subprocess exits in 'wait'.
- tracker.untrack(comm_server, logs_server, proc)
-
- return self
-
- def wait(self) -> int:
- code = super().wait()
- self._close_unused_sockets(self._comm_server, self._logs_server)
- return code
-
-
def _convert_jars_root(
value: None | os.PathLike[str] | pathlib.Path | list[os.PathLike[str] |
pathlib.Path],
) -> list[pathlib.Path]:
@@ -373,7 +167,7 @@ def _convert_jars_root(
@attrs.define(kw_only=True)
-class JavaCoordinator(BaseCoordinator):
+class JavaCoordinator(SubprocessCoordinator):
"""
Coordinator that launches a JVM subprocess for DAG parsing and task
execution.
@@ -424,33 +218,14 @@ class JavaCoordinator(BaseCoordinator):
validator=attrs.validators.min_len(1),
)
main_class: str = ""
- task_startup_timeout: float = 10.0
-
- def execute_task(
- self,
- *,
- what: TaskInstanceDTO,
- dag_rel_path: str | os.PathLike[str],
- bundle_info: BundleInfo,
- client: Client,
- logger: FilteringBoundLogger | None = None,
- sentry_integration: str = "",
- subprocess_logs_to_stdout: bool,
- **kwargs,
- ) -> BaseCoordinator.ExecutionResult:
- process = _JavaActivitySubprocess.start(
- what=what,
- dag_rel_path=dag_rel_path,
- bundle_info=bundle_info,
- client=client,
- logger=logger,
- subprocess_logs_to_stdout=subprocess_logs_to_stdout,
- sentry_integration=sentry_integration,
- java_executable=self.java_executable,
- jvm_args=self.jvm_args,
- jars_root=self.jars_root,
- main_class=self.main_class,
- startup_timeout=self.task_startup_timeout,
- )
- exit_code = process.wait()
- return self.ExecutionResult(exit_code, process.final_state)
+
+ def _build_execute_task_command(self, *, what: TaskInstanceDTO) ->
tuple[list[str], str | None]:
+ jar = _JarInfo.find(self.jars_root, self.main_class)
+ command = [
+ self.java_executable,
+ "-classpath",
+ _calculate_classpath(self.jars_root),
+ *self.jvm_args,
+ jar.main_class,
+ ]
+ return command, jar.schema_version
diff --git a/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
b/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
index 25b9a323b5d..8670a0e895c 100644
--- a/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
+++ b/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
@@ -18,28 +18,21 @@
from __future__ import annotations
-import contextlib
import os
import pathlib
import re
import socket
import subprocess
-import threading
-import time
import zipfile
-from unittest.mock import ANY, MagicMock, call, patch
+from unittest.mock import MagicMock, patch
import pytest
from uuid6 import uuid7
from airflow.sdk.coordinators.java.coordinator import (
JavaCoordinator,
- _accept_connections,
_calculate_classpath,
_JarInfo,
- _JavaActivitySubprocess,
- _ResourceTracker,
- _start_server,
_walk_jars,
)
from airflow.sdk.execution_time.coordinator import BaseCoordinator
@@ -51,10 +44,6 @@ from tests_common.test_utils.version_compat import
AIRFLOW_V_3_3_PLUS
if not AIRFLOW_V_3_3_PLUS:
pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0",
allow_module_level=True)
-METADATA_YAML_PATH = "META-INF/airflow-metadata.yaml"
-DAG_CODE_PATH = "dag_source.py"
-TEST_MAIN_CLASS = "com.example.MyBundle"
-
def _make_ti(dag_id: str = "test_dag", queue: str = "java") -> TaskInstanceDTO:
return TaskInstanceDTO(
@@ -89,45 +78,6 @@ def _make_jar(
return path
-class TestStartServer:
- def test_returns_listening_socket(self):
- server = _start_server()
- try:
- _, port = server.getsockname()
- finally:
- server.close()
- assert port > 0
-
- def test_two_calls_return_different_ports(self):
- s1 = _start_server()
- s2 = _start_server()
- try:
- _, port1 = s1.getsockname()
- _, port2 = s2.getsockname()
- finally:
- s1.close()
- s2.close()
- assert port1 != port2
-
- def test_accepts_connection(self):
- conn = client = None
- server = _start_server()
- try:
- _, port = server.getsockname()
- client = socket.socket()
- client.connect(("127.0.0.1", port))
- conn, _ = server.accept()
- conn.sendall(b"ping")
- received = client.recv(4)
- finally:
- if conn:
- conn.close()
- if client:
- client.close()
- server.close()
- assert received == b"ping"
-
-
class TestCalculateClasspath:
def test_single_jar(self, tmp_path):
jar = tmp_path.joinpath("app.jar")
@@ -272,225 +222,6 @@ class TestWalkJars:
mock_log.debug.assert_any_call("Skipping already-visited directory",
path=sub)
-class TestAcceptConnections:
- def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0)
-> None:
- def _connect():
- time.sleep(delay)
- c = socket.socket()
- with contextlib.suppress(OSError): # Server may already be closed
in teardown.
- c.connect(addr)
-
- threading.Thread(target=_connect, daemon=True).start()
-
- def test_accepts_single_server(self):
- server = _start_server()
- _, port = server.getsockname()
- self._connect_after_delay(("127.0.0.1", port))
-
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
-
- try:
- accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
- assert server in accepted
- accepted[server].close()
- finally:
- server.close()
-
- def test_accepts_multiple_servers(self):
- comm_server = _start_server()
- logs_server = _start_server()
- _, comm_port = comm_server.getsockname()
- _, logs_port = logs_server.getsockname()
-
- self._connect_after_delay(("127.0.0.1", comm_port))
- self._connect_after_delay(("127.0.0.1", logs_port))
-
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
-
- try:
- accepted, _ = _accept_connections({"comm": comm_server, "logs":
logs_server}, {}, mock_proc)
- assert set(accepted) == {comm_server, logs_server}
- for sock in accepted.values():
- sock.close()
- finally:
- comm_server.close()
- logs_server.close()
-
- def test_raises_timeout_when_no_connection(self):
- server = _start_server()
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
- try:
- with pytest.raises(TimeoutError, match="did not connect within
timeout"):
- _accept_connections({"comm": server}, {}, mock_proc,
max_wait=0.05)
- finally:
- server.close()
-
- def test_raises_runtime_error_if_process_exits_before_connecting(self):
- server = _start_server()
- mock_proc = MagicMock(spec=subprocess.Popen)
- # proc has already exited
- mock_proc.poll.return_value = 1
- mock_proc.returncode = 1
- try:
- with pytest.raises(RuntimeError, match="process exited with 1"):
- _accept_connections({"comm": server}, {}, mock_proc)
- finally:
- server.close()
-
- def test_returned_sockets_are_connected(self):
- """Accepted sockets should be real, usable connections."""
- server = _start_server()
- _, port = server.getsockname()
-
- client = socket.socket()
- client.connect(("127.0.0.1", port))
-
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
-
- try:
- accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
- accepted[server].sendall(b"hello")
- assert client.recv(5) == b"hello"
- accepted[server].close()
- client.close()
- finally:
- server.close()
-
- def test_empty_drains_returns_empty_drained_dict(self):
- """When drains={} the returned drained mapping must also be empty."""
- server = _start_server()
- _, port = server.getsockname()
- self._connect_after_delay(("127.0.0.1", port))
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
- try:
- _, drained = _accept_connections({"comm": server}, {}, mock_proc)
- assert drained == {}
- finally:
- server.close()
-
- def test_drain_socket_present_in_drained_dict(self):
- """The drained dict must be keyed by the drain socket objects."""
- server = _start_server()
- drain_r, drain_w = socket.socketpair()
- _, port = server.getsockname()
- self._connect_after_delay(("127.0.0.1", port))
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
- try:
- _, drained = _accept_connections({"comm": server}, {"stdout":
drain_r}, mock_proc)
- assert drain_r in drained
- finally:
- server.close()
- drain_r.close()
- drain_w.close()
-
- def test_bytes_written_to_drain_socket_are_returned(self):
- """Bytes written to a drain socket before the connection is accepted
- must be captured and returned in the drained dict."""
- server = _start_server()
- drain_r, drain_w = socket.socketpair()
- _, port = server.getsockname()
-
- drain_w.sendall(b"early output\n")
- self._connect_after_delay(("127.0.0.1", port), delay=0.05)
-
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
- try:
- _, drained = _accept_connections({"comm": server}, {"stdout":
drain_r}, mock_proc)
- assert drained[drain_r] == b"early output\n"
- finally:
- server.close()
- drain_r.close()
- drain_w.close()
-
- def test_accepted_dict_keyed_by_server_socket_object(self):
- """The returned accepted mapping must use server socket objects as
keys,
- not the string names passed in the servers dict."""
- server = _start_server()
- _, port = server.getsockname()
- self._connect_after_delay(("127.0.0.1", port))
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
- try:
- accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
- # Key must be the socket object itself, not the string "comm"
- assert server in accepted
- assert "comm" not in accepted
- accepted[server].close()
- finally:
- server.close()
-
-
-class TestResourceTracker:
- """Unit tests for the _ResourceTracker context manager introduced in this
PR.
-
- _ResourceTracker tracks sockets and Popen objects and ensures they are
- closed/terminated on context-manager exit, unless explicitly untracked
- beforehand.
- """
-
- def test_track_returns_passed_objects_as_tuple(self):
- tracker = _ResourceTracker(timeout=0.1)
- sock = MagicMock(spec=socket.socket)
- result = tracker.track(sock)
- assert result == (sock,)
-
- def test_track_multiple_objects_returns_all(self):
- tracker = _ResourceTracker(timeout=0.1)
- sock1 = MagicMock(spec=socket.socket)
- sock2 = MagicMock(spec=socket.socket)
- result = tracker.track(sock1, sock2)
- assert set(result) == {sock1, sock2}
-
- def test_untrack_returns_objects(self):
- tracker = _ResourceTracker(timeout=0.1)
- sock = MagicMock(spec=socket.socket)
- tracker.track(sock)
- result = tracker.untrack(sock)
- assert result == (sock,)
-
- def test_context_manager_closes_tracked_socket_on_exit(self):
- sock = MagicMock(spec=socket.socket)
- with _ResourceTracker(timeout=0.1) as tracker:
- tracker.track(sock)
- sock.close.assert_called_once()
-
- def test_context_manager_terminates_tracked_popen_on_exit(self):
- proc = MagicMock(spec=subprocess.Popen)
- with _ResourceTracker(timeout=0.1) as tracker:
- tracker.track(proc)
- proc.terminate.assert_called_once()
-
- def test_untracked_socket_not_closed_on_exit(self):
- sock = MagicMock(spec=socket.socket)
- with _ResourceTracker(timeout=0.1) as tracker:
- tracker.track(sock)
- tracker.untrack(sock)
- sock.close.assert_not_called()
-
- def test_only_remaining_tracked_objects_cleaned_up(self):
- """After untracking one socket the other must still be closed."""
- sock_keep = MagicMock(spec=socket.socket)
- sock_release = MagicMock(spec=socket.socket)
- with _ResourceTracker(timeout=0.1) as tracker:
- tracker.track(sock_keep, sock_release)
- tracker.untrack(sock_release)
- sock_keep.close.assert_called_once()
- sock_release.close.assert_not_called()
-
- def test_untrack_unknown_object_does_not_raise(self):
- sock = MagicMock(spec=socket.socket)
- tracker = _ResourceTracker(timeout=0.1)
- # Untracking something never tracked must be a no-op, not an error
- tracker.untrack(sock)
-
-
class TestJavaCoordinatorAttributes:
def test_default_kwargs(self):
coordinator = JavaCoordinator(jars_root="/airflow/java-bundles")
@@ -551,11 +282,11 @@ class TestJavaCoordinatorExecuteTask:
with (
patch(
- "airflow.sdk.coordinators.java.coordinator.subprocess.Popen",
+ "airflow.sdk.coordinators._subprocess.subprocess.Popen",
side_effect=capture_popen,
),
patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
+ "airflow.sdk.coordinators._subprocess._accept_connections",
side_effect=lambda servers, drains, proc, **kw: (
{servers["comm"]: comm_sock, servers["logs"]: logs_sock},
{soc: b"" for soc in drains.values()},
@@ -636,7 +367,7 @@ class TestJavaCoordinatorExecuteTask:
with (
patch("subprocess.Popen", return_value=mock_proc),
patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
+ "airflow.sdk.coordinators._subprocess._accept_connections",
side_effect=lambda servers, drains, proc, **kw: (
{servers["comm"]: comm_sock, servers["logs"]: logs_sock},
{soc: b"" for soc in drains.values()},
@@ -657,160 +388,3 @@ class TestJavaCoordinatorExecuteTask:
assert isinstance(result, BaseCoordinator.ExecutionResult)
assert result.exit_code == 0
-
-
-class TestJavaActivitySubprocessStart:
- """
- Unit tests for _JavaActivitySubprocess.start().
-
- These tests mock subprocess.Popen and _accept_connections to verify that
- start() wires up the right command and stores the right sockets,
- without requiring a real Java runtime.
- """
-
- def _start_with_mocks(
- self,
- jars_root: pathlib.Path,
- mock_client,
- *,
- java_executable: str = "java",
- jvm_args: list[str] | None = None,
- ti: TaskInstanceDTO | None = None,
- ):
- """Call _JavaActivitySubprocess.start() with all subprocess machinery
mocked out."""
- ti = ti or _make_ti()
-
- mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.pid = 12345
- comm_sock = MagicMock(spec=socket.socket)
- logs_sock = MagicMock(spec=socket.socket)
-
- with (
- patch(
- "airflow.sdk.coordinators.java.coordinator.subprocess.Popen",
- return_value=mock_proc,
- ) as popen_mock,
- patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
- side_effect=lambda servers, drains, proc, **kw: (
- {servers["comm"]: comm_sock, servers["logs"]: logs_sock},
- {soc: b"" for soc in drains.values()},
- ),
- ),
- patch.object(ActivitySubprocess, "_register_pipe_readers"),
- patch.object(ActivitySubprocess, "_on_child_started"),
- patch("psutil.Process"),
- ):
- proc = _JavaActivitySubprocess.start(
- what=ti,
- dag_rel_path="dags/test.jar",
- bundle_info=MagicMock(),
- client=mock_client,
- java_executable=java_executable,
- jvm_args=jvm_args or [],
- jars_root=[jars_root],
- main_class="",
- subprocess_logs_to_stdout=False,
- )
-
- return proc, popen_mock
-
- def test_stdin_is_comm_socket(self, jars_root, mock_client):
- """stdin (used by send_msg) must be the accepted comm socket."""
- ti = _make_ti()
- comm_sock = MagicMock(spec=socket.socket)
- logs_sock = MagicMock(spec=socket.socket)
-
- with (
-
patch("airflow.sdk.coordinators.java.coordinator.subprocess.Popen") as
popen_mock,
- patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
- side_effect=lambda servers, drains, proc, **kw: (
- {servers["comm"]: comm_sock, servers["logs"]: logs_sock},
- {soc: b"" for soc in drains.values()},
- ),
- ),
- patch.object(ActivitySubprocess, "_register_pipe_readers"),
- patch.object(ActivitySubprocess, "_on_child_started"),
- patch("psutil.Process"),
- ):
- popen_mock.return_value.pid = 12345
- proc = _JavaActivitySubprocess.start(
- what=ti,
- dag_rel_path="dags/test.jar",
- bundle_info=MagicMock(),
- client=MagicMock(),
- java_executable="java",
- jvm_args=[],
- jars_root=[jars_root],
- main_class="",
- subprocess_logs_to_stdout=False,
- )
-
- assert proc.stdin is comm_sock
-
- def test_pid_taken_from_popen(self, jars_root, mock_client):
- proc, _ = self._start_with_mocks(jars_root, mock_client)
- assert proc.pid == 12345
-
- def test_on_child_started_called(self, jars_root, mock_client):
- ti = _make_ti()
- with (
-
patch("airflow.sdk.coordinators.java.coordinator.subprocess.Popen") as
popen_mock,
- patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
- side_effect=lambda servers, drains, proc, **kw: (
- {soc: MagicMock(spec=socket.socket) for soc in
servers.values()},
- {soc: b"" for soc in drains.values()},
- ),
- ),
- patch.object(ActivitySubprocess, "_register_pipe_readers"),
- patch.object(ActivitySubprocess, "_on_child_started") as
mock_on_started,
- patch("psutil.Process"),
- ):
- popen_mock.return_value.pid = 12345
- _JavaActivitySubprocess.start(
- what=ti,
- dag_rel_path="dags/test.jar",
- bundle_info=MagicMock(),
- client=mock_client,
- java_executable="java",
- jvm_args=[],
- jars_root=[jars_root],
- main_class="",
- subprocess_logs_to_stdout=False,
- )
-
- mock_on_started.assert_called_once()
- kwargs = mock_on_started.call_args.kwargs
- assert kwargs["ti"] is ti
- assert kwargs["dag_rel_path"] == "dags/test.jar"
-
- def test_register_pipe_readers_called_with_four_sockets(self, jars_root,
mock_client):
- """Both socketpair read-ends and both TCP sockets must be registered,
with a data kwarg."""
- with (
-
patch("airflow.sdk.coordinators.java.coordinator.subprocess.Popen") as
popen_mock,
- patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
- side_effect=lambda servers, drains, proc, **kw: (
- {soc: MagicMock(spec=socket.socket) for soc in
servers.values()},
- {soc: b"" for soc in drains.values()},
- ),
- ),
- patch.object(ActivitySubprocess, "_register_pipe_readers") as
mock_register,
- patch.object(ActivitySubprocess, "_on_child_started"),
- patch("psutil.Process"),
- ):
- popen_mock.return_value.pid = 12345
- _JavaActivitySubprocess.start(
- what=_make_ti(),
- dag_rel_path="dags/test.jar",
- bundle_info=MagicMock(),
- client=mock_client,
- java_executable="java",
- jvm_args=[],
- jars_root=[jars_root],
- main_class="",
- subprocess_logs_to_stdout=False,
- )
- assert mock_register.mock_calls == [call(ANY, ANY, ANY, ANY, data=ANY)]
diff --git a/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
similarity index 50%
copy from task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
copy to task-sdk/tests/task_sdk/coordinators/test_subprocess.py
index 25b9a323b5d..e5105e9f223 100644
--- a/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
+++ b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
@@ -15,32 +15,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
from __future__ import annotations
import contextlib
-import os
-import pathlib
-import re
import socket
import subprocess
import threading
import time
-import zipfile
from unittest.mock import ANY, MagicMock, call, patch
+import attrs
import pytest
from uuid6 import uuid7
-from airflow.sdk.coordinators.java.coordinator import (
- JavaCoordinator,
+from airflow.sdk.api.client import Client, TaskInstanceOperations
+from airflow.sdk.coordinators._subprocess import (
+ SubprocessCoordinator,
_accept_connections,
- _calculate_classpath,
- _JarInfo,
- _JavaActivitySubprocess,
+ _PopenActivitySubprocess,
_ResourceTracker,
_start_server,
- _walk_jars,
)
from airflow.sdk.execution_time.coordinator import BaseCoordinator
from airflow.sdk.execution_time.supervisor import ActivitySubprocess
@@ -51,12 +45,8 @@ from tests_common.test_utils.version_compat import
AIRFLOW_V_3_3_PLUS
if not AIRFLOW_V_3_3_PLUS:
pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0",
allow_module_level=True)
-METADATA_YAML_PATH = "META-INF/airflow-metadata.yaml"
-DAG_CODE_PATH = "dag_source.py"
-TEST_MAIN_CLASS = "com.example.MyBundle"
-
-def _make_ti(dag_id: str = "test_dag", queue: str = "java") -> TaskInstanceDTO:
+def _make_ti(dag_id: str = "tutorial_dag", queue: str = "socket") ->
TaskInstanceDTO:
return TaskInstanceDTO(
id=uuid7(),
dag_version_id=uuid7(),
@@ -71,31 +61,14 @@ def _make_ti(dag_id: str = "test_dag", queue: str = "java")
-> TaskInstanceDTO:
)
-def _make_jar(
- path: pathlib.Path,
- *,
- main_class: str | None = "com.example.Main",
- schema_version: str | None = None,
-) -> pathlib.Path:
- """Write a minimal JAR with (optionally) a Main-Class manifest entry."""
- lines = ["Manifest-Version: 1.0"]
- if main_class:
- lines.append(f"Main-Class: {main_class}")
- if schema_version:
- lines.append(f"Airflow-Supervisor-Schema-Version: {schema_version}")
- manifest = "\n".join(lines) + "\n\n"
- with zipfile.ZipFile(path, "w") as zf:
- zf.writestr("META-INF/MANIFEST.MF", manifest)
- return path
-
-
class TestStartServer:
def test_returns_listening_socket(self):
server = _start_server()
try:
- _, port = server.getsockname()
+ host, port = server.getsockname()
finally:
server.close()
+ assert host == "127.0.0.1"
assert port > 0
def test_two_calls_return_different_ports(self):
@@ -128,156 +101,12 @@ class TestStartServer:
assert received == b"ping"
-class TestCalculateClasspath:
- def test_single_jar(self, tmp_path):
- jar = tmp_path.joinpath("app.jar")
- jar.write_bytes(b"")
- result = _calculate_classpath([tmp_path])
- assert result == jar.as_posix()
-
- def test_multiple_jars_all_included(self, tmp_path):
- tmp_path.joinpath("a.jar").write_bytes(b"")
- tmp_path.joinpath("b.jar").write_bytes(b"")
- tmp_path.joinpath("c.jar").write_bytes(b"")
- result = _calculate_classpath([tmp_path])
- entries = set(result.split(os.pathsep))
- assert entries == {
- tmp_path.joinpath("a.jar").as_posix(),
- tmp_path.joinpath("b.jar").as_posix(),
- tmp_path.joinpath("c.jar").as_posix(),
- }
-
- def test_non_jar_files_excluded(self, tmp_path):
- jar = tmp_path.joinpath("app.jar")
- jar.write_bytes(b"")
- tmp_path.joinpath("readme.txt").write_bytes(b"")
- tmp_path.joinpath("config.yaml").write_bytes(b"")
- result = _calculate_classpath([tmp_path])
- assert result == jar.as_posix()
-
- def test_empty_directory_returns_empty_string(self, tmp_path):
- result = _calculate_classpath([tmp_path])
- assert result == ""
-
-
-class TestMainJar:
- def test_returns_main_class_from_jar(self, tmp_path):
- _make_jar(tmp_path.joinpath("app.jar"), main_class="com.example.Main",
schema_version="2026-06-16")
- assert _JarInfo.find([tmp_path], "") == _JarInfo("com.example.Main",
"2026-06-16")
-
- def test_no_jars_raises_file_not_found(self, tmp_path):
- with pytest.raises(FileNotFoundError,
match=re.escape(str(tmp_path.resolve()))):
- _JarInfo.find([tmp_path], "")
-
- def test_jar_without_main_class_not_returned(self, tmp_path):
- _make_jar(tmp_path.joinpath("app.jar"), main_class=None)
- with pytest.raises(FileNotFoundError):
- _JarInfo.find([tmp_path], "")
-
- def test_jar_with_main_class_but_no_schema_version_raises(self, tmp_path):
- """A JAR with Main-Class but no Airflow-Supervisor-Schema-Version must
raise ValueError."""
- _make_jar(tmp_path.joinpath("app.jar"), main_class="com.example.Main")
- with pytest.raises(FileNotFoundError,
match="Airflow-Supervisor-Schema-Version"):
- _JarInfo.find([tmp_path], "")
-
- def test_non_jar_files_skipped(self, tmp_path):
- tmp_path.joinpath("readme.txt").write_bytes(b"not a jar")
- _make_jar(tmp_path.joinpath("app.jar"), main_class="com.example.Main",
schema_version="2026-06-16")
- assert _JarInfo.find([tmp_path], "") == _JarInfo("com.example.Main",
"2026-06-16")
-
- def test_first_jar_missing_main_class_falls_through_to_second(self,
tmp_path):
- # Alphabetically: a.jar (no Main-Class), b.jar (has Main-Class).
- _make_jar(tmp_path.joinpath("a.jar"), main_class=None)
- _make_jar(tmp_path.joinpath("b.jar"),
main_class="com.example.Fallback", schema_version="2026-06-16")
- assert _JarInfo.find([tmp_path], "") ==
_JarInfo("com.example.Fallback", "2026-06-16")
-
- def test_fully_qualified_class_name_preserved(self, tmp_path):
- _make_jar(
- tmp_path.joinpath("app.jar"),
- main_class="org.apache.airflow.sdk.java.TaskRunner",
- schema_version="2026-06-16",
- )
- assert _JarInfo.find([tmp_path], "") == _JarInfo(
- main_class="org.apache.airflow.sdk.java.TaskRunner",
- schema_version="2026-06-16",
- )
-
- def test_find_by_explicit_main_class(self, tmp_path):
- """When a main_class filter is given, only the matching JAR is
returned."""
- _make_jar(tmp_path.joinpath("a.jar"), main_class="com.example.Alpha",
schema_version="2026-06-16")
- _make_jar(tmp_path.joinpath("b.jar"), main_class="com.example.Beta",
schema_version="2026-06-16")
- result = _JarInfo.find([tmp_path], "com.example.Beta")
- assert result.main_class == "com.example.Beta"
-
- def test_find_by_explicit_main_class_not_present_raises(self, tmp_path):
- """When no JAR matches the main_class filter, FileNotFoundError is
raised."""
- _make_jar(tmp_path.joinpath("app.jar"), main_class="com.example.Main",
schema_version="2026-06-16")
- with pytest.raises(FileNotFoundError, match="com.example.Missing"):
- _JarInfo.find([tmp_path], "com.example.Missing")
-
- def test_symlink_cycle_does_not_infinite_recurse(self, tmp_path):
- nested = tmp_path / "inner"
- nested.mkdir()
- _make_jar(nested / "app.jar", main_class="com.example.Loop",
schema_version="2026-06-16")
- loop = nested / "loop"
- try:
- loop.symlink_to(tmp_path)
- except (OSError, NotImplementedError):
- pytest.skip("symlinks not supported on this platform")
-
- result = _JarInfo.find([tmp_path], "com.example.Loop")
- assert result == _JarInfo("com.example.Loop", "2026-06-16")
-
-
-class TestWalkJars:
- def test_skips_directory_whose_key_is_already_in_seen_dirs(self, tmp_path):
- """A directory whose (st_dev, st_ino) is already in seen_dirs is
skipped."""
- _make_jar(tmp_path / "app.jar", main_class="com.example.Main",
schema_version="2026-06-16")
- st = tmp_path.stat()
- seen_dirs: set[tuple[int, int]] = {(st.st_dev, st.st_ino)}
- assert list(_walk_jars([tmp_path], seen_dirs)) == []
-
- def test_records_visited_directories_in_seen_dirs(self, tmp_path):
- """Every directory descended into is added to seen_dirs."""
- sub = tmp_path / "sub"
- sub.mkdir()
- _make_jar(sub / "app.jar", main_class="com.example.Main",
schema_version="2026-06-16")
- seen_dirs: set[tuple[int, int]] = set()
- list(_walk_jars([tmp_path], seen_dirs))
- assert (tmp_path.stat().st_dev, tmp_path.stat().st_ino) in seen_dirs
- assert (sub.stat().st_dev, sub.stat().st_ino) in seen_dirs
-
- def test_symlink_cycle_yields_each_jar_once(self, tmp_path):
- """A symlink that loops back to an ancestor must not yield the same
JAR twice."""
- nested = tmp_path / "inner"
- nested.mkdir()
- jar = _make_jar(nested / "app.jar", main_class="com.example.Loop",
schema_version="2026-06-16")
- loop = nested / "loop"
- try:
- loop.symlink_to(tmp_path)
- except (OSError, NotImplementedError):
- pytest.skip("symlinks not supported on this platform")
-
- seen_dirs: set[tuple[int, int]] = set()
- yielded = list(_walk_jars([tmp_path], seen_dirs))
- assert [p.resolve() for p in yielded] == [jar.resolve()]
-
- def test_skip_logged_when_directory_revisited(self, tmp_path):
- """A revisited directory triggers the 'Skipping already-visited
directory' debug log."""
- sub = tmp_path / "sub"
- sub.mkdir()
- seen_dirs: set[tuple[int, int]] = {(sub.stat().st_dev,
sub.stat().st_ino)}
- with patch("airflow.sdk.coordinators.java.coordinator.log") as
mock_log:
- list(_walk_jars([sub], seen_dirs))
- mock_log.debug.assert_any_call("Skipping already-visited directory",
path=sub)
-
-
class TestAcceptConnections:
def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0)
-> None:
def _connect():
time.sleep(delay)
c = socket.socket()
- with contextlib.suppress(OSError): # Server may already be closed
in teardown.
+ with contextlib.suppress(OSError):
c.connect(addr)
threading.Thread(target=_connect, daemon=True).start()
@@ -297,7 +126,7 @@ class TestAcceptConnections:
finally:
server.close()
- def test_accepts_multiple_servers(self):
+ def test_accepts_multiple_servers_keyed_by_server_socket(self):
comm_server = _start_server()
logs_server = _start_server()
_, comm_port = comm_server.getsockname()
@@ -310,104 +139,105 @@ class TestAcceptConnections:
mock_proc.poll.return_value = None
try:
- accepted, _ = _accept_connections({"comm": comm_server, "logs":
logs_server}, {}, mock_proc)
+ accepted, drained = _accept_connections({"comm": comm_server,
"logs": logs_server}, {}, mock_proc)
assert set(accepted) == {comm_server, logs_server}
+ assert drained == {}
for sock in accepted.values():
sock.close()
finally:
comm_server.close()
logs_server.close()
- def test_raises_timeout_when_no_connection(self):
+ def test_empty_drains_returns_empty_drained_dict(self):
server = _start_server()
+ _, port = server.getsockname()
+ self._connect_after_delay(("127.0.0.1", port))
+
mock_proc = MagicMock(spec=subprocess.Popen)
mock_proc.poll.return_value = None
try:
- with pytest.raises(TimeoutError, match="did not connect within
timeout"):
- _accept_connections({"comm": server}, {}, mock_proc,
max_wait=0.05)
+ _, drained = _accept_connections({"comm": server}, {}, mock_proc)
+ assert drained == {}
finally:
server.close()
- def test_raises_runtime_error_if_process_exits_before_connecting(self):
+ def test_drain_socket_present_in_drained_dict(self):
server = _start_server()
+ drain_r, drain_w = socket.socketpair()
+ _, port = server.getsockname()
+ self._connect_after_delay(("127.0.0.1", port))
+
mock_proc = MagicMock(spec=subprocess.Popen)
- # proc has already exited
- mock_proc.poll.return_value = 1
- mock_proc.returncode = 1
+ mock_proc.poll.return_value = None
try:
- with pytest.raises(RuntimeError, match="process exited with 1"):
- _accept_connections({"comm": server}, {}, mock_proc)
+ _, drained = _accept_connections({"comm": server}, {"stdout":
drain_r}, mock_proc)
+ assert drain_r in drained
finally:
+ drain_r.close()
+ drain_w.close()
server.close()
- def test_returned_sockets_are_connected(self):
- """Accepted sockets should be real, usable connections."""
+ def test_drain_captures_early_output(self):
+ """Bytes written to the drain socket before the comm server accepts
+ must be captured and returned in the drained dict."""
server = _start_server()
+ drain_r, drain_w = socket.socketpair()
_, port = server.getsockname()
- client = socket.socket()
- client.connect(("127.0.0.1", port))
+ drain_w.sendall(b"early output\n")
+ drain_w.shutdown(socket.SHUT_WR)
+ self._connect_after_delay(("127.0.0.1", port), delay=0.05)
mock_proc = MagicMock(spec=subprocess.Popen)
mock_proc.poll.return_value = None
-
try:
- accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
- accepted[server].sendall(b"hello")
- assert client.recv(5) == b"hello"
- accepted[server].close()
- client.close()
+ _, drained = _accept_connections({"comm": server}, {"stdout":
drain_r}, mock_proc)
+ assert drained[drain_r] == b"early output\n"
finally:
+ drain_r.close()
+ drain_w.close()
server.close()
- def test_empty_drains_returns_empty_drained_dict(self):
- """When drains={} the returned drained mapping must also be empty."""
+ def test_raises_timeout_when_no_connection(self):
server = _start_server()
- _, port = server.getsockname()
- self._connect_after_delay(("127.0.0.1", port))
mock_proc = MagicMock(spec=subprocess.Popen)
mock_proc.poll.return_value = None
try:
- _, drained = _accept_connections({"comm": server}, {}, mock_proc)
- assert drained == {}
+ with pytest.raises(TimeoutError, match="did not connect within
timeout"):
+ _accept_connections({"comm": server}, {}, mock_proc,
max_wait=0.05)
finally:
server.close()
- def test_drain_socket_present_in_drained_dict(self):
- """The drained dict must be keyed by the drain socket objects."""
+ def test_raises_runtime_error_if_process_exits_before_connecting(self):
server = _start_server()
- drain_r, drain_w = socket.socketpair()
- _, port = server.getsockname()
- self._connect_after_delay(("127.0.0.1", port))
mock_proc = MagicMock(spec=subprocess.Popen)
- mock_proc.poll.return_value = None
+ mock_proc.poll.return_value = 1
+ mock_proc.returncode = 1
try:
- _, drained = _accept_connections({"comm": server}, {"stdout":
drain_r}, mock_proc)
- assert drain_r in drained
+ with pytest.raises(RuntimeError, match="process exited with 1"):
+ _accept_connections({"comm": server}, {}, mock_proc)
finally:
server.close()
- drain_r.close()
- drain_w.close()
- def test_bytes_written_to_drain_socket_are_returned(self):
- """Bytes written to a drain socket before the connection is accepted
- must be captured and returned in the drained dict."""
+ def test_returned_sockets_are_connected(self):
+ """Accepted sockets should be real, usable connections."""
server = _start_server()
- drain_r, drain_w = socket.socketpair()
_, port = server.getsockname()
- drain_w.sendall(b"early output\n")
- self._connect_after_delay(("127.0.0.1", port), delay=0.05)
+ client = socket.socket()
+ client.connect(("127.0.0.1", port))
mock_proc = MagicMock(spec=subprocess.Popen)
mock_proc.poll.return_value = None
+
try:
- _, drained = _accept_connections({"comm": server}, {"stdout":
drain_r}, mock_proc)
- assert drained[drain_r] == b"early output\n"
+ accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
+ accepted[server].sendall(b"hello")
+ assert client.recv(5) == b"hello"
+ accepted[server].close()
+ client.close()
finally:
server.close()
- drain_r.close()
- drain_w.close()
def test_accepted_dict_keyed_by_server_socket_object(self):
"""The returned accepted mapping must use server socket objects as
keys,
@@ -428,7 +258,8 @@ class TestAcceptConnections:
class TestResourceTracker:
- """Unit tests for the _ResourceTracker context manager introduced in this
PR.
+ """
+ Unit tests for the _ResourceTracker context manager.
_ResourceTracker tracks sockets and Popen objects and ensures they are
closed/terminated on context-manager exit, unless explicitly untracked
@@ -491,71 +322,77 @@ class TestResourceTracker:
tracker.untrack(sock)
-class TestJavaCoordinatorAttributes:
- def test_default_kwargs(self):
- coordinator = JavaCoordinator(jars_root="/airflow/java-bundles")
- assert coordinator.java_executable == "java"
- assert coordinator.jvm_args == []
- assert coordinator.jars_root == [pathlib.Path("/airflow/java-bundles")]
-
- def test_custom_kwargs(self):
- coordinator = JavaCoordinator(
- java_executable="/opt/java/bin/java",
- jvm_args=["-Xmx512m", "-Xms256m"],
- jars_root=["/airflow/java-bundles"],
- )
- assert coordinator.java_executable == "/opt/java/bin/java"
- assert coordinator.jvm_args == ["-Xmx512m", "-Xms256m"]
- assert coordinator.jars_root == [pathlib.Path("/airflow/java-bundles")]
[email protected](kw_only=True)
+class _StubSubprocessCoordinator(SubprocessCoordinator):
+ """Minimal SubprocessCoordinator subclass used to exercise the base
machinery."""
+ command: list[str]
+ schema_version: str | None = None
[email protected]
-def jars_root(tmp_path):
- _make_jar(tmp_path.joinpath("app.jar"),
main_class="com.example.TaskRunner", schema_version="2026-06-16")
- return tmp_path
+ def _build_execute_task_command(self, *, what):
+ return list(self.command), self.schema_version
@pytest.fixture
def mock_client(make_ti_context):
- client = MagicMock()
+ client = MagicMock(spec=Client)
+ client.task_instances = MagicMock(spec=TaskInstanceOperations)
client.task_instances.start.return_value = make_ti_context()
return client
-class TestJavaCoordinatorExecuteTask:
+class TestSubprocessCoordinatorAttributes:
+ def test_default_startup_timeout(self):
+ coordinator = _StubSubprocessCoordinator(command=["/bin/true"])
+ assert coordinator.task_startup_timeout == 10.0
+
+ def test_custom_startup_timeout(self):
+ coordinator = _StubSubprocessCoordinator(command=["/bin/true"],
task_startup_timeout=2.5)
+ assert coordinator.task_startup_timeout == 2.5
+
+ def test_build_execute_task_command_default_raises(self):
+ class _Plain(SubprocessCoordinator):
+ pass
+
+ with pytest.raises(NotImplementedError):
+ _Plain()._build_execute_task_command(what=_make_ti())
+
+
+class TestSubprocessCoordinatorExecuteTask:
def _captured_popen_cmd(
self,
- jars_root: pathlib.Path,
mock_client,
*,
- java_executable: str = "java",
- jvm_args: list[str] | None = None,
- ) -> list[str]:
- """Run execute_task with mocked subprocess and return the command
list."""
+ command: list[str],
+ schema_version: str | None = None,
+ ) -> tuple[list[str], str | None]:
ti = _make_ti()
- coordinator = JavaCoordinator(
- java_executable=java_executable,
- jvm_args=jvm_args or [],
- jars_root=jars_root,
- )
+ coordinator = _StubSubprocessCoordinator(command=command,
schema_version=schema_version)
mock_proc = MagicMock(spec=subprocess.Popen)
mock_proc.pid = 12345
comm_sock = MagicMock(spec=socket.socket)
logs_sock = MagicMock(spec=socket.socket)
popen_calls: list = []
+ cls_kwargs: dict = {}
def capture_popen(cmd, **kwargs):
popen_calls.append(cmd)
return mock_proc
+ original_start = _PopenActivitySubprocess.__dict__["start"].__func__
+
+ def spy_start(cls, **kwargs):
+ cls_kwargs.update(kwargs)
+ return original_start(cls, **kwargs)
+
with (
patch(
- "airflow.sdk.coordinators.java.coordinator.subprocess.Popen",
+ "airflow.sdk.coordinators._subprocess.subprocess.Popen",
side_effect=capture_popen,
),
patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
+ "airflow.sdk.coordinators._subprocess._accept_connections",
side_effect=lambda servers, drains, proc, **kw: (
{servers["comm"]: comm_sock, servers["logs"]: logs_sock},
{soc: b"" for soc in drains.values()},
@@ -564,69 +401,59 @@ class TestJavaCoordinatorExecuteTask:
patch.object(ActivitySubprocess, "_register_pipe_readers"),
patch.object(ActivitySubprocess, "_on_child_started"),
patch.object(ActivitySubprocess, "wait", return_value=0),
- patch("psutil.Process"),
+ patch.object(
+ _PopenActivitySubprocess,
+ "start",
+ classmethod(spy_start),
+ ),
):
coordinator.execute_task(
what=ti,
- dag_rel_path="dags/test.jar",
+ dag_rel_path="bundle",
bundle_info=MagicMock(),
client=mock_client,
subprocess_logs_to_stdout=False,
)
assert popen_calls, "subprocess.Popen was not called"
- return popen_calls[0]
+ return popen_calls[0], cls_kwargs.get("subprocess_schema_version")
- def test_java_executable_is_first_arg(self, jars_root, mock_client):
- cmd = self._captured_popen_cmd(
- jars_root, mock_client,
java_executable="/usr/lib/jvm/java-17/bin/java"
- )
- assert cmd[0] == "/usr/lib/jvm/java-17/bin/java"
-
- def test_classpath_flag_and_value_present(self, jars_root, mock_client):
- cmd = self._captured_popen_cmd(jars_root, mock_client)
- assert "-classpath" in cmd
- cp_idx = cmd.index("-classpath")
- classpath = cmd[cp_idx + 1]
- assert jars_root.joinpath("app.jar").as_posix() in classpath
-
- def test_main_class_present(self, jars_root, mock_client):
- cmd = self._captured_popen_cmd(jars_root, mock_client)
- assert "com.example.TaskRunner" in cmd
+ def test_command_prefix_preserved(self, mock_client):
+ cmd, _ = self._captured_popen_cmd(mock_client,
command=["/path/to/runtime", "arg1"])
+ assert cmd[:2] == ["/path/to/runtime", "arg1"]
- def test_comm_and_logs_args_present(self, jars_root, mock_client):
- cmd = self._captured_popen_cmd(jars_root, mock_client)
+ def test_comm_and_logs_flags_appended(self, mock_client):
+ cmd, _ = self._captured_popen_cmd(mock_client,
command=["/path/to/runtime"])
comm_args = [a for a in cmd if a.startswith("--comm=")]
logs_args = [a for a in cmd if a.startswith("--logs=")]
assert len(comm_args) == 1
assert len(logs_args) == 1
- def test_comm_and_logs_contain_port(self, jars_root, mock_client):
- cmd = self._captured_popen_cmd(jars_root, mock_client)
+ def test_comm_and_logs_contain_port(self, mock_client):
+ cmd, _ = self._captured_popen_cmd(mock_client,
command=["/path/to/runtime"])
comm_arg = next(a for a in cmd if a.startswith("--comm="))
logs_arg = next(a for a in cmd if a.startswith("--logs="))
# format is host:port
assert ":" in comm_arg.split("=", 1)[1]
assert ":" in logs_arg.split("=", 1)[1]
- def test_jvm_args_inserted_before_main_class(self, jars_root, mock_client):
- cmd = self._captured_popen_cmd(jars_root, mock_client,
jvm_args=["-Xmx512m", "-Dsome.prop=value"])
- main_idx = cmd.index("com.example.TaskRunner")
- for jvm_arg in ["-Xmx512m", "-Dsome.prop=value"]:
- assert jvm_arg in cmd
- assert cmd.index(jvm_arg) < main_idx
-
- def test_comm_and_logs_after_main_class(self, jars_root, mock_client):
- cmd = self._captured_popen_cmd(jars_root, mock_client)
- main_idx = cmd.index("com.example.TaskRunner")
+ def test_comm_and_logs_after_user_command(self, mock_client):
+ cmd, _ = self._captured_popen_cmd(mock_client,
command=["/path/to/runtime", "user-arg"])
+ user_idx = cmd.index("user-arg")
comm_idx = next(i for i, a in enumerate(cmd) if
a.startswith("--comm="))
logs_idx = next(i for i, a in enumerate(cmd) if
a.startswith("--logs="))
- assert comm_idx > main_idx
- assert logs_idx > main_idx
+ assert user_idx < comm_idx
+ assert user_idx < logs_idx
- def test_returns_execution_result(self, jars_root, mock_client):
+ def test_schema_version_forwarded(self, mock_client):
+ _, schema = self._captured_popen_cmd(
+ mock_client, command=["/path/to/runtime"],
schema_version="2026-06-16"
+ )
+ assert schema == "2026-06-16"
+
+ def test_returns_execution_result(self, mock_client):
ti = _make_ti()
- coordinator = JavaCoordinator(jars_root=jars_root)
+ coordinator = _StubSubprocessCoordinator(command=["/bin/true"])
mock_proc = MagicMock(spec=subprocess.Popen)
mock_proc.pid = 99999
@@ -636,7 +463,7 @@ class TestJavaCoordinatorExecuteTask:
with (
patch("subprocess.Popen", return_value=mock_proc),
patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
+ "airflow.sdk.coordinators._subprocess._accept_connections",
side_effect=lambda servers, drains, proc, **kw: (
{servers["comm"]: comm_sock, servers["logs"]: logs_sock},
{soc: b"" for soc in drains.values()},
@@ -645,11 +472,10 @@ class TestJavaCoordinatorExecuteTask:
patch.object(ActivitySubprocess, "_register_pipe_readers"),
patch.object(ActivitySubprocess, "_on_child_started"),
patch.object(ActivitySubprocess, "wait", return_value=0),
- patch("psutil.Process"),
):
result = coordinator.execute_task(
what=ti,
- dag_rel_path="dags/test.jar",
+ dag_rel_path="bundle",
bundle_info=MagicMock(),
client=mock_client,
subprocess_logs_to_stdout=False,
@@ -659,27 +485,9 @@ class TestJavaCoordinatorExecuteTask:
assert result.exit_code == 0
-class TestJavaActivitySubprocessStart:
- """
- Unit tests for _JavaActivitySubprocess.start().
-
- These tests mock subprocess.Popen and _accept_connections to verify that
- start() wires up the right command and stores the right sockets,
- without requiring a real Java runtime.
- """
-
- def _start_with_mocks(
- self,
- jars_root: pathlib.Path,
- mock_client,
- *,
- java_executable: str = "java",
- jvm_args: list[str] | None = None,
- ti: TaskInstanceDTO | None = None,
- ):
- """Call _JavaActivitySubprocess.start() with all subprocess machinery
mocked out."""
- ti = ti or _make_ti()
-
+class TestPopenActivitySubprocessStart:
+ def _start_with_mocks(self, mock_client, *, command: list[str],
schema_version=None):
+ ti = _make_ti()
mock_proc = MagicMock(spec=subprocess.Popen)
mock_proc.pid = 12345
comm_sock = MagicMock(spec=socket.socket)
@@ -687,11 +495,11 @@ class TestJavaActivitySubprocessStart:
with (
patch(
- "airflow.sdk.coordinators.java.coordinator.subprocess.Popen",
+ "airflow.sdk.coordinators._subprocess.subprocess.Popen",
return_value=mock_proc,
) as popen_mock,
patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
+ "airflow.sdk.coordinators._subprocess._accept_connections",
side_effect=lambda servers, drains, proc, **kw: (
{servers["comm"]: comm_sock, servers["logs"]: logs_sock},
{soc: b"" for soc in drains.values()},
@@ -699,66 +507,32 @@ class TestJavaActivitySubprocessStart:
),
patch.object(ActivitySubprocess, "_register_pipe_readers"),
patch.object(ActivitySubprocess, "_on_child_started"),
- patch("psutil.Process"),
):
- proc = _JavaActivitySubprocess.start(
+ proc = _PopenActivitySubprocess.start(
what=ti,
- dag_rel_path="dags/test.jar",
+ dag_rel_path="bundle",
bundle_info=MagicMock(),
client=mock_client,
- java_executable=java_executable,
- jvm_args=jvm_args or [],
- jars_root=[jars_root],
- main_class="",
- subprocess_logs_to_stdout=False,
- )
-
- return proc, popen_mock
-
- def test_stdin_is_comm_socket(self, jars_root, mock_client):
- """stdin (used by send_msg) must be the accepted comm socket."""
- ti = _make_ti()
- comm_sock = MagicMock(spec=socket.socket)
- logs_sock = MagicMock(spec=socket.socket)
-
- with (
-
patch("airflow.sdk.coordinators.java.coordinator.subprocess.Popen") as
popen_mock,
- patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
- side_effect=lambda servers, drains, proc, **kw: (
- {servers["comm"]: comm_sock, servers["logs"]: logs_sock},
- {soc: b"" for soc in drains.values()},
- ),
- ),
- patch.object(ActivitySubprocess, "_register_pipe_readers"),
- patch.object(ActivitySubprocess, "_on_child_started"),
- patch("psutil.Process"),
- ):
- popen_mock.return_value.pid = 12345
- proc = _JavaActivitySubprocess.start(
- what=ti,
- dag_rel_path="dags/test.jar",
- bundle_info=MagicMock(),
- client=MagicMock(),
- java_executable="java",
- jvm_args=[],
- jars_root=[jars_root],
- main_class="",
+ command=command,
+ subprocess_schema_version=schema_version,
subprocess_logs_to_stdout=False,
)
+ return proc, popen_mock, comm_sock
+ def test_stdin_is_comm_socket(self, mock_client):
+ proc, _, comm_sock = self._start_with_mocks(mock_client,
command=["/bin/true"])
assert proc.stdin is comm_sock
- def test_pid_taken_from_popen(self, jars_root, mock_client):
- proc, _ = self._start_with_mocks(jars_root, mock_client)
+ def test_pid_taken_from_popen(self, mock_client):
+ proc, _, _ = self._start_with_mocks(mock_client, command=["/bin/true"])
assert proc.pid == 12345
- def test_on_child_started_called(self, jars_root, mock_client):
+ def test_on_child_started_called(self, mock_client):
ti = _make_ti()
with (
-
patch("airflow.sdk.coordinators.java.coordinator.subprocess.Popen") as
popen_mock,
+ patch("airflow.sdk.coordinators._subprocess.subprocess.Popen") as
popen_mock,
patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
+ "airflow.sdk.coordinators._subprocess._accept_connections",
side_effect=lambda servers, drains, proc, **kw: (
{soc: MagicMock(spec=socket.socket) for soc in
servers.values()},
{soc: b"" for soc in drains.values()},
@@ -766,32 +540,28 @@ class TestJavaActivitySubprocessStart:
),
patch.object(ActivitySubprocess, "_register_pipe_readers"),
patch.object(ActivitySubprocess, "_on_child_started") as
mock_on_started,
- patch("psutil.Process"),
):
popen_mock.return_value.pid = 12345
- _JavaActivitySubprocess.start(
+ _PopenActivitySubprocess.start(
what=ti,
- dag_rel_path="dags/test.jar",
+ dag_rel_path="bundle",
bundle_info=MagicMock(),
client=mock_client,
- java_executable="java",
- jvm_args=[],
- jars_root=[jars_root],
- main_class="",
+ command=["/bin/true"],
subprocess_logs_to_stdout=False,
)
mock_on_started.assert_called_once()
kwargs = mock_on_started.call_args.kwargs
assert kwargs["ti"] is ti
- assert kwargs["dag_rel_path"] == "dags/test.jar"
+ assert kwargs["dag_rel_path"] == "bundle"
- def test_register_pipe_readers_called_with_four_sockets(self, jars_root,
mock_client):
+ def test_register_pipe_readers_called_with_four_sockets(self, mock_client):
"""Both socketpair read-ends and both TCP sockets must be registered,
with a data kwarg."""
with (
-
patch("airflow.sdk.coordinators.java.coordinator.subprocess.Popen") as
popen_mock,
+ patch("airflow.sdk.coordinators._subprocess.subprocess.Popen") as
popen_mock,
patch(
-
"airflow.sdk.coordinators.java.coordinator._accept_connections",
+ "airflow.sdk.coordinators._subprocess._accept_connections",
side_effect=lambda servers, drains, proc, **kw: (
{soc: MagicMock(spec=socket.socket) for soc in
servers.values()},
{soc: b"" for soc in drains.values()},
@@ -799,18 +569,14 @@ class TestJavaActivitySubprocessStart:
),
patch.object(ActivitySubprocess, "_register_pipe_readers") as
mock_register,
patch.object(ActivitySubprocess, "_on_child_started"),
- patch("psutil.Process"),
):
popen_mock.return_value.pid = 12345
- _JavaActivitySubprocess.start(
+ _PopenActivitySubprocess.start(
what=_make_ti(),
- dag_rel_path="dags/test.jar",
+ dag_rel_path="bundle",
bundle_info=MagicMock(),
client=mock_client,
- java_executable="java",
- jvm_args=[],
- jars_root=[jars_root],
- main_class="",
+ command=["/bin/true"],
subprocess_logs_to_stdout=False,
)
assert mock_register.mock_calls == [call(ANY, ANY, ANY, ANY, data=ANY)]