This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new caa401da794 Add support for timeout to BatchOperator (#45660)
caa401da794 is described below
commit caa401da7948057cff056c67f272307f1f5c7f4f
Author: Nate Robinson <[email protected]>
AuthorDate: Fri Jan 17 16:56:06 2025 -0500
Add support for timeout to BatchOperator (#45660)
An execution timeout for the submit_job api call can now be passed through
the operator to the boto3 call.
---
providers/src/airflow/providers/amazon/aws/operators/batch.py | 6 ++++++
providers/tests/amazon/aws/operators/test_batch.py | 7 +++++++
2 files changed, 13 insertions(+)
diff --git a/providers/src/airflow/providers/amazon/aws/operators/batch.py
b/providers/src/airflow/providers/amazon/aws/operators/batch.py
index 3df00fb04c3..e69508d8931 100644
--- a/providers/src/airflow/providers/amazon/aws/operators/batch.py
+++ b/providers/src/airflow/providers/amazon/aws/operators/batch.py
@@ -95,6 +95,7 @@ class BatchOperator(BaseOperator):
If it is an array job, only the logs of the first task will be printed.
:param awslogs_fetch_interval: The interval with which cloudwatch logs are
to be fetched, 30 sec.
:param poll_interval: (Deferrable mode only) Time in seconds to wait
between polling.
+ :param submit_job_timeout: Execution timeout in seconds for submitted
batch job.
.. note::
Any custom waiters must return a waiter for these calls:
@@ -184,6 +185,7 @@ class BatchOperator(BaseOperator):
poll_interval: int = 30,
awslogs_enabled: bool = False,
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
+ submit_job_timeout: int | None = None,
**kwargs,
) -> None:
BaseOperator.__init__(self, **kwargs)
@@ -208,6 +210,7 @@ class BatchOperator(BaseOperator):
self.poll_interval = poll_interval
self.awslogs_enabled = awslogs_enabled
self.awslogs_fetch_interval = awslogs_fetch_interval
+ self.submit_job_timeout = submit_job_timeout
# params for hook
self.max_retries = max_retries
@@ -315,6 +318,9 @@ class BatchOperator(BaseOperator):
"schedulingPriorityOverride": self.scheduling_priority_override,
}
+ if self.submit_job_timeout:
+ args["timeout"] = {"attemptDurationSeconds":
self.submit_job_timeout}
+
try:
response = self.hook.client.submit_job(**trim_none_values(args))
except Exception as e:
diff --git a/providers/tests/amazon/aws/operators/test_batch.py
b/providers/tests/amazon/aws/operators/test_batch.py
index 0c14c256edb..c1b1d847b7d 100644
--- a/providers/tests/amazon/aws/operators/test_batch.py
+++ b/providers/tests/amazon/aws/operators/test_batch.py
@@ -70,6 +70,7 @@ class TestBatchOperator:
aws_conn_id="airflow_test",
region_name="eu-west-1",
tags={},
+ submit_job_timeout=3600,
)
self.client_mock = self.get_client_type_mock.return_value
# We're mocking all actual AWS calls and don't need a connection. This
@@ -109,6 +110,7 @@ class TestBatchOperator:
assert self.batch.hook.client == self.client_mock
assert self.batch.tags == {}
assert self.batch.wait_for_completion is True
+ assert self.batch.submit_job_timeout == 3600
self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1")
@@ -141,6 +143,7 @@ class TestBatchOperator:
assert issubclass(type(batch_job.hook.client),
botocore.client.BaseClient)
assert batch_job.tags == {}
assert batch_job.wait_for_completion is True
+ assert batch_job.submit_job_timeout is None
def test_template_fields_overrides(self):
assert self.batch.template_fields == (
@@ -181,6 +184,7 @@ class TestBatchOperator:
parameters={},
retryStrategy={"attempts": 1},
tags={},
+ timeout={"attemptDurationSeconds": 3600},
)
assert self.batch.job_id == JOB_ID
@@ -205,6 +209,7 @@ class TestBatchOperator:
parameters={},
retryStrategy={"attempts": 1},
tags={},
+ timeout={"attemptDurationSeconds": 3600},
)
@mock.patch.object(BatchClientHook, "get_job_description")
@@ -261,6 +266,7 @@ class TestBatchOperator:
parameters={},
retryStrategy={"attempts": 1},
tags={},
+ timeout={"attemptDurationSeconds": 3600},
)
@mock.patch.object(BatchClientHook, "get_job_description")
@@ -359,6 +365,7 @@ class TestBatchOperator:
parameters={},
retryStrategy={"attempts": 1},
tags={},
+ timeout={"attemptDurationSeconds": 3600},
)
@mock.patch.object(BatchClientHook, "check_job_success")