This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3fa9d46ec7 Refactor: Simplify code in providers/google (#33229)
3fa9d46ec7 is described below

commit 3fa9d46ec74ef8453fcf17fbd49280cb6fb37cef
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Tue Sep 12 21:24:17 2023 +0000

    Refactor: Simplify code in providers/google (#33229)
    
    
    
    Co-authored-by: Elad Kalif <[email protected]>
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/providers/google/cloud/hooks/bigquery.py   |  7 +---
 airflow/providers/google/cloud/hooks/gcs.py        | 48 ++++++++++------------
 .../providers/google/cloud/operators/bigquery.py   |  2 +-
 .../google/cloud/operators/cloud_build.py          |  2 +-
 .../google/cloud/transfers/facebook_ads_to_gcs.py  |  4 +-
 .../providers/google/cloud/transfers/gcs_to_gcs.py |  4 +-
 .../google/cloud/utils/bigquery_get_data.py        | 10 ++---
 .../google/cloud/utils/field_validator.py          | 18 ++++----
 8 files changed, 40 insertions(+), 55 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index c98e7e92ab..be486c6c8d 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2874,12 +2874,7 @@ class BigQueryCursor(BigQueryBaseCursor):
 
         A sequence of sequences (e.g. a list of tuples) is returned.
         """
-        result = []
-        while True:
-            one = self.fetchone()
-            if one is None:
-                break
-            result.append(one)
+        result = list(iter(self.fetchone, None))
         return result
 
     def get_arraysize(self) -> int:
diff --git a/airflow/providers/google/cloud/hooks/gcs.py 
b/airflow/providers/google/cloud/hooks/gcs.py
index acc4ad688a..d8bb36037f 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -330,11 +330,16 @@ class GCSHook(GoogleBaseHook):
         # TODO: future improvement check file size before downloading,
         #  to check for local space availability
 
-        num_file_attempts = 0
+        if num_max_attempts is None:
+            num_max_attempts = 3
+
+        for attempt in range(num_max_attempts):
+            if attempt:
+                # Wait with exponential backoff scheme before retrying.
+                timeout_seconds = 2**attempt
+                time.sleep(timeout_seconds)
 
-        while True:
             try:
-                num_file_attempts += 1
                 client = self.get_conn()
                 bucket = client.bucket(bucket_name, user_project=user_project)
                 blob = bucket.blob(blob_name=object_name, 
chunk_size=chunk_size)
@@ -347,19 +352,17 @@ class GCSHook(GoogleBaseHook):
                     return blob.download_as_bytes()
 
             except GoogleCloudError:
-                if num_file_attempts == num_max_attempts:
+                if attempt == num_max_attempts - 1:
                     self.log.error(
                         "Download attempt of object: %s from %s has failed. 
Attempt: %s, max %s.",
                         object_name,
                         bucket_name,
-                        num_file_attempts,
+                        attempt,
                         num_max_attempts,
                     )
                     raise
-
-                # Wait with exponential backoff scheme before retrying.
-                timeout_seconds = 2 ** (num_file_attempts - 1)
-                time.sleep(timeout_seconds)
+        else:
+            raise NotImplementedError  # should not reach this, but makes mypy 
happy
 
     def download_as_byte_array(
         self,
@@ -826,15 +829,10 @@ class GCSHook(GoogleBaseHook):
                     versions=versions,
                 )
 
-            blob_names = []
-            for blob in blobs:
-                blob_names.append(blob.name)
-
-            prefixes = blobs.prefixes
-            if prefixes:
-                ids += list(prefixes)
+            if blobs.prefixes:
+                ids.extend(blobs.prefixes)
             else:
-                ids += blob_names
+                ids.extend(blob.name for blob in blobs)
 
             page_token = blobs.next_page_token
             if page_token is None:
@@ -942,16 +940,14 @@ class GCSHook(GoogleBaseHook):
                     versions=versions,
                 )
 
-            blob_names = []
-            for blob in blobs:
-                if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) 
< timespan_end:
-                    blob_names.append(blob.name)
-
-            prefixes = blobs.prefixes
-            if prefixes:
-                ids += list(prefixes)
+            if blobs.prefixes:
+                ids.extend(blobs.prefixes)
             else:
-                ids += blob_names
+                ids.extend(
+                    blob.name
+                    for blob in blobs
+                    if timespan_start <= 
blob.updated.replace(tzinfo=timezone.utc) < timespan_end
+                )
 
             page_token = blobs.next_page_token
             if page_token is None:
diff --git a/airflow/providers/google/cloud/operators/bigquery.py 
b/airflow/providers/google/cloud/operators/bigquery.py
index 4a17c41ba2..1ed333d9dc 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -329,7 +329,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, 
SQLCheckOperator):
         records = event["records"]
         if not records:
             raise AirflowException("The query returned empty results")
-        elif not all(bool(r) for r in records):
+        elif not all(records):
             self._raise_exception(  # type: ignore[attr-defined]
                 f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
             )
diff --git a/airflow/providers/google/cloud/operators/cloud_build.py 
b/airflow/providers/google/cloud/operators/cloud_build.py
index 9daacefa72..690a64f349 100644
--- a/airflow/providers/google/cloud/operators/cloud_build.py
+++ b/airflow/providers/google/cloud/operators/cloud_build.py
@@ -202,7 +202,7 @@ class 
CloudBuildCreateBuildOperator(GoogleCloudBaseOperator):
         if not isinstance(self.build_raw, str):
             return
         with open(self.build_raw) as file:
-            if any(self.build_raw.endswith(ext) for ext in [".yaml", ".yml"]):
+            if self.build_raw.endswith((".yaml", ".yml")):
                 self.build = yaml.safe_load(file.read())
             if self.build_raw.endswith(".json"):
                 self.build = json.loads(file.read())
diff --git a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py 
b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
index bc0dae153a..758bd818ca 100644
--- a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
@@ -226,7 +226,5 @@ class FacebookAdsReportToGcsOperator(BaseOperator):
 
     def _transform_object_name_with_account_id(self, account_id: str):
         directory_parts = self.object_name.split("/")
-        directory_parts[len(directory_parts) - 1] = (
-            account_id + "_" + directory_parts[len(directory_parts) - 1]
-        )
+        directory_parts[-1] = f"{account_id}_{directory_parts[-1]}"
         return "/".join(directory_parts)
diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py 
b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
index b262a75e6a..ad7a41c23d 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
@@ -207,7 +207,7 @@ class GCSToGCSOperator(BaseOperator):
                 stacklevel=2,
             )
         self.source_object = source_object
-        if source_objects and any([WILDCARD in obj for obj in source_objects]):
+        if source_objects and any(WILDCARD in obj for obj in source_objects):
             warnings.warn(
                 "Usage of wildcard (*) in 'source_objects' is deprecated, 
utilize 'match_glob' instead",
                 AirflowProviderDeprecationWarning,
@@ -429,7 +429,7 @@ class GCSToGCSOperator(BaseOperator):
         # Check whether the prefix is a root directory for all the rest of 
objects.
         _pref = prefix.rstrip("/")
         is_directory = prefix.endswith("/") or all(
-            [obj.replace(_pref, "", 1).startswith("/") for obj in 
source_objects]
+            obj.replace(_pref, "", 1).startswith("/") for obj in source_objects
         )
 
         if is_directory:
diff --git a/airflow/providers/google/cloud/utils/bigquery_get_data.py 
b/airflow/providers/google/cloud/utils/bigquery_get_data.py
index 8fb61fc52c..d178aee963 100644
--- a/airflow/providers/google/cloud/utils/bigquery_get_data.py
+++ b/airflow/providers/google/cloud/utils/bigquery_get_data.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 from typing import TYPE_CHECKING
 
 from google.cloud.bigquery.table import Row, RowIterator
@@ -38,14 +39,13 @@ def bigquery_get_data(
     logger.info("Fetching Data from:")
     logger.info("Dataset: %s ; Table: %s", dataset_id, table_id)
 
-    i = 0
-    while True:
+    for start_index in itertools.count(step=batch_size):
         rows: list[Row] | RowIterator = big_query_hook.list_rows(
             dataset_id=dataset_id,
             table_id=table_id,
             max_results=batch_size,
             selected_fields=selected_fields,
-            start_index=i * batch_size,
+            start_index=start_index,
         )
 
         if isinstance(rows, RowIterator):
@@ -55,8 +55,6 @@ def bigquery_get_data(
             logger.info("Job Finished")
             return
 
-        logger.info("Total Extracted rows: %s", len(rows) + i * batch_size)
+        logger.info("Total Extracted rows: %s", len(rows) + start_index)
 
         yield [row.values() for row in rows]
-
-        i += 1
diff --git a/airflow/providers/google/cloud/utils/field_validator.py 
b/airflow/providers/google/cloud/utils/field_validator.py
index 87aee5d7af..415351c69c 100644
--- a/airflow/providers/google/cloud/utils/field_validator.py
+++ b/airflow/providers/google/cloud/utils/field_validator.py
@@ -257,7 +257,7 @@ class GcpBodyFieldValidator(LoggingMixin):
             self._validate_field(
                 validation_spec=child_validation_spec, 
dictionary_to_validate=value, parent=full_field_path
             )
-        all_dict_keys = [spec["name"] for spec in children_validation_specs]
+        all_dict_keys = {spec["name"] for spec in children_validation_specs}
         for field_name in value.keys():
             if field_name not in all_dict_keys:
                 self.log.warning(
@@ -428,20 +428,18 @@ class GcpBodyFieldValidator(LoggingMixin):
             raise GcpFieldValidationException(
                 f"There was an error when validating: body 
'{body_to_validate}': '{e}'"
             )
-        all_field_names = [
+        all_field_names = {
             spec["name"]
             for spec in self._validation_specs
             if spec.get("type") != "union" and spec.get("api_version") != 
self._api_version
-        ]
+        }
         all_union_fields = [spec for spec in self._validation_specs if 
spec.get("type") == "union"]
         for union_field in all_union_fields:
-            all_field_names.extend(
-                [
-                    nested_union_spec["name"]
-                    for nested_union_spec in union_field["fields"]
-                    if nested_union_spec.get("type") != "union"
-                    and nested_union_spec.get("api_version") != 
self._api_version
-                ]
+            all_field_names.update(
+                nested_union_spec["name"]
+                for nested_union_spec in union_field["fields"]
+                if nested_union_spec.get("type") != "union"
+                and nested_union_spec.get("api_version") != self._api_version
             )
         for field_name in body_to_validate.keys():
             if field_name not in all_field_names:

Reply via email to