This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3721c9a441 Use base aws classes in Amazon S3 Glacier Operators/Sensors 
(#35108)
3721c9a441 is described below

commit 3721c9a4413d3f5002b46589beeff490827cd9cb
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Tue Oct 24 18:54:21 2023 +0400

    Use base aws classes in Amazon S3 Glacier Operators/Sensors (#35108)
---
 airflow/providers/amazon/aws/hooks/glacier.py      |  6 +-
 airflow/providers/amazon/aws/operators/glacier.py  | 30 ++++-----
 airflow/providers/amazon/aws/sensors/glacier.py    | 15 ++---
 .../operators/s3/glacier.rst                       |  5 ++
 .../providers/amazon/aws/operators/test_glacier.py | 72 +++++++++++++++++-----
 tests/providers/amazon/aws/sensors/test_glacier.py | 65 ++++++++++++-------
 6 files changed, 124 insertions(+), 69 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/glacier.py 
b/airflow/providers/amazon/aws/hooks/glacier.py
index bd260000e7..4655b28e30 100644
--- a/airflow/providers/amazon/aws/hooks/glacier.py
+++ b/airflow/providers/amazon/aws/hooks/glacier.py
@@ -35,9 +35,9 @@ class GlacierHook(AwsBaseHook):
         - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
     """
 
-    def __init__(self, aws_conn_id: str = "aws_default") -> None:
-        super().__init__(client_type="glacier")
-        self.aws_conn_id = aws_conn_id
+    def __init__(self, *args, **kwargs) -> None:
+        kwargs.update({"client_type": "glacier", "resource_type": None})
+        super().__init__(*args, **kwargs)
 
     def retrieve_inventory(self, vault_name: str) -> dict[str, Any]:
         """Initiate an Amazon Glacier inventory-retrieval job.
diff --git a/airflow/providers/amazon/aws/operators/glacier.py 
b/airflow/providers/amazon/aws/operators/glacier.py
index 54123e586d..3164004181 100644
--- a/airflow/providers/amazon/aws/operators/glacier.py
+++ b/airflow/providers/amazon/aws/operators/glacier.py
@@ -19,14 +19,15 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Sequence
 
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class GlacierCreateJobOperator(BaseOperator):
+class GlacierCreateJobOperator(AwsBaseOperator[GlacierHook]):
     """
     Initiate an Amazon Glacier inventory-retrieval job.
 
@@ -38,25 +39,18 @@ class GlacierCreateJobOperator(BaseOperator):
     :param vault_name: the Glacier vault on which job is executed
     """
 
-    template_fields: Sequence[str] = ("vault_name",)
+    aws_hook_class = GlacierHook
+    template_fields: Sequence[str] = aws_template_fields("vault_name")
 
-    def __init__(
-        self,
-        *,
-        aws_conn_id="aws_default",
-        vault_name: str,
-        **kwargs,
-    ):
+    def __init__(self, *, vault_name: str, **kwargs):
         super().__init__(**kwargs)
-        self.aws_conn_id = aws_conn_id
         self.vault_name = vault_name
 
     def execute(self, context: Context):
-        hook = GlacierHook(aws_conn_id=self.aws_conn_id)
-        return hook.retrieve_inventory(vault_name=self.vault_name)
+        return self.hook.retrieve_inventory(vault_name=self.vault_name)
 
 
-class GlacierUploadArchiveOperator(BaseOperator):
+class GlacierUploadArchiveOperator(AwsBaseOperator[GlacierHook]):
     """
     This operator add an archive to an Amazon S3 Glacier vault.
 
@@ -74,7 +68,8 @@ class GlacierUploadArchiveOperator(BaseOperator):
     :param aws_conn_id: The reference to the AWS connection details
     """
 
-    template_fields: Sequence[str] = ("vault_name",)
+    aws_hook_class = GlacierHook
+    template_fields: Sequence[str] = aws_template_fields("vault_name")
 
     def __init__(
         self,
@@ -84,11 +79,9 @@ class GlacierUploadArchiveOperator(BaseOperator):
         checksum: str | None = None,
         archive_description: str | None = None,
         account_id: str | None = None,
-        aws_conn_id="aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
-        self.aws_conn_id = aws_conn_id
         self.account_id = account_id
         self.vault_name = vault_name
         self.body = body
@@ -96,8 +89,7 @@ class GlacierUploadArchiveOperator(BaseOperator):
         self.archive_description = archive_description
 
     def execute(self, context: Context):
-        hook = GlacierHook(aws_conn_id=self.aws_conn_id)
-        return hook.get_conn().upload_archive(
+        return self.hook.conn.upload_archive(
             accountId=self.account_id,
             vaultName=self.vault_name,
             archiveDescription=self.archive_description,
diff --git a/airflow/providers/amazon/aws/sensors/glacier.py 
b/airflow/providers/amazon/aws/sensors/glacier.py
index e9cc8fc4b7..7a65fc6fc3 100644
--- a/airflow/providers/amazon/aws/sensors/glacier.py
+++ b/airflow/providers/amazon/aws/sensors/glacier.py
@@ -18,12 +18,12 @@
 from __future__ import annotations
 
 from enum import Enum
-from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -36,7 +36,7 @@ class JobStatus(Enum):
     SUCCEEDED = "Succeeded"
 
 
-class GlacierJobOperationSensor(BaseSensorOperator):
+class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
     """
     Glacier sensor for checking job state. This operator runs only in 
reschedule mode.
 
@@ -63,12 +63,12 @@ class GlacierJobOperationSensor(BaseSensorOperator):
         prevent too much load on the scheduler.
     """
 
-    template_fields: Sequence[str] = ("vault_name", "job_id")
+    aws_hook_class = GlacierHook
+    template_fields: Sequence[str] = aws_template_fields("vault_name", 
"job_id")
 
     def __init__(
         self,
         *,
-        aws_conn_id: str = "aws_default",
         vault_name: str,
         job_id: str,
         poke_interval: int = 60 * 20,
@@ -76,16 +76,11 @@ class GlacierJobOperationSensor(BaseSensorOperator):
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
-        self.aws_conn_id = aws_conn_id
         self.vault_name = vault_name
         self.job_id = job_id
         self.poke_interval = poke_interval
         self.mode = mode
 
-    @cached_property
-    def hook(self):
-        return GlacierHook(aws_conn_id=self.aws_conn_id)
-
     def poke(self, context: Context) -> bool:
         response = self.hook.describe_job(vault_name=self.vault_name, 
job_id=self.job_id)
 
diff --git a/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst 
b/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
index 9dca7a776c..c85e7ac294 100644
--- a/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
+++ b/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
@@ -27,6 +27,11 @@ Prerequisite Tasks
 
 .. include:: ../../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_glacier.py 
b/tests/providers/amazon/aws/operators/test_glacier.py
index d9afe50511..4dbd8f2f5a 100644
--- a/tests/providers/amazon/aws/operators/test_glacier.py
+++ b/tests/providers/amazon/aws/operators/test_glacier.py
@@ -17,13 +17,19 @@
 # under the License.
 from __future__ import annotations
 
+from typing import TYPE_CHECKING, Any
 from unittest import mock
 
+import pytest
+
 from airflow.providers.amazon.aws.operators.glacier import (
     GlacierCreateJobOperator,
     GlacierUploadArchiveOperator,
 )
 
+if TYPE_CHECKING:
+    from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+
 AWS_CONN_ID = "aws_default"
 BUCKET_NAME = "airflow_bucket"
 FILENAME = "path/to/file/"
@@ -34,22 +40,60 @@ TASK_ID = "glacier_job"
 VAULT_NAME = "airflow"
 
 
-class TestGlacierCreateJobOperator:
-    @mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook")
+class BaseGlacierOperatorsTests:
+    op_class: type[AwsBaseOperator]
+    default_op_kwargs: dict[str, Any]
+
+    def test_base_aws_op_attributes(self):
+        op = self.op_class(**self.default_op_kwargs)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
+        op = self.op_class(
+            **self.default_op_kwargs,
+            aws_conn_id="aws-test-custom-conn",
+            region_name="eu-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == "aws-test-custom-conn"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+
+class TestGlacierCreateJobOperator(BaseGlacierOperatorsTests):
+    op_class = GlacierCreateJobOperator
+
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self):
+        self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id": TASK_ID}
+
+    @mock.patch.object(GlacierCreateJobOperator, "hook", 
new_callable=mock.PropertyMock)
     def test_execute(self, hook_mock):
-        op = GlacierCreateJobOperator(aws_conn_id=AWS_CONN_ID, 
vault_name=VAULT_NAME, task_id=TASK_ID)
+        op = self.op_class(aws_conn_id=None, **self.default_op_kwargs)
         op.execute(mock.MagicMock())
-        hook_mock.assert_called_once_with(aws_conn_id=AWS_CONN_ID)
         
hook_mock.return_value.retrieve_inventory.assert_called_once_with(vault_name=VAULT_NAME)
 
 
-class TestGlacierUploadArchiveOperator:
-    
@mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook.get_conn")
-    def test_execute(self, hook_mock):
-        op = GlacierUploadArchiveOperator(
-            aws_conn_id=AWS_CONN_ID, vault_name=VAULT_NAME, body=b"Test Data", 
task_id=TASK_ID
-        )
-        op.execute(mock.MagicMock())
-        hook_mock.return_value.upload_archive.assert_called_once_with(
-            accountId=None, vaultName=VAULT_NAME, archiveDescription=None, 
body=b"Test Data", checksum=None
-        )
+class TestGlacierUploadArchiveOperator(BaseGlacierOperatorsTests):
+    op_class = GlacierUploadArchiveOperator
+
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self):
+        self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id": 
TASK_ID, "body": b"Test Data"}
+
+    def test_execute(self):
+        with mock.patch.object(self.op_class.aws_hook_class, "conn", 
new_callable=mock.PropertyMock) as m:
+            op = self.op_class(aws_conn_id=None, **self.default_op_kwargs)
+            op.execute(mock.MagicMock())
+            m.return_value.upload_archive.assert_called_once_with(
+                accountId=None,
+                vaultName=VAULT_NAME,
+                archiveDescription=None,
+                body=b"Test Data",
+                checksum=None,
+            )
diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py 
b/tests/providers/amazon/aws/sensors/test_glacier.py
index 4213eed9d0..5019a4dd0c 100644
--- a/tests/providers/amazon/aws/sensors/test_glacier.py
+++ b/tests/providers/amazon/aws/sensors/test_glacier.py
@@ -28,49 +28,68 @@ SUCCEEDED = "Succeeded"
 IN_PROGRESS = "InProgress"
 
 
+@pytest.fixture
+def mocked_describe_job():
+    with 
mock.patch("airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job")
 as m:
+        yield m
+
+
 class TestAmazonGlacierSensor:
     def setup_method(self):
-        self.op = GlacierJobOperationSensor(
+        self.default_op_kwargs = dict(
             task_id="test_athena_sensor",
-            aws_conn_id="aws_default",
             vault_name="airflow",
             job_id="1a2b3c4d",
             poke_interval=60 * 20,
         )
+        self.op = GlacierJobOperationSensor(**self.default_op_kwargs, 
aws_conn_id=None)
 
-    @mock.patch(
-        
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
-        side_effect=[{"Action": "", "StatusCode": JobStatus.SUCCEEDED.value}],
-    )
-    def test_poke_succeeded(self, _):
+    def test_base_aws_op_attributes(self):
+        op = GlacierJobOperationSensor(**self.default_op_kwargs)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
+        op = GlacierJobOperationSensor(
+            **self.default_op_kwargs,
+            aws_conn_id="aws-test-custom-conn",
+            region_name="eu-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == "aws-test-custom-conn"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+    def test_poke_succeeded(self, mocked_describe_job):
+        mocked_describe_job.side_effect = [{"Action": "", "StatusCode": 
JobStatus.SUCCEEDED.value}]
         assert self.op.poke(None)
 
-    @mock.patch(
-        
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
-        side_effect=[{"Action": "", "StatusCode": 
JobStatus.IN_PROGRESS.value}],
-    )
-    def test_poke_in_progress(self, _):
+    def test_poke_in_progress(self, mocked_describe_job):
+        mocked_describe_job.side_effect = [{"Action": "", "StatusCode": 
JobStatus.IN_PROGRESS.value}]
         assert not self.op.poke(None)
 
-    @mock.patch(
-        
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
-        side_effect=[{"Action": "", "StatusCode": ""}],
-    )
-    def test_poke_fail(self, _):
-        with pytest.raises(AirflowException) as ctx:
+    def test_poke_fail(self, mocked_describe_job):
+        mocked_describe_job.side_effect = [{"Action": "", "StatusCode": ""}]
+        with pytest.raises(AirflowException, match="Sensor failed"):
             self.op.poke(None)
-        assert "Sensor failed" in str(ctx.value)
 
     @pytest.mark.parametrize(
-        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+        "soft_fail, expected_exception",
+        [
+            pytest.param(False, AirflowException, id="not-soft-fail"),
+            pytest.param(True, AirflowSkipException, id="soft-fail"),
+        ],
     )
-    
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.describe_job")
-    def test_fail_poke(self, describe_job, soft_fail, expected_exception):
+    def test_fail_poke(self, soft_fail, expected_exception, 
mocked_describe_job):
         self.op.soft_fail = soft_fail
         response = {"Action": "some action", "StatusCode": "Failed"}
         message = f'Sensor failed. Job status: {response["Action"]}, code 
status: {response["StatusCode"]}'
         with pytest.raises(expected_exception, match=message):
-            describe_job.return_value = response
+            mocked_describe_job.return_value = response
             self.op.poke(context={})
 
 

Reply via email to