This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 73cd0d5f8e Add skip_on_exit_code to SSHOperator (#36303)
73cd0d5f8e is described below
commit 73cd0d5f8e99f93d2ff90a7cca04e6cdf086c359
Author: Maxim Martynov <[email protected]>
AuthorDate: Thu Dec 21 03:37:16 2023 +0300
Add skip_on_exit_code to SSHOperator (#36303)
---
airflow/providers/ssh/operators/ssh.py | 19 ++++++++++++++---
tests/providers/ssh/operators/test_ssh.py | 35 ++++++++++++++++++++++++++++++-
2 files changed, 50 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/ssh/operators/ssh.py
b/airflow/providers/ssh/operators/ssh.py
index 076d5d9f91..621dc47539 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -20,12 +20,12 @@ from __future__ import annotations
import warnings
from base64 import b64encode
from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Container, Sequence
from deprecated.classic import deprecated
from airflow.configuration import conf
-from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.models import BaseOperator
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.types import NOTSET, ArgNotSet
@@ -60,6 +60,9 @@ class SSHOperator(BaseOperator):
The default is ``False`` but note that `get_pty` is forced to ``True``
when the `command` starts with ``sudo``.
:param banner_timeout: timeout to wait for banner from the server in
seconds
+ :param skip_on_exit_code: If command exits with this exit code, leave the
task
+ in ``skipped`` state (default: None). If set to ``None``, any non-zero
+ exit code will be treated as a failure.
If *do_xcom_push* is *True*, the numeric exit code emitted by
the ssh session is pushed to XCom under key ``ssh_exit``.
@@ -91,6 +94,7 @@ class SSHOperator(BaseOperator):
environment: dict | None = None,
get_pty: bool = False,
banner_timeout: float = 30.0,
+ skip_on_exit_code: int | Container[int] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -106,6 +110,13 @@ class SSHOperator(BaseOperator):
self.environment = environment
self.get_pty = get_pty
self.banner_timeout = banner_timeout
+ self.skip_on_exit_code = (
+ skip_on_exit_code
+ if isinstance(skip_on_exit_code, Container)
+ else [skip_on_exit_code]
+ if skip_on_exit_code
+ else []
+ )
@cached_property
def ssh_hook(self) -> SSHHook:
@@ -141,7 +152,7 @@ class SSHOperator(BaseOperator):
self.log.info("Creating ssh_client")
return self.hook.get_conn()
- def exec_ssh_client_command(self, ssh_client: SSHClient, command: str):
+ def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) ->
tuple[int, bytes, bytes]:
warnings.warn(
"exec_ssh_client_command method on SSHOperator is deprecated, call
"
"`ssh_hook.exec_ssh_client_command` instead",
@@ -156,6 +167,8 @@ class SSHOperator(BaseOperator):
if context and self.do_xcom_push:
ti = context.get("task_instance")
ti.xcom_push(key="ssh_exit", value=exit_status)
+ if exit_status in self.skip_on_exit_code:
+ raise AirflowSkipException(f"SSH command returned exit code
{exit_status}. Skipping.")
if exit_status != 0:
raise AirflowException(f"SSH operator error: exit status =
{exit_status}")
diff --git a/tests/providers/ssh/operators/test_ssh.py
b/tests/providers/ssh/operators/test_ssh.py
index 1467b73e43..241f3c8c7b 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -23,7 +23,7 @@ from unittest import mock
import pytest
from paramiko.client import SSHClient
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import TaskInstance
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
@@ -203,6 +203,39 @@ class TestSSHOperator:
self.hook.get_conn.assert_called_once()
self.hook.get_conn.return_value.__exit__.assert_called_once()
+ @pytest.mark.parametrize(
+ "extra_kwargs, actual_exit_code, expected_exc",
+ [
+ ({}, 0, None),
+ ({}, 100, AirflowException),
+ ({"skip_on_exit_code": None}, 0, None),
+ ({"skip_on_exit_code": None}, 100, AirflowException),
+ ({"skip_on_exit_code": 100}, 100, AirflowSkipException),
+ ({"skip_on_exit_code": 100}, 101, AirflowException),
+ ({"skip_on_exit_code": [100]}, 100, AirflowSkipException),
+ ({"skip_on_exit_code": [100]}, 101, AirflowException),
+ ({"skip_on_exit_code": [100, 102]}, 101, AirflowException),
+ ({"skip_on_exit_code": (100,)}, 100, AirflowSkipException),
+ ({"skip_on_exit_code": (100,)}, 101, AirflowException),
+ ],
+ )
+ def test_skip(self, extra_kwargs, actual_exit_code, expected_exc):
+ command = "not_a_real_command"
+ self.exec_ssh_client_command.return_value = (actual_exit_code, b"",
b"")
+
+ operator = SSHOperator(
+ task_id="test",
+ ssh_hook=self.hook,
+ command=command,
+ **extra_kwargs,
+ )
+
+ if expected_exc is None:
+ operator.execute({})
+ else:
+ with pytest.raises(expected_exc):
+ operator.execute({})
+
def test_command_errored(self):
# Test that run_ssh_client_command works on invalid commands
command = "not_a_real_command"