This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 6fd8f36893 Unit tests: Reorganize AWS custom waiter (#38785) 6fd8f36893 is described below commit 6fd8f368937711b4edd3b8b0ecf08080afee27c1 Author: D. Ferruzzi <ferru...@amazon.com> AuthorDate: Sat Apr 6 01:59:45 2024 -0700 Unit tests: Reorganize AWS custom waiter (#38785) --- tests/providers/amazon/aws/waiters/test_batch.py | 78 +++++ .../amazon/aws/waiters/test_custom_waiters.py | 341 --------------------- tests/providers/amazon/aws/waiters/test_dynamo.py | 82 +++++ tests/providers/amazon/aws/waiters/test_ecs.py | 139 +++++++++ tests/providers/amazon/aws/waiters/test_eks.py | 56 ++++ tests/providers/amazon/aws/waiters/test_emr.py | 95 ++++++ 6 files changed, 450 insertions(+), 341 deletions(-) diff --git a/tests/providers/amazon/aws/waiters/test_batch.py b/tests/providers/amazon/aws/waiters/test_batch.py new file mode 100644 index 0000000000..f71caf9eda --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_batch.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import boto3 +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook + + +class TestCustomBatchServiceWaiters: + JOB_ID = "test_job_id" + + @pytest.fixture(autouse=True) + def setup_test_cases(self, monkeypatch): + self.client = boto3.client("batch", region_name="eu-west-3") + monkeypatch.setattr(BatchClientHook, "conn", self.client) + + @pytest.fixture + def mock_describe_jobs(self): + """Mock ``BatchClientHook.Client.describe_jobs`` method.""" + with mock.patch.object(self.client, "describe_jobs") as m: + yield m + + def test_service_waiters(self): + hook_waiters = BatchClientHook(aws_conn_id=None).list_waiters() + assert "batch_job_complete" in hook_waiters + + @staticmethod + def describe_jobs(status: str): + """ + Helper function for generate minimal DescribeJobs response for a single job. + https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html + """ + return { + "jobs": [ + { + "status": status, + }, + ], + } + + def test_job_succeeded(self, mock_describe_jobs): + """Test job succeeded""" + mock_describe_jobs.side_effect = [ + self.describe_jobs(BatchClientHook.RUNNING_STATE), + self.describe_jobs(BatchClientHook.SUCCESS_STATE), + ] + waiter = BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete") + waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01, "MaxAttempts": 2}) + + def test_job_failed(self, mock_describe_jobs): + """Test job failed""" + mock_describe_jobs.side_effect = [ + self.describe_jobs(BatchClientHook.RUNNING_STATE), + self.describe_jobs(BatchClientHook.FAILURE_STATE), + ] + waiter = BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete") + + with pytest.raises(WaiterError, match="Waiter encountered a terminal failure state"): + waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01, "MaxAttempts": 2}) diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py b/tests/providers/amazon/aws/waiters/test_custom_waiters.py index 3272d603c4..82e0594337 100644 --- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py +++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py @@ -17,22 +17,13 @@ from __future__ import annotations -import json -from typing import Sequence from unittest import mock import boto3 import pytest -from botocore.exceptions import WaiterError from botocore.waiter import WaiterModel -from moto import mock_aws from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook -from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook -from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, EcsTaskDefinitionStates -from airflow.providers.amazon.aws.hooks.eks import EksHook -from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter @@ -92,335 +83,3 @@ class TestBaseWaiter: with mock.patch("botocore.client.BaseClient.get_waiter") as m: hook.get_waiter(waiter_name="FooBar") m.assert_called_once_with("FooBar") - - -class TestCustomEKSServiceWaiters: - def test_service_waiters(self): - hook = EksHook() - with open(hook.waiter_path) as config_file: - expected_waiters = json.load(config_file)["waiters"] - - for waiter in list(expected_waiters.keys()): - assert waiter in hook.list_waiters() - assert waiter in hook._list_custom_waiters() - - @mock_aws - def test_existing_waiter_inherited(self): - """ - AwsBaseHook::get_waiter will first check if there is a custom waiter with the - provided name and pass that through is it exists, otherwise it will check the - custom waiters for the given service. This test checks to make sure that the - waiter is the same whichever way you get it and no modifications are made. - """ - hook_waiter = EksHook().get_waiter("cluster_active") - client_waiter = EksHook().conn.get_waiter("cluster_active") - boto_waiter = boto3.client("eks").get_waiter("cluster_active") - - assert_all_match(hook_waiter.name, client_waiter.name, boto_waiter.name) - assert_all_match(len(hook_waiter.__dict__), len(client_waiter.__dict__), len(boto_waiter.__dict__)) - for attr in hook_waiter.__dict__: - # Not all attributes in a Waiter are directly comparable - # so the best we can do it make sure the same attrs exist. - assert hasattr(boto_waiter, attr) - assert hasattr(client_waiter, attr) - - -class TestCustomECSServiceWaiters: - """Test waiters from ``amazon/aws/waiters/ecs.json``.""" - - @pytest.fixture(autouse=True) - def setup_test_cases(self, monkeypatch): - self.client = boto3.client("ecs", region_name="eu-west-3") - monkeypatch.setattr(EcsHook, "conn", self.client) - - @pytest.fixture - def mock_describe_clusters(self): - """Mock ``ECS.Client.describe_clusters`` method.""" - with mock.patch.object(self.client, "describe_clusters") as m: - yield m - - @pytest.fixture - def mock_describe_task_definition(self): - """Mock ``ECS.Client.describe_task_definition`` method.""" - with mock.patch.object(self.client, "describe_task_definition") as m: - yield m - - def test_service_waiters(self): - hook_waiters = EcsHook(aws_conn_id=None).list_waiters() - assert "cluster_active" in hook_waiters - assert "cluster_inactive" in hook_waiters - - @staticmethod - def describe_clusters( - status: str | EcsClusterStates, cluster_name: str = "spam-egg", failures: dict | list | None = None - ): - """ - Helper function for generate minimal DescribeClusters response for single job. - https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeClusters.html - """ - if isinstance(status, EcsClusterStates): - status = status.value - else: - assert status in EcsClusterStates.__members__.values() - - failures = failures or [] - if isinstance(failures, dict): - failures = [failures] - - return {"clusters": [{"clusterName": cluster_name, "status": status}], "failures": failures} - - def test_cluster_active(self, mock_describe_clusters): - """Test cluster reach Active state during creation.""" - mock_describe_clusters.side_effect = [ - self.describe_clusters(EcsClusterStates.DEPROVISIONING), - self.describe_clusters(EcsClusterStates.PROVISIONING), - self.describe_clusters(EcsClusterStates.ACTIVE), - ] - waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active") - waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) - - @pytest.mark.parametrize("state", ["FAILED", "INACTIVE"]) - def test_cluster_active_failure_states(self, mock_describe_clusters, state): - """Test cluster reach inactive state during creation.""" - mock_describe_clusters.side_effect = [ - self.describe_clusters(EcsClusterStates.PROVISIONING), - self.describe_clusters(state), - ] - waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active") - with pytest.raises(WaiterError, match=f'matched expected path: "{state}"'): - waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) - - def test_cluster_active_failure_reasons(self, mock_describe_clusters): - """Test cluster reach failure state during creation.""" - mock_describe_clusters.side_effect = [ - self.describe_clusters(EcsClusterStates.PROVISIONING), - self.describe_clusters(EcsClusterStates.PROVISIONING, failures={"reason": "MISSING"}), - ] - waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active") - with pytest.raises(WaiterError, match='matched expected path: "MISSING"'): - waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) - - def test_cluster_inactive(self, mock_describe_clusters): - """Test cluster reach Inactive state during deletion.""" - mock_describe_clusters.side_effect = [ - self.describe_clusters(EcsClusterStates.ACTIVE), - self.describe_clusters(EcsClusterStates.ACTIVE), - self.describe_clusters(EcsClusterStates.INACTIVE), - ] - waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive") - waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) - - def test_cluster_inactive_failure_reasons(self, mock_describe_clusters): - """Test cluster reach failure state during deletion.""" - mock_describe_clusters.side_effect = [ - self.describe_clusters(EcsClusterStates.ACTIVE), - self.describe_clusters(EcsClusterStates.DEPROVISIONING), - self.describe_clusters(EcsClusterStates.DEPROVISIONING, failures={"reason": "MISSING"}), - ] - waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive") - waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) - - @staticmethod - def describe_task_definition(status: str | EcsTaskDefinitionStates, task_definition: str = "spam-egg"): - """ - Helper function for generate minimal DescribeTaskDefinition response for single job. - https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html - """ - if isinstance(status, EcsTaskDefinitionStates): - status = status.value - else: - assert status in EcsTaskDefinitionStates.__members__.values() - - return { - "taskDefinition": { - "taskDefinitionArn": ( - f"arn:aws:ecs:eu-west-3:123456789012:task-definition/{task_definition}:42" - ), - "status": status, - } - } - - -class TestCustomDynamoDBServiceWaiters: - """Test waiters from ``amazon/aws/waiters/dynamodb.json``.""" - - STATUS_COMPLETED = "COMPLETED" - STATUS_FAILED = "FAILED" - STATUS_IN_PROGRESS = "IN_PROGRESS" - - @pytest.fixture(autouse=True) - def setup_test_cases(self, monkeypatch): - self.resource = boto3.resource("dynamodb", region_name="eu-west-3") - monkeypatch.setattr(DynamoDBHook, "conn", self.resource) - self.client = self.resource.meta.client - - @pytest.fixture - def mock_describe_export(self): - """Mock ``DynamoDBHook.Client.describe_export`` method.""" - with mock.patch.object(self.client, "describe_export") as m: - yield m - - def test_service_waiters(self): - hook_waiters = DynamoDBHook(aws_conn_id=None).list_waiters() - assert "export_table" in hook_waiters - - @staticmethod - def describe_export(status: str): - """ - Helper function for generate minimal DescribeExport response for single job. - https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_DescribeExport.html - """ - return {"ExportDescription": {"ExportStatus": status}} - - def test_export_table_to_point_in_time_completed(self, mock_describe_export): - """Test state transition from `in progress` to `completed` during init.""" - waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table") - mock_describe_export.side_effect = [ - self.describe_export(self.STATUS_IN_PROGRESS), - self.describe_export(self.STATUS_COMPLETED), - ] - waiter.wait( - ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", - WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, - ) - - def test_export_table_to_point_in_time_failed(self, mock_describe_export): - """Test state transition from `in progress` to `failed` during init.""" - with mock.patch("boto3.client") as client: - client.return_value = self.client - mock_describe_export.side_effect = [ - self.describe_export(self.STATUS_IN_PROGRESS), - self.describe_export(self.STATUS_FAILED), - ] - waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client) - with pytest.raises(WaiterError, match='we matched expected path: "FAILED"'): - waiter.wait( - ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", - WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, - ) - - -class TestCustomBatchServiceWaiters: - """Test waiters from ``amazon/aws/waiters/batch.json``.""" - - JOB_ID = "test_job_id" - - @pytest.fixture(autouse=True) - def setup_test_cases(self, monkeypatch): - self.client = boto3.client("batch", region_name="eu-west-3") - monkeypatch.setattr(BatchClientHook, "conn", self.client) - - @pytest.fixture - def mock_describe_jobs(self): - """Mock ``BatchClientHook.Client.describe_jobs`` method.""" - with mock.patch.object(self.client, "describe_jobs") as m: - yield m - - def test_service_waiters(self): - hook_waiters = BatchClientHook(aws_conn_id=None).list_waiters() - assert "batch_job_complete" in hook_waiters - - @staticmethod - def describe_jobs(status: str): - """ - Helper function for generate minimal DescribeJobs response for a single job. - https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html - """ - return { - "jobs": [ - { - "status": status, - }, - ], - } - - def test_job_succeeded(self, mock_describe_jobs): - """Test job succeeded""" - mock_describe_jobs.side_effect = [ - self.describe_jobs(BatchClientHook.RUNNING_STATE), - self.describe_jobs(BatchClientHook.SUCCESS_STATE), - ] - waiter = BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete") - waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01, "MaxAttempts": 2}) - - def test_job_failed(self, mock_describe_jobs): - """Test job failed""" - mock_describe_jobs.side_effect = [ - self.describe_jobs(BatchClientHook.RUNNING_STATE), - self.describe_jobs(BatchClientHook.FAILURE_STATE), - ] - waiter = BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete") - - with pytest.raises(WaiterError, match="Waiter encountered a terminal failure state"): - waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01, "MaxAttempts": 2}) - - -class TestCustomEmrServiceWaiters: - """Test waiters from ``amazon/aws/waiters/emr.json``.""" - - JOBFLOW_ID = "test_jobflow_id" - STEP_ID1 = "test_step_id_1" - STEP_ID2 = "test_step_id_2" - - @pytest.fixture(autouse=True) - def setup_test_cases(self, monkeypatch): - self.client = boto3.client("emr", region_name="eu-west-3") - monkeypatch.setattr(EmrHook, "conn", self.client) - - @pytest.fixture - def mock_list_steps(self): - """Mock ``EmrHook.Client.list_steps`` method.""" - with mock.patch.object(self.client, "list_steps") as m: - yield m - - def test_service_waiters(self): - hook_waiters = EmrHook(aws_conn_id=None).list_waiters() - assert "steps_wait_for_terminal" in hook_waiters - - @staticmethod - def list_steps(step_records: Sequence[tuple[str, str]]): - """ - Helper function to generate minimal ListSteps response. - https://docs.aws.amazon.com/emr/latest/APIReference/API_ListSteps.html - """ - return { - "Steps": [ - { - "Id": step_record[0], - "Status": { - "State": step_record[1], - }, - } - for step_record in step_records - ], - } - - def test_steps_succeeded(self, mock_list_steps): - """Test steps succeeded""" - mock_list_steps.side_effect = [ - self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2, "RUNNING")]), - self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2, "COMPLETED")]), - self.list_steps([(self.STEP_ID1, "COMPLETED"), (self.STEP_ID2, "COMPLETED")]), - ] - waiter = EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal") - waiter.wait( - ClusterId=self.JOBFLOW_ID, - StepIds=[self.STEP_ID1, self.STEP_ID2], - WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, - ) - - def test_steps_failed(self, mock_list_steps): - """Test steps failed""" - mock_list_steps.side_effect = [ - self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2, "RUNNING")]), - self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2, "COMPLETED")]), - self.list_steps([(self.STEP_ID1, "FAILED"), (self.STEP_ID2, "COMPLETED")]), - ] - waiter = EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal") - - with pytest.raises(WaiterError, match="Waiter encountered a terminal failure state"): - waiter.wait( - ClusterId=self.JOBFLOW_ID, - StepIds=[self.STEP_ID1, self.STEP_ID2], - WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, - ) diff --git a/tests/providers/amazon/aws/waiters/test_dynamo.py b/tests/providers/amazon/aws/waiters/test_dynamo.py new file mode 100644 index 0000000000..be94f68081 --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_dynamo.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import boto3 +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook + + +class TestCustomDynamoDBServiceWaiters: + STATUS_COMPLETED = "COMPLETED" + STATUS_FAILED = "FAILED" + STATUS_IN_PROGRESS = "IN_PROGRESS" + + @pytest.fixture(autouse=True) + def setup_test_cases(self, monkeypatch): + self.resource = boto3.resource("dynamodb", region_name="eu-west-3") + monkeypatch.setattr(DynamoDBHook, "conn", self.resource) + self.client = self.resource.meta.client + + @pytest.fixture + def mock_describe_export(self): + """Mock ``DynamoDBHook.Client.describe_export`` method.""" + with mock.patch.object(self.client, "describe_export") as m: + yield m + + def test_service_waiters(self): + hook_waiters = DynamoDBHook(aws_conn_id=None).list_waiters() + assert "export_table" in hook_waiters + + @staticmethod + def describe_export(status: str): + """ + Helper function for generate minimal DescribeExport response for single job. + https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_DescribeExport.html + """ + return {"ExportDescription": {"ExportStatus": status}} + + def test_export_table_to_point_in_time_completed(self, mock_describe_export): + """Test state transition from `in progress` to `completed` during init.""" + waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table") + mock_describe_export.side_effect = [ + self.describe_export(self.STATUS_IN_PROGRESS), + self.describe_export(self.STATUS_COMPLETED), + ] + waiter.wait( + ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + ) + + def test_export_table_to_point_in_time_failed(self, mock_describe_export): + """Test state transition from `in progress` to `failed` during init.""" + with mock.patch("boto3.client") as client: + client.return_value = self.client + mock_describe_export.side_effect = [ + self.describe_export(self.STATUS_IN_PROGRESS), + self.describe_export(self.STATUS_FAILED), + ] + waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client) + with pytest.raises(WaiterError, match='we matched expected path: "FAILED"'): + waiter.wait( + ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + ) diff --git a/tests/providers/amazon/aws/waiters/test_ecs.py b/tests/providers/amazon/aws/waiters/test_ecs.py new file mode 100644 index 0000000000..5742cd59b1 --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_ecs.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import boto3 +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, EcsTaskDefinitionStates + + +class TestCustomECSServiceWaiters: + @pytest.fixture(autouse=True) + def setup_test_cases(self, monkeypatch): + self.client = boto3.client("ecs", region_name="eu-west-3") + monkeypatch.setattr(EcsHook, "conn", self.client) + + @pytest.fixture + def mock_describe_clusters(self): + """Mock ``ECS.Client.describe_clusters`` method.""" + with mock.patch.object(self.client, "describe_clusters") as m: + yield m + + @pytest.fixture + def mock_describe_task_definition(self): + """Mock ``ECS.Client.describe_task_definition`` method.""" + with mock.patch.object(self.client, "describe_task_definition") as m: + yield m + + def test_service_waiters(self): + hook_waiters = EcsHook(aws_conn_id=None).list_waiters() + assert "cluster_active" in hook_waiters + assert "cluster_inactive" in hook_waiters + + @staticmethod + def describe_clusters( + status: str | EcsClusterStates, cluster_name: str = "spam-egg", failures: dict | list | None = None + ): + """ + Helper function for generate minimal DescribeClusters response for single job. + https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeClusters.html + """ + if isinstance(status, EcsClusterStates): + status = status.value + else: + assert status in EcsClusterStates.__members__.values() + + failures = failures or [] + if isinstance(failures, dict): + failures = [failures] + + return {"clusters": [{"clusterName": cluster_name, "status": status}], "failures": failures} + + def test_cluster_active(self, mock_describe_clusters): + """Test cluster reach Active state during creation.""" + mock_describe_clusters.side_effect = [ + self.describe_clusters(EcsClusterStates.DEPROVISIONING), + self.describe_clusters(EcsClusterStates.PROVISIONING), + self.describe_clusters(EcsClusterStates.ACTIVE), + ] + waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active") + waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) + + @pytest.mark.parametrize("state", ["FAILED", "INACTIVE"]) + def test_cluster_active_failure_states(self, mock_describe_clusters, state): + """Test cluster reach inactive state during creation.""" + mock_describe_clusters.side_effect = [ + self.describe_clusters(EcsClusterStates.PROVISIONING), + self.describe_clusters(state), + ] + waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active") + with pytest.raises(WaiterError, match=f'matched expected path: "{state}"'): + waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) + + def test_cluster_active_failure_reasons(self, mock_describe_clusters): + """Test cluster reach failure state during creation.""" + mock_describe_clusters.side_effect = [ + self.describe_clusters(EcsClusterStates.PROVISIONING), + self.describe_clusters(EcsClusterStates.PROVISIONING, failures={"reason": "MISSING"}), + ] + waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active") + with pytest.raises(WaiterError, match='matched expected path: "MISSING"'): + waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) + + def test_cluster_inactive(self, mock_describe_clusters): + """Test cluster reach Inactive state during deletion.""" + mock_describe_clusters.side_effect = [ + self.describe_clusters(EcsClusterStates.ACTIVE), + self.describe_clusters(EcsClusterStates.ACTIVE), + self.describe_clusters(EcsClusterStates.INACTIVE), + ] + waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive") + waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) + + def test_cluster_inactive_failure_reasons(self, mock_describe_clusters): + """Test cluster reach failure state during deletion.""" + mock_describe_clusters.side_effect = [ + self.describe_clusters(EcsClusterStates.ACTIVE), + self.describe_clusters(EcsClusterStates.DEPROVISIONING), + self.describe_clusters(EcsClusterStates.DEPROVISIONING, failures={"reason": "MISSING"}), + ] + waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive") + waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) + + @staticmethod + def describe_task_definition(status: str | EcsTaskDefinitionStates, task_definition: str = "spam-egg"): + """ + Helper function for generate minimal DescribeTaskDefinition response for single job. + https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html + """ + if isinstance(status, EcsTaskDefinitionStates): + status = status.value + else: + assert status in EcsTaskDefinitionStates.__members__.values() + + return { + "taskDefinition": { + "taskDefinitionArn": ( + f"arn:aws:ecs:eu-west-3:123456789012:task-definition/{task_definition}:42" + ), + "status": status, + } + } diff --git a/tests/providers/amazon/aws/waiters/test_eks.py b/tests/providers/amazon/aws/waiters/test_eks.py new file mode 100644 index 0000000000..9013c8b7c6 --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_eks.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json + +import boto3 +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.eks import EksHook +from tests.providers.amazon.aws.waiters.test_custom_waiters import assert_all_match + + +class TestCustomEKSServiceWaiters: + def test_service_waiters(self): + hook = EksHook() + with open(hook.waiter_path) as config_file: + expected_waiters = json.load(config_file)["waiters"] + + for waiter in list(expected_waiters.keys()): + assert waiter in hook.list_waiters() + assert waiter in hook._list_custom_waiters() + + @mock_aws + def test_existing_waiter_inherited(self): + """ + AwsBaseHook::get_waiter will first check if there is a custom waiter with the + provided name and pass that through is it exists, otherwise it will check the + custom waiters for the given service. This test checks to make sure that the + waiter is the same whichever way you get it and no modifications are made. + """ + hook_waiter = EksHook().get_waiter("cluster_active") + client_waiter = EksHook().conn.get_waiter("cluster_active") + boto_waiter = boto3.client("eks").get_waiter("cluster_active") + + assert_all_match(hook_waiter.name, client_waiter.name, boto_waiter.name) + assert_all_match(len(hook_waiter.__dict__), len(client_waiter.__dict__), len(boto_waiter.__dict__)) + for attr in hook_waiter.__dict__: + # Not all attributes in a Waiter are directly comparable + # so the best we can do it make sure the same attrs exist. + assert hasattr(boto_waiter, attr) + assert hasattr(client_waiter, attr) diff --git a/tests/providers/amazon/aws/waiters/test_emr.py b/tests/providers/amazon/aws/waiters/test_emr.py new file mode 100644 index 0000000000..a02327f004 --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_emr.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Sequence +from unittest import mock + +import boto3 +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.emr import EmrHook + + +class TestCustomEmrServiceWaiters: + JOBFLOW_ID = "test_jobflow_id" + STEP_ID1 = "test_step_id_1" + STEP_ID2 = "test_step_id_2" + + @pytest.fixture(autouse=True) + def setup_test_cases(self, monkeypatch): + self.client = boto3.client("emr", region_name="eu-west-3") + monkeypatch.setattr(EmrHook, "conn", self.client) + + @pytest.fixture + def mock_list_steps(self): + """Mock ``EmrHook.Client.list_steps`` method.""" + with mock.patch.object(self.client, "list_steps") as m: + yield m + + def test_service_waiters(self): + hook_waiters = EmrHook(aws_conn_id=None).list_waiters() + assert "steps_wait_for_terminal" in hook_waiters + + @staticmethod + def list_steps(step_records: Sequence[tuple[str, str]]): + """ + Helper function to generate minimal ListSteps response. + https://docs.aws.amazon.com/emr/latest/APIReference/API_ListSteps.html + """ + return { + "Steps": [ + { + "Id": step_record[0], + "Status": { + "State": step_record[1], + }, + } + for step_record in step_records + ], + } + + def test_steps_succeeded(self, mock_list_steps): + """Test steps succeeded""" + mock_list_steps.side_effect = [ + self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2, "RUNNING")]), + self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2, "COMPLETED")]), + self.list_steps([(self.STEP_ID1, "COMPLETED"), (self.STEP_ID2, "COMPLETED")]), + ] + waiter = EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal") + waiter.wait( + ClusterId=self.JOBFLOW_ID, + StepIds=[self.STEP_ID1, self.STEP_ID2], + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + ) + + def test_steps_failed(self, mock_list_steps): + """Test steps failed""" + mock_list_steps.side_effect = [ + self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2, "RUNNING")]), + self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2, "COMPLETED")]), + self.list_steps([(self.STEP_ID1, "FAILED"), (self.STEP_ID2, "COMPLETED")]), + ] + waiter = EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal") + + with pytest.raises(WaiterError, match="Waiter encountered a terminal failure state"): + waiter.wait( + ClusterId=self.JOBFLOW_ID, + StepIds=[self.STEP_ID1, self.STEP_ID2], + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + )