This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 545e4d505e Extend hooks arguments into `AwsBaseWaiterTrigger` (#34884)
545e4d505e is described below

commit 545e4d505e669473f42a6637f5593d0860dac086
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Thu Oct 12 22:59:45 2023 +0400

    Extend hooks arguments into `AwsBaseWaiterTrigger` (#34884)
    
    * Extend hooks arguments into `AwsBaseWaiterTrigger`
    
    * Use prune dictionary AwsBaseWaiterTrigger
    
    ---------
    Co-authored-by: Vincent Beck <vincb...@amazon.com>
    
    * Add links to boto3 documentation in docstring
    
    * Add super() into the AwsBaseWaiterTrigger
---
 airflow/providers/amazon/aws/triggers/base.py    | 27 ++++++++++++++++--
 tests/providers/amazon/aws/triggers/test_base.py | 36 +++++++++++++++++++++++-
 2 files changed, 59 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/triggers/base.py 
b/airflow/providers/amazon/aws/triggers/base.py
index a6fc6104dd..9b2e8696e4 100644
--- a/airflow/providers/amazon/aws/triggers/base.py
+++ b/airflow/providers/amazon/aws/triggers/base.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any, AsyncIterator
 
 from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
 from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.helpers import prune_dict
 
 if TYPE_CHECKING:
     from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
@@ -55,6 +56,11 @@ class AwsBaseWaiterTrigger(BaseTrigger):
     :param waiter_max_attempts: The maximum number of attempts to be made.
     :param aws_conn_id: The Airflow connection used for AWS credentials. To be 
used to build the hook.
     :param region_name: The AWS region where the resources to watch are. To be 
used to build the hook.
+    :param verify: Whether or not to verify SSL certificates. To be used to 
build the hook.
+        See: 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client.
+        To be used to build the hook. For available key-values see:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
     def __init__(
@@ -72,7 +78,10 @@ class AwsBaseWaiterTrigger(BaseTrigger):
         waiter_max_attempts: int,
         aws_conn_id: str | None,
         region_name: str | None = None,
+        verify: bool | str | None = None,
+        botocore_config: dict | None = None,
     ):
+        super().__init__()
         # parameters that should be hardcoded in the child's implem
         self.serialized_fields = serialized_fields
 
@@ -90,6 +99,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
         self.attempts = waiter_max_attempts
         self.aws_conn_id = aws_conn_id
         self.region_name = region_name
+        self.verify = verify
+        self.botocore_config = botocore_config
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         # here we put together the "common" params,
@@ -102,9 +113,19 @@ class AwsBaseWaiterTrigger(BaseTrigger):
             },
             **self.serialized_fields,
         )
-        if self.region_name:
-            # if we serialize the None value from this, it breaks subclasses 
that don't have it in their ctor.
-            params["region_name"] = self.region_name
+
+        # if we serialize the None value from this, it breaks subclasses that 
don't have it in their ctor.
+        params.update(
+            prune_dict(
+                {
+                    # Keep previous behaviour when empty string in region_name 
evaluated as `None`
+                    "region_name": self.region_name or None,
+                    "verify": self.verify,
+                    "botocore_config": self.botocore_config,
+                }
+            )
+        )
+
         return (
             # remember that self is an instance of the subclass here, not of 
this class.
             self.__class__.__module__ + "." + self.__class__.__qualname__,
diff --git a/tests/providers/amazon/aws/triggers/test_base.py 
b/tests/providers/amazon/aws/triggers/test_base.py
index 9334e555b9..e5866596be 100644
--- a/tests/providers/amazon/aws/triggers/test_base.py
+++ b/tests/providers/amazon/aws/triggers/test_base.py
@@ -63,7 +63,41 @@ class TestAwsBaseWaiterTrigger:
         assert "region_name" in args
         assert args["region_name"] == "my_region"
 
-    def test_region_not_serialized_if_omitted(self):
+    @pytest.mark.parametrize("verify", [True, False, 
pytest.param("/foo/bar.pem", id="path")])
+    def test_verify_serialized(self, verify):
+        self.trigger.verify = verify
+        _, args = self.trigger.serialize()
+
+        assert "verify" in args
+        assert args["verify"] == verify
+
+    @pytest.mark.parametrize(
+        "botocore_config",
+        [
+            pytest.param({"read_timeout": 10, "connect_timeout": 42, 
"keepalive": True}, id="non-empty-dict"),
+            pytest.param({}, id="empty-dict"),
+        ],
+    )
+    def test_botocore_config_serialized(self, botocore_config):
+        self.trigger.botocore_config = botocore_config
+        _, args = self.trigger.serialize()
+
+        assert "botocore_config" in args
+        assert args["botocore_config"] == botocore_config
+
+    @pytest.mark.parametrize("param_name", ["region_name", "verify", 
"botocore_config"])
+    def test_hooks_args_not_serialized_if_omitted(self, param_name):
+        _, args = self.trigger.serialize()
+
+        assert param_name not in args
+
+    def test_region_name_not_serialized_if_empty_string(self):
+        """
+        Compatibility with previous behaviour when empty string region name 
not serialised.
+
+        It would evaluate as None, however empty string it is not valid region 
name in boto3.
+        """
+        self.trigger.region_name = ""
         _, args = self.trigger.serialize()
 
         assert "region_name" not in args

Reply via email to