This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d4fe325f84 Adding params. to create_auto_ml_forecasting_training_job 
in AutoMl hook (#39767)
d4fe325f84 is described below

commit d4fe325f8489aa19858b68ea42b71d99e80410a4
Author: Siddesh M G <108360375+vinay22...@users.noreply.github.com>
AuthorDate: Sun May 26 19:08:18 2024 +0530

    Adding params. to create_auto_ml_forecasting_training_job in AutoMl hook 
(#39767)
    
    * Update auto_ml.py
    
    Added window_stride_length & window_max_count
    
    * Update auto_ml.py
    
    * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
    
    Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com>
    
    * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
    
    Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com>
    
    * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
    
    Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com>
    
    * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
    
    Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com>
    
    * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
    
    Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com>
    
    * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
    
    Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com>
    
    * Update test_vertex_ai.py
    
    * Update test_vertex_ai.py
    
    * Update auto_ml.py
    
    * Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
    
    Co-authored-by: Josh Fell <48934154+josh-f...@users.noreply.github.com>
    
    * Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
    
    Co-authored-by: Andrey Anshin <andrey.ans...@taragol.is>
    
    * Update auto_ml.py
    
    * Update test_vertex_ai.py
    
    ---------
    
    Co-authored-by: Shahar Epstein <60007259+shah...@users.noreply.github.com>
    Co-authored-by: Josh Fell <48934154+josh-f...@users.noreply.github.com>
    Co-authored-by: Andrey Anshin <andrey.ans...@taragol.is>
---
 airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py     | 8 ++++++++
 airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py | 6 ++++++
 tests/providers/google/cloud/operators/test_vertex_ai.py      | 4 ++++
 3 files changed, 18 insertions(+)

diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py 
b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
index 5c6d56529d..b1ad7d1a07 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
@@ -551,6 +551,8 @@ class AutoMLHook(GoogleBaseHook):
         is_default_version: bool | None = None,
         model_version_aliases: list[str] | None = None,
         model_version_description: str | None = None,
+        window_stride_length: int | None = None,
+        window_max_count: int | None = None,
     ) -> tuple[models.Model | None, str]:
         """
         Create an AutoML Forecasting Training Job.
@@ -703,6 +705,10 @@ class AutoMLHook(GoogleBaseHook):
         :param sync: Whether to execute this method synchronously. If False, 
this method will be executed in
             concurrent Future and any downstream object will be immediately 
returned and synced when the
             Future has completed.
+        :param window_stride_length: Optional. Step length used to generate 
input examples. Every
+            ``window_stride_length`` rows will be used to generate a sliding 
window.
+        :param window_max_count: Optional. Number of rows that should be used 
to generate input examples. If the
+            total row count is larger than this number, the input data will be 
randomly sampled to hit the count.
         """
         if column_transformations:
             warnings.warn(
@@ -758,6 +764,8 @@ class AutoMLHook(GoogleBaseHook):
             is_default_version=is_default_version,
             model_version_aliases=model_version_aliases,
             model_version_description=model_version_description,
+            window_stride_length=window_stride_length,
+            window_max_count=window_max_count,
         )
         training_id = self.extract_training_id(self._job.resource_name)
         if model:
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py 
b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index 1475232012..7e3d8bb083 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -138,6 +138,8 @@ class 
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
         region: str,
         impersonation_chain: str | Sequence[str] | None = None,
         parent_model: str | None = None,
+        window_stride_length: int | None = None,
+        window_max_count: int | None = None,
         **kwargs,
     ) -> None:
         super().__init__(
@@ -170,6 +172,8 @@ class 
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
         self.quantiles = quantiles
         self.validation_options = validation_options
         self.budget_milli_node_hours = budget_milli_node_hours
+        self.window_stride_length = window_stride_length
+        self.window_max_count = window_max_count
 
     def execute(self, context: Context):
         self.hook = AutoMLHook(
@@ -220,6 +224,8 @@ class 
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
             model_display_name=self.model_display_name,
             model_labels=self.model_labels,
             sync=self.sync,
+            window_stride_length=self.window_stride_length,
+            window_max_count=self.window_max_count,
         )
 
         if model:
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py 
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index 3f8649f588..4b8264d615 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -1340,6 +1340,8 @@ class 
TestVertexAICreateAutoMLForecastingTrainingJobOperator:
             is_default_version=None,
             model_version_aliases=None,
             model_version_description=None,
+            window_stride_length=None,
+            window_max_count=None,
         )
 
     @mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset")
@@ -1405,6 +1407,8 @@ class 
TestVertexAICreateAutoMLForecastingTrainingJobOperator:
             is_default_version=None,
             model_version_aliases=None,
             model_version_description=None,
+            window_stride_length=None,
+            window_max_count=None,
         )
 
 

Reply via email to