This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 800ed1457d7 Add Google Cloud VertexAI and Translate datasets import
data verification (#51364)
800ed1457d7 is described below
commit 800ed1457d7812ee717d4b15a7c2fad30e839f40
Author: olegkachur-e <[email protected]>
AuthorDate: Mon Jun 30 23:13:52 2025 +0200
Add Google Cloud VertexAI and Translate datasets import data verification
(#51364)
For the:
- Google Cloud VertexAI datasets.
- Google Cloud Trasnalation native model datasets.
Co-authored-by: Oleg Kachur <[email protected]>
---
.../providers/google/cloud/hooks/translate.py | 2 +-
.../providers/google/cloud/links/translate.py | 2 +-
.../providers/google/cloud/operators/translate.py | 32 ++++++++++++++-
.../google/cloud/operators/vertex_ai/dataset.py | 46 +++++++++++++++++++++-
.../unit/google/cloud/operators/test_translate.py | 17 +++++++-
.../unit/google/cloud/operators/test_vertex_ai.py | 20 ++++++++--
6 files changed, 110 insertions(+), 9 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/translate.py
b/providers/google/src/airflow/providers/google/cloud/hooks/translate.py
index 57b7846f023..c5f39819ba0 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/translate.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/translate.py
@@ -429,7 +429,7 @@ class TranslateHook(GoogleBaseHook, OperationHelper):
project_id: str,
location: str,
retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | _MethodDefault = DEFAULT,
+ timeout: float | None | _MethodDefault = DEFAULT,
metadata: Sequence[tuple[str, str]] = (),
) -> automl_translation.Dataset:
"""
diff --git
a/providers/google/src/airflow/providers/google/cloud/links/translate.py
b/providers/google/src/airflow/providers/google/cloud/links/translate.py
index bbd8febe37a..29ed09df4e5 100644
--- a/providers/google/src/airflow/providers/google/cloud/links/translate.py
+++ b/providers/google/src/airflow/providers/google/cloud/links/translate.py
@@ -149,7 +149,7 @@ class TranslationNativeDatasetLink(BaseGoogleLink):
"""
name = "Translation Native Dataset"
- key = "translation_naive_dataset"
+ key = "translation_native_dataset"
format_str = TRANSLATION_NATIVE_DATASET_LINK
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/translate.py
b/providers/google/src/airflow/providers/google/cloud/operators/translate.py
index ba6b0c338ab..1a112f71636 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/translate.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/translate.py
@@ -37,6 +37,7 @@ from airflow.providers.google.cloud.links.translate import (
TranslationNativeDatasetLink,
)
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import
DatasetImportDataResultsCheckHelper
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
if TYPE_CHECKING:
@@ -575,7 +576,7 @@ class
TranslateDatasetsListOperator(GoogleCloudBaseOperator):
return result_ids
-class TranslateImportDataOperator(GoogleCloudBaseOperator):
+class TranslateImportDataOperator(GoogleCloudBaseOperator,
DatasetImportDataResultsCheckHelper):
"""
Import data to the translation dataset.
@@ -602,6 +603,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param raise_for_empty_result: Raise an error if no additional data has
been populated after the import.
"""
template_fields: Sequence[str] = (
@@ -627,6 +629,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
retry: Retry | _MethodDefault = DEFAULT,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
+ raise_for_empty_result: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -639,9 +642,21 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
self.retry = retry
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
+ self.raise_for_empty_result = raise_for_empty_result
def execute(self, context: Context):
hook = TranslateHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
+ initial_dataset_size = self._get_number_of_ds_items(
+ dataset=hook.get_dataset(
+ dataset_id=self.dataset_id,
+ project_id=self.project_id,
+ location=self.location,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ ),
+ total_key_name="example_count",
+ )
self.log.info("Importing data to dataset...")
operation = hook.import_dataset_data(
dataset_id=self.dataset_id,
@@ -660,7 +675,22 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
location=self.location,
)
hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
+
+ result_dataset_size = self._get_number_of_ds_items(
+ dataset=hook.get_dataset(
+ dataset_id=self.dataset_id,
+ project_id=self.project_id,
+ location=self.location,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ ),
+ total_key_name="example_count",
+ )
+ if self.raise_for_empty_result:
+ self._raise_for_empty_import_result(self.dataset_id,
initial_dataset_size, result_dataset_size)
self.log.info("Importing data finished!")
+ return {"total_imported": int(result_dataset_size) -
int(initial_dataset_size)}
class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
index a2c0c79eb0f..fc7ea9ba854 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
@@ -26,6 +26,7 @@ from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig,
ImportDataConfig
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
from airflow.providers.google.cloud.links.vertex_ai import
VertexAIDatasetLink, VertexAIDatasetListLink
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
@@ -335,7 +336,21 @@ class ExportDataOperator(GoogleCloudBaseOperator):
self.log.info("Export was done successfully")
-class ImportDataOperator(GoogleCloudBaseOperator):
+class DatasetImportDataResultsCheckHelper:
+ """Helper utils to verify import dataset data results."""
+
+ @staticmethod
+ def _get_number_of_ds_items(dataset, total_key_name):
+ number_of_items = type(dataset).to_dict(dataset).get(total_key_name, 0)
+ return number_of_items
+
+ @staticmethod
+ def _raise_for_empty_import_result(dataset_id, initial_size,
size_after_import):
+ if int(size_after_import) - int(initial_size) <= 0:
+ raise AirflowException(f"Empty results of data import for the
dataset_id {dataset_id}.")
+
+
+class ImportDataOperator(GoogleCloudBaseOperator,
DatasetImportDataResultsCheckHelper):
"""
Imports data into a Dataset.
@@ -356,6 +371,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param raise_for_empty_result: Raise an error if no additional data has
been populated after the import.
"""
template_fields = ("region", "dataset_id", "project_id",
"impersonation_chain")
@@ -372,6 +388,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
+ raise_for_empty_result: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -384,13 +401,24 @@ class ImportDataOperator(GoogleCloudBaseOperator):
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
+ self.raise_for_empty_result = raise_for_empty_result
def execute(self, context: Context):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
-
+ initial_dataset_size = self._get_number_of_ds_items(
+ dataset=hook.get_dataset(
+ dataset_id=self.dataset_id,
+ project_id=self.project_id,
+ region=self.region,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ ),
+ total_key_name="data_item_count",
+ )
self.log.info("Importing data: %s", self.dataset_id)
operation = hook.import_data(
project_id=self.project_id,
@@ -402,7 +430,21 @@ class ImportDataOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
)
hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ result_dataset_size = self._get_number_of_ds_items(
+ dataset=hook.get_dataset(
+ dataset_id=self.dataset_id,
+ project_id=self.project_id,
+ region=self.region,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ ),
+ total_key_name="data_item_count",
+ )
+ if self.raise_for_empty_result:
+ self._raise_for_empty_import_result(self.dataset_id,
initial_dataset_size, result_dataset_size)
self.log.info("Import was done successfully")
+ return {"total_data_items_imported": int(result_dataset_size) -
int(initial_dataset_size)}
class ListDatasetsOperator(GoogleCloudBaseOperator):
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_translate.py
b/providers/google/tests/unit/google/cloud/operators/test_translate.py
index 4e20d8ee2ac..5bb259d01cc 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_translate.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_translate.py
@@ -22,6 +22,7 @@ from unittest import mock
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.translate_v3.types import (
BatchTranslateDocumentResponse,
+ Dataset,
TranslateDocumentResponse,
automl_translation,
translation_service,
@@ -331,6 +332,19 @@ class TestTranslateImportData:
"input_files": [{"usage": "UNASSIGNED", "gcs_source":
{"input_uri": "import data gcs path"}}]
}
mock_hook.return_value.import_dataset_data.return_value =
mock.MagicMock()
+
+ SAMPLE_DATASET = {
+ "name": "sample_translation_dataset",
+ "example_count": None,
+ "source_language_code": "en",
+ "target_language_code": "es",
+ }
+ INITIAL_DS_SIZE = 1
+ FINAL_DS_SIZE = 101
+ INITIAL_DS = {**SAMPLE_DATASET, "example_count": INITIAL_DS_SIZE}
+ FINAL_DS = {**SAMPLE_DATASET, "example_count": FINAL_DS_SIZE}
+
+ mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS),
Dataset(FINAL_DS)]
op = TranslateImportDataOperator(
task_id="task_id",
dataset_id=DATASET_ID,
@@ -343,7 +357,7 @@ class TestTranslateImportData:
retry=DEFAULT,
)
context = mock.MagicMock()
- op.execute(context=context)
+ res = op.execute(context=context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -363,6 +377,7 @@ class TestTranslateImportData:
location=LOCATION,
project_id=PROJECT_ID,
)
+ assert res["total_imported"] == FINAL_DS_SIZE - INITIAL_DS_SIZE
class TestTranslateDeleteData:
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
index 460e8b24847..1e9ecf2b569 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
@@ -26,6 +26,7 @@ pytest.importorskip("google.cloud.aiplatform_v1")
from google.api_core.gapic_v1.method import DEFAULT
from google.api_core.retry import Retry
+from google.cloud.aiplatform_v1.types.dataset import Dataset
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, TaskDeferred
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
@@ -1362,9 +1363,8 @@ class TestVertexAIExportDataOperator:
class TestVertexAIImportDataOperator:
- @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
@mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
- def test_execute(self, mock_hook, to_dict_mock):
+ def test_execute(self, mock_hook):
op = ImportDataOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -1377,7 +1377,20 @@ class TestVertexAIImportDataOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- op.execute(context={})
+ SAMPLE_DATASET = {
+ "name": "sample_translation_dataset",
+ "display_name": "VertexAI dataset",
+ "data_item_count": None,
+ }
+ INITIAL_DS_SIZE = 1
+ FINAL_DS_SIZE = 101
+ INITIAL_DS = {**SAMPLE_DATASET, "data_item_count": INITIAL_DS_SIZE}
+ FINAL_DS = {**SAMPLE_DATASET, "data_item_count": FINAL_DS_SIZE}
+
+ mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS),
Dataset(FINAL_DS)]
+
+ res = op.execute(context={})
+
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.import_data.assert_called_once_with(
region=GCP_LOCATION,
@@ -1388,6 +1401,7 @@ class TestVertexAIImportDataOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+ assert res["total_data_items_imported"] == FINAL_DS_SIZE -
INITIAL_DS_SIZE
class TestVertexAIListDatasetsOperator: