This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5be06ca96e7 Attempt best-effort cancellation of pending futures on 
`GoogleCloudError` and cap worker count to number of files in 
`GCSTimeSpanFileTransformOperator`. This avoids scheduling unnecessary work 
during failures and prevents over-provisioning threads for small batches. 
Existing failure semantics are preserved (`*_continue_on_fail` unchanged). 
Updated tests to validate cancellation behaviour and worker cap. (#64511)
5be06ca96e7 is described below

commit 5be06ca96e7aa3f40f5f339c9e6bb07b34c3c1c1
Author: SameerMesiah97 <[email protected]>
AuthorDate: Wed Apr 1 13:47:51 2026 +0100

    Attempt best-effort cancellation of pending futures on `GoogleCloudError` 
and cap worker count to number of files in `GCSTimeSpanFileTransformOperator`. 
This avoids scheduling unnecessary work during failures and prevents 
over-provisioning threads for small batches. Existing failure semantics are 
preserved (`*_continue_on_fail` unchanged). Updated tests to validate 
cancellation behaviour and worker cap. (#64511)
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../providers/google/cloud/operators/gcs.py        | 26 ++++++--
 .../tests/unit/google/cloud/operators/test_gcs.py  | 75 +++++++++++++++++++---
 2 files changed, 85 insertions(+), 16 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py 
b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
index ffe58f8db96..bc5b044b60f 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
@@ -876,10 +876,13 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
             temp_input_dir_path = Path(temp_input_dir)
             temp_output_dir_path = Path(temp_output_dir)
 
+            num_downloads = len(blobs_to_transform)
+            download_workers = min(self.max_download_workers, num_downloads) 
if num_downloads > 0 else 1
+
             self.log.info(
                 "Downloading %d files using %d workers",
-                len(blobs_to_transform),
-                self.max_download_workers,
+                num_downloads,
+                download_workers,
             )
 
             # Get storage client once (storage.Client is thread-safe for 
concurrent requests).
@@ -897,7 +900,7 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
 
                 return blob_name
 
-            with ThreadPoolExecutor(max_workers=self.max_download_workers) as 
executor:
+            with ThreadPoolExecutor(max_workers=download_workers) as executor:
                 futures = {executor.submit(_download, blob): blob for blob in 
blobs_to_transform}
 
                 for future in as_completed(futures):
@@ -906,6 +909,10 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
                         future.result()
                     except GoogleCloudError as e:
                         if not self.download_continue_on_fail:
+                            # Attempt to cancel pending futures to reduce 
unnecessary work.
+                            # Note: futures already running cannot be 
cancelled.
+                            for f in futures:
+                                f.cancel()
                             raise
                         self.log.warning("Download failed for %s: %s", blob, e)
 
@@ -935,10 +942,13 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
 
             upload_candidates = [f for f in temp_output_dir_path.glob("**/*") 
if f.is_file()]
 
+            num_uploads = len(upload_candidates)
+            upload_workers = min(self.max_upload_workers, num_uploads) if 
num_uploads > 0 else 1
+
             self.log.info(
                 "Uploading %d files using %d workers",
-                len(upload_candidates),
-                self.max_upload_workers,
+                num_uploads,
+                upload_workers,
             )
 
             destination_hook = GCSHook(
@@ -969,7 +979,7 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
 
             files_uploaded: list[str] = []
 
-            with ThreadPoolExecutor(max_workers=self.max_upload_workers) as 
executor:
+            with ThreadPoolExecutor(max_workers=upload_workers) as executor:
                 futures = {
                     executor.submit(_upload, upload_file): str(upload_file)
                     for upload_file in upload_candidates
@@ -982,6 +992,10 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
                         files_uploaded.append(uploaded_name)
                     except GoogleCloudError as e:
                         if not self.upload_continue_on_fail:
+                            # Attempt to cancel pending futures to reduce 
unnecessary work.
+                            # Note: futures already running cannot be 
cancelled.
+                            for f in futures:
+                                f.cancel()
                             raise
                         self.log.warning("Upload failed for %s: %s", 
upload_file, e)
 
diff --git a/providers/google/tests/unit/google/cloud/operators/test_gcs.py 
b/providers/google/tests/unit/google/cloud/operators/test_gcs.py
index ea8edea51e7..1ecd31754e3 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_gcs.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_gcs.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from concurrent.futures import ThreadPoolExecutor
 from datetime import datetime, timedelta, timezone
 from pathlib import Path
 from unittest import mock
@@ -665,13 +666,15 @@ class TestGCSTimeSpanFileTransformOperator:
             (0, True),
             (1, False),
             (2, False),
+            (10, False),
         ],
     )
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.ThreadPoolExecutor", 
wraps=ThreadPoolExecutor)
     
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
     def test_parallel_download_worker_behavior(
-        self, mock_hook, mock_subprocess, mock_tempdir, workers, should_raise
+        self, mock_hook, mock_subprocess, mock_tempdir, mock_executor, 
workers, should_raise
     ):
         timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
         timespan_end = timespan_start + timedelta(hours=1)
@@ -729,12 +732,17 @@ class TestGCSTimeSpanFileTransformOperator:
 
         assert mock_blob.download_to_filename.call_count == 2
 
+        expected_workers = min(workers, 2)
+        mock_executor.assert_any_call(max_workers=expected_workers)
+
     @pytest.mark.parametrize("continue_on_fail", [False, True])
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.as_completed")
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.ThreadPoolExecutor")
     
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
     def test_parallel_download_failure_behavior(
-        self, mock_hook, mock_subprocess, mock_tempdir, continue_on_fail
+        self, mock_hook, mock_subprocess, mock_tempdir, mock_executor, 
mock_as_completed, continue_on_fail
     ):
         timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
         timespan_end = timespan_start + timedelta(hours=1)
@@ -748,11 +756,28 @@ class TestGCSTimeSpanFileTransformOperator:
         }
 
         mock_tempdir.return_value.__enter__.side_effect = ["source", 
"destination"]
-        mock_hook.return_value.list_by_timespan.return_value = ["file1"]
 
-        mock_client, mock_bucket, mock_blob = 
self._setup_gcs_client_chain(mock_hook)
+        mock_hook.return_value.list_by_timespan.return_value = ["file1", 
"file2"]
+
+        self._setup_gcs_client_chain(mock_hook)
+
+        failing_future = mock.Mock()
+        failing_future.result.side_effect = GoogleCloudError("fail")
+        failing_future.cancel = mock.Mock()
 
-        mock_blob.download_to_filename.side_effect = GoogleCloudError("fail")
+        other_future = mock.Mock()
+        other_future.result.return_value = None
+        other_future.cancel = mock.Mock()
+
+        mock_executor.return_value.__enter__.return_value.submit.side_effect = 
[
+            failing_future,
+            other_future,
+        ]
+
+        # Force deterministic completion order for futures
+        # Note: this does not reflect true as_completed behaviour but allows
+        # us to validate cancellation logic.
+        mock_as_completed.side_effect = lambda futures: list(futures.keys())
 
         mock_proc = mock.MagicMock()
         mock_proc.returncode = 0
@@ -782,19 +807,23 @@ class TestGCSTimeSpanFileTransformOperator:
             with pytest.raises(GoogleCloudError):
                 op.execute(context=context)
 
+            other_future.cancel.assert_called()
+
     @pytest.mark.parametrize(
         ("workers", "should_raise"),
         [
             (0, True),
             (1, False),
             (2, False),
+            (10, False),
         ],
     )
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.ThreadPoolExecutor", 
wraps=ThreadPoolExecutor)
     
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
     def test_parallel_upload_worker_behavior(
-        self, mock_hook, mock_subprocess, mock_tempdir, workers, should_raise
+        self, mock_hook, mock_subprocess, mock_tempdir, mock_executor, 
workers, should_raise
     ):
         timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
         timespan_end = timespan_start + timedelta(hours=1)
@@ -861,12 +890,17 @@ class TestGCSTimeSpanFileTransformOperator:
 
         assert mock_blob.upload_from_filename.call_count == 2
 
+        expected_workers = min(workers, 2)
+        mock_executor.assert_any_call(max_workers=expected_workers)
+
     @pytest.mark.parametrize("continue_on_fail", [False, True])
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.as_completed")
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.ThreadPoolExecutor")
     
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
     def test_parallel_upload_failure_behavior(
-        self, mock_hook, mock_subprocess, mock_tempdir, continue_on_fail
+        self, mock_hook, mock_subprocess, mock_tempdir, mock_executor, 
mock_as_completed, continue_on_fail
     ):
         timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
         timespan_end = timespan_start + timedelta(hours=1)
@@ -891,9 +925,25 @@ class TestGCSTimeSpanFileTransformOperator:
         mock_subprocess.PIPE = "pipe"
         mock_subprocess.STDOUT = "stdout"
 
-        mock_client, mock_bucket, mock_blob = 
self._setup_gcs_client_chain(mock_hook)
+        self._setup_gcs_client_chain(mock_hook)
+
+        failing_future = mock.Mock()
+        failing_future.result.side_effect = GoogleCloudError("fail")
+        failing_future.cancel = mock.Mock()
+
+        other_future = mock.Mock()
+        other_future.result.return_value = None
+        other_future.cancel = mock.Mock()
+
+        mock_executor.return_value.__enter__.return_value.submit.side_effect = 
[
+            failing_future,
+            other_future,
+        ]
 
-        mock_blob.upload_from_filename.side_effect = GoogleCloudError("fail")
+        # Force deterministic completion order for futures
+        # Note: this does not reflect true as_completed behaviour but allows
+        # us to validate cancellation logic.
+        mock_as_completed.side_effect = lambda futures: list(futures.keys())
 
         op = GCSTimeSpanFileTransformOperator(
             task_id="test",
@@ -912,7 +962,10 @@ class TestGCSTimeSpanFileTransformOperator:
             mock.patch.object(Path, "glob") as path_glob,
             mock.patch.object(Path, "is_file", return_value=True),
         ):
-            path_glob.return_value.__iter__.return_value = 
[Path("destination/a.txt")]
+            path_glob.return_value.__iter__.return_value = [
+                Path("destination/a.txt"),
+                Path("destination/b.txt"),
+            ]
 
             if continue_on_fail:
                 op.execute(context=context)
@@ -920,6 +973,8 @@ class TestGCSTimeSpanFileTransformOperator:
                 with pytest.raises(GoogleCloudError):
                     op.execute(context=context)
 
+                other_future.cancel.assert_called()
+
 
 class TestGCSDeleteBucketOperator:
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")

Reply via email to