This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ba11017691d Adding SageMaker Transform extra link (#45677)
ba11017691d is described below
commit ba11017691d97c987d1eb169651549a6bcb82770
Author: ellisms <[email protected]>
AuthorDate: Wed Jan 15 12:47:41 2025 -0500
Adding SageMaker Transform extra link (#45677)
* Adding SageMaker Transform extra link
* Fixed link error; added test case
* Removed unnecesasry aws_conn_id causing db_tests error
---
.../providers/amazon/aws/links/sagemaker.py | 27 ++++++++++++++++
.../providers/amazon/aws/operators/sagemaker.py | 19 ++++++++++++
.../src/airflow/providers/amazon/provider.yaml | 1 +
providers/tests/amazon/aws/links/test_sagemaker.py | 36 ++++++++++++++++++++++
.../aws/operators/test_sagemaker_transform.py | 29 ++++++++++++++++-
5 files changed, 111 insertions(+), 1 deletion(-)
diff --git a/providers/src/airflow/providers/amazon/aws/links/sagemaker.py
b/providers/src/airflow/providers/amazon/aws/links/sagemaker.py
new file mode 100644
index 00000000000..c5d0deb72e6
--- /dev/null
+++ b/providers/src/airflow/providers/amazon/aws/links/sagemaker.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK,
BaseAwsLink
+
+
+class SageMakerTransformJobLink(BaseAwsLink):
+ """Helper class for constructing AWS Transform Run Details Link."""
+
+ name = "Amazon SageMaker Transform Job Details"
+ key = "sagemaker_transform_job_details"
+ format_str = BASE_AWS_CONSOLE_LINK +
"/sagemaker/home?region={region_name}#/transform-jobs/{job_name}"
diff --git a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py
b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py
index 76432ae7f3b..981c79f1960 100644
--- a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import datetime
import json
import time
+import urllib
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, ClassVar
@@ -34,6 +35,7 @@ from airflow.providers.amazon.aws.hooks.sagemaker import (
SageMakerHook,
secondary_training_status_message,
)
+from airflow.providers.amazon.aws.links.sagemaker import
SageMakerTransformJobLink
from airflow.providers.amazon.aws.triggers.sagemaker import (
SageMakerPipelineTrigger,
SageMakerTrigger,
@@ -659,6 +661,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""
+ operator_extra_links = (SageMakerTransformJobLink(),)
+
def __init__(
self,
*,
@@ -765,6 +769,21 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker transform Job creation failed:
{response}")
+ transform_job_url = SageMakerTransformJobLink.format_str.format(
+
aws_domain=SageMakerTransformJobLink.get_aws_domain(self.hook.conn_partition),
+ region_name=self.hook.conn_region_name,
+ job_name=urllib.parse.quote(transform_config["TransformJobName"],
safe=""),
+ )
+ SageMakerTransformJobLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.hook.conn_region_name,
+ aws_partition=self.hook.conn_partition,
+ job_name=urllib.parse.quote(transform_config["TransformJobName"],
safe=""),
+ )
+
+ self.log.info("You can monitor this SageMaker Transform job at %s",
transform_job_url)
+
if self.deferrable and self.wait_for_completion:
response =
self.hook.describe_transform_job(transform_config["TransformJobName"])
status = response["TransformJobStatus"]
diff --git a/providers/src/airflow/providers/amazon/provider.yaml
b/providers/src/airflow/providers/amazon/provider.yaml
index df76336c5ac..d89532b848c 100644
--- a/providers/src/airflow/providers/amazon/provider.yaml
+++ b/providers/src/airflow/providers/amazon/provider.yaml
@@ -884,6 +884,7 @@ extra-links:
- airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink
- airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink
- airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink
+ - airflow.providers.amazon.aws.links.sagemaker.SageMakerTransformJobLink
- airflow.providers.amazon.aws.links.step_function.StateMachineDetailsLink
-
airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink
diff --git a/providers/tests/amazon/aws/links/test_sagemaker.py
b/providers/tests/amazon/aws/links/test_sagemaker.py
new file mode 100644
index 00000000000..d656b3559bc
--- /dev/null
+++ b/providers/tests/amazon/aws/links/test_sagemaker.py
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.sagemaker import
SageMakerTransformJobLink
+
+from providers.tests.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase
+
+
+class TestSageMakerTransformDetailsLink(BaseAwsLinksTestCase):
+ link_class = SageMakerTransformJobLink
+
+ def test_extra_link(self):
+ self.assert_extra_link_url(
+ expected_url=(
+ "https://console.aws.amazon.com/sagemaker/home"
+ "?region=us-east-1#/transform-jobs/test_job_name"
+ ),
+ region_name="us-east-1",
+ aws_partition="aws",
+ job_name="test_job_name",
+ )
diff --git a/providers/tests/amazon/aws/operators/test_sagemaker_transform.py
b/providers/tests/amazon/aws/operators/test_sagemaker_transform.py
index 9f3bac20fb4..76f89cb710e 100644
--- a/providers/tests/amazon/aws/operators/test_sagemaker_transform.py
+++ b/providers/tests/amazon/aws/operators/test_sagemaker_transform.py
@@ -25,6 +25,7 @@ from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.links.sagemaker import
SageMakerTransformJobLink
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerTransformOperator
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
@@ -75,7 +76,6 @@ class TestSageMakerTransformOperator:
def setup_method(self):
self.sagemaker = SageMakerTransformOperator(
task_id="test_sagemaker_operator",
- aws_conn_id="sagemaker_test_id",
config=copy.deepcopy(CONFIG),
wait_for_completion=False,
check_interval=5,
@@ -128,6 +128,33 @@ class TestSageMakerTransformOperator:
max_ingestion_time=None,
)
+ @mock.patch.object(SageMakerHook, "describe_transform_job")
+ @mock.patch.object(SageMakerHook, "create_model")
+ @mock.patch.object(SageMakerHook, "describe_model")
+ @mock.patch.object(SageMakerHook, "create_transform_job")
+ # @mock.patch.object(sagemaker, "serialize", return_value="")
+ def test_log_correct_url(self, mock_transform, __, ___, mock_desc):
+ region = "us-east-1"
+ job_name = CONFIG["Transform"]["TransformJobName"]
+ mock_desc.side_effect = [
+ ClientError({"Error": {"Code": "ValidationException"}}, "op"),
+ {"ModelName": "model_name"},
+ ]
+ mock_transform.return_value = {
+ "TransformJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+
+ aws_domain = SageMakerTransformJobLink.get_aws_domain("aws")
+ job_run_url = (
+
f"https://console.{aws_domain}/sagemaker/home?region={region}#/transform-jobs/{job_name}"
+ )
+
+ with mock.patch.object(self.sagemaker.log, "info") as mock_log_info:
+ self.sagemaker.execute(None)
+ # assert job_run_id == JOB_RUN_ID
+ mock_log_info.assert_any_call("You can monitor this SageMaker
Transform job at %s", job_run_url)
+
@mock.patch.object(SageMakerHook, "describe_transform_job")
@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_transform_job")