This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new cd5509fc701 Reduce SSH connection churn in `SSHRemoteJobOperator`
under high fan-out (#68115)
cd5509fc701 is described below
commit cd5509fc701cd18f32ae5b9625fa34a151c12f9e
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 9 00:27:50 2026 +0100
Reduce SSH connection churn in `SSHRemoteJobOperator` under high fan-out
(#68115)
The operator and trigger opened a new SSH connection for every remote
command. A large expand() fan-out against one host drove the connection
rate past the remote sshd MaxStartups limit, which drops connections and
surfaces as "paramiko ... Error reading SSH protocol banner" (an immediate
EOF, not a banner timeout) at submit time, and left job directories behind
when the cleanup connection was dropped too.
Changes:
- Trigger holds one connection for the whole poll loop instead of
reconnecting per command, with bounded jittered reconnect on drops and
asyncssh.Error treated as reconnectable.
- Operator reuses one connection for OS detection and submission.
- Cleanup retries instead of orphaning the job directory on a dropped
connection.
- Configurable conn_retry_attempts (operator/hook) for the submit burst,
plus command_timeout and max_reconnect_attempts forwarded to the trigger.
- SSHHookAsync sets a keepalive on the long-lived trigger connection.
* Fix mypy return type and docs spellcheck in SSH remote job trigger
- _run_command decodes bytes stdout/stderr so the return matches
tuple[int, str, str] (asyncssh types them as bytes | str).
- Drop 'jittered'/'desynchronise' from docstrings (Sphinx spellcheck).
* Address review: log hint on submit-connection failure + mapped-task docs
- execute() now logs an actionable hint when the submit connection fails
(sshd
MaxStartups under concurrency, conn_retry_attempts,
pools/max_active_tis_per_dag),
then re-raises the original error. Scoped to the connect step, no error
matching.
- High Fan-out docs: link the dynamic task mapping limits page and note the
storm is not specific to mapped tasks (parallel runs/high concurrency
too).
---
providers/ssh/docs/operators/ssh_remote_job.rst | 42 +++-
.../ssh/src/airflow/providers/ssh/hooks/ssh.py | 25 +-
.../providers/ssh/operators/ssh_remote_job.py | 242 +++++++++++-------
.../providers/ssh/triggers/ssh_remote_job.py | 252 +++++++++++--------
providers/ssh/tests/unit/ssh/hooks/test_ssh.py | 20 ++
.../unit/ssh/operators/test_ssh_remote_job.py | 131 ++++++++++
.../tests/unit/ssh/triggers/test_ssh_remote_job.py | 273 +++++++++++++--------
7 files changed, 698 insertions(+), 287 deletions(-)
diff --git a/providers/ssh/docs/operators/ssh_remote_job.rst
b/providers/ssh/docs/operators/ssh_remote_job.rst
index c78a0792740..7a8783b9861 100644
--- a/providers/ssh/docs/operators/ssh_remote_job.rst
+++ b/providers/ssh/docs/operators/ssh_remote_job.rst
@@ -164,6 +164,13 @@ Parameters
* ``remote_os`` (str, optional): Remote OS type (``"auto"``, ``"posix"``,
``"windows"``). Default: ``"auto"``
* ``skip_on_exit_code`` (int or list, optional): Exit code(s) that should
cause task to skip instead of fail
+* ``conn_timeout`` (int, optional): SSH connection timeout in seconds
+* ``banner_timeout`` (float, optional): Seconds to wait for the SSH banner.
Default: 30.0
+* ``conn_retry_attempts`` (int, optional): How many times to attempt the
initial SSH connection for
+ submission and cleanup before failing. Default: 5. Raise this for large
fan-outs where the remote
+ ``sshd`` transiently refuses connections (see :ref:`High Fan-out
<howto/operator:SSHRemoteJobOperator:fanout>`)
+* ``cleanup_retries`` (int, optional): How many times to retry remote
directory cleanup before giving up
+ and leaving the directory in place. Default: 3
Remote OS Detection
-------------------
@@ -213,7 +220,9 @@ Limitations and Considerations
-------------------------------
**Network Interruptions**: While the operator is resilient to disconnections
during monitoring,
-the initial job submission must succeed. If submission fails, the task will
fail immediately.
+the initial job submission must succeed. The connection used for submission is
retried
+(``conn_retry_attempts``); if every attempt fails, the task fails immediately.
The trigger also
+reconnects automatically if the monitoring connection drops mid-job.
**Remote Process Management**: Jobs are detached using ``nohup`` (POSIX) or
``Start-Process`` (Windows).
If the remote host reboots during job execution, the job will be lost.
@@ -231,7 +240,36 @@ tasks can run on the same remote host without conflicts.
**Cleanup**: Use ``cleanup="on_success"`` or ``cleanup="always"`` to avoid
accumulating
job directories on the remote host. For debugging, use ``cleanup="never"`` and
manually
-inspect the job directory.
+inspect the job directory. Cleanup runs only when the job reaches completion,
so tasks that
+are killed or time out can still leave a directory behind; for those, add a
server-side TTL
+reaper (for example ``systemd-tmpfiles`` or a cron job) for the base directory.
+
+.. _howto/operator:SSHRemoteJobOperator:fanout:
+
+High Fan-out (Many Concurrent Tasks)
+-------------------------------------
+
+Many tasks targeting the same SSH server at once (a large ``.expand()``
fan-out, parallel DAG
+runs, or just high concurrency) can overwhelm it. Each remote
+command opens a new SSH connection, and the remote ``sshd`` throttles
concurrent
+*unauthenticated* connections via ``MaxStartups`` (default ``10:30:100``:
start randomly
+dropping at 10 concurrent, reaching 100% at 100). A dropped connection
surfaces on the client
+as::
+
+ paramiko ... Error reading SSH protocol banner
+
+This is the server closing the socket before the handshake, not a slow banner,
so raising
+``banner_timeout`` does not help.
+
+The operator and trigger keep the connection rate low: submission reuses a
single connection
+for OS detection and the submit itself, and the trigger holds **one**
connection for the whole
+poll loop instead of reconnecting on every status check. To push a high
fan-out further:
+
+* Raise ``MaxStartups`` (and ``MaxSessions``) on the remote ``sshd`` -- this
is the direct fix.
+* Increase ``conn_retry_attempts`` so transient refusals during the initial
burst are retried.
+* Cap how many mapped tasks run at once with ``max_active_tis_per_dag`` (or a
pool) instead of
+ releasing the entire fan-out simultaneously. See the "Placing Limits on
Mapped Tasks" section of
+ :doc:`apache-airflow:authoring-and-scheduling/dynamic-task-mapping` for the
available limits.
Comparison with SSHOperator
----------------------------
diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
index d029a50772b..54b7edf6008 100644
--- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
@@ -82,6 +82,9 @@ class SSHHook(BaseHook):
lifetime of the transport
:param ciphers: list of ciphers to use in order of preference
:param auth_timeout: timeout (in seconds) for the attempt to authenticate
with the remote_host
+ :param conn_retry_attempts: number of times to attempt the initial SSH
connection before
+ giving up (default 3). Raising this helps when many tasks target the
same SSH server at
+ once and some connections are transiently refused (e.g. ``sshd``
``MaxStartups`` throttling).
"""
# List of classes to try loading private keys as, ordered (roughly) by
most common to least common
@@ -130,9 +133,11 @@ class SSHHook(BaseHook):
ciphers: list[str] | None = None,
auth_timeout: int | None = None,
host_proxy_cmd: str | None = None,
+ conn_retry_attempts: int = 3,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
+ self.conn_retry_attempts = max(1, conn_retry_attempts)
self.remote_host = remote_host
self.username = username
self.password = password
@@ -344,7 +349,7 @@ class SSHHook(BaseHook):
for attempt in Retrying(
reraise=True,
wait=wait_fixed(3) + wait_random(0, 2),
- stop=stop_after_attempt(3),
+ stop=stop_after_attempt(self.conn_retry_attempts),
before_sleep=log_before_sleep,
):
with attempt:
@@ -553,6 +558,7 @@ class SSHHookAsync(BaseHook):
key_file: str = "",
passphrase: str = "",
private_key: str = "",
+ keepalive_interval: int = 30,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
@@ -564,6 +570,7 @@ class SSHHookAsync(BaseHook):
self.key_file = key_file
self.passphrase = passphrase
self.private_key = private_key
+ self.keepalive_interval = keepalive_interval
def _parse_extras(self, conn: Any) -> None:
"""Parse extra fields from the connection into instance fields."""
@@ -631,10 +638,26 @@ class SSHHookAsync(BaseHook):
conn_config["client_keys"] = [_private_key]
if self.passphrase:
conn_config["passphrase"] = self.passphrase
+ if self.keepalive_interval:
+ # The trigger holds one connection for the whole job; a keepalive
stops idle
+ # NAT/firewall timeouts from silently dropping it between long
poll intervals.
+ conn_config["keepalive_interval"] = self.keepalive_interval
ssh_client_conn = await asyncssh.connect(**conn_config)
return ssh_client_conn
+ async def get_conn(self):
+ """
+ Open an asyncssh connection that can be reused for multiple commands.
+
+ Unlike :meth:`run_command`, the returned connection is **not** closed
+ automatically; the caller owns its lifecycle (e.g.
+ ``async with await hook.get_conn() as conn: ...`` or an explicit
+ ``conn.close()``). Reusing one connection avoids a new TCP/SSH
handshake
+ per command, which matters when many tasks poll the same SSH server.
+ """
+ return await self._get_conn()
+
async def run_command(self, command: str, timeout: float | None = None) ->
tuple[int, str, str]:
"""
Execute a command on the remote host asynchronously.
diff --git
a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
index 783edb39b8f..dc3414eca68 100644
--- a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
+++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import time
import warnings
from collections.abc import Container, Sequence
from datetime import timedelta
@@ -74,6 +75,25 @@ class SSHRemoteJobOperator(BaseOperator):
:param skip_on_exit_code: Exit codes that should skip the task instead of
failing
:param conn_timeout: SSH connection timeout in seconds
:param banner_timeout: Timeout waiting for SSH banner in seconds
+ :param conn_retry_attempts: How many times to attempt the initial SSH
connection for
+ submission/cleanup before failing (default 5). Helps when many mapped
tasks hit the
+ same host at once and ``sshd`` transiently refuses connections
(``MaxStartups``).
+ :param cleanup_retries: How many times to attempt remote directory cleanup
before
+ giving up and leaving the directory in place (default 3). Prevents a
transient SSH
+ failure during cleanup from orphaning the job directory on the remote
host.
+ :param command_timeout: Per-command timeout in seconds for the trigger's
status/log polls
+ (default 30.0).
+ :param max_reconnect_attempts: Consecutive connection failures the trigger
tolerates (with
+ backoff) before failing the task while monitoring the remote job
(default 5).
+
+ .. note::
+ A large ``expand()`` fan-out opens many SSH connections against one
host. The remote
+ ``sshd`` throttles concurrent unauthenticated connections via
``MaxStartups`` (default
+ ``10:30:100``); when exceeded it drops connections, surfacing as
+ ``paramiko ... Error reading SSH protocol banner``. For high fan-out,
raise ``MaxStartups``
+ on the server. The directory ``/tmp/airflow-ssh-jobs`` (POSIX) is only
cleaned when
+ ``cleanup`` is set and the job reaches completion, so also consider a
server-side TTL
+ reaper (for example ``systemd-tmpfiles``) for jobs that are killed or
time out.
"""
template_fields: Sequence[str] = ("command", "environment", "remote_host",
"remote_base_dir")
@@ -104,6 +124,10 @@ class SSHRemoteJobOperator(BaseOperator):
skip_on_exit_code: int | Container[int] | None = None,
conn_timeout: int | None = None,
banner_timeout: float = 30.0,
+ conn_retry_attempts: int = 5,
+ cleanup_retries: int = 3,
+ command_timeout: float = 30.0,
+ max_reconnect_attempts: int = 5,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -123,6 +147,10 @@ class SSHRemoteJobOperator(BaseOperator):
self.remote_os = remote_os
self.conn_timeout = conn_timeout
self.banner_timeout = banner_timeout
+ self.conn_retry_attempts = conn_retry_attempts
+ self.cleanup_retries = max(1, cleanup_retries)
+ self.command_timeout = command_timeout
+ self.max_reconnect_attempts = max_reconnect_attempts
self.skip_on_exit_code = (
skip_on_exit_code
if isinstance(skip_on_exit_code, Container)
@@ -170,67 +198,69 @@ class SSHRemoteJobOperator(BaseOperator):
remote_host=self.remote_host or "",
conn_timeout=self.conn_timeout,
banner_timeout=self.banner_timeout,
+ conn_retry_attempts=self.conn_retry_attempts,
)
- def _detect_remote_os(self) -> Literal["posix", "windows"]:
+ def _detect_remote_os(self, ssh_client) -> Literal["posix", "windows"]:
"""
- Detect the remote operating system.
+ Detect the remote operating system on an already-open SSH connection.
Uses a two-stage detection:
1. Try POSIX detection via `uname` (works on Linux, macOS, BSD,
Solaris, AIX, etc.)
2. Try Windows detection via PowerShell
3. Raise error if both fail
+
+ :param ssh_client: An open paramiko SSH client to reuse (avoids a
second handshake).
"""
if self.remote_os != "auto":
return self.remote_os
self.log.info("Auto-detecting remote operating system...")
- with self.ssh_hook.get_conn() as ssh_client:
- try:
- exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
- ssh_client,
- build_posix_os_detection_command(),
- get_pty=False,
- environment=None,
- timeout=10,
- )
- if exit_status == 0 and stdout:
- output = stdout.decode("utf-8",
errors="replace").strip().lower()
- posix_systems = [
- "linux",
- "darwin",
- "freebsd",
- "openbsd",
- "netbsd",
- "sunos",
- "aix",
- "hp-ux",
- ]
- if any(system in output for system in posix_systems):
- self.log.info("Detected POSIX system: %s", output)
- return "posix"
- except Exception as e:
- self.log.debug("POSIX detection failed: %s", e)
-
- try:
- exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
- ssh_client,
- build_windows_os_detection_command(),
- get_pty=False,
- environment=None,
- timeout=10,
- )
- if exit_status == 0 and stdout:
- output = stdout.decode("utf-8", errors="replace").strip()
- if "WINDOWS" in output.upper():
- self.log.info("Detected Windows system")
- return "windows"
- except Exception as e:
- self.log.debug("Windows detection failed: %s", e)
+ try:
+ exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ build_posix_os_detection_command(),
+ get_pty=False,
+ environment=None,
+ timeout=10,
+ )
+ if exit_status == 0 and stdout:
+ output = stdout.decode("utf-8",
errors="replace").strip().lower()
+ posix_systems = [
+ "linux",
+ "darwin",
+ "freebsd",
+ "openbsd",
+ "netbsd",
+ "sunos",
+ "aix",
+ "hp-ux",
+ ]
+ if any(system in output for system in posix_systems):
+ self.log.info("Detected POSIX system: %s", output)
+ return "posix"
+ except Exception as e:
+ self.log.debug("POSIX detection failed: %s", e)
- raise AirflowException(
- "Could not auto-detect remote OS. Please explicitly set
remote_os='posix' or 'windows'"
+ try:
+ exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ build_windows_os_detection_command(),
+ get_pty=False,
+ environment=None,
+ timeout=10,
)
+ if exit_status == 0 and stdout:
+ output = stdout.decode("utf-8", errors="replace").strip()
+ if "WINDOWS" in output.upper():
+ self.log.info("Detected Windows system")
+ return "windows"
+ except Exception as e:
+ self.log.debug("Windows detection failed: %s", e)
+
+ raise AirflowException(
+ "Could not auto-detect remote OS. Please explicitly set
remote_os='posix' or 'windows'"
+ )
def execute(self, context: Context) -> None:
"""
@@ -241,9 +271,6 @@ class SSHRemoteJobOperator(BaseOperator):
if not self.command:
raise AirflowException("SSH operator error: command not
specified.")
- self._detected_os = self._detect_remote_os()
- self.log.info("Remote OS: %s", self._detected_os)
-
ti = context["ti"]
self._job_id = generate_job_id(
dag_id=ti.dag_id,
@@ -253,27 +280,49 @@ class SSHRemoteJobOperator(BaseOperator):
)
self.log.info("Generated job ID: %s", self._job_id)
- self._paths = RemoteJobPaths(
- job_id=self._job_id,
- remote_os=self._detected_os,
- base_dir=self.remote_base_dir,
- )
-
- if self._detected_os == "posix":
- wrapper_cmd = build_posix_wrapper_command(
- command=self.command,
- paths=self._paths,
- environment=self.environment,
+ # Reuse a single connection for OS detection (when 'auto') and
submission so the
+ # operator opens one SSH handshake per task instead of two. Under a
large fan-out
+ # this halves the connection burst that triggers sshd MaxStartups
throttling.
+ self.log.info("Connecting to %s", self.ssh_hook.remote_host)
+ try:
+ ssh_conn = self.ssh_hook.get_conn()
+ except Exception:
+ self.log.error(
+ "Failed to connect to %s to submit the remote job. When many
SSH connections reach "
+ "the same host at once, the server can start refusing new ones
before the handshake "
+ "(for example sshd MaxStartups). This is not limited to mapped
tasks: parallel DAG "
+ "runs or high concurrency can cause it too. Try raising
MaxStartups/MaxSessions on "
+ "the server, increasing conn_retry_attempts (currently %d), or
reducing concurrency "
+ "with a pool (or max_active_tis_per_dag for mapped tasks). See
the "
+ "SSHRemoteJobOperator 'High Fan-out' docs.",
+ self.ssh_hook.remote_host,
+ self.conn_retry_attempts,
)
- else:
- wrapper_cmd = build_windows_wrapper_command(
- command=self.command,
- paths=self._paths,
- environment=self.environment,
+ raise
+ with ssh_conn as ssh_client:
+ self._detected_os = self._detect_remote_os(ssh_client)
+ self.log.info("Remote OS: %s", self._detected_os)
+
+ self._paths = RemoteJobPaths(
+ job_id=self._job_id,
+ remote_os=self._detected_os,
+ base_dir=self.remote_base_dir,
)
- self.log.info("Submitting remote job to %s", self.ssh_hook.remote_host)
- with self.ssh_hook.get_conn() as ssh_client:
+ if self._detected_os == "posix":
+ wrapper_cmd = build_posix_wrapper_command(
+ command=self.command,
+ paths=self._paths,
+ environment=self.environment,
+ )
+ else:
+ wrapper_cmd = build_windows_wrapper_command(
+ command=self.command,
+ paths=self._paths,
+ environment=self.environment,
+ )
+
+ self.log.info("Submitting remote job to %s",
self.ssh_hook.remote_host)
exit_status, stdout, stderr =
self.ssh_hook.exec_ssh_client_command(
ssh_client,
wrapper_cmd,
@@ -320,6 +369,8 @@ class SSHRemoteJobOperator(BaseOperator):
poll_interval=self.poll_interval,
log_chunk_size=self.log_chunk_size,
log_offset=0,
+ command_timeout=self.command_timeout,
+ max_reconnect_attempts=self.max_reconnect_attempts,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.timeout) if self.timeout else None,
@@ -361,6 +412,8 @@ class SSHRemoteJobOperator(BaseOperator):
poll_interval=self.poll_interval,
log_chunk_size=self.log_chunk_size,
log_offset=event.get("log_offset", 0),
+ command_timeout=self.command_timeout,
+ max_reconnect_attempts=self.max_reconnect_attempts,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.timeout) if self.timeout else
None,
@@ -389,25 +442,46 @@ class SSHRemoteJobOperator(BaseOperator):
self.log.info("Remote job completed successfully")
def _cleanup_remote_job(self, job_dir: str, remote_os: str) -> None:
- """Clean up the remote job directory."""
+ """
+ Clean up the remote job directory, retrying on transient SSH failures.
+
+ Under a large fan-out the cleanup connection can itself be refused by
the
+ remote ``sshd`` (``MaxStartups``). Retrying a few times keeps a
transient drop
+ from orphaning the job directory; if every attempt fails we log loudly
and
+ leave the directory rather than failing the (already finished) task.
+ """
self.log.info("Cleaning up remote job directory: %s", job_dir)
- try:
- if remote_os == "posix":
- cleanup_cmd = build_posix_cleanup_command(job_dir)
- else:
- cleanup_cmd = build_windows_cleanup_command(job_dir)
+ if remote_os == "posix":
+ cleanup_cmd = build_posix_cleanup_command(job_dir)
+ else:
+ cleanup_cmd = build_windows_cleanup_command(job_dir)
- with self.ssh_hook.get_conn() as ssh_client:
- self.ssh_hook.exec_ssh_client_command(
- ssh_client,
- cleanup_cmd,
- get_pty=False,
- environment=None,
- timeout=30,
- )
- self.log.info("Remote cleanup completed")
- except Exception as e:
- self.log.warning("Failed to clean up remote job directory: %s", e)
+ last_error: Exception | None = None
+ for attempt in range(1, self.cleanup_retries + 1):
+ try:
+ with self.ssh_hook.get_conn() as ssh_client:
+ self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ cleanup_cmd,
+ get_pty=False,
+ environment=None,
+ timeout=30,
+ )
+ self.log.info("Remote cleanup completed")
+ return
+ except Exception as e:
+ last_error = e
+ self.log.warning("Cleanup attempt %d/%d failed: %s", attempt,
self.cleanup_retries, e)
+ if attempt < self.cleanup_retries:
+ time.sleep(min(2**attempt, 10))
+
+ self.log.warning(
+ "Failed to clean up remote job directory after %d attempts;
leaving orphaned "
+ "directory %s on the remote host (last error: %s)",
+ self.cleanup_retries,
+ job_dir,
+ last_error,
+ )
def on_kill(self) -> None:
"""
diff --git a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
index 0d4072c1ca4..2d390a3f2e6 100644
--- a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
+++ b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
@@ -20,10 +20,11 @@
from __future__ import annotations
import asyncio
+import random
from collections.abc import AsyncIterator
-from typing import Any, Literal
+from typing import TYPE_CHECKING, Any, Literal
-import tenacity
+import asyncssh
from airflow.providers.ssh.hooks.ssh import SSHHookAsync
from airflow.providers.ssh.utils.remote_job import (
@@ -36,6 +37,16 @@ from airflow.providers.ssh.utils.remote_job import (
)
from airflow.triggers.base import BaseTrigger, TriggerEvent
+if TYPE_CHECKING:
+ from asyncssh import SSHClientConnection
+
+# Errors that mean the connection itself is broken/refused and the poll should
+# reconnect instead of failing the job. ``asyncssh.Error`` covers handshake,
+# protocol and disconnect failures (e.g. an sshd that drops the connection
under
+# ``MaxStartups`` load); ``OSError`` covers TCP-level refusals;
``TimeoutError``
+# covers a wedged command or connection.
+_CONNECTION_ERRORS = (OSError, asyncssh.Error, TimeoutError)
+
class SSHRemoteJobTrigger(BaseTrigger):
"""
@@ -44,6 +55,13 @@ class SSHRemoteJobTrigger(BaseTrigger):
This trigger polls the remote host to check job completion status
and reads log output incrementally.
+ A single SSH connection is opened and reused for the whole poll loop
instead
+ of reconnecting for every command. Opening a fresh TCP/SSH connection per
poll
+ multiplies the connection rate against the remote ``sshd`` (which throttles
+ concurrent unauthenticated connections via ``MaxStartups``), so reuse
keeps the
+ load flat when many tasks target the same host. If the connection drops,
the
+ trigger transparently reconnects with backoff up to
``max_reconnect_attempts``.
+
:param ssh_conn_id: SSH connection ID from Airflow Connections
:param remote_host: Optional override for the remote host
:param job_id: Unique identifier for the remote job
@@ -54,6 +72,9 @@ class SSHRemoteJobTrigger(BaseTrigger):
:param poll_interval: Seconds between polling attempts
:param log_chunk_size: Maximum bytes to read per poll
:param log_offset: Current byte offset in the log file
+ :param command_timeout: Per-command timeout in seconds
+ :param max_reconnect_attempts: Consecutive connection failures tolerated
before the
+ trigger gives up and emits an error event
"""
def __init__(
@@ -69,6 +90,7 @@ class SSHRemoteJobTrigger(BaseTrigger):
log_chunk_size: int = 65536,
log_offset: int = 0,
command_timeout: float = 30.0,
+ max_reconnect_attempts: int = 5,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
@@ -82,6 +104,7 @@ class SSHRemoteJobTrigger(BaseTrigger):
self.log_chunk_size = log_chunk_size
self.log_offset = log_offset
self.command_timeout = command_timeout
+ self.max_reconnect_attempts = max_reconnect_attempts
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the trigger for storage."""
@@ -99,6 +122,7 @@ class SSHRemoteJobTrigger(BaseTrigger):
"log_chunk_size": self.log_chunk_size,
"log_offset": self.log_offset,
"command_timeout": self.command_timeout,
+ "max_reconnect_attempts": self.max_reconnect_attempts,
},
)
@@ -109,18 +133,42 @@ class SSHRemoteJobTrigger(BaseTrigger):
host=self.remote_host,
)
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
- retry=tenacity.retry_if_exception_type((OSError, TimeoutError,
ConnectionError)),
- reraise=True,
- )
- async def _check_completion(self, hook: SSHHookAsync) -> int | None:
+ async def _connect(self) -> SSHClientConnection:
+ """Open a reusable asyncssh connection. Separated out as a seam for
testing."""
+ return await self._get_hook().get_conn()
+
+ @staticmethod
+ async def _close(conn: SSHClientConnection) -> None:
+ """Close a connection, swallowing teardown errors."""
+ try:
+ conn.close()
+ await conn.wait_closed()
+ except Exception:
+ # Teardown is best-effort; a failing close has nothing actionable
to recover.
+ pass
+
+ def _reconnect_delay(self, attempt: int) -> float:
+ """Exponential backoff with randomness so reconnecting triggers do not
retry in lockstep."""
+ base = min(2 ** (attempt - 1), 30)
+ return base + random.uniform(0, base)
+
+ async def _run_command(self, conn: SSHClientConnection, command: str) ->
tuple[int, str, str]:
+ """Run a command on an existing connection, mirroring
``SSHHookAsync.run_command``."""
+ result = await conn.run(command, timeout=self.command_timeout,
check=False)
+ stdout = result.stdout or ""
+ stderr = result.stderr or ""
+ # asyncssh types stdout/stderr as bytes | str; with the default text
encoding they are
+ # str, but decode defensively so the helper holds if a binary
connection is ever used.
+ if isinstance(stdout, bytes):
+ stdout = stdout.decode("utf-8", errors="replace")
+ if isinstance(stderr, bytes):
+ stderr = stderr.decode("utf-8", errors="replace")
+ return result.exit_status or 0, stdout, stderr
+
+ async def _check_completion(self, conn: SSHClientConnection) -> int | None:
"""
Check if the remote job has completed.
- Retries transient network errors up to 3 times with exponential
backoff.
-
:return: Exit code if completed, None if still running
"""
if self.remote_os == "posix":
@@ -128,62 +176,32 @@ class SSHRemoteJobTrigger(BaseTrigger):
else:
cmd = build_windows_completion_check_command(self.exit_code_file)
- try:
- _, stdout, _ = await hook.run_command(cmd,
timeout=self.command_timeout)
- stdout = stdout.strip()
- if stdout and stdout.isdigit():
- return int(stdout)
- except (OSError, TimeoutError, ConnectionError) as e:
- self.log.warning("Transient error checking completion (will
retry): %s", e)
- raise
- except Exception as e:
- self.log.warning("Error checking completion status: %s", e)
+ _, stdout, _ = await self._run_command(conn, cmd)
+ stdout = stdout.strip()
+ if stdout and stdout.isdigit():
+ return int(stdout)
return None
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
- retry=tenacity.retry_if_exception_type((OSError, TimeoutError,
ConnectionError)),
- reraise=True,
- )
- async def _get_log_size(self, hook: SSHHookAsync) -> int:
- """
- Get the current size of the log file in bytes.
-
- Retries transient network errors up to 3 times with exponential
backoff.
- """
+ async def _get_log_size(self, conn: SSHClientConnection) -> int:
+ """Get the current size of the log file in bytes."""
if self.remote_os == "posix":
cmd = build_posix_file_size_command(self.log_file)
else:
cmd = build_windows_file_size_command(self.log_file)
- try:
- _, stdout, _ = await hook.run_command(cmd,
timeout=self.command_timeout)
- stdout = stdout.strip()
- if stdout and stdout.isdigit():
- return int(stdout)
- except (OSError, TimeoutError, ConnectionError) as e:
- self.log.warning("Transient error getting log size (will retry):
%s", e)
- raise
- except Exception as e:
- self.log.warning("Error getting log file size: %s", e)
+ _, stdout, _ = await self._run_command(conn, cmd)
+ stdout = stdout.strip()
+ if stdout and stdout.isdigit():
+ return int(stdout)
return 0
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
- retry=tenacity.retry_if_exception_type((OSError, TimeoutError,
ConnectionError)),
- reraise=True,
- )
- async def _read_log_chunk(self, hook: SSHHookAsync) -> tuple[str, int]:
+ async def _read_log_chunk(self, conn: SSHClientConnection) -> tuple[str,
int]:
"""
Read a chunk of logs from the current offset.
- Retries transient network errors up to 3 times with exponential
backoff.
-
:return: Tuple of (log_chunk, new_offset)
"""
- file_size = await self._get_log_size(hook)
+ file_size = await self._get_log_size(conn)
if file_size <= self.log_offset:
return "", self.log_offset
@@ -195,47 +213,94 @@ class SSHRemoteJobTrigger(BaseTrigger):
else:
cmd = build_windows_log_tail_command(self.log_file,
self.log_offset, bytes_to_read)
- try:
- exit_code, stdout, _ = await hook.run_command(cmd,
timeout=self.command_timeout)
+ _, stdout, _ = await self._run_command(conn, cmd)
- # Advance offset by bytes requested, not decoded string length
- new_offset = self.log_offset + bytes_to_read if stdout else
self.log_offset
+ # Advance offset by bytes requested, not decoded string length
+ new_offset = self.log_offset + bytes_to_read if stdout else
self.log_offset
+ return stdout, new_offset
- return stdout, new_offset
- except (OSError, TimeoutError, ConnectionError) as e:
- self.log.warning("Transient error reading logs (will retry): %s",
e)
- raise
- except Exception as e:
- self.log.warning("Error reading log chunk: %s", e)
- return "", self.log_offset
+ def _error_event(self, message: str) -> TriggerEvent:
+ return TriggerEvent(
+ {
+ "job_id": self.job_id,
+ "job_dir": self.job_dir,
+ "log_file": self.log_file,
+ "exit_code_file": self.exit_code_file,
+ "remote_os": self.remote_os,
+ "status": "error",
+ "done": True,
+ "exit_code": None,
+ "log_chunk": "",
+ "log_offset": self.log_offset,
+ "message": message,
+ }
+ )
async def run(self) -> AsyncIterator[TriggerEvent]:
"""
- Poll the remote job status and yield events with log chunks.
+ Poll the remote job status and yield a completion event.
- This method runs in a loop, checking the job status and reading
- logs at each poll interval. It yields a TriggerEvent each time
- with the current status and any new log output.
+ One connection is held for the whole loop. On a connection-level
failure the
+ connection is dropped and re-established (with exponential backoff) up
to
+ ``max_reconnect_attempts`` consecutive times; any other error, or
exhausting the
+ reconnect budget, ends the trigger with an error event.
"""
- hook = self._get_hook()
+ conn: SSHClientConnection | None = None
+ # Consecutive failures since the last *fully successful* poll. A
successful
+ # handshake alone does not reset this: a connection that handshakes
but whose
+ # command channel keeps failing (e.g. ChannelOpenError under sshd
MaxSessions)
+ # must still exhaust the budget instead of looping forever.
+ failures = 0
- while True:
- try:
- exit_code = await self._check_completion(hook)
- log_chunk, new_offset = await self._read_log_chunk(hook)
+ try:
+ while True:
+ if conn is None:
+ try:
+ conn = await self._connect()
+ except _CONNECTION_ERRORS as e:
+ failures += 1
+ if failures > self.max_reconnect_attempts:
+ raise
+ delay = self._reconnect_delay(failures)
+ self.log.warning(
+ "Failed to connect to remote host (attempt %d/%d),
retrying in %.1fs: %s",
+ failures,
+ self.max_reconnect_attempts,
+ delay,
+ e,
+ )
+ await asyncio.sleep(delay)
+ continue
+
+ try:
+ exit_code = await self._check_completion(conn)
+ log_chunk, new_offset = await self._read_log_chunk(conn)
+ except _CONNECTION_ERRORS as e:
+ failures += 1
+ self.log.warning(
+ "Lost SSH connection while polling (attempt %d/%d),
reconnecting: %s",
+ failures,
+ self.max_reconnect_attempts,
+ e,
+ )
+ await self._close(conn)
+ conn = None
+ if failures > self.max_reconnect_attempts:
+ raise
+ await asyncio.sleep(self._reconnect_delay(failures))
+ continue
- base_event = {
- "job_id": self.job_id,
- "job_dir": self.job_dir,
- "log_file": self.log_file,
- "exit_code_file": self.exit_code_file,
- "remote_os": self.remote_os,
- }
+ # A full poll cycle succeeded on this connection; clear the
failure budget.
+ failures = 0
if exit_code is not None:
yield TriggerEvent(
{
- **base_event,
+ "job_id": self.job_id,
+ "job_dir": self.job_dir,
+ "log_file": self.log_file,
+ "exit_code_file": self.exit_code_file,
+ "remote_os": self.remote_os,
"status": "success" if exit_code == 0 else
"failed",
"done": True,
"exit_code": exit_code,
@@ -251,21 +316,10 @@ class SSHRemoteJobTrigger(BaseTrigger):
self.log.info("%s", log_chunk.rstrip())
await asyncio.sleep(self.poll_interval)
- except Exception as e:
- self.log.exception("Error in SSH remote job trigger")
- yield TriggerEvent(
- {
- "job_id": self.job_id,
- "job_dir": self.job_dir,
- "log_file": self.log_file,
- "exit_code_file": self.exit_code_file,
- "remote_os": self.remote_os,
- "status": "error",
- "done": True,
- "exit_code": None,
- "log_chunk": "",
- "log_offset": self.log_offset,
- "message": f"Trigger error: {e}",
- }
- )
- return
+ except Exception as e:
+ self.log.exception("Error in SSH remote job trigger")
+ yield self._error_event(f"Trigger error: {e}")
+ return
+ finally:
+ if conn is not None:
+ await self._close(conn)
diff --git a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
index 289cdfa3e48..e277c1dc3f6 100644
--- a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
+++ b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
@@ -594,6 +594,26 @@ class TestSSHHook:
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.set_missing_host_key_policy.called
is True
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_conn_retry_attempts_defaults_to_three(self, ssh_client):
+ hook = SSHHook(ssh_conn_id="ssh_default")
+ assert hook.conn_retry_attempts == 3
+
+ @mock.patch("time.sleep")
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_conn_retry_attempts_retries_until_limit(self, ssh_client,
_mock_sleep):
+ """get_conn retries the configured number of times before
re-raising."""
+ ssh_client.return_value.connect.side_effect =
paramiko.ssh_exception.SSHException(
+ "Error reading SSH protocol banner"
+ )
+ hook = SSHHook(ssh_conn_id="ssh_default", conn_retry_attempts=4)
+ assert hook.conn_retry_attempts == 4
+
+ with pytest.raises(paramiko.ssh_exception.SSHException):
+ hook.get_conn()
+
+ assert ssh_client.return_value.connect.call_count == 4
+
@mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
def
test_ssh_connection_with_host_key_where_allow_host_key_change_is_true(self,
ssh_client):
hook =
SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE)
diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
index 871f3981393..6ecf80baef2 100644
--- a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
+++ b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
@@ -123,6 +123,72 @@ class TestSSHRemoteJobOperator:
assert isinstance(exc_info.value.trigger, SSHRemoteJobTrigger)
assert exc_info.value.method_name == "execute_complete"
+ def test_execute_forwards_trigger_tuning_params(self):
+ """command_timeout and max_reconnect_attempts must reach the deferred
trigger."""
+ self.mock_hook.exec_ssh_client_command.return_value = (0, b"job", b"")
+
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ remote_os="posix",
+ command_timeout=12.5,
+ max_reconnect_attempts=9,
+ )
+ mock_ti = mock.MagicMock()
+ mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number =
"d", "t", "r", 1
+
+ with pytest.raises(TaskDeferred) as exc_info:
+ op.execute({"ti": mock_ti})
+
+ trigger = exc_info.value.trigger
+ assert trigger.command_timeout == 12.5
+ assert trigger.max_reconnect_attempts == 9
+
+ def test_execute_complete_re_defer_forwards_tuning_params(self):
+ """The re-defer path must also forward the trigger-tuning params."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ command_timeout=7.0,
+ max_reconnect_attempts=4,
+ )
+ event = {
+ "done": False,
+ "status": "running",
+ "job_id": "j",
+ "job_dir": "/tmp/airflow-ssh-jobs/j",
+ "log_file": "/tmp/airflow-ssh-jobs/j/stdout.log",
+ "exit_code_file": "/tmp/airflow-ssh-jobs/j/exit_code",
+ "remote_os": "posix",
+ "log_chunk": "",
+ "log_offset": 0,
+ "exit_code": None,
+ }
+ with pytest.raises(TaskDeferred) as exc_info:
+ op.execute_complete({}, event)
+
+ trigger = exc_info.value.trigger
+ assert trigger.command_timeout == 7.0
+ assert trigger.max_reconnect_attempts == 4
+
+ def test_execute_connect_failure_is_reraised(self):
+ """A connection failure during submit is re-raised unchanged (advisory
is log-only)."""
+ self.mock_hook.get_conn.side_effect = OSError("Error reading SSH
protocol banner")
+
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ remote_os="posix",
+ )
+ mock_ti = mock.MagicMock()
+ mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number =
"d", "t", "r", 1
+
+ with pytest.raises(OSError, match="Error reading SSH protocol banner"):
+ op.execute({"ti": mock_ti})
+
def test_execute_raises_if_no_command(self):
"""Test that execute raises if command is not specified."""
op = SSHRemoteJobOperator(
@@ -329,3 +395,68 @@ class TestSSHRemoteJobOperator:
# Should not raise even without active job
op.on_kill()
+
+ def test_execute_uses_single_connection_for_detect_and_submit(self):
+ """OS auto-detection and submission must share one SSH connection (one
handshake)."""
+ self.mock_hook.exec_ssh_client_command.side_effect = [
+ (0, b"Linux", b""), # OS detection
+ (0, b"af_test_dag_test_task_run1_try1_abc123", b""), # submission
+ ]
+
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ remote_os="auto",
+ )
+
+ mock_ti = mock.MagicMock()
+ mock_ti.dag_id = "test_dag"
+ mock_ti.task_id = "test_task"
+ mock_ti.run_id = "run1"
+ mock_ti.try_number = 1
+
+ with pytest.raises(TaskDeferred):
+ op.execute({"ti": mock_ti})
+
+ # One handshake for the whole execute(): detection + submit reuse it.
+ self.mock_hook.get_conn.assert_called_once()
+ assert self.mock_hook.exec_ssh_client_command.call_count == 2
+ assert op._detected_os == "posix"
+
+ def test_cleanup_retries_then_succeeds(self):
+ """Cleanup retries on a transient SSH failure and stops once it
succeeds."""
+ self.mock_hook.exec_ssh_client_command.side_effect = [
+ Exception("Error reading SSH protocol banner"),
+ (0, b"", b""),
+ ]
+
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ cleanup_retries=3,
+ )
+
+ with
mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep") as
mock_sleep:
+ op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123",
"posix")
+
+ assert self.mock_hook.exec_ssh_client_command.call_count == 2
+ mock_sleep.assert_called_once()
+
+ def test_cleanup_gives_up_after_retries_without_raising(self):
+ """When every cleanup attempt fails the task is not failed; the dir is
left in place."""
+ self.mock_hook.exec_ssh_client_command.side_effect =
Exception("connection refused")
+
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ cleanup_retries=3,
+ )
+
+ with
mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep"):
+ # Must not raise even though all attempts fail.
+ op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123",
"posix")
+
+ assert self.mock_hook.exec_ssh_client_command.call_count == 3
diff --git a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
index 67d9672a8cb..8e1d1e5953a 100644
--- a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
+++ b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
@@ -24,6 +24,20 @@ import pytest
from airflow.providers.ssh.triggers.ssh_remote_job import SSHRemoteJobTrigger
+def _make_trigger(**overrides):
+ kwargs = dict(
+ ssh_conn_id="test_conn",
+ remote_host=None,
+ job_id="test_job",
+ job_dir="/tmp/job",
+ log_file="/tmp/job/stdout.log",
+ exit_code_file="/tmp/job/exit_code",
+ remote_os="posix",
+ )
+ kwargs.update(overrides)
+ return SSHRemoteJobTrigger(**kwargs)
+
+
class TestSSHRemoteJobTrigger:
def test_serialization(self):
"""Test that the trigger can be serialized correctly."""
@@ -38,6 +52,7 @@ class TestSSHRemoteJobTrigger:
poll_interval=10,
log_chunk_size=32768,
log_offset=1000,
+ max_reconnect_attempts=7,
)
classpath, kwargs = trigger.serialize()
@@ -53,145 +68,201 @@ class TestSSHRemoteJobTrigger:
assert kwargs["poll_interval"] == 10
assert kwargs["log_chunk_size"] == 32768
assert kwargs["log_offset"] == 1000
+ assert kwargs["max_reconnect_attempts"] == 7
+
+ def test_serialization_round_trips(self):
+ """The serialized kwargs must be sufficient to rebuild the trigger."""
+ trigger = _make_trigger(poll_interval=3, max_reconnect_attempts=2)
+ _, kwargs = trigger.serialize()
+ rebuilt = SSHRemoteJobTrigger(**kwargs)
+ assert rebuilt.serialize() == trigger.serialize()
def test_default_values(self):
"""Test default parameter values."""
- trigger = SSHRemoteJobTrigger(
- ssh_conn_id="test_conn",
- remote_host=None,
- job_id="test_job",
- job_dir="/tmp/job",
- log_file="/tmp/job/stdout.log",
- exit_code_file="/tmp/job/exit_code",
- remote_os="posix",
- )
+ trigger = _make_trigger()
assert trigger.poll_interval == 5
assert trigger.log_chunk_size == 65536
assert trigger.log_offset == 0
+ assert trigger.max_reconnect_attempts == 5
@pytest.mark.asyncio
async def test_run_job_completed_success(self):
"""Test trigger when job completes successfully."""
- trigger = SSHRemoteJobTrigger(
- ssh_conn_id="test_conn",
- remote_host=None,
- job_id="test_job",
- job_dir="/tmp/job",
- log_file="/tmp/job/stdout.log",
- exit_code_file="/tmp/job/exit_code",
- remote_os="posix",
- )
+ trigger = _make_trigger()
- with mock.patch.object(trigger, "_check_completion", return_value=0):
- with mock.patch.object(trigger, "_read_log_chunk",
return_value=("Final output\n", 100)):
- events = []
- async for event in trigger.run():
- events.append(event)
+ with (
+ mock.patch.object(trigger, "_connect",
return_value=mock.MagicMock()),
+ mock.patch.object(trigger, "_close", return_value=None),
+ mock.patch.object(trigger, "_check_completion", return_value=0),
+ mock.patch.object(trigger, "_read_log_chunk", return_value=("Final
output\n", 100)),
+ ):
+ events = [event async for event in trigger.run()]
- assert len(events) == 1
- assert events[0].payload["status"] == "success"
- assert events[0].payload["done"] is True
- assert events[0].payload["exit_code"] == 0
- assert events[0].payload["log_chunk"] == "Final output\n"
+ assert len(events) == 1
+ assert events[0].payload["status"] == "success"
+ assert events[0].payload["done"] is True
+ assert events[0].payload["exit_code"] == 0
+ assert events[0].payload["log_chunk"] == "Final output\n"
@pytest.mark.asyncio
async def test_run_job_completed_failure(self):
"""Test trigger when job completes with failure."""
- trigger = SSHRemoteJobTrigger(
- ssh_conn_id="test_conn",
- remote_host=None,
- job_id="test_job",
- job_dir="/tmp/job",
- log_file="/tmp/job/stdout.log",
- exit_code_file="/tmp/job/exit_code",
- remote_os="posix",
- )
+ trigger = _make_trigger()
- with mock.patch.object(trigger, "_check_completion", return_value=1):
- with mock.patch.object(trigger, "_read_log_chunk",
return_value=("Error output\n", 50)):
- events = []
- async for event in trigger.run():
- events.append(event)
+ with (
+ mock.patch.object(trigger, "_connect",
return_value=mock.MagicMock()),
+ mock.patch.object(trigger, "_close", return_value=None),
+ mock.patch.object(trigger, "_check_completion", return_value=1),
+ mock.patch.object(trigger, "_read_log_chunk", return_value=("Error
output\n", 50)),
+ ):
+ events = [event async for event in trigger.run()]
- assert len(events) == 1
- assert events[0].payload["status"] == "failed"
- assert events[0].payload["done"] is True
- assert events[0].payload["exit_code"] == 1
+ assert len(events) == 1
+ assert events[0].payload["status"] == "failed"
+ assert events[0].payload["done"] is True
+ assert events[0].payload["exit_code"] == 1
@pytest.mark.asyncio
- async def test_run_job_polls_until_completion(self):
- """Test trigger polls without yielding until job completes."""
- trigger = SSHRemoteJobTrigger(
- ssh_conn_id="test_conn",
- remote_host=None,
- job_id="test_job",
- job_dir="/tmp/job",
- log_file="/tmp/job/stdout.log",
- exit_code_file="/tmp/job/exit_code",
- remote_os="posix",
- poll_interval=0.01,
- )
+ async def test_run_reuses_single_connection_across_polls(self):
+ """The connection is opened once and reused for every poll, not per
command."""
+ trigger = _make_trigger(poll_interval=0.01)
poll_count = 0
async def mock_check_completion(_):
nonlocal poll_count
poll_count += 1
- # Return None (still running) for first 2 polls, then exit code 0
- if poll_count < 3:
- return None
+ return None if poll_count < 3 else 0
+
+ with (
+ mock.patch.object(trigger, "_connect",
return_value=mock.MagicMock()) as mock_connect,
+ mock.patch.object(trigger, "_close", return_value=None) as
mock_close,
+ mock.patch.object(trigger, "_check_completion",
side_effect=mock_check_completion),
+ mock.patch.object(trigger, "_read_log_chunk",
return_value=("output\n", 50)),
+ ):
+ events = [event async for event in trigger.run()]
+
+ assert len(events) == 1
+ assert events[0].payload["status"] == "success"
+ assert poll_count == 3
+ # One connect for the whole loop, closed once at teardown.
+ assert mock_connect.call_count == 1
+ assert mock_close.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_run_reconnects_on_connection_drop(self):
+ """A connection-level error mid-poll drops the connection and
reconnects."""
+ trigger = _make_trigger(poll_interval=0.01, max_reconnect_attempts=3)
+
+ calls = {"check": 0}
+
+ async def flaky_check(_):
+ calls["check"] += 1
+ if calls["check"] == 1:
+ raise OSError("Error reading SSH protocol banner")
return 0
- with mock.patch.object(trigger, "_check_completion",
side_effect=mock_check_completion):
- with mock.patch.object(trigger, "_read_log_chunk",
return_value=("output\n", 50)):
- events = []
- async for event in trigger.run():
- events.append(event)
+ with (
+ mock.patch.object(trigger, "_connect",
return_value=mock.MagicMock()) as mock_connect,
+ mock.patch.object(trigger, "_close", return_value=None) as
mock_close,
+ mock.patch.object(trigger, "_check_completion",
side_effect=flaky_check),
+ mock.patch.object(trigger, "_read_log_chunk",
return_value=("out\n", 10)),
+ mock.patch("asyncio.sleep", new=mock.AsyncMock()),
+ ):
+ events = [event async for event in trigger.run()]
- # Only one event should be yielded (the completion event)
- assert len(events) == 1
- assert events[0].payload["status"] == "success"
- assert events[0].payload["done"] is True
- assert events[0].payload["exit_code"] == 0
- # Should have polled 3 times
- assert poll_count == 3
+ assert len(events) == 1
+ assert events[0].payload["status"] == "success"
+ # Initial connect + one reconnect after the dropped poll.
+ assert mock_connect.call_count == 2
+ # Dropped connection closed during reconnect, plus final teardown
close.
+ assert mock_close.call_count == 2
@pytest.mark.asyncio
- async def test_run_handles_exception(self):
- """Test trigger handles exceptions gracefully."""
- trigger = SSHRemoteJobTrigger(
- ssh_conn_id="test_conn",
- remote_host=None,
- job_id="test_job",
- job_dir="/tmp/job",
- log_file="/tmp/job/stdout.log",
- exit_code_file="/tmp/job/exit_code",
- remote_os="posix",
- )
+ async def test_run_gives_up_after_max_reconnects(self):
+ """When connections keep failing, the trigger emits a single error
event."""
+ trigger = _make_trigger(max_reconnect_attempts=2)
+
+ with (
+ mock.patch.object(trigger, "_connect",
side_effect=OSError("connection refused")),
+ mock.patch.object(trigger, "_close", return_value=None),
+ mock.patch("asyncio.sleep", new=mock.AsyncMock()),
+ ):
+ events = [event async for event in trigger.run()]
- with mock.patch.object(trigger, "_check_completion",
side_effect=Exception("Connection failed")):
- events = []
- async for event in trigger.run():
- events.append(event)
+ assert len(events) == 1
+ assert events[0].payload["status"] == "error"
+ assert events[0].payload["done"] is True
+ assert events[0].payload["exit_code"] is None
+ assert "connection refused" in events[0].payload["message"]
- assert len(events) == 1
- assert events[0].payload["status"] == "error"
- assert events[0].payload["done"] is True
- assert "Connection failed" in events[0].payload["message"]
+ @pytest.mark.asyncio
+ async def
test_run_gives_up_when_polls_keep_failing_despite_reconnects(self):
+ """A connection that handshakes but whose polls keep failing must
still hit the cap.
+
+ Regression: the reconnect budget must not reset on a bare successful
handshake, or a
+ connection that reconnects fine but never completes a poll (e.g.
ChannelOpenError under
+ sshd MaxSessions) would loop forever and the task would defer
indefinitely.
+ """
+ trigger = _make_trigger(max_reconnect_attempts=2)
+
+ with (
+ mock.patch.object(trigger, "_connect",
return_value=mock.MagicMock()) as mock_connect,
+ mock.patch.object(trigger, "_close", return_value=None),
+ mock.patch.object(
+ trigger, "_check_completion", side_effect=OSError("channel
open failed")
+ ) as mock_check,
+ mock.patch("asyncio.sleep", new=mock.AsyncMock()),
+ ):
+ events = [event async for event in trigger.run()]
+
+ assert len(events) == 1
+ assert events[0].payload["status"] == "error"
+ assert "channel open failed" in events[0].payload["message"]
+ # Budget = 2 -> third consecutive failure ends it (handshake succeeds
each round
+ # but never resets the counter because no poll ever completes).
+ assert mock_check.call_count == 3
+ assert mock_connect.call_count == 3
+
+ @pytest.mark.asyncio
+ async def test_run_handles_unexpected_exception(self):
+ """A non-connection error surfaces immediately as an error event."""
+ trigger = _make_trigger()
+
+ with (
+ mock.patch.object(trigger, "_connect",
return_value=mock.MagicMock()),
+ mock.patch.object(trigger, "_close", return_value=None),
+ mock.patch.object(trigger, "_check_completion",
side_effect=ValueError("boom")),
+ ):
+ events = [event async for event in trigger.run()]
+
+ assert len(events) == 1
+ assert events[0].payload["status"] == "error"
+ assert events[0].payload["done"] is True
+ assert "boom" in events[0].payload["message"]
def test_get_hook(self):
"""Test hook creation."""
- trigger = SSHRemoteJobTrigger(
- ssh_conn_id="test_conn",
- remote_host="custom.host.com",
- job_id="test_job",
- job_dir="/tmp/job",
- log_file="/tmp/job/stdout.log",
- exit_code_file="/tmp/job/exit_code",
- remote_os="posix",
- )
+ trigger = _make_trigger(remote_host="custom.host.com")
hook = trigger._get_hook()
assert hook.ssh_conn_id == "test_conn"
assert hook.host == "custom.host.com"
+
+ @pytest.mark.asyncio
+ async def test_run_command_uses_existing_connection(self):
+ """_run_command runs on the passed connection without opening a new
one."""
+ trigger = _make_trigger(command_timeout=12.0)
+
+ result = mock.MagicMock()
+ result.exit_status = 0
+ result.stdout = "42"
+ result.stderr = ""
+ conn = mock.MagicMock()
+ conn.run = mock.AsyncMock(return_value=result)
+
+ exit_code, stdout, stderr = await trigger._run_command(conn, "echo 42")
+
+ conn.run.assert_awaited_once_with("echo 42", timeout=12.0, check=False)
+ assert (exit_code, stdout, stderr) == (0, "42", "")