This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new caa401da794 Add support for timeout to BatchOperator (#45660)
caa401da794 is described below

commit caa401da7948057cff056c67f272307f1f5c7f4f
Author: Nate Robinson <[email protected]>
AuthorDate: Fri Jan 17 16:56:06 2025 -0500

    Add support for timeout to BatchOperator (#45660)
    
    An execution timeout for the submit_job api call can now be passed through 
the operator to the boto3 call.
---
 providers/src/airflow/providers/amazon/aws/operators/batch.py | 6 ++++++
 providers/tests/amazon/aws/operators/test_batch.py            | 7 +++++++
 2 files changed, 13 insertions(+)

diff --git a/providers/src/airflow/providers/amazon/aws/operators/batch.py 
b/providers/src/airflow/providers/amazon/aws/operators/batch.py
index 3df00fb04c3..e69508d8931 100644
--- a/providers/src/airflow/providers/amazon/aws/operators/batch.py
+++ b/providers/src/airflow/providers/amazon/aws/operators/batch.py
@@ -95,6 +95,7 @@ class BatchOperator(BaseOperator):
         If it is an array job, only the logs of the first task will be printed.
     :param awslogs_fetch_interval: The interval with which cloudwatch logs are 
to be fetched, 30 sec.
     :param poll_interval: (Deferrable mode only) Time in seconds to wait 
between polling.
+    :param submit_job_timeout: Execution timeout in seconds for submitted 
batch job.
 
     .. note::
         Any custom waiters must return a waiter for these calls:
@@ -184,6 +185,7 @@ class BatchOperator(BaseOperator):
         poll_interval: int = 30,
         awslogs_enabled: bool = False,
         awslogs_fetch_interval: timedelta = timedelta(seconds=30),
+        submit_job_timeout: int | None = None,
         **kwargs,
     ) -> None:
         BaseOperator.__init__(self, **kwargs)
@@ -208,6 +210,7 @@ class BatchOperator(BaseOperator):
         self.poll_interval = poll_interval
         self.awslogs_enabled = awslogs_enabled
         self.awslogs_fetch_interval = awslogs_fetch_interval
+        self.submit_job_timeout = submit_job_timeout
 
         # params for hook
         self.max_retries = max_retries
@@ -315,6 +318,9 @@ class BatchOperator(BaseOperator):
             "schedulingPriorityOverride": self.scheduling_priority_override,
         }
 
+        if self.submit_job_timeout:
+            args["timeout"] = {"attemptDurationSeconds": 
self.submit_job_timeout}
+
         try:
             response = self.hook.client.submit_job(**trim_none_values(args))
         except Exception as e:
diff --git a/providers/tests/amazon/aws/operators/test_batch.py 
b/providers/tests/amazon/aws/operators/test_batch.py
index 0c14c256edb..c1b1d847b7d 100644
--- a/providers/tests/amazon/aws/operators/test_batch.py
+++ b/providers/tests/amazon/aws/operators/test_batch.py
@@ -70,6 +70,7 @@ class TestBatchOperator:
             aws_conn_id="airflow_test",
             region_name="eu-west-1",
             tags={},
+            submit_job_timeout=3600,
         )
         self.client_mock = self.get_client_type_mock.return_value
         # We're mocking all actual AWS calls and don't need a connection. This
@@ -109,6 +110,7 @@ class TestBatchOperator:
         assert self.batch.hook.client == self.client_mock
         assert self.batch.tags == {}
         assert self.batch.wait_for_completion is True
+        assert self.batch.submit_job_timeout == 3600
 
         
self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1")
 
@@ -141,6 +143,7 @@ class TestBatchOperator:
         assert issubclass(type(batch_job.hook.client), 
botocore.client.BaseClient)
         assert batch_job.tags == {}
         assert batch_job.wait_for_completion is True
+        assert batch_job.submit_job_timeout is None
 
     def test_template_fields_overrides(self):
         assert self.batch.template_fields == (
@@ -181,6 +184,7 @@ class TestBatchOperator:
             parameters={},
             retryStrategy={"attempts": 1},
             tags={},
+            timeout={"attemptDurationSeconds": 3600},
         )
 
         assert self.batch.job_id == JOB_ID
@@ -205,6 +209,7 @@ class TestBatchOperator:
             parameters={},
             retryStrategy={"attempts": 1},
             tags={},
+            timeout={"attemptDurationSeconds": 3600},
         )
 
     @mock.patch.object(BatchClientHook, "get_job_description")
@@ -261,6 +266,7 @@ class TestBatchOperator:
             parameters={},
             retryStrategy={"attempts": 1},
             tags={},
+            timeout={"attemptDurationSeconds": 3600},
         )
 
     @mock.patch.object(BatchClientHook, "get_job_description")
@@ -359,6 +365,7 @@ class TestBatchOperator:
             parameters={},
             retryStrategy={"attempts": 1},
             tags={},
+            timeout={"attemptDurationSeconds": 3600},
         )
 
     @mock.patch.object(BatchClientHook, "check_job_success")

Reply via email to