This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cd5509fc701 Reduce SSH connection churn in `SSHRemoteJobOperator` 
under high fan-out (#68115)
cd5509fc701 is described below

commit cd5509fc701cd18f32ae5b9625fa34a151c12f9e
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 9 00:27:50 2026 +0100

    Reduce SSH connection churn in `SSHRemoteJobOperator` under high fan-out 
(#68115)
    
    The operator and trigger opened a new SSH connection for every remote
    command. A large expand() fan-out against one host drove the connection
    rate past the remote sshd MaxStartups limit, which drops connections and
    surfaces as "paramiko ... Error reading SSH protocol banner" (an immediate
    EOF, not a banner timeout) at submit time, and left job directories behind
    when the cleanup connection was dropped too.
    
    Changes:
    - Trigger holds one connection for the whole poll loop instead of
      reconnecting per command, with bounded jittered reconnect on drops and
      asyncssh.Error treated as reconnectable.
    - Operator reuses one connection for OS detection and submission.
    - Cleanup retries instead of orphaning the job directory on a dropped
      connection.
    - Configurable conn_retry_attempts (operator/hook) for the submit burst,
      plus command_timeout and max_reconnect_attempts forwarded to the trigger.
    - SSHHookAsync sets a keepalive on the long-lived trigger connection.
    
    * Fix mypy return type and docs spellcheck in SSH remote job trigger
    
    - _run_command decodes bytes stdout/stderr so the return matches
      tuple[int, str, str] (asyncssh types them as bytes | str).
    - Drop 'jittered'/'desynchronise' from docstrings (Sphinx spellcheck).
    
    * Address review: log hint on submit-connection failure + mapped-task docs
    
    - execute() now logs an actionable hint when the submit connection fails 
(sshd
      MaxStartups under concurrency, conn_retry_attempts, 
pools/max_active_tis_per_dag),
      then re-raises the original error. Scoped to the connect step, no error 
matching.
    - High Fan-out docs: link the dynamic task mapping limits page and note the
      storm is not specific to mapped tasks (parallel runs/high concurrency 
too).
---
 providers/ssh/docs/operators/ssh_remote_job.rst    |  42 +++-
 .../ssh/src/airflow/providers/ssh/hooks/ssh.py     |  25 +-
 .../providers/ssh/operators/ssh_remote_job.py      | 242 +++++++++++-------
 .../providers/ssh/triggers/ssh_remote_job.py       | 252 +++++++++++--------
 providers/ssh/tests/unit/ssh/hooks/test_ssh.py     |  20 ++
 .../unit/ssh/operators/test_ssh_remote_job.py      | 131 ++++++++++
 .../tests/unit/ssh/triggers/test_ssh_remote_job.py | 273 +++++++++++++--------
 7 files changed, 698 insertions(+), 287 deletions(-)

diff --git a/providers/ssh/docs/operators/ssh_remote_job.rst 
b/providers/ssh/docs/operators/ssh_remote_job.rst
index c78a0792740..7a8783b9861 100644
--- a/providers/ssh/docs/operators/ssh_remote_job.rst
+++ b/providers/ssh/docs/operators/ssh_remote_job.rst
@@ -164,6 +164,13 @@ Parameters
 
 * ``remote_os`` (str, optional): Remote OS type (``"auto"``, ``"posix"``, 
``"windows"``). Default: ``"auto"``
 * ``skip_on_exit_code`` (int or list, optional): Exit code(s) that should 
cause task to skip instead of fail
+* ``conn_timeout`` (int, optional): SSH connection timeout in seconds
+* ``banner_timeout`` (float, optional): Seconds to wait for the SSH banner. 
Default: 30.0
+* ``conn_retry_attempts`` (int, optional): How many times to attempt the 
initial SSH connection for
+  submission and cleanup before failing. Default: 5. Raise this for large 
fan-outs where the remote
+  ``sshd`` transiently refuses connections (see :ref:`High Fan-out 
<howto/operator:SSHRemoteJobOperator:fanout>`)
+* ``cleanup_retries`` (int, optional): How many times to retry remote 
directory cleanup before giving up
+  and leaving the directory in place. Default: 3
 
 Remote OS Detection
 -------------------
@@ -213,7 +220,9 @@ Limitations and Considerations
 -------------------------------
 
 **Network Interruptions**: While the operator is resilient to disconnections 
during monitoring,
-the initial job submission must succeed. If submission fails, the task will 
fail immediately.
+the initial job submission must succeed. The connection used for submission is 
retried
+(``conn_retry_attempts``); if every attempt fails, the task fails immediately. 
The trigger also
+reconnects automatically if the monitoring connection drops mid-job.
 
 **Remote Process Management**: Jobs are detached using ``nohup`` (POSIX) or 
``Start-Process`` (Windows).
 If the remote host reboots during job execution, the job will be lost.
@@ -231,7 +240,36 @@ tasks can run on the same remote host without conflicts.
 
 **Cleanup**: Use ``cleanup="on_success"`` or ``cleanup="always"`` to avoid 
accumulating
 job directories on the remote host. For debugging, use ``cleanup="never"`` and 
manually
-inspect the job directory.
+inspect the job directory. Cleanup runs only when the job reaches completion, 
so tasks that
+are killed or time out can still leave a directory behind; for those, add a 
server-side TTL
+reaper (for example ``systemd-tmpfiles`` or a cron job) for the base directory.
+
+.. _howto/operator:SSHRemoteJobOperator:fanout:
+
+High Fan-out (Many Concurrent Tasks)
+-------------------------------------
+
+Many tasks targeting the same SSH server at once (a large ``.expand()`` 
fan-out, parallel DAG
+runs, or just high concurrency) can overwhelm it. Each remote
+command opens a new SSH connection, and the remote ``sshd`` throttles 
concurrent
+*unauthenticated* connections via ``MaxStartups`` (default ``10:30:100``: 
start randomly
+dropping at 10 concurrent, reaching 100% at 100). A dropped connection 
surfaces on the client
+as::
+
+    paramiko ... Error reading SSH protocol banner
+
+This is the server closing the socket before the handshake, not a slow banner, 
so raising
+``banner_timeout`` does not help.
+
+The operator and trigger keep the connection rate low: submission reuses a 
single connection
+for OS detection and the submit itself, and the trigger holds **one** 
connection for the whole
+poll loop instead of reconnecting on every status check. To push a high 
fan-out further:
+
+* Raise ``MaxStartups`` (and ``MaxSessions``) on the remote ``sshd`` -- this 
is the direct fix.
+* Increase ``conn_retry_attempts`` so transient refusals during the initial 
burst are retried.
+* Cap how many mapped tasks run at once with ``max_active_tis_per_dag`` (or a 
pool) instead of
+  releasing the entire fan-out simultaneously. See the "Placing Limits on 
Mapped Tasks" section of
+  :doc:`apache-airflow:authoring-and-scheduling/dynamic-task-mapping` for the 
available limits.
 
 Comparison with SSHOperator
 ----------------------------
diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py 
b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
index d029a50772b..54b7edf6008 100644
--- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
@@ -82,6 +82,9 @@ class SSHHook(BaseHook):
         lifetime of the transport
     :param ciphers: list of ciphers to use in order of preference
     :param auth_timeout: timeout (in seconds) for the attempt to authenticate 
with the remote_host
+    :param conn_retry_attempts: number of times to attempt the initial SSH 
connection before
+        giving up (default 3). Raising this helps when many tasks target the 
same SSH server at
+        once and some connections are transiently refused (e.g. ``sshd`` 
``MaxStartups`` throttling).
     """
 
     # List of classes to try loading private keys as, ordered (roughly) by 
most common to least common
@@ -130,9 +133,11 @@ class SSHHook(BaseHook):
         ciphers: list[str] | None = None,
         auth_timeout: int | None = None,
         host_proxy_cmd: str | None = None,
+        conn_retry_attempts: int = 3,
     ) -> None:
         super().__init__()
         self.ssh_conn_id = ssh_conn_id
+        self.conn_retry_attempts = max(1, conn_retry_attempts)
         self.remote_host = remote_host
         self.username = username
         self.password = password
@@ -344,7 +349,7 @@ class SSHHook(BaseHook):
         for attempt in Retrying(
             reraise=True,
             wait=wait_fixed(3) + wait_random(0, 2),
-            stop=stop_after_attempt(3),
+            stop=stop_after_attempt(self.conn_retry_attempts),
             before_sleep=log_before_sleep,
         ):
             with attempt:
@@ -553,6 +558,7 @@ class SSHHookAsync(BaseHook):
         key_file: str = "",
         passphrase: str = "",
         private_key: str = "",
+        keepalive_interval: int = 30,
     ) -> None:
         super().__init__()
         self.ssh_conn_id = ssh_conn_id
@@ -564,6 +570,7 @@ class SSHHookAsync(BaseHook):
         self.key_file = key_file
         self.passphrase = passphrase
         self.private_key = private_key
+        self.keepalive_interval = keepalive_interval
 
     def _parse_extras(self, conn: Any) -> None:
         """Parse extra fields from the connection into instance fields."""
@@ -631,10 +638,26 @@ class SSHHookAsync(BaseHook):
             conn_config["client_keys"] = [_private_key]
         if self.passphrase:
             conn_config["passphrase"] = self.passphrase
+        if self.keepalive_interval:
+            # The trigger holds one connection for the whole job; a keepalive 
stops idle
+            # NAT/firewall timeouts from silently dropping it between long 
poll intervals.
+            conn_config["keepalive_interval"] = self.keepalive_interval
 
         ssh_client_conn = await asyncssh.connect(**conn_config)
         return ssh_client_conn
 
+    async def get_conn(self):
+        """
+        Open an asyncssh connection that can be reused for multiple commands.
+
+        Unlike :meth:`run_command`, the returned connection is **not** closed
+        automatically; the caller owns its lifecycle (e.g.
+        ``async with await hook.get_conn() as conn: ...`` or an explicit
+        ``conn.close()``). Reusing one connection avoids a new TCP/SSH 
handshake
+        per command, which matters when many tasks poll the same SSH server.
+        """
+        return await self._get_conn()
+
     async def run_command(self, command: str, timeout: float | None = None) -> 
tuple[int, str, str]:
         """
         Execute a command on the remote host asynchronously.
diff --git 
a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py 
b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
index 783edb39b8f..dc3414eca68 100644
--- a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
+++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+import time
 import warnings
 from collections.abc import Container, Sequence
 from datetime import timedelta
@@ -74,6 +75,25 @@ class SSHRemoteJobOperator(BaseOperator):
     :param skip_on_exit_code: Exit codes that should skip the task instead of 
failing
     :param conn_timeout: SSH connection timeout in seconds
     :param banner_timeout: Timeout waiting for SSH banner in seconds
+    :param conn_retry_attempts: How many times to attempt the initial SSH 
connection for
+        submission/cleanup before failing (default 5). Helps when many mapped 
tasks hit the
+        same host at once and ``sshd`` transiently refuses connections 
(``MaxStartups``).
+    :param cleanup_retries: How many times to attempt remote directory cleanup 
before
+        giving up and leaving the directory in place (default 3). Prevents a 
transient SSH
+        failure during cleanup from orphaning the job directory on the remote 
host.
+    :param command_timeout: Per-command timeout in seconds for the trigger's 
status/log polls
+        (default 30.0).
+    :param max_reconnect_attempts: Consecutive connection failures the trigger 
tolerates (with
+        backoff) before failing the task while monitoring the remote job 
(default 5).
+
+    .. note::
+        A large ``expand()`` fan-out opens many SSH connections against one 
host. The remote
+        ``sshd`` throttles concurrent unauthenticated connections via 
``MaxStartups`` (default
+        ``10:30:100``); when exceeded it drops connections, surfacing as
+        ``paramiko ... Error reading SSH protocol banner``. For high fan-out, 
raise ``MaxStartups``
+        on the server. The directory ``/tmp/airflow-ssh-jobs`` (POSIX) is only 
cleaned when
+        ``cleanup`` is set and the job reaches completion, so also consider a 
server-side TTL
+        reaper (for example ``systemd-tmpfiles``) for jobs that are killed or 
time out.
     """
 
     template_fields: Sequence[str] = ("command", "environment", "remote_host", 
"remote_base_dir")
@@ -104,6 +124,10 @@ class SSHRemoteJobOperator(BaseOperator):
         skip_on_exit_code: int | Container[int] | None = None,
         conn_timeout: int | None = None,
         banner_timeout: float = 30.0,
+        conn_retry_attempts: int = 5,
+        cleanup_retries: int = 3,
+        command_timeout: float = 30.0,
+        max_reconnect_attempts: int = 5,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -123,6 +147,10 @@ class SSHRemoteJobOperator(BaseOperator):
         self.remote_os = remote_os
         self.conn_timeout = conn_timeout
         self.banner_timeout = banner_timeout
+        self.conn_retry_attempts = conn_retry_attempts
+        self.cleanup_retries = max(1, cleanup_retries)
+        self.command_timeout = command_timeout
+        self.max_reconnect_attempts = max_reconnect_attempts
         self.skip_on_exit_code = (
             skip_on_exit_code
             if isinstance(skip_on_exit_code, Container)
@@ -170,67 +198,69 @@ class SSHRemoteJobOperator(BaseOperator):
             remote_host=self.remote_host or "",
             conn_timeout=self.conn_timeout,
             banner_timeout=self.banner_timeout,
+            conn_retry_attempts=self.conn_retry_attempts,
         )
 
-    def _detect_remote_os(self) -> Literal["posix", "windows"]:
+    def _detect_remote_os(self, ssh_client) -> Literal["posix", "windows"]:
         """
-        Detect the remote operating system.
+        Detect the remote operating system on an already-open SSH connection.
 
         Uses a two-stage detection:
         1. Try POSIX detection via `uname` (works on Linux, macOS, BSD, 
Solaris, AIX, etc.)
         2. Try Windows detection via PowerShell
         3. Raise error if both fail
+
+        :param ssh_client: An open paramiko SSH client to reuse (avoids a 
second handshake).
         """
         if self.remote_os != "auto":
             return self.remote_os
 
         self.log.info("Auto-detecting remote operating system...")
-        with self.ssh_hook.get_conn() as ssh_client:
-            try:
-                exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
-                    ssh_client,
-                    build_posix_os_detection_command(),
-                    get_pty=False,
-                    environment=None,
-                    timeout=10,
-                )
-                if exit_status == 0 and stdout:
-                    output = stdout.decode("utf-8", 
errors="replace").strip().lower()
-                    posix_systems = [
-                        "linux",
-                        "darwin",
-                        "freebsd",
-                        "openbsd",
-                        "netbsd",
-                        "sunos",
-                        "aix",
-                        "hp-ux",
-                    ]
-                    if any(system in output for system in posix_systems):
-                        self.log.info("Detected POSIX system: %s", output)
-                        return "posix"
-            except Exception as e:
-                self.log.debug("POSIX detection failed: %s", e)
-
-            try:
-                exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
-                    ssh_client,
-                    build_windows_os_detection_command(),
-                    get_pty=False,
-                    environment=None,
-                    timeout=10,
-                )
-                if exit_status == 0 and stdout:
-                    output = stdout.decode("utf-8", errors="replace").strip()
-                    if "WINDOWS" in output.upper():
-                        self.log.info("Detected Windows system")
-                        return "windows"
-            except Exception as e:
-                self.log.debug("Windows detection failed: %s", e)
+        try:
+            exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
+                ssh_client,
+                build_posix_os_detection_command(),
+                get_pty=False,
+                environment=None,
+                timeout=10,
+            )
+            if exit_status == 0 and stdout:
+                output = stdout.decode("utf-8", 
errors="replace").strip().lower()
+                posix_systems = [
+                    "linux",
+                    "darwin",
+                    "freebsd",
+                    "openbsd",
+                    "netbsd",
+                    "sunos",
+                    "aix",
+                    "hp-ux",
+                ]
+                if any(system in output for system in posix_systems):
+                    self.log.info("Detected POSIX system: %s", output)
+                    return "posix"
+        except Exception as e:
+            self.log.debug("POSIX detection failed: %s", e)
 
-            raise AirflowException(
-                "Could not auto-detect remote OS. Please explicitly set 
remote_os='posix' or 'windows'"
+        try:
+            exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
+                ssh_client,
+                build_windows_os_detection_command(),
+                get_pty=False,
+                environment=None,
+                timeout=10,
             )
+            if exit_status == 0 and stdout:
+                output = stdout.decode("utf-8", errors="replace").strip()
+                if "WINDOWS" in output.upper():
+                    self.log.info("Detected Windows system")
+                    return "windows"
+        except Exception as e:
+            self.log.debug("Windows detection failed: %s", e)
+
+        raise AirflowException(
+            "Could not auto-detect remote OS. Please explicitly set 
remote_os='posix' or 'windows'"
+        )
 
     def execute(self, context: Context) -> None:
         """
@@ -241,9 +271,6 @@ class SSHRemoteJobOperator(BaseOperator):
         if not self.command:
             raise AirflowException("SSH operator error: command not 
specified.")
 
-        self._detected_os = self._detect_remote_os()
-        self.log.info("Remote OS: %s", self._detected_os)
-
         ti = context["ti"]
         self._job_id = generate_job_id(
             dag_id=ti.dag_id,
@@ -253,27 +280,49 @@ class SSHRemoteJobOperator(BaseOperator):
         )
         self.log.info("Generated job ID: %s", self._job_id)
 
-        self._paths = RemoteJobPaths(
-            job_id=self._job_id,
-            remote_os=self._detected_os,
-            base_dir=self.remote_base_dir,
-        )
-
-        if self._detected_os == "posix":
-            wrapper_cmd = build_posix_wrapper_command(
-                command=self.command,
-                paths=self._paths,
-                environment=self.environment,
+        # Reuse a single connection for OS detection (when 'auto') and 
submission so the
+        # operator opens one SSH handshake per task instead of two. Under a 
large fan-out
+        # this halves the connection burst that triggers sshd MaxStartups 
throttling.
+        self.log.info("Connecting to %s", self.ssh_hook.remote_host)
+        try:
+            ssh_conn = self.ssh_hook.get_conn()
+        except Exception:
+            self.log.error(
+                "Failed to connect to %s to submit the remote job. When many 
SSH connections reach "
+                "the same host at once, the server can start refusing new ones 
before the handshake "
+                "(for example sshd MaxStartups). This is not limited to mapped 
tasks: parallel DAG "
+                "runs or high concurrency can cause it too. Try raising 
MaxStartups/MaxSessions on "
+                "the server, increasing conn_retry_attempts (currently %d), or 
reducing concurrency "
+                "with a pool (or max_active_tis_per_dag for mapped tasks). See 
the "
+                "SSHRemoteJobOperator 'High Fan-out' docs.",
+                self.ssh_hook.remote_host,
+                self.conn_retry_attempts,
             )
-        else:
-            wrapper_cmd = build_windows_wrapper_command(
-                command=self.command,
-                paths=self._paths,
-                environment=self.environment,
+            raise
+        with ssh_conn as ssh_client:
+            self._detected_os = self._detect_remote_os(ssh_client)
+            self.log.info("Remote OS: %s", self._detected_os)
+
+            self._paths = RemoteJobPaths(
+                job_id=self._job_id,
+                remote_os=self._detected_os,
+                base_dir=self.remote_base_dir,
             )
 
-        self.log.info("Submitting remote job to %s", self.ssh_hook.remote_host)
-        with self.ssh_hook.get_conn() as ssh_client:
+            if self._detected_os == "posix":
+                wrapper_cmd = build_posix_wrapper_command(
+                    command=self.command,
+                    paths=self._paths,
+                    environment=self.environment,
+                )
+            else:
+                wrapper_cmd = build_windows_wrapper_command(
+                    command=self.command,
+                    paths=self._paths,
+                    environment=self.environment,
+                )
+
+            self.log.info("Submitting remote job to %s", 
self.ssh_hook.remote_host)
             exit_status, stdout, stderr = 
self.ssh_hook.exec_ssh_client_command(
                 ssh_client,
                 wrapper_cmd,
@@ -320,6 +369,8 @@ class SSHRemoteJobOperator(BaseOperator):
                 poll_interval=self.poll_interval,
                 log_chunk_size=self.log_chunk_size,
                 log_offset=0,
+                command_timeout=self.command_timeout,
+                max_reconnect_attempts=self.max_reconnect_attempts,
             ),
             method_name="execute_complete",
             timeout=timedelta(seconds=self.timeout) if self.timeout else None,
@@ -361,6 +412,8 @@ class SSHRemoteJobOperator(BaseOperator):
                     poll_interval=self.poll_interval,
                     log_chunk_size=self.log_chunk_size,
                     log_offset=event.get("log_offset", 0),
+                    command_timeout=self.command_timeout,
+                    max_reconnect_attempts=self.max_reconnect_attempts,
                 ),
                 method_name="execute_complete",
                 timeout=timedelta(seconds=self.timeout) if self.timeout else 
None,
@@ -389,25 +442,46 @@ class SSHRemoteJobOperator(BaseOperator):
         self.log.info("Remote job completed successfully")
 
     def _cleanup_remote_job(self, job_dir: str, remote_os: str) -> None:
-        """Clean up the remote job directory."""
+        """
+        Clean up the remote job directory, retrying on transient SSH failures.
+
+        Under a large fan-out the cleanup connection can itself be refused by 
the
+        remote ``sshd`` (``MaxStartups``). Retrying a few times keeps a 
transient drop
+        from orphaning the job directory; if every attempt fails we log loudly 
and
+        leave the directory rather than failing the (already finished) task.
+        """
         self.log.info("Cleaning up remote job directory: %s", job_dir)
-        try:
-            if remote_os == "posix":
-                cleanup_cmd = build_posix_cleanup_command(job_dir)
-            else:
-                cleanup_cmd = build_windows_cleanup_command(job_dir)
+        if remote_os == "posix":
+            cleanup_cmd = build_posix_cleanup_command(job_dir)
+        else:
+            cleanup_cmd = build_windows_cleanup_command(job_dir)
 
-            with self.ssh_hook.get_conn() as ssh_client:
-                self.ssh_hook.exec_ssh_client_command(
-                    ssh_client,
-                    cleanup_cmd,
-                    get_pty=False,
-                    environment=None,
-                    timeout=30,
-                )
-            self.log.info("Remote cleanup completed")
-        except Exception as e:
-            self.log.warning("Failed to clean up remote job directory: %s", e)
+        last_error: Exception | None = None
+        for attempt in range(1, self.cleanup_retries + 1):
+            try:
+                with self.ssh_hook.get_conn() as ssh_client:
+                    self.ssh_hook.exec_ssh_client_command(
+                        ssh_client,
+                        cleanup_cmd,
+                        get_pty=False,
+                        environment=None,
+                        timeout=30,
+                    )
+                self.log.info("Remote cleanup completed")
+                return
+            except Exception as e:
+                last_error = e
+                self.log.warning("Cleanup attempt %d/%d failed: %s", attempt, 
self.cleanup_retries, e)
+                if attempt < self.cleanup_retries:
+                    time.sleep(min(2**attempt, 10))
+
+        self.log.warning(
+            "Failed to clean up remote job directory after %d attempts; 
leaving orphaned "
+            "directory %s on the remote host (last error: %s)",
+            self.cleanup_retries,
+            job_dir,
+            last_error,
+        )
 
     def on_kill(self) -> None:
         """
diff --git a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py 
b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
index 0d4072c1ca4..2d390a3f2e6 100644
--- a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
+++ b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
@@ -20,10 +20,11 @@
 from __future__ import annotations
 
 import asyncio
+import random
 from collections.abc import AsyncIterator
-from typing import Any, Literal
+from typing import TYPE_CHECKING, Any, Literal
 
-import tenacity
+import asyncssh
 
 from airflow.providers.ssh.hooks.ssh import SSHHookAsync
 from airflow.providers.ssh.utils.remote_job import (
@@ -36,6 +37,16 @@ from airflow.providers.ssh.utils.remote_job import (
 )
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
+if TYPE_CHECKING:
+    from asyncssh import SSHClientConnection
+
+# Errors that mean the connection itself is broken/refused and the poll should
+# reconnect instead of failing the job. ``asyncssh.Error`` covers handshake,
+# protocol and disconnect failures (e.g. an sshd that drops the connection 
under
+# ``MaxStartups`` load); ``OSError`` covers TCP-level refusals; 
``TimeoutError``
+# covers a wedged command or connection.
+_CONNECTION_ERRORS = (OSError, asyncssh.Error, TimeoutError)
+
 
 class SSHRemoteJobTrigger(BaseTrigger):
     """
@@ -44,6 +55,13 @@ class SSHRemoteJobTrigger(BaseTrigger):
     This trigger polls the remote host to check job completion status
     and reads log output incrementally.
 
+    A single SSH connection is opened and reused for the whole poll loop 
instead
+    of reconnecting for every command. Opening a fresh TCP/SSH connection per 
poll
+    multiplies the connection rate against the remote ``sshd`` (which throttles
+    concurrent unauthenticated connections via ``MaxStartups``), so reuse 
keeps the
+    load flat when many tasks target the same host. If the connection drops, 
the
+    trigger transparently reconnects with backoff up to 
``max_reconnect_attempts``.
+
     :param ssh_conn_id: SSH connection ID from Airflow Connections
     :param remote_host: Optional override for the remote host
     :param job_id: Unique identifier for the remote job
@@ -54,6 +72,9 @@ class SSHRemoteJobTrigger(BaseTrigger):
     :param poll_interval: Seconds between polling attempts
     :param log_chunk_size: Maximum bytes to read per poll
     :param log_offset: Current byte offset in the log file
+    :param command_timeout: Per-command timeout in seconds
+    :param max_reconnect_attempts: Consecutive connection failures tolerated 
before the
+        trigger gives up and emits an error event
     """
 
     def __init__(
@@ -69,6 +90,7 @@ class SSHRemoteJobTrigger(BaseTrigger):
         log_chunk_size: int = 65536,
         log_offset: int = 0,
         command_timeout: float = 30.0,
+        max_reconnect_attempts: int = 5,
     ) -> None:
         super().__init__()
         self.ssh_conn_id = ssh_conn_id
@@ -82,6 +104,7 @@ class SSHRemoteJobTrigger(BaseTrigger):
         self.log_chunk_size = log_chunk_size
         self.log_offset = log_offset
         self.command_timeout = command_timeout
+        self.max_reconnect_attempts = max_reconnect_attempts
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serialize the trigger for storage."""
@@ -99,6 +122,7 @@ class SSHRemoteJobTrigger(BaseTrigger):
                 "log_chunk_size": self.log_chunk_size,
                 "log_offset": self.log_offset,
                 "command_timeout": self.command_timeout,
+                "max_reconnect_attempts": self.max_reconnect_attempts,
             },
         )
 
@@ -109,18 +133,42 @@ class SSHRemoteJobTrigger(BaseTrigger):
             host=self.remote_host,
         )
 
-    @tenacity.retry(
-        stop=tenacity.stop_after_attempt(3),
-        wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
-        retry=tenacity.retry_if_exception_type((OSError, TimeoutError, 
ConnectionError)),
-        reraise=True,
-    )
-    async def _check_completion(self, hook: SSHHookAsync) -> int | None:
+    async def _connect(self) -> SSHClientConnection:
+        """Open a reusable asyncssh connection. Separated out as a seam for 
testing."""
+        return await self._get_hook().get_conn()
+
+    @staticmethod
+    async def _close(conn: SSHClientConnection) -> None:
+        """Close a connection, swallowing teardown errors."""
+        try:
+            conn.close()
+            await conn.wait_closed()
+        except Exception:
+            # Teardown is best-effort; a failing close has nothing actionable 
to recover.
+            pass
+
+    def _reconnect_delay(self, attempt: int) -> float:
+        """Exponential backoff with randomness so reconnecting triggers do not 
retry in lockstep."""
+        base = min(2 ** (attempt - 1), 30)
+        return base + random.uniform(0, base)
+
+    async def _run_command(self, conn: SSHClientConnection, command: str) -> 
tuple[int, str, str]:
+        """Run a command on an existing connection, mirroring 
``SSHHookAsync.run_command``."""
+        result = await conn.run(command, timeout=self.command_timeout, 
check=False)
+        stdout = result.stdout or ""
+        stderr = result.stderr or ""
+        # asyncssh types stdout/stderr as bytes | str; with the default text 
encoding they are
+        # str, but decode defensively so the helper holds if a binary 
connection is ever used.
+        if isinstance(stdout, bytes):
+            stdout = stdout.decode("utf-8", errors="replace")
+        if isinstance(stderr, bytes):
+            stderr = stderr.decode("utf-8", errors="replace")
+        return result.exit_status or 0, stdout, stderr
+
+    async def _check_completion(self, conn: SSHClientConnection) -> int | None:
         """
         Check if the remote job has completed.
 
-        Retries transient network errors up to 3 times with exponential 
backoff.
-
         :return: Exit code if completed, None if still running
         """
         if self.remote_os == "posix":
@@ -128,62 +176,32 @@ class SSHRemoteJobTrigger(BaseTrigger):
         else:
             cmd = build_windows_completion_check_command(self.exit_code_file)
 
-        try:
-            _, stdout, _ = await hook.run_command(cmd, 
timeout=self.command_timeout)
-            stdout = stdout.strip()
-            if stdout and stdout.isdigit():
-                return int(stdout)
-        except (OSError, TimeoutError, ConnectionError) as e:
-            self.log.warning("Transient error checking completion (will 
retry): %s", e)
-            raise
-        except Exception as e:
-            self.log.warning("Error checking completion status: %s", e)
+        _, stdout, _ = await self._run_command(conn, cmd)
+        stdout = stdout.strip()
+        if stdout and stdout.isdigit():
+            return int(stdout)
         return None
 
-    @tenacity.retry(
-        stop=tenacity.stop_after_attempt(3),
-        wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
-        retry=tenacity.retry_if_exception_type((OSError, TimeoutError, 
ConnectionError)),
-        reraise=True,
-    )
-    async def _get_log_size(self, hook: SSHHookAsync) -> int:
-        """
-        Get the current size of the log file in bytes.
-
-        Retries transient network errors up to 3 times with exponential 
backoff.
-        """
+    async def _get_log_size(self, conn: SSHClientConnection) -> int:
+        """Get the current size of the log file in bytes."""
         if self.remote_os == "posix":
             cmd = build_posix_file_size_command(self.log_file)
         else:
             cmd = build_windows_file_size_command(self.log_file)
 
-        try:
-            _, stdout, _ = await hook.run_command(cmd, 
timeout=self.command_timeout)
-            stdout = stdout.strip()
-            if stdout and stdout.isdigit():
-                return int(stdout)
-        except (OSError, TimeoutError, ConnectionError) as e:
-            self.log.warning("Transient error getting log size (will retry): 
%s", e)
-            raise
-        except Exception as e:
-            self.log.warning("Error getting log file size: %s", e)
+        _, stdout, _ = await self._run_command(conn, cmd)
+        stdout = stdout.strip()
+        if stdout and stdout.isdigit():
+            return int(stdout)
         return 0
 
-    @tenacity.retry(
-        stop=tenacity.stop_after_attempt(3),
-        wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
-        retry=tenacity.retry_if_exception_type((OSError, TimeoutError, 
ConnectionError)),
-        reraise=True,
-    )
-    async def _read_log_chunk(self, hook: SSHHookAsync) -> tuple[str, int]:
+    async def _read_log_chunk(self, conn: SSHClientConnection) -> tuple[str, 
int]:
         """
         Read a chunk of logs from the current offset.
 
-        Retries transient network errors up to 3 times with exponential 
backoff.
-
         :return: Tuple of (log_chunk, new_offset)
         """
-        file_size = await self._get_log_size(hook)
+        file_size = await self._get_log_size(conn)
         if file_size <= self.log_offset:
             return "", self.log_offset
 
@@ -195,47 +213,94 @@ class SSHRemoteJobTrigger(BaseTrigger):
         else:
             cmd = build_windows_log_tail_command(self.log_file, 
self.log_offset, bytes_to_read)
 
-        try:
-            exit_code, stdout, _ = await hook.run_command(cmd, 
timeout=self.command_timeout)
+        _, stdout, _ = await self._run_command(conn, cmd)
 
-            # Advance offset by bytes requested, not decoded string length
-            new_offset = self.log_offset + bytes_to_read if stdout else 
self.log_offset
+        # Advance offset by bytes requested, not decoded string length
+        new_offset = self.log_offset + bytes_to_read if stdout else 
self.log_offset
+        return stdout, new_offset
 
-            return stdout, new_offset
-        except (OSError, TimeoutError, ConnectionError) as e:
-            self.log.warning("Transient error reading logs (will retry): %s", 
e)
-            raise
-        except Exception as e:
-            self.log.warning("Error reading log chunk: %s", e)
-            return "", self.log_offset
+    def _error_event(self, message: str) -> TriggerEvent:
+        return TriggerEvent(
+            {
+                "job_id": self.job_id,
+                "job_dir": self.job_dir,
+                "log_file": self.log_file,
+                "exit_code_file": self.exit_code_file,
+                "remote_os": self.remote_os,
+                "status": "error",
+                "done": True,
+                "exit_code": None,
+                "log_chunk": "",
+                "log_offset": self.log_offset,
+                "message": message,
+            }
+        )
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """
-        Poll the remote job status and yield events with log chunks.
+        Poll the remote job status and yield a completion event.
 
-        This method runs in a loop, checking the job status and reading
-        logs at each poll interval. It yields a TriggerEvent each time
-        with the current status and any new log output.
+        One connection is held for the whole loop. On a connection-level 
failure the
+        connection is dropped and re-established (with exponential backoff) up 
to
+        ``max_reconnect_attempts`` consecutive times; any other error, or 
exhausting the
+        reconnect budget, ends the trigger with an error event.
         """
-        hook = self._get_hook()
+        conn: SSHClientConnection | None = None
+        # Consecutive failures since the last *fully successful* poll. A 
successful
+        # handshake alone does not reset this: a connection that handshakes 
but whose
+        # command channel keeps failing (e.g. ChannelOpenError under sshd 
MaxSessions)
+        # must still exhaust the budget instead of looping forever.
+        failures = 0
 
-        while True:
-            try:
-                exit_code = await self._check_completion(hook)
-                log_chunk, new_offset = await self._read_log_chunk(hook)
+        try:
+            while True:
+                if conn is None:
+                    try:
+                        conn = await self._connect()
+                    except _CONNECTION_ERRORS as e:
+                        failures += 1
+                        if failures > self.max_reconnect_attempts:
+                            raise
+                        delay = self._reconnect_delay(failures)
+                        self.log.warning(
+                            "Failed to connect to remote host (attempt %d/%d), 
retrying in %.1fs: %s",
+                            failures,
+                            self.max_reconnect_attempts,
+                            delay,
+                            e,
+                        )
+                        await asyncio.sleep(delay)
+                        continue
+
+                try:
+                    exit_code = await self._check_completion(conn)
+                    log_chunk, new_offset = await self._read_log_chunk(conn)
+                except _CONNECTION_ERRORS as e:
+                    failures += 1
+                    self.log.warning(
+                        "Lost SSH connection while polling (attempt %d/%d), 
reconnecting: %s",
+                        failures,
+                        self.max_reconnect_attempts,
+                        e,
+                    )
+                    await self._close(conn)
+                    conn = None
+                    if failures > self.max_reconnect_attempts:
+                        raise
+                    await asyncio.sleep(self._reconnect_delay(failures))
+                    continue
 
-                base_event = {
-                    "job_id": self.job_id,
-                    "job_dir": self.job_dir,
-                    "log_file": self.log_file,
-                    "exit_code_file": self.exit_code_file,
-                    "remote_os": self.remote_os,
-                }
+                # A full poll cycle succeeded on this connection; clear the 
failure budget.
+                failures = 0
 
                 if exit_code is not None:
                     yield TriggerEvent(
                         {
-                            **base_event,
+                            "job_id": self.job_id,
+                            "job_dir": self.job_dir,
+                            "log_file": self.log_file,
+                            "exit_code_file": self.exit_code_file,
+                            "remote_os": self.remote_os,
                             "status": "success" if exit_code == 0 else 
"failed",
                             "done": True,
                             "exit_code": exit_code,
@@ -251,21 +316,10 @@ class SSHRemoteJobTrigger(BaseTrigger):
                     self.log.info("%s", log_chunk.rstrip())
                 await asyncio.sleep(self.poll_interval)
 
-            except Exception as e:
-                self.log.exception("Error in SSH remote job trigger")
-                yield TriggerEvent(
-                    {
-                        "job_id": self.job_id,
-                        "job_dir": self.job_dir,
-                        "log_file": self.log_file,
-                        "exit_code_file": self.exit_code_file,
-                        "remote_os": self.remote_os,
-                        "status": "error",
-                        "done": True,
-                        "exit_code": None,
-                        "log_chunk": "",
-                        "log_offset": self.log_offset,
-                        "message": f"Trigger error: {e}",
-                    }
-                )
-                return
+        except Exception as e:
+            self.log.exception("Error in SSH remote job trigger")
+            yield self._error_event(f"Trigger error: {e}")
+            return
+        finally:
+            if conn is not None:
+                await self._close(conn)
diff --git a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py 
b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
index 289cdfa3e48..e277c1dc3f6 100644
--- a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
+++ b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
@@ -594,6 +594,26 @@ class TestSSHHook:
             assert ssh_client.return_value.connect.called is True
             assert ssh_client.return_value.set_missing_host_key_policy.called 
is True
 
+    @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+    def test_conn_retry_attempts_defaults_to_three(self, ssh_client):
+        hook = SSHHook(ssh_conn_id="ssh_default")
+        assert hook.conn_retry_attempts == 3
+
+    @mock.patch("time.sleep")
+    @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+    def test_conn_retry_attempts_retries_until_limit(self, ssh_client, 
_mock_sleep):
+        """get_conn retries the configured number of times before 
re-raising."""
+        ssh_client.return_value.connect.side_effect = 
paramiko.ssh_exception.SSHException(
+            "Error reading SSH protocol banner"
+        )
+        hook = SSHHook(ssh_conn_id="ssh_default", conn_retry_attempts=4)
+        assert hook.conn_retry_attempts == 4
+
+        with pytest.raises(paramiko.ssh_exception.SSHException):
+            hook.get_conn()
+
+        assert ssh_client.return_value.connect.call_count == 4
+
     @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
     def 
test_ssh_connection_with_host_key_where_allow_host_key_change_is_true(self, 
ssh_client):
         hook = 
SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE)
diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py 
b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
index 871f3981393..6ecf80baef2 100644
--- a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
+++ b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
@@ -123,6 +123,72 @@ class TestSSHRemoteJobOperator:
         assert isinstance(exc_info.value.trigger, SSHRemoteJobTrigger)
         assert exc_info.value.method_name == "execute_complete"
 
+    def test_execute_forwards_trigger_tuning_params(self):
+        """command_timeout and max_reconnect_attempts must reach the deferred 
trigger."""
+        self.mock_hook.exec_ssh_client_command.return_value = (0, b"job", b"")
+
+        op = SSHRemoteJobOperator(
+            task_id="test_task",
+            ssh_conn_id="test_conn",
+            command="/path/to/script.sh",
+            remote_os="posix",
+            command_timeout=12.5,
+            max_reconnect_attempts=9,
+        )
+        mock_ti = mock.MagicMock()
+        mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number = 
"d", "t", "r", 1
+
+        with pytest.raises(TaskDeferred) as exc_info:
+            op.execute({"ti": mock_ti})
+
+        trigger = exc_info.value.trigger
+        assert trigger.command_timeout == 12.5
+        assert trigger.max_reconnect_attempts == 9
+
+    def test_execute_complete_re_defer_forwards_tuning_params(self):
+        """The re-defer path must also forward the trigger-tuning params."""
+        op = SSHRemoteJobOperator(
+            task_id="test_task",
+            ssh_conn_id="test_conn",
+            command="/path/to/script.sh",
+            command_timeout=7.0,
+            max_reconnect_attempts=4,
+        )
+        event = {
+            "done": False,
+            "status": "running",
+            "job_id": "j",
+            "job_dir": "/tmp/airflow-ssh-jobs/j",
+            "log_file": "/tmp/airflow-ssh-jobs/j/stdout.log",
+            "exit_code_file": "/tmp/airflow-ssh-jobs/j/exit_code",
+            "remote_os": "posix",
+            "log_chunk": "",
+            "log_offset": 0,
+            "exit_code": None,
+        }
+        with pytest.raises(TaskDeferred) as exc_info:
+            op.execute_complete({}, event)
+
+        trigger = exc_info.value.trigger
+        assert trigger.command_timeout == 7.0
+        assert trigger.max_reconnect_attempts == 4
+
+    def test_execute_connect_failure_is_reraised(self):
+        """A connection failure during submit is re-raised unchanged (advisory 
is log-only)."""
+        self.mock_hook.get_conn.side_effect = OSError("Error reading SSH 
protocol banner")
+
+        op = SSHRemoteJobOperator(
+            task_id="test_task",
+            ssh_conn_id="test_conn",
+            command="/path/to/script.sh",
+            remote_os="posix",
+        )
+        mock_ti = mock.MagicMock()
+        mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number = 
"d", "t", "r", 1
+
+        with pytest.raises(OSError, match="Error reading SSH protocol banner"):
+            op.execute({"ti": mock_ti})
+
     def test_execute_raises_if_no_command(self):
         """Test that execute raises if command is not specified."""
         op = SSHRemoteJobOperator(
@@ -329,3 +395,68 @@ class TestSSHRemoteJobOperator:
 
         # Should not raise even without active job
         op.on_kill()
+
+    def test_execute_uses_single_connection_for_detect_and_submit(self):
+        """OS auto-detection and submission must share one SSH connection (one 
handshake)."""
+        self.mock_hook.exec_ssh_client_command.side_effect = [
+            (0, b"Linux", b""),  # OS detection
+            (0, b"af_test_dag_test_task_run1_try1_abc123", b""),  # submission
+        ]
+
+        op = SSHRemoteJobOperator(
+            task_id="test_task",
+            ssh_conn_id="test_conn",
+            command="/path/to/script.sh",
+            remote_os="auto",
+        )
+
+        mock_ti = mock.MagicMock()
+        mock_ti.dag_id = "test_dag"
+        mock_ti.task_id = "test_task"
+        mock_ti.run_id = "run1"
+        mock_ti.try_number = 1
+
+        with pytest.raises(TaskDeferred):
+            op.execute({"ti": mock_ti})
+
+        # One handshake for the whole execute(): detection + submit reuse it.
+        self.mock_hook.get_conn.assert_called_once()
+        assert self.mock_hook.exec_ssh_client_command.call_count == 2
+        assert op._detected_os == "posix"
+
+    def test_cleanup_retries_then_succeeds(self):
+        """Cleanup retries on a transient SSH failure and stops once it 
succeeds."""
+        self.mock_hook.exec_ssh_client_command.side_effect = [
+            Exception("Error reading SSH protocol banner"),
+            (0, b"", b""),
+        ]
+
+        op = SSHRemoteJobOperator(
+            task_id="test_task",
+            ssh_conn_id="test_conn",
+            command="/path/to/script.sh",
+            cleanup_retries=3,
+        )
+
+        with 
mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep") as 
mock_sleep:
+            op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123", 
"posix")
+
+        assert self.mock_hook.exec_ssh_client_command.call_count == 2
+        mock_sleep.assert_called_once()
+
+    def test_cleanup_gives_up_after_retries_without_raising(self):
+        """When every cleanup attempt fails the task is not failed; the dir is 
left in place."""
+        self.mock_hook.exec_ssh_client_command.side_effect = 
Exception("connection refused")
+
+        op = SSHRemoteJobOperator(
+            task_id="test_task",
+            ssh_conn_id="test_conn",
+            command="/path/to/script.sh",
+            cleanup_retries=3,
+        )
+
+        with 
mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep"):
+            # Must not raise even though all attempts fail.
+            op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123", 
"posix")
+
+        assert self.mock_hook.exec_ssh_client_command.call_count == 3
diff --git a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py 
b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
index 67d9672a8cb..8e1d1e5953a 100644
--- a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
+++ b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
@@ -24,6 +24,20 @@ import pytest
 from airflow.providers.ssh.triggers.ssh_remote_job import SSHRemoteJobTrigger
 
 
+def _make_trigger(**overrides):
+    kwargs = dict(
+        ssh_conn_id="test_conn",
+        remote_host=None,
+        job_id="test_job",
+        job_dir="/tmp/job",
+        log_file="/tmp/job/stdout.log",
+        exit_code_file="/tmp/job/exit_code",
+        remote_os="posix",
+    )
+    kwargs.update(overrides)
+    return SSHRemoteJobTrigger(**kwargs)
+
+
 class TestSSHRemoteJobTrigger:
     def test_serialization(self):
         """Test that the trigger can be serialized correctly."""
@@ -38,6 +52,7 @@ class TestSSHRemoteJobTrigger:
             poll_interval=10,
             log_chunk_size=32768,
             log_offset=1000,
+            max_reconnect_attempts=7,
         )
 
         classpath, kwargs = trigger.serialize()
@@ -53,145 +68,201 @@ class TestSSHRemoteJobTrigger:
         assert kwargs["poll_interval"] == 10
         assert kwargs["log_chunk_size"] == 32768
         assert kwargs["log_offset"] == 1000
+        assert kwargs["max_reconnect_attempts"] == 7
+
+    def test_serialization_round_trips(self):
+        """The serialized kwargs must be sufficient to rebuild the trigger."""
+        trigger = _make_trigger(poll_interval=3, max_reconnect_attempts=2)
+        _, kwargs = trigger.serialize()
+        rebuilt = SSHRemoteJobTrigger(**kwargs)
+        assert rebuilt.serialize() == trigger.serialize()
 
     def test_default_values(self):
         """Test default parameter values."""
-        trigger = SSHRemoteJobTrigger(
-            ssh_conn_id="test_conn",
-            remote_host=None,
-            job_id="test_job",
-            job_dir="/tmp/job",
-            log_file="/tmp/job/stdout.log",
-            exit_code_file="/tmp/job/exit_code",
-            remote_os="posix",
-        )
+        trigger = _make_trigger()
 
         assert trigger.poll_interval == 5
         assert trigger.log_chunk_size == 65536
         assert trigger.log_offset == 0
+        assert trigger.max_reconnect_attempts == 5
 
     @pytest.mark.asyncio
     async def test_run_job_completed_success(self):
         """Test trigger when job completes successfully."""
-        trigger = SSHRemoteJobTrigger(
-            ssh_conn_id="test_conn",
-            remote_host=None,
-            job_id="test_job",
-            job_dir="/tmp/job",
-            log_file="/tmp/job/stdout.log",
-            exit_code_file="/tmp/job/exit_code",
-            remote_os="posix",
-        )
+        trigger = _make_trigger()
 
-        with mock.patch.object(trigger, "_check_completion", return_value=0):
-            with mock.patch.object(trigger, "_read_log_chunk", 
return_value=("Final output\n", 100)):
-                events = []
-                async for event in trigger.run():
-                    events.append(event)
+        with (
+            mock.patch.object(trigger, "_connect", 
return_value=mock.MagicMock()),
+            mock.patch.object(trigger, "_close", return_value=None),
+            mock.patch.object(trigger, "_check_completion", return_value=0),
+            mock.patch.object(trigger, "_read_log_chunk", return_value=("Final 
output\n", 100)),
+        ):
+            events = [event async for event in trigger.run()]
 
-                assert len(events) == 1
-                assert events[0].payload["status"] == "success"
-                assert events[0].payload["done"] is True
-                assert events[0].payload["exit_code"] == 0
-                assert events[0].payload["log_chunk"] == "Final output\n"
+        assert len(events) == 1
+        assert events[0].payload["status"] == "success"
+        assert events[0].payload["done"] is True
+        assert events[0].payload["exit_code"] == 0
+        assert events[0].payload["log_chunk"] == "Final output\n"
 
     @pytest.mark.asyncio
     async def test_run_job_completed_failure(self):
         """Test trigger when job completes with failure."""
-        trigger = SSHRemoteJobTrigger(
-            ssh_conn_id="test_conn",
-            remote_host=None,
-            job_id="test_job",
-            job_dir="/tmp/job",
-            log_file="/tmp/job/stdout.log",
-            exit_code_file="/tmp/job/exit_code",
-            remote_os="posix",
-        )
+        trigger = _make_trigger()
 
-        with mock.patch.object(trigger, "_check_completion", return_value=1):
-            with mock.patch.object(trigger, "_read_log_chunk", 
return_value=("Error output\n", 50)):
-                events = []
-                async for event in trigger.run():
-                    events.append(event)
+        with (
+            mock.patch.object(trigger, "_connect", 
return_value=mock.MagicMock()),
+            mock.patch.object(trigger, "_close", return_value=None),
+            mock.patch.object(trigger, "_check_completion", return_value=1),
+            mock.patch.object(trigger, "_read_log_chunk", return_value=("Error 
output\n", 50)),
+        ):
+            events = [event async for event in trigger.run()]
 
-                assert len(events) == 1
-                assert events[0].payload["status"] == "failed"
-                assert events[0].payload["done"] is True
-                assert events[0].payload["exit_code"] == 1
+        assert len(events) == 1
+        assert events[0].payload["status"] == "failed"
+        assert events[0].payload["done"] is True
+        assert events[0].payload["exit_code"] == 1
 
     @pytest.mark.asyncio
-    async def test_run_job_polls_until_completion(self):
-        """Test trigger polls without yielding until job completes."""
-        trigger = SSHRemoteJobTrigger(
-            ssh_conn_id="test_conn",
-            remote_host=None,
-            job_id="test_job",
-            job_dir="/tmp/job",
-            log_file="/tmp/job/stdout.log",
-            exit_code_file="/tmp/job/exit_code",
-            remote_os="posix",
-            poll_interval=0.01,
-        )
+    async def test_run_reuses_single_connection_across_polls(self):
+        """The connection is opened once and reused for every poll, not per 
command."""
+        trigger = _make_trigger(poll_interval=0.01)
 
         poll_count = 0
 
         async def mock_check_completion(_):
             nonlocal poll_count
             poll_count += 1
-            # Return None (still running) for first 2 polls, then exit code 0
-            if poll_count < 3:
-                return None
+            return None if poll_count < 3 else 0
+
+        with (
+            mock.patch.object(trigger, "_connect", 
return_value=mock.MagicMock()) as mock_connect,
+            mock.patch.object(trigger, "_close", return_value=None) as 
mock_close,
+            mock.patch.object(trigger, "_check_completion", 
side_effect=mock_check_completion),
+            mock.patch.object(trigger, "_read_log_chunk", 
return_value=("output\n", 50)),
+        ):
+            events = [event async for event in trigger.run()]
+
+        assert len(events) == 1
+        assert events[0].payload["status"] == "success"
+        assert poll_count == 3
+        # One connect for the whole loop, closed once at teardown.
+        assert mock_connect.call_count == 1
+        assert mock_close.call_count == 1
+
+    @pytest.mark.asyncio
+    async def test_run_reconnects_on_connection_drop(self):
+        """A connection-level error mid-poll drops the connection and 
reconnects."""
+        trigger = _make_trigger(poll_interval=0.01, max_reconnect_attempts=3)
+
+        calls = {"check": 0}
+
+        async def flaky_check(_):
+            calls["check"] += 1
+            if calls["check"] == 1:
+                raise OSError("Error reading SSH protocol banner")
             return 0
 
-        with mock.patch.object(trigger, "_check_completion", 
side_effect=mock_check_completion):
-            with mock.patch.object(trigger, "_read_log_chunk", 
return_value=("output\n", 50)):
-                events = []
-                async for event in trigger.run():
-                    events.append(event)
+        with (
+            mock.patch.object(trigger, "_connect", 
return_value=mock.MagicMock()) as mock_connect,
+            mock.patch.object(trigger, "_close", return_value=None) as 
mock_close,
+            mock.patch.object(trigger, "_check_completion", 
side_effect=flaky_check),
+            mock.patch.object(trigger, "_read_log_chunk", 
return_value=("out\n", 10)),
+            mock.patch("asyncio.sleep", new=mock.AsyncMock()),
+        ):
+            events = [event async for event in trigger.run()]
 
-                # Only one event should be yielded (the completion event)
-                assert len(events) == 1
-                assert events[0].payload["status"] == "success"
-                assert events[0].payload["done"] is True
-                assert events[0].payload["exit_code"] == 0
-                # Should have polled 3 times
-                assert poll_count == 3
+        assert len(events) == 1
+        assert events[0].payload["status"] == "success"
+        # Initial connect + one reconnect after the dropped poll.
+        assert mock_connect.call_count == 2
+        # Dropped connection closed during reconnect, plus final teardown 
close.
+        assert mock_close.call_count == 2
 
     @pytest.mark.asyncio
-    async def test_run_handles_exception(self):
-        """Test trigger handles exceptions gracefully."""
-        trigger = SSHRemoteJobTrigger(
-            ssh_conn_id="test_conn",
-            remote_host=None,
-            job_id="test_job",
-            job_dir="/tmp/job",
-            log_file="/tmp/job/stdout.log",
-            exit_code_file="/tmp/job/exit_code",
-            remote_os="posix",
-        )
+    async def test_run_gives_up_after_max_reconnects(self):
+        """When connections keep failing, the trigger emits a single error 
event."""
+        trigger = _make_trigger(max_reconnect_attempts=2)
+
+        with (
+            mock.patch.object(trigger, "_connect", 
side_effect=OSError("connection refused")),
+            mock.patch.object(trigger, "_close", return_value=None),
+            mock.patch("asyncio.sleep", new=mock.AsyncMock()),
+        ):
+            events = [event async for event in trigger.run()]
 
-        with mock.patch.object(trigger, "_check_completion", 
side_effect=Exception("Connection failed")):
-            events = []
-            async for event in trigger.run():
-                events.append(event)
+        assert len(events) == 1
+        assert events[0].payload["status"] == "error"
+        assert events[0].payload["done"] is True
+        assert events[0].payload["exit_code"] is None
+        assert "connection refused" in events[0].payload["message"]
 
-            assert len(events) == 1
-            assert events[0].payload["status"] == "error"
-            assert events[0].payload["done"] is True
-            assert "Connection failed" in events[0].payload["message"]
+    @pytest.mark.asyncio
+    async def 
test_run_gives_up_when_polls_keep_failing_despite_reconnects(self):
+        """A connection that handshakes but whose polls keep failing must 
still hit the cap.
+
+        Regression: the reconnect budget must not reset on a bare successful 
handshake, or a
+        connection that reconnects fine but never completes a poll (e.g. 
ChannelOpenError under
+        sshd MaxSessions) would loop forever and the task would defer 
indefinitely.
+        """
+        trigger = _make_trigger(max_reconnect_attempts=2)
+
+        with (
+            mock.patch.object(trigger, "_connect", 
return_value=mock.MagicMock()) as mock_connect,
+            mock.patch.object(trigger, "_close", return_value=None),
+            mock.patch.object(
+                trigger, "_check_completion", side_effect=OSError("channel 
open failed")
+            ) as mock_check,
+            mock.patch("asyncio.sleep", new=mock.AsyncMock()),
+        ):
+            events = [event async for event in trigger.run()]
+
+        assert len(events) == 1
+        assert events[0].payload["status"] == "error"
+        assert "channel open failed" in events[0].payload["message"]
+        # Budget = 2 -> third consecutive failure ends it (handshake succeeds 
each round
+        # but never resets the counter because no poll ever completes).
+        assert mock_check.call_count == 3
+        assert mock_connect.call_count == 3
+
+    @pytest.mark.asyncio
+    async def test_run_handles_unexpected_exception(self):
+        """A non-connection error surfaces immediately as an error event."""
+        trigger = _make_trigger()
+
+        with (
+            mock.patch.object(trigger, "_connect", 
return_value=mock.MagicMock()),
+            mock.patch.object(trigger, "_close", return_value=None),
+            mock.patch.object(trigger, "_check_completion", 
side_effect=ValueError("boom")),
+        ):
+            events = [event async for event in trigger.run()]
+
+        assert len(events) == 1
+        assert events[0].payload["status"] == "error"
+        assert events[0].payload["done"] is True
+        assert "boom" in events[0].payload["message"]
 
     def test_get_hook(self):
         """Test hook creation."""
-        trigger = SSHRemoteJobTrigger(
-            ssh_conn_id="test_conn",
-            remote_host="custom.host.com",
-            job_id="test_job",
-            job_dir="/tmp/job",
-            log_file="/tmp/job/stdout.log",
-            exit_code_file="/tmp/job/exit_code",
-            remote_os="posix",
-        )
+        trigger = _make_trigger(remote_host="custom.host.com")
 
         hook = trigger._get_hook()
         assert hook.ssh_conn_id == "test_conn"
         assert hook.host == "custom.host.com"
+
+    @pytest.mark.asyncio
+    async def test_run_command_uses_existing_connection(self):
+        """_run_command runs on the passed connection without opening a new 
one."""
+        trigger = _make_trigger(command_timeout=12.0)
+
+        result = mock.MagicMock()
+        result.exit_status = 0
+        result.stdout = "42"
+        result.stderr = ""
+        conn = mock.MagicMock()
+        conn.run = mock.AsyncMock(return_value=result)
+
+        exit_code, stdout, stderr = await trigger._run_command(conn, "echo 42")
+
+        conn.run.assert_awaited_once_with("echo 42", timeout=12.0, check=False)
+        assert (exit_code, stdout, stderr) == (0, "42", "")

Reply via email to