This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 3721c9a441 Use base aws classes in Amazon S3 Glacier Operators/Sensors (#35108) 3721c9a441 is described below commit 3721c9a4413d3f5002b46589beeff490827cd9cb Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Tue Oct 24 18:54:21 2023 +0400 Use base aws classes in Amazon S3 Glacier Operators/Sensors (#35108) --- airflow/providers/amazon/aws/hooks/glacier.py | 6 +- airflow/providers/amazon/aws/operators/glacier.py | 30 ++++----- airflow/providers/amazon/aws/sensors/glacier.py | 15 ++--- .../operators/s3/glacier.rst | 5 ++ .../providers/amazon/aws/operators/test_glacier.py | 72 +++++++++++++++++----- tests/providers/amazon/aws/sensors/test_glacier.py | 65 ++++++++++++------- 6 files changed, 124 insertions(+), 69 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/glacier.py b/airflow/providers/amazon/aws/hooks/glacier.py index bd260000e7..4655b28e30 100644 --- a/airflow/providers/amazon/aws/hooks/glacier.py +++ b/airflow/providers/amazon/aws/hooks/glacier.py @@ -35,9 +35,9 @@ class GlacierHook(AwsBaseHook): - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - def __init__(self, aws_conn_id: str = "aws_default") -> None: - super().__init__(client_type="glacier") - self.aws_conn_id = aws_conn_id + def __init__(self, *args, **kwargs) -> None: + kwargs.update({"client_type": "glacier", "resource_type": None}) + super().__init__(*args, **kwargs) def retrieve_inventory(self, vault_name: str) -> dict[str, Any]: """Initiate an Amazon Glacier inventory-retrieval job. diff --git a/airflow/providers/amazon/aws/operators/glacier.py b/airflow/providers/amazon/aws/operators/glacier.py index 54123e586d..3164004181 100644 --- a/airflow/providers/amazon/aws/operators/glacier.py +++ b/airflow/providers/amazon/aws/operators/glacier.py @@ -19,14 +19,15 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glacier import GlacierHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class GlacierCreateJobOperator(BaseOperator): +class GlacierCreateJobOperator(AwsBaseOperator[GlacierHook]): """ Initiate an Amazon Glacier inventory-retrieval job. @@ -38,25 +39,18 @@ class GlacierCreateJobOperator(BaseOperator): :param vault_name: the Glacier vault on which job is executed """ - template_fields: Sequence[str] = ("vault_name",) + aws_hook_class = GlacierHook + template_fields: Sequence[str] = aws_template_fields("vault_name") - def __init__( - self, - *, - aws_conn_id="aws_default", - vault_name: str, - **kwargs, - ): + def __init__(self, *, vault_name: str, **kwargs): super().__init__(**kwargs) - self.aws_conn_id = aws_conn_id self.vault_name = vault_name def execute(self, context: Context): - hook = GlacierHook(aws_conn_id=self.aws_conn_id) - return hook.retrieve_inventory(vault_name=self.vault_name) + return self.hook.retrieve_inventory(vault_name=self.vault_name) -class GlacierUploadArchiveOperator(BaseOperator): +class GlacierUploadArchiveOperator(AwsBaseOperator[GlacierHook]): """ This operator add an archive to an Amazon S3 Glacier vault. @@ -74,7 +68,8 @@ class GlacierUploadArchiveOperator(BaseOperator): :param aws_conn_id: The reference to the AWS connection details """ - template_fields: Sequence[str] = ("vault_name",) + aws_hook_class = GlacierHook + template_fields: Sequence[str] = aws_template_fields("vault_name") def __init__( self, @@ -84,11 +79,9 @@ class GlacierUploadArchiveOperator(BaseOperator): checksum: str | None = None, archive_description: str | None = None, account_id: str | None = None, - aws_conn_id="aws_default", **kwargs, ): super().__init__(**kwargs) - self.aws_conn_id = aws_conn_id self.account_id = account_id self.vault_name = vault_name self.body = body @@ -96,8 +89,7 @@ class GlacierUploadArchiveOperator(BaseOperator): self.archive_description = archive_description def execute(self, context: Context): - hook = GlacierHook(aws_conn_id=self.aws_conn_id) - return hook.get_conn().upload_archive( + return self.hook.conn.upload_archive( accountId=self.account_id, vaultName=self.vault_name, archiveDescription=self.archive_description, diff --git a/airflow/providers/amazon/aws/sensors/glacier.py b/airflow/providers/amazon/aws/sensors/glacier.py index e9cc8fc4b7..7a65fc6fc3 100644 --- a/airflow/providers/amazon/aws/sensors/glacier.py +++ b/airflow/providers/amazon/aws/sensors/glacier.py @@ -18,12 +18,12 @@ from __future__ import annotations from enum import Enum -from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.glacier import GlacierHook -from airflow.sensors.base import BaseSensorOperator +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context @@ -36,7 +36,7 @@ class JobStatus(Enum): SUCCEEDED = "Succeeded" -class GlacierJobOperationSensor(BaseSensorOperator): +class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]): """ Glacier sensor for checking job state. This operator runs only in reschedule mode. @@ -63,12 +63,12 @@ class GlacierJobOperationSensor(BaseSensorOperator): prevent too much load on the scheduler. """ - template_fields: Sequence[str] = ("vault_name", "job_id") + aws_hook_class = GlacierHook + template_fields: Sequence[str] = aws_template_fields("vault_name", "job_id") def __init__( self, *, - aws_conn_id: str = "aws_default", vault_name: str, job_id: str, poke_interval: int = 60 * 20, @@ -76,16 +76,11 @@ class GlacierJobOperationSensor(BaseSensorOperator): **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.aws_conn_id = aws_conn_id self.vault_name = vault_name self.job_id = job_id self.poke_interval = poke_interval self.mode = mode - @cached_property - def hook(self): - return GlacierHook(aws_conn_id=self.aws_conn_id) - def poke(self, context: Context) -> bool: response = self.hook.describe_job(vault_name=self.vault_name, job_id=self.job_id) diff --git a/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst b/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst index 9dca7a776c..c85e7ac294 100644 --- a/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst +++ b/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst @@ -27,6 +27,11 @@ Prerequisite Tasks .. include:: ../../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_glacier.py b/tests/providers/amazon/aws/operators/test_glacier.py index d9afe50511..4dbd8f2f5a 100644 --- a/tests/providers/amazon/aws/operators/test_glacier.py +++ b/tests/providers/amazon/aws/operators/test_glacier.py @@ -17,13 +17,19 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING, Any from unittest import mock +import pytest + from airflow.providers.amazon.aws.operators.glacier import ( GlacierCreateJobOperator, GlacierUploadArchiveOperator, ) +if TYPE_CHECKING: + from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator + AWS_CONN_ID = "aws_default" BUCKET_NAME = "airflow_bucket" FILENAME = "path/to/file/" @@ -34,22 +40,60 @@ TASK_ID = "glacier_job" VAULT_NAME = "airflow" -class TestGlacierCreateJobOperator: - @mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook") +class BaseGlacierOperatorsTests: + op_class: type[AwsBaseOperator] + default_op_kwargs: dict[str, Any] + + def test_base_aws_op_attributes(self): + op = self.op_class(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = self.op_class( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + +class TestGlacierCreateJobOperator(BaseGlacierOperatorsTests): + op_class = GlacierCreateJobOperator + + @pytest.fixture(autouse=True) + def setup_test_cases(self): + self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id": TASK_ID} + + @mock.patch.object(GlacierCreateJobOperator, "hook", new_callable=mock.PropertyMock) def test_execute(self, hook_mock): - op = GlacierCreateJobOperator(aws_conn_id=AWS_CONN_ID, vault_name=VAULT_NAME, task_id=TASK_ID) + op = self.op_class(aws_conn_id=None, **self.default_op_kwargs) op.execute(mock.MagicMock()) - hook_mock.assert_called_once_with(aws_conn_id=AWS_CONN_ID) hook_mock.return_value.retrieve_inventory.assert_called_once_with(vault_name=VAULT_NAME) -class TestGlacierUploadArchiveOperator: - @mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook.get_conn") - def test_execute(self, hook_mock): - op = GlacierUploadArchiveOperator( - aws_conn_id=AWS_CONN_ID, vault_name=VAULT_NAME, body=b"Test Data", task_id=TASK_ID - ) - op.execute(mock.MagicMock()) - hook_mock.return_value.upload_archive.assert_called_once_with( - accountId=None, vaultName=VAULT_NAME, archiveDescription=None, body=b"Test Data", checksum=None - ) +class TestGlacierUploadArchiveOperator(BaseGlacierOperatorsTests): + op_class = GlacierUploadArchiveOperator + + @pytest.fixture(autouse=True) + def setup_test_cases(self): + self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id": TASK_ID, "body": b"Test Data"} + + def test_execute(self): + with mock.patch.object(self.op_class.aws_hook_class, "conn", new_callable=mock.PropertyMock) as m: + op = self.op_class(aws_conn_id=None, **self.default_op_kwargs) + op.execute(mock.MagicMock()) + m.return_value.upload_archive.assert_called_once_with( + accountId=None, + vaultName=VAULT_NAME, + archiveDescription=None, + body=b"Test Data", + checksum=None, + ) diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py b/tests/providers/amazon/aws/sensors/test_glacier.py index 4213eed9d0..5019a4dd0c 100644 --- a/tests/providers/amazon/aws/sensors/test_glacier.py +++ b/tests/providers/amazon/aws/sensors/test_glacier.py @@ -28,49 +28,68 @@ SUCCEEDED = "Succeeded" IN_PROGRESS = "InProgress" +@pytest.fixture +def mocked_describe_job(): + with mock.patch("airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job") as m: + yield m + + class TestAmazonGlacierSensor: def setup_method(self): - self.op = GlacierJobOperationSensor( + self.default_op_kwargs = dict( task_id="test_athena_sensor", - aws_conn_id="aws_default", vault_name="airflow", job_id="1a2b3c4d", poke_interval=60 * 20, ) + self.op = GlacierJobOperationSensor(**self.default_op_kwargs, aws_conn_id=None) - @mock.patch( - "airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job", - side_effect=[{"Action": "", "StatusCode": JobStatus.SUCCEEDED.value}], - ) - def test_poke_succeeded(self, _): + def test_base_aws_op_attributes(self): + op = GlacierJobOperationSensor(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = GlacierJobOperationSensor( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + def test_poke_succeeded(self, mocked_describe_job): + mocked_describe_job.side_effect = [{"Action": "", "StatusCode": JobStatus.SUCCEEDED.value}] assert self.op.poke(None) - @mock.patch( - "airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job", - side_effect=[{"Action": "", "StatusCode": JobStatus.IN_PROGRESS.value}], - ) - def test_poke_in_progress(self, _): + def test_poke_in_progress(self, mocked_describe_job): + mocked_describe_job.side_effect = [{"Action": "", "StatusCode": JobStatus.IN_PROGRESS.value}] assert not self.op.poke(None) - @mock.patch( - "airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job", - side_effect=[{"Action": "", "StatusCode": ""}], - ) - def test_poke_fail(self, _): - with pytest.raises(AirflowException) as ctx: + def test_poke_fail(self, mocked_describe_job): + mocked_describe_job.side_effect = [{"Action": "", "StatusCode": ""}] + with pytest.raises(AirflowException, match="Sensor failed"): self.op.poke(None) - assert "Sensor failed" in str(ctx.value) @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], ) - @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.describe_job") - def test_fail_poke(self, describe_job, soft_fail, expected_exception): + def test_fail_poke(self, soft_fail, expected_exception, mocked_describe_job): self.op.soft_fail = soft_fail response = {"Action": "some action", "StatusCode": "Failed"} message = f'Sensor failed. Job status: {response["Action"]}, code status: {response["StatusCode"]}' with pytest.raises(expected_exception, match=message): - describe_job.return_value = response + mocked_describe_job.return_value = response self.op.poke(context={})