o-nikolas commented on code in PR #28827: URL: https://github.com/apache/airflow/pull/28827#discussion_r1066269743
########## airflow/providers/amazon/aws/operators/emr.py: ########## @@ -538,42 +544,76 @@ def __init__( emr_conn_id: str | None = "emr_default", job_flow_overrides: str | dict[str, Any] | None = None, region_name: str | None = None, + wait_for_completion: bool = False, + waiter_countdown: int | None = None, + waiter_check_interval_seconds: int = 60, **kwargs, ): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.emr_conn_id = emr_conn_id self.job_flow_overrides = job_flow_overrides or {} self.region_name = region_name + self.wait_for_completion = wait_for_completion + self.waiter_countdown = waiter_countdown + self.waiter_check_interval_seconds = waiter_check_interval_seconds + + self._job_flow_id: str | None = None - def execute(self, context: Context) -> str: - emr = EmrHook( + @cached_property + def _emr_hook(self) -> EmrHook: + """Create and return an EmrHook.""" + return EmrHook( aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name ) + def execute(self, context: Context) -> str | None: self.log.info( - "Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s", self.aws_conn_id, self.emr_conn_id + "Creating job flow using aws_conn_id: %s, emr_conn_id: %s", self.aws_conn_id, self.emr_conn_id ) if isinstance(self.job_flow_overrides, str): job_flow_overrides: dict[str, Any] = ast.literal_eval(self.job_flow_overrides) self.job_flow_overrides = job_flow_overrides else: job_flow_overrides = self.job_flow_overrides - response = emr.create_job_flow(job_flow_overrides) + response = self._emr_hook.create_job_flow(job_flow_overrides) if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: - raise AirflowException(f"JobFlow creation failed: {response}") + raise AirflowException(f"Job flow creation failed: {response}") else: - job_flow_id = response["JobFlowId"] - self.log.info("JobFlow with id %s created", job_flow_id) + self._job_flow_id = response["JobFlowId"] + self.log.info("Job flow with id %s created", self._job_flow_id) EmrClusterLink.persist( context=context, operator=self, - region_name=emr.conn_region_name, - aws_partition=emr.conn_partition, - job_flow_id=job_flow_id, + region_name=self._emr_hook.conn_region_name, + aws_partition=self._emr_hook.conn_partition, + job_flow_id=self._job_flow_id, ) - return job_flow_id + + if self.wait_for_completion: + # Didn't use a boto-supplied waiter because those don't support waiting for WAITING state. + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#waiters + waiter( + get_state_callable=self._emr_hook.get_conn().describe_cluster, Review Comment: Glad to see someone using this already to create new customer waiters! :star_struck: ########## airflow/providers/amazon/aws/utils/waiter.py: ########## @@ -60,14 +60,20 @@ def waiter( break if state in failure_states: raise AirflowException(f"{object_type.title()} reached failure state {state}.") - if countdown > check_interval_seconds: - countdown -= check_interval_seconds + + if countdown is None: Review Comment: I like the new feature! Thanks for adding it :smile: If `countdown` is `None` you could set it to `float('inf')` before the while loop and then the rest of the code would function as it was before and you would not need the extra branching statements or the duplicated logs/sleeps (you'd need to add `float` as a possible type for the param as well). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org