This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new ba11017691d Adding SageMaker Transform extra link (#45677) ba11017691d is described below commit ba11017691d97c987d1eb169651549a6bcb82770 Author: ellisms <114107920+elli...@users.noreply.github.com> AuthorDate: Wed Jan 15 12:47:41 2025 -0500 Adding SageMaker Transform extra link (#45677) * Adding SageMaker Transform extra link * Fixed link error; added test case * Removed unnecesasry aws_conn_id causing db_tests error --- .../providers/amazon/aws/links/sagemaker.py | 27 ++++++++++++++++ .../providers/amazon/aws/operators/sagemaker.py | 19 ++++++++++++ .../src/airflow/providers/amazon/provider.yaml | 1 + providers/tests/amazon/aws/links/test_sagemaker.py | 36 ++++++++++++++++++++++ .../aws/operators/test_sagemaker_transform.py | 29 ++++++++++++++++- 5 files changed, 111 insertions(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/amazon/aws/links/sagemaker.py b/providers/src/airflow/providers/amazon/aws/links/sagemaker.py new file mode 100644 index 00000000000..c5d0deb72e6 --- /dev/null +++ b/providers/src/airflow/providers/amazon/aws/links/sagemaker.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class SageMakerTransformJobLink(BaseAwsLink): + """Helper class for constructing AWS Transform Run Details Link.""" + + name = "Amazon SageMaker Transform Job Details" + key = "sagemaker_transform_job_details" + format_str = BASE_AWS_CONSOLE_LINK + "/sagemaker/home?region={region_name}#/transform-jobs/{job_name}" diff --git a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py index 76432ae7f3b..981c79f1960 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -19,6 +19,7 @@ from __future__ import annotations import datetime import json import time +import urllib from collections.abc import Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, ClassVar @@ -34,6 +35,7 @@ from airflow.providers.amazon.aws.hooks.sagemaker import ( SageMakerHook, secondary_training_status_message, ) +from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink from airflow.providers.amazon.aws.triggers.sagemaker import ( SageMakerPipelineTrigger, SageMakerTrigger, @@ -659,6 +661,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator): :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ + operator_extra_links = (SageMakerTransformJobLink(),) + def __init__( self, *, @@ -765,6 +769,21 @@ class SageMakerTransformOperator(SageMakerBaseOperator): if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Sagemaker transform Job creation failed: {response}") + transform_job_url = SageMakerTransformJobLink.format_str.format( + aws_domain=SageMakerTransformJobLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + job_name=urllib.parse.quote(transform_config["TransformJobName"], safe=""), + ) + SageMakerTransformJobLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_name=urllib.parse.quote(transform_config["TransformJobName"], safe=""), + ) + + self.log.info("You can monitor this SageMaker Transform job at %s", transform_job_url) + if self.deferrable and self.wait_for_completion: response = self.hook.describe_transform_job(transform_config["TransformJobName"]) status = response["TransformJobStatus"] diff --git a/providers/src/airflow/providers/amazon/provider.yaml b/providers/src/airflow/providers/amazon/provider.yaml index df76336c5ac..d89532b848c 100644 --- a/providers/src/airflow/providers/amazon/provider.yaml +++ b/providers/src/airflow/providers/amazon/provider.yaml @@ -884,6 +884,7 @@ extra-links: - airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink - airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink - airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink + - airflow.providers.amazon.aws.links.sagemaker.SageMakerTransformJobLink - airflow.providers.amazon.aws.links.step_function.StateMachineDetailsLink - airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink diff --git a/providers/tests/amazon/aws/links/test_sagemaker.py b/providers/tests/amazon/aws/links/test_sagemaker.py new file mode 100644 index 00000000000..d656b3559bc --- /dev/null +++ b/providers/tests/amazon/aws/links/test_sagemaker.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink + +from providers.tests.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase + + +class TestSageMakerTransformDetailsLink(BaseAwsLinksTestCase): + link_class = SageMakerTransformJobLink + + def test_extra_link(self): + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/sagemaker/home" + "?region=us-east-1#/transform-jobs/test_job_name" + ), + region_name="us-east-1", + aws_partition="aws", + job_name="test_job_name", + ) diff --git a/providers/tests/amazon/aws/operators/test_sagemaker_transform.py b/providers/tests/amazon/aws/operators/test_sagemaker_transform.py index 9f3bac20fb4..76f89cb710e 100644 --- a/providers/tests/amazon/aws/operators/test_sagemaker_transform.py +++ b/providers/tests/amazon/aws/operators/test_sagemaker_transform.py @@ -25,6 +25,7 @@ from botocore.exceptions import ClientError from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger @@ -75,7 +76,6 @@ class TestSageMakerTransformOperator: def setup_method(self): self.sagemaker = SageMakerTransformOperator( task_id="test_sagemaker_operator", - aws_conn_id="sagemaker_test_id", config=copy.deepcopy(CONFIG), wait_for_completion=False, check_interval=5, @@ -128,6 +128,33 @@ class TestSageMakerTransformOperator: max_ingestion_time=None, ) + @mock.patch.object(SageMakerHook, "describe_transform_job") + @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "describe_model") + @mock.patch.object(SageMakerHook, "create_transform_job") + # @mock.patch.object(sagemaker, "serialize", return_value="") + def test_log_correct_url(self, mock_transform, __, ___, mock_desc): + region = "us-east-1" + job_name = CONFIG["Transform"]["TransformJobName"] + mock_desc.side_effect = [ + ClientError({"Error": {"Code": "ValidationException"}}, "op"), + {"ModelName": "model_name"}, + ] + mock_transform.return_value = { + "TransformJobArn": "test_arn", + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + aws_domain = SageMakerTransformJobLink.get_aws_domain("aws") + job_run_url = ( + f"https://console.{aws_domain}/sagemaker/home?region={region}#/transform-jobs/{job_name}" + ) + + with mock.patch.object(self.sagemaker.log, "info") as mock_log_info: + self.sagemaker.execute(None) + # assert job_run_id == JOB_RUN_ID + mock_log_info.assert_any_call("You can monitor this SageMaker Transform job at %s", job_run_url) + @mock.patch.object(SageMakerHook, "describe_transform_job") @mock.patch.object(SageMakerHook, "create_model") @mock.patch.object(SageMakerHook, "create_transform_job")