This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ad863ce510 Amazon Sagemaker Sample DAG and docs update (#23256)
ad863ce510 is described below
commit ad863ce51020992508c83c1dcf0624623b09bca6
Author: D. Ferruzzi <[email protected]>
AuthorDate: Sun May 8 14:37:51 2022 -0700
Amazon Sagemaker Sample DAG and docs update (#23256)
---
.../amazon/aws/example_dags/example_sagemaker.py | 504 ++++++++++++++++-----
.../aws/example_dags/example_sagemaker_endpoint.py | 230 ++++++++++
airflow/providers/amazon/aws/hooks/sagemaker.py | 45 +-
.../providers/amazon/aws/operators/sagemaker.py | 95 ++--
airflow/providers/amazon/aws/sensors/sagemaker.py | 65 +--
.../operators/sagemaker.rst | 203 +++++++--
docs/spelling_wordlist.txt | 1 +
7 files changed, 941 insertions(+), 202 deletions(-)
diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
index 6162b061db..df69013e1c 100644
--- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
@@ -14,164 +14,454 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import base64
+import os
+import subprocess
from datetime import datetime
-from os import environ
+from tempfile import NamedTemporaryFile
+
+import boto3
from airflow import DAG
+from airflow.decorators import task
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerDeleteModelOperator,
SageMakerModelOperator,
SageMakerProcessingOperator,
SageMakerTrainingOperator,
SageMakerTransformOperator,
+ SageMakerTuningOperator,
+)
+from airflow.providers.amazon.aws.sensors.sagemaker import (
+ SageMakerTrainingSensor,
+ SageMakerTransformSensor,
+ SageMakerTuningSensor,
+)
+
+# Project name will be used in naming the S3 buckets and various tasks.
+# The dataset used in this example is identifying varieties of the Iris flower.
+PROJECT_NAME = 'iris'
+TIMESTAMP = '{{ ts_nodash }}'
+
+S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket')
+RAW_DATA_S3_KEY = f'{PROJECT_NAME}/preprocessing/input.csv'
+INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data'
+TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/results'
+PREDICTION_OUTPUT_S3_KEY = f'{PROJECT_NAME}/transform'
+
+PROCESSING_LOCAL_INPUT_PATH = '/opt/ml/processing/input'
+PROCESSING_LOCAL_OUTPUT_PATH = '/opt/ml/processing/output'
+
+MODEL_NAME = f'{PROJECT_NAME}-KNN-model'
+# Job names can't be reused, so appending a timestamp ensures it is unique.
+PROCESSING_JOB_NAME = f'{PROJECT_NAME}-processing-{TIMESTAMP}'
+TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}'
+TRANSFORM_JOB_NAME = f'{PROJECT_NAME}-transform-{TIMESTAMP}'
+TUNING_JOB_NAME = f'{PROJECT_NAME}-tune-{TIMESTAMP}'
+
+ROLE_ARN = os.getenv(
+ 'SAGEMAKER_ROLE_ARN',
+ 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole',
)
+ECR_REPOSITORY = os.getenv('ECR_REPOSITORY',
'1234567890.dkr.ecr.us-west-2.amazonaws.com/process_data')
+REGION = ECR_REPOSITORY.split('.')[3]
+
+# For this example we are using a subset of Fischer's Iris Data Set.
+# The full dataset can be found at UC Irvine's machine learning repository:
+# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
+DATASET = """
+ 5.1,3.5,1.4,0.2,Iris-setosa
+ 4.9,3.0,1.4,0.2,Iris-setosa
+ 7.0,3.2,4.7,1.4,Iris-versicolor
+ 6.4,3.2,4.5,1.5,Iris-versicolor
+ 4.9,2.5,4.5,1.7,Iris-virginica
+ 7.3,2.9,6.3,1.8,Iris-virginica
+ """
+SAMPLE_SIZE = DATASET.count('\n') - 1
-MODEL_NAME = "sample_model"
-TRAINING_JOB_NAME = "sample_training"
-IMAGE_URI = environ.get("ECR_IMAGE_URI",
"123456789012.dkr.ecr.us-east-1.amazonaws.com/repo_name")
-S3_BUCKET = environ.get("BUCKET_NAME", "test-airflow-12345")
-ROLE = environ.get("SAGEMAKER_ROLE_ARN",
"arn:aws:iam::123456789012:role/role_name")
+# The URI of an Amazon-provided docker image for handling KNN model training.
This is a public ECR
+# repo cited in public SageMaker documentation, so the account number does not
need to be redacted.
+# For more info see:
https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title
+KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn'
+TASK_TIMEOUT = {'MaxRuntimeInSeconds': 6 * 60}
+
+RESOURCE_CONFIG = {
+ 'InstanceCount': 1,
+ 'InstanceType': 'ml.m5.large',
+ 'VolumeSizeInGB': 1,
+}
+
+TRAINING_DATA_SOURCE = {
+ 'CompressionType': 'None',
+ 'ContentType': 'text/csv',
+ 'DataSource': { # type: ignore
+ 'S3DataSource': {
+ 'S3DataDistributionType': 'FullyReplicated',
+ 'S3DataType': 'S3Prefix',
+ 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv',
+ }
+ },
+}
+
+# Define configs for processing, training, model creation, and batch transform
jobs
SAGEMAKER_PROCESSING_JOB_CONFIG = {
- "ProcessingJobName": "sample_processing_job",
- "ProcessingInputs": [
+ 'ProcessingJobName': PROCESSING_JOB_NAME,
+ 'RoleArn': f'{ROLE_ARN}',
+ 'ProcessingInputs': [
{
- "InputName": "input",
- "AppManaged": False,
- "S3Input": {
- "S3Uri": f"s3://{S3_BUCKET}/preprocessing/input/",
- "LocalPath": "/opt/ml/processing/input/",
- "S3DataType": "S3Prefix",
- "S3InputMode": "File",
- "S3DataDistributionType": "FullyReplicated",
- "S3CompressionType": "None",
+ 'InputName': 'input',
+ 'AppManaged': False,
+ 'S3Input': {
+ 'S3Uri': f's3://{S3_BUCKET}/{RAW_DATA_S3_KEY}',
+ 'LocalPath': PROCESSING_LOCAL_INPUT_PATH,
+ 'S3DataType': 'S3Prefix',
+ 'S3InputMode': 'File',
+ 'S3DataDistributionType': 'FullyReplicated',
+ 'S3CompressionType': 'None',
},
},
],
- "ProcessingOutputConfig": {
- "Outputs": [
+ 'ProcessingOutputConfig': {
+ 'Outputs': [
{
- "OutputName": "output",
- "S3Output": {
- "S3Uri": f"s3://{S3_BUCKET}/preprocessing/output/",
- "LocalPath": "/opt/ml/processing/output/",
- "S3UploadMode": "EndOfJob",
+ 'OutputName': 'output',
+ 'S3Output': {
+ 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}',
+ 'LocalPath': PROCESSING_LOCAL_OUTPUT_PATH,
+ 'S3UploadMode': 'EndOfJob',
},
- "AppManaged": False,
+ 'AppManaged': False,
}
]
},
- "ProcessingResources": {
- "ClusterConfig": {
- "InstanceCount": 1,
- "InstanceType": "ml.m5.large",
- "VolumeSizeInGB": 5,
- }
+ 'ProcessingResources': {
+ 'ClusterConfig': RESOURCE_CONFIG,
},
- "StoppingCondition": {"MaxRuntimeInSeconds": 3600},
- "AppSpecification": {
- "ImageUri": f"{IMAGE_URI}",
- "ContainerEntrypoint": ["python3", "./preprocessing.py"],
+ 'StoppingCondition': TASK_TIMEOUT,
+ 'AppSpecification': {
+ 'ImageUri': ECR_REPOSITORY,
},
- "RoleArn": f"{ROLE}",
}
-SAGEMAKER_TRAINING_JOB_CONFIG = {
- "AlgorithmSpecification": {
- "TrainingImage": f"{IMAGE_URI}",
+TRAINING_CONFIG = {
+ 'TrainingJobName': TRAINING_JOB_NAME,
+ 'RoleArn': ROLE_ARN,
+ 'AlgorithmSpecification': {
+ "TrainingImage": KNN_IMAGE_URI,
"TrainingInputMode": "File",
},
- "InputDataConfig": [
- {
- "ChannelName": "config",
- "DataSource": {
- "S3DataSource": {
- "S3DataType": "S3Prefix",
- "S3Uri": f"s3://{S3_BUCKET}/config/",
- "S3DataDistributionType": "FullyReplicated",
- }
- },
- "CompressionType": "None",
- "RecordWrapperType": "None",
- },
- ],
- "OutputDataConfig": {
- "KmsKeyId": "",
- "S3OutputPath": f"s3://{S3_BUCKET}/training/",
- },
- "ResourceConfig": {
- "InstanceType": "ml.m5.large",
- "InstanceCount": 1,
- "VolumeSizeInGB": 5,
+ 'HyperParameters': {
+ 'predictor_type': 'classifier',
+ 'feature_dim': '4',
+ 'k': '3',
+ 'sample_size': str(SAMPLE_SIZE),
},
- "StoppingCondition": {"MaxRuntimeInSeconds": 6000},
- "RoleArn": f"{ROLE}",
- "EnableNetworkIsolation": False,
- "EnableInterContainerTrafficEncryption": False,
- "EnableManagedSpotTraining": False,
- "TrainingJobName": TRAINING_JOB_NAME,
-}
-
-SAGEMAKER_CREATE_MODEL_CONFIG = {
- "ModelName": MODEL_NAME,
- "Containers": [
+ 'InputDataConfig': [
{
- "Image": f"{IMAGE_URI}",
- "Mode": "SingleModel",
- "ModelDataUrl":
f"s3://{S3_BUCKET}/training/{TRAINING_JOB_NAME}/output/model.tar.gz",
+ 'ChannelName': 'train',
+ **TRAINING_DATA_SOURCE, # type: ignore [arg-type]
}
],
- "ExecutionRoleArn": f"{ROLE}",
- "EnableNetworkIsolation": False,
+ 'OutputDataConfig': {'S3OutputPath':
f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'},
+ 'ResourceConfig': RESOURCE_CONFIG,
+ 'StoppingCondition': TASK_TIMEOUT,
}
-SAGEMAKER_INFERENCE_CONFIG = {
- "TransformJobName": "sample_transform_job",
- "ModelName": MODEL_NAME,
- "TransformInput": {
- "DataSource": {
- "S3DataSource": {
- "S3DataType": "S3Prefix",
- "S3Uri": f"s3://{S3_BUCKET}/config/config_date.yml",
+MODEL_CONFIG = {
+ 'ModelName': MODEL_NAME,
+ 'ExecutionRoleArn': ROLE_ARN,
+ 'PrimaryContainer': {
+ 'Mode': 'SingleModel',
+ 'Image': KNN_IMAGE_URI,
+ 'ModelDataUrl':
f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz',
+ },
+}
+
+TRANSFORM_CONFIG = {
+ 'TransformJobName': TRANSFORM_JOB_NAME,
+ 'ModelName': MODEL_NAME,
+ 'TransformInput': {
+ 'DataSource': {
+ 'S3DataSource': {
+ 'S3DataType': 'S3Prefix',
+ 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/test.csv',
}
},
- "ContentType": "application/x-yaml",
- "CompressionType": "None",
- "SplitType": "None",
+ 'SplitType': 'Line',
+ 'ContentType': 'text/csv',
+ },
+ 'TransformOutput': {'S3OutputPath':
f's3://{S3_BUCKET}/{PREDICTION_OUTPUT_S3_KEY}'},
+ 'TransformResources': {
+ 'InstanceCount': 1,
+ 'InstanceType': 'ml.m5.large',
},
- "TransformOutput": {"S3OutputPath":
f"s3://{S3_BUCKET}/inferencing/output/"},
- "TransformResources": {"InstanceType": "ml.m5.large", "InstanceCount": 1},
}
-# [START howto_operator_sagemaker]
+TUNING_CONFIG = {
+ 'HyperParameterTuningJobName': TUNING_JOB_NAME,
+ 'HyperParameterTuningJobConfig': {
+ 'Strategy': 'Bayesian',
+ 'HyperParameterTuningJobObjective': {
+ 'MetricName': 'test:accuracy',
+ 'Type': 'Maximize',
+ },
+ 'ResourceLimits': {
+ # You would bump these up in production as appropriate.
+ 'MaxNumberOfTrainingJobs': 1,
+ 'MaxParallelTrainingJobs': 1,
+ },
+ 'ParameterRanges': {
+ 'CategoricalParameterRanges': [],
+ 'IntegerParameterRanges': [
+ # Set the min and max values of the hyperparameters you want
to tune.
+ {
+ 'Name': 'k',
+ 'MinValue': '1',
+ "MaxValue": str(SAMPLE_SIZE),
+ },
+ {
+ 'Name': 'sample_size',
+ 'MinValue': '1',
+ 'MaxValue': str(SAMPLE_SIZE),
+ },
+ ],
+ },
+ },
+ 'TrainingJobDefinition': {
+ 'StaticHyperParameters': {
+ 'predictor_type': 'classifier',
+ 'feature_dim': '4',
+ },
+ 'AlgorithmSpecification': {'TrainingImage': KNN_IMAGE_URI,
'TrainingInputMode': 'File'},
+ 'InputDataConfig': [
+ {
+ 'ChannelName': 'train',
+ **TRAINING_DATA_SOURCE, # type: ignore [arg-type]
+ },
+ {
+ 'ChannelName': 'test',
+ **TRAINING_DATA_SOURCE, # type: ignore [arg-type]
+ },
+ ],
+ 'OutputDataConfig': {'S3OutputPath':
f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}'},
+ 'ResourceConfig': RESOURCE_CONFIG,
+ 'StoppingCondition': TASK_TIMEOUT,
+ 'RoleArn': ROLE_ARN,
+ },
+}
+
+
+# This script will be the entrypoint for the docker image which will handle
preprocessing the raw data
+# NOTE: The following string must remain dedented as it is being written to a
file.
+PREPROCESS_SCRIPT = (
+ """
+import boto3
+import numpy as np
+import pandas as pd
+
+def main():
+ # Load the Iris dataset from {input_path}/input.csv, split it into
train/test
+ # subsets, and write them to {output_path}/ for the Processing Operator.
+
+ columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width',
'species']
+ iris = pd.read_csv('{input_path}/input.csv', names=columns)
+
+ # Process data
+ iris['species'] = iris['species'].replace({{'Iris-virginica': 0,
'Iris-versicolor': 1, 'Iris-setosa': 2}})
+ iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length',
'petal_width']]
+
+ # Split into test and train data
+ iris_train, iris_test = np.split(
+ iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 *
len(iris))]
+ )
+
+ # Remove the "answers" from the test set
+ iris_test.drop(['species'], axis=1, inplace=True)
+
+ # Write the splits to disk
+ iris_train.to_csv('{output_path}/train.csv', index=False, header=False)
+ iris_test.to_csv('{output_path}/test.csv', index=False, header=False)
+
+ print('Preprocessing Done.')
+
+if __name__ == "__main__":
+ main()
+
+ """
+).format(input_path=PROCESSING_LOCAL_INPUT_PATH,
output_path=PROCESSING_LOCAL_OUTPUT_PATH)
+
+
+@task
+def upload_dataset_to_s3():
+ """Uploads the provided dataset to a designated Amazon S3 bucket."""
+ S3Hook().load_string(
+ string_data=DATASET,
+ bucket_name=S3_BUCKET,
+ key=RAW_DATA_S3_KEY,
+ replace=True,
+ )
+
+
+@task
+def build_and_upload_docker_image():
+ """
+ We need a Docker image with the following requirements:
+ - Has numpy, pandas, requests, and boto3 installed
+ - Has our data preprocessing script mounted and set as the entry point
+ """
+
+ # Fetch and parse ECR Token to be used for the docker push
+ ecr_client = boto3.client('ecr', region_name=REGION)
+ token = ecr_client.get_authorization_token()
+ credentials =
(base64.b64decode(token['authorizationData'][0]['authorizationToken'])).decode('utf-8')
+ username, password = credentials.split(':')
+
+ with NamedTemporaryFile(mode='w+t') as preprocessing_script,
NamedTemporaryFile(mode='w+t') as dockerfile:
+
+ preprocessing_script.write(PREPROCESS_SCRIPT)
+ preprocessing_script.flush()
+
+ dockerfile.write(
+ f"""
+ FROM amazonlinux
+ COPY {preprocessing_script.name.split('/')[2]} /preprocessing.py
+ ADD credentials /credentials
+ ENV AWS_SHARED_CREDENTIALS_FILE=/credentials
+ RUN yum install python3 pip -y
+ RUN pip3 install boto3 pandas requests
+ CMD [ "python3", "/preprocessing.py"]
+ """
+ )
+ dockerfile.flush()
+
+ docker_build_and_push_commands = f"""
+ cp /root/.aws/credentials /tmp/credentials &&
+ docker build -f {dockerfile.name} -t {ECR_REPOSITORY} /tmp &&
+ rm /tmp/credentials &&
+ aws ecr get-login-password --region {REGION} |
+ docker login --username {username} --password {password}
{ECR_REPOSITORY} &&
+ docker push {ECR_REPOSITORY}
+ """
+ docker_build = subprocess.Popen(
+ docker_build_and_push_commands,
+ shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ _, err = docker_build.communicate()
+
+ if docker_build.returncode != 0:
+ raise RuntimeError(err)
+
+
+@task(trigger_rule='all_done')
+def cleanup():
+ # Delete S3 Artifacts
+ client = boto3.client('s3')
+ object_keys = [
+ key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET,
Prefix=PROJECT_NAME)['Contents']
+ ]
+ for key in object_keys:
+ client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key':
key}]})
+
+
with DAG(
- "sample_sagemaker_dag",
+ dag_id='example_sagemaker',
schedule_interval=None,
- start_date=datetime(2022, 2, 21),
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
catchup=False,
) as dag:
- sagemaker_processing_task = SageMakerProcessingOperator(
+
+ # [START howto_operator_sagemaker_processing]
+ preprocess_raw_data = SageMakerProcessingOperator(
+ task_id='preprocess_raw_data',
config=SAGEMAKER_PROCESSING_JOB_CONFIG,
- aws_conn_id="aws_default",
- task_id="sagemaker_preprocessing_task",
+ do_xcom_push=False,
)
+ # [END howto_operator_sagemaker_processing]
- training_task = SageMakerTrainingOperator(
- config=SAGEMAKER_TRAINING_JOB_CONFIG, aws_conn_id="aws_default",
task_id="sagemaker_training_task"
+ # [START howto_operator_sagemaker_training]
+ train_model = SageMakerTrainingOperator(
+ task_id='train_model',
+ config=TRAINING_CONFIG,
+ # Waits by default, setting as False to demonstrate the Sensor below.
+ wait_for_completion=False,
+ do_xcom_push=False,
)
+ # [END howto_operator_sagemaker_training]
- model_create_task = SageMakerModelOperator(
- config=SAGEMAKER_CREATE_MODEL_CONFIG, aws_conn_id="aws_default",
task_id="sagemaker_create_model_task"
+ # [START howto_operator_sagemaker_training_sensor]
+ await_training = SageMakerTrainingSensor(
+ task_id='await_training',
+ job_name=TRAINING_JOB_NAME,
)
+ # [END howto_operator_sagemaker_training_sensor]
- inference_task = SageMakerTransformOperator(
- config=SAGEMAKER_INFERENCE_CONFIG, aws_conn_id="aws_default",
task_id="sagemaker_inference_task"
+ # [START howto_operator_sagemaker_model]
+ create_model = SageMakerModelOperator(
+ task_id='create_model',
+ config=MODEL_CONFIG,
+ do_xcom_push=False,
)
+ # [END howto_operator_sagemaker_model]
- model_delete_task = SageMakerDeleteModelOperator(
- task_id="sagemaker_delete_model_task", config={'ModelName':
MODEL_NAME}, aws_conn_id="aws_default"
+ # [START howto_operator_sagemaker_tuning]
+ tune_model = SageMakerTuningOperator(
+ task_id='tune_model',
+ config=TUNING_CONFIG,
+ # Waits by default, setting as False to demonstrate the Sensor below.
+ wait_for_completion=False,
+ do_xcom_push=False,
)
+ # [END howto_operator_sagemaker_tuning]
- sagemaker_processing_task >> training_task >> model_create_task >>
inference_task >> model_delete_task
- # [END howto_operator_sagemaker]
+ # [START howto_operator_sagemaker_tuning_sensor]
+ await_tune = SageMakerTuningSensor(
+ task_id='await_tuning',
+ job_name=TUNING_JOB_NAME,
+ )
+ # [END howto_operator_sagemaker_tuning_sensor]
+
+ # [START howto_operator_sagemaker_transform]
+ test_model = SageMakerTransformOperator(
+ task_id='test_model',
+ config=TRANSFORM_CONFIG,
+ # Waits by default, setting as False to demonstrate the Sensor below.
+ wait_for_completion=False,
+ do_xcom_push=False,
+ )
+ # [END howto_operator_sagemaker_transform]
+
+ # [START howto_operator_sagemaker_transform_sensor]
+ await_transform = SageMakerTransformSensor(
+ task_id='await_transform',
+ job_name=TRANSFORM_JOB_NAME,
+ )
+ # [END howto_operator_sagemaker_transform_sensor]
+
+ # Trigger rule set to "all_done" so clean up will run regardless of
success on other tasks.
+ # [START howto_operator_sagemaker_delete_model]
+ delete_model = SageMakerDeleteModelOperator(
+ task_id='delete_model',
+ config={'ModelName': MODEL_NAME},
+ trigger_rule='all_done',
+ )
+ # [END howto_operator_sagemaker_delete_model]
+
+ (
+ upload_dataset_to_s3()
+ >> build_and_upload_docker_image()
+ >> preprocess_raw_data
+ >> train_model
+ >> await_training
+ >> create_model
+ >> tune_model
+ >> await_tune
+ >> test_model
+ >> await_transform
+ >> cleanup()
+ >> delete_model
+ )
diff --git
a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
new file mode 100644
index 0000000000..b4207a9b16
--- /dev/null
+++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
@@ -0,0 +1,230 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import os
+from datetime import datetime
+
+import boto3
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator
+from airflow.providers.amazon.aws.operators.sagemaker import (
+ SageMakerDeleteModelOperator,
+ SageMakerEndpointConfigOperator,
+ SageMakerEndpointOperator,
+ SageMakerModelOperator,
+ SageMakerTrainingOperator,
+)
+from airflow.providers.amazon.aws.sensors.sagemaker import
SageMakerEndpointSensor
+
+# Project name will be used in naming the S3 buckets and various tasks.
+# The dataset used in this example is identifying varieties of the Iris flower.
+PROJECT_NAME = 'iris'
+TIMESTAMP = '{{ ts_nodash }}'
+
+S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket')
+ROLE_ARN = os.getenv(
+ 'SAGEMAKER_ROLE_ARN',
+ 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole',
+)
+INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data'
+TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/training-results'
+
+MODEL_NAME = f'{PROJECT_NAME}-KNN-model'
+ENDPOINT_NAME = f'{PROJECT_NAME}-endpoint'
+# Job names can't be reused, so appending a timestamp ensures it is unique.
+ENDPOINT_CONFIG_JOB_NAME = f'{PROJECT_NAME}-endpoint-config-{TIMESTAMP}'
+TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}'
+
+# For an example of how to obtain the following train and test data, please see
+#
https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+TRAIN_DATA =
'0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n'
+SAMPLE_TEST_DATA = '6.4,3.2,4.5,1.5'
+
+# The URI of an Amazon-provided docker image for handling KNN model training.
This is a public ECR
+# repo cited in public SageMaker documentation, so the account number does not
need to be redacted.
+# For more info see:
https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title
+KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn'
+
+# Define configs for processing, training, model creation, and batch transform
jobs
+TRAINING_CONFIG = {
+ 'TrainingJobName': TRAINING_JOB_NAME,
+ 'RoleArn': ROLE_ARN,
+ 'AlgorithmSpecification': {
+ "TrainingImage": KNN_IMAGE_URI,
+ "TrainingInputMode": "File",
+ },
+ 'HyperParameters': {
+ 'predictor_type': 'classifier',
+ 'feature_dim': '4',
+ 'k': '3',
+ 'sample_size': '6',
+ },
+ 'InputDataConfig': [
+ {
+ 'ChannelName': 'train',
+ 'CompressionType': 'None',
+ 'ContentType': 'text/csv',
+ 'DataSource': {
+ 'S3DataSource': {
+ 'S3DataDistributionType': 'FullyReplicated',
+ 'S3DataType': 'S3Prefix',
+ 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv',
+ }
+ },
+ }
+ ],
+ 'OutputDataConfig': {'S3OutputPath':
f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'},
+ 'ResourceConfig': {
+ 'InstanceCount': 1,
+ 'InstanceType': 'ml.m5.large',
+ 'VolumeSizeInGB': 1,
+ },
+ 'StoppingCondition': {'MaxRuntimeInSeconds': 6 * 60},
+}
+
+MODEL_CONFIG = {
+ 'ModelName': MODEL_NAME,
+ 'ExecutionRoleArn': ROLE_ARN,
+ 'PrimaryContainer': {
+ 'Mode': 'SingleModel',
+ 'Image': KNN_IMAGE_URI,
+ 'ModelDataUrl':
f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz',
+ },
+}
+
+ENDPOINT_CONFIG_CONFIG = {
+ 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME,
+ 'ProductionVariants': [
+ {
+ 'VariantName': f'{PROJECT_NAME}-demo',
+ 'ModelName': MODEL_NAME,
+ 'InstanceType': 'ml.t2.medium',
+ 'InitialInstanceCount': 1,
+ },
+ ],
+}
+
+DEPLOY_ENDPOINT_CONFIG = {
+ 'EndpointName': ENDPOINT_NAME,
+ 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME,
+}
+
+
+@task
+def call_endpoint():
+ runtime = boto3.Session().client('sagemaker-runtime')
+
+ response = runtime.invoke_endpoint(
+ EndpointName=ENDPOINT_NAME,
+ ContentType='text/csv',
+ Body=SAMPLE_TEST_DATA,
+ )
+
+ return json.loads(response["Body"].read().decode())['predictions']
+
+
+@task(trigger_rule='all_done')
+def cleanup():
+ # Delete Endpoint and Endpoint Config
+ client = boto3.client('sagemaker')
+ endpoint_config_name =
client.list_endpoint_configs()['EndpointConfigs'][0]['EndpointConfigName']
+ client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
+ client.delete_endpoint(EndpointName=ENDPOINT_NAME)
+
+ # Delete S3 Artifacts
+ client = boto3.client('s3')
+ object_keys = [
+ key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET,
Prefix=PROJECT_NAME)['Contents']
+ ]
+ for key in object_keys:
+ client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key':
key}]})
+
+
+with DAG(
+ dag_id='example_sagemaker_endpoint',
+ schedule_interval=None,
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
+ catchup=False,
+) as dag:
+
+ upload_data = S3CreateObjectOperator(
+ task_id='upload_data',
+ s3_bucket=S3_BUCKET,
+ s3_key=f'{INPUT_DATA_S3_KEY}/train.csv',
+ data=TRAIN_DATA,
+ replace=True,
+ )
+
+ train_model = SageMakerTrainingOperator(
+ task_id='train_model',
+ config=TRAINING_CONFIG,
+ do_xcom_push=False,
+ )
+
+ create_model = SageMakerModelOperator(
+ task_id='create_model',
+ config=MODEL_CONFIG,
+ do_xcom_push=False,
+ )
+
+ # [START howto_operator_sagemaker_endpoint_config]
+ configure_endpoint = SageMakerEndpointConfigOperator(
+ task_id='configure_endpoint',
+ config=ENDPOINT_CONFIG_CONFIG,
+ do_xcom_push=False,
+ )
+ # [END howto_operator_sagemaker_endpoint_config]
+
+ # [START howto_operator_sagemaker_endpoint]
+ deploy_endpoint = SageMakerEndpointOperator(
+ task_id='deploy_endpoint',
+ config=DEPLOY_ENDPOINT_CONFIG,
+ # Waits by default, <setting as False to demonstrate the Sensor below.
+ wait_for_completion=False,
+ do_xcom_push=False,
+ )
+ # [END howto_operator_sagemaker_endpoint]
+
+ # [START howto_operator_sagemaker_endpoint_sensor]
+ await_endpoint = SageMakerEndpointSensor(
+ task_id='await_endpoint',
+ endpoint_name=ENDPOINT_NAME,
+ do_xcom_push=False,
+ )
+ # [END howto_operator_sagemaker_endpoint_sensor]
+
+ # Trigger rule set to "all_done" so clean up will run regardless of
success on other tasks.
+ delete_model = SageMakerDeleteModelOperator(
+ task_id='delete_model',
+ config={'ModelName': MODEL_NAME},
+ trigger_rule='all_done',
+ )
+
+ (
+ upload_data
+ >> train_model
+ >> create_model
+ >> configure_endpoint
+ >> deploy_endpoint
+ >> await_endpoint
+ >> call_endpoint()
+ >> cleanup()
+ >> delete_model
+ )
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 73348c8555..2c8c28a738 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -310,7 +310,8 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time: Optional[int] = None,
):
"""
- Create a training job
+ Starts a model training job. After training completes, Amazon
SageMaker saves
+ the resulting model artifacts to an Amazon S3 location that you
specify.
:param config: the config for training
:param wait_for_completion: if the program should keep running until
job finishes
@@ -357,7 +358,11 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time: Optional[int] = None,
):
"""
- Create a tuning job
+ Starts a hyperparameter tuning job. A hyperparameter tuning job finds
the
+ best version of a model by running many training jobs on your dataset
using
+ the algorithm you choose and values for hyperparameters within ranges
that
+ you specify. It then chooses the hyperparameter values that result in
a model
+ that performs the best, as measured by an objective metric that you
choose.
:param config: the config for tuning
:param wait_for_completion: if the program should keep running until
job finishes
@@ -389,7 +394,8 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time: Optional[int] = None,
):
"""
- Create a transform job
+ Starts a transform job. A transform job uses a trained model to get
inferences
+ on a dataset and saves these results to an Amazon S3 location that you
specify.
:param config: the config for transform job
:param wait_for_completion: if the program should keep running until
job finishes
@@ -422,7 +428,10 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time: Optional[int] = None,
):
"""
- Create a processing job
+ Use Amazon SageMaker Processing to analyze data and evaluate machine
learning
+ models on Amazon SageMaker. With Processing, you can use a simplified,
managed
+ experience on SageMaker to run your data processing workloads, such as
feature
+ engineering, data validation, model evaluation, and model
interpretation.
:param config: the config for processing job
:param wait_for_completion: if the program should keep running until
job finishes
@@ -446,7 +455,10 @@ class SageMakerHook(AwsBaseHook):
def create_model(self, config: dict):
"""
- Create a model job
+ Creates a model in Amazon SageMaker. In the request, you name the
model and
+ describe a primary container. For the primary container, you specify
the Docker
+ image that contains inference code, artifacts (from prior training),
and a custom
+ environment map that the inference code uses when you deploy the model
for predictions.
:param config: the config for model
:return: A response to model creation
@@ -455,7 +467,14 @@ class SageMakerHook(AwsBaseHook):
def create_endpoint_config(self, config: dict):
"""
- Create an endpoint config
+ Creates an endpoint configuration that Amazon SageMaker hosting
+ services uses to deploy models. In the configuration, you identify
+ one or more models, created using the CreateModel API, to deploy and
+ the resources that you want Amazon SageMaker to provision.
+
+ .. seealso::
+
:class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_model`
+
:class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_endpoint`
:param config: the config for endpoint-config
:return: A response to endpoint config creation
@@ -470,7 +489,15 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time: Optional[int] = None,
):
"""
- Create an endpoint
+ When you create a serverless endpoint, SageMaker provisions and manages
+ the compute resources for you. Then, you can make inference requests to
+ the endpoint and receive model predictions in response. SageMaker
scales
+ the compute resources up and down as needed to handle your request
traffic.
+
+ Requires an Endpoint Config.
+ .. seealso::
+
:class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_endpoint_config`
+
:param config: the config for endpoint
:param wait_for_completion: if the program should keep running until
job finishes
@@ -501,7 +528,9 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time: Optional[int] = None,
):
"""
- Update an endpoint
+ Deploys the new EndpointConfig specified in the request, switches to
using
+ newly created endpoint, and then deletes resources provisioned for the
+ endpoint using the previous EndpointConfig (there is no availability
loss).
:param config: the config for endpoint
:param wait_for_completion: if the program should keep running until
job finishes
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 071c167e9a..11be2e7a83 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -105,14 +105,18 @@ class SageMakerBaseOperator(BaseOperator):
class SageMakerProcessingOperator(SageMakerBaseOperator):
- """Initiate a SageMaker processing job.
+ """
+ Use Amazon SageMaker Processing to analyze data and evaluate machine
learning
+ models on Amazon SageMake. With Processing, you can use a simplified,
managed
+ experience on SageMaker to run your data processing workloads, such as
feature
+ engineering, data validation, model evaluation, and model interpretation.
- This operator returns The ARN of the processing job created in Amazon
SageMaker.
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerProcessingOperator`
:param config: The configuration necessary to start a processing job
(templated).
-
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_processing_job`
- :param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: If wait is set to True, the time interval, in
seconds,
that the operation waits to check the status of the processing job.
:param print_log: if the operator should print the cloudwatch log during
processing
@@ -123,13 +127,13 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
the operation does not timeout.
:param action_if_job_exists: Behaviour if the job name already exists.
Possible options are "increment"
(default) and "fail".
+ :return Dict: Returns The ARN of the processing job created in Amazon
SageMaker.
"""
def __init__(
self,
*,
config: dict,
- aws_conn_id: str,
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = 30,
@@ -137,7 +141,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
action_if_job_exists: str = 'increment',
**kwargs,
):
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(config=config, **kwargs)
if action_if_job_exists not in ('increment', 'fail'):
raise AirflowException(
f"Argument action_if_job_exists accepts only 'increment' and
'fail'. \
@@ -185,14 +189,20 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
"""
- Create a SageMaker endpoint config.
+ Creates an endpoint configuration that Amazon SageMaker hosting
+ services uses to deploy models. In the configuration, you identify
+ one or more models, created using the CreateModel API, to deploy and
+ the resources that you want Amazon SageMaker to provision.
- This operator returns The ARN of the endpoint config created in Amazon
SageMaker
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerEndpointConfigOperator`
:param config: The configuration necessary to create an endpoint config.
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_endpoint_config`
:param aws_conn_id: The AWS connection ID to use.
+ :return Dict: Returns The ARN of the endpoint config created in Amazon
SageMaker.
"""
integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
@@ -213,9 +223,16 @@ class
SageMakerEndpointConfigOperator(SageMakerBaseOperator):
class SageMakerEndpointOperator(SageMakerBaseOperator):
"""
- Create a SageMaker endpoint.
+ When you create a serverless endpoint, SageMaker provisions and manages
+ the compute resources for you. Then, you can make inference requests to
+ the endpoint and receive model predictions in response. SageMaker scales
+ the compute resources up and down as needed to handle your request traffic.
+
+ Requires an Endpoint Config.
- This operator returns The ARN of the endpoint created in Amazon SageMaker
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerEndpointOperator`
:param config:
The configuration necessary to create an endpoint.
@@ -242,13 +259,13 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
For details of the configuration parameter of endpoint_configuration
see
:py:meth:`SageMaker.Client.create_endpoint`
- :param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: Whether the operator should wait until the
endpoint creation finishes.
:param check_interval: If wait is set to True, this is the time interval,
in seconds, that this operation
waits before polling the status of the endpoint creation.
:param max_ingestion_time: If wait is set to True, this operation fails if
the endpoint creation doesn't
finish within max_ingestion_time seconds. If you set this parameter to
None it never times out.
:param operation: Whether to create an endpoint or update an endpoint.
Must be either 'create or 'update'.
+ :return Dict: Returns The ARN of the endpoint created in Amazon SageMaker.
"""
def __init__(
@@ -331,9 +348,13 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
class SageMakerTransformOperator(SageMakerBaseOperator):
- """Initiate a SageMaker transform job.
+ """
+ Starts a transform job. A transform job uses a trained model to get
inferences
+ on a dataset and saves these results to an Amazon S3 location that you
specify.
- This operator returns The ARN of the model created in Amazon SageMaker.
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerTransformOperator`
:param config: The configuration necessary to start a transform job
(templated).
@@ -354,13 +375,13 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
For details of the configuration parameter of model_config, See:
:py:meth:`SageMaker.Client.create_model`
- :param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: Set to True to wait until the transform job
finishes.
:param check_interval: If wait is set to True, the time interval, in
seconds,
that this operation waits to check the status of the transform job.
:param max_ingestion_time: If wait is set to True, the operation fails
if the transform job doesn't finish within max_ingestion_time seconds.
If you
set this parameter to None, the operation does not timeout.
+ :return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""
def __init__(
@@ -422,21 +443,28 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
class SageMakerTuningOperator(SageMakerBaseOperator):
- """Initiate a SageMaker hyperparameter tuning job.
+ """
+ Starts a hyperparameter tuning job. A hyperparameter tuning job finds the
+ best version of a model by running many training jobs on your dataset using
+ the algorithm you choose and values for hyperparameters within ranges that
+ you specify. It then chooses the hyperparameter values that result in a
model
+ that performs the best, as measured by an objective metric that you choose.
- This operator returns The ARN of the tuning job created in Amazon
SageMaker.
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerTuningOperator`
:param config: The configuration necessary to start a tuning job
(templated).
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
- :param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: Set to True to wait until the tuning job
finishes.
:param check_interval: If wait is set to True, the time interval, in
seconds,
that this operation waits to check the status of the tuning job.
:param max_ingestion_time: If wait is set to True, the operation fails
if the tuning job doesn't finish within max_ingestion_time seconds. If
you
set this parameter to None, the operation does not timeout.
+ :return Dict: Returns The ARN of the tuning job created in Amazon
SageMaker.
"""
integer_fields = [
@@ -487,14 +515,20 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
class SageMakerModelOperator(SageMakerBaseOperator):
- """Create a SageMaker model.
+ """
+ Creates a model in Amazon SageMaker. In the request, you name the model and
+ describe a primary container. For the primary container, you specify the
Docker
+ image that contains inference code, artifacts (from prior training), and a
custom
+ environment map that the inference code uses when you deploy the model for
predictions.
- This operator returns The ARN of the model created in Amazon SageMaker
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerModelOperator`
:param config: The configuration necessary to create a model.
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_model`
- :param aws_conn_id: The AWS connection ID to use.
+ :return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""
def __init__(self, *, config, **kwargs):
@@ -518,14 +552,16 @@ class SageMakerModelOperator(SageMakerBaseOperator):
class SageMakerTrainingOperator(SageMakerBaseOperator):
"""
- Initiate a SageMaker training job.
+ Starts a model training job. After training completes, Amazon SageMaker
saves
+ the resulting model artifacts to an Amazon S3 location that you specify.
- This operator returns The ARN of the training job created in Amazon
SageMaker.
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerTrainingOperator`
:param config: The configuration necessary to start a training job
(templated).
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_training_job`
- :param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: If wait is set to True, the time interval, in
seconds,
that the operation waits to check the status of the training job.
:param print_log: if the operator should print the cloudwatch log during
training
@@ -539,6 +575,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
:param action_if_job_exists: Behaviour if the job name already exists.
Possible options are "increment"
(default) and "fail".
This is only relevant if check_if
+ :return Dict: Returns The ARN of the training job created in Amazon
SageMaker.
"""
integer_fields = [
@@ -611,19 +648,19 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
class SageMakerDeleteModelOperator(SageMakerBaseOperator):
- """Deletes a SageMaker model.
+ """
+ Deletes a SageMaker model.
- This operator deletes the Model entry created in SageMaker.
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerDeleteModelOperator`
:param config: The configuration necessary to delete the model.
-
For details of the configuration parameter see
:py:meth:`SageMaker.Client.delete_model`
- :param aws_conn_id: The AWS connection ID to use.
"""
- def __init__(self, *, config, aws_conn_id: str, **kwargs):
+ def __init__(self, *, config, **kwargs):
super().__init__(config=config, **kwargs)
- self.aws_conn_id = aws_conn_id
self.config = config
def execute(self, context: 'Context') -> Any:
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py
b/airflow/providers/amazon/aws/sensors/sagemaker.py
index 054b139cc2..3cf6dceef1 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker.py
@@ -27,10 +27,10 @@ if TYPE_CHECKING:
class SageMakerBaseSensor(BaseSensorOperator):
- """Contains general sensor behavior for SageMaker.
+ """
+ Contains general sensor behavior for SageMaker.
- Subclasses should implement get_sagemaker_response()
- and state_from_response() methods.
+ Subclasses should implement get_sagemaker_response() and
state_from_response() methods.
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE
methods.
"""
@@ -84,13 +84,15 @@ class SageMakerBaseSensor(BaseSensorOperator):
class SageMakerEndpointSensor(SageMakerBaseSensor):
- """Asks for the state of the endpoint state until it reaches a
- terminal state.
- If it fails the sensor errors, the task fails.
-
+ """
+ Polls the endpoint state until it reaches a terminal state. Raises an
+ AirflowException with the failure reason if a failed state is reached.
- :param job_name: job_name of the endpoint instance to check the state of
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+ :ref:`howto/sensor:SageMakerEndpointSensor`
+ :param endpoint_name: Name of the endpoint instance to watch.
"""
template_fields: Sequence[str] = ('endpoint_name',)
@@ -118,15 +120,15 @@ class SageMakerEndpointSensor(SageMakerBaseSensor):
class SageMakerTransformSensor(SageMakerBaseSensor):
- """Asks for the state of the transform state until it reaches a
- terminal state.
- The sensor will error if the job errors, throwing a
- AirflowException
- containing the failure reason.
+ """
+ Polls the transform job until it reaches a terminal state. Raises an
+ AirflowException with the failure reason if a failed state is reached.
- :param
- job_name: job_name of the transform job instance to check the state of
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+ :ref:`howto/sensor:SageMakerTransformSensor`
+ :param job_name: Name of the transform job to watch.
"""
template_fields: Sequence[str] = ('job_name',)
@@ -154,16 +156,15 @@ class SageMakerTransformSensor(SageMakerBaseSensor):
class SageMakerTuningSensor(SageMakerBaseSensor):
- """Asks for the state of the tuning state until it reaches a terminal
- state.
- The sensor will error if the job errors, throwing a
- AirflowException
- containing the failure reason.
-
- :param
- job_name: job_name of the tuning instance to check the state of
- :type
- job_name: str
+ """
+ Asks for the state of the tuning state until it reaches a terminal state.
+ Raises an AirflowException with the failure reason if a failed state is
reached.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+ :ref:`howto/sensor:SageMakerTuningSensor`
+
+ :param job_name: Name of the tuning instance to watch.
"""
template_fields: Sequence[str] = ('job_name',)
@@ -191,14 +192,16 @@ class SageMakerTuningSensor(SageMakerBaseSensor):
class SageMakerTrainingSensor(SageMakerBaseSensor):
- """Asks for the state of the training state until it reaches a
- terminal state.
- If it fails the sensor errors, failing the task.
-
+ """
+ Polls the training job until it reaches a terminal state. Raises an
+ AirflowException with the failure reason if a failed state is reached.
- :param job_name: name of the SageMaker training job to check the state of
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+ :ref:`howto/sensor:SageMakerTrainingSensor`
- :param print_log: if the operator should print the cloudwatch log
+ :param job_name: Name of the training job to watch.
+ :param print_log: Prints the cloudwatch log if True; Defaults to True.
"""
template_fields: Sequence[str] = ('job_name',)
diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
index f44d258a27..a31527d4ca 100644
--- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
@@ -15,52 +15,201 @@
specific language governing permissions and limitations
under the License.
-
Amazon SageMaker Operators
-========================================
+==========================
+
+`Amazon SageMaker <https://docs.aws.amazon.com/sagemaker>`__ is a fully managed
+machine learning service. With Amazon SageMaker, data scientists and developers
+can quickly build and train machine learning models, and then deploy them into
a
+production-ready hosted environment.
+
+Airflow provides operators to create and interact with SageMaker Jobs.
Prerequisite Tasks
------------------
.. include:: _partials/prerequisite_tasks.rst
-Overview
---------
+Manage Amazon SageMaker Jobs
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. _howto/operator:SageMakerProcessingOperator:
+
+Create an Amazon SageMaker Processing Job
+"""""""""""""""""""""""""""""""""""""""""
+
+To create an Amazon Sagemaker processing job to sanitize your dataset you can
use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_processing]
+ :end-before: [END howto_operator_sagemaker_processing]
+
+
+.. _howto/operator:SageMakerTrainingOperator:
+
+Create an Amazon SageMaker Training Job
+"""""""""""""""""""""""""""""""""""""""
+
+To create an Amazon Sagemaker training job you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_training]
+ :end-before: [END howto_operator_sagemaker_training]
+
+.. _howto/operator:SageMakerModelOperator:
+
+Create an Amazon SageMaker Model
+""""""""""""""""""""""""""""""""
+
+To create an Amazon Sagemaker model you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_model]
+ :end-before: [END howto_operator_sagemaker_model]
+
+.. _howto/operator:SageMakerTuningOperator:
+
+Start a Hyperparameter Tuning Job
+"""""""""""""""""""""""""""""""""
+
+To start a hyperparameter tuning job for an Amazon Sagemaker model you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_tuning]
+ :end-before: [END howto_operator_sagemaker_tuning]
+
+.. _howto/operator:SageMakerDeleteModelOperator:
+
+Delete an Amazon SageMaker Model
+""""""""""""""""""""""""""""""""
+
+To delete an Amazon Sagemaker model you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteModelOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_delete_model]
+ :end-before: [END howto_operator_sagemaker_delete_model]
+
+.. _howto/operator:SageMakerTransformOperator:
+
+Create an Amazon SageMaker Transform Job
+""""""""""""""""""""""""""""""""""""""""
-Airflow to Amazon SageMaker integration provides several operators to create
and interact with
-SageMaker Jobs.
+To create an Amazon Sagemaker transform job you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`.
- -
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteModelOperator`
- -
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator`
- -
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator`
- -
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator`
- -
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`
- -
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator`
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_transform]
+ :end-before: [END howto_operator_sagemaker_transform]
-Purpose
-"""""""
+.. _howto/operator:SageMakerEndpointConfigOperator:
+
+Create an Amazon SageMaker Endpoint Config Job
+""""""""""""""""""""""""""""""""""""""""""""""
+
+To create an Amazon Sagemaker endpoint config job you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointConfigOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_endpoint_config]
+ :end-before: [END howto_operator_sagemaker_endpoint_config]
-This example DAG ``example_sagemaker.py`` uses
``SageMakerProcessingOperator``, ``SageMakerTrainingOperator``,
-``SageMakerModelOperator``, ``SageMakerDeleteModelOperator`` and
``SageMakerTransformOperator`` to
-create SageMaker processing job, run the training job,
-generate the models artifact in s3, create the model,
-, run SageMaker Batch inference and delete the model from SageMaker.
+.. _howto/operator:SageMakerEndpointOperator:
-Defining tasks
-""""""""""""""
+Create an Amazon SageMaker Endpoint Job
+"""""""""""""""""""""""""""""""""""""""
-In the following code we create a SageMaker processing,
-training, Sagemaker Model, batch transform job and
-then delete the model.
+To create an Amazon Sagemaker endpoint you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_endpoint]
+ :end-before: [END howto_operator_sagemaker_endpoint]
+
+
+Amazon SageMaker Sensors
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. _howto/sensor:SageMakerTrainingSensor:
+
+Amazon SageMaker Training Sensor
+""""""""""""""""""""""""""""""""
+
+To check the state of an Amazon Sagemaker training job until it reaches a
terminal state
+you can use
:class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTrainingSensor`.
.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
:language: python
- :start-after: [START howto_operator_sagemaker]
- :end-before: [END howto_operator_sagemaker]
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_training_sensor]
+ :end-before: [END howto_operator_sagemaker_training_sensor]
+
+.. _howto/sensor:SageMakerTransformSensor:
+
+Amazon SageMaker Transform Sensor
+"""""""""""""""""""""""""""""""""""
+
+To check the state of an Amazon Sagemaker transform job until it reaches a
terminal state
+you can use
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_transform_sensor]
+ :end-before: [END howto_operator_sagemaker_transform_sensor]
+
+.. _howto/sensor:SageMakerTuningSensor:
+
+Amazon SageMaker Tuning Sensor
+""""""""""""""""""""""""""""""
+
+To check the state of an Amazon Sagemaker hyperparameter tuning job until it
reaches a terminal state
+you can use
:class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTuningSensor`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_tuning_sensor]
+ :end-before: [END howto_operator_sagemaker_tuning_sensor]
+
+.. _howto/sensor:SageMakerEndpointSensor:
+
+Amazon SageMaker Endpoint Sensor
+""""""""""""""""""""""""""""""""
+
+To check the state of an Amazon Sagemaker hyperparameter tuning job until it
reaches a terminal state
+you can use
:class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerEndpointSensor`.
+
+.. exampleinclude::
/../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_endpoint_sensor]
+ :end-before: [END howto_operator_sagemaker_endpoint_sensor]
Reference
----------
+^^^^^^^^^
For further information, look at:
* `Boto3 Library Documentation for Sagemaker
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html>`__
+* `Amazon SageMaker Developer Guide
<https://docs.aws.amazon.com/sagemaker/latest/dg/whatis.html>`__
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index bb54cf2c2b..98e71c971d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -910,6 +910,7 @@ https
httpx
hvac
hyperparameter
+hyperparameters
iPython
iTerm
iam