This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 88c9596f4a check job_status before BatchOperator execute in deferrable 
mode (#36523)
88c9596f4a is described below

commit 88c9596f4aaff492dda8b0b87fa60ee16444e9b6
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Wed Jan 10 20:38:28 2024 +0800

    check job_status before BatchOperator execute in deferrable mode (#36523)
---
 airflow/providers/amazon/aws/operators/batch.py    | 48 +++++++++++------
 tests/providers/amazon/aws/operators/test_batch.py | 63 +++++++++++++++++++++-
 2 files changed, 94 insertions(+), 17 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/batch.py 
b/airflow/providers/amazon/aws/operators/batch.py
index fe6f9dadb6..8a124b4027 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -230,7 +230,7 @@ class BatchOperator(BaseOperator):
             region_name=self.region_name,
         )
 
-    def execute(self, context: Context):
+    def execute(self, context: Context) -> str | None:
         """Submit and monitor an AWS Batch job.
 
         :raises: AirflowException
@@ -238,28 +238,46 @@ class BatchOperator(BaseOperator):
         self.submit_job(context)
 
         if self.deferrable:
-            self.defer(
-                timeout=self.execution_timeout,
-                trigger=BatchJobTrigger(
-                    job_id=self.job_id,
-                    waiter_max_attempts=self.max_retries,
-                    aws_conn_id=self.aws_conn_id,
-                    region_name=self.region_name,
-                    waiter_delay=self.poll_interval,
-                ),
-                method_name="execute_complete",
-            )
+            if not self.job_id:
+                raise AirflowException("AWS Batch job - job_id was not found")
+
+            job = self.hook.get_job_description(self.job_id)
+            job_status = job.get("status")
+            if job_status == self.hook.SUCCESS_STATE:
+                self.log.info("Job completed.")
+                return self.job_id
+            elif job_status == self.hook.FAILURE_STATE:
+                raise AirflowException(f"Error while running job: 
{self.job_id} is in {job_status} state")
+            elif job_status in self.hook.INTERMEDIATE_STATES:
+                self.defer(
+                    timeout=self.execution_timeout,
+                    trigger=BatchJobTrigger(
+                        job_id=self.job_id,
+                        waiter_max_attempts=self.max_retries,
+                        aws_conn_id=self.aws_conn_id,
+                        region_name=self.region_name,
+                        waiter_delay=self.poll_interval,
+                    ),
+                    method_name="execute_complete",
+                )
+
+            raise AirflowException(f"Unexpected status: {job_status}")
 
         if self.wait_for_completion:
             self.monitor_job(context)
 
         return self.job_id
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if event is None:
+            err_msg = "Trigger error: event is None"
+            self.log.info(err_msg)
+            raise AirflowException(err_msg)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
-        else:
-            self.log.info("Job completed.")
+
+        self.log.info("Job completed.")
         return event["job_id"]
 
     def on_kill(self):
diff --git a/tests/providers/amazon/aws/operators/test_batch.py 
b/tests/providers/amazon/aws/operators/test_batch.py
index 020f071786..313d721b3a 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -268,8 +268,11 @@ class TestBatchOperator:
                 container_overrides={"a": "b"},
             )
 
+    @mock.patch.object(BatchClientHook, "get_job_description")
     
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
-    def test_defer_if_deferrable_param_set(self, mock_client):
+    def test_defer_if_deferrable_param_set(self, mock_client, 
mock_get_job_description):
+        mock_get_job_description.return_value = {"status": "SUBMITTED"}
+
         batch = BatchOperator(
             task_id="task",
             job_name=JOB_NAME,
@@ -280,9 +283,65 @@ class TestBatchOperator:
         )
 
         with pytest.raises(TaskDeferred) as exc:
-            batch.execute(context=None)
+            batch.execute(self.mock_context)
         assert isinstance(exc.value.trigger, BatchJobTrigger)
 
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+    def test_defer_but_failed_due_to_job_id_not_found(self, mock_client):
+        """Test that an AirflowException is raised if job_id is not set before 
deferral."""
+        mock_client.return_value.submit_job.return_value = {
+            "jobName": JOB_NAME,
+            "jobId": None,
+        }
+
+        batch = BatchOperator(
+            task_id="task",
+            job_name=JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            do_xcom_push=False,
+            deferrable=True,
+        )
+        with pytest.raises(AirflowException) as exc:
+            batch.execute(self.mock_context)
+        assert "AWS Batch job - job_id was not found" in str(exc.value)
+
+    @mock.patch.object(BatchClientHook, "get_job_description")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+    def test_defer_but_success_before_deferred(self, mock_client, 
mock_get_job_description):
+        """Test that an AirflowException is raised if job_id is not set before 
deferral."""
+        mock_client.return_value.submit_job.return_value = 
RESPONSE_WITHOUT_FAILURES
+        mock_get_job_description.return_value = {"status": "SUCCEEDED"}
+
+        batch = BatchOperator(
+            task_id="task",
+            job_name=JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            do_xcom_push=False,
+            deferrable=True,
+        )
+        assert batch.execute(self.mock_context) == JOB_ID
+
+    @mock.patch.object(BatchClientHook, "get_job_description")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+    def test_defer_but_fail_before_deferred(self, mock_client, 
mock_get_job_description):
+        """Test that an AirflowException is raised if job_id is not set before 
deferral."""
+        mock_client.return_value.submit_job.return_value = 
RESPONSE_WITHOUT_FAILURES
+        mock_get_job_description.return_value = {"status": "FAILED"}
+
+        batch = BatchOperator(
+            task_id="task",
+            job_name=JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            do_xcom_push=False,
+            deferrable=True,
+        )
+        with pytest.raises(AirflowException) as exc:
+            batch.execute(self.mock_context)
+        assert f"Error while running job: {JOB_ID} is in FAILED state" in 
str(exc.value)
+
     @mock.patch.object(BatchClientHook, "get_job_description")
     @mock.patch.object(BatchClientHook, "wait_for_job")
     @mock.patch.object(BatchClientHook, "check_job_success")

Reply via email to