This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 800ed1457d7 Add Google Cloud VertexAI and Translate datasets import 
data verification (#51364)
800ed1457d7 is described below

commit 800ed1457d7812ee717d4b15a7c2fad30e839f40
Author: olegkachur-e <[email protected]>
AuthorDate: Mon Jun 30 23:13:52 2025 +0200

    Add Google Cloud VertexAI and Translate datasets import data verification 
(#51364)
    
    For the:
    - Google Cloud VertexAI datasets.
    - Google Cloud Trasnalation native model datasets.
    
    Co-authored-by: Oleg Kachur <[email protected]>
---
 .../providers/google/cloud/hooks/translate.py      |  2 +-
 .../providers/google/cloud/links/translate.py      |  2 +-
 .../providers/google/cloud/operators/translate.py  | 32 ++++++++++++++-
 .../google/cloud/operators/vertex_ai/dataset.py    | 46 +++++++++++++++++++++-
 .../unit/google/cloud/operators/test_translate.py  | 17 +++++++-
 .../unit/google/cloud/operators/test_vertex_ai.py  | 20 ++++++++--
 6 files changed, 110 insertions(+), 9 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/translate.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/translate.py
index 57b7846f023..c5f39819ba0 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/translate.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/translate.py
@@ -429,7 +429,7 @@ class TranslateHook(GoogleBaseHook, OperationHelper):
         project_id: str,
         location: str,
         retry: Retry | _MethodDefault = DEFAULT,
-        timeout: float | _MethodDefault = DEFAULT,
+        timeout: float | None | _MethodDefault = DEFAULT,
         metadata: Sequence[tuple[str, str]] = (),
     ) -> automl_translation.Dataset:
         """
diff --git 
a/providers/google/src/airflow/providers/google/cloud/links/translate.py 
b/providers/google/src/airflow/providers/google/cloud/links/translate.py
index bbd8febe37a..29ed09df4e5 100644
--- a/providers/google/src/airflow/providers/google/cloud/links/translate.py
+++ b/providers/google/src/airflow/providers/google/cloud/links/translate.py
@@ -149,7 +149,7 @@ class TranslationNativeDatasetLink(BaseGoogleLink):
     """
 
     name = "Translation Native Dataset"
-    key = "translation_naive_dataset"
+    key = "translation_native_dataset"
     format_str = TRANSLATION_NATIVE_DATASET_LINK
 
 
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/translate.py 
b/providers/google/src/airflow/providers/google/cloud/operators/translate.py
index ba6b0c338ab..1a112f71636 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/translate.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/translate.py
@@ -37,6 +37,7 @@ from airflow.providers.google.cloud.links.translate import (
     TranslationNativeDatasetLink,
 )
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import 
DatasetImportDataResultsCheckHelper
 from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID
 
 if TYPE_CHECKING:
@@ -575,7 +576,7 @@ class 
TranslateDatasetsListOperator(GoogleCloudBaseOperator):
         return result_ids
 
 
-class TranslateImportDataOperator(GoogleCloudBaseOperator):
+class TranslateImportDataOperator(GoogleCloudBaseOperator, 
DatasetImportDataResultsCheckHelper):
     """
     Import data to the translation dataset.
 
@@ -602,6 +603,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param raise_for_empty_result: Raise an error if no additional data has 
been populated after the import.
     """
 
     template_fields: Sequence[str] = (
@@ -627,6 +629,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
         retry: Retry | _MethodDefault = DEFAULT,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        raise_for_empty_result: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -639,9 +642,21 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
         self.retry = retry
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
+        self.raise_for_empty_result = raise_for_empty_result
 
     def execute(self, context: Context):
         hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
+        initial_dataset_size = self._get_number_of_ds_items(
+            dataset=hook.get_dataset(
+                dataset_id=self.dataset_id,
+                project_id=self.project_id,
+                location=self.location,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            ),
+            total_key_name="example_count",
+        )
         self.log.info("Importing data to dataset...")
         operation = hook.import_dataset_data(
             dataset_id=self.dataset_id,
@@ -660,7 +675,22 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
             location=self.location,
         )
         hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
+
+        result_dataset_size = self._get_number_of_ds_items(
+            dataset=hook.get_dataset(
+                dataset_id=self.dataset_id,
+                project_id=self.project_id,
+                location=self.location,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            ),
+            total_key_name="example_count",
+        )
+        if self.raise_for_empty_result:
+            self._raise_for_empty_import_result(self.dataset_id, 
initial_dataset_size, result_dataset_size)
         self.log.info("Importing data finished!")
+        return {"total_imported": int(result_dataset_size) - 
int(initial_dataset_size)}
 
 
 class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
 
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
index a2c0c79eb0f..fc7ea9ba854 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
@@ -26,6 +26,7 @@ from google.api_core.exceptions import NotFound
 from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
 from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, 
ImportDataConfig
 
+from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
 from airflow.providers.google.cloud.links.vertex_ai import 
VertexAIDatasetLink, VertexAIDatasetListLink
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
@@ -335,7 +336,21 @@ class ExportDataOperator(GoogleCloudBaseOperator):
         self.log.info("Export was done successfully")
 
 
-class ImportDataOperator(GoogleCloudBaseOperator):
+class DatasetImportDataResultsCheckHelper:
+    """Helper utils to verify import dataset data results."""
+
+    @staticmethod
+    def _get_number_of_ds_items(dataset, total_key_name):
+        number_of_items = type(dataset).to_dict(dataset).get(total_key_name, 0)
+        return number_of_items
+
+    @staticmethod
+    def _raise_for_empty_import_result(dataset_id, initial_size, 
size_after_import):
+        if int(size_after_import) - int(initial_size) <= 0:
+            raise AirflowException(f"Empty results of data import for the 
dataset_id {dataset_id}.")
+
+
+class ImportDataOperator(GoogleCloudBaseOperator, 
DatasetImportDataResultsCheckHelper):
     """
     Imports data into a Dataset.
 
@@ -356,6 +371,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param raise_for_empty_result: Raise an error if no additional data has 
been populated after the import.
     """
 
     template_fields = ("region", "dataset_id", "project_id", 
"impersonation_chain")
@@ -372,6 +388,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
         metadata: Sequence[tuple[str, str]] = (),
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        raise_for_empty_result: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -384,13 +401,24 @@ class ImportDataOperator(GoogleCloudBaseOperator):
         self.metadata = metadata
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
+        self.raise_for_empty_result = raise_for_empty_result
 
     def execute(self, context: Context):
         hook = DatasetHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
-
+        initial_dataset_size = self._get_number_of_ds_items(
+            dataset=hook.get_dataset(
+                dataset_id=self.dataset_id,
+                project_id=self.project_id,
+                region=self.region,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            ),
+            total_key_name="data_item_count",
+        )
         self.log.info("Importing data: %s", self.dataset_id)
         operation = hook.import_data(
             project_id=self.project_id,
@@ -402,7 +430,21 @@ class ImportDataOperator(GoogleCloudBaseOperator):
             metadata=self.metadata,
         )
         hook.wait_for_operation(timeout=self.timeout, operation=operation)
+        result_dataset_size = self._get_number_of_ds_items(
+            dataset=hook.get_dataset(
+                dataset_id=self.dataset_id,
+                project_id=self.project_id,
+                region=self.region,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            ),
+            total_key_name="data_item_count",
+        )
+        if self.raise_for_empty_result:
+            self._raise_for_empty_import_result(self.dataset_id, 
initial_dataset_size, result_dataset_size)
         self.log.info("Import was done successfully")
+        return {"total_data_items_imported": int(result_dataset_size) - 
int(initial_dataset_size)}
 
 
 class ListDatasetsOperator(GoogleCloudBaseOperator):
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_translate.py 
b/providers/google/tests/unit/google/cloud/operators/test_translate.py
index 4e20d8ee2ac..5bb259d01cc 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_translate.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_translate.py
@@ -22,6 +22,7 @@ from unittest import mock
 from google.api_core.gapic_v1.method import DEFAULT
 from google.cloud.translate_v3.types import (
     BatchTranslateDocumentResponse,
+    Dataset,
     TranslateDocumentResponse,
     automl_translation,
     translation_service,
@@ -331,6 +332,19 @@ class TestTranslateImportData:
             "input_files": [{"usage": "UNASSIGNED", "gcs_source": 
{"input_uri": "import data gcs path"}}]
         }
         mock_hook.return_value.import_dataset_data.return_value = 
mock.MagicMock()
+
+        SAMPLE_DATASET = {
+            "name": "sample_translation_dataset",
+            "example_count": None,
+            "source_language_code": "en",
+            "target_language_code": "es",
+        }
+        INITIAL_DS_SIZE = 1
+        FINAL_DS_SIZE = 101
+        INITIAL_DS = {**SAMPLE_DATASET, "example_count": INITIAL_DS_SIZE}
+        FINAL_DS = {**SAMPLE_DATASET, "example_count": FINAL_DS_SIZE}
+
+        mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS), 
Dataset(FINAL_DS)]
         op = TranslateImportDataOperator(
             task_id="task_id",
             dataset_id=DATASET_ID,
@@ -343,7 +357,7 @@ class TestTranslateImportData:
             retry=DEFAULT,
         )
         context = mock.MagicMock()
-        op.execute(context=context)
+        res = op.execute(context=context)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=IMPERSONATION_CHAIN,
@@ -363,6 +377,7 @@ class TestTranslateImportData:
             location=LOCATION,
             project_id=PROJECT_ID,
         )
+        assert res["total_imported"] == FINAL_DS_SIZE - INITIAL_DS_SIZE
 
 
 class TestTranslateDeleteData:
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py 
b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
index 460e8b24847..1e9ecf2b569 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
@@ -26,6 +26,7 @@ pytest.importorskip("google.cloud.aiplatform_v1")
 
 from google.api_core.gapic_v1.method import DEFAULT
 from google.api_core.retry import Retry
+from google.cloud.aiplatform_v1.types.dataset import Dataset
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, TaskDeferred
 from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
@@ -1362,9 +1363,8 @@ class TestVertexAIExportDataOperator:
 
 
 class TestVertexAIImportDataOperator:
-    @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
     @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
-    def test_execute(self, mock_hook, to_dict_mock):
+    def test_execute(self, mock_hook):
         op = ImportDataOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -1377,7 +1377,20 @@ class TestVertexAIImportDataOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
-        op.execute(context={})
+        SAMPLE_DATASET = {
+            "name": "sample_translation_dataset",
+            "display_name": "VertexAI dataset",
+            "data_item_count": None,
+        }
+        INITIAL_DS_SIZE = 1
+        FINAL_DS_SIZE = 101
+        INITIAL_DS = {**SAMPLE_DATASET, "data_item_count": INITIAL_DS_SIZE}
+        FINAL_DS = {**SAMPLE_DATASET, "data_item_count": FINAL_DS_SIZE}
+
+        mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS), 
Dataset(FINAL_DS)]
+
+        res = op.execute(context={})
+
         mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
         mock_hook.return_value.import_data.assert_called_once_with(
             region=GCP_LOCATION,
@@ -1388,6 +1401,7 @@ class TestVertexAIImportDataOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        assert res["total_data_items_imported"] == FINAL_DS_SIZE - 
INITIAL_DS_SIZE
 
 
 class TestVertexAIListDatasetsOperator:

Reply via email to