This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 6779c49812 Merge AWS DMS tests into the single module (#35585) 6779c49812 is described below commit 6779c49812f411ee0225d222a1f6700bc8580db7 Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Sun Nov 12 20:45:25 2023 +0400 Merge AWS DMS tests into the single module (#35585) --- tests/always/test_project_structure.py | 3 - .../aws/hooks/{test_dms_task.py => test_dms.py} | 0 tests/providers/amazon/aws/operators/test_dms.py | 270 +++++++++++++++++++++ .../amazon/aws/operators/test_dms_create_task.py | 85 ------- .../amazon/aws/operators/test_dms_delete_task.py | 56 ----- .../aws/operators/test_dms_describe_tasks.py | 97 -------- .../amazon/aws/operators/test_dms_start_task.py | 60 ----- .../amazon/aws/operators/test_dms_stop_task.py | 56 ----- .../aws/sensors/{test_dms_task.py => test_dms.py} | 0 9 files changed, 270 insertions(+), 357 deletions(-) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 0fd8d7db01..61331490cb 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -64,11 +64,8 @@ class TestProjectStructure: "tests/providers/amazon/aws/executors/ecs/test_ecs_executor_config.py", "tests/providers/amazon/aws/executors/ecs/test_utils.py", "tests/providers/amazon/aws/fs/test_s3.py", - "tests/providers/amazon/aws/hooks/test_dms.py", - "tests/providers/amazon/aws/operators/test_dms.py", "tests/providers/amazon/aws/operators/test_emr.py", "tests/providers/amazon/aws/operators/test_sagemaker.py", - "tests/providers/amazon/aws/sensors/test_dms.py", "tests/providers/amazon/aws/sensors/test_emr.py", "tests/providers/amazon/aws/sensors/test_sagemaker.py", "tests/providers/amazon/aws/test_exceptions.py", diff --git a/tests/providers/amazon/aws/hooks/test_dms_task.py b/tests/providers/amazon/aws/hooks/test_dms.py similarity index 100% rename from tests/providers/amazon/aws/hooks/test_dms_task.py rename to tests/providers/amazon/aws/hooks/test_dms.py diff --git a/tests/providers/amazon/aws/operators/test_dms.py b/tests/providers/amazon/aws/operators/test_dms.py new file mode 100644 index 0000000000..c56e01e4c3 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_dms.py @@ -0,0 +1,270 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from unittest import mock + +import pendulum +import pytest + +from airflow import DAG +from airflow.models import DagRun, TaskInstance +from airflow.providers.amazon.aws.hooks.dms import DmsHook +from airflow.providers.amazon.aws.operators.dms import ( + DmsCreateTaskOperator, + DmsDeleteTaskOperator, + DmsDescribeTasksOperator, + DmsStartTaskOperator, + DmsStopTaskOperator, +) +from airflow.utils import timezone + +TASK_ARN = "test_arn" + + +class TestDmsCreateTaskOperator: + TASK_DATA = { + "replication_task_id": "task_id", + "source_endpoint_arn": "source_endpoint", + "target_endpoint_arn": "target_endpoint", + "replication_instance_arn": "replication_arn", + "table_mappings": { + "rules": [ + { + "rule-type": "selection", + "rule-id": "1", + "rule-name": "1", + "object-locator": { + "schema-name": "test", + "table-name": "%", + }, + "rule-action": "include", + } + ] + }, + } + + def test_init(self): + create_operator = DmsCreateTaskOperator(task_id="create_task", **self.TASK_DATA) + + assert create_operator.replication_task_id == self.TASK_DATA["replication_task_id"] + assert create_operator.source_endpoint_arn == self.TASK_DATA["source_endpoint_arn"] + assert create_operator.target_endpoint_arn == self.TASK_DATA["target_endpoint_arn"] + assert create_operator.replication_instance_arn == self.TASK_DATA["replication_instance_arn"] + assert create_operator.migration_type == "full-load" + assert create_operator.table_mappings == self.TASK_DATA["table_mappings"] + + @mock.patch.object(DmsHook, "get_task_status", side_effect=("ready",)) + @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) + @mock.patch.object(DmsHook, "get_conn") + def test_create_task(self, mock_conn, mock_create_replication_task, mock_get_task_status): + dms_hook = DmsHook() + + create_task = DmsCreateTaskOperator(task_id="create_task", **self.TASK_DATA) + create_task.execute(None) + + mock_create_replication_task.assert_called_once_with(**self.TASK_DATA, migration_type="full-load") + + assert dms_hook.get_task_status(TASK_ARN) == "ready" + + @mock.patch.object(DmsHook, "get_task_status", side_effect=("ready",)) + @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) + @mock.patch.object(DmsHook, "get_conn") + def test_create_task_with_migration_type( + self, mock_conn, mock_create_replication_task, mock_get_task_status + ): + migration_type = "cdc" + dms_hook = DmsHook() + + create_task = DmsCreateTaskOperator( + task_id="create_task", migration_type=migration_type, **self.TASK_DATA + ) + create_task.execute(None) + + mock_create_replication_task.assert_called_once_with(**self.TASK_DATA, migration_type=migration_type) + + assert dms_hook.get_task_status(TASK_ARN) == "ready" + + +class TestDmsDeleteTaskOperator: + TASK_DATA = { + "replication_task_id": "task_id", + "source_endpoint_arn": "source_endpoint", + "target_endpoint_arn": "target_endpoint", + "replication_instance_arn": "replication_arn", + "migration_type": "full-load", + "table_mappings": {}, + } + + def test_init(self): + dms_operator = DmsDeleteTaskOperator(task_id="delete_task", replication_task_arn=TASK_ARN) + + assert dms_operator.replication_task_arn == TASK_ARN + + @mock.patch.object(DmsHook, "get_task_status", side_effect=("deleting",)) + @mock.patch.object(DmsHook, "delete_replication_task") + @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) + @mock.patch.object(DmsHook, "get_conn") + def test_delete_task( + self, mock_conn, mock_create_replication_task, mock_delete_replication_task, mock_get_task_status + ): + dms_hook = DmsHook() + task = dms_hook.create_replication_task(**self.TASK_DATA) + + delete_task = DmsDeleteTaskOperator(task_id="delete_task", replication_task_arn=task) + delete_task.execute(None) + + mock_delete_replication_task.assert_called_once_with(replication_task_arn=TASK_ARN) + + assert dms_hook.get_task_status(TASK_ARN) == "deleting" + + +class TestDmsDescribeTasksOperator: + FILTER = {"Name": "replication-task-arn", "Values": [TASK_ARN]} + MOCK_DATA = { + "replication_task_id": "test_task", + "source_endpoint_arn": "source-endpoint-arn", + "target_endpoint_arn": "target-endpoint-arn", + "replication_instance_arn": "replication-instance-arn", + "migration_type": "full-load", + "table_mappings": {}, + } + MOCK_RESPONSE = [ + { + "ReplicationTaskIdentifier": MOCK_DATA["replication_task_id"], + "SourceEndpointArn": MOCK_DATA["source_endpoint_arn"], + "TargetEndpointArn": MOCK_DATA["target_endpoint_arn"], + "ReplicationInstanceArn": MOCK_DATA["replication_instance_arn"], + "MigrationType": MOCK_DATA["migration_type"], + "TableMappings": json.dumps(MOCK_DATA["table_mappings"]), + "ReplicationTaskArn": TASK_ARN, + "Status": "creating", + } + ] + + def setup_method(self): + args = { + "owner": "airflow", + "start_date": pendulum.datetime(2018, 1, 1, tz="UTC"), + } + + self.dag = DAG("dms_describe_tasks_operator", default_args=args, schedule="@once") + + def test_init(self): + dms_operator = DmsDescribeTasksOperator( + task_id="describe_tasks", describe_tasks_kwargs={"Filters": [self.FILTER]} + ) + + assert dms_operator.describe_tasks_kwargs == {"Filters": [self.FILTER]} + + @mock.patch.object(DmsHook, "describe_replication_tasks", return_value=(None, MOCK_RESPONSE)) + @mock.patch.object(DmsHook, "get_conn") + def test_describe_tasks(self, mock_conn, mock_describe_replication_tasks): + describe_tasks_kwargs = {"Filters": [self.FILTER]} + describe_task = DmsDescribeTasksOperator( + task_id="describe_tasks", describe_tasks_kwargs=describe_tasks_kwargs + ) + describe_task.execute(None) + + mock_describe_replication_tasks.assert_called_once_with(**describe_tasks_kwargs) + + @pytest.mark.db_test + @mock.patch.object(DmsHook, "describe_replication_tasks", return_value=(None, MOCK_RESPONSE)) + @mock.patch.object(DmsHook, "get_conn") + def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_tasks): + describe_task = DmsDescribeTasksOperator( + task_id="describe_tasks", dag=self.dag, describe_tasks_kwargs={"Filters": [self.FILTER]} + ) + + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=describe_task) + ti.dag_run = dag_run + marker, response = describe_task.execute(ti.get_template_context()) + + assert marker is None + assert response == self.MOCK_RESPONSE + + +class TestDmsStartTaskOperator: + TASK_DATA = { + "replication_task_id": "task_id", + "source_endpoint_arn": "source_endpoint", + "target_endpoint_arn": "target_endpoint", + "replication_instance_arn": "replication_arn", + "migration_type": "full-load", + "table_mappings": {}, + } + + def test_init(self): + dms_operator = DmsStartTaskOperator(task_id="start_task", replication_task_arn=TASK_ARN) + + assert dms_operator.replication_task_arn == TASK_ARN + assert dms_operator.start_replication_task_type == "start-replication" + + @mock.patch.object(DmsHook, "get_task_status", side_effect=("starting",)) + @mock.patch.object(DmsHook, "start_replication_task") + @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) + @mock.patch.object(DmsHook, "get_conn") + def test_start_task( + self, mock_conn, mock_create_replication_task, mock_start_replication_task, mock_get_task_status + ): + dms_hook = DmsHook() + task = dms_hook.create_replication_task(**self.TASK_DATA) + + start_task = DmsStartTaskOperator(task_id="start_task", replication_task_arn=task) + start_task.execute(None) + + mock_start_replication_task.assert_called_once_with( + replication_task_arn=TASK_ARN, + start_replication_task_type="start-replication", + ) + + assert dms_hook.get_task_status(TASK_ARN) == "starting" + + +class TestDmsStopTaskOperator: + TASK_DATA = { + "replication_task_id": "task_id", + "source_endpoint_arn": "source_endpoint", + "target_endpoint_arn": "target_endpoint", + "replication_instance_arn": "replication_arn", + "migration_type": "full-load", + "table_mappings": {}, + } + + def test_init(self): + dms_operator = DmsStopTaskOperator(task_id="stop_task", replication_task_arn=TASK_ARN) + + assert dms_operator.replication_task_arn == TASK_ARN + + @mock.patch.object(DmsHook, "get_task_status", side_effect=("stopping",)) + @mock.patch.object(DmsHook, "stop_replication_task") + @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) + @mock.patch.object(DmsHook, "get_conn") + def test_stop_task( + self, mock_conn, mock_create_replication_task, mock_stop_replication_task, mock_get_task_status + ): + dms_hook = DmsHook() + task = dms_hook.create_replication_task(**self.TASK_DATA) + + stop_task = DmsStopTaskOperator(task_id="stop_task", replication_task_arn=task) + stop_task.execute(None) + + mock_stop_replication_task.assert_called_once_with(replication_task_arn=TASK_ARN) + + assert dms_hook.get_task_status(TASK_ARN) == "stopping" diff --git a/tests/providers/amazon/aws/operators/test_dms_create_task.py b/tests/providers/amazon/aws/operators/test_dms_create_task.py deleted file mode 100644 index 73d2dbd20c..0000000000 --- a/tests/providers/amazon/aws/operators/test_dms_create_task.py +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.providers.amazon.aws.operators.dms import DmsCreateTaskOperator - -TASK_ARN = "test_arn" -TASK_DATA = { - "replication_task_id": "task_id", - "source_endpoint_arn": "source_endpoint", - "target_endpoint_arn": "target_endpoint", - "replication_instance_arn": "replication_arn", - "table_mappings": { - "rules": [ - { - "rule-type": "selection", - "rule-id": "1", - "rule-name": "1", - "object-locator": { - "schema-name": "test", - "table-name": "%", - }, - "rule-action": "include", - } - ] - }, -} - - -class TestDmsCreateTaskOperator: - def test_init(self): - create_operator = DmsCreateTaskOperator(task_id="create_task", **TASK_DATA) - - assert create_operator.replication_task_id == TASK_DATA["replication_task_id"] - assert create_operator.source_endpoint_arn == TASK_DATA["source_endpoint_arn"] - assert create_operator.target_endpoint_arn == TASK_DATA["target_endpoint_arn"] - assert create_operator.replication_instance_arn == TASK_DATA["replication_instance_arn"] - assert create_operator.migration_type == "full-load" - assert create_operator.table_mappings == TASK_DATA["table_mappings"] - - @mock.patch.object(DmsHook, "get_task_status", side_effect=("ready",)) - @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) - @mock.patch.object(DmsHook, "get_conn") - def test_create_task(self, mock_conn, mock_create_replication_task, mock_get_task_status): - dms_hook = DmsHook() - - create_task = DmsCreateTaskOperator(task_id="create_task", **TASK_DATA) - create_task.execute(None) - - mock_create_replication_task.assert_called_once_with(**TASK_DATA, migration_type="full-load") - - assert dms_hook.get_task_status(TASK_ARN) == "ready" - - @mock.patch.object(DmsHook, "get_task_status", side_effect=("ready",)) - @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) - @mock.patch.object(DmsHook, "get_conn") - def test_create_task_with_migration_type( - self, mock_conn, mock_create_replication_task, mock_get_task_status - ): - migration_type = "cdc" - dms_hook = DmsHook() - - create_task = DmsCreateTaskOperator(task_id="create_task", migration_type=migration_type, **TASK_DATA) - create_task.execute(None) - - mock_create_replication_task.assert_called_once_with(**TASK_DATA, migration_type=migration_type) - - assert dms_hook.get_task_status(TASK_ARN) == "ready" diff --git a/tests/providers/amazon/aws/operators/test_dms_delete_task.py b/tests/providers/amazon/aws/operators/test_dms_delete_task.py deleted file mode 100644 index b7e66d8546..0000000000 --- a/tests/providers/amazon/aws/operators/test_dms_delete_task.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.providers.amazon.aws.operators.dms import DmsDeleteTaskOperator - -TASK_ARN = "test_arn" -TASK_DATA = { - "replication_task_id": "task_id", - "source_endpoint_arn": "source_endpoint", - "target_endpoint_arn": "target_endpoint", - "replication_instance_arn": "replication_arn", - "migration_type": "full-load", - "table_mappings": {}, -} - - -class TestDmsDeleteTaskOperator: - def test_init(self): - dms_operator = DmsDeleteTaskOperator(task_id="delete_task", replication_task_arn=TASK_ARN) - - assert dms_operator.replication_task_arn == TASK_ARN - - @mock.patch.object(DmsHook, "get_task_status", side_effect=("deleting",)) - @mock.patch.object(DmsHook, "delete_replication_task") - @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) - @mock.patch.object(DmsHook, "get_conn") - def test_delete_task( - self, mock_conn, mock_create_replication_task, mock_delete_replication_task, mock_get_task_status - ): - dms_hook = DmsHook() - task = dms_hook.create_replication_task(**TASK_DATA) - - delete_task = DmsDeleteTaskOperator(task_id="delete_task", replication_task_arn=task) - delete_task.execute(None) - - mock_delete_replication_task.assert_called_once_with(replication_task_arn=TASK_ARN) - - assert dms_hook.get_task_status(TASK_ARN) == "deleting" diff --git a/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py b/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py deleted file mode 100644 index 2ebc294f83..0000000000 --- a/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py +++ /dev/null @@ -1,97 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import json -from unittest import mock - -import pytest - -from airflow.models import DAG, DagRun, TaskInstance -from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.providers.amazon.aws.operators.dms import DmsDescribeTasksOperator -from airflow.utils import timezone -from airflow.utils.timezone import datetime - -TEST_DAG_ID = "unit_tests" -DEFAULT_DATE = datetime(2018, 1, 1) -MOCK_TASK_ARN = "test_arn" -FILTER = {"Name": "replication-task-arn", "Values": [MOCK_TASK_ARN]} -MOCK_DATA = { - "replication_task_id": "test_task", - "source_endpoint_arn": "source-endpoint-arn", - "target_endpoint_arn": "target-endpoint-arn", - "replication_instance_arn": "replication-instance-arn", - "migration_type": "full-load", - "table_mappings": {}, -} -MOCK_RESPONSE = [ - { - "ReplicationTaskIdentifier": MOCK_DATA["replication_task_id"], - "SourceEndpointArn": MOCK_DATA["source_endpoint_arn"], - "TargetEndpointArn": MOCK_DATA["target_endpoint_arn"], - "ReplicationInstanceArn": MOCK_DATA["replication_instance_arn"], - "MigrationType": MOCK_DATA["migration_type"], - "TableMappings": json.dumps(MOCK_DATA["table_mappings"]), - "ReplicationTaskArn": MOCK_TASK_ARN, - "Status": "creating", - } -] - - -class TestDmsDescribeTasksOperator: - def setup_method(self): - args = { - "owner": "airflow", - "start_date": DEFAULT_DATE, - } - - self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", default_args=args, schedule="@once") - - def test_init(self): - dms_operator = DmsDescribeTasksOperator( - task_id="describe_tasks", describe_tasks_kwargs={"Filters": [FILTER]} - ) - - assert dms_operator.describe_tasks_kwargs == {"Filters": [FILTER]} - - @mock.patch.object(DmsHook, "describe_replication_tasks", return_value=(None, MOCK_RESPONSE)) - @mock.patch.object(DmsHook, "get_conn") - def test_describe_tasks(self, mock_conn, mock_describe_replication_tasks): - describe_tasks_kwargs = {"Filters": [FILTER]} - describe_task = DmsDescribeTasksOperator( - task_id="describe_tasks", describe_tasks_kwargs=describe_tasks_kwargs - ) - describe_task.execute(None) - - mock_describe_replication_tasks.assert_called_once_with(**describe_tasks_kwargs) - - @pytest.mark.db_test - @mock.patch.object(DmsHook, "describe_replication_tasks", return_value=(None, MOCK_RESPONSE)) - @mock.patch.object(DmsHook, "get_conn") - def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_tasks): - describe_task = DmsDescribeTasksOperator( - task_id="describe_tasks", dag=self.dag, describe_tasks_kwargs={"Filters": [FILTER]} - ) - - dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") - ti = TaskInstance(task=describe_task) - ti.dag_run = dag_run - marker, response = describe_task.execute(ti.get_template_context()) - - assert marker is None - assert response == MOCK_RESPONSE diff --git a/tests/providers/amazon/aws/operators/test_dms_start_task.py b/tests/providers/amazon/aws/operators/test_dms_start_task.py deleted file mode 100644 index e7d24e8a07..0000000000 --- a/tests/providers/amazon/aws/operators/test_dms_start_task.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.providers.amazon.aws.operators.dms import DmsStartTaskOperator - -TASK_ARN = "test_arn" -TASK_DATA = { - "replication_task_id": "task_id", - "source_endpoint_arn": "source_endpoint", - "target_endpoint_arn": "target_endpoint", - "replication_instance_arn": "replication_arn", - "migration_type": "full-load", - "table_mappings": {}, -} - - -class TestDmsStartTaskOperator: - def test_init(self): - dms_operator = DmsStartTaskOperator(task_id="start_task", replication_task_arn=TASK_ARN) - - assert dms_operator.replication_task_arn == TASK_ARN - assert dms_operator.start_replication_task_type == "start-replication" - - @mock.patch.object(DmsHook, "get_task_status", side_effect=("starting",)) - @mock.patch.object(DmsHook, "start_replication_task") - @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) - @mock.patch.object(DmsHook, "get_conn") - def test_start_task( - self, mock_conn, mock_create_replication_task, mock_start_replication_task, mock_get_task_status - ): - dms_hook = DmsHook() - task = dms_hook.create_replication_task(**TASK_DATA) - - start_task = DmsStartTaskOperator(task_id="start_task", replication_task_arn=task) - start_task.execute(None) - - mock_start_replication_task.assert_called_once_with( - replication_task_arn=TASK_ARN, - start_replication_task_type="start-replication", - ) - - assert dms_hook.get_task_status(TASK_ARN) == "starting" diff --git a/tests/providers/amazon/aws/operators/test_dms_stop_task.py b/tests/providers/amazon/aws/operators/test_dms_stop_task.py deleted file mode 100644 index 483dd6fef2..0000000000 --- a/tests/providers/amazon/aws/operators/test_dms_stop_task.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.providers.amazon.aws.operators.dms import DmsStopTaskOperator - -TASK_ARN = "test_arn" -TASK_DATA = { - "replication_task_id": "task_id", - "source_endpoint_arn": "source_endpoint", - "target_endpoint_arn": "target_endpoint", - "replication_instance_arn": "replication_arn", - "migration_type": "full-load", - "table_mappings": {}, -} - - -class TestDmsStopTaskOperator: - def test_init(self): - dms_operator = DmsStopTaskOperator(task_id="stop_task", replication_task_arn=TASK_ARN) - - assert dms_operator.replication_task_arn == TASK_ARN - - @mock.patch.object(DmsHook, "get_task_status", side_effect=("stopping",)) - @mock.patch.object(DmsHook, "stop_replication_task") - @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) - @mock.patch.object(DmsHook, "get_conn") - def test_stop_task( - self, mock_conn, mock_create_replication_task, mock_stop_replication_task, mock_get_task_status - ): - dms_hook = DmsHook() - task = dms_hook.create_replication_task(**TASK_DATA) - - stop_task = DmsStopTaskOperator(task_id="stop_task", replication_task_arn=task) - stop_task.execute(None) - - mock_stop_replication_task.assert_called_once_with(replication_task_arn=TASK_ARN) - - assert dms_hook.get_task_status(TASK_ARN) == "stopping" diff --git a/tests/providers/amazon/aws/sensors/test_dms_task.py b/tests/providers/amazon/aws/sensors/test_dms.py similarity index 100% rename from tests/providers/amazon/aws/sensors/test_dms_task.py rename to tests/providers/amazon/aws/sensors/test_dms.py