This is an automated email from the ASF dual-hosted git repository. kamilbregula pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new b9d677c Add type hints to aws provider (#11531) b9d677c is described below commit b9d677cdd660e0be8278a64658e73359276a9682 Author: Katsunori Kanda <pot...@gmail.com> AuthorDate: Thu Oct 22 09:49:22 2020 +0900 Add type hints to aws provider (#11531) * Added type hints to aws provider * Update airflow/providers/amazon/aws/log/s3_task_handler.py * Fix expectation for submit_job * Fix documentation Co-authored-by: Kamil BreguĊa <mik-...@users.noreply.github.com> --- .../aws/hooks/elasticache_replication_group.py | 45 +++++----- airflow/providers/amazon/aws/hooks/glue.py | 4 +- airflow/providers/amazon/aws/hooks/sagemaker.py | 66 +++++++-------- .../amazon/aws/log/cloudwatch_task_handler.py | 4 +- .../providers/amazon/aws/log/s3_task_handler.py | 11 ++- airflow/providers/amazon/aws/operators/batch.py | 39 +++++---- .../amazon/aws/operators/cloud_formation.py | 9 +- airflow/providers/amazon/aws/operators/datasync.py | 97 +++++++++++++--------- airflow/providers/amazon/aws/operators/ecs.py | 77 +++++++++-------- airflow/providers/amazon/aws/operators/glue.py | 29 +++---- .../providers/amazon/aws/operators/s3_bucket.py | 4 +- .../amazon/aws/operators/s3_copy_object.py | 15 ++-- .../amazon/aws/operators/s3_delete_objects.py | 12 ++- airflow/providers/amazon/aws/operators/s3_list.py | 13 ++- .../amazon/aws/operators/sagemaker_base.py | 11 ++- .../amazon/aws/operators/sagemaker_endpoint.py | 17 ++-- .../aws/operators/sagemaker_endpoint_config.py | 4 +- .../amazon/aws/operators/sagemaker_model.py | 4 +- .../amazon/aws/operators/sagemaker_processing.py | 19 +++-- .../amazon/aws/operators/sagemaker_training.py | 15 ++-- .../amazon/aws/operators/sagemaker_transform.py | 17 ++-- .../amazon/aws/operators/sagemaker_tuning.py | 13 ++- airflow/providers/amazon/aws/operators/sns.py | 11 +-- airflow/providers/amazon/aws/operators/sqs.py | 11 +-- .../step_function_get_execution_output.py | 10 ++- .../aws/operators/step_function_start_execution.py | 4 +- .../amazon/aws/sensors/cloud_formation.py | 21 +++-- airflow/providers/amazon/aws/sensors/emr_base.py | 18 ++-- airflow/providers/amazon/aws/sensors/glue.py | 2 +- .../amazon/aws/sensors/glue_catalog_partition.py | 23 ++--- airflow/providers/amazon/aws/sensors/redshift.py | 20 +++-- airflow/providers/amazon/aws/sensors/s3_key.py | 23 ++--- airflow/providers/amazon/aws/sensors/s3_prefix.py | 20 +++-- .../providers/amazon/aws/sensors/sagemaker_base.py | 24 +++--- .../amazon/aws/sensors/sagemaker_training.py | 7 +- .../amazon/aws/sensors/sagemaker_transform.py | 2 +- .../amazon/aws/sensors/sagemaker_tuning.py | 2 +- airflow/providers/amazon/aws/sensors/sqs.py | 19 +++-- .../amazon/aws/sensors/step_function_execution.py | 20 +++-- tests/providers/amazon/aws/operators/test_batch.py | 6 +- 40 files changed, 459 insertions(+), 309 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py index d1c7409..54305d5 100644 --- a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py +++ b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional from time import sleep @@ -40,15 +41,21 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): TERMINAL_STATES = frozenset({"available", "create-failed", "deleting"}) def __init__( - self, max_retries=10, exponential_back_off_factor=1, initial_poke_interval=60, *args, **kwargs + self, + max_retries: int = 10, + exponential_back_off_factor: float = 1, + initial_poke_interval: float = 60, + *args, + **kwargs, ): self.max_retries = max_retries self.exponential_back_off_factor = exponential_back_off_factor self.initial_poke_interval = initial_poke_interval - super().__init__(client_type='elasticache', *args, **kwargs) + kwargs["client_type"] = "elasticache" + super().__init__(*args, **kwargs) - def create_replication_group(self, config): + def create_replication_group(self, config: dict) -> dict: """ Call ElastiCache API for creating a replication group @@ -59,7 +66,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): """ return self.conn.create_replication_group(**config) - def delete_replication_group(self, replication_group_id): + def delete_replication_group(self, replication_group_id: str) -> dict: """ Call ElastiCache API for deleting a replication group @@ -70,7 +77,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): """ return self.conn.delete_replication_group(ReplicationGroupId=replication_group_id) - def describe_replication_group(self, replication_group_id): + def describe_replication_group(self, replication_group_id: str) -> dict: """ Call ElastiCache API for describing a replication group @@ -81,7 +88,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): """ return self.conn.describe_replication_groups(ReplicationGroupId=replication_group_id) - def get_replication_group_status(self, replication_group_id): + def get_replication_group_status(self, replication_group_id: str) -> str: """ Get current status of replication group @@ -92,7 +99,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): """ return self.describe_replication_group(replication_group_id)['ReplicationGroups'][0]['Status'] - def is_replication_group_available(self, replication_group_id): + def is_replication_group_available(self, replication_group_id: str) -> bool: """ Helper for checking if replication group is available or not @@ -105,10 +112,10 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): def wait_for_availability( self, - replication_group_id, - initial_sleep_time=None, - exponential_back_off_factor=None, - max_retries=None, + replication_group_id: str, + initial_sleep_time: Optional[float] = None, + exponential_back_off_factor: Optional[float] = None, + max_retries: Optional[int] = None, ): """ Check if replication group is available or not by performing a describe over it @@ -164,10 +171,10 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): def wait_for_deletion( self, - replication_group_id, - initial_sleep_time=None, - exponential_back_off_factor=None, - max_retries=None, + replication_group_id: str, + initial_sleep_time: Optional[float] = None, + exponential_back_off_factor: Optional[float] = None, + max_retries: Optional[int] = None, ): """ Helper for deleting a replication group ensuring it is either deleted or can't be deleted @@ -244,10 +251,10 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook): def ensure_delete_replication_group( self, - replication_group_id, - initial_sleep_time=None, - exponential_back_off_factor=None, - max_retries=None, + replication_group_id: str, + initial_sleep_time: Optional[float] = None, + exponential_back_off_factor: Optional[float] = None, + max_retries: Optional[int] = None, ): """ Delete a replication group ensuring it is either deleted or can't be deleted diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index 8bc2e72..dbc8707 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -93,14 +93,14 @@ class AwsGlueJobHook(AwsBaseHook): self.log.error("Failed to create aws glue job, error: %s", general_error) raise - def initialize_job(self, script_arguments: Optional[List] = None) -> Dict[str, str]: + def initialize_job(self, script_arguments: Optional[dict] = None) -> Dict[str, str]: """ Initializes connection with AWS Glue to run job :return: """ glue_client = self.get_conn() - script_arguments = script_arguments or [] + script_arguments = script_arguments or {} try: job_name = self.get_or_create_glue_job() diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index 9009967..af2733d 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -22,7 +22,7 @@ import tempfile import time import warnings from functools import partial -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Any, Callable, Generator from botocore.exceptions import ClientError @@ -51,7 +51,7 @@ class LogState: Position = collections.namedtuple('Position', ['timestamp', 'skip']) -def argmin(arr, f) -> Optional[int]: +def argmin(arr, f: Callable) -> Optional[int]: """Return the index, i, in arr that minimizes f(arr[i])""" min_value = None min_idx = None @@ -94,7 +94,9 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de return message != last_message -def secondary_training_status_message(job_description, prev_description): +def secondary_training_status_message( + job_description: Dict[str, List[dict]], prev_description: Optional[dict] +) -> str: """ Returns a string contains start time and the secondary training job status message. @@ -105,22 +107,14 @@ def secondary_training_status_message(job_description, prev_description): :return: Job status string to be printed. """ - if ( - job_description is None - or job_description.get('SecondaryStatusTransitions') is None - or len(job_description.get('SecondaryStatusTransitions')) == 0 - ): + current_transitions = job_description.get('SecondaryStatusTransitions') + if current_transitions is None or len(current_transitions) == 0: return '' - prev_description_secondary_transitions = ( - prev_description.get('SecondaryStatusTransitions') if prev_description is not None else None - ) - prev_transitions_num = ( - len(prev_description['SecondaryStatusTransitions']) - if prev_description_secondary_transitions is not None - else 0 - ) - current_transitions = job_description['SecondaryStatusTransitions'] + prev_transitions_num = 0 + if prev_description is not None: + if prev_description.get('SecondaryStatusTransitions') is not None: + prev_transitions_num = len(prev_description['SecondaryStatusTransitions']) transitions_to_print = ( current_transitions[-1:] @@ -278,7 +272,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip) - def multi_stream_iter(self, log_group, streams, positions=None): + def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Generator: """ Iterate over the available events coming from a set of log streams in a single log group interleaving the events from each stream so they're yielded in timestamp order. @@ -298,7 +292,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip) for s in streams ] - events = [] + events: List[Optional[Any]] = [] for event_stream in event_iters: if not event_stream: events.append(None) @@ -309,8 +303,8 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods events.append(None) while any(events): - i = argmin(events, lambda x: x['timestamp'] if x else 9999999999) - yield (i, events[i]) + i = argmin(events, lambda x: x['timestamp'] if x else 9999999999) or 0 + yield i, events[i] try: events[i] = next(event_iters[i]) except StopIteration: @@ -576,13 +570,13 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods def describe_training_job_with_log( self, - job_name, + job_name: str, positions, - stream_names, - instance_count, - state, - last_description, - last_describe_job_call, + stream_names: list, + instance_count: int, + state: int, + last_description: dict, + last_describe_job_call: float, ): """Return the training job info associated with job_name and print CloudWatch logs""" log_group = '/aws/sagemaker/TrainingJobs' @@ -635,7 +629,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods state = LogState.JOB_COMPLETE return state, last_description, last_describe_job_call - def describe_tuning_job(self, name: str): + def describe_tuning_job(self, name: str) -> dict: """ Return the tuning job info associated with the name @@ -645,7 +639,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods """ return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name) - def describe_model(self, name: str): + def describe_model(self, name: str) -> dict: """ Return the SageMaker model info associated with the name @@ -655,7 +649,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods """ return self.get_conn().describe_model(ModelName=name) - def describe_transform_job(self, name: str): + def describe_transform_job(self, name: str) -> dict: """ Return the transform job info associated with the name @@ -665,7 +659,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods """ return self.get_conn().describe_transform_job(TransformJobName=name) - def describe_processing_job(self, name: str): + def describe_processing_job(self, name: str) -> dict: """ Return the processing job info associated with the name @@ -675,7 +669,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods """ return self.get_conn().describe_processing_job(ProcessingJobName=name) - def describe_endpoint_config(self, name: str): + def describe_endpoint_config(self, name: str) -> dict: """ Return the endpoint config info associated with the name @@ -685,7 +679,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods """ return self.get_conn().describe_endpoint_config(EndpointConfigName=name) - def describe_endpoint(self, name: str): + def describe_endpoint(self, name: str) -> dict: """ :param name: the name of the endpoint :type name: str @@ -697,7 +691,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods self, job_name: str, key: str, - describe_function, + describe_function: Callable, check_interval: int, max_ingestion_time: Optional[int] = None, non_terminal_states: Optional[Set] = None, @@ -916,7 +910,9 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods ) return results - def _list_request(self, partial_func, result_key: str, max_results: Optional[int] = None) -> List[Dict]: + def _list_request( + self, partial_func: Callable, result_key: str, max_results: Optional[int] = None + ) -> List[Dict]: """ All AWS boto3 list_* requests return results in batches (if the key "NextToken" is contained in the result, there are more results to fetch). The default AWS batch size is 10, and configurable up to diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index 5305784..fdd8154 100644 --- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -39,7 +39,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin): :type filename_template: str """ - def __init__(self, base_log_folder, log_group_arn, filename_template): + def __init__(self, base_log_folder: str, log_group_arn: str, filename_template: str): super().__init__(base_log_folder, filename_template) split_arn = log_group_arn.split(':') @@ -99,7 +99,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin): {'end_of_log': True}, ) - def get_cloudwatch_logs(self, stream_name): + def get_cloudwatch_logs(self, stream_name: str) -> str: """ Return all logs from the given log stream. diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py index 8b32a2f..922e9ec 100644 --- a/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -31,7 +31,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): uploads to and reads from S3 remote storage. """ - def __init__(self, base_log_folder, s3_log_folder, filename_template): + def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template: str): super().__init__(base_log_folder, filename_template) self.remote_base = s3_log_folder self.log_relative_path = '' @@ -119,11 +119,12 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): else: return super()._read(ti, try_number) - def s3_log_exists(self, remote_log_location): + def s3_log_exists(self, remote_log_location: str) -> bool: """ Check if remote_log_location exists in remote storage :param remote_log_location: log's location in remote storage + :type remote_log_location: str :return: True if location exists else False """ try: @@ -132,7 +133,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): pass return False - def s3_read(self, remote_log_location, return_error=False): + def s3_read(self, remote_log_location: str, return_error: bool = False) -> str: """ Returns the log found at the remote_log_location. Returns '' if no logs are found or there is an error. @@ -142,6 +143,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): :param return_error: if True, returns a string error message if an error occurs. Otherwise returns '' when an error occurs. :type return_error: bool + :return: the log found at the remote_log_location """ try: return self.hook.read_key(remote_log_location) @@ -151,8 +153,9 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): # return error if needed if return_error: return msg + return '' - def s3_write(self, log, remote_log_location, append=True): + def s3_write(self, log: str, remote_log_location: str, append: bool = True): """ Writes the log to the remote_log_location. Fails silently if no hook was created. diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index eb0c4a4..82868e6 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -26,7 +26,7 @@ An Airflow operator for AWS Batch services - http://boto3.readthedocs.io/en/latest/reference/services/batch.html - https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html """ -from typing import Dict, Optional +from typing import Dict, Optional, Any from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -48,13 +48,13 @@ class AwsBatchOperator(BaseOperator): :type job_queue: str :param overrides: the `containerOverrides` parameter for boto3 (templated) - :type overrides: Dict + :type overrides: Optional[dict] :param array_properties: the `arrayProperties` parameter for boto3 - :type array_properties: Dict + :type array_properties: Optional[dict] :param parameters: the `parameters` for boto3 (templated) - :type parameters: Dict + :type parameters: Optional[dict] :param job_id: the job ID, usually unknown (None) until the submit_job operation gets the jobId defined by AWS Batch @@ -101,18 +101,18 @@ class AwsBatchOperator(BaseOperator): def __init__( self, *, - job_name, - job_definition, - job_queue, - overrides, - array_properties=None, - parameters=None, - job_id=None, - waiters=None, - max_retries=None, - status_retries=None, - aws_conn_id=None, - region_name=None, + job_name: str, + job_definition: str, + job_queue: str, + overrides: dict, + array_properties: Optional[dict] = None, + parameters: Optional[dict] = None, + job_id: Optional[str] = None, + waiters: Optional[Any] = None, + max_retries: Optional[int] = None, + status_retries: Optional[int] = None, + aws_conn_id: Optional[str] = None, + region_name: Optional[str] = None, **kwargs, ): # pylint: disable=too-many-arguments @@ -121,9 +121,9 @@ class AwsBatchOperator(BaseOperator): self.job_name = job_name self.job_definition = job_definition self.job_queue = job_queue - self.overrides = overrides + self.overrides = overrides or {} self.array_properties = array_properties or {} - self.parameters = parameters + self.parameters = parameters or {} self.waiters = waiters self.hook = AwsBatchClientHook( max_retries=max_retries, @@ -181,6 +181,9 @@ class AwsBatchOperator(BaseOperator): :raises: AirflowException """ + if not self.job_id: + raise AirflowException('AWS Batch job - job_id was not found') + try: if self.waiters: self.waiters.wait_for_job(self.job_id) diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py b/airflow/providers/amazon/aws/operators/cloud_formation.py index d6fa654..0f511c3 100644 --- a/airflow/providers/amazon/aws/operators/cloud_formation.py +++ b/airflow/providers/amazon/aws/operators/cloud_formation.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains CloudFormation create/delete stack operators.""" -from typing import List +from typing import List, Optional from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook @@ -43,7 +43,7 @@ class CloudFormationCreateStackOperator(BaseOperator): ui_color = '#6b9659' @apply_defaults - def __init__(self, *, stack_name, params, aws_conn_id='aws_default', **kwargs): + def __init__(self, *, stack_name: str, params: dict, aws_conn_id: str = 'aws_default', **kwargs): super().__init__(**kwargs) self.stack_name = stack_name self.params = params @@ -77,11 +77,12 @@ class CloudFormationDeleteStackOperator(BaseOperator): ui_fgcolor = '#FFF' @apply_defaults - def __init__(self, *, stack_name, params=None, aws_conn_id='aws_default', **kwargs): + def __init__( + self, *, stack_name: str, params: Optional[dict] = None, aws_conn_id: str = 'aws_default', **kwargs + ): super().__init__(**kwargs) self.params = params or {} self.stack_name = stack_name - self.params = params self.aws_conn_id = aws_conn_id def execute(self, context): diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index 9f2e9b3..b70ed96 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -19,6 +19,7 @@ import logging import random +from typing import Optional, List from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -108,19 +109,19 @@ class AWSDataSyncOperator(BaseOperator): def __init__( self, *, - aws_conn_id="aws_default", - wait_interval_seconds=5, - task_arn=None, - source_location_uri=None, - destination_location_uri=None, - allow_random_task_choice=False, - allow_random_location_choice=False, - create_task_kwargs=None, - create_source_location_kwargs=None, - create_destination_location_kwargs=None, - update_task_kwargs=None, - task_execution_kwargs=None, - delete_task_after_execution=False, + aws_conn_id: str = "aws_default", + wait_interval_seconds: int = 5, + task_arn: Optional[str] = None, + source_location_uri: Optional[str] = None, + destination_location_uri: Optional[str] = None, + allow_random_task_choice: bool = False, + allow_random_location_choice: bool = False, + create_task_kwargs: Optional[dict] = None, + create_source_location_kwargs: Optional[dict] = None, + create_destination_location_kwargs: Optional[dict] = None, + update_task_kwargs: Optional[dict] = None, + task_execution_kwargs: Optional[dict] = None, + delete_task_after_execution: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -163,27 +164,29 @@ class AWSDataSyncOperator(BaseOperator): ) # Others - self.hook = None + self.hook: Optional[AWSDataSyncHook] = None # Candidates - these are found in AWS as possible things # for us to use - self.candidate_source_location_arns = None - self.candidate_destination_location_arns = None - self.candidate_task_arns = None + self.candidate_source_location_arns: Optional[List[str]] = None + self.candidate_destination_location_arns: Optional[List[str]] = None + self.candidate_task_arns: Optional[List[str]] = None # Actuals - self.source_location_arn = None - self.destination_location_arn = None - self.task_execution_arn = None + self.source_location_arn: Optional[str] = None + self.destination_location_arn: Optional[str] = None + self.task_execution_arn: Optional[str] = None - def get_hook(self): + def get_hook(self) -> AWSDataSyncHook: """Create and return AWSDataSyncHook. :return AWSDataSyncHook: An AWSDataSyncHook instance. """ - if not self.hook: - self.hook = AWSDataSyncHook( - aws_conn_id=self.aws_conn_id, - wait_interval_seconds=self.wait_interval_seconds, - ) + if self.hook: + return self.hook + + self.hook = AWSDataSyncHook( + aws_conn_id=self.aws_conn_id, + wait_interval_seconds=self.wait_interval_seconds, + ) return self.hook def execute(self, context): @@ -221,7 +224,7 @@ class AWSDataSyncOperator(BaseOperator): return {"TaskArn": self.task_arn, "TaskExecutionArn": self.task_execution_arn} - def _get_tasks_and_locations(self): + def _get_tasks_and_locations(self) -> None: """Find existing DataSync Task based on source and dest Locations.""" hook = self.get_hook() @@ -244,7 +247,7 @@ class AWSDataSyncOperator(BaseOperator): ) self.log.info("Found candidate DataSync TaskArns %s", self.candidate_task_arns) - def choose_task(self, task_arn_list): + def choose_task(self, task_arn_list: list) -> Optional[str]: """Select 1 DataSync TaskArn from a list""" if not task_arn_list: return None @@ -258,7 +261,7 @@ class AWSDataSyncOperator(BaseOperator): return random.choice(task_arn_list) raise AirflowException("Unable to choose a Task from {}".format(task_arn_list)) - def choose_location(self, location_arn_list): + def choose_location(self, location_arn_list: List[str]) -> Optional[str]: """Select 1 DataSync LocationArn from a list""" if not location_arn_list: return None @@ -272,12 +275,15 @@ class AWSDataSyncOperator(BaseOperator): return random.choice(location_arn_list) raise AirflowException("Unable to choose a Location from {}".format(location_arn_list)) - def _create_datasync_task(self): + def _create_datasync_task(self) -> None: """Create a AWS DataSyncTask.""" + if not self.candidate_source_location_arns or not self.candidate_destination_location_arns: + return + hook = self.get_hook() self.source_location_arn = self.choose_location(self.candidate_source_location_arns) - if not self.source_location_arn and self.create_source_location_kwargs: + if not self.source_location_arn and self.source_location_uri and self.create_source_location_kwargs: self.log.info('Attempting to create source Location') self.source_location_arn = hook.create_location( self.source_location_uri, **self.create_source_location_kwargs @@ -288,7 +294,11 @@ class AWSDataSyncOperator(BaseOperator): ) self.destination_location_arn = self.choose_location(self.candidate_destination_location_arns) - if not self.destination_location_arn and self.create_destination_location_kwargs: + if ( + not self.destination_location_arn + and self.destination_location_uri + and self.create_destination_location_kwargs + ): self.log.info('Attempting to create destination Location') self.destination_location_arn = hook.create_location( self.destination_location_uri, **self.create_destination_location_kwargs @@ -305,18 +315,22 @@ class AWSDataSyncOperator(BaseOperator): if not self.task_arn: raise AirflowException("Task could not be created") self.log.info("Created a Task with TaskArn %s", self.task_arn) - return self.task_arn - def _update_datasync_task(self): + def _update_datasync_task(self) -> None: """Update a AWS DataSyncTask.""" + if not self.task_arn: + return + hook = self.get_hook() self.log.info("Updating TaskArn %s", self.task_arn) hook.update_task(self.task_arn, **self.update_task_kwargs) self.log.info("Updated TaskArn %s", self.task_arn) - return self.task_arn - def _execute_datasync_task(self): + def _execute_datasync_task(self) -> None: """Create and monitor an AWSDataSync TaskExecution for a Task.""" + if not self.task_arn: + raise AirflowException("Missing TaskArn") + hook = self.get_hook() # Create a task execution: @@ -340,9 +354,8 @@ class AWSDataSyncOperator(BaseOperator): if not result: raise AirflowException("Failed TaskExecutionArn %s" % self.task_execution_arn) - return self.task_execution_arn - def on_kill(self): + def on_kill(self) -> None: """Cancel the submitted DataSync task.""" hook = self.get_hook() if self.task_execution_arn: @@ -350,16 +363,18 @@ class AWSDataSyncOperator(BaseOperator): hook.cancel_task_execution(task_execution_arn=self.task_execution_arn) self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn) - def _delete_datasync_task(self): + def _delete_datasync_task(self) -> None: """Deletes an AWS DataSync Task.""" + if not self.task_arn: + return + hook = self.get_hook() # Delete task: self.log.info("Deleting Task with TaskArn %s", self.task_arn) hook.delete_task(self.task_arn) self.log.info("Task Deleted") - return self.task_arn - def _get_location_arns(self, location_uri): + def _get_location_arns(self, location_uri) -> List[str]: location_arns = self.get_hook().get_location_arns(location_uri) self.log.info("Found LocationArns %s for LocationUri %s", location_arns, location_uri) return location_arns diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index e2e85d9..e7f5abb 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -18,7 +18,9 @@ import re import sys from datetime import datetime -from typing import Dict, Optional +from typing import Optional + +from botocore.waiter import Waiter from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -40,19 +42,19 @@ class ECSProtocol(Protocol): - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html """ - def run_task(self, **kwargs): + def run_task(self, **kwargs) -> dict: """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501 # pylint: disable=line-too-long ... - def get_waiter(self, x: str): + def get_waiter(self, x: str) -> Waiter: """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501 # pylint: disable=line-too-long ... - def describe_tasks(self, cluster: str, tasks) -> Dict: + def describe_tasks(self, cluster: str, tasks) -> dict: """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501 # pylint: disable=line-too-long ... - def stop_task(self, cluster, task, reason: str) -> Dict: + def stop_task(self, cluster, task, reason: str) -> dict: """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501 # pylint: disable=line-too-long ... @@ -111,30 +113,28 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes """ ui_color = '#f0ede4' - client = None # type: Optional[ECSProtocol] - arn = None # type: Optional[str] template_fields = ('overrides',) @apply_defaults def __init__( self, *, - task_definition, - cluster, - overrides, # pylint: disable=too-many-arguments - aws_conn_id=None, - region_name=None, - launch_type='EC2', - group=None, - placement_constraints=None, - placement_strategy=None, - platform_version='LATEST', - network_configuration=None, - tags=None, - awslogs_group=None, - awslogs_region=None, - awslogs_stream_prefix=None, - propagate_tags=None, + task_definition: str, + cluster: str, + overrides: dict, # pylint: disable=too-many-arguments + aws_conn_id: Optional[str] = None, + region_name: Optional[str] = None, + launch_type: str = 'EC2', + group: Optional[str] = None, + placement_constraints: Optional[list] = None, + placement_strategy: Optional[list] = None, + platform_version: str = 'LATEST', + network_configuration: Optional[dict] = None, + tags: Optional[dict] = None, + awslogs_group: Optional[str] = None, + awslogs_region: Optional[str] = None, + awslogs_stream_prefix: Optional[str] = None, + propagate_tags: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -160,7 +160,9 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes if self.awslogs_region is None: self.awslogs_region = region_name - self.hook = None + self.hook: Optional[AwsBaseHook] = None + self.client: Optional[ECSProtocol] = None + self.arn: Optional[str] = None def execute(self, context): self.log.info( @@ -207,12 +209,18 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes self._check_success_task() self.log.info('ECS Task has been successfully executed: %s', response) - def _wait_for_task_ended(self): + def _wait_for_task_ended(self) -> None: + if not self.client or not self.arn: + return + waiter = self.client.get_waiter('tasks_stopped') waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow waiter.wait(cluster=self.cluster, tasks=[self.arn]) - def _check_success_task(self): + def _check_success_task(self) -> None: + if not self.client or not self.arn: + return + response = self.client.describe_tasks(cluster=self.cluster, tasks=[self.arn]) self.log.info('ECS Task stopped, check status: %s', response) @@ -252,19 +260,22 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes ) ) - def get_hook(self): + def get_hook(self) -> AwsBaseHook: """Create and return an AwsHook.""" - if not self.hook: - self.hook = AwsBaseHook( - aws_conn_id=self.aws_conn_id, client_type='ecs', region_name=self.region_name - ) + if self.hook: + return self.hook + + self.hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='ecs', region_name=self.region_name) return self.hook - def get_logs_hook(self): + def get_logs_hook(self) -> AwsLogsHook: """Create and return an AwsLogsHook.""" return AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region) - def on_kill(self): + def on_kill(self) -> None: + if not self.client or not self.arn: + return + response = self.client.stop_task( cluster=self.cluster, task=self.arn, reason='Task killed by the user' ) diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 991135f..48d5661 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -18,6 +18,7 @@ from __future__ import unicode_literals import os.path +from typing import Optional from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue import AwsGlueJobHook @@ -61,24 +62,24 @@ class AwsGlueJobOperator(BaseOperator): def __init__( self, *, - job_name='aws_glue_default_job', - job_desc='AWS Glue Job with Airflow', - script_location=None, - concurrent_run_limit=None, - script_args=None, - retry_limit=None, - num_of_dpus=6, - aws_conn_id='aws_default', - region_name=None, - s3_bucket=None, - iam_role_name=None, + job_name: str = 'aws_glue_default_job', + job_desc: str = 'AWS Glue Job with Airflow', + script_location: Optional[str] = None, + concurrent_run_limit: Optional[int] = None, + script_args: Optional[dict] = None, + retry_limit: Optional[int] = None, + num_of_dpus: int = 6, + aws_conn_id: str = 'aws_default', + region_name: Optional[str] = None, + s3_bucket: Optional[str] = None, + iam_role_name: Optional[str] = None, **kwargs, ): # pylint: disable=too-many-arguments super(AwsGlueJobOperator, self).__init__(**kwargs) self.job_name = job_name self.job_desc = job_desc self.script_location = script_location - self.concurrent_run_limit = concurrent_run_limit + self.concurrent_run_limit = concurrent_run_limit or 1 self.script_args = script_args or {} self.retry_limit = retry_limit self.num_of_dpus = num_of_dpus @@ -87,7 +88,7 @@ class AwsGlueJobOperator(BaseOperator): self.s3_bucket = s3_bucket self.iam_role_name = iam_role_name self.s3_protocol = "s3://" - self.s3_artifcats_prefix = 'artifacts/glue-scripts/' + self.s3_artifacts_prefix = 'artifacts/glue-scripts/' def execute(self, context): """ @@ -98,7 +99,7 @@ class AwsGlueJobOperator(BaseOperator): if self.script_location and not self.script_location.startswith(self.s3_protocol): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) script_name = os.path.basename(self.script_location) - s3_hook.load_file(self.script_location, self.s3_bucket, self.s3_artifcats_prefix + script_name) + s3_hook.load_file(self.script_location, self.s3_bucket, self.s3_artifacts_prefix + script_name) glue_job = AwsGlueJobHook( job_name=self.job_name, desc=self.job_desc, diff --git a/airflow/providers/amazon/aws/operators/s3_bucket.py b/airflow/providers/amazon/aws/operators/s3_bucket.py index 14d25cc..bb15baa 100644 --- a/airflow/providers/amazon/aws/operators/s3_bucket.py +++ b/airflow/providers/amazon/aws/operators/s3_bucket.py @@ -43,7 +43,7 @@ class S3CreateBucketOperator(BaseOperator): def __init__( self, *, - bucket_name, + bucket_name: str, aws_conn_id: Optional[str] = "aws_default", region_name: Optional[str] = None, **kwargs, @@ -81,7 +81,7 @@ class S3DeleteBucketOperator(BaseOperator): def __init__( self, - bucket_name, + bucket_name: str, force_delete: bool = False, aws_conn_id: Optional[str] = "aws_default", **kwargs, diff --git a/airflow/providers/amazon/aws/operators/s3_copy_object.py b/airflow/providers/amazon/aws/operators/s3_copy_object.py index 4b2d290..052c9ad 100644 --- a/airflow/providers/amazon/aws/operators/s3_copy_object.py +++ b/airflow/providers/amazon/aws/operators/s3_copy_object.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, Union from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -70,13 +71,13 @@ class S3CopyObjectOperator(BaseOperator): def __init__( self, *, - source_bucket_key, - dest_bucket_key, - source_bucket_name=None, - dest_bucket_name=None, - source_version_id=None, - aws_conn_id='aws_default', - verify=None, + source_bucket_key: str, + dest_bucket_key: str, + source_bucket_name: Optional[str] = None, + dest_bucket_name: Optional[str] = None, + source_version_id: Optional[str] = None, + aws_conn_id: str = 'aws_default', + verify: Optional[Union[str, bool]] = None, **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/s3_delete_objects.py b/airflow/providers/amazon/aws/operators/s3_delete_objects.py index b6d267b..96c9e14 100644 --- a/airflow/providers/amazon/aws/operators/s3_delete_objects.py +++ b/airflow/providers/amazon/aws/operators/s3_delete_objects.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, Union from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -62,7 +63,16 @@ class S3DeleteObjectsOperator(BaseOperator): template_fields = ('keys', 'bucket', 'prefix') @apply_defaults - def __init__(self, *, bucket, keys=None, prefix=None, aws_conn_id='aws_default', verify=None, **kwargs): + def __init__( + self, + *, + bucket: str, + keys: Optional[Union[str, list]] = None, + prefix: Optional[str] = None, + aws_conn_id: str = 'aws_default', + verify: Optional[Union[str, bool]] = None, + **kwargs, + ): if not bool(keys) ^ bool(prefix): raise ValueError("Either keys or prefix should be set.") diff --git a/airflow/providers/amazon/aws/operators/s3_list.py b/airflow/providers/amazon/aws/operators/s3_list.py index 4c25e99..58d599d 100644 --- a/airflow/providers/amazon/aws/operators/s3_list.py +++ b/airflow/providers/amazon/aws/operators/s3_list.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterable +from typing import Iterable, Optional, Union from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -70,7 +70,16 @@ class S3ListOperator(BaseOperator): ui_color = '#ffd700' @apply_defaults - def __init__(self, *, bucket, prefix='', delimiter='', aws_conn_id='aws_default', verify=None, **kwargs): + def __init__( + self, + *, + bucket: str, + prefix: str = '', + delimiter: str = '', + aws_conn_id: str = 'aws_default', + verify: Optional[Union[str, bool]] = None, + **kwargs, + ): super().__init__(**kwargs) self.bucket = bucket self.prefix = prefix diff --git a/airflow/providers/amazon/aws/operators/sagemaker_base.py b/airflow/providers/amazon/aws/operators/sagemaker_base.py index 19fb921..3fa1b2e 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_base.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_base.py @@ -19,6 +19,8 @@ import json from typing import Iterable +from cached_property import cached_property + from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.utils.decorators import apply_defaults @@ -41,12 +43,11 @@ class SageMakerBaseOperator(BaseOperator): integer_fields = [] # type: Iterable[Iterable[str]] @apply_defaults - def __init__(self, *, config, aws_conn_id='aws_default', **kwargs): + def __init__(self, *, config: dict, aws_conn_id: str = 'aws_default', **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.config = config - self.hook = None def parse_integer(self, config, field): """Recursive method for parsing string fields holding integer values to integers.""" @@ -84,7 +85,6 @@ class SageMakerBaseOperator(BaseOperator): def preprocess_config(self): """Process the config into a usable form.""" self.log.info('Preprocessing the config and doing required s3_operations') - self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) self.parse_config_integers() @@ -97,3 +97,8 @@ class SageMakerBaseOperator(BaseOperator): def execute(self, context): raise NotImplementedError('Please implement execute() in sub class!') + + @cached_property + def hook(self): + """Return SageMakerHook""" + return SageMakerHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py index c7a89f2..53cfd93 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional from botocore.exceptions import ClientError @@ -74,11 +75,11 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): def __init__( self, *, - config, - wait_for_completion=True, - check_interval=30, - max_ingestion_time=None, - operation='create', + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + operation: str = 'create', **kwargs, ): super().__init__(config=config, **kwargs) @@ -92,12 +93,12 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"') self.create_integer_fields() - def create_integer_fields(self): + def create_integer_fields(self) -> None: """Set fields which should be casted to integers.""" if 'EndpointConfig' in self.config: self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']] - def expand_role(self): + def expand_role(self) -> None: if 'Model' not in self.config: return hook = AwsBaseHook(self.aws_conn_id, client_type='iam') @@ -105,7 +106,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): if 'ExecutionRoleArn' in config: config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn']) - def execute(self, context): + def execute(self, context) -> dict: self.preprocess_config() model_info = self.config.get('Model') diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py index 9bde451..bbf2be1 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py @@ -38,12 +38,12 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator): integer_fields = [['ProductionVariants', 'InitialInstanceCount']] @apply_defaults - def __init__(self, *, config, **kwargs): + def __init__(self, *, config: dict, **kwargs): super().__init__(config=config, **kwargs) self.config = config - def execute(self, context): + def execute(self, context) -> dict: self.preprocess_config() self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName']) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_model.py b/airflow/providers/amazon/aws/operators/sagemaker_model.py index 122ceee..25730ea 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_model.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_model.py @@ -42,12 +42,12 @@ class SageMakerModelOperator(SageMakerBaseOperator): self.config = config - def expand_role(self): + def expand_role(self) -> None: if 'ExecutionRoleArn' in self.config: hook = AwsBaseHook(self.aws_conn_id, client_type='iam') self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn']) - def execute(self, context): + def execute(self, context) -> dict: self.preprocess_config() self.log.info('Creating SageMaker Model %s.', self.config['ModelName']) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_processing.py b/airflow/providers/amazon/aws/operators/sagemaker_processing.py index c1bcac7..e56a987 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_processing.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_processing.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -55,12 +56,12 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): def __init__( self, *, - config, - aws_conn_id, - wait_for_completion=True, - print_log=True, - check_interval=30, - max_ingestion_time=None, + config: dict, + aws_conn_id: str, + wait_for_completion: bool = True, + print_log: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 **kwargs, ): @@ -78,7 +79,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): self.max_ingestion_time = max_ingestion_time self._create_integer_fields() - def _create_integer_fields(self): + def _create_integer_fields(self) -> None: """Set fields which should be casted to integers.""" self.integer_fields = [ ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], @@ -87,12 +88,12 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): if 'StoppingCondition' in self.config: self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']] - def expand_role(self): + def expand_role(self) -> None: if 'RoleArn' in self.config: hook = AwsBaseHook(self.aws_conn_id, client_type='iam') self.config['RoleArn'] = hook.expand_role(self.config['RoleArn']) - def execute(self, context): + def execute(self, context) -> dict: self.preprocess_config() processing_job_name = self.config["ProcessingJobName"] diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py b/airflow/providers/amazon/aws/operators/sagemaker_training.py index 6175a61..29c34f6 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_training.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_training.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -61,11 +62,11 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): def __init__( self, *, - config, - wait_for_completion=True, - print_log=True, - check_interval=30, - max_ingestion_time=None, + config: dict, + wait_for_completion: bool = True, + print_log: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 **kwargs, ): @@ -84,12 +85,12 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): f"Provided value: '{action_if_job_exists}'." ) - def expand_role(self): + def expand_role(self) -> None: if 'RoleArn' in self.config: hook = AwsBaseHook(self.aws_conn_id, client_type='iam') self.config['RoleArn'] = hook.expand_role(self.config['RoleArn']) - def execute(self, context): + def execute(self, context) -> dict: self.preprocess_config() training_job_name = self.config["TrainingJobName"] diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py index 7ae8f3a..1dadb3d 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_transform.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, List from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -63,7 +64,13 @@ class SageMakerTransformOperator(SageMakerBaseOperator): @apply_defaults def __init__( - self, *, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None, **kwargs + self, + *, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + **kwargs, ): super().__init__(config=config, **kwargs) self.config = config @@ -72,9 +79,9 @@ class SageMakerTransformOperator(SageMakerBaseOperator): self.max_ingestion_time = max_ingestion_time self.create_integer_fields() - def create_integer_fields(self): + def create_integer_fields(self) -> None: """Set fields which should be casted to integers.""" - self.integer_fields = [ + self.integer_fields: List[List[str]] = [ ['Transform', 'TransformResources', 'InstanceCount'], ['Transform', 'MaxConcurrentTransforms'], ['Transform', 'MaxPayloadInMB'], @@ -83,7 +90,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator): for field in self.integer_fields: field.pop(0) - def expand_role(self): + def expand_role(self) -> None: if 'Model' not in self.config: return config = self.config['Model'] @@ -91,7 +98,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator): hook = AwsBaseHook(self.aws_conn_id, client_type='iam') config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn']) - def execute(self, context): + def execute(self, context) -> dict: self.preprocess_config() model_config = self.config.get('Model') diff --git a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py index 483e541..f8df36a 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -56,7 +57,13 @@ class SageMakerTuningOperator(SageMakerBaseOperator): @apply_defaults def __init__( - self, *, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None, **kwargs + self, + *, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + **kwargs, ): super().__init__(config=config, **kwargs) self.config = config @@ -64,14 +71,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator): self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time - def expand_role(self): + def expand_role(self) -> None: if 'TrainingJobDefinition' in self.config: config = self.config['TrainingJobDefinition'] if 'RoleArn' in config: hook = AwsBaseHook(self.aws_conn_id, client_type='iam') config['RoleArn'] = hook.expand_role(config['RoleArn']) - def execute(self, context): + def execute(self, context) -> dict: self.preprocess_config() self.log.info( diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py index 8917dfe..1e88913 100644 --- a/airflow/providers/amazon/aws/operators/sns.py +++ b/airflow/providers/amazon/aws/operators/sns.py @@ -17,6 +17,7 @@ # under the License. """Publish message to SNS queue""" +from typing import Optional from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.sns import AwsSnsHook @@ -47,11 +48,11 @@ class SnsPublishOperator(BaseOperator): def __init__( self, *, - target_arn, - message, - aws_conn_id='aws_default', - subject=None, - message_attributes=None, + target_arn: str, + message: str, + aws_conn_id: str = 'aws_default', + subject: Optional[str] = None, + message_attributes: Optional[dict] = None, **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/sqs.py b/airflow/providers/amazon/aws/operators/sqs.py index 6005195..afc50e4 100644 --- a/airflow/providers/amazon/aws/operators/sqs.py +++ b/airflow/providers/amazon/aws/operators/sqs.py @@ -16,6 +16,7 @@ # under the License. """Publish message to SQS queue""" +from typing import Optional from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.sqs import SQSHook @@ -46,11 +47,11 @@ class SQSPublishOperator(BaseOperator): def __init__( self, *, - sqs_queue, - message_content, - message_attributes=None, - delay_seconds=0, - aws_conn_id='aws_default', + sqs_queue: str, + message_content: str, + message_attributes: Optional[dict] = None, + delay_seconds: int = 0, + aws_conn_id: str = 'aws_default', **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py index 2eaa2c4..769f06c 100644 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -16,6 +16,7 @@ # under the License. import json +from typing import Optional from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook @@ -42,7 +43,14 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator): ui_color = '#f9c915' @apply_defaults - def __init__(self, *, execution_arn: str, aws_conn_id='aws_default', region_name=None, **kwargs): + def __init__( + self, + *, + execution_arn: str, + aws_conn_id: str = 'aws_default', + region_name: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.execution_arn = execution_arn self.aws_conn_id = aws_conn_id diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py index 0d8f446..b364ba5 100644 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -55,8 +55,8 @@ class StepFunctionStartExecutionOperator(BaseOperator): state_machine_arn: str, name: Optional[str] = None, state_machine_input: Union[dict, str, None] = None, - aws_conn_id='aws_default', - region_name=None, + aws_conn_id: str = 'aws_default', + region_name: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py index 5c66a0c..2da691b 100644 --- a/airflow/providers/amazon/aws/sensors/cloud_formation.py +++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """This module contains sensors for AWS CloudFormation.""" +from typing import Optional + from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults @@ -69,12 +71,19 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator): ui_color = '#C5CAE9' @apply_defaults - def __init__(self, *, stack_name, aws_conn_id='aws_default', region_name=None, **kwargs): + def __init__( + self, + *, + stack_name: str, + aws_conn_id: str = 'aws_default', + region_name: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.region_name = region_name self.stack_name = stack_name - self.hook = None + self.hook: Optional[AWSCloudFormationHook] = None def poke(self, context): stack_status = self.get_hook().get_stack_status(self.stack_name) @@ -84,8 +93,10 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator): return False raise ValueError(f'Stack {self.stack_name} in bad state: {stack_status}') - def get_hook(self): + def get_hook(self) -> AWSCloudFormationHook: """Create and return an AWSCloudFormationHook""" - if not self.hook: - self.hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + if self.hook: + return self.hook + + self.hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py b/airflow/providers/amazon/aws/sensors/emr_base.py index f05197b..d862c6b 100644 --- a/airflow/providers/amazon/aws/sensors/emr_base.py +++ b/airflow/providers/amazon/aws/sensors/emr_base.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Iterable from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrHook @@ -42,17 +42,19 @@ class EmrBaseSensor(BaseSensorOperator): ui_color = '#66c3ff' @apply_defaults - def __init__(self, *, aws_conn_id='aws_default', **kwargs): + def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id - self.target_states = None # will be set in subclasses - self.failed_states = None # will be set in subclasses - self.hook = None + self.target_states: Optional[Iterable[str]] = None # will be set in subclasses + self.failed_states: Optional[Iterable[str]] = None # will be set in subclasses + self.hook: Optional[EmrHook] = None - def get_hook(self): + def get_hook(self) -> EmrHook: """Get EmrHook""" - if not self.hook: - self.hook = EmrHook(aws_conn_id=self.aws_conn_id) + if self.hook: + return self.hook + + self.hook = EmrHook(aws_conn_id=self.aws_conn_id) return self.hook def poke(self, context): diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py index 7b2ce30..92876c3 100644 --- a/airflow/providers/amazon/aws/sensors/glue.py +++ b/airflow/providers/amazon/aws/sensors/glue.py @@ -36,7 +36,7 @@ class AwsGlueJobSensor(BaseSensorOperator): template_fields = ('job_name', 'run_id') @apply_defaults - def __init__(self, *, job_name, run_id, aws_conn_id='aws_default', **kwargs): + def __init__(self, *, job_name: str, run_id: str, aws_conn_id: str = 'aws_default', **kwargs): super().__init__(**kwargs) self.job_name = job_name self.run_id = run_id diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py index 7292626..3849094 100644 --- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -59,12 +60,12 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator): def __init__( self, *, - table_name, - expression="ds='{{ ds }}'", - aws_conn_id='aws_default', - region_name=None, - database_name='default', - poke_interval=60 * 3, + table_name: str, + expression: str = "ds='{{ ds }}'", + aws_conn_id: str = 'aws_default', + region_name: Optional[str] = None, + database_name: str = 'default', + poke_interval: int = 60 * 3, **kwargs, ): super().__init__(poke_interval=poke_interval, **kwargs) @@ -73,7 +74,7 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator): self.table_name = table_name self.expression = expression self.database_name = database_name - self.hook = None + self.hook: Optional[AwsGlueCatalogHook] = None def poke(self, context): """Checks for existence of the partition in the AWS Glue Catalog table""" @@ -85,8 +86,10 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator): return self.get_hook().check_for_partition(self.database_name, self.table_name, self.expression) - def get_hook(self): + def get_hook(self) -> AwsGlueCatalogHook: """Gets the AwsGlueCatalogHook""" - if not self.hook: - self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + if self.hook: + return self.hook + + self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/redshift.py b/airflow/providers/amazon/aws/sensors/redshift.py index 37f3521..106801a 100644 --- a/airflow/providers/amazon/aws/sensors/redshift.py +++ b/airflow/providers/amazon/aws/sensors/redshift.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional from airflow.providers.amazon.aws.hooks.redshift import RedshiftHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -34,19 +35,28 @@ class AwsRedshiftClusterSensor(BaseSensorOperator): template_fields = ('cluster_identifier', 'target_status') @apply_defaults - def __init__(self, *, cluster_identifier, target_status='available', aws_conn_id='aws_default', **kwargs): + def __init__( + self, + *, + cluster_identifier: str, + target_status: str = 'available', + aws_conn_id: str = 'aws_default', + **kwargs, + ): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.target_status = target_status self.aws_conn_id = aws_conn_id - self.hook = None + self.hook: Optional[RedshiftHook] = None def poke(self, context): self.log.info('Poking for status : %s\nfor cluster %s', self.target_status, self.cluster_identifier) return self.get_hook().cluster_status(self.cluster_identifier) == self.target_status - def get_hook(self): + def get_hook(self) -> RedshiftHook: """Create and return a RedshiftHook""" - if not self.hook: - self.hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + if self.hook: + return self.hook + + self.hook = RedshiftHook(aws_conn_id=self.aws_conn_id) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/s3_key.py b/airflow/providers/amazon/aws/sensors/s3_key.py index 0c0f6e3..9eab08e 100644 --- a/airflow/providers/amazon/aws/sensors/s3_key.py +++ b/airflow/providers/amazon/aws/sensors/s3_key.py @@ -15,8 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - +from typing import Optional, Union from urllib.parse import urlparse from airflow.exceptions import AirflowException @@ -62,11 +61,11 @@ class S3KeySensor(BaseSensorOperator): def __init__( self, *, - bucket_key, - bucket_name=None, - wildcard_match=False, - aws_conn_id='aws_default', - verify=None, + bucket_key: str, + bucket_name: Optional[str] = None, + wildcard_match: bool = False, + aws_conn_id: str = 'aws_default', + verify: Optional[Union[str, bool]] = None, **kwargs, ): super().__init__(**kwargs) @@ -91,7 +90,7 @@ class S3KeySensor(BaseSensorOperator): self.wildcard_match = wildcard_match self.aws_conn_id = aws_conn_id self.verify = verify - self.hook = None + self.hook: Optional[S3Hook] = None def poke(self, context): self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key) @@ -99,8 +98,10 @@ class S3KeySensor(BaseSensorOperator): return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name) return self.get_hook().check_for_key(self.bucket_key, self.bucket_name) - def get_hook(self): + def get_hook(self) -> S3Hook: """Create and return an S3Hook""" - if not self.hook: - self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + if self.hook: + return self.hook + + self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/s3_prefix.py b/airflow/providers/amazon/aws/sensors/s3_prefix.py index 4dc4900..13fb37c 100644 --- a/airflow/providers/amazon/aws/sensors/s3_prefix.py +++ b/airflow/providers/amazon/aws/sensors/s3_prefix.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, Union from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -56,7 +57,14 @@ class S3PrefixSensor(BaseSensorOperator): @apply_defaults def __init__( - self, *, bucket_name, prefix, delimiter='/', aws_conn_id='aws_default', verify=None, **kwargs + self, + *, + bucket_name: str, + prefix: str, + delimiter: str = '/', + aws_conn_id: str = 'aws_default', + verify: Optional[Union[str, bool]] = None, + **kwargs, ): super().__init__(**kwargs) # Parse @@ -66,7 +74,7 @@ class S3PrefixSensor(BaseSensorOperator): self.full_url = "s3://" + bucket_name + '/' + prefix self.aws_conn_id = aws_conn_id self.verify = verify - self.hook = None + self.hook: Optional[S3Hook] = None def poke(self, context): self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name) @@ -74,8 +82,10 @@ class S3PrefixSensor(BaseSensorOperator): prefix=self.prefix, delimiter=self.delimiter, bucket_name=self.bucket_name ) - def get_hook(self): + def get_hook(self) -> S3Hook: """Create and return an S3Hook""" - if not self.hook: - self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + if self.hook: + return self.hook + + self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_base.py b/airflow/providers/amazon/aws/sensors/sagemaker_base.py index 6704b1a..f55b6cc 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_base.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_base.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, Set + from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -32,15 +34,17 @@ class SageMakerBaseSensor(BaseSensorOperator): ui_color = '#ededed' @apply_defaults - def __init__(self, *, aws_conn_id='aws_default', **kwargs): + def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id - self.hook = None + self.hook: Optional[SageMakerHook] = None - def get_hook(self): + def get_hook(self) -> SageMakerHook: """Get SageMakerHook""" - if not self.hook: - self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) + if self.hook: + return self.hook + + self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) return self.hook def poke(self, context): @@ -62,22 +66,22 @@ class SageMakerBaseSensor(BaseSensorOperator): raise AirflowException('Sagemaker job failed for the following reason: %s' % failed_reason) return True - def non_terminal_states(self): + def non_terminal_states(self) -> Set[str]: """Placeholder for returning states with should not terminate.""" raise NotImplementedError('Please implement non_terminal_states() in subclass') - def failed_states(self): + def failed_states(self) -> Set[str]: """Placeholder for returning states with are considered failed.""" raise NotImplementedError('Please implement failed_states() in subclass') - def get_sagemaker_response(self): + def get_sagemaker_response(self) -> Optional[dict]: """Placeholder for checking status of a SageMaker task.""" raise NotImplementedError('Please implement get_sagemaker_response() in subclass') - def get_failed_reason_from_response(self, response): # pylint: disable=unused-argument + def get_failed_reason_from_response(self, response: dict) -> str: # pylint: disable=unused-argument """Placeholder for extracting the reason for failure from an AWS response.""" return 'Unknown' - def state_from_response(self, response): + def state_from_response(self, response: dict) -> str: """Placeholder for extracting the state from an AWS response.""" raise NotImplementedError('Please implement state_from_response() in subclass') diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py index 36403b8..9cd7668 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_training.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional import time @@ -44,13 +45,13 @@ class SageMakerTrainingSensor(SageMakerBaseSensor): self.print_log = print_log self.positions = {} self.stream_names = [] - self.instance_count = None - self.state = None + self.instance_count: Optional[int] = None + self.state: Optional[int] = None self.last_description = None self.last_describe_job_call = None self.log_resource_inited = False - def init_log_resource(self, hook): + def init_log_resource(self, hook: SageMakerHook) -> None: """Set tailing LogState for associated training job.""" description = hook.describe_training_job(self.job_name) self.instance_count = description['ResourceConfig']['InstanceCount'] diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py index 4108c98..a751e56 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py @@ -35,7 +35,7 @@ class SageMakerTransformSensor(SageMakerBaseSensor): template_ext = () @apply_defaults - def __init__(self, *, job_name, **kwargs): + def __init__(self, *, job_name: str, **kwargs): super().__init__(**kwargs) self.job_name = job_name diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py index 794695b..96080e0 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py @@ -35,7 +35,7 @@ class SageMakerTuningSensor(SageMakerBaseSensor): template_ext = () @apply_defaults - def __init__(self, *, job_name, **kwargs): + def __init__(self, *, job_name: str, **kwargs): super().__init__(**kwargs) self.job_name = job_name diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py index b6c467d..e7a250d 100644 --- a/airflow/providers/amazon/aws/sensors/sqs.py +++ b/airflow/providers/amazon/aws/sensors/sqs.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. """Reads and then deletes the message from SQS queue""" +from typing import Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sqs import SQSHook @@ -43,14 +44,20 @@ class SQSSensor(BaseSensorOperator): @apply_defaults def __init__( - self, *, sqs_queue, aws_conn_id='aws_default', max_messages=5, wait_time_seconds=1, **kwargs + self, + *, + sqs_queue, + aws_conn_id: str = 'aws_default', + max_messages: int = 5, + wait_time_seconds: int = 1, + **kwargs, ): super().__init__(**kwargs) self.sqs_queue = sqs_queue self.aws_conn_id = aws_conn_id self.max_messages = max_messages self.wait_time_seconds = wait_time_seconds - self.hook = None + self.hook: Optional[SQSHook] = None def poke(self, context): """ @@ -90,8 +97,10 @@ class SQSSensor(BaseSensorOperator): return False - def get_hook(self): + def get_hook(self) -> SQSHook: """Create and return an SQSHook""" - if not self.hook: - self.hook = SQSHook(aws_conn_id=self.aws_conn_id) + if self.hook: + return self.hook + + self.hook = SQSHook(aws_conn_id=self.aws_conn_id) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py index 6126670..75c7e8b 100644 --- a/airflow/providers/amazon/aws/sensors/step_function_execution.py +++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py @@ -16,6 +16,7 @@ # under the License. import json +from typing import Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook @@ -51,12 +52,19 @@ class StepFunctionExecutionSensor(BaseSensorOperator): ui_color = '#66c3ff' @apply_defaults - def __init__(self, *, execution_arn: str, aws_conn_id='aws_default', region_name=None, **kwargs): + def __init__( + self, + *, + execution_arn: str, + aws_conn_id: str = 'aws_default', + region_name: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.execution_arn = execution_arn self.aws_conn_id = aws_conn_id self.region_name = region_name - self.hook = None + self.hook: Optional[StepFunctionHook] = None def poke(self, context): execution_status = self.get_hook().describe_execution(self.execution_arn) @@ -73,8 +81,10 @@ class StepFunctionExecutionSensor(BaseSensorOperator): self.xcom_push(context, 'output', output) return True - def get_hook(self): + def get_hook(self) -> StepFunctionHook: """Create and return a StepFunctionHook""" - if not self.hook: - self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + if self.hook: + return self.hook + + self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) return self.hook diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index acc87a1..505a71a 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -86,7 +86,7 @@ class TestAwsBatchOperator(unittest.TestCase): self.assertEqual(self.batch.waiters, None) self.assertEqual(self.batch.hook.max_retries, self.MAX_RETRIES) self.assertEqual(self.batch.hook.status_retries, self.STATUS_RETRIES) - self.assertEqual(self.batch.parameters, None) + self.assertEqual(self.batch.parameters, {}) self.assertEqual(self.batch.overrides, {}) self.assertEqual(self.batch.array_properties, {}) self.assertEqual(self.batch.hook.region_name, "eu-west-1") @@ -121,7 +121,7 @@ class TestAwsBatchOperator(unittest.TestCase): containerOverrides={}, jobDefinition="hello-world", arrayProperties={}, - parameters=None, + parameters={}, ) self.assertEqual(self.batch.job_id, JOB_ID) @@ -140,7 +140,7 @@ class TestAwsBatchOperator(unittest.TestCase): containerOverrides={}, jobDefinition="hello-world", arrayProperties={}, - parameters=None, + parameters={}, ) @mock.patch.object(AwsBatchClientHook, "check_job_success")