vandonr-amz commented on code in PR #28472:
URL: https://github.com/apache/airflow/pull/28472#discussion_r1061820842


##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -958,3 +958,96 @@ def execute(self, context: Context):
             if group_created:
                 
self.hook.conn.delete_model_package_group(ModelPackageGroupName=self.package_group_name)
             raise
+
+
+class SageMakerAutoMLOperator(SageMakerBaseOperator):
+    """
+    Creates an auto ML job, learning to predict the given column from the data 
provided through S3.
+    The learning output is written to the specified S3 location.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerAutoMLOperator`
+
+    :param job_name: Name of the job to create, needs to be unique within the 
account.
+    :param s3_input: The S3 location (folder or file) where to fetch the data.
+        By default, it expects csv with headers.
+    :param target_attribute: The name of the column containing the values to 
predict.
+    :param s3_output: The S3 folder where to write the model artifacts. Must 
be 128 characters or fewer.
+    :param role_arn: The ARN of the IAM role to use when interacting with S3.
+        Must have read access to the input, and write access to the output 
folder.
+    :param compressed_input: Set to True if the input is gzipped.
+    :param time_limit: The maximum amount of time in seconds to spend training 
the model(s).
+    :param autodeploy_endpoint_name: If specified, the best model will be 
deployed to an endpoint with
+        that name. No deployment made otherwise.
+    :param extras: Use this dictionary to set any variable input variable for 
job creation that is not
+        offered through the parameters of this function. The format is 
described in:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
+    :param wait_for_completion: Whether to wait for the job to finish before 
returning. Defaults to True.
+    :param check_interval: Interval in seconds between 2 status checks when 
waiting for completion.
+
+    :returns: Only if waiting for completion, a dictionary detailing the best 
model. The structure is that of
+        the "BestCandidate" key in:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
+    """
+
+    template_fields: Sequence[str] = (
+        "job_name",
+        "s3_input",
+        "target_attribute",
+        "s3_output",
+        "role_arn",
+        "compressed_input",
+        "time_limit",
+        "autodeploy_endpoint_name",
+        "extras",
+        "wait_for_completion",
+        "check_interval",
+    )
+
+    def __init__(
+        self,
+        *,
+        job_name: str,
+        s3_input: str,
+        target_attribute: str,
+        s3_output: str,
+        role_arn: str,
+        compressed_input: bool = False,
+        time_limit: int | None = None,
+        autodeploy_endpoint_name: str | None = None,
+        extras: dict | None = None,
+        wait_for_completion: bool = True,
+        check_interval: int = 30,
+        aws_conn_id: str = DEFAULT_CONN_ID,
+        config: dict | None = None,
+        **kwargs,
+    ):
+        super().__init__(config=config or {}, aws_conn_id=aws_conn_id, 
**kwargs)
+        self.job_name = job_name
+        self.s3_input = s3_input
+        self.target_attribute = target_attribute
+        self.s3_output = s3_output
+        self.role_arn = role_arn
+        self.compressed_input = compressed_input
+        self.time_limit = time_limit
+        self.autodeploy_endpoint_name = autodeploy_endpoint_name
+        self.extras = extras
+        self.wait_for_completion = wait_for_completion
+        self.check_interval = check_interval
+
+    def execute(self, context: Context) -> dict | None:
+        best = self.hook.create_auto_ml_job(
+            self.job_name,
+            self.s3_input,
+            self.target_attribute,
+            self.s3_output,
+            self.role_arn,
+            self.compressed_input,
+            self.time_limit,
+            self.autodeploy_endpoint_name,
+            self.extras,
+            self.wait_for_completion,
+            self.check_interval,
+        )
+        return best

Review Comment:
   Maybe I'm misunderstanding your point here, but `do_xcom_push` doesn't need 
to be checked here because it's handled once for everyone here, if I read it 
correctly:
   
https://github.com/apache/airflow/blob/2a6fb076262cbba4c15eec9495e68f482d071247/airflow/models/taskinstance.py#L1585-L1588
   (It also makes sense that this piece of code wouldn't be repeated in every 
operator)
   
   Also, `best` is coming from the hook, here
   
https://github.com/apache/airflow/blob/2a6fb076262cbba4c15eec9495e68f482d071247/airflow/providers/amazon/aws/hooks/sagemaker.py#L1245-L1247
   So it should never be an empty dictionary. It's either something or None. I 
cannot _not_ return because it makes the static checker angry, but returning 
None is effectively the same.
   
   I can add a check here that best is `!= {}` if that's what you meant ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to