This is an automated email from the ASF dual-hosted git repository. joshfell pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new d4fe325f84 Adding params. to create_auto_ml_forecasting_training_job in AutoMl hook (#39767) d4fe325f84 is described below commit d4fe325f8489aa19858b68ea42b71d99e80410a4 Author: Siddesh M G <108360375+vinay22...@users.noreply.github.com> AuthorDate: Sun May 26 19:08:18 2024 +0530 Adding params. to create_auto_ml_forecasting_training_job in AutoMl hook (#39767) * Update auto_ml.py Added window_stride_length & window_max_count * Update auto_ml.py * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com> * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com> * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com> * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com> * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com> * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com> * Update test_vertex_ai.py * Update test_vertex_ai.py * Update auto_ml.py * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py Co-authored-by: Josh Fell <48934154+josh-f...@users.noreply.github.com> * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py Co-authored-by: Andrey Anshin <andrey.ans...@taragol.is> * Update auto_ml.py * Update test_vertex_ai.py --------- Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com> Co-authored-by: Josh Fell <48934154+josh-f...@users.noreply.github.com> Co-authored-by: Andrey Anshin <andrey.ans...@taragol.is> --- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py | 8 ++++++++ airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py | 6 ++++++ tests/providers/google/cloud/operators/test_vertex_ai.py | 4 ++++ 3 files changed, 18 insertions(+) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py index 5c6d56529d..b1ad7d1a07 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py @@ -551,6 +551,8 @@ class AutoMLHook(GoogleBaseHook): is_default_version: bool | None = None, model_version_aliases: list[str] | None = None, model_version_description: str | None = None, + window_stride_length: int | None = None, + window_max_count: int | None = None, ) -> tuple[models.Model | None, str]: """ Create an AutoML Forecasting Training Job. @@ -703,6 +705,10 @@ class AutoMLHook(GoogleBaseHook): :param sync: Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + :param window_stride_length: Optional. Step length used to generate input examples. Every + ``window_stride_length`` rows will be used to generate a sliding window. + :param window_max_count: Optional. Number of rows that should be used to generate input examples. If the + total row count is larger than this number, the input data will be randomly sampled to hit the count. """ if column_transformations: warnings.warn( @@ -758,6 +764,8 @@ class AutoMLHook(GoogleBaseHook): is_default_version=is_default_version, model_version_aliases=model_version_aliases, model_version_description=model_version_description, + window_stride_length=window_stride_length, + window_max_count=window_max_count, ) training_id = self.extract_training_id(self._job.resource_name) if model: diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index 1475232012..7e3d8bb083 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -138,6 +138,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator): region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, + window_stride_length: int | None = None, + window_max_count: int | None = None, **kwargs, ) -> None: super().__init__( @@ -170,6 +172,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator): self.quantiles = quantiles self.validation_options = validation_options self.budget_milli_node_hours = budget_milli_node_hours + self.window_stride_length = window_stride_length + self.window_max_count = window_max_count def execute(self, context: Context): self.hook = AutoMLHook( @@ -220,6 +224,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator): model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, + window_stride_length=self.window_stride_length, + window_max_count=self.window_max_count, ) if model: diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 3f8649f588..4b8264d615 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -1340,6 +1340,8 @@ class TestVertexAICreateAutoMLForecastingTrainingJobOperator: is_default_version=None, model_version_aliases=None, model_version_description=None, + window_stride_length=None, + window_max_count=None, ) @mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset") @@ -1405,6 +1407,8 @@ class TestVertexAICreateAutoMLForecastingTrainingJobOperator: is_default_version=None, model_version_aliases=None, model_version_description=None, + window_stride_length=None, + window_max_count=None, )