This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 27b5f696a4 Add Deferrable mode for EMR Serverless Start Job Operator 
(#32534)
27b5f696a4 is described below

commit 27b5f696a48a088a23294c542acb46bd6e544809
Author: Syed Hussain <103602455+syeda...@users.noreply.github.com>
AuthorDate: Thu Jul 20 07:15:58 2023 -0700

    Add Deferrable mode for EMR Serverless Start Job Operator (#32534)
    
    * Add Deferrable mode for EMR Serverless Start Job Operator
---
 airflow/providers/amazon/aws/operators/emr.py      | 52 +++++++++++++--
 airflow/providers/amazon/aws/triggers/emr.py       | 74 +++++++++++++++++++++-
 .../operators/emr/emr_serverless.rst               |  2 +
 .../amazon/aws/operators/test_emr_serverless.py    | 41 +++++++++++-
 .../amazon/aws/triggers/test_emr_serverless.py     | 63 ++++++++++++++++++
 5 files changed, 225 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 4cb070da73..c01dbbbe91 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -33,6 +33,8 @@ from airflow.providers.amazon.aws.triggers.emr import (
     EmrAddStepsTrigger,
     EmrContainerTrigger,
     EmrCreateJobFlowTrigger,
+    EmrServerlessStartApplicationTrigger,
+    EmrServerlessStartJobTrigger,
     EmrTerminateJobFlowTrigger,
 )
 from airflow.providers.amazon.aws.utils.waiter import waiter
@@ -1107,6 +1109,9 @@ class EmrServerlessStartJobOperator(BaseOperator):
     :waiter_max_attempts: Number of times the waiter should poll the 
application to check the state.
         If not set, the waiter will use its default value.
     :param waiter_delay: Number of seconds between polling the state of the 
job run.
+    :param deferrable: If True, the operator will wait asynchronously for the 
crawl to complete.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False, but can be overridden in config file by setting 
default_deferrable to True)
     """
 
     template_fields: Sequence[str] = (
@@ -1137,6 +1142,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
         waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
         waiter_max_attempts: int | ArgNotSet = NOTSET,
         waiter_delay: int | ArgNotSet = NOTSET,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ):
         if waiter_check_interval_seconds is NOTSET:
@@ -1171,6 +1177,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
         self.waiter_max_attempts = int(waiter_max_attempts)  # type: 
ignore[arg-type]
         self.waiter_delay = int(waiter_delay)  # type: ignore[arg-type]
         self.job_id: str | None = None
+        self.deferrable = deferrable
         super().__init__(**kwargs)
 
         self.client_request_token = client_request_token or str(uuid4())
@@ -1180,14 +1187,25 @@ class EmrServerlessStartJobOperator(BaseOperator):
         """Create and return an EmrServerlessHook."""
         return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
 
-    def execute(self, context: Context) -> str | None:
-        self.log.info("Starting job on Application: %s", self.application_id)
+    def execute(self, context: Context, event: dict[str, Any] | None = None) 
-> str | None:
 
         app_state = 
self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"]
         if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES:
+            self.log.info("Application state is %s", app_state)
+            self.log.info("Starting application %s", self.application_id)
             self.hook.conn.start_application(applicationId=self.application_id)
             waiter = self.hook.get_waiter("serverless_app_started")
-
+            if self.deferrable:
+                self.defer(
+                    trigger=EmrServerlessStartApplicationTrigger(
+                        application_id=self.application_id,
+                        waiter_delay=self.waiter_delay,
+                        waiter_max_attempts=self.waiter_max_attempts,
+                        aws_conn_id=self.aws_conn_id,
+                    ),
+                    method_name="execute",
+                    timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay),
+                )
             wait(
                 waiter=waiter,
                 waiter_max_attempts=self.waiter_max_attempts,
@@ -1197,7 +1215,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
                 status_message="Serverless Application status is",
                 status_args=["application.state", "application.stateDetails"],
             )
-
+        self.log.info("Starting job on Application: %s", self.application_id)
         response = self.hook.conn.start_job_run(
             clientToken=self.client_request_token,
             applicationId=self.application_id,
@@ -1213,6 +1231,18 @@ class EmrServerlessStartJobOperator(BaseOperator):
 
         self.job_id = response["jobRunId"]
         self.log.info("EMR serverless job started: %s", self.job_id)
+        if self.deferrable:
+            self.defer(
+                trigger=EmrServerlessStartJobTrigger(
+                    application_id=self.application_id,
+                    job_id=self.job_id,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay),
+            )
         if self.wait_for_completion:
             waiter = self.hook.get_waiter("serverless_job_completed")
             wait(
@@ -1227,8 +1257,20 @@ class EmrServerlessStartJobOperator(BaseOperator):
 
         return self.job_id
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        if event is None:
+            self.log.error("Trigger error: event is None")
+            raise AirflowException("Trigger error: event is None")
+        elif event["status"] == "success":
+            self.log.info("Serverless job completed")
+            return event["job_id"]
+
     def on_kill(self) -> None:
-        """Cancel the submitted job run."""
+        """
+        Cancel the submitted job run.
+
+        Note: this method will not run in deferrable mode.
+        """
         if self.job_id:
             self.log.info("Stopping job run with jobId - %s", self.job_id)
             response = 
self.hook.conn.cancel_job_run(applicationId=self.application_id, 
jobRunId=self.job_id)
diff --git a/airflow/providers/amazon/aws/triggers/emr.py 
b/airflow/providers/amazon/aws/triggers/emr.py
index 7deadc1f37..471a8a747a 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -24,7 +24,7 @@ from botocore.exceptions import WaiterError
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
-from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
+from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, 
EmrServerlessHook
 from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
@@ -283,3 +283,75 @@ class EmrStepSensorTrigger(AwsBaseWaiterTrigger):
 
     def hook(self) -> AwsGenericHook:
         return EmrHook(self.aws_conn_id)
+
+
+class EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger):
+    """
+    Poll an Emr Serverless application and wait for it to be started.
+
+    :param application_id: The ID of the application being polled.
+    :waiter_delay: polling period in seconds to check for the status
+    :param waiter_max_attempts: The maximum number of attempts to be made
+    :param aws_conn_id: Reference to AWS connection id
+    """
+
+    def __init__(
+        self,
+        application_id: str,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
+        aws_conn_id: str = "aws_default",
+    ):
+        super().__init__(
+            serialized_fields={"application_id": application_id},
+            waiter_name="serverless_app_started",
+            waiter_args={"applicationId": application_id},
+            failure_message="Application failed to start",
+            status_message="Application status is",
+            status_queries=["application.state", "application.stateDetails"],
+            return_key="application_id",
+            return_value=application_id,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return EmrServerlessHook(self.aws_conn_id)
+
+
+class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
+    """
+    Poll an Emr Serverless job run and wait for it to be completed.
+
+    :param application_id: The ID of the application the job in being run on.
+    :param job_id: The ID of the job run.
+    :waiter_delay: polling period in seconds to check for the status
+    :param waiter_max_attempts: The maximum number of attempts to be made
+    :param aws_conn_id: Reference to AWS connection id
+    """
+
+    def __init__(
+        self,
+        application_id: str,
+        job_id: str | None,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
+        aws_conn_id: str = "aws_default",
+    ) -> None:
+        super().__init__(
+            serialized_fields={"application_id": application_id, "job_id": 
job_id},
+            waiter_name="serverless_job_completed",
+            waiter_args={"applicationId": application_id, "jobRunId": job_id},
+            failure_message="Serverless Job failed",
+            status_message="Serverless Job status is",
+            status_queries=["jobRun.state", "jobRun.stateDetails"],
+            return_key="job_id",
+            return_value=job_id,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return EmrServerlessHook(self.aws_conn_id)
diff --git 
a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst 
b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst
index 4f74690da0..76638815b8 100644
--- a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst
+++ b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst
@@ -54,6 +54,8 @@ Start an EMR Serverless Job
 
 You can use 
:class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessStartJobOperator`
 to
 start an EMR Serverless Job.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as 
a parameter. This requires
+the aiobotocore module to be installed.
 
 .. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_emr_serverless.py
    :language: python
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py 
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index 8cb4eb1707..2d9f37830f 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -23,7 +23,7 @@ from uuid import UUID
 import pytest
 from botocore.exceptions import WaiterError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
 from airflow.providers.amazon.aws.operators.emr import (
     EmrServerlessCreateApplicationOperator,
@@ -730,6 +730,45 @@ class TestEmrServerlessStartJobOperator:
         assert operator.waiter_delay == expected[0]
         assert operator.waiter_max_attempts == expected[1]
 
+    @mock.patch.object(EmrServerlessHook, "conn")
+    def test_start_job_deferrable(self, mock_conn):
+        mock_conn.get_application.return_value = {"application": {"state": 
"STARTED"}}
+        mock_conn.start_job_run.return_value = {
+            "jobRunId": job_run_id,
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        operator = EmrServerlessStartJobOperator(
+            task_id=task_id,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=job_driver,
+            configuration_overrides=configuration_overrides,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred):
+            operator.execute(None)
+
+    @mock.patch.object(EmrServerlessHook, "get_waiter")
+    @mock.patch.object(EmrServerlessHook, "conn")
+    def test_start_job_deferrable_app_not_started(self, mock_conn, 
mock_get_waiter):
+        mock_get_waiter.return_value = True
+        mock_conn.get_application.return_value = {"application": {"state": 
"CREATING"}}
+        mock_conn.start_application.return_value = {
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        operator = EmrServerlessStartJobOperator(
+            task_id=task_id,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=job_driver,
+            configuration_overrides=configuration_overrides,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred):
+            operator.execute(None)
+
 
 class TestEmrServerlessDeleteOperator:
     @mock.patch.object(EmrServerlessHook, "get_waiter")
diff --git a/tests/providers/amazon/aws/triggers/test_emr_serverless.py 
b/tests/providers/amazon/aws/triggers/test_emr_serverless.py
new file mode 100644
index 0000000000..029fc4ccbf
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_emr_serverless.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.emr import (
+    EmrServerlessStartApplicationTrigger,
+    EmrServerlessStartJobTrigger,
+)
+
+TEST_APPLICATION_ID = "test-application-id"
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
+TEST_AWS_CONN_ID = "test-aws-id"
+AWS_CONN_ID = "aws_emr_conn"
+TEST_JOB_ID = "test-job-id"
+
+
+class TestEmrTriggers:
+    @pytest.mark.parametrize(
+        "trigger",
+        [
+            EmrServerlessStartApplicationTrigger(
+                application_id=TEST_APPLICATION_ID,
+                aws_conn_id=TEST_AWS_CONN_ID,
+                waiter_delay=TEST_WAITER_DELAY,
+                waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+            ),
+            EmrServerlessStartJobTrigger(
+                application_id=TEST_APPLICATION_ID,
+                job_id=TEST_JOB_ID,
+                aws_conn_id=TEST_AWS_CONN_ID,
+                waiter_delay=TEST_WAITER_DELAY,
+                waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+            ),
+        ],
+    )
+    def test_serialize_recreate(self, trigger):
+        class_path, args = trigger.serialize()
+
+        class_name = class_path.split(".")[-1]
+        clazz = globals()[class_name]
+        instance = clazz(**args)
+
+        class_path2, args2 = instance.serialize()
+
+        assert class_path == class_path2
+        assert args == args2

Reply via email to