This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 73cd0d5f8e Add skip_on_exit_code to SSHOperator (#36303)
73cd0d5f8e is described below

commit 73cd0d5f8e99f93d2ff90a7cca04e6cdf086c359
Author: Maxim Martynov <[email protected]>
AuthorDate: Thu Dec 21 03:37:16 2023 +0300

    Add skip_on_exit_code to SSHOperator (#36303)
---
 airflow/providers/ssh/operators/ssh.py    | 19 ++++++++++++++---
 tests/providers/ssh/operators/test_ssh.py | 35 ++++++++++++++++++++++++++++++-
 2 files changed, 50 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/ssh/operators/ssh.py 
b/airflow/providers/ssh/operators/ssh.py
index 076d5d9f91..621dc47539 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -20,12 +20,12 @@ from __future__ import annotations
 import warnings
 from base64 import b64encode
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Container, Sequence
 
 from deprecated.classic import deprecated
 
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.models import BaseOperator
 from airflow.providers.ssh.hooks.ssh import SSHHook
 from airflow.utils.types import NOTSET, ArgNotSet
@@ -60,6 +60,9 @@ class SSHOperator(BaseOperator):
         The default is ``False`` but note that `get_pty` is forced to ``True``
         when the `command` starts with ``sudo``.
     :param banner_timeout: timeout to wait for banner from the server in 
seconds
+    :param skip_on_exit_code: If command exits with this exit code, leave the 
task
+        in ``skipped`` state (default: None). If set to ``None``, any non-zero
+        exit code will be treated as a failure.
 
     If *do_xcom_push* is *True*, the numeric exit code emitted by
     the ssh session is pushed to XCom under key ``ssh_exit``.
@@ -91,6 +94,7 @@ class SSHOperator(BaseOperator):
         environment: dict | None = None,
         get_pty: bool = False,
         banner_timeout: float = 30.0,
+        skip_on_exit_code: int | Container[int] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -106,6 +110,13 @@ class SSHOperator(BaseOperator):
         self.environment = environment
         self.get_pty = get_pty
         self.banner_timeout = banner_timeout
+        self.skip_on_exit_code = (
+            skip_on_exit_code
+            if isinstance(skip_on_exit_code, Container)
+            else [skip_on_exit_code]
+            if skip_on_exit_code
+            else []
+        )
 
     @cached_property
     def ssh_hook(self) -> SSHHook:
@@ -141,7 +152,7 @@ class SSHOperator(BaseOperator):
         self.log.info("Creating ssh_client")
         return self.hook.get_conn()
 
-    def exec_ssh_client_command(self, ssh_client: SSHClient, command: str):
+    def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) -> 
tuple[int, bytes, bytes]:
         warnings.warn(
             "exec_ssh_client_command method on SSHOperator is deprecated, call 
"
             "`ssh_hook.exec_ssh_client_command` instead",
@@ -156,6 +167,8 @@ class SSHOperator(BaseOperator):
         if context and self.do_xcom_push:
             ti = context.get("task_instance")
             ti.xcom_push(key="ssh_exit", value=exit_status)
+        if exit_status in self.skip_on_exit_code:
+            raise AirflowSkipException(f"SSH command returned exit code 
{exit_status}. Skipping.")
         if exit_status != 0:
             raise AirflowException(f"SSH operator error: exit status = 
{exit_status}")
 
diff --git a/tests/providers/ssh/operators/test_ssh.py 
b/tests/providers/ssh/operators/test_ssh.py
index 1467b73e43..241f3c8c7b 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -23,7 +23,7 @@ from unittest import mock
 import pytest
 from paramiko.client import SSHClient
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.models import TaskInstance
 from airflow.providers.ssh.hooks.ssh import SSHHook
 from airflow.providers.ssh.operators.ssh import SSHOperator
@@ -203,6 +203,39 @@ class TestSSHOperator:
         self.hook.get_conn.assert_called_once()
         self.hook.get_conn.return_value.__exit__.assert_called_once()
 
+    @pytest.mark.parametrize(
+        "extra_kwargs, actual_exit_code, expected_exc",
+        [
+            ({}, 0, None),
+            ({}, 100, AirflowException),
+            ({"skip_on_exit_code": None}, 0, None),
+            ({"skip_on_exit_code": None}, 100, AirflowException),
+            ({"skip_on_exit_code": 100}, 100, AirflowSkipException),
+            ({"skip_on_exit_code": 100}, 101, AirflowException),
+            ({"skip_on_exit_code": [100]}, 100, AirflowSkipException),
+            ({"skip_on_exit_code": [100]}, 101, AirflowException),
+            ({"skip_on_exit_code": [100, 102]}, 101, AirflowException),
+            ({"skip_on_exit_code": (100,)}, 100, AirflowSkipException),
+            ({"skip_on_exit_code": (100,)}, 101, AirflowException),
+        ],
+    )
+    def test_skip(self, extra_kwargs, actual_exit_code, expected_exc):
+        command = "not_a_real_command"
+        self.exec_ssh_client_command.return_value = (actual_exit_code, b"", 
b"")
+
+        operator = SSHOperator(
+            task_id="test",
+            ssh_hook=self.hook,
+            command=command,
+            **extra_kwargs,
+        )
+
+        if expected_exc is None:
+            operator.execute({})
+        else:
+            with pytest.raises(expected_exc):
+                operator.execute({})
+
     def test_command_errored(self):
         # Test that run_ssh_client_command works on invalid commands
         command = "not_a_real_command"

Reply via email to