This is an automated email from the ASF dual-hosted git repository. phanikumv pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 88c9596f4a check job_status before BatchOperator execute in deferrable mode (#36523) 88c9596f4a is described below commit 88c9596f4aaff492dda8b0b87fa60ee16444e9b6 Author: Wei Lee <weilee...@gmail.com> AuthorDate: Wed Jan 10 20:38:28 2024 +0800 check job_status before BatchOperator execute in deferrable mode (#36523) --- airflow/providers/amazon/aws/operators/batch.py | 48 +++++++++++------ tests/providers/amazon/aws/operators/test_batch.py | 63 +++++++++++++++++++++- 2 files changed, 94 insertions(+), 17 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index fe6f9dadb6..8a124b4027 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -230,7 +230,7 @@ class BatchOperator(BaseOperator): region_name=self.region_name, ) - def execute(self, context: Context): + def execute(self, context: Context) -> str | None: """Submit and monitor an AWS Batch job. :raises: AirflowException @@ -238,28 +238,46 @@ class BatchOperator(BaseOperator): self.submit_job(context) if self.deferrable: - self.defer( - timeout=self.execution_timeout, - trigger=BatchJobTrigger( - job_id=self.job_id, - waiter_max_attempts=self.max_retries, - aws_conn_id=self.aws_conn_id, - region_name=self.region_name, - waiter_delay=self.poll_interval, - ), - method_name="execute_complete", - ) + if not self.job_id: + raise AirflowException("AWS Batch job - job_id was not found") + + job = self.hook.get_job_description(self.job_id) + job_status = job.get("status") + if job_status == self.hook.SUCCESS_STATE: + self.log.info("Job completed.") + return self.job_id + elif job_status == self.hook.FAILURE_STATE: + raise AirflowException(f"Error while running job: {self.job_id} is in {job_status} state") + elif job_status in self.hook.INTERMEDIATE_STATES: + self.defer( + timeout=self.execution_timeout, + trigger=BatchJobTrigger( + job_id=self.job_id, + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + waiter_delay=self.poll_interval, + ), + method_name="execute_complete", + ) + + raise AirflowException(f"Unexpected status: {job_status}") if self.wait_for_completion: self.monitor_job(context) return self.job_id - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + if event is None: + err_msg = "Trigger error: event is None" + self.log.info(err_msg) + raise AirflowException(err_msg) + if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") - else: - self.log.info("Job completed.") + + self.log.info("Job completed.") return event["job_id"] def on_kill(self): diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 020f071786..313d721b3a 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -268,8 +268,11 @@ class TestBatchOperator: container_overrides={"a": "b"}, ) + @mock.patch.object(BatchClientHook, "get_job_description") @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") - def test_defer_if_deferrable_param_set(self, mock_client): + def test_defer_if_deferrable_param_set(self, mock_client, mock_get_job_description): + mock_get_job_description.return_value = {"status": "SUBMITTED"} + batch = BatchOperator( task_id="task", job_name=JOB_NAME, @@ -280,9 +283,65 @@ class TestBatchOperator: ) with pytest.raises(TaskDeferred) as exc: - batch.execute(context=None) + batch.execute(self.mock_context) assert isinstance(exc.value.trigger, BatchJobTrigger) + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") + def test_defer_but_failed_due_to_job_id_not_found(self, mock_client): + """Test that an AirflowException is raised if job_id is not set before deferral.""" + mock_client.return_value.submit_job.return_value = { + "jobName": JOB_NAME, + "jobId": None, + } + + batch = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + do_xcom_push=False, + deferrable=True, + ) + with pytest.raises(AirflowException) as exc: + batch.execute(self.mock_context) + assert "AWS Batch job - job_id was not found" in str(exc.value) + + @mock.patch.object(BatchClientHook, "get_job_description") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") + def test_defer_but_success_before_deferred(self, mock_client, mock_get_job_description): + """Test that an AirflowException is raised if job_id is not set before deferral.""" + mock_client.return_value.submit_job.return_value = RESPONSE_WITHOUT_FAILURES + mock_get_job_description.return_value = {"status": "SUCCEEDED"} + + batch = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + do_xcom_push=False, + deferrable=True, + ) + assert batch.execute(self.mock_context) == JOB_ID + + @mock.patch.object(BatchClientHook, "get_job_description") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") + def test_defer_but_fail_before_deferred(self, mock_client, mock_get_job_description): + """Test that an AirflowException is raised if job_id is not set before deferral.""" + mock_client.return_value.submit_job.return_value = RESPONSE_WITHOUT_FAILURES + mock_get_job_description.return_value = {"status": "FAILED"} + + batch = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + do_xcom_push=False, + deferrable=True, + ) + with pytest.raises(AirflowException) as exc: + batch.execute(self.mock_context) + assert f"Error while running job: {JOB_ID} is in FAILED state" in str(exc.value) + @mock.patch.object(BatchClientHook, "get_job_description") @mock.patch.object(BatchClientHook, "wait_for_job") @mock.patch.object(BatchClientHook, "check_job_success")