This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 77a6b4f419 Add args to docker service ContainerSpec (#39464)
77a6b4f419 is described below

commit 77a6b4f41917e05009c34fd9be9b7fa4192a11bc
Author: Guy Driesen <guydrie...@users.noreply.github.com>
AuthorDate: Sat May 18 00:40:42 2024 +0200

    Add args to docker service ContainerSpec (#39464)
    
    * Add args to docker service ContainerSpec
    
    * args is a list in ContainerSpec
    
    * fix ContainerSpec assertion
    
    * fix args formatter
    
    * fix ContainerSpec assert
    
    * remove some spaces
    
    * add docker service args list test case
    
    * replace ast.literal_eval with json.loads
    
    * remove json string representation
    
    ---------
    
    Co-authored-by: Guy Driesen <19373791+guydrie...@users.noreply.github.com>
---
 airflow/providers/docker/operators/docker_swarm.py | 19 ++++++
 .../docker/operators/test_docker_swarm.py          | 77 ++++++++++++++++++++++
 2 files changed, 96 insertions(+)

diff --git a/airflow/providers/docker/operators/docker_swarm.py 
b/airflow/providers/docker/operators/docker_swarm.py
index b9fc6c89a7..a05bfdc897 100644
--- a/airflow/providers/docker/operators/docker_swarm.py
+++ b/airflow/providers/docker/operators/docker_swarm.py
@@ -19,6 +19,7 @@
 from __future__ import annotations
 
 import re
+import shlex
 from datetime import datetime
 from time import sleep
 from typing import TYPE_CHECKING
@@ -58,6 +59,7 @@ class DockerSwarmOperator(DockerOperator):
         container's process exits.
         The default is False.
     :param command: Command to be run in the container. (templated)
+    :param args: Arguments to the command.
     :param docker_url: URL of the host running the docker daemon.
         Default is the value of the ``DOCKER_HOST`` environment variable or 
unix://var/run/docker.sock
         if it is unset.
@@ -106,6 +108,7 @@ class DockerSwarmOperator(DockerOperator):
         self,
         *,
         image: str,
+        args: str | list[str] | None = None,
         enable_logging: bool = True,
         configs: list[types.ConfigReference] | None = None,
         secrets: list[types.SecretReference] | None = None,
@@ -116,6 +119,7 @@ class DockerSwarmOperator(DockerOperator):
         **kwargs,
     ) -> None:
         super().__init__(image=image, **kwargs)
+        self.args = args
         self.enable_logging = enable_logging
         self.service = None
         self.configs = configs
@@ -136,6 +140,7 @@ class DockerSwarmOperator(DockerOperator):
                 container_spec=types.ContainerSpec(
                     image=self.image,
                     command=self.format_command(self.command),
+                    args=self.format_args(self.args),
                     mounts=self.mounts,
                     env=self.environment,
                     user=self.user,
@@ -225,6 +230,20 @@ class DockerSwarmOperator(DockerOperator):
             sleep(2)
             last_line_logged, last_timestamp = 
stream_new_logs(last_line_logged, since=last_timestamp)
 
+    @staticmethod
+    def format_args(args: list[str] | str | None) -> list[str] | None:
+        """Retrieve args.
+
+        The args string is parsed to a list.
+
+        :param args: args to the docker service
+
+        :return: the args as list
+        """
+        if isinstance(args, str):
+            return shlex.split(args)
+        return args
+
     def on_kill(self) -> None:
         if self.hook.client_created and self.service is not None:
             self.log.info("Removing docker service: %s", self.service["ID"])
diff --git a/tests/providers/docker/operators/test_docker_swarm.py 
b/tests/providers/docker/operators/test_docker_swarm.py
index 5576eec083..29661123d5 100644
--- a/tests/providers/docker/operators/test_docker_swarm.py
+++ b/tests/providers/docker/operators/test_docker_swarm.py
@@ -84,6 +84,7 @@ class TestDockerSwarmOperator:
         types_mock.ContainerSpec.assert_called_once_with(
             image="ubuntu:latest",
             command="env",
+            args=None,
             user="unittest",
             mounts=[types.Mount(source="/host/path", target="/container/path", 
type="bind")],
             tty=True,
@@ -254,3 +255,79 @@ class TestDockerSwarmOperator:
             placement=None,
         )
         types_mock.Resources.assert_not_called()
+
+    @mock.patch("airflow.providers.docker.operators.docker_swarm.types")
+    def test_service_args_str(self, types_mock, docker_api_client_patcher):
+        mock_obj = mock.Mock()
+
+        client_mock = mock.Mock(spec=APIClient)
+        client_mock.create_service.return_value = {"ID": "some_id"}
+        client_mock.images.return_value = []
+        client_mock.pull.return_value = [b'{"status":"pull log"}']
+        client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
+        types_mock.TaskTemplate.return_value = mock_obj
+        types_mock.ContainerSpec.return_value = mock_obj
+        types_mock.RestartPolicy.return_value = mock_obj
+        types_mock.Resources.return_value = mock_obj
+
+        docker_api_client_patcher.return_value = client_mock
+
+        operator = DockerSwarmOperator(
+            image="ubuntu:latest",
+            command="env",
+            args="--show",
+            task_id="unittest",
+            auto_remove="success",
+            enable_logging=False,
+        )
+        operator.execute(None)
+
+        types_mock.ContainerSpec.assert_called_once_with(
+            image="ubuntu:latest",
+            command="env",
+            args=["--show"],
+            user=None,
+            mounts=[],
+            tty=False,
+            env={"AIRFLOW_TMP_DIR": "/tmp/airflow"},
+            configs=None,
+            secrets=None,
+        )
+
+    @mock.patch("airflow.providers.docker.operators.docker_swarm.types")
+    def test_service_args_list(self, types_mock, docker_api_client_patcher):
+        mock_obj = mock.Mock()
+
+        client_mock = mock.Mock(spec=APIClient)
+        client_mock.create_service.return_value = {"ID": "some_id"}
+        client_mock.images.return_value = []
+        client_mock.pull.return_value = [b'{"status":"pull log"}']
+        client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
+        types_mock.TaskTemplate.return_value = mock_obj
+        types_mock.ContainerSpec.return_value = mock_obj
+        types_mock.RestartPolicy.return_value = mock_obj
+        types_mock.Resources.return_value = mock_obj
+
+        docker_api_client_patcher.return_value = client_mock
+
+        operator = DockerSwarmOperator(
+            image="ubuntu:latest",
+            command="env",
+            args=["--show"],
+            task_id="unittest",
+            auto_remove="success",
+            enable_logging=False,
+        )
+        operator.execute(None)
+
+        types_mock.ContainerSpec.assert_called_once_with(
+            image="ubuntu:latest",
+            command="env",
+            args=["--show"],
+            user=None,
+            mounts=[],
+            tty=False,
+            env={"AIRFLOW_TMP_DIR": "/tmp/airflow"},
+            configs=None,
+            secrets=None,
+        )

Reply via email to