This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 25d463c3e3 Deprecate AutoMLTrainModelOperator for NL (#34212)
25d463c3e3 is described below
commit 25d463c3e33f8628e1bcbe4dc6924693ec141dc0
Author: VladaZakharova <[email protected]>
AuthorDate: Mon Sep 11 12:53:08 2023 +0200
Deprecate AutoMLTrainModelOperator for NL (#34212)
---
airflow/providers/google/cloud/operators/automl.py | 21 +++-
.../operators/cloud/automl.rst | 14 ++-
.../example_automl_nl_text_classification.py | 121 ++++++++++++---------
.../automl/example_automl_nl_text_extraction.py | 117 +++++++++++---------
.../automl/example_automl_nl_text_sentiment.py | 115 ++++++++++++--------
5 files changed, 240 insertions(+), 148 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/automl.py
b/airflow/providers/google/cloud/operators/automl.py
index 6fefa3081a..aee9fc9631 100644
--- a/airflow/providers/google/cloud/operators/automl.py
+++ b/airflow/providers/google/cloud/operators/automl.py
@@ -19,6 +19,7 @@
from __future__ import annotations
import ast
+import warnings
from typing import TYPE_CHECKING, Sequence, Tuple
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -31,6 +32,7 @@ from google.cloud.automl_v1beta1 import (
TableSpec,
)
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.links.automl import (
AutoMLDatasetLink,
@@ -53,6 +55,10 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
"""
Creates Google Cloud AutoML model.
+ AutoMLTrainModelOperator for text prediction is deprecated. Please use
+
:class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
+ instead.
+
.. seealso::
For more information on how to use this operator, take a look at the
guide:
:ref:`howto/operator:AutoMLTrainModelOperator`
@@ -102,7 +108,6 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
**kwargs,
) -> None:
super().__init__(**kwargs)
-
self.model = model
self.location = location
self.project_id = project_id
@@ -113,6 +118,20 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
self.impersonation_chain = impersonation_chain
def execute(self, context: Context):
+ # Output warning if running AutoML Natural Language prediction job
+ automl_nl_model_keys = [
+ "text_classification_model_metadata",
+ "text_extraction_model_metadata",
+ "text_sentiment_dataset_metadata",
+ ]
+ if any(key in automl_nl_model_keys for key in self.model):
+ warnings.warn(
+ "AutoMLTrainModelOperator for text prediction is deprecated.
All the functionality of legacy "
+ "AutoML Natural Language and new features are available on the
Vertex AI platform. "
+ "Please use `CreateAutoMLTextTrainingJobOperator`",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
hook = CloudAutoMLHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst
b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
index 28821c4c6f..b283f51c07 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/automl.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
@@ -102,6 +102,16 @@ To create a Google AutoML model you can use
The operator will wait for the operation to complete. Additionally the operator
returns the id of model in :ref:`XCom <concepts:xcom>` under ``model_id`` key.
+This Operator is deprecated when running for text prediction and will be
removed soon.
+All the functionality of legacy AutoML Natural Language and new features are
available on the
+Vertex AI platform. Please use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`.
+When running Vertex AI Operator for training dat, please ensure that your data
is correctly stored in Vertex AI
+datasets. To create and import data to the dataset please use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
+and
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`
+
.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
@@ -164,7 +174,7 @@ the model must be deployed.
Listing And Deleting Datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-You can get a list of AutoML models using
+You can get a list of AutoML datasets using
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`.
The operator returns list
of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key.
@@ -174,7 +184,7 @@ of datasets ids in :ref:`XCom <concepts:xcom>` under
``dataset_id_list`` key.
:start-after: [START howto_operator_list_dataset]
:end-before: [END howto_operator_list_dataset]
-To delete a model you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
+To delete a dataset you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
The delete operator allows also to pass list or coma separated string of
datasets ids to be deleted.
.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
diff --git
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
index 0a04b3b361..753c91dfd0 100644
---
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
+++
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
@@ -24,47 +24,54 @@ import os
from datetime import datetime
from typing import cast
+from google.cloud.aiplatform import schema
+from google.protobuf.struct_pb2 import Value
+
from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.operators.automl import (
- AutoMLCreateDatasetOperator,
- AutoMLDeleteDatasetOperator,
- AutoMLDeleteModelOperator,
- AutoMLDeployModelOperator,
- AutoMLImportDataOperator,
- AutoMLTrainModelOperator,
-)
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
GCSSynchronizeBucketsOperator,
)
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+ CreateAutoMLTextTrainingJobOperator,
+ DeleteAutoMLTrainingJobOperator,
+)
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
+ CreateDatasetOperator,
+ DeleteDatasetOperator,
+ ImportDataOperator,
+)
from airflow.utils.trigger_rule import TriggerRule
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_automl_text_cls"
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "example_automl_text_cls"
GCP_AUTOML_LOCATION = "us-central1"
DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
-MODEL_NAME = "text_clss_test_model"
-MODEL = {
- "display_name": MODEL_NAME,
- "text_classification_model_metadata": {},
-}
+TEXT_CLSS_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/classification.csv"
+
+MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
DATASET_NAME = f"ds_clss_{ENV_ID}".replace("-", "_")
DATASET = {
"display_name": DATASET_NAME,
- "text_classification_dataset_metadata": {"classification_type":
"MULTICLASS"},
+ "metadata_schema_uri": schema.dataset.metadata.text,
+ "metadata": Value(string_value="clss-dataset"),
}
-AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_classification.csv"
-IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
-
+DATA_CONFIG = [
+ {
+ "import_schema_uri":
schema.dataset.ioformat.text.single_label_classification,
+ "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
+ },
+]
extract_object_id = CloudAutoMLHook.extract_object_id
# Example DAG for AutoML Natural Language Text Classification
@@ -85,67 +92,77 @@ with models.DAG(
move_dataset_file = GCSSynchronizeBucketsOperator(
task_id="move_dataset_to_bucket",
source_bucket=RESOURCE_DATA_BUCKET,
- source_object="automl/datasets/text",
+ source_object="vertex-ai/automl/datasets/text",
destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
destination_object="automl",
recursive=True,
)
- create_dataset = AutoMLCreateDatasetOperator(
- task_id="create_dataset",
+ create_clss_dataset = CreateDatasetOperator(
+ task_id="create_clss_dataset",
dataset=DATASET,
- location=GCP_AUTOML_LOCATION,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
)
+ clss_dataset_id = create_clss_dataset.output["dataset_id"]
- dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
- MODEL["dataset_id"] = dataset_id
- import_dataset = AutoMLImportDataOperator(
- task_id="import_dataset",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- input_config=IMPORT_INPUT_CONFIG,
+ import_clss_dataset = ImportDataOperator(
+ task_id="import_clss_data",
+ dataset_id=clss_dataset_id,
+ region=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ import_configs=DATA_CONFIG,
)
- MODEL["dataset_id"] = dataset_id
-
- create_model = AutoMLTrainModelOperator(task_id="create_model",
model=MODEL, location=GCP_AUTOML_LOCATION)
- model_id = cast(str, XComArg(create_model, key="model_id"))
- deploy_model = AutoMLDeployModelOperator(
- task_id="deploy_model",
- model_id=model_id,
- location=GCP_AUTOML_LOCATION,
+ # [START howto_operator_automl_create_model]
+ create_clss_training_job = CreateAutoMLTextTrainingJobOperator(
+ task_id="create_clss_training_job",
+ display_name=TEXT_CLSS_DISPLAY_NAME,
+ prediction_type="classification",
+ multi_label=False,
+ dataset_id=clss_dataset_id,
+ model_display_name=MODEL_NAME,
+ training_fraction_split=0.7,
+ validation_fraction_split=0.2,
+ test_fraction_split=0.1,
+ sync=True,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
)
+ # [END howto_operator_automl_create_model]
+ model_id = cast(str, XComArg(create_clss_training_job, key="model_id"))
- delete_model = AutoMLDeleteModelOperator(
- task_id="delete_model",
- model_id=model_id,
- location=GCP_AUTOML_LOCATION,
+ delete_clss_training_job = DeleteAutoMLTrainingJobOperator(
+ task_id="delete_clss_training_job",
+ training_pipeline_id=create_clss_training_job.output["training_id"],
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
)
- delete_dataset = AutoMLDeleteDatasetOperator(
- task_id="delete_dataset",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
+ delete_clss_dataset = DeleteDatasetOperator(
+ task_id="delete_clss_dataset",
+ dataset_id=clss_dataset_id,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
)
delete_bucket = GCSDeleteBucketOperator(
- task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE
+ task_id="delete_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
)
(
# TEST SETUP
- [create_bucket >> move_dataset_file, create_dataset]
+ [create_bucket >> move_dataset_file, create_clss_dataset]
# TEST BODY
- >> import_dataset
- >> create_model
- >> deploy_model
+ >> import_clss_dataset
+ >> create_clss_training_job
# TEST TEARDOWN
- >> delete_model
- >> delete_dataset
+ >> delete_clss_training_job
+ >> delete_clss_dataset
>> delete_bucket
)
diff --git
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
index 260a7d84c6..06b22779a8 100644
---
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
+++
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
@@ -24,42 +24,52 @@ import os
from datetime import datetime
from typing import cast
+from google.cloud.aiplatform import schema
+from google.protobuf.struct_pb2 import Value
+
from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.operators.automl import (
- AutoMLCreateDatasetOperator,
- AutoMLDeleteDatasetOperator,
- AutoMLDeleteModelOperator,
- AutoMLImportDataOperator,
- AutoMLTrainModelOperator,
-)
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
GCSSynchronizeBucketsOperator,
)
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+ CreateAutoMLTextTrainingJobOperator,
+ DeleteAutoMLTrainingJobOperator,
+)
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
+ CreateDatasetOperator,
+ DeleteDatasetOperator,
+ ImportDataOperator,
+)
from airflow.utils.trigger_rule import TriggerRule
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_automl_text_extr"
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "example_automl_text_extr"
GCP_AUTOML_LOCATION = "us-central1"
RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
-
-DATASET_NAME = f"ds_extr_{ENV_ID}".replace("-", "_")
-DATASET = {"display_name": DATASET_NAME, "text_extraction_dataset_metadata":
{}}
-AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_extraction.csv"
-IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
-
-MODEL_NAME = "entity_extr_test_model"
-MODEL = {
- "display_name": MODEL_NAME,
- "text_extraction_model_metadata": {},
+TEXT_EXTR_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/extraction.jsonl"
+
+MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+DATASET_NAME = f"ds_clss_{ENV_ID}".replace("-", "_")
+DATASET = {
+ "display_name": DATASET_NAME,
+ "metadata_schema_uri": schema.dataset.metadata.text,
+ "metadata": Value(string_value="extr-dataset"),
}
+DATA_CONFIG = [
+ {
+ "import_schema_uri": schema.dataset.ioformat.text.extraction,
+ "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
+ },
+]
extract_object_id = CloudAutoMLHook.extract_object_id
@@ -80,51 +90,60 @@ with models.DAG(
)
move_dataset_file = GCSSynchronizeBucketsOperator(
- task_id="move_data_to_bucket",
+ task_id="move_dataset_to_bucket",
source_bucket=RESOURCE_DATA_BUCKET,
- source_object="automl/datasets/text",
+ source_object="vertex-ai/automl/datasets/text",
destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
destination_object="automl",
recursive=True,
)
- create_dataset = AutoMLCreateDatasetOperator(
- task_id="create_dataset",
+ create_extr_dataset = CreateDatasetOperator(
+ task_id="create_extr_dataset",
dataset=DATASET,
- location=GCP_AUTOML_LOCATION,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
)
+ extr_dataset_id = create_extr_dataset.output["dataset_id"]
- dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
- MODEL["dataset_id"] = dataset_id
- import_dataset = AutoMLImportDataOperator(
- task_id="import_dataset",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- input_config=IMPORT_INPUT_CONFIG,
+ import_extr_dataset = ImportDataOperator(
+ task_id="import_extr_data",
+ dataset_id=extr_dataset_id,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
+ import_configs=DATA_CONFIG,
)
- MODEL["dataset_id"] = dataset_id
- create_model = AutoMLTrainModelOperator(
- task_id="create_model",
- model=MODEL,
- location=GCP_AUTOML_LOCATION,
+ # [START howto_operator_automl_create_model]
+ create_extr_training_job = CreateAutoMLTextTrainingJobOperator(
+ task_id="create_extr_training_job",
+ display_name=TEXT_EXTR_DISPLAY_NAME,
+ prediction_type="extraction",
+ multi_label=False,
+ dataset_id=extr_dataset_id,
+ model_display_name=MODEL_NAME,
+ training_fraction_split=0.8,
+ validation_fraction_split=0.1,
+ test_fraction_split=0.1,
+ sync=True,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
)
- model_id = cast(str, XComArg(create_model, key="model_id"))
+ # [END howto_operator_automl_create_model]
+ model_id = cast(str, XComArg(create_extr_training_job, key="model_id"))
- delete_model_task = AutoMLDeleteModelOperator(
- task_id="delete_model_task",
- model_id=model_id,
- location=GCP_AUTOML_LOCATION,
+ delete_extr_training_job = DeleteAutoMLTrainingJobOperator(
+ task_id="delete_extr_training_job",
+ training_pipeline_id=create_extr_training_job.output["training_id"],
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
)
- delete_datasets_task = AutoMLDeleteDatasetOperator(
- task_id="delete_datasets_task",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
+ delete_extr_dataset = DeleteDatasetOperator(
+ task_id="delete_extr_dataset",
+ dataset_id=extr_dataset_id,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
trigger_rule=TriggerRule.ALL_DONE,
)
@@ -137,13 +156,13 @@ with models.DAG(
(
# TEST SETUP
- [create_bucket >> move_dataset_file, create_dataset]
+ [create_bucket >> move_dataset_file, create_extr_dataset]
# TEST BODY
- >> import_dataset
- >> create_model
+ >> import_extr_dataset
+ >> create_extr_training_job
# TEST TEARDOWN
- >> delete_model_task
- >> delete_datasets_task
+ >> delete_extr_training_job
+ >> delete_extr_dataset
>> delete_bucket
)
diff --git
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
index 3559339755..1529f07bc6 100644
---
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
+++
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
@@ -24,44 +24,53 @@ import os
from datetime import datetime
from typing import cast
+from google.cloud.aiplatform import schema
+from google.protobuf.struct_pb2 import Value
+
from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.operators.automl import (
- AutoMLCreateDatasetOperator,
- AutoMLDeleteDatasetOperator,
- AutoMLDeleteModelOperator,
- AutoMLImportDataOperator,
- AutoMLTrainModelOperator,
-)
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
GCSSynchronizeBucketsOperator,
)
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+ CreateAutoMLTextTrainingJobOperator,
+ DeleteAutoMLTrainingJobOperator,
+)
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
+ CreateDatasetOperator,
+ DeleteDatasetOperator,
+ ImportDataOperator,
+)
from airflow.utils.trigger_rule import TriggerRule
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_automl_text_sent"
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "example_automl_text_sent"
GCP_AUTOML_LOCATION = "us-central1"
DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
-MODEL_NAME = "text_sent_test_model"
-MODEL = {
- "display_name": MODEL_NAME,
- "text_sentiment_model_metadata": {},
-}
+TEXT_SENT_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/sentiment.csv"
+
+MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
DATASET_NAME = f"ds_sent_{ENV_ID}".replace("-", "_")
DATASET = {
"display_name": DATASET_NAME,
- "text_sentiment_dataset_metadata": {"sentiment_max": 5},
+ "metadata_schema_uri": schema.dataset.metadata.text,
+ "metadata": Value(string_value="sent-dataset"),
}
-AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_sentiment.csv"
-IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+DATA_CONFIG = [
+ {
+ "import_schema_uri": schema.dataset.ioformat.text.sentiment,
+ "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
+ },
+]
extract_object_id = CloudAutoMLHook.extract_object_id
@@ -84,43 +93,61 @@ with models.DAG(
move_dataset_file = GCSSynchronizeBucketsOperator(
task_id="move_dataset_to_bucket",
source_bucket=RESOURCE_DATA_BUCKET,
- source_object="automl/datasets/text",
+ source_object="vertex-ai/automl/datasets/text",
destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
destination_object="automl",
recursive=True,
)
- create_dataset = AutoMLCreateDatasetOperator(
- task_id="create_dataset", dataset=DATASET, location=GCP_AUTOML_LOCATION
+ create_sent_dataset = CreateDatasetOperator(
+ task_id="create_sent_dataset",
+ dataset=DATASET,
+ region=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
)
+ sent_dataset_id = create_sent_dataset.output["dataset_id"]
- dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
- MODEL["dataset_id"] = dataset_id
-
- import_dataset = AutoMLImportDataOperator(
- task_id="import_dataset",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- input_config=IMPORT_INPUT_CONFIG,
+ import_sent_dataset = ImportDataOperator(
+ task_id="import_sent_data",
+ dataset_id=sent_dataset_id,
+ region=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ import_configs=DATA_CONFIG,
)
- MODEL["dataset_id"] = dataset_id
-
- create_model = AutoMLTrainModelOperator(task_id="create_model",
model=MODEL, location=GCP_AUTOML_LOCATION)
- model_id = cast(str, XComArg(create_model, key="model_id"))
+ # [START howto_operator_automl_create_model]
+ create_sent_training_job = CreateAutoMLTextTrainingJobOperator(
+ task_id="create_sent_training_job",
+ display_name=TEXT_SENT_DISPLAY_NAME,
+ prediction_type="sentiment",
+ multi_label=False,
+ dataset_id=sent_dataset_id,
+ model_display_name=MODEL_NAME,
+ training_fraction_split=0.7,
+ validation_fraction_split=0.2,
+ test_fraction_split=0.1,
+ sentiment_max=5,
+ sync=True,
+ region=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_automl_create_model]
+ model_id = cast(str, XComArg(create_sent_training_job, key="model_id"))
- delete_model = AutoMLDeleteModelOperator(
- task_id="delete_model",
- model_id=model_id,
- location=GCP_AUTOML_LOCATION,
+ delete_sent_training_job = DeleteAutoMLTrainingJobOperator(
+ task_id="delete_sent_training_job",
+ training_pipeline_id=create_sent_training_job.output["training_id"],
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
)
- delete_dataset = AutoMLDeleteDatasetOperator(
- task_id="delete_dataset",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
+ delete_sent_dataset = DeleteDatasetOperator(
+ task_id="delete_sent_dataset",
+ dataset_id=sent_dataset_id,
+ region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
)
delete_bucket = GCSDeleteBucketOperator(
@@ -131,13 +158,13 @@ with models.DAG(
(
# TEST SETUP
- [create_bucket >> move_dataset_file, create_dataset]
+ [create_bucket >> move_dataset_file, create_sent_dataset]
# TEST BODY
- >> import_dataset
- >> create_model
+ >> import_sent_dataset
+ >> create_sent_training_job
# TEST TEARDOWN
- >> delete_model
- >> delete_dataset
+ >> delete_sent_training_job
+ >> delete_sent_dataset
>> delete_bucket
)