This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a6f959bd5 check sagemaker processing job status before deferring 
(#36658)
5a6f959bd5 is described below

commit 5a6f959bd5826409a8d15a894edf36d0e76ef77a
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Wed Jan 10 20:52:35 2024 +0800

    check sagemaker processing job status before deferring (#36658)
---
 .../providers/amazon/aws/operators/sagemaker.py    | 15 +++++-
 .../aws/operators/test_sagemaker_processing.py     | 60 ++++++++++++++++++++--
 2 files changed, 71 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index e8f5f0880c..1b4ffc45af 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -283,8 +283,20 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
             raise AirflowException(f"Sagemaker Processing Job creation failed: 
{response}")
 
         if self.deferrable and self.wait_for_completion:
+            response = 
self.hook.describe_processing_job(self.config["ProcessingJobName"])
+            status = response["ProcessingJobStatus"]
+            if status in self.hook.failed_states:
+                raise AirflowException(f"SageMaker job failed because 
{response['FailureReason']}")
+            elif status == "Completed":
+                self.log.info("%s completed successfully.", self.task_id)
+                return {"Processing": serialize(response)}
+
+            timeout = self.execution_timeout
+            if self.max_ingestion_time:
+                timeout = datetime.timedelta(seconds=self.max_ingestion_time)
+
             self.defer(
-                timeout=self.execution_timeout,
+                timeout=timeout,
                 trigger=SageMakerTrigger(
                     job_name=self.config["ProcessingJobName"],
                     job_type="Processing",
@@ -304,6 +316,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         else:
             self.log.info(event["message"])
         self.serialized_job = 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        self.log.info("%s completed successfully.", self.task_id)
         return {"Processing": self.serialized_job}
 
     def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 0135ba13fe..3a9c9c21f1 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -101,6 +101,10 @@ class TestSageMakerProcessingOperator:
             check_interval=5,
         )
 
+        self.defer_processing_config_kwargs = dict(
+            task_id="test_sagemaker_operator", wait_for_completion=True, 
check_interval=5, deferrable=True
+        )
+
     @mock.patch.object(SageMakerHook, "describe_processing_job")
     @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", 
return_value=0)
     @mock.patch.object(
@@ -243,6 +247,9 @@ class TestSageMakerProcessingOperator:
                 action_if_job_exists="not_fail_or_increment",
             )
 
+    @mock.patch.object(
+        SageMakerHook, "describe_processing_job", 
return_value={"ProcessingJobStatus": "InProgress"}
+    )
     @mock.patch.object(
         SageMakerHook,
         "create_processing_job",
@@ -252,17 +259,64 @@ class TestSageMakerProcessingOperator:
         },
     )
     @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
-    def test_operator_defer(self, mock_job_exists, mock_processing):
+    def test_operator_defer(self, mock_job_exists, mock_processing, 
mock_describe):
         sagemaker_operator = SageMakerProcessingOperator(
-            **self.processing_config_kwargs,
+            **self.defer_processing_config_kwargs,
             config=CREATE_PROCESSING_PARAMS,
-            deferrable=True,
         )
         sagemaker_operator.wait_for_completion = True
         with pytest.raises(TaskDeferred) as exc:
             sagemaker_operator.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
 
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer")
+    @mock.patch.object(
+        SageMakerHook, "describe_processing_job", 
return_value={"ProcessingJobStatus": "Completed"}
+    )
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": 
{"HTTPStatusCode": 200}},
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_operator_complete_before_defer(
+        self, mock_job_exists, mock_processing, mock_describe, mock_defer
+    ):
+        sagemaker_operator = SageMakerProcessingOperator(
+            **self.defer_processing_config_kwargs,
+            config=CREATE_PROCESSING_PARAMS,
+        )
+        sagemaker_operator.execute(context=None)
+        assert not mock_defer.called
+
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer")
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_processing_job",
+        return_value={"ProcessingJobStatus": "Failed", "FailureReason": "It 
failed"},
+    )
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": 
{"HTTPStatusCode": 200}},
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_operator_failed_before_defer(
+        self,
+        mock_job_exists,
+        mock_processing,
+        mock_describe,
+        mock_defer,
+    ):
+        sagemaker_operator = SageMakerProcessingOperator(
+            **self.defer_processing_config_kwargs,
+            config=CREATE_PROCESSING_PARAMS,
+        )
+        with pytest.raises(AirflowException):
+            sagemaker_operator.execute(context=None)
+
+        assert not mock_defer.called
+
     @mock.patch.object(
         SageMakerHook,
         "describe_processing_job",

Reply via email to