This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 27b5f696a4 Add Deferrable mode for EMR Serverless Start Job Operator (#32534) 27b5f696a4 is described below commit 27b5f696a48a088a23294c542acb46bd6e544809 Author: Syed Hussain <103602455+syeda...@users.noreply.github.com> AuthorDate: Thu Jul 20 07:15:58 2023 -0700 Add Deferrable mode for EMR Serverless Start Job Operator (#32534) * Add Deferrable mode for EMR Serverless Start Job Operator --- airflow/providers/amazon/aws/operators/emr.py | 52 +++++++++++++-- airflow/providers/amazon/aws/triggers/emr.py | 74 +++++++++++++++++++++- .../operators/emr/emr_serverless.rst | 2 + .../amazon/aws/operators/test_emr_serverless.py | 41 +++++++++++- .../amazon/aws/triggers/test_emr_serverless.py | 63 ++++++++++++++++++ 5 files changed, 225 insertions(+), 7 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 4cb070da73..c01dbbbe91 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -33,6 +33,8 @@ from airflow.providers.amazon.aws.triggers.emr import ( EmrAddStepsTrigger, EmrContainerTrigger, EmrCreateJobFlowTrigger, + EmrServerlessStartApplicationTrigger, + EmrServerlessStartJobTrigger, EmrTerminateJobFlowTrigger, ) from airflow.providers.amazon.aws.utils.waiter import waiter @@ -1107,6 +1109,9 @@ class EmrServerlessStartJobOperator(BaseOperator): :waiter_max_attempts: Number of times the waiter should poll the application to check the state. If not set, the waiter will use its default value. :param waiter_delay: Number of seconds between polling the state of the job run. + :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) """ template_fields: Sequence[str] = ( @@ -1137,6 +1142,7 @@ class EmrServerlessStartJobOperator(BaseOperator): waiter_check_interval_seconds: int | ArgNotSet = NOTSET, waiter_max_attempts: int | ArgNotSet = NOTSET, waiter_delay: int | ArgNotSet = NOTSET, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): if waiter_check_interval_seconds is NOTSET: @@ -1171,6 +1177,7 @@ class EmrServerlessStartJobOperator(BaseOperator): self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type] self.waiter_delay = int(waiter_delay) # type: ignore[arg-type] self.job_id: str | None = None + self.deferrable = deferrable super().__init__(**kwargs) self.client_request_token = client_request_token or str(uuid4()) @@ -1180,14 +1187,25 @@ class EmrServerlessStartJobOperator(BaseOperator): """Create and return an EmrServerlessHook.""" return EmrServerlessHook(aws_conn_id=self.aws_conn_id) - def execute(self, context: Context) -> str | None: - self.log.info("Starting job on Application: %s", self.application_id) + def execute(self, context: Context, event: dict[str, Any] | None = None) -> str | None: app_state = self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"] if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES: + self.log.info("Application state is %s", app_state) + self.log.info("Starting application %s", self.application_id) self.hook.conn.start_application(applicationId=self.application_id) waiter = self.hook.get_waiter("serverless_app_started") - + if self.deferrable: + self.defer( + trigger=EmrServerlessStartApplicationTrigger( + application_id=self.application_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) wait( waiter=waiter, waiter_max_attempts=self.waiter_max_attempts, @@ -1197,7 +1215,7 @@ class EmrServerlessStartJobOperator(BaseOperator): status_message="Serverless Application status is", status_args=["application.state", "application.stateDetails"], ) - + self.log.info("Starting job on Application: %s", self.application_id) response = self.hook.conn.start_job_run( clientToken=self.client_request_token, applicationId=self.application_id, @@ -1213,6 +1231,18 @@ class EmrServerlessStartJobOperator(BaseOperator): self.job_id = response["jobRunId"] self.log.info("EMR serverless job started: %s", self.job_id) + if self.deferrable: + self.defer( + trigger=EmrServerlessStartJobTrigger( + application_id=self.application_id, + job_id=self.job_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) if self.wait_for_completion: waiter = self.hook.get_waiter("serverless_job_completed") wait( @@ -1227,8 +1257,20 @@ class EmrServerlessStartJobOperator(BaseOperator): return self.job_id + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event is None: + self.log.error("Trigger error: event is None") + raise AirflowException("Trigger error: event is None") + elif event["status"] == "success": + self.log.info("Serverless job completed") + return event["job_id"] + def on_kill(self) -> None: - """Cancel the submitted job run.""" + """ + Cancel the submitted job run. + + Note: this method will not run in deferrable mode. + """ if self.job_id: self.log.info("Stopping job run with jobId - %s", self.job_id) response = self.hook.conn.cancel_job_run(applicationId=self.application_id, jobRunId=self.job_id) diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 7deadc1f37..471a8a747a 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -24,7 +24,7 @@ from botocore.exceptions import WaiterError from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -283,3 +283,75 @@ class EmrStepSensorTrigger(AwsBaseWaiterTrigger): def hook(self) -> AwsGenericHook: return EmrHook(self.aws_conn_id) + + +class EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger): + """ + Poll an Emr Serverless application and wait for it to be started. + + :param application_id: The ID of the application being polled. + :waiter_delay: polling period in seconds to check for the status + :param waiter_max_attempts: The maximum number of attempts to be made + :param aws_conn_id: Reference to AWS connection id + """ + + def __init__( + self, + application_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + aws_conn_id: str = "aws_default", + ): + super().__init__( + serialized_fields={"application_id": application_id}, + waiter_name="serverless_app_started", + waiter_args={"applicationId": application_id}, + failure_message="Application failed to start", + status_message="Application status is", + status_queries=["application.state", "application.stateDetails"], + return_key="application_id", + return_value=application_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return EmrServerlessHook(self.aws_conn_id) + + +class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger): + """ + Poll an Emr Serverless job run and wait for it to be completed. + + :param application_id: The ID of the application the job in being run on. + :param job_id: The ID of the job run. + :waiter_delay: polling period in seconds to check for the status + :param waiter_max_attempts: The maximum number of attempts to be made + :param aws_conn_id: Reference to AWS connection id + """ + + def __init__( + self, + application_id: str, + job_id: str | None, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + aws_conn_id: str = "aws_default", + ) -> None: + super().__init__( + serialized_fields={"application_id": application_id, "job_id": job_id}, + waiter_name="serverless_job_completed", + waiter_args={"applicationId": application_id, "jobRunId": job_id}, + failure_message="Serverless Job failed", + status_message="Serverless Job status is", + status_queries=["jobRun.state", "jobRun.stateDetails"], + return_key="job_id", + return_value=job_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return EmrServerlessHook(self.aws_conn_id) diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst index 4f74690da0..76638815b8 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst @@ -54,6 +54,8 @@ Start an EMR Serverless Job You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessStartJobOperator` to start an EMR Serverless Job. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr_serverless.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 8cb4eb1707..2d9f37830f 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -23,7 +23,7 @@ from uuid import UUID import pytest from botocore.exceptions import WaiterError -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook from airflow.providers.amazon.aws.operators.emr import ( EmrServerlessCreateApplicationOperator, @@ -730,6 +730,45 @@ class TestEmrServerlessStartJobOperator: assert operator.waiter_delay == expected[0] assert operator.waiter_max_attempts == expected[1] + @mock.patch.object(EmrServerlessHook, "conn") + def test_start_job_deferrable(self, mock_conn): + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + deferrable=True, + ) + + with pytest.raises(TaskDeferred): + operator.execute(None) + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_start_job_deferrable_app_not_started(self, mock_conn, mock_get_waiter): + mock_get_waiter.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} + mock_conn.start_application.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + deferrable=True, + ) + + with pytest.raises(TaskDeferred): + operator.execute(None) + class TestEmrServerlessDeleteOperator: @mock.patch.object(EmrServerlessHook, "get_waiter") diff --git a/tests/providers/amazon/aws/triggers/test_emr_serverless.py b/tests/providers/amazon/aws/triggers/test_emr_serverless.py new file mode 100644 index 0000000000..029fc4ccbf --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr_serverless.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.triggers.emr import ( + EmrServerlessStartApplicationTrigger, + EmrServerlessStartJobTrigger, +) + +TEST_APPLICATION_ID = "test-application-id" +TEST_WAITER_DELAY = 10 +TEST_WAITER_MAX_ATTEMPTS = 10 +TEST_AWS_CONN_ID = "test-aws-id" +AWS_CONN_ID = "aws_emr_conn" +TEST_JOB_ID = "test-job-id" + + +class TestEmrTriggers: + @pytest.mark.parametrize( + "trigger", + [ + EmrServerlessStartApplicationTrigger( + application_id=TEST_APPLICATION_ID, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + ), + EmrServerlessStartJobTrigger( + application_id=TEST_APPLICATION_ID, + job_id=TEST_JOB_ID, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + ), + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2