ellisms commented on code in PR #64274: URL: https://github.com/apache/airflow/pull/64274#discussion_r3041296005
########## providers/amazon/tests/unit/amazon/aws/operators/test_neptune_analytics.py: ########## @@ -0,0 +1,1314 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Generator +from unittest import mock + +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.neptune_analytics import NeptuneAnalyticsHook +from airflow.providers.amazon.aws.operators.neptune_analytics import ( + NeptuneCancelImportTaskOperator, + NeptuneCreateGraphOperator, + NeptuneCreateGraphWithImportOperator, + NeptuneCreatePrivateGraphEndpointOperator, + NeptuneDeleteGraphOperator, + NeptuneDeletePrivateGraphEndpointOperator, + NeptuneStartImportTaskOperator, +) +from airflow.providers.common.compat.sdk import TaskDeferred + +GRAPH_NAME = "test_graph" +GRAPH_ID = "test-graph-id" +VPC_ID = "vpc-12345" +SUBNET_IDS = ["subnet-1", "subnet-2"] +SECURITY_GROUP_IDS = ["sg-1", "sg-2"] +ENDPOINT_ID = "vpce-12345" +SOURCE_S3_URI = "s3://my-bucket/my-data/" +ROLE_ARN = "arn:aws:iam::123456789012:role/NeptuneImportRole" + + [email protected] +def hook() -> Generator[NeptuneAnalyticsHook, None, None]: + with mock_aws(): + yield NeptuneAnalyticsHook(aws_conn_id="aws_default") + + +class TestNeptuneCreateGraphOperator: + def test_template_fields(self): + # Verify template_fields includes the expected fields + fields = NeptuneCreateGraphOperator.template_fields + assert "graph_name" in fields + assert "vector_search_config" in fields + assert "provisioned_memory" in fields + + def test_template_fields_renderers(self): + assert NeptuneCreateGraphOperator.template_fields_renderers == {"vector_search_config": "json"} + + def test_operator_extra_links(self): + from airflow.providers.amazon.aws.links.neptune_analytics import NeptuneGraphLink + + assert len(NeptuneCreateGraphOperator.operator_extra_links) == 1 + assert isinstance(NeptuneCreateGraphOperator.operator_extra_links[0], NeptuneGraphLink) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + ) + + assert operator.public_connectivity is None + assert operator.replica_count is None + assert operator.deletion_protect is False + assert operator.kms_key is None + assert operator.tags is None + + operator.execute(None) + + mock_conn.create_graph.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"test": 123}, + provisionedMemory=16, + deletionProtection=False, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + public_connectivity=True, + replica_count=3, + kms_key_id="test-key", + tags={"key1": "test"}, + deletion_protection=True, + ) + + assert operator.public_connectivity is True + assert operator.replica_count == 3 + assert operator.deletion_protect is True + assert operator.kms_key == "test-key" + assert operator.tags == {"key1": "test"} + + operator.execute(None) + + mock_conn.create_graph.assert_called_once_with( + graphName=GRAPH_NAME, + vectorSearchConfiguration={"test": 123}, + replicaCount=3, + publicConnectivity=True, + provisionedMemory=16, + deletionProtection=True, + kmsKeyIdentifier="test-key", + tags={"key1": "test"}, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph(self, mock_hook_get_waiter, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + provisioned_memory=16, + vector_search_config={"test": 123}, + wait_for_completion=False, + ) + resp = operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + assert "graph_id" in resp + assert resp["graph_id"] == GRAPH_ID + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_graph_wait_for_completion(self, mock_hook_get_waiter, mock_conn, mock_persist): + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + provisioned_memory=16, + vector_search_config={"test": 123}, + wait_for_completion=True, + ) + resp = operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("graph_available") + assert "graph_id" in resp + assert resp["graph_id"] == GRAPH_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_persist_called_with_correct_args(self, mock_conn): + """Test that NeptuneGraphLink.persist is called with the correct arguments.""" + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + wait_for_completion=False, + ) + + mock_context = mock.MagicMock() + with mock.patch( + "airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist" + ) as mock_persist: + operator.execute(mock_context) + + mock_persist.assert_called_once_with( + context=mock_context, + operator=operator, + region_name=mock.ANY, + aws_partition=mock.ANY, + graph_id=GRAPH_ID, + ) + + @mock.patch("airflow.providers.amazon.aws.operators.neptune_analytics.NeptuneGraphLink.persist") + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_deferrable_defers_with_graph_available_trigger(self, mock_conn, mock_persist): + """Test that deferrable mode defers with NeptuneGraphAvailableTrigger.""" + from airflow.providers.amazon.aws.triggers.neptune_analytics import NeptuneGraphAvailableTrigger + + mock_conn.create_graph.return_value = {"id": GRAPH_ID, "status": "CREATING"} + + operator = NeptuneCreateGraphOperator( + task_id="test_task", + graph_name=GRAPH_NAME, + vector_search_config={"test": 123}, + provisioned_memory=16, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + operator.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, NeptuneGraphAvailableTrigger) + assert exc_info.value.method_name == "execute_complete" + + +class TestNeptuneCreatePrivateGraphEndpointOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id is None + assert operator.subnet_ids is None + assert operator.vpc_security_group_ids is None + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + result = operator.execute(None) + + mock_conn.create_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + ) + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + assert result is not None + assert result["vpc_endpoint_id"] == ENDPOINT_ID + assert result["graph_id"] == GRAPH_ID + assert result["vpc_id"] == VPC_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + subnet_ids=SUBNET_IDS, + vpc_security_group_ids=SECURITY_GROUP_IDS, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.subnet_ids == SUBNET_IDS + assert operator.vpc_security_group_ids == SECURITY_GROUP_IDS + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.create_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + subnetIds=SUBNET_IDS, + vpcSecurityGroupIds=SECURITY_GROUP_IDS, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_endpoint_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=False, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_create_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=True, + ) + result = operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("private_graph_endpoint_available") + mock_hook_get_waiter.return_value.wait.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, + ) + assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_create_endpoint_sets_vpc_id_from_response(self, mock_conn): + """When vpc_id is not provided, the operator should use the vpc_id from the API response.""" + mock_conn.create_private_graph_endpoint.return_value = { + "status": "CREATING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + ) + + assert operator.vpc_id is None + result = operator.execute(None) + + # vpc_id should be set from the create response + assert operator.vpc_id == VPC_ID + assert result["vpc_id"] == VPC_ID + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_create_endpoint_failed_status(self, mock_conn): + from airflow.providers.common.compat.sdk import AirflowException + + mock_conn.create_private_graph_endpoint.return_value = { + "status": "FAILED", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + with pytest.raises(AirflowException, match=f"Private endpoint failed to create for graph {GRAPH_ID}"): + operator.execute(None) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_execute_complete(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID + ) + + result = operator.execute_complete(None, {"status": "success"}) + + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + assert result == {"vpc_endpoint_id": ENDPOINT_ID, "graph_id": GRAPH_ID, "vpc_id": VPC_ID} + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_get_graph_endpoint_id(self, mock_conn): + mock_conn.get_private_graph_endpoint.return_value = { + "vpcEndpointId": ENDPOINT_ID, + } + + operator = NeptuneCreatePrivateGraphEndpointOperator( + task_id="test_task", graph_identifier=GRAPH_ID, vpc_id=VPC_ID + ) + + result = operator._get_graph_endpoint_id() + + assert result == ENDPOINT_ID + mock_conn.get_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + +class TestNeptuneDeletePrivateGraphEndpointOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.delete_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_custom_args(self, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + waiter_delay=60, + waiter_max_attempts=100, + ) + + assert operator.graph_identifier == GRAPH_ID + assert operator.vpc_id == VPC_ID + assert operator.waiter_delay == 60 + assert operator.waiter_max_attempts == 100 + + operator.execute(None) + + mock_conn.delete_private_graph_endpoint.assert_called_once_with( + graphIdentifier=GRAPH_ID, + vpcId=VPC_ID, + ) + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_endpoint_no_wait(self, mock_hook_get_waiter, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=False, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_not_called() + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + @mock.patch.object(NeptuneAnalyticsHook, "get_waiter") + def test_delete_endpoint_wait_for_completion(self, mock_hook_get_waiter, mock_conn): + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "DELETING", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + wait_for_completion=True, + ) + operator.execute(None) + + mock_hook_get_waiter.assert_called_once_with("private_graph_endpoint_deleted") + + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_delete_endpoint_failed_status(self, mock_conn): + from airflow.providers.common.compat.sdk import AirflowException + + mock_conn.delete_private_graph_endpoint.return_value = { + "status": "FAILED", + "vpcEndpointId": ENDPOINT_ID, + "vpcId": VPC_ID, + } + + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + with pytest.raises(AirflowException, match=f"Failed to delete private endpoint {ENDPOINT_ID}"): + operator.execute(None) + + def test_execute_complete_success(self): + operator = NeptuneDeletePrivateGraphEndpointOperator( + task_id="test_task", + graph_identifier=GRAPH_ID, + vpc_id=VPC_ID, + ) + + event = { + "status": "success", + "endpoint_id": ENDPOINT_ID, + } + + operator.execute_complete(None, event) + + # Verify the method completes without error and logs the endpoint_id + + +class TestNeptuneDeleteGraphOperator: + @mock.patch.object(NeptuneAnalyticsHook, "conn") + def test_init_defaults(self, mock_conn): + mock_conn.delete_graph.return_value = { + "id": GRAPH_ID, + "name": GRAPH_NAME, + "status": "DELETING", + } + + operator = NeptuneDeleteGraphOperator( + task_id="test_task", + graph_id=GRAPH_ID, + skip_snapshot=True, + ) + + assert operator.graph_id == GRAPH_ID + assert operator.skip_snapshot is True + assert operator.wait_for_completion is True + assert operator.waiter_delay == 30 + assert operator.waiter_max_attempts == 60 + + operator.execute(None) + + mock_conn.delete_graph.assert_called_once_with( + graphIdentifier=GRAPH_ID, + skipSnapshot=True, + ) Review Comment: I think this is a hallucination. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
