This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e83a98603e iterate through blobs before checking prefixes (#36202)
e83a98603e is described below

commit e83a98603ef15c7d57910c482ba75eb76ed79553
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Thu Dec 14 20:09:53 2023 +0530

    iterate through blobs before checking prefixes (#36202)
    
    * fix(providers/google): iterate through blobs before checking prefixes
    
    According to 
https://github.com/googleapis/python-storage/blob/v2.14.0/google/cloud/storage/client.py#L1213-L1217,
 the prefixes are not returned until the blobs are consumed
    
    * test(providers/google): add test cases to check gcs.list result
---
 airflow/providers/google/cloud/hooks/gcs.py    | 18 ++++++-----
 tests/providers/google/cloud/hooks/test_gcs.py | 42 +++++++++++++++++++++++---
 2 files changed, 47 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/gcs.py 
b/airflow/providers/google/cloud/hooks/gcs.py
index 45a202124d..02055583ce 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -821,12 +821,13 @@ class GCSHook(GoogleBaseHook):
                     delimiter=delimiter,
                     versions=versions,
                 )
-                list(blobs)
+
+            blob_names = [blob.name for blob in blobs]
 
             if blobs.prefixes:
                 ids.extend(blobs.prefixes)
             else:
-                ids.extend(blob.name for blob in blobs)
+                ids.extend(blob_names)
 
             page_token = blobs.next_page_token
             if page_token is None:
@@ -933,16 +934,17 @@ class GCSHook(GoogleBaseHook):
                     delimiter=delimiter,
                     versions=versions,
                 )
-                list(blobs)
+
+            blob_names = [
+                blob.name
+                for blob in blobs
+                if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) 
< timespan_end
+            ]
 
             if blobs.prefixes:
                 ids.extend(blobs.prefixes)
             else:
-                ids.extend(
-                    blob.name
-                    for blob in blobs
-                    if timespan_start <= 
blob.updated.replace(tzinfo=timezone.utc) < timespan_end
-                )
+                ids.extend(blob_names)
 
             page_token = blobs.next_page_token
             if page_token is None:
diff --git a/tests/providers/google/cloud/hooks/test_gcs.py 
b/tests/providers/google/cloud/hooks/test_gcs.py
index 33df98e37b..825a357d39 100644
--- a/tests/providers/google/cloud/hooks/test_gcs.py
+++ b/tests/providers/google/cloud/hooks/test_gcs.py
@@ -21,6 +21,7 @@ import copy
 import logging
 import os
 import re
+from collections import namedtuple
 from datetime import datetime, timedelta
 from io import BytesIO
 from unittest import mock
@@ -799,14 +800,26 @@ class TestGCSHook:
         )
 
     @pytest.mark.parametrize(
-        "prefix, result",
+        "prefix, blob_names, returned_prefixes, call_args, result",
         (
             (
                 "prefix",
+                ["prefix"],
+                None,
+                [mock.call(delimiter=",", prefix="prefix", versions=None, 
max_results=None, page_token=None)],
+                ["prefix"],
+            ),
+            (
+                "prefix",
+                ["prefix"],
+                {"prefix,"},
                 [mock.call(delimiter=",", prefix="prefix", versions=None, 
max_results=None, page_token=None)],
+                ["prefix,"],
             ),
             (
                 ["prefix", "prefix_2"],
+                ["prefix", "prefix2"],
+                None,
                 [
                     mock.call(
                         delimiter=",", prefix="prefix", versions=None, 
max_results=None, page_token=None
@@ -815,19 +828,38 @@ class TestGCSHook:
                         delimiter=",", prefix="prefix_2", versions=None, 
max_results=None, page_token=None
                     ),
                 ],
+                ["prefix", "prefix2"],
             ),
         ),
     )
     @mock.patch(GCS_STRING.format("GCSHook.get_conn"))
-    def test_list__delimiter(self, mock_service, prefix, result):
-        
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token
 = None
+    def test_list__delimiter(self, mock_service, prefix, blob_names, 
returned_prefixes, call_args, result):
+        Blob = namedtuple("Blob", ["name"])
+
+        class BlobsIterator:
+            def __init__(self):
+                self._item_iter = (Blob(name=name) for name in blob_names)
+
+            def __iter__(self):
+                return self
+
+            def __next__(self):
+                try:
+                    return next(self._item_iter)
+                except StopIteration:
+                    self.prefixes = returned_prefixes
+                    self.next_page_token = None
+                    raise
+
+        mock_service.return_value.bucket.return_value.list_blobs.return_value 
= BlobsIterator()
         with pytest.deprecated_call():
-            self.gcs_hook.list(
+            blobs = self.gcs_hook.list(
                 bucket_name="test_bucket",
                 prefix=prefix,
                 delimiter=",",
             )
-        assert 
mock_service.return_value.bucket.return_value.list_blobs.call_args_list == 
result
+        assert 
mock_service.return_value.bucket.return_value.list_blobs.call_args_list == 
call_args
+        assert blobs == result
 
     @mock.patch(GCS_STRING.format("GCSHook.get_conn"))
     @mock.patch("airflow.providers.google.cloud.hooks.gcs.functools")

Reply via email to