This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3fa9d46ec7 Refactor: Simplify code in providers/google (#33229)
3fa9d46ec7 is described below
commit 3fa9d46ec74ef8453fcf17fbd49280cb6fb37cef
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Tue Sep 12 21:24:17 2023 +0000
Refactor: Simplify code in providers/google (#33229)
Co-authored-by: Elad Kalif <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/providers/google/cloud/hooks/bigquery.py | 7 +---
airflow/providers/google/cloud/hooks/gcs.py | 48 ++++++++++------------
.../providers/google/cloud/operators/bigquery.py | 2 +-
.../google/cloud/operators/cloud_build.py | 2 +-
.../google/cloud/transfers/facebook_ads_to_gcs.py | 4 +-
.../providers/google/cloud/transfers/gcs_to_gcs.py | 4 +-
.../google/cloud/utils/bigquery_get_data.py | 10 ++---
.../google/cloud/utils/field_validator.py | 18 ++++----
8 files changed, 40 insertions(+), 55 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index c98e7e92ab..be486c6c8d 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2874,12 +2874,7 @@ class BigQueryCursor(BigQueryBaseCursor):
A sequence of sequences (e.g. a list of tuples) is returned.
"""
- result = []
- while True:
- one = self.fetchone()
- if one is None:
- break
- result.append(one)
+ result = list(iter(self.fetchone, None))
return result
def get_arraysize(self) -> int:
diff --git a/airflow/providers/google/cloud/hooks/gcs.py
b/airflow/providers/google/cloud/hooks/gcs.py
index acc4ad688a..d8bb36037f 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -330,11 +330,16 @@ class GCSHook(GoogleBaseHook):
# TODO: future improvement check file size before downloading,
# to check for local space availability
- num_file_attempts = 0
+ if num_max_attempts is None:
+ num_max_attempts = 3
+
+ for attempt in range(num_max_attempts):
+ if attempt:
+ # Wait with exponential backoff scheme before retrying.
+ timeout_seconds = 2**attempt
+ time.sleep(timeout_seconds)
- while True:
try:
- num_file_attempts += 1
client = self.get_conn()
bucket = client.bucket(bucket_name, user_project=user_project)
blob = bucket.blob(blob_name=object_name,
chunk_size=chunk_size)
@@ -347,19 +352,17 @@ class GCSHook(GoogleBaseHook):
return blob.download_as_bytes()
except GoogleCloudError:
- if num_file_attempts == num_max_attempts:
+ if attempt == num_max_attempts - 1:
self.log.error(
"Download attempt of object: %s from %s has failed.
Attempt: %s, max %s.",
object_name,
bucket_name,
- num_file_attempts,
+ attempt,
num_max_attempts,
)
raise
-
- # Wait with exponential backoff scheme before retrying.
- timeout_seconds = 2 ** (num_file_attempts - 1)
- time.sleep(timeout_seconds)
+ else:
+ raise NotImplementedError # should not reach this, but makes mypy
happy
def download_as_byte_array(
self,
@@ -826,15 +829,10 @@ class GCSHook(GoogleBaseHook):
versions=versions,
)
- blob_names = []
- for blob in blobs:
- blob_names.append(blob.name)
-
- prefixes = blobs.prefixes
- if prefixes:
- ids += list(prefixes)
+ if blobs.prefixes:
+ ids.extend(blobs.prefixes)
else:
- ids += blob_names
+ ids.extend(blob.name for blob in blobs)
page_token = blobs.next_page_token
if page_token is None:
@@ -942,16 +940,14 @@ class GCSHook(GoogleBaseHook):
versions=versions,
)
- blob_names = []
- for blob in blobs:
- if timespan_start <= blob.updated.replace(tzinfo=timezone.utc)
< timespan_end:
- blob_names.append(blob.name)
-
- prefixes = blobs.prefixes
- if prefixes:
- ids += list(prefixes)
+ if blobs.prefixes:
+ ids.extend(blobs.prefixes)
else:
- ids += blob_names
+ ids.extend(
+ blob.name
+ for blob in blobs
+ if timespan_start <=
blob.updated.replace(tzinfo=timezone.utc) < timespan_end
+ )
page_token = blobs.next_page_token
if page_token is None:
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index 4a17c41ba2..1ed333d9dc 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -329,7 +329,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin,
SQLCheckOperator):
records = event["records"]
if not records:
raise AirflowException("The query returned empty results")
- elif not all(bool(r) for r in records):
+ elif not all(records):
self._raise_exception( # type: ignore[attr-defined]
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
)
diff --git a/airflow/providers/google/cloud/operators/cloud_build.py
b/airflow/providers/google/cloud/operators/cloud_build.py
index 9daacefa72..690a64f349 100644
--- a/airflow/providers/google/cloud/operators/cloud_build.py
+++ b/airflow/providers/google/cloud/operators/cloud_build.py
@@ -202,7 +202,7 @@ class
CloudBuildCreateBuildOperator(GoogleCloudBaseOperator):
if not isinstance(self.build_raw, str):
return
with open(self.build_raw) as file:
- if any(self.build_raw.endswith(ext) for ext in [".yaml", ".yml"]):
+ if self.build_raw.endswith((".yaml", ".yml")):
self.build = yaml.safe_load(file.read())
if self.build_raw.endswith(".json"):
self.build = json.loads(file.read())
diff --git a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
index bc0dae153a..758bd818ca 100644
--- a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
@@ -226,7 +226,5 @@ class FacebookAdsReportToGcsOperator(BaseOperator):
def _transform_object_name_with_account_id(self, account_id: str):
directory_parts = self.object_name.split("/")
- directory_parts[len(directory_parts) - 1] = (
- account_id + "_" + directory_parts[len(directory_parts) - 1]
- )
+ directory_parts[-1] = f"{account_id}_{directory_parts[-1]}"
return "/".join(directory_parts)
diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
index b262a75e6a..ad7a41c23d 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
@@ -207,7 +207,7 @@ class GCSToGCSOperator(BaseOperator):
stacklevel=2,
)
self.source_object = source_object
- if source_objects and any([WILDCARD in obj for obj in source_objects]):
+ if source_objects and any(WILDCARD in obj for obj in source_objects):
warnings.warn(
"Usage of wildcard (*) in 'source_objects' is deprecated,
utilize 'match_glob' instead",
AirflowProviderDeprecationWarning,
@@ -429,7 +429,7 @@ class GCSToGCSOperator(BaseOperator):
# Check whether the prefix is a root directory for all the rest of
objects.
_pref = prefix.rstrip("/")
is_directory = prefix.endswith("/") or all(
- [obj.replace(_pref, "", 1).startswith("/") for obj in
source_objects]
+ obj.replace(_pref, "", 1).startswith("/") for obj in source_objects
)
if is_directory:
diff --git a/airflow/providers/google/cloud/utils/bigquery_get_data.py
b/airflow/providers/google/cloud/utils/bigquery_get_data.py
index 8fb61fc52c..d178aee963 100644
--- a/airflow/providers/google/cloud/utils/bigquery_get_data.py
+++ b/airflow/providers/google/cloud/utils/bigquery_get_data.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import itertools
from typing import TYPE_CHECKING
from google.cloud.bigquery.table import Row, RowIterator
@@ -38,14 +39,13 @@ def bigquery_get_data(
logger.info("Fetching Data from:")
logger.info("Dataset: %s ; Table: %s", dataset_id, table_id)
- i = 0
- while True:
+ for start_index in itertools.count(step=batch_size):
rows: list[Row] | RowIterator = big_query_hook.list_rows(
dataset_id=dataset_id,
table_id=table_id,
max_results=batch_size,
selected_fields=selected_fields,
- start_index=i * batch_size,
+ start_index=start_index,
)
if isinstance(rows, RowIterator):
@@ -55,8 +55,6 @@ def bigquery_get_data(
logger.info("Job Finished")
return
- logger.info("Total Extracted rows: %s", len(rows) + i * batch_size)
+ logger.info("Total Extracted rows: %s", len(rows) + start_index)
yield [row.values() for row in rows]
-
- i += 1
diff --git a/airflow/providers/google/cloud/utils/field_validator.py
b/airflow/providers/google/cloud/utils/field_validator.py
index 87aee5d7af..415351c69c 100644
--- a/airflow/providers/google/cloud/utils/field_validator.py
+++ b/airflow/providers/google/cloud/utils/field_validator.py
@@ -257,7 +257,7 @@ class GcpBodyFieldValidator(LoggingMixin):
self._validate_field(
validation_spec=child_validation_spec,
dictionary_to_validate=value, parent=full_field_path
)
- all_dict_keys = [spec["name"] for spec in children_validation_specs]
+ all_dict_keys = {spec["name"] for spec in children_validation_specs}
for field_name in value.keys():
if field_name not in all_dict_keys:
self.log.warning(
@@ -428,20 +428,18 @@ class GcpBodyFieldValidator(LoggingMixin):
raise GcpFieldValidationException(
f"There was an error when validating: body
'{body_to_validate}': '{e}'"
)
- all_field_names = [
+ all_field_names = {
spec["name"]
for spec in self._validation_specs
if spec.get("type") != "union" and spec.get("api_version") !=
self._api_version
- ]
+ }
all_union_fields = [spec for spec in self._validation_specs if
spec.get("type") == "union"]
for union_field in all_union_fields:
- all_field_names.extend(
- [
- nested_union_spec["name"]
- for nested_union_spec in union_field["fields"]
- if nested_union_spec.get("type") != "union"
- and nested_union_spec.get("api_version") !=
self._api_version
- ]
+ all_field_names.update(
+ nested_union_spec["name"]
+ for nested_union_spec in union_field["fields"]
+ if nested_union_spec.get("type") != "union"
+ and nested_union_spec.get("api_version") != self._api_version
)
for field_name in body_to_validate.keys():
if field_name not in all_field_names: