This is an automated email from the ASF dual-hosted git repository. onikolas pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 5c887988b0 Refactor Eks Create Cluster Operator code (#31960) 5c887988b0 is described below commit 5c887988b02b02e60f693c9341013592a291ee27 Author: Syed Hussaain <103602455+syeda...@users.noreply.github.com> AuthorDate: Fri Jun 23 14:18:13 2023 -0700 Refactor Eks Create Cluster Operator code (#31960) * Refactor EksCreateClusterOperator to reuse code being used in multiple places * Update create_compute method to pass tests Add waiter params to EksCreateClusterOperator and EksCreateNodegroupOperator Update EksCreateFargateProfileTrigger and EksDeleteFargateProfileTrigger to use more consistent waiter names Update unit tests for triggers and operators --- airflow/providers/amazon/aws/operators/eks.py | 249 +++++++++++++++-------- airflow/providers/amazon/aws/triggers/eks.py | 62 +++--- tests/providers/amazon/aws/operators/test_eks.py | 32 ++- tests/providers/amazon/aws/triggers/test_eks.py | 52 ++--- 4 files changed, 247 insertions(+), 148 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 8131be4f65..e280da5e5a 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -17,10 +17,11 @@ """This module contains Amazon EKS operators.""" from __future__ import annotations +import logging import warnings from ast import literal_eval from datetime import timedelta -from typing import TYPE_CHECKING, Any, List, Sequence, cast +from typing import TYPE_CHECKING, List, Sequence, cast from botocore.exceptions import ClientError, WaiterError @@ -31,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.eks import ( EksCreateFargateProfileTrigger, EksDeleteFargateProfileTrigger, ) +from airflow.providers.amazon.aws.utils.waiter_with_logging import wait try: from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator @@ -59,6 +61,75 @@ NODEGROUP_FULL_NAME = "Amazon EKS managed node groups" FARGATE_FULL_NAME = "AWS Fargate profiles" +def _create_compute( + compute: str | None, + cluster_name: str, + aws_conn_id: str, + region: str | None, + waiter_delay: int, + waiter_max_attempts: int, + wait_for_completion: bool = False, + nodegroup_name: str | None = None, + nodegroup_role_arn: str | None = None, + create_nodegroup_kwargs: dict | None = None, + fargate_profile_name: str | None = None, + fargate_pod_execution_role_arn: str | None = None, + fargate_selectors: list | None = None, + create_fargate_profile_kwargs: dict | None = None, + subnets: list[str] | None = None, +): + log = logging.getLogger(__name__) + eks_hook = EksHook(aws_conn_id=aws_conn_id, region_name=region) + if compute == "nodegroup" and nodegroup_name: + + # this is to satisfy mypy + subnets = subnets or [] + create_nodegroup_kwargs = create_nodegroup_kwargs or {} + + eks_hook.create_nodegroup( + clusterName=cluster_name, + nodegroupName=nodegroup_name, + subnets=subnets, + nodeRole=nodegroup_role_arn, + **create_nodegroup_kwargs, + ) + if wait_for_completion: + log.info("Waiting for nodegroup to provision. This will take some time.") + wait( + waiter=eks_hook.conn.get_waiter("nodegroup_active"), + waiter_delay=waiter_delay, + max_attempts=waiter_max_attempts, + args={"clusterName": cluster_name, "nodegroupName": nodegroup_name}, + failure_message="Nodegroup creation failed", + status_message="Nodegroup status is", + status_args=["nodegroup.status"], + ) + elif compute == "fargate" and fargate_profile_name: + + # this is to satisfy mypy + create_fargate_profile_kwargs = create_fargate_profile_kwargs or {} + fargate_selectors = fargate_selectors or [] + + eks_hook.create_fargate_profile( + clusterName=cluster_name, + fargateProfileName=fargate_profile_name, + podExecutionRoleArn=fargate_pod_execution_role_arn, + selectors=fargate_selectors, + **create_fargate_profile_kwargs, + ) + if wait_for_completion: + log.info("Waiting for Fargate profile to provision. This will take some time.") + wait( + waiter=eks_hook.conn.get_waiter("fargate_profile_active"), + waiter_delay=waiter_delay, + max_attempts=waiter_max_attempts, + args={"clusterName": cluster_name, "fargateProfileName": fargate_profile_name}, + failure_message="Fargate profile creation failed", + status_message="Fargate profile status is", + status_args=["fargateProfile.status"], + ) + + class EksCreateClusterOperator(BaseOperator): """ Creates an Amazon EKS Cluster control plane. @@ -112,6 +183,8 @@ class EksCreateClusterOperator(BaseOperator): :param fargate_selectors: The selectors to match for pods to use this AWS Fargate profile. (templated) :param create_fargate_profile_kwargs: Optional parameters to pass to the CreateFargateProfile API (templated) + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster status + :param waiter_max_attempts: The maximum number of attempts to check the status of the cluster. """ @@ -137,7 +210,7 @@ class EksCreateClusterOperator(BaseOperator): self, cluster_name: str, cluster_role_arn: str, - resources_vpc_config: dict[str, Any], + resources_vpc_config: dict, compute: str | None = DEFAULT_COMPUTE_TYPE, create_cluster_kwargs: dict | None = None, nodegroup_name: str = DEFAULT_NODEGROUP_NAME, @@ -150,6 +223,8 @@ class EksCreateClusterOperator(BaseOperator): wait_for_completion: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, region: str | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 40, **kwargs, ) -> None: self.compute = compute @@ -157,17 +232,21 @@ class EksCreateClusterOperator(BaseOperator): self.cluster_role_arn = cluster_role_arn self.resources_vpc_config = resources_vpc_config self.create_cluster_kwargs = create_cluster_kwargs or {} - self.nodegroup_name = nodegroup_name self.nodegroup_role_arn = nodegroup_role_arn - self.create_nodegroup_kwargs = create_nodegroup_kwargs or {} - self.fargate_profile_name = fargate_profile_name self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn - self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}] self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {} self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id self.region = region - super().__init__(**kwargs) + self.nodegroup_name = nodegroup_name + self.create_nodegroup_kwargs = create_nodegroup_kwargs or {} + self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}] + self.fargate_profile_name = fargate_profile_name + super().__init__( + **kwargs, + ) def execute(self, context: Context): if self.compute: @@ -183,13 +262,8 @@ class EksCreateClusterOperator(BaseOperator): compute=FARGATE_FULL_NAME, requirement="fargate_pod_execution_role_arn" ) ) - - eks_hook = EksHook( - aws_conn_id=self.aws_conn_id, - region_name=self.region, - ) - - eks_hook.create_cluster( + self.eks_hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + self.eks_hook.create_cluster( name=self.cluster_name, roleArn=self.cluster_role_arn, resourcesVpcConfig=self.resources_vpc_config, @@ -202,44 +276,38 @@ class EksCreateClusterOperator(BaseOperator): return None self.log.info("Waiting for EKS Cluster to provision. This will take some time.") - client = eks_hook.conn + client = self.eks_hook.conn try: - client.get_waiter("cluster_active").wait(name=self.cluster_name) + client.get_waiter("cluster_active").wait( + name=self.cluster_name, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) except (ClientError, WaiterError) as e: self.log.error("Cluster failed to start and will be torn down.\n %s", e) - eks_hook.delete_cluster(name=self.cluster_name) - client.get_waiter("cluster_deleted").wait(name=self.cluster_name) - raise - - if self.compute == "nodegroup": - eks_hook.create_nodegroup( - clusterName=self.cluster_name, - nodegroupName=self.nodegroup_name, - subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")), - nodeRole=self.nodegroup_role_arn, - **self.create_nodegroup_kwargs, - ) - if self.wait_for_completion: - self.log.info("Waiting for nodegroup to provision. This will take some time.") - client.get_waiter("nodegroup_active").wait( - clusterName=self.cluster_name, - nodegroupName=self.nodegroup_name, - ) - elif self.compute == "fargate": - eks_hook.create_fargate_profile( - clusterName=self.cluster_name, - fargateProfileName=self.fargate_profile_name, - podExecutionRoleArn=self.fargate_pod_execution_role_arn, - selectors=self.fargate_selectors, - **self.create_fargate_profile_kwargs, + self.eks_hook.delete_cluster(name=self.cluster_name) + client.get_waiter("cluster_deleted").wait( + name=self.cluster_name, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - if self.wait_for_completion: - self.log.info("Waiting for Fargate profile to provision. This will take some time.") - client.get_waiter("fargate_profile_active").wait( - clusterName=self.cluster_name, - fargateProfileName=self.fargate_profile_name, - ) + raise + _create_compute( + compute=self.compute, + cluster_name=self.cluster_name, + aws_conn_id=self.aws_conn_id, + region=self.region, + wait_for_completion=self.wait_for_completion, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + nodegroup_name=self.nodegroup_name, + nodegroup_role_arn=self.nodegroup_role_arn, + create_nodegroup_kwargs=self.create_nodegroup_kwargs, + fargate_profile_name=self.fargate_profile_name, + fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn, + fargate_selectors=self.fargate_selectors, + create_fargate_profile_kwargs=self.create_fargate_profile_kwargs, + subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")), + ) class EksCreateNodegroupOperator(BaseOperator): @@ -265,6 +333,8 @@ class EksCreateNodegroupOperator(BaseOperator): maintained on each worker node). :param region: Which AWS region the connection should use. (templated) If this is None or empty then the default boto3 behaviour is used. + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup status + :param waiter_max_attempts: The maximum number of attempts to check the status of the nodegroup. """ @@ -289,8 +359,12 @@ class EksCreateNodegroupOperator(BaseOperator): wait_for_completion: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, region: str | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 80, **kwargs, ) -> None: + self.nodegroup_subnets = nodegroup_subnets + self.compute = "nodegroup" self.cluster_name = cluster_name self.nodegroup_role_arn = nodegroup_role_arn self.nodegroup_name = nodegroup_name @@ -298,10 +372,15 @@ class EksCreateNodegroupOperator(BaseOperator): self.wait_for_completion = wait_for_completion self.aws_conn_id = aws_conn_id self.region = region - self.nodegroup_subnets = nodegroup_subnets - super().__init__(**kwargs) + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + super().__init__( + **kwargs, + ) def execute(self, context: Context): + self.log.info(self.task_id) if isinstance(self.nodegroup_subnets, str): nodegroup_subnets_list: list[str] = [] if self.nodegroup_subnets != "": @@ -314,25 +393,20 @@ class EksCreateNodegroupOperator(BaseOperator): self.nodegroup_subnets, ) self.nodegroup_subnets = nodegroup_subnets_list - - eks_hook = EksHook( + _create_compute( + compute=self.compute, + cluster_name=self.cluster_name, aws_conn_id=self.aws_conn_id, - region_name=self.region, - ) - eks_hook.create_nodegroup( - clusterName=self.cluster_name, - nodegroupName=self.nodegroup_name, + region=self.region, + wait_for_completion=self.wait_for_completion, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + nodegroup_name=self.nodegroup_name, + nodegroup_role_arn=self.nodegroup_role_arn, + create_nodegroup_kwargs=self.create_nodegroup_kwargs, subnets=self.nodegroup_subnets, - nodeRole=self.nodegroup_role_arn, - **self.create_nodegroup_kwargs, ) - if self.wait_for_completion: - self.log.info("Waiting for nodegroup to provision. This will take some time.") - eks_hook.conn.get_waiter("nodegroup_active").wait( - clusterName=self.cluster_name, nodegroupName=self.nodegroup_name - ) - class EksCreateFargateProfileOperator(BaseOperator): """ @@ -392,30 +466,34 @@ class EksCreateFargateProfileOperator(BaseOperator): **kwargs, ) -> None: self.cluster_name = cluster_name - self.pod_execution_role_arn = pod_execution_role_arn self.selectors = selectors + self.pod_execution_role_arn = pod_execution_role_arn self.fargate_profile_name = fargate_profile_name self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {} - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion self.aws_conn_id = aws_conn_id self.region = region self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable - super().__init__(**kwargs) + self.compute = "fargate" + super().__init__( + **kwargs, + ) def execute(self, context: Context): - eks_hook = EksHook( + _create_compute( + compute=self.compute, + cluster_name=self.cluster_name, aws_conn_id=self.aws_conn_id, - region_name=self.region, - ) - - eks_hook.create_fargate_profile( - clusterName=self.cluster_name, - fargateProfileName=self.fargate_profile_name, - podExecutionRoleArn=self.pod_execution_role_arn, - selectors=self.selectors, - **self.create_fargate_profile_kwargs, + region=self.region, + wait_for_completion=self.wait_for_completion, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + fargate_profile_name=self.fargate_profile_name, + fargate_pod_execution_role_arn=self.pod_execution_role_arn, + fargate_selectors=self.selectors, + create_fargate_profile_kwargs=self.create_fargate_profile_kwargs, ) if self.deferrable: self.defer( @@ -423,21 +501,15 @@ class EksCreateFargateProfileOperator(BaseOperator): cluster_name=self.cluster_name, fargate_profile_name=self.fargate_profile_name, aws_conn_id=self.aws_conn_id, - poll_interval=self.waiter_delay, - max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + region=self.region, ), method_name="execute_complete", # timeout is set to ensure that if a trigger dies, the timeout does not restart # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) timeout=timedelta(seconds=(self.waiter_max_attempts * self.waiter_delay + 60)), ) - elif self.wait_for_completion: - self.log.info("Waiting for Fargate profile to provision. This will take some time.") - eks_hook.conn.get_waiter("fargate_profile_active").wait( - clusterName=self.cluster_name, - fargateProfileName=self.fargate_profile_name, - WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, - ) def execute_complete(self, context, event=None): if event["status"] != "success": @@ -677,8 +749,9 @@ class EksDeleteFargateProfileOperator(BaseOperator): cluster_name=self.cluster_name, fargate_profile_name=self.fargate_profile_name, aws_conn_id=self.aws_conn_id, - poll_interval=self.waiter_delay, - max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + region=self.region, ), method_name="execute_complete", # timeout is set to ensure that if a trigger dies, the timeout does not restart diff --git a/airflow/providers/amazon/aws/triggers/eks.py b/airflow/providers/amazon/aws/triggers/eks.py index dddab74b30..8ccd88167c 100644 --- a/airflow/providers/amazon/aws/triggers/eks.py +++ b/airflow/providers/amazon/aws/triggers/eks.py @@ -33,8 +33,8 @@ class EksCreateFargateProfileTrigger(BaseTrigger): :param cluster_name: The name of the EKS cluster :param fargate_profile_name: The name of the fargate profile - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -42,15 +42,17 @@ class EksCreateFargateProfileTrigger(BaseTrigger): self, cluster_name: str, fargate_profile_name: str, - poll_interval: int, - max_attempts: int, + waiter_delay: int, + waiter_max_attempts: int, aws_conn_id: str, + region: str | None = None, ): self.cluster_name = cluster_name self.fargate_profile_name = fargate_profile_name - self.poll_interval = poll_interval - self.max_attempts = max_attempts + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id + self.region = region def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -58,24 +60,25 @@ class EksCreateFargateProfileTrigger(BaseTrigger): { "cluster_name": self.cluster_name, "fargate_profile_name": self.fargate_profile_name, - "poll_interval": str(self.poll_interval), - "max_attempts": str(self.max_attempts), + "waiter_delay": str(self.waiter_delay), + "waiter_max_attempts": str(self.waiter_max_attempts), "aws_conn_id": self.aws_conn_id, + "region": self.region, }, ) async def run(self): - self.hook = EksHook(aws_conn_id=self.aws_conn_id) + self.hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region) async with self.hook.async_conn as client: attempt = 0 waiter = client.get_waiter("fargate_profile_active") - while attempt < int(self.max_attempts): + while attempt < int(self.waiter_max_attempts): attempt += 1 try: await waiter.wait( clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name, - WaiterConfig={"Delay": int(self.poll_interval), "MaxAttempts": 1}, + WaiterConfig={"Delay": int(self.waiter_delay), "MaxAttempts": 1}, ) break except WaiterError as error: @@ -84,10 +87,10 @@ class EksCreateFargateProfileTrigger(BaseTrigger): self.log.info( "Status of fargate profile is %s", error.last_response["fargateProfile"]["status"] ) - await asyncio.sleep(int(self.poll_interval)) - if attempt >= int(self.max_attempts): + await asyncio.sleep(int(self.waiter_delay)) + if attempt >= int(self.waiter_max_attempts): raise AirflowException( - f"Create Fargate Profile failed - max attempts reached: {self.max_attempts}" + f"Create Fargate Profile failed - max attempts reached: {self.waiter_max_attempts}" ) else: yield TriggerEvent({"status": "success", "message": "Fargate Profile Created"}) @@ -100,8 +103,8 @@ class EksDeleteFargateProfileTrigger(BaseTrigger): :param cluster_name: The name of the EKS cluster :param fargate_profile_name: The name of the fargate profile - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -109,15 +112,17 @@ class EksDeleteFargateProfileTrigger(BaseTrigger): self, cluster_name: str, fargate_profile_name: str, - poll_interval: int, - max_attempts: int, + waiter_delay: int, + waiter_max_attempts: int, aws_conn_id: str, + region: str | None = None, ): self.cluster_name = cluster_name self.fargate_profile_name = fargate_profile_name - self.poll_interval = poll_interval - self.max_attempts = max_attempts + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id + self.region = region def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -125,24 +130,25 @@ class EksDeleteFargateProfileTrigger(BaseTrigger): { "cluster_name": self.cluster_name, "fargate_profile_name": self.fargate_profile_name, - "poll_interval": str(self.poll_interval), - "max_attempts": str(self.max_attempts), + "waiter_delay": str(self.waiter_delay), + "waiter_max_attempts": str(self.waiter_max_attempts), "aws_conn_id": self.aws_conn_id, + "region": self.region, }, ) async def run(self): - self.hook = EksHook(aws_conn_id=self.aws_conn_id) + self.hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region) async with self.hook.async_conn as client: attempt = 0 waiter = client.get_waiter("fargate_profile_deleted") - while attempt < int(self.max_attempts): + while attempt < int(self.waiter_max_attempts): attempt += 1 try: await waiter.wait( clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name, - WaiterConfig={"Delay": int(self.poll_interval), "MaxAttempts": 1}, + WaiterConfig={"Delay": int(self.waiter_delay), "MaxAttempts": 1}, ) break except WaiterError as error: @@ -151,10 +157,10 @@ class EksDeleteFargateProfileTrigger(BaseTrigger): self.log.info( "Status of fargate profile is %s", error.last_response["fargateProfile"]["status"] ) - await asyncio.sleep(int(self.poll_interval)) - if attempt >= int(self.max_attempts): + await asyncio.sleep(int(self.waiter_delay)) + if attempt >= int(self.waiter_max_attempts): raise AirflowException( - f"Delete Fargate Profile failed - max attempts reached: {self.max_attempts}" + f"Delete Fargate Profile failed - max attempts reached: {self.waiter_max_attempts}" ) else: yield TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 089aef1704..311aad972d 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -200,7 +200,11 @@ class TestEksCreateClusterOperator: operator.execute({}) mock_create_cluster.assert_called_with(**convert_keys(parameters)) mock_create_nodegroup.assert_not_called() - mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME) + mock_waiter.assert_called_once_with( + mock.ANY, + name=CLUSTER_NAME, + WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY}, + ) assert_expected_waiter_type(mock_waiter, "ClusterActive") @mock.patch.object(Waiter, "wait") @@ -216,7 +220,11 @@ class TestEksCreateClusterOperator: mock_create_cluster.assert_called_once_with(**convert_keys(self.create_cluster_params)) mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params)) - mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME) + mock_waiter.assert_called_once_with( + mock.ANY, + name=CLUSTER_NAME, + WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY}, + ) assert_expected_waiter_type(mock_waiter, "ClusterActive") @mock.patch.object(Waiter, "wait") @@ -235,7 +243,12 @@ class TestEksCreateClusterOperator: mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params)) # Calls waiter once for the cluster and once for the nodegroup. assert mock_waiter.call_count == 2 - mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME) + mock_waiter.assert_called_with( + mock.ANY, + clusterName=CLUSTER_NAME, + nodegroupName=NODEGROUP_NAME, + WaiterConfig={"MaxAttempts": mock.ANY}, + ) assert_expected_waiter_type(mock_waiter, "NodegroupActive") @mock.patch.object(Waiter, "wait") @@ -253,7 +266,11 @@ class TestEksCreateClusterOperator: mock_create_fargate_profile.assert_called_once_with( **convert_keys(self.create_fargate_profile_params) ) - mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME) + mock_waiter.assert_called_once_with( + mock.ANY, + name=CLUSTER_NAME, + WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY}, + ) assert_expected_waiter_type(mock_waiter, "ClusterActive") @mock.patch.object(Waiter, "wait") @@ -275,7 +292,10 @@ class TestEksCreateClusterOperator: # Calls waiter once for the cluster and once for the nodegroup. assert mock_waiter.call_count == 2 mock_waiter.assert_called_with( - mock.ANY, clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME + mock.ANY, + clusterName=CLUSTER_NAME, + fargateProfileName=FARGATE_PROFILE_NAME, + WaiterConfig={"MaxAttempts": mock.ANY}, ) assert_expected_waiter_type(mock_waiter, "FargateProfileActive") @@ -377,7 +397,7 @@ class TestEksCreateFargateProfileOperator: mock.ANY, clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME, - WaiterConfig={"Delay": 10, "MaxAttempts": 60}, + WaiterConfig={"MaxAttempts": mock.ANY}, ) assert_expected_waiter_type(mock_waiter, "FargateProfileActive") diff --git a/tests/providers/amazon/aws/triggers/test_eks.py b/tests/providers/amazon/aws/triggers/test_eks.py index abab121d24..dbc71e7296 100644 --- a/tests/providers/amazon/aws/triggers/test_eks.py +++ b/tests/providers/amazon/aws/triggers/test_eks.py @@ -32,8 +32,8 @@ from airflow.triggers.base import TriggerEvent TEST_CLUSTER_IDENTIFIER = "test-cluster" TEST_FARGATE_PROFILE_NAME = "test-fargate-profile" -TEST_POLL_INTERVAL = 10 -TEST_MAX_ATTEMPTS = 10 +TEST_WAITER_DELAY = 10 +TEST_WAITER_MAX_ATTEMPTS = 10 TEST_AWS_CONN_ID = "test-aws-id" @@ -43,8 +43,8 @@ class TestEksCreateFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) class_path, args = eks_create_fargate_profile_trigger.serialize() @@ -52,8 +52,8 @@ class TestEksCreateFargateProfileTrigger: assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) + assert args["waiter_delay"] == str(TEST_WAITER_DELAY) + assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) @pytest.mark.asyncio @mock.patch.object(EksHook, "async_conn") @@ -67,8 +67,8 @@ class TestEksCreateFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) generator = eks_create_fargate_profile_trigger.run() @@ -96,8 +96,8 @@ class TestEksCreateFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) generator = eks_create_fargate_profile_trigger.run() @@ -126,8 +126,8 @@ class TestEksCreateFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=2, ) with pytest.raises(AirflowException) as exc: generator = eks_create_fargate_profile_trigger.run() @@ -158,8 +158,8 @@ class TestEksCreateFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) with pytest.raises(AirflowException) as exc: @@ -175,8 +175,8 @@ class TestEksDeleteFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) class_path, args = eks_delete_fargate_profile_trigger.serialize() @@ -184,8 +184,8 @@ class TestEksDeleteFargateProfileTrigger: assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) + assert args["waiter_delay"] == str(TEST_WAITER_DELAY) + assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) @pytest.mark.asyncio @mock.patch.object(EksHook, "async_conn") @@ -199,8 +199,8 @@ class TestEksDeleteFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) generator = eks_delete_fargate_profile_trigger.run() @@ -228,8 +228,8 @@ class TestEksDeleteFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) generator = eks_delete_fargate_profile_trigger.run() @@ -257,8 +257,8 @@ class TestEksDeleteFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=2, ) with pytest.raises(AirflowException) as exc: generator = eks_delete_fargate_profile_trigger.run() @@ -289,8 +289,8 @@ class TestEksDeleteFargateProfileTrigger: cluster_name=TEST_CLUSTER_IDENTIFIER, fargate_profile_name=TEST_FARGATE_PROFILE_NAME, aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, ) with pytest.raises(AirflowException) as exc: generator = eks_delete_fargate_profile_trigger.run()