This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c4928da40 introduce a method to convert dictionaries to boto-style 
key-value lists (#28816)
2c4928da40 is described below

commit 2c4928da40667cd4d52030b8b79419175948cb85
Author: Raphaƫl Vandon <114772123+vandonr-...@users.noreply.github.com>
AuthorDate: Tue Jan 24 15:45:16 2023 -0800

    introduce a method to convert dictionaries to boto-style key-value lists 
(#28816)
    
    * accept either dict of list for tags
---
 airflow/providers/amazon/aws/hooks/s3.py           | 28 ++++++++++------
 airflow/providers/amazon/aws/hooks/sagemaker.py    |  5 ++-
 airflow/providers/amazon/aws/operators/rds.py      | 32 ++++++++++--------
 airflow/providers/amazon/aws/operators/s3.py       |  4 +--
 .../providers/amazon/aws/operators/sagemaker.py    |  4 +--
 airflow/providers/amazon/aws/utils/tags.py         | 38 ++++++++++++++++++++++
 tests/providers/amazon/aws/hooks/test_s3.py        |  9 +++++
 7 files changed, 89 insertions(+), 31 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py 
b/airflow/providers/amazon/aws/hooks/s3.py
index 89c9261cb6..f88274747d 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -43,6 +43,7 @@ from botocore.exceptions import ClientError
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.utils.helpers import chunks
 
 T = TypeVar("T", bound=Callable)
@@ -1063,36 +1064,43 @@ class S3Hook(AwsBaseHook):
     @provide_bucket_name
     def put_bucket_tagging(
         self,
-        tag_set: list[dict[str, str]] | None = None,
+        tag_set: dict[str, str] | list[dict[str, str]] | None = None,
         key: str | None = None,
         value: str | None = None,
         bucket_name: str | None = None,
     ) -> None:
         """
-        Overwrites the existing TagSet with provided tags.  Must provide 
either a TagSet or a key/value pair.
+        Overwrites the existing TagSet with provided tags.
+        Must provide a TagSet, a key/value pair, or both.
 
         .. seealso::
             - :external+boto3:py:meth:`S3.Client.put_bucket_tagging`
 
-        :param tag_set: A List containing the key/value pairs for the tags.
+        :param tag_set: A dictionary containing the key/value pairs for the 
tags,
+            or a list already formatted for the API
         :param key: The Key for the new TagSet entry.
         :param value: The Value for the new TagSet entry.
         :param bucket_name: The name of the bucket.
+
         :return: None
         """
-        self.log.info("S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, 
value, tag_set)
-        if not tag_set:
-            tag_set = []
+        formatted_tags = format_tags(tag_set)
+
         if key and value:
-            tag_set.append({"Key": key, "Value": value})
-        elif not tag_set or (key or value):
-            message = "put_bucket_tagging() requires either a predefined 
TagSet or a key/value pair."
+            formatted_tags.append({"Key": key, "Value": value})
+        elif key or value:
+            message = (
+                "Key and Value must be specified as a pair. "
+                f"Only one of the two had a value (key: '{key}', value: 
'{value}')"
+            )
             self.log.error(message)
             raise ValueError(message)
 
+        self.log.info("Tagging S3 Bucket %s with %s", bucket_name, 
formatted_tags)
+
         try:
             s3_client = self.get_conn()
-            s3_client.put_bucket_tagging(Bucket=bucket_name, 
Tagging={"TagSet": tag_set})
+            s3_client.put_bucket_tagging(Bucket=bucket_name, 
Tagging={"TagSet": formatted_tags})
         except ClientError as e:
             self.log.error(e)
             raise e
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py 
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index c5aeb3d9ed..4c731f2051 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -35,6 +35,7 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.utils import timezone
 
 
@@ -1100,9 +1101,7 @@ class SageMakerHook(AwsBaseHook):
 
         :return: the ARN of the pipeline execution launched.
         """
-        if pipeline_params is None:
-            pipeline_params = {}
-        formatted_params = [{"Name": kvp[0], "Value": kvp[1]} for kvp in 
pipeline_params.items()]
+        formatted_params = format_tags(pipeline_params, key_label="Name")
 
         try:
             res = self.conn.start_pipeline_execution(
diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index c10e969c8b..2f2cf58438 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -25,6 +25,7 @@ from mypy_boto3_rds.type_defs import TagTypeDef
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
+from airflow.providers.amazon.aws.utils.tags import format_tags
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -64,7 +65,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
     :param db_type: Type of the DB - either "instance" or "cluster"
     :param db_identifier: The identifier of the instance or cluster that you 
want to create the snapshot of
     :param db_snapshot_identifier: The identifier for the DB snapshot
-    :param tags: A list of tags in format `[{"Key": "something", "Value": 
"something"},]
+    :param tags: A dictionary of tags or a list of tags in format `[{"Key": 
"...", "Value": "..."},]`
         `USER Tagging 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
     :param wait_for_completion:  If True, waits for creation of the DB 
snapshot to complete. (default: True)
     """
@@ -77,7 +78,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
         db_type: str,
         db_identifier: str,
         db_snapshot_identifier: str,
-        tags: Sequence[TagTypeDef] | None = None,
+        tags: Sequence[TagTypeDef] | dict | None = None,
         wait_for_completion: bool = True,
         aws_conn_id: str = "aws_conn_id",
         **kwargs,
@@ -86,7 +87,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
         self.db_type = RdsDbType(db_type)
         self.db_identifier = db_identifier
         self.db_snapshot_identifier = db_snapshot_identifier
-        self.tags = tags or []
+        self.tags = tags
         self.wait_for_completion = wait_for_completion
 
     def execute(self, context: Context) -> str:
@@ -97,11 +98,12 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
             self.db_snapshot_identifier,
         )
 
+        formatted_tags = format_tags(self.tags)
         if self.db_type.value == "instance":
             create_instance_snap = self.hook.conn.create_db_snapshot(
                 DBInstanceIdentifier=self.db_identifier,
                 DBSnapshotIdentifier=self.db_snapshot_identifier,
-                Tags=self.tags,
+                Tags=formatted_tags,
             )
             create_response = json.dumps(create_instance_snap, default=str)
             if self.wait_for_completion:
@@ -110,7 +112,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
             create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
                 DBClusterIdentifier=self.db_identifier,
                 DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
-                Tags=self.tags,
+                Tags=formatted_tags,
             )
             create_response = json.dumps(create_cluster_snap, default=str)
             if self.wait_for_completion:
@@ -132,7 +134,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
     :param source_db_snapshot_identifier: The identifier of the source snapshot
     :param target_db_snapshot_identifier: The identifier of the target snapshot
     :param kms_key_id: The AWS KMS key identifier for an encrypted DB snapshot
-    :param tags: A list of tags in format `[{"Key": "something", "Value": 
"something"},]
+    :param tags: A dictionary of tags or a list of tags in format `[{"Key": 
"...", "Value": "..."},]`
         `USER Tagging 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
     :param copy_tags: Whether to copy all tags from the source snapshot to the 
target snapshot (default False)
     :param pre_signed_url: The URL that contains a Signature Version 4 signed 
request
@@ -159,7 +161,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
         source_db_snapshot_identifier: str,
         target_db_snapshot_identifier: str,
         kms_key_id: str = "",
-        tags: Sequence[TagTypeDef] | None = None,
+        tags: Sequence[TagTypeDef] | dict | None = None,
         copy_tags: bool = False,
         pre_signed_url: str = "",
         option_group_name: str = "",
@@ -175,7 +177,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
         self.source_db_snapshot_identifier = source_db_snapshot_identifier
         self.target_db_snapshot_identifier = target_db_snapshot_identifier
         self.kms_key_id = kms_key_id
-        self.tags = tags or []
+        self.tags = tags
         self.copy_tags = copy_tags
         self.pre_signed_url = pre_signed_url
         self.option_group_name = option_group_name
@@ -190,12 +192,13 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
             self.target_db_snapshot_identifier,
         )
 
+        formatted_tags = format_tags(self.tags)
         if self.db_type.value == "instance":
             copy_instance_snap = self.hook.conn.copy_db_snapshot(
                 SourceDBSnapshotIdentifier=self.source_db_snapshot_identifier,
                 TargetDBSnapshotIdentifier=self.target_db_snapshot_identifier,
                 KmsKeyId=self.kms_key_id,
-                Tags=self.tags,
+                Tags=formatted_tags,
                 CopyTags=self.copy_tags,
                 PreSignedUrl=self.pre_signed_url,
                 OptionGroupName=self.option_group_name,
@@ -212,7 +215,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
                 
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
                 
TargetDBClusterSnapshotIdentifier=self.target_db_snapshot_identifier,
                 KmsKeyId=self.kms_key_id,
-                Tags=self.tags,
+                Tags=formatted_tags,
                 CopyTags=self.copy_tags,
                 PreSignedUrl=self.pre_signed_url,
                 SourceRegion=self.source_region,
@@ -403,7 +406,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
         `USER Events 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Events.Messages.html>`__
     :param source_ids: The list of identifiers of the event sources for which 
events are returned
     :param enabled: A value that indicates whether to activate the 
subscription (default True)l
-    :param tags: A list of tags in format `[{"Key": "something", "Value": 
"something"},]
+    :param tags: A dictionary of tags or a list of tags in format `[{"Key": 
"...", "Value": "..."},]`
         `USER Tagging 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
     :param wait_for_completion:  If True, waits for creation of the 
subscription to complete. (default: True)
     """
@@ -426,7 +429,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
         event_categories: Sequence[str] | None = None,
         source_ids: Sequence[str] | None = None,
         enabled: bool = True,
-        tags: Sequence[TagTypeDef] | None = None,
+        tags: Sequence[TagTypeDef] | dict | None = None,
         wait_for_completion: bool = True,
         aws_conn_id: str = "aws_default",
         **kwargs,
@@ -439,12 +442,13 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
         self.event_categories = event_categories or []
         self.source_ids = source_ids or []
         self.enabled = enabled
-        self.tags = tags or []
+        self.tags = tags
         self.wait_for_completion = wait_for_completion
 
     def execute(self, context: Context) -> str:
         self.log.info("Creating event subscription '%s' to '%s'", 
self.subscription_name, self.sns_topic_arn)
 
+        formatted_tags = format_tags(self.tags)
         create_subscription = self.hook.conn.create_event_subscription(
             SubscriptionName=self.subscription_name,
             SnsTopicArn=self.sns_topic_arn,
@@ -452,7 +456,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
             EventCategories=self.event_categories,
             SourceIds=self.source_ids,
             Enabled=self.enabled,
-            Tags=self.tags,
+            Tags=formatted_tags,
         )
 
         if self.wait_for_completion:
diff --git a/airflow/providers/amazon/aws/operators/s3.py 
b/airflow/providers/amazon/aws/operators/s3.py
index d748da67bb..d9ab1ab75c 100644
--- a/airflow/providers/amazon/aws/operators/s3.py
+++ b/airflow/providers/amazon/aws/operators/s3.py
@@ -163,7 +163,7 @@ class S3PutBucketTaggingOperator(BaseOperator):
         If a key is provided, a value must be provided as well.
     :param value: The value portion of the key/value pair for a tag to be 
added.
         If a value is provided, a key must be provided as well.
-    :param tag_set: A List of key/value pairs.
+    :param tag_set: A dictionary containing the tags, or a List of key/value 
pairs.
     :param aws_conn_id: The Airflow connection used for AWS credentials.
         If this is None or empty then the default boto3 behaviour is used. If
         running Airflow in a distributed manner and aws_conn_id is None or
@@ -179,7 +179,7 @@ class S3PutBucketTaggingOperator(BaseOperator):
         bucket_name: str,
         key: str | None = None,
         value: str | None = None,
-        tag_set: list[dict[str, str]] | None = None,
+        tag_set: dict | list[dict[str, str]] | None = None,
         aws_conn_id: str | None = "aws_default",
         **kwargs,
     ) -> None:
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index f0191f2f53..aa6130e3f8 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -28,6 +28,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import 
AwsBaseHook
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.utils import trim_none_values
 from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
+from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.utils.json import AirflowJsonEncoder
 
 if TYPE_CHECKING:
@@ -1090,11 +1091,10 @@ class 
SageMakerCreateExperimentOperator(SageMakerBaseOperator):
 
     def execute(self, context: Context) -> str:
         sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
-        tags_set = [{"Key": kvp[0], "Value": kvp[1]} for kvp in 
self.tags.items()]
         params = {
             "ExperimentName": self.name,
             "Description": self.description,
-            "Tags": tags_set,
+            "Tags": format_tags(self.tags),
         }
         ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params))
         arn = ans["ExperimentArn"]
diff --git a/airflow/providers/amazon/aws/utils/tags.py 
b/airflow/providers/amazon/aws/utils/tags.py
new file mode 100644
index 0000000000..c8afb124b6
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/tags.py
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+
+def format_tags(source: Any, *, key_label: str = "Key", value_label: str = 
"Value"):
+    """
+    If given a dictionary, formats it as an array of objects with a key and a 
value field to be passed to boto
+    calls that expect this format.
+    Else, assumes that it's already in the right format and returns it as is. 
We do not validate
+    the format here since it's done by boto anyway, and the error wouldn't be 
clearer if thrown from here.
+
+    :param source: a dict from which keys and values are read
+    :param key_label: optional, the label to use for keys if not "Key"
+    :param value_label: optional, the label to use for values if not "Value"
+    """
+    if source is None:
+        return []
+    elif isinstance(source, dict):
+        return [{key_label: kvp[0], value_label: kvp[1]} for kvp in 
source.items()]
+    else:
+        return source
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py 
b/tests/providers/amazon/aws/hooks/test_s3.py
index 6ee49ffb26..6eaec58dbf 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -734,6 +734,15 @@ class TestAwsS3Hook:
 
         assert hook.get_bucket_tagging(bucket_name="new_bucket") == tag_set
 
+    @mock_s3
+    def test_put_bucket_tagging_with_dict(self):
+        hook = S3Hook()
+        hook.create_bucket(bucket_name="new_bucket")
+        tag_set = {"Color": "Green"}
+        hook.put_bucket_tagging(bucket_name="new_bucket", tag_set=tag_set)
+
+        assert hook.get_bucket_tagging(bucket_name="new_bucket") == [{"Key": 
"Color", "Value": "Green"}]
+
     @mock_s3
     def test_put_bucket_tagging_with_pair(self):
         hook = S3Hook()

Reply via email to