This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c887988b0 Refactor Eks Create Cluster Operator code (#31960)
5c887988b0 is described below

commit 5c887988b02b02e60f693c9341013592a291ee27
Author: Syed Hussaain <103602455+syeda...@users.noreply.github.com>
AuthorDate: Fri Jun 23 14:18:13 2023 -0700

    Refactor Eks Create Cluster Operator code (#31960)
    
    * Refactor EksCreateClusterOperator to reuse code being used in multiple 
places
    
    * Update create_compute method to pass tests
    Add waiter params to EksCreateClusterOperator and EksCreateNodegroupOperator
    Update EksCreateFargateProfileTrigger and EksDeleteFargateProfileTrigger to 
use more consistent waiter names
    Update unit tests for triggers and operators
---
 airflow/providers/amazon/aws/operators/eks.py    | 249 +++++++++++++++--------
 airflow/providers/amazon/aws/triggers/eks.py     |  62 +++---
 tests/providers/amazon/aws/operators/test_eks.py |  32 ++-
 tests/providers/amazon/aws/triggers/test_eks.py  |  52 ++---
 4 files changed, 247 insertions(+), 148 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/eks.py 
b/airflow/providers/amazon/aws/operators/eks.py
index 8131be4f65..e280da5e5a 100644
--- a/airflow/providers/amazon/aws/operators/eks.py
+++ b/airflow/providers/amazon/aws/operators/eks.py
@@ -17,10 +17,11 @@
 """This module contains Amazon EKS operators."""
 from __future__ import annotations
 
+import logging
 import warnings
 from ast import literal_eval
 from datetime import timedelta
-from typing import TYPE_CHECKING, Any, List, Sequence, cast
+from typing import TYPE_CHECKING, List, Sequence, cast
 
 from botocore.exceptions import ClientError, WaiterError
 
@@ -31,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.eks import (
     EksCreateFargateProfileTrigger,
     EksDeleteFargateProfileTrigger,
 )
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 
 try:
     from airflow.providers.cncf.kubernetes.operators.pod import 
KubernetesPodOperator
@@ -59,6 +61,75 @@ NODEGROUP_FULL_NAME = "Amazon EKS managed node groups"
 FARGATE_FULL_NAME = "AWS Fargate profiles"
 
 
+def _create_compute(
+    compute: str | None,
+    cluster_name: str,
+    aws_conn_id: str,
+    region: str | None,
+    waiter_delay: int,
+    waiter_max_attempts: int,
+    wait_for_completion: bool = False,
+    nodegroup_name: str | None = None,
+    nodegroup_role_arn: str | None = None,
+    create_nodegroup_kwargs: dict | None = None,
+    fargate_profile_name: str | None = None,
+    fargate_pod_execution_role_arn: str | None = None,
+    fargate_selectors: list | None = None,
+    create_fargate_profile_kwargs: dict | None = None,
+    subnets: list[str] | None = None,
+):
+    log = logging.getLogger(__name__)
+    eks_hook = EksHook(aws_conn_id=aws_conn_id, region_name=region)
+    if compute == "nodegroup" and nodegroup_name:
+
+        # this is to satisfy mypy
+        subnets = subnets or []
+        create_nodegroup_kwargs = create_nodegroup_kwargs or {}
+
+        eks_hook.create_nodegroup(
+            clusterName=cluster_name,
+            nodegroupName=nodegroup_name,
+            subnets=subnets,
+            nodeRole=nodegroup_role_arn,
+            **create_nodegroup_kwargs,
+        )
+        if wait_for_completion:
+            log.info("Waiting for nodegroup to provision.  This will take some 
time.")
+            wait(
+                waiter=eks_hook.conn.get_waiter("nodegroup_active"),
+                waiter_delay=waiter_delay,
+                max_attempts=waiter_max_attempts,
+                args={"clusterName": cluster_name, "nodegroupName": 
nodegroup_name},
+                failure_message="Nodegroup creation failed",
+                status_message="Nodegroup status is",
+                status_args=["nodegroup.status"],
+            )
+    elif compute == "fargate" and fargate_profile_name:
+
+        # this is to satisfy mypy
+        create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
+        fargate_selectors = fargate_selectors or []
+
+        eks_hook.create_fargate_profile(
+            clusterName=cluster_name,
+            fargateProfileName=fargate_profile_name,
+            podExecutionRoleArn=fargate_pod_execution_role_arn,
+            selectors=fargate_selectors,
+            **create_fargate_profile_kwargs,
+        )
+        if wait_for_completion:
+            log.info("Waiting for Fargate profile to provision.  This will 
take some time.")
+            wait(
+                waiter=eks_hook.conn.get_waiter("fargate_profile_active"),
+                waiter_delay=waiter_delay,
+                max_attempts=waiter_max_attempts,
+                args={"clusterName": cluster_name, "fargateProfileName": 
fargate_profile_name},
+                failure_message="Fargate profile creation failed",
+                status_message="Fargate profile status is",
+                status_args=["fargateProfile.status"],
+            )
+
+
 class EksCreateClusterOperator(BaseOperator):
     """
     Creates an Amazon EKS Cluster control plane.
@@ -112,6 +183,8 @@ class EksCreateClusterOperator(BaseOperator):
     :param fargate_selectors: The selectors to match for pods to use this AWS 
Fargate profile. (templated)
     :param create_fargate_profile_kwargs: Optional parameters to pass to the 
CreateFargateProfile API
          (templated)
+    :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check cluster status
+    :param waiter_max_attempts: The maximum number of attempts to check the 
status of the cluster.
 
     """
 
@@ -137,7 +210,7 @@ class EksCreateClusterOperator(BaseOperator):
         self,
         cluster_name: str,
         cluster_role_arn: str,
-        resources_vpc_config: dict[str, Any],
+        resources_vpc_config: dict,
         compute: str | None = DEFAULT_COMPUTE_TYPE,
         create_cluster_kwargs: dict | None = None,
         nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
@@ -150,6 +223,8 @@ class EksCreateClusterOperator(BaseOperator):
         wait_for_completion: bool = False,
         aws_conn_id: str = DEFAULT_CONN_ID,
         region: str | None = None,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 40,
         **kwargs,
     ) -> None:
         self.compute = compute
@@ -157,17 +232,21 @@ class EksCreateClusterOperator(BaseOperator):
         self.cluster_role_arn = cluster_role_arn
         self.resources_vpc_config = resources_vpc_config
         self.create_cluster_kwargs = create_cluster_kwargs or {}
-        self.nodegroup_name = nodegroup_name
         self.nodegroup_role_arn = nodegroup_role_arn
-        self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
-        self.fargate_profile_name = fargate_profile_name
         self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn
-        self.fargate_selectors = fargate_selectors or [{"namespace": 
DEFAULT_NAMESPACE_NAME}]
         self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or 
{}
         self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
         self.aws_conn_id = aws_conn_id
         self.region = region
-        super().__init__(**kwargs)
+        self.nodegroup_name = nodegroup_name
+        self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
+        self.fargate_selectors = fargate_selectors or [{"namespace": 
DEFAULT_NAMESPACE_NAME}]
+        self.fargate_profile_name = fargate_profile_name
+        super().__init__(
+            **kwargs,
+        )
 
     def execute(self, context: Context):
         if self.compute:
@@ -183,13 +262,8 @@ class EksCreateClusterOperator(BaseOperator):
                         compute=FARGATE_FULL_NAME, 
requirement="fargate_pod_execution_role_arn"
                     )
                 )
-
-        eks_hook = EksHook(
-            aws_conn_id=self.aws_conn_id,
-            region_name=self.region,
-        )
-
-        eks_hook.create_cluster(
+        self.eks_hook = EksHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region)
+        self.eks_hook.create_cluster(
             name=self.cluster_name,
             roleArn=self.cluster_role_arn,
             resourcesVpcConfig=self.resources_vpc_config,
@@ -202,44 +276,38 @@ class EksCreateClusterOperator(BaseOperator):
             return None
 
         self.log.info("Waiting for EKS Cluster to provision.  This will take 
some time.")
-        client = eks_hook.conn
+        client = self.eks_hook.conn
 
         try:
-            client.get_waiter("cluster_active").wait(name=self.cluster_name)
+            client.get_waiter("cluster_active").wait(
+                name=self.cluster_name,
+                WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": 
self.waiter_max_attempts},
+            )
         except (ClientError, WaiterError) as e:
             self.log.error("Cluster failed to start and will be torn down.\n 
%s", e)
-            eks_hook.delete_cluster(name=self.cluster_name)
-            client.get_waiter("cluster_deleted").wait(name=self.cluster_name)
-            raise
-
-        if self.compute == "nodegroup":
-            eks_hook.create_nodegroup(
-                clusterName=self.cluster_name,
-                nodegroupName=self.nodegroup_name,
-                subnets=cast(List[str], 
self.resources_vpc_config.get("subnetIds")),
-                nodeRole=self.nodegroup_role_arn,
-                **self.create_nodegroup_kwargs,
-            )
-            if self.wait_for_completion:
-                self.log.info("Waiting for nodegroup to provision.  This will 
take some time.")
-                client.get_waiter("nodegroup_active").wait(
-                    clusterName=self.cluster_name,
-                    nodegroupName=self.nodegroup_name,
-                )
-        elif self.compute == "fargate":
-            eks_hook.create_fargate_profile(
-                clusterName=self.cluster_name,
-                fargateProfileName=self.fargate_profile_name,
-                podExecutionRoleArn=self.fargate_pod_execution_role_arn,
-                selectors=self.fargate_selectors,
-                **self.create_fargate_profile_kwargs,
+            self.eks_hook.delete_cluster(name=self.cluster_name)
+            client.get_waiter("cluster_deleted").wait(
+                name=self.cluster_name,
+                WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": 
self.waiter_max_attempts},
             )
-            if self.wait_for_completion:
-                self.log.info("Waiting for Fargate profile to provision.  This 
will take some time.")
-                client.get_waiter("fargate_profile_active").wait(
-                    clusterName=self.cluster_name,
-                    fargateProfileName=self.fargate_profile_name,
-                )
+            raise
+        _create_compute(
+            compute=self.compute,
+            cluster_name=self.cluster_name,
+            aws_conn_id=self.aws_conn_id,
+            region=self.region,
+            wait_for_completion=self.wait_for_completion,
+            waiter_delay=self.waiter_delay,
+            waiter_max_attempts=self.waiter_max_attempts,
+            nodegroup_name=self.nodegroup_name,
+            nodegroup_role_arn=self.nodegroup_role_arn,
+            create_nodegroup_kwargs=self.create_nodegroup_kwargs,
+            fargate_profile_name=self.fargate_profile_name,
+            fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
+            fargate_selectors=self.fargate_selectors,
+            create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
+            subnets=cast(List[str], 
self.resources_vpc_config.get("subnetIds")),
+        )
 
 
 class EksCreateNodegroupOperator(BaseOperator):
@@ -265,6 +333,8 @@ class EksCreateNodegroupOperator(BaseOperator):
          maintained on each worker node).
     :param region: Which AWS region the connection should use. (templated)
         If this is None or empty then the default boto3 behaviour is used.
+    :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check nodegroup status
+    :param waiter_max_attempts: The maximum number of attempts to check the 
status of the nodegroup.
 
     """
 
@@ -289,8 +359,12 @@ class EksCreateNodegroupOperator(BaseOperator):
         wait_for_completion: bool = False,
         aws_conn_id: str = DEFAULT_CONN_ID,
         region: str | None = None,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 80,
         **kwargs,
     ) -> None:
+        self.nodegroup_subnets = nodegroup_subnets
+        self.compute = "nodegroup"
         self.cluster_name = cluster_name
         self.nodegroup_role_arn = nodegroup_role_arn
         self.nodegroup_name = nodegroup_name
@@ -298,10 +372,15 @@ class EksCreateNodegroupOperator(BaseOperator):
         self.wait_for_completion = wait_for_completion
         self.aws_conn_id = aws_conn_id
         self.region = region
-        self.nodegroup_subnets = nodegroup_subnets
-        super().__init__(**kwargs)
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+
+        super().__init__(
+            **kwargs,
+        )
 
     def execute(self, context: Context):
+        self.log.info(self.task_id)
         if isinstance(self.nodegroup_subnets, str):
             nodegroup_subnets_list: list[str] = []
             if self.nodegroup_subnets != "":
@@ -314,25 +393,20 @@ class EksCreateNodegroupOperator(BaseOperator):
                         self.nodegroup_subnets,
                     )
             self.nodegroup_subnets = nodegroup_subnets_list
-
-        eks_hook = EksHook(
+        _create_compute(
+            compute=self.compute,
+            cluster_name=self.cluster_name,
             aws_conn_id=self.aws_conn_id,
-            region_name=self.region,
-        )
-        eks_hook.create_nodegroup(
-            clusterName=self.cluster_name,
-            nodegroupName=self.nodegroup_name,
+            region=self.region,
+            wait_for_completion=self.wait_for_completion,
+            waiter_delay=self.waiter_delay,
+            waiter_max_attempts=self.waiter_max_attempts,
+            nodegroup_name=self.nodegroup_name,
+            nodegroup_role_arn=self.nodegroup_role_arn,
+            create_nodegroup_kwargs=self.create_nodegroup_kwargs,
             subnets=self.nodegroup_subnets,
-            nodeRole=self.nodegroup_role_arn,
-            **self.create_nodegroup_kwargs,
         )
 
-        if self.wait_for_completion:
-            self.log.info("Waiting for nodegroup to provision.  This will take 
some time.")
-            eks_hook.conn.get_waiter("nodegroup_active").wait(
-                clusterName=self.cluster_name, 
nodegroupName=self.nodegroup_name
-            )
-
 
 class EksCreateFargateProfileOperator(BaseOperator):
     """
@@ -392,30 +466,34 @@ class EksCreateFargateProfileOperator(BaseOperator):
         **kwargs,
     ) -> None:
         self.cluster_name = cluster_name
-        self.pod_execution_role_arn = pod_execution_role_arn
         self.selectors = selectors
+        self.pod_execution_role_arn = pod_execution_role_arn
         self.fargate_profile_name = fargate_profile_name
         self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or 
{}
-        self.wait_for_completion = wait_for_completion
+        self.wait_for_completion = False if deferrable else wait_for_completion
         self.aws_conn_id = aws_conn_id
         self.region = region
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
         self.deferrable = deferrable
-        super().__init__(**kwargs)
+        self.compute = "fargate"
+        super().__init__(
+            **kwargs,
+        )
 
     def execute(self, context: Context):
-        eks_hook = EksHook(
+        _create_compute(
+            compute=self.compute,
+            cluster_name=self.cluster_name,
             aws_conn_id=self.aws_conn_id,
-            region_name=self.region,
-        )
-
-        eks_hook.create_fargate_profile(
-            clusterName=self.cluster_name,
-            fargateProfileName=self.fargate_profile_name,
-            podExecutionRoleArn=self.pod_execution_role_arn,
-            selectors=self.selectors,
-            **self.create_fargate_profile_kwargs,
+            region=self.region,
+            wait_for_completion=self.wait_for_completion,
+            waiter_delay=self.waiter_delay,
+            waiter_max_attempts=self.waiter_max_attempts,
+            fargate_profile_name=self.fargate_profile_name,
+            fargate_pod_execution_role_arn=self.pod_execution_role_arn,
+            fargate_selectors=self.selectors,
+            create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
         )
         if self.deferrable:
             self.defer(
@@ -423,21 +501,15 @@ class EksCreateFargateProfileOperator(BaseOperator):
                     cluster_name=self.cluster_name,
                     fargate_profile_name=self.fargate_profile_name,
                     aws_conn_id=self.aws_conn_id,
-                    poll_interval=self.waiter_delay,
-                    max_attempts=self.waiter_max_attempts,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    region=self.region,
                 ),
                 method_name="execute_complete",
                 # timeout is set to ensure that if a trigger dies, the timeout 
does not restart
                 # 60 seconds is added to allow the trigger to exit gracefully 
(i.e. yield TriggerEvent)
                 timeout=timedelta(seconds=(self.waiter_max_attempts * 
self.waiter_delay + 60)),
             )
-        elif self.wait_for_completion:
-            self.log.info("Waiting for Fargate profile to provision.  This 
will take some time.")
-            eks_hook.conn.get_waiter("fargate_profile_active").wait(
-                clusterName=self.cluster_name,
-                fargateProfileName=self.fargate_profile_name,
-                WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": 
self.waiter_max_attempts},
-            )
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
@@ -677,8 +749,9 @@ class EksDeleteFargateProfileOperator(BaseOperator):
                     cluster_name=self.cluster_name,
                     fargate_profile_name=self.fargate_profile_name,
                     aws_conn_id=self.aws_conn_id,
-                    poll_interval=self.waiter_delay,
-                    max_attempts=self.waiter_max_attempts,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    region=self.region,
                 ),
                 method_name="execute_complete",
                 # timeout is set to ensure that if a trigger dies, the timeout 
does not restart
diff --git a/airflow/providers/amazon/aws/triggers/eks.py 
b/airflow/providers/amazon/aws/triggers/eks.py
index dddab74b30..8ccd88167c 100644
--- a/airflow/providers/amazon/aws/triggers/eks.py
+++ b/airflow/providers/amazon/aws/triggers/eks.py
@@ -33,8 +33,8 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
 
     :param cluster_name: The name of the EKS cluster
     :param fargate_profile_name: The name of the fargate profile
-    :param poll_interval: The amount of time in seconds to wait between 
attempts.
-    :param max_attempts: The maximum number of attempts to be made.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
     :param aws_conn_id: The Airflow connection used for AWS credentials.
     """
 
@@ -42,15 +42,17 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
         self,
         cluster_name: str,
         fargate_profile_name: str,
-        poll_interval: int,
-        max_attempts: int,
+        waiter_delay: int,
+        waiter_max_attempts: int,
         aws_conn_id: str,
+        region: str | None = None,
     ):
         self.cluster_name = cluster_name
         self.fargate_profile_name = fargate_profile_name
-        self.poll_interval = poll_interval
-        self.max_attempts = max_attempts
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
         self.aws_conn_id = aws_conn_id
+        self.region = region
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         return (
@@ -58,24 +60,25 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
             {
                 "cluster_name": self.cluster_name,
                 "fargate_profile_name": self.fargate_profile_name,
-                "poll_interval": str(self.poll_interval),
-                "max_attempts": str(self.max_attempts),
+                "waiter_delay": str(self.waiter_delay),
+                "waiter_max_attempts": str(self.waiter_max_attempts),
                 "aws_conn_id": self.aws_conn_id,
+                "region": self.region,
             },
         )
 
     async def run(self):
-        self.hook = EksHook(aws_conn_id=self.aws_conn_id)
+        self.hook = EksHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region)
         async with self.hook.async_conn as client:
             attempt = 0
             waiter = client.get_waiter("fargate_profile_active")
-            while attempt < int(self.max_attempts):
+            while attempt < int(self.waiter_max_attempts):
                 attempt += 1
                 try:
                     await waiter.wait(
                         clusterName=self.cluster_name,
                         fargateProfileName=self.fargate_profile_name,
-                        WaiterConfig={"Delay": int(self.poll_interval), 
"MaxAttempts": 1},
+                        WaiterConfig={"Delay": int(self.waiter_delay), 
"MaxAttempts": 1},
                     )
                     break
                 except WaiterError as error:
@@ -84,10 +87,10 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
                     self.log.info(
                         "Status of fargate profile is %s", 
error.last_response["fargateProfile"]["status"]
                     )
-                    await asyncio.sleep(int(self.poll_interval))
-        if attempt >= int(self.max_attempts):
+                    await asyncio.sleep(int(self.waiter_delay))
+        if attempt >= int(self.waiter_max_attempts):
             raise AirflowException(
-                f"Create Fargate Profile failed - max attempts reached: 
{self.max_attempts}"
+                f"Create Fargate Profile failed - max attempts reached: 
{self.waiter_max_attempts}"
             )
         else:
             yield TriggerEvent({"status": "success", "message": "Fargate 
Profile Created"})
@@ -100,8 +103,8 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
 
     :param cluster_name: The name of the EKS cluster
     :param fargate_profile_name: The name of the fargate profile
-    :param poll_interval: The amount of time in seconds to wait between 
attempts.
-    :param max_attempts: The maximum number of attempts to be made.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
     :param aws_conn_id: The Airflow connection used for AWS credentials.
     """
 
@@ -109,15 +112,17 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
         self,
         cluster_name: str,
         fargate_profile_name: str,
-        poll_interval: int,
-        max_attempts: int,
+        waiter_delay: int,
+        waiter_max_attempts: int,
         aws_conn_id: str,
+        region: str | None = None,
     ):
         self.cluster_name = cluster_name
         self.fargate_profile_name = fargate_profile_name
-        self.poll_interval = poll_interval
-        self.max_attempts = max_attempts
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
         self.aws_conn_id = aws_conn_id
+        self.region = region
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         return (
@@ -125,24 +130,25 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
             {
                 "cluster_name": self.cluster_name,
                 "fargate_profile_name": self.fargate_profile_name,
-                "poll_interval": str(self.poll_interval),
-                "max_attempts": str(self.max_attempts),
+                "waiter_delay": str(self.waiter_delay),
+                "waiter_max_attempts": str(self.waiter_max_attempts),
                 "aws_conn_id": self.aws_conn_id,
+                "region": self.region,
             },
         )
 
     async def run(self):
-        self.hook = EksHook(aws_conn_id=self.aws_conn_id)
+        self.hook = EksHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region)
         async with self.hook.async_conn as client:
             attempt = 0
             waiter = client.get_waiter("fargate_profile_deleted")
-            while attempt < int(self.max_attempts):
+            while attempt < int(self.waiter_max_attempts):
                 attempt += 1
                 try:
                     await waiter.wait(
                         clusterName=self.cluster_name,
                         fargateProfileName=self.fargate_profile_name,
-                        WaiterConfig={"Delay": int(self.poll_interval), 
"MaxAttempts": 1},
+                        WaiterConfig={"Delay": int(self.waiter_delay), 
"MaxAttempts": 1},
                     )
                     break
                 except WaiterError as error:
@@ -151,10 +157,10 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
                     self.log.info(
                         "Status of fargate profile is %s", 
error.last_response["fargateProfile"]["status"]
                     )
-                    await asyncio.sleep(int(self.poll_interval))
-        if attempt >= int(self.max_attempts):
+                    await asyncio.sleep(int(self.waiter_delay))
+        if attempt >= int(self.waiter_max_attempts):
             raise AirflowException(
-                f"Delete Fargate Profile failed - max attempts reached: 
{self.max_attempts}"
+                f"Delete Fargate Profile failed - max attempts reached: 
{self.waiter_max_attempts}"
             )
         else:
             yield TriggerEvent({"status": "success", "message": "Fargate 
Profile Deleted"})
diff --git a/tests/providers/amazon/aws/operators/test_eks.py 
b/tests/providers/amazon/aws/operators/test_eks.py
index 089aef1704..311aad972d 100644
--- a/tests/providers/amazon/aws/operators/test_eks.py
+++ b/tests/providers/amazon/aws/operators/test_eks.py
@@ -200,7 +200,11 @@ class TestEksCreateClusterOperator:
         operator.execute({})
         mock_create_cluster.assert_called_with(**convert_keys(parameters))
         mock_create_nodegroup.assert_not_called()
-        mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+        mock_waiter.assert_called_once_with(
+            mock.ANY,
+            name=CLUSTER_NAME,
+            WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY},
+        )
         assert_expected_waiter_type(mock_waiter, "ClusterActive")
 
     @mock.patch.object(Waiter, "wait")
@@ -216,7 +220,11 @@ class TestEksCreateClusterOperator:
 
         
mock_create_cluster.assert_called_once_with(**convert_keys(self.create_cluster_params))
         
mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params))
-        mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+        mock_waiter.assert_called_once_with(
+            mock.ANY,
+            name=CLUSTER_NAME,
+            WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY},
+        )
         assert_expected_waiter_type(mock_waiter, "ClusterActive")
 
     @mock.patch.object(Waiter, "wait")
@@ -235,7 +243,12 @@ class TestEksCreateClusterOperator:
         
mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params))
         # Calls waiter once for the cluster and once for the nodegroup.
         assert mock_waiter.call_count == 2
-        mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, 
nodegroupName=NODEGROUP_NAME)
+        mock_waiter.assert_called_with(
+            mock.ANY,
+            clusterName=CLUSTER_NAME,
+            nodegroupName=NODEGROUP_NAME,
+            WaiterConfig={"MaxAttempts": mock.ANY},
+        )
         assert_expected_waiter_type(mock_waiter, "NodegroupActive")
 
     @mock.patch.object(Waiter, "wait")
@@ -253,7 +266,11 @@ class TestEksCreateClusterOperator:
         mock_create_fargate_profile.assert_called_once_with(
             **convert_keys(self.create_fargate_profile_params)
         )
-        mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+        mock_waiter.assert_called_once_with(
+            mock.ANY,
+            name=CLUSTER_NAME,
+            WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY},
+        )
         assert_expected_waiter_type(mock_waiter, "ClusterActive")
 
     @mock.patch.object(Waiter, "wait")
@@ -275,7 +292,10 @@ class TestEksCreateClusterOperator:
         # Calls waiter once for the cluster and once for the nodegroup.
         assert mock_waiter.call_count == 2
         mock_waiter.assert_called_with(
-            mock.ANY, clusterName=CLUSTER_NAME, 
fargateProfileName=FARGATE_PROFILE_NAME
+            mock.ANY,
+            clusterName=CLUSTER_NAME,
+            fargateProfileName=FARGATE_PROFILE_NAME,
+            WaiterConfig={"MaxAttempts": mock.ANY},
         )
         assert_expected_waiter_type(mock_waiter, "FargateProfileActive")
 
@@ -377,7 +397,7 @@ class TestEksCreateFargateProfileOperator:
             mock.ANY,
             clusterName=CLUSTER_NAME,
             fargateProfileName=FARGATE_PROFILE_NAME,
-            WaiterConfig={"Delay": 10, "MaxAttempts": 60},
+            WaiterConfig={"MaxAttempts": mock.ANY},
         )
         assert_expected_waiter_type(mock_waiter, "FargateProfileActive")
 
diff --git a/tests/providers/amazon/aws/triggers/test_eks.py 
b/tests/providers/amazon/aws/triggers/test_eks.py
index abab121d24..dbc71e7296 100644
--- a/tests/providers/amazon/aws/triggers/test_eks.py
+++ b/tests/providers/amazon/aws/triggers/test_eks.py
@@ -32,8 +32,8 @@ from airflow.triggers.base import TriggerEvent
 
 TEST_CLUSTER_IDENTIFIER = "test-cluster"
 TEST_FARGATE_PROFILE_NAME = "test-fargate-profile"
-TEST_POLL_INTERVAL = 10
-TEST_MAX_ATTEMPTS = 10
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
 TEST_AWS_CONN_ID = "test-aws-id"
 
 
@@ -43,8 +43,8 @@ class TestEksCreateFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
 
         class_path, args = eks_create_fargate_profile_trigger.serialize()
@@ -52,8 +52,8 @@ class TestEksCreateFargateProfileTrigger:
         assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER
         assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME
         assert args["aws_conn_id"] == TEST_AWS_CONN_ID
-        assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
-        assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
+        assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
+        assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
 
     @pytest.mark.asyncio
     @mock.patch.object(EksHook, "async_conn")
@@ -67,8 +67,8 @@ class TestEksCreateFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
 
         generator = eks_create_fargate_profile_trigger.run()
@@ -96,8 +96,8 @@ class TestEksCreateFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
 
         generator = eks_create_fargate_profile_trigger.run()
@@ -126,8 +126,8 @@ class TestEksCreateFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=2,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=2,
         )
         with pytest.raises(AirflowException) as exc:
             generator = eks_create_fargate_profile_trigger.run()
@@ -158,8 +158,8 @@ class TestEksCreateFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
 
         with pytest.raises(AirflowException) as exc:
@@ -175,8 +175,8 @@ class TestEksDeleteFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
 
         class_path, args = eks_delete_fargate_profile_trigger.serialize()
@@ -184,8 +184,8 @@ class TestEksDeleteFargateProfileTrigger:
         assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER
         assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME
         assert args["aws_conn_id"] == TEST_AWS_CONN_ID
-        assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
-        assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
+        assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
+        assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
 
     @pytest.mark.asyncio
     @mock.patch.object(EksHook, "async_conn")
@@ -199,8 +199,8 @@ class TestEksDeleteFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
 
         generator = eks_delete_fargate_profile_trigger.run()
@@ -228,8 +228,8 @@ class TestEksDeleteFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
 
         generator = eks_delete_fargate_profile_trigger.run()
@@ -257,8 +257,8 @@ class TestEksDeleteFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=2,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=2,
         )
         with pytest.raises(AirflowException) as exc:
             generator = eks_delete_fargate_profile_trigger.run()
@@ -289,8 +289,8 @@ class TestEksDeleteFargateProfileTrigger:
             cluster_name=TEST_CLUSTER_IDENTIFIER,
             fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
             aws_conn_id=TEST_AWS_CONN_ID,
-            poll_interval=TEST_POLL_INTERVAL,
-            max_attempts=TEST_MAX_ATTEMPTS,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
         )
         with pytest.raises(AirflowException) as exc:
             generator = eks_delete_fargate_profile_trigger.run()


Reply via email to