This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new e83a98603e iterate through blobs before checking prefixes (#36202) e83a98603e is described below commit e83a98603ef15c7d57910c482ba75eb76ed79553 Author: Wei Lee <weilee...@gmail.com> AuthorDate: Thu Dec 14 20:09:53 2023 +0530 iterate through blobs before checking prefixes (#36202) * fix(providers/google): iterate through blobs before checking prefixes According to https://github.com/googleapis/python-storage/blob/v2.14.0/google/cloud/storage/client.py#L1213-L1217, the prefixes are not returned until the blobs are consumed * test(providers/google): add test cases to check gcs.list result --- airflow/providers/google/cloud/hooks/gcs.py | 18 ++++++----- tests/providers/google/cloud/hooks/test_gcs.py | 42 +++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 45a202124d..02055583ce 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -821,12 +821,13 @@ class GCSHook(GoogleBaseHook): delimiter=delimiter, versions=versions, ) - list(blobs) + + blob_names = [blob.name for blob in blobs] if blobs.prefixes: ids.extend(blobs.prefixes) else: - ids.extend(blob.name for blob in blobs) + ids.extend(blob_names) page_token = blobs.next_page_token if page_token is None: @@ -933,16 +934,17 @@ class GCSHook(GoogleBaseHook): delimiter=delimiter, versions=versions, ) - list(blobs) + + blob_names = [ + blob.name + for blob in blobs + if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end + ] if blobs.prefixes: ids.extend(blobs.prefixes) else: - ids.extend( - blob.name - for blob in blobs - if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end - ) + ids.extend(blob_names) page_token = blobs.next_page_token if page_token is None: diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 33df98e37b..825a357d39 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -21,6 +21,7 @@ import copy import logging import os import re +from collections import namedtuple from datetime import datetime, timedelta from io import BytesIO from unittest import mock @@ -799,14 +800,26 @@ class TestGCSHook: ) @pytest.mark.parametrize( - "prefix, result", + "prefix, blob_names, returned_prefixes, call_args, result", ( ( "prefix", + ["prefix"], + None, + [mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)], + ["prefix"], + ), + ( + "prefix", + ["prefix"], + {"prefix,"}, [mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)], + ["prefix,"], ), ( ["prefix", "prefix_2"], + ["prefix", "prefix2"], + None, [ mock.call( delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None @@ -815,19 +828,38 @@ class TestGCSHook: delimiter=",", prefix="prefix_2", versions=None, max_results=None, page_token=None ), ], + ["prefix", "prefix2"], ), ), ) @mock.patch(GCS_STRING.format("GCSHook.get_conn")) - def test_list__delimiter(self, mock_service, prefix, result): - mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None + def test_list__delimiter(self, mock_service, prefix, blob_names, returned_prefixes, call_args, result): + Blob = namedtuple("Blob", ["name"]) + + class BlobsIterator: + def __init__(self): + self._item_iter = (Blob(name=name) for name in blob_names) + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self._item_iter) + except StopIteration: + self.prefixes = returned_prefixes + self.next_page_token = None + raise + + mock_service.return_value.bucket.return_value.list_blobs.return_value = BlobsIterator() with pytest.deprecated_call(): - self.gcs_hook.list( + blobs = self.gcs_hook.list( bucket_name="test_bucket", prefix=prefix, delimiter=",", ) - assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result + assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == call_args + assert blobs == result @mock.patch(GCS_STRING.format("GCSHook.get_conn")) @mock.patch("airflow.providers.google.cloud.hooks.gcs.functools")