This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c3348af0972 Add parallel download and upload support to 
`GCSTimeSpanFileTransformOperator` by introducing `max_download_workers` and 
`max_upload_workers` (defaulting to 1) and executing both phases via 
`ThreadPoolExecutor` when configured. The implementation reuses a single 
`storage.Client` per phase, filters out directory placeholder objects before 
download, and preserves existing error propagation and `*_continue_on_fail` 
semantics. (#62196)
c3348af0972 is described below

commit c3348af09722cd07645b33d586f3e8665182c4cb
Author: SameerMesiah97 <[email protected]>
AuthorDate: Mon Mar 16 00:15:54 2026 +0000

    Add parallel download and upload support to 
`GCSTimeSpanFileTransformOperator` by introducing `max_download_workers` and 
`max_upload_workers` (defaulting to 1) and executing both phases via 
`ThreadPoolExecutor` when configured. The implementation reuses a single 
`storage.Client` per phase, filters out directory placeholder objects before 
download, and preserves existing error propagation and `*_continue_on_fail` 
semantics. (#62196)
    
    Unit tests have been updated and extended to cover worker validation, 
parallel execution paths, and failure handling. The existing 
`gcs_transform_timespan` system test DAG has also been extended with a parallel 
execution case to validate the behavior against real GCS.
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../providers/google/cloud/operators/gcs.py        | 147 ++++++---
 .../cloud/gcs/example_gcs_transform_timespan.py    |  16 +
 .../tests/unit/google/cloud/operators/test_gcs.py  | 330 +++++++++++++++++++--
 3 files changed, 436 insertions(+), 57 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py 
b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
index 1abd3122c10..1cfb40ff0cb 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
@@ -24,6 +24,7 @@ import subprocess
 import sys
 import warnings
 from collections.abc import Sequence
+from concurrent.futures import ThreadPoolExecutor, as_completed
 from pathlib import Path
 from tempfile import NamedTemporaryFile, TemporaryDirectory
 from typing import TYPE_CHECKING
@@ -681,6 +682,10 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
     data from source, transform it and write the output to the local
     destination file.
 
+    Downloads and uploads can be executed in parallel by configuring
+    ``max_download_workers`` and ``max_upload_workers``. By default,
+    execution is sequential.
+
     :param source_bucket: The bucket to fetch data from. (templated)
     :param source_prefix: Prefix string which filters objects whose name begin 
with
            this prefix. Can interpolate logical date and time components. 
(templated)
@@ -722,6 +727,10 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
     :param upload_continue_on_fail: With this set to true, if an upload fails 
the task does not error out
         but will still continue.
     :param upload_num_attempts: Number of attempts to try to upload a single 
file.
+    :param max_download_workers: Maximum number of worker threads to use for 
parallel downloads.
+        Must be greater than or equal to 1. Defaults to 1 (sequential 
execution).
+    :param max_upload_workers: Maximum number of worker threads to use for 
parallel uploads.
+        Must be greater than or equal to 1. Defaults to 1 (sequential 
execution).
     """
 
     template_fields: Sequence[str] = (
@@ -765,6 +774,8 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
         download_num_attempts: int = 1,
         upload_continue_on_fail: bool | None = False,
         upload_num_attempts: int = 1,
+        max_download_workers: int = 1,
+        max_upload_workers: int = 1,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -787,6 +798,15 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
         self.upload_continue_on_fail = upload_continue_on_fail
         self.upload_num_attempts = upload_num_attempts
 
+        if max_download_workers < 1:
+            raise ValueError("max_download_workers must be >= 1")
+
+        if max_upload_workers < 1:
+            raise ValueError("max_upload_workers must be >= 1")
+
+        self.max_download_workers = max_download_workers
+        self.max_upload_workers = max_upload_workers
+
         self._source_prefix_interp: str | None = None
         self._destination_prefix_interp: str | None = None
 
@@ -838,41 +858,67 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
         )
 
         # Fetch list of files.
-        blobs_to_transform = source_hook.list_by_timespan(
-            bucket_name=self.source_bucket,
-            prefix=self._source_prefix_interp,
-            timespan_start=timespan_start,
-            timespan_end=timespan_end,
-        )
+
+        blobs_to_transform = [
+            blob
+            for blob in source_hook.list_by_timespan(
+                bucket_name=self.source_bucket,
+                prefix=self._source_prefix_interp,
+                timespan_start=timespan_start,
+                timespan_end=timespan_end,
+            )
+            # Filter out "directory" placeholders (GCS objects ending with '/')
+            # to avoid attempting to download non-file blobs.
+            if not blob.endswith("/")
+        ]
 
         with TemporaryDirectory() as temp_input_dir, TemporaryDirectory() as 
temp_output_dir:
             temp_input_dir_path = Path(temp_input_dir)
             temp_output_dir_path = Path(temp_output_dir)
 
-            # TODO: download in parallel.
-            for blob_to_transform in blobs_to_transform:
-                destination_file = temp_input_dir_path / blob_to_transform
+            self.log.info(
+                "Downloading %d files using %d workers",
+                len(blobs_to_transform),
+                self.max_download_workers,
+            )
+
+            # Get storage client once (storage.Client is thread-safe for 
concurrent requests).
+            client = source_hook.get_conn()
+
+            def _download(blob_name: str):
+
+                bucket = client.bucket(bucket_name=self.source_bucket)
+                blob = bucket.blob(blob_name=blob_name, 
chunk_size=self.chunk_size)
+
+                destination_file = temp_input_dir_path / blob_name
                 destination_file.parent.mkdir(parents=True, exist_ok=True)
-                try:
-                    source_hook.download(
-                        bucket_name=self.source_bucket,
-                        object_name=blob_to_transform,
-                        filename=str(destination_file),
-                        chunk_size=self.chunk_size,
-                        num_max_attempts=self.download_num_attempts,
-                    )
-                except GoogleCloudError:
-                    if not self.download_continue_on_fail:
-                        raise
+
+                blob.download_to_filename(filename=str(destination_file))
+
+                return blob_name
+
+            with ThreadPoolExecutor(max_workers=self.max_download_workers) as 
executor:
+                futures = {executor.submit(_download, blob): blob for blob in 
blobs_to_transform}
+
+                for future in as_completed(futures):
+                    blob = futures[future]
+                    try:
+                        future.result()
+                    except GoogleCloudError as e:
+                        if not self.download_continue_on_fail:
+                            raise
+                        self.log.warning("Download failed for %s: %s", blob, e)
 
             self.log.info("Starting the transformation")
             cmd = [self.transform_script] if isinstance(self.transform_script, 
str) else self.transform_script
+
             cmd += [
                 str(temp_input_dir_path),
                 str(temp_output_dir_path),
                 timespan_start.replace(microsecond=0).isoformat(),
                 timespan_end.replace(microsecond=0).isoformat(),
             ]
+
             with subprocess.Popen(
                 args=cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, 
close_fds=True
             ) as process:
@@ -887,32 +933,57 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
 
             self.log.info("Transformation succeeded. Output temporarily 
located at %s", temp_output_dir_path)
 
-            files_uploaded = []
+            upload_candidates = [f for f in temp_output_dir_path.glob("**/*") 
if f.is_file()]
 
-            # TODO: upload in parallel.
-            for upload_file in temp_output_dir_path.glob("**/*"):
-                if upload_file.is_dir():
-                    continue
+            self.log.info(
+                "Uploading %d files using %d workers",
+                len(upload_candidates),
+                self.max_upload_workers,
+            )
 
+            destination_hook = GCSHook(
+                gcp_conn_id=self.destination_gcp_conn_id,
+                impersonation_chain=self.destination_impersonation_chain,
+            )
+
+            # Get storage client once (storage.Client is thread-safe for 
concurrent requests).
+            client = destination_hook.get_conn()
+
+            def _upload(upload_file: Path):
+
+                bucket = client.bucket(bucket_name=self.destination_bucket)
+
+                # Preserve directory structure relative to the output temp 
directory.
                 upload_file_name = 
str(upload_file.relative_to(temp_output_dir_path))
 
                 if self._destination_prefix_interp is not None:
                     upload_file_name = 
f"{self._destination_prefix_interp.rstrip('/')}/{upload_file_name}"
 
-                self.log.info("Uploading file %s to %s", upload_file, 
upload_file_name)
+                blob = bucket.blob(blob_name=upload_file_name, 
chunk_size=self.chunk_size)
 
-                try:
-                    destination_hook.upload(
-                        bucket_name=self.destination_bucket,
-                        object_name=upload_file_name,
-                        filename=str(upload_file),
-                        chunk_size=self.chunk_size,
-                        num_max_attempts=self.upload_num_attempts,
-                    )
-                    files_uploaded.append(str(upload_file_name))
-                except GoogleCloudError:
-                    if not self.upload_continue_on_fail:
-                        raise
+                blob.upload_from_filename(
+                    filename=str(upload_file),
+                )
+
+                return upload_file_name
+
+            files_uploaded: list[str] = []
+
+            with ThreadPoolExecutor(max_workers=self.max_upload_workers) as 
executor:
+                futures = {
+                    executor.submit(_upload, upload_file): str(upload_file)
+                    for upload_file in upload_candidates
+                }
+
+                for future in as_completed(futures):
+                    upload_file = futures[future]
+                    try:
+                        uploaded_name = future.result()
+                        files_uploaded.append(uploaded_name)
+                    except GoogleCloudError as e:
+                        if not self.upload_continue_on_fail:
+                            raise
+                        self.log.warning("Upload failed for %s: %s", 
upload_file, e)
 
             return files_uploaded
 
diff --git 
a/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
 
b/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
index 8b9b74d3615..b0b175f5bf9 100644
--- 
a/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
+++ 
b/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
@@ -57,6 +57,8 @@ SOURCE_GCP_CONN_ID = DESTINATION_GCP_CONN_ID = 
"google_cloud_default"
 FILE_NAME = "example_upload.txt"
 SOURCE_PREFIX = "timespan_source"
 DESTINATION_PREFIX = "timespan_destination"
+DESTINATION_PREFIX_PARALLEL = "timespan_destination_parallel"
+
 UPLOAD_FILE_PATH = f"gcs/{FILE_NAME}"
 
 TRANSFORM_SCRIPT_PATH = str(Path(__file__).parent / "resources" / 
"transform_timespan.py")
@@ -101,6 +103,19 @@ with DAG(
         destination_gcp_conn_id=DESTINATION_GCP_CONN_ID,
         transform_script=["python", TRANSFORM_SCRIPT_PATH],
     )
+
+    gcs_timespan_transform_files_parallel = GCSTimeSpanFileTransformOperator(
+        task_id="gcs_timespan_transform_files_parallel",
+        source_bucket=BUCKET_NAME_SRC,
+        source_prefix=SOURCE_PREFIX,
+        source_gcp_conn_id=SOURCE_GCP_CONN_ID,
+        destination_bucket=BUCKET_NAME_DST,
+        destination_prefix=DESTINATION_PREFIX_PARALLEL,
+        destination_gcp_conn_id=DESTINATION_GCP_CONN_ID,
+        transform_script=["python", TRANSFORM_SCRIPT_PATH],
+        max_download_workers=2,
+        max_upload_workers=2,
+    )
     # [END howto_operator_gcs_timespan_file_transform_operator_Task]
 
     delete_bucket_src = GCSDeleteBucketOperator(
@@ -121,6 +136,7 @@ with DAG(
         copy_file,
         # TEST BODY
         gcs_timespan_transform_files_task,
+        gcs_timespan_transform_files_parallel,
         # TEST TEARDOWN
         [delete_bucket_src, delete_bucket_dst, check_openlineage_events],
     )
diff --git a/providers/google/tests/unit/google/cloud/operators/test_gcs.py 
b/providers/google/tests/unit/google/cloud/operators/test_gcs.py
index 48e9561ac53..ea8edea51e7 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_gcs.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_gcs.py
@@ -22,6 +22,7 @@ from pathlib import Path
 from unittest import mock
 
 import pytest
+from google.cloud.exceptions import GoogleCloudError
 
 from airflow.providers.common.compat.openlineage.facet import (
     Dataset,
@@ -226,6 +227,8 @@ class TestGCSDeleteObjectsOperator:
         assert len(lineage.outputs) == 0
         assert all(element in lineage.inputs for element in expected_inputs)
         assert all(element in expected_inputs for element in lineage.inputs)
+        print("EXPECTED:", expected_inputs)
+        print("ACTUAL:", lineage.inputs)
 
 
 class TestGoogleCloudStorageListOperator:
@@ -372,6 +375,13 @@ class 
TestGCSTimeSpanFileTransformOperatorDateInterpolation:
 
 
 class TestGCSTimeSpanFileTransformOperator:
+    def _setup_gcs_client_chain(self, mock_hook):
+        mock_client = mock.MagicMock()
+        mock_hook.return_value.get_conn.return_value = mock_client
+        mock_bucket = mock_client.bucket.return_value
+        mock_blob = mock_bucket.blob.return_value
+        return mock_client, mock_bucket, mock_blob
+
     
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
@@ -410,6 +420,8 @@ class TestGCSTimeSpanFileTransformOperator:
             f"{source_prefix}/{file2}",
         ]
 
+        mock_client, mock_bucket, mock_blob = 
self._setup_gcs_client_chain(mock_hook)
+
         mock_proc = mock.MagicMock()
         mock_proc.returncode = 0
         mock_proc.stdout.readline = lambda: b""
@@ -432,7 +444,10 @@ class TestGCSTimeSpanFileTransformOperator:
             transform_script=transform_script,
         )
 
-        with mock.patch.object(Path, "glob") as path_glob:
+        with (
+            mock.patch.object(Path, "glob") as path_glob,
+            mock.patch.object(Path, "is_file", return_value=True),
+        ):
             path_glob.return_value.__iter__.return_value = [
                 Path(f"{destination}/{file1}"),
                 Path(f"{destination}/{file2}"),
@@ -446,21 +461,29 @@ class TestGCSTimeSpanFileTransformOperator:
             prefix=source_prefix,
         )
 
-        mock_hook.return_value.download.assert_has_calls(
+        mock_client.bucket.assert_has_calls(
+            [
+                mock.call(bucket_name=source_bucket),
+                mock.call(bucket_name=source_bucket),
+            ],
+            any_order=True,
+        )
+
+        mock_bucket.blob.assert_has_calls(
+            [
+                mock.call(blob_name=f"{source_prefix}/{file1}", 
chunk_size=None),
+                mock.call(blob_name=f"{source_prefix}/{file2}", 
chunk_size=None),
+            ],
+            any_order=True,
+        )
+
+        mock_blob.download_to_filename.assert_has_calls(
             [
                 mock.call(
-                    bucket_name=source_bucket,
-                    object_name=f"{source_prefix}/{file1}",
                     filename=f"{source}/{source_prefix}/{file1}",
-                    chunk_size=None,
-                    num_max_attempts=1,
                 ),
                 mock.call(
-                    bucket_name=source_bucket,
-                    object_name=f"{source_prefix}/{file2}",
                     filename=f"{source}/{source_prefix}/{file2}",
-                    chunk_size=None,
-                    num_max_attempts=1,
                 ),
             ]
         )
@@ -478,21 +501,29 @@ class TestGCSTimeSpanFileTransformOperator:
             close_fds=True,
         )
 
-        mock_hook.return_value.upload.assert_has_calls(
+        mock_client.bucket.assert_has_calls(
+            [
+                mock.call(bucket_name=destination_bucket),
+                mock.call(bucket_name=destination_bucket),
+            ],
+            any_order=True,
+        )
+
+        mock_bucket.blob.assert_has_calls(
+            [
+                mock.call(blob_name=f"{destination_prefix}/{file1}", 
chunk_size=None),
+                mock.call(blob_name=f"{destination_prefix}/{file2}", 
chunk_size=None),
+            ],
+            any_order=True,
+        )
+
+        mock_blob.upload_from_filename.assert_has_calls(
             [
                 mock.call(
-                    bucket_name=destination_bucket,
                     filename=f"{destination}/{file1}",
-                    object_name=f"{destination_prefix}/{file1}",
-                    chunk_size=None,
-                    num_max_attempts=1,
                 ),
                 mock.call(
-                    bucket_name=destination_bucket,
                     filename=f"{destination}/{file2}",
-                    object_name=f"{destination_prefix}/{file2}",
-                    chunk_size=None,
-                    num_max_attempts=1,
                 ),
             ]
         )
@@ -628,6 +659,267 @@ class TestGCSTimeSpanFileTransformOperator:
         assert all(element in lineage.outputs for element in outputs)
         assert all(element in outputs for element in lineage.outputs)
 
+    @pytest.mark.parametrize(
+        ("workers", "should_raise"),
+        [
+            (0, True),
+            (1, False),
+            (2, False),
+        ],
+    )
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+    def test_parallel_download_worker_behavior(
+        self, mock_hook, mock_subprocess, mock_tempdir, workers, should_raise
+    ):
+        timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+        timespan_end = timespan_start + timedelta(hours=1)
+
+        context = {
+            "logical_date": timespan_start,
+            "data_interval_start": timespan_start,
+            "data_interval_end": timespan_end,
+            "ti": mock.Mock(),
+            "task": mock.MagicMock(),
+        }
+
+        if should_raise:
+            with pytest.raises(ValueError, match="max_download_workers must be 
>= 1"):
+                GCSTimeSpanFileTransformOperator(
+                    task_id="test",
+                    source_bucket="bucket",
+                    source_prefix="prefix",
+                    source_gcp_conn_id="",
+                    destination_bucket="dest",
+                    destination_prefix="dest",
+                    destination_gcp_conn_id="",
+                    transform_script="script.py",
+                    max_download_workers=workers,
+                )
+            return
+
+        mock_tempdir.return_value.__enter__.side_effect = ["source", 
"destination"]
+        mock_hook.return_value.list_by_timespan.return_value = ["file1", 
"file2"]
+
+        mock_client, mock_bucket, mock_blob = 
self._setup_gcs_client_chain(mock_hook)
+
+        mock_proc = mock.MagicMock()
+        mock_proc.returncode = 0
+        mock_proc.stdout.readline = lambda: b""
+        mock_proc.wait.return_value = None
+
+        mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+        mock_subprocess.PIPE = "pipe"
+        mock_subprocess.STDOUT = "stdout"
+
+        op = GCSTimeSpanFileTransformOperator(
+            task_id="test",
+            source_bucket="bucket",
+            source_prefix="prefix",
+            source_gcp_conn_id="",
+            destination_bucket="dest",
+            destination_prefix="dest",
+            destination_gcp_conn_id="",
+            transform_script="script.py",
+            max_download_workers=workers,
+        )
+
+        op.execute(context=context)
+
+        assert mock_blob.download_to_filename.call_count == 2
+
+    @pytest.mark.parametrize("continue_on_fail", [False, True])
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+    def test_parallel_download_failure_behavior(
+        self, mock_hook, mock_subprocess, mock_tempdir, continue_on_fail
+    ):
+        timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+        timespan_end = timespan_start + timedelta(hours=1)
+
+        context = {
+            "logical_date": timespan_start,
+            "data_interval_start": timespan_start,
+            "data_interval_end": timespan_end,
+            "ti": mock.Mock(),
+            "task": mock.MagicMock(),
+        }
+
+        mock_tempdir.return_value.__enter__.side_effect = ["source", 
"destination"]
+        mock_hook.return_value.list_by_timespan.return_value = ["file1"]
+
+        mock_client, mock_bucket, mock_blob = 
self._setup_gcs_client_chain(mock_hook)
+
+        mock_blob.download_to_filename.side_effect = GoogleCloudError("fail")
+
+        mock_proc = mock.MagicMock()
+        mock_proc.returncode = 0
+        mock_proc.stdout.readline = lambda: b""
+        mock_proc.wait.return_value = None
+
+        mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+        mock_subprocess.PIPE = "pipe"
+        mock_subprocess.STDOUT = "stdout"
+
+        op = GCSTimeSpanFileTransformOperator(
+            task_id="test",
+            source_bucket="bucket",
+            source_prefix="prefix",
+            source_gcp_conn_id="",
+            destination_bucket="dest",
+            destination_prefix="dest",
+            destination_gcp_conn_id="",
+            transform_script="script.py",
+            max_download_workers=2,
+            download_continue_on_fail=continue_on_fail,
+        )
+
+        if continue_on_fail:
+            op.execute(context=context)
+        else:
+            with pytest.raises(GoogleCloudError):
+                op.execute(context=context)
+
+    @pytest.mark.parametrize(
+        ("workers", "should_raise"),
+        [
+            (0, True),
+            (1, False),
+            (2, False),
+        ],
+    )
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+    def test_parallel_upload_worker_behavior(
+        self, mock_hook, mock_subprocess, mock_tempdir, workers, should_raise
+    ):
+        timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+        timespan_end = timespan_start + timedelta(hours=1)
+
+        context = {
+            "logical_date": timespan_start,
+            "data_interval_start": timespan_start,
+            "data_interval_end": timespan_end,
+            "ti": mock.Mock(),
+            "task": mock.MagicMock(),
+        }
+
+        if should_raise:
+            with pytest.raises(ValueError, match="max_upload_workers must be 
>= 1"):
+                GCSTimeSpanFileTransformOperator(
+                    task_id="test",
+                    source_bucket="bucket",
+                    source_prefix="prefix",
+                    source_gcp_conn_id="",
+                    destination_bucket="dest",
+                    destination_prefix="dest",
+                    destination_gcp_conn_id="",
+                    transform_script="script.py",
+                    max_upload_workers=workers,
+                )
+            return
+
+        mock_tempdir.return_value.__enter__.side_effect = ["source", 
"destination"]
+        mock_hook.return_value.list_by_timespan.return_value = []
+
+        mock_client, mock_bucket, mock_blob = 
self._setup_gcs_client_chain(mock_hook)
+
+        mock_proc = mock.MagicMock()
+        mock_proc.returncode = 0
+        mock_proc.stdout.readline = lambda: b""
+        mock_proc.wait.return_value = None
+
+        mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+        mock_subprocess.PIPE = "pipe"
+        mock_subprocess.STDOUT = "stdout"
+
+        op = GCSTimeSpanFileTransformOperator(
+            task_id="test",
+            source_bucket="bucket",
+            source_prefix="prefix",
+            source_gcp_conn_id="",
+            destination_bucket="dest",
+            destination_prefix="dest",
+            destination_gcp_conn_id="",
+            transform_script="script.py",
+            max_upload_workers=workers,
+        )
+
+        with (
+            mock.patch.object(Path, "glob") as path_glob,
+            mock.patch.object(Path, "is_file", return_value=True),
+        ):
+            path_glob.return_value.__iter__.return_value = [
+                Path("destination/a.txt"),
+                Path("destination/b.txt"),
+            ]
+
+            op.execute(context=context)
+
+        assert mock_blob.upload_from_filename.call_count == 2
+
+    @pytest.mark.parametrize("continue_on_fail", [False, True])
+    
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+    @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+    def test_parallel_upload_failure_behavior(
+        self, mock_hook, mock_subprocess, mock_tempdir, continue_on_fail
+    ):
+        timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+        timespan_end = timespan_start + timedelta(hours=1)
+
+        context = {
+            "logical_date": timespan_start,
+            "data_interval_start": timespan_start,
+            "data_interval_end": timespan_end,
+            "ti": mock.Mock(),
+            "task": mock.MagicMock(),
+        }
+
+        mock_tempdir.return_value.__enter__.side_effect = ["source", 
"destination"]
+        mock_hook.return_value.list_by_timespan.return_value = []
+
+        mock_proc = mock.MagicMock()
+        mock_proc.returncode = 0
+        mock_proc.stdout.readline = lambda: b""
+        mock_proc.wait.return_value = None
+
+        mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+        mock_subprocess.PIPE = "pipe"
+        mock_subprocess.STDOUT = "stdout"
+
+        mock_client, mock_bucket, mock_blob = 
self._setup_gcs_client_chain(mock_hook)
+
+        mock_blob.upload_from_filename.side_effect = GoogleCloudError("fail")
+
+        op = GCSTimeSpanFileTransformOperator(
+            task_id="test",
+            source_bucket="bucket",
+            source_prefix="prefix",
+            source_gcp_conn_id="",
+            destination_bucket="dest",
+            destination_prefix="dest",
+            destination_gcp_conn_id="",
+            transform_script="script.py",
+            max_upload_workers=2,
+            upload_continue_on_fail=continue_on_fail,
+        )
+
+        with (
+            mock.patch.object(Path, "glob") as path_glob,
+            mock.patch.object(Path, "is_file", return_value=True),
+        ):
+            path_glob.return_value.__iter__.return_value = 
[Path("destination/a.txt")]
+
+            if continue_on_fail:
+                op.execute(context=context)
+            else:
+                with pytest.raises(GoogleCloudError):
+                    op.execute(context=context)
+
 
 class TestGCSDeleteBucketOperator:
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")

Reply via email to