This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c3348af0972 Add parallel download and upload support to
`GCSTimeSpanFileTransformOperator` by introducing `max_download_workers` and
`max_upload_workers` (defaulting to 1) and executing both phases via
`ThreadPoolExecutor` when configured. The implementation reuses a single
`storage.Client` per phase, filters out directory placeholder objects before
download, and preserves existing error propagation and `*_continue_on_fail`
semantics. (#62196)
c3348af0972 is described below
commit c3348af09722cd07645b33d586f3e8665182c4cb
Author: SameerMesiah97 <[email protected]>
AuthorDate: Mon Mar 16 00:15:54 2026 +0000
Add parallel download and upload support to
`GCSTimeSpanFileTransformOperator` by introducing `max_download_workers` and
`max_upload_workers` (defaulting to 1) and executing both phases via
`ThreadPoolExecutor` when configured. The implementation reuses a single
`storage.Client` per phase, filters out directory placeholder objects before
download, and preserves existing error propagation and `*_continue_on_fail`
semantics. (#62196)
Unit tests have been updated and extended to cover worker validation,
parallel execution paths, and failure handling. The existing
`gcs_transform_timespan` system test DAG has also been extended with a parallel
execution case to validate the behavior against real GCS.
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../providers/google/cloud/operators/gcs.py | 147 ++++++---
.../cloud/gcs/example_gcs_transform_timespan.py | 16 +
.../tests/unit/google/cloud/operators/test_gcs.py | 330 +++++++++++++++++++--
3 files changed, 436 insertions(+), 57 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
index 1abd3122c10..1cfb40ff0cb 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py
@@ -24,6 +24,7 @@ import subprocess
import sys
import warnings
from collections.abc import Sequence
+from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING
@@ -681,6 +682,10 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
data from source, transform it and write the output to the local
destination file.
+ Downloads and uploads can be executed in parallel by configuring
+ ``max_download_workers`` and ``max_upload_workers``. By default,
+ execution is sequential.
+
:param source_bucket: The bucket to fetch data from. (templated)
:param source_prefix: Prefix string which filters objects whose name begin
with
this prefix. Can interpolate logical date and time components.
(templated)
@@ -722,6 +727,10 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
:param upload_continue_on_fail: With this set to true, if an upload fails
the task does not error out
but will still continue.
:param upload_num_attempts: Number of attempts to try to upload a single
file.
+ :param max_download_workers: Maximum number of worker threads to use for
parallel downloads.
+ Must be greater than or equal to 1. Defaults to 1 (sequential
execution).
+ :param max_upload_workers: Maximum number of worker threads to use for
parallel uploads.
+ Must be greater than or equal to 1. Defaults to 1 (sequential
execution).
"""
template_fields: Sequence[str] = (
@@ -765,6 +774,8 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
download_num_attempts: int = 1,
upload_continue_on_fail: bool | None = False,
upload_num_attempts: int = 1,
+ max_download_workers: int = 1,
+ max_upload_workers: int = 1,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -787,6 +798,15 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
self.upload_continue_on_fail = upload_continue_on_fail
self.upload_num_attempts = upload_num_attempts
+ if max_download_workers < 1:
+ raise ValueError("max_download_workers must be >= 1")
+
+ if max_upload_workers < 1:
+ raise ValueError("max_upload_workers must be >= 1")
+
+ self.max_download_workers = max_download_workers
+ self.max_upload_workers = max_upload_workers
+
self._source_prefix_interp: str | None = None
self._destination_prefix_interp: str | None = None
@@ -838,41 +858,67 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
)
# Fetch list of files.
- blobs_to_transform = source_hook.list_by_timespan(
- bucket_name=self.source_bucket,
- prefix=self._source_prefix_interp,
- timespan_start=timespan_start,
- timespan_end=timespan_end,
- )
+
+ blobs_to_transform = [
+ blob
+ for blob in source_hook.list_by_timespan(
+ bucket_name=self.source_bucket,
+ prefix=self._source_prefix_interp,
+ timespan_start=timespan_start,
+ timespan_end=timespan_end,
+ )
+ # Filter out "directory" placeholders (GCS objects ending with '/')
+ # to avoid attempting to download non-file blobs.
+ if not blob.endswith("/")
+ ]
with TemporaryDirectory() as temp_input_dir, TemporaryDirectory() as
temp_output_dir:
temp_input_dir_path = Path(temp_input_dir)
temp_output_dir_path = Path(temp_output_dir)
- # TODO: download in parallel.
- for blob_to_transform in blobs_to_transform:
- destination_file = temp_input_dir_path / blob_to_transform
+ self.log.info(
+ "Downloading %d files using %d workers",
+ len(blobs_to_transform),
+ self.max_download_workers,
+ )
+
+ # Get storage client once (storage.Client is thread-safe for
concurrent requests).
+ client = source_hook.get_conn()
+
+ def _download(blob_name: str):
+
+ bucket = client.bucket(bucket_name=self.source_bucket)
+ blob = bucket.blob(blob_name=blob_name,
chunk_size=self.chunk_size)
+
+ destination_file = temp_input_dir_path / blob_name
destination_file.parent.mkdir(parents=True, exist_ok=True)
- try:
- source_hook.download(
- bucket_name=self.source_bucket,
- object_name=blob_to_transform,
- filename=str(destination_file),
- chunk_size=self.chunk_size,
- num_max_attempts=self.download_num_attempts,
- )
- except GoogleCloudError:
- if not self.download_continue_on_fail:
- raise
+
+ blob.download_to_filename(filename=str(destination_file))
+
+ return blob_name
+
+ with ThreadPoolExecutor(max_workers=self.max_download_workers) as
executor:
+ futures = {executor.submit(_download, blob): blob for blob in
blobs_to_transform}
+
+ for future in as_completed(futures):
+ blob = futures[future]
+ try:
+ future.result()
+ except GoogleCloudError as e:
+ if not self.download_continue_on_fail:
+ raise
+ self.log.warning("Download failed for %s: %s", blob, e)
self.log.info("Starting the transformation")
cmd = [self.transform_script] if isinstance(self.transform_script,
str) else self.transform_script
+
cmd += [
str(temp_input_dir_path),
str(temp_output_dir_path),
timespan_start.replace(microsecond=0).isoformat(),
timespan_end.replace(microsecond=0).isoformat(),
]
+
with subprocess.Popen(
args=cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
close_fds=True
) as process:
@@ -887,32 +933,57 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
self.log.info("Transformation succeeded. Output temporarily
located at %s", temp_output_dir_path)
- files_uploaded = []
+ upload_candidates = [f for f in temp_output_dir_path.glob("**/*")
if f.is_file()]
- # TODO: upload in parallel.
- for upload_file in temp_output_dir_path.glob("**/*"):
- if upload_file.is_dir():
- continue
+ self.log.info(
+ "Uploading %d files using %d workers",
+ len(upload_candidates),
+ self.max_upload_workers,
+ )
+ destination_hook = GCSHook(
+ gcp_conn_id=self.destination_gcp_conn_id,
+ impersonation_chain=self.destination_impersonation_chain,
+ )
+
+ # Get storage client once (storage.Client is thread-safe for
concurrent requests).
+ client = destination_hook.get_conn()
+
+ def _upload(upload_file: Path):
+
+ bucket = client.bucket(bucket_name=self.destination_bucket)
+
+ # Preserve directory structure relative to the output temp
directory.
upload_file_name =
str(upload_file.relative_to(temp_output_dir_path))
if self._destination_prefix_interp is not None:
upload_file_name =
f"{self._destination_prefix_interp.rstrip('/')}/{upload_file_name}"
- self.log.info("Uploading file %s to %s", upload_file,
upload_file_name)
+ blob = bucket.blob(blob_name=upload_file_name,
chunk_size=self.chunk_size)
- try:
- destination_hook.upload(
- bucket_name=self.destination_bucket,
- object_name=upload_file_name,
- filename=str(upload_file),
- chunk_size=self.chunk_size,
- num_max_attempts=self.upload_num_attempts,
- )
- files_uploaded.append(str(upload_file_name))
- except GoogleCloudError:
- if not self.upload_continue_on_fail:
- raise
+ blob.upload_from_filename(
+ filename=str(upload_file),
+ )
+
+ return upload_file_name
+
+ files_uploaded: list[str] = []
+
+ with ThreadPoolExecutor(max_workers=self.max_upload_workers) as
executor:
+ futures = {
+ executor.submit(_upload, upload_file): str(upload_file)
+ for upload_file in upload_candidates
+ }
+
+ for future in as_completed(futures):
+ upload_file = futures[future]
+ try:
+ uploaded_name = future.result()
+ files_uploaded.append(uploaded_name)
+ except GoogleCloudError as e:
+ if not self.upload_continue_on_fail:
+ raise
+ self.log.warning("Upload failed for %s: %s",
upload_file, e)
return files_uploaded
diff --git
a/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
b/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
index 8b9b74d3615..b0b175f5bf9 100644
---
a/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
+++
b/providers/google/tests/system/google/cloud/gcs/example_gcs_transform_timespan.py
@@ -57,6 +57,8 @@ SOURCE_GCP_CONN_ID = DESTINATION_GCP_CONN_ID =
"google_cloud_default"
FILE_NAME = "example_upload.txt"
SOURCE_PREFIX = "timespan_source"
DESTINATION_PREFIX = "timespan_destination"
+DESTINATION_PREFIX_PARALLEL = "timespan_destination_parallel"
+
UPLOAD_FILE_PATH = f"gcs/{FILE_NAME}"
TRANSFORM_SCRIPT_PATH = str(Path(__file__).parent / "resources" /
"transform_timespan.py")
@@ -101,6 +103,19 @@ with DAG(
destination_gcp_conn_id=DESTINATION_GCP_CONN_ID,
transform_script=["python", TRANSFORM_SCRIPT_PATH],
)
+
+ gcs_timespan_transform_files_parallel = GCSTimeSpanFileTransformOperator(
+ task_id="gcs_timespan_transform_files_parallel",
+ source_bucket=BUCKET_NAME_SRC,
+ source_prefix=SOURCE_PREFIX,
+ source_gcp_conn_id=SOURCE_GCP_CONN_ID,
+ destination_bucket=BUCKET_NAME_DST,
+ destination_prefix=DESTINATION_PREFIX_PARALLEL,
+ destination_gcp_conn_id=DESTINATION_GCP_CONN_ID,
+ transform_script=["python", TRANSFORM_SCRIPT_PATH],
+ max_download_workers=2,
+ max_upload_workers=2,
+ )
# [END howto_operator_gcs_timespan_file_transform_operator_Task]
delete_bucket_src = GCSDeleteBucketOperator(
@@ -121,6 +136,7 @@ with DAG(
copy_file,
# TEST BODY
gcs_timespan_transform_files_task,
+ gcs_timespan_transform_files_parallel,
# TEST TEARDOWN
[delete_bucket_src, delete_bucket_dst, check_openlineage_events],
)
diff --git a/providers/google/tests/unit/google/cloud/operators/test_gcs.py
b/providers/google/tests/unit/google/cloud/operators/test_gcs.py
index 48e9561ac53..ea8edea51e7 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_gcs.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_gcs.py
@@ -22,6 +22,7 @@ from pathlib import Path
from unittest import mock
import pytest
+from google.cloud.exceptions import GoogleCloudError
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
@@ -226,6 +227,8 @@ class TestGCSDeleteObjectsOperator:
assert len(lineage.outputs) == 0
assert all(element in lineage.inputs for element in expected_inputs)
assert all(element in expected_inputs for element in lineage.inputs)
+ print("EXPECTED:", expected_inputs)
+ print("ACTUAL:", lineage.inputs)
class TestGoogleCloudStorageListOperator:
@@ -372,6 +375,13 @@ class
TestGCSTimeSpanFileTransformOperatorDateInterpolation:
class TestGCSTimeSpanFileTransformOperator:
+ def _setup_gcs_client_chain(self, mock_hook):
+ mock_client = mock.MagicMock()
+ mock_hook.return_value.get_conn.return_value = mock_client
+ mock_bucket = mock_client.bucket.return_value
+ mock_blob = mock_bucket.blob.return_value
+ return mock_client, mock_bucket, mock_blob
+
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
@mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
@@ -410,6 +420,8 @@ class TestGCSTimeSpanFileTransformOperator:
f"{source_prefix}/{file2}",
]
+ mock_client, mock_bucket, mock_blob =
self._setup_gcs_client_chain(mock_hook)
+
mock_proc = mock.MagicMock()
mock_proc.returncode = 0
mock_proc.stdout.readline = lambda: b""
@@ -432,7 +444,10 @@ class TestGCSTimeSpanFileTransformOperator:
transform_script=transform_script,
)
- with mock.patch.object(Path, "glob") as path_glob:
+ with (
+ mock.patch.object(Path, "glob") as path_glob,
+ mock.patch.object(Path, "is_file", return_value=True),
+ ):
path_glob.return_value.__iter__.return_value = [
Path(f"{destination}/{file1}"),
Path(f"{destination}/{file2}"),
@@ -446,21 +461,29 @@ class TestGCSTimeSpanFileTransformOperator:
prefix=source_prefix,
)
- mock_hook.return_value.download.assert_has_calls(
+ mock_client.bucket.assert_has_calls(
+ [
+ mock.call(bucket_name=source_bucket),
+ mock.call(bucket_name=source_bucket),
+ ],
+ any_order=True,
+ )
+
+ mock_bucket.blob.assert_has_calls(
+ [
+ mock.call(blob_name=f"{source_prefix}/{file1}",
chunk_size=None),
+ mock.call(blob_name=f"{source_prefix}/{file2}",
chunk_size=None),
+ ],
+ any_order=True,
+ )
+
+ mock_blob.download_to_filename.assert_has_calls(
[
mock.call(
- bucket_name=source_bucket,
- object_name=f"{source_prefix}/{file1}",
filename=f"{source}/{source_prefix}/{file1}",
- chunk_size=None,
- num_max_attempts=1,
),
mock.call(
- bucket_name=source_bucket,
- object_name=f"{source_prefix}/{file2}",
filename=f"{source}/{source_prefix}/{file2}",
- chunk_size=None,
- num_max_attempts=1,
),
]
)
@@ -478,21 +501,29 @@ class TestGCSTimeSpanFileTransformOperator:
close_fds=True,
)
- mock_hook.return_value.upload.assert_has_calls(
+ mock_client.bucket.assert_has_calls(
+ [
+ mock.call(bucket_name=destination_bucket),
+ mock.call(bucket_name=destination_bucket),
+ ],
+ any_order=True,
+ )
+
+ mock_bucket.blob.assert_has_calls(
+ [
+ mock.call(blob_name=f"{destination_prefix}/{file1}",
chunk_size=None),
+ mock.call(blob_name=f"{destination_prefix}/{file2}",
chunk_size=None),
+ ],
+ any_order=True,
+ )
+
+ mock_blob.upload_from_filename.assert_has_calls(
[
mock.call(
- bucket_name=destination_bucket,
filename=f"{destination}/{file1}",
- object_name=f"{destination_prefix}/{file1}",
- chunk_size=None,
- num_max_attempts=1,
),
mock.call(
- bucket_name=destination_bucket,
filename=f"{destination}/{file2}",
- object_name=f"{destination_prefix}/{file2}",
- chunk_size=None,
- num_max_attempts=1,
),
]
)
@@ -628,6 +659,267 @@ class TestGCSTimeSpanFileTransformOperator:
assert all(element in lineage.outputs for element in outputs)
assert all(element in outputs for element in lineage.outputs)
+ @pytest.mark.parametrize(
+ ("workers", "should_raise"),
+ [
+ (0, True),
+ (1, False),
+ (2, False),
+ ],
+ )
+
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+ def test_parallel_download_worker_behavior(
+ self, mock_hook, mock_subprocess, mock_tempdir, workers, should_raise
+ ):
+ timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+ timespan_end = timespan_start + timedelta(hours=1)
+
+ context = {
+ "logical_date": timespan_start,
+ "data_interval_start": timespan_start,
+ "data_interval_end": timespan_end,
+ "ti": mock.Mock(),
+ "task": mock.MagicMock(),
+ }
+
+ if should_raise:
+ with pytest.raises(ValueError, match="max_download_workers must be
>= 1"):
+ GCSTimeSpanFileTransformOperator(
+ task_id="test",
+ source_bucket="bucket",
+ source_prefix="prefix",
+ source_gcp_conn_id="",
+ destination_bucket="dest",
+ destination_prefix="dest",
+ destination_gcp_conn_id="",
+ transform_script="script.py",
+ max_download_workers=workers,
+ )
+ return
+
+ mock_tempdir.return_value.__enter__.side_effect = ["source",
"destination"]
+ mock_hook.return_value.list_by_timespan.return_value = ["file1",
"file2"]
+
+ mock_client, mock_bucket, mock_blob =
self._setup_gcs_client_chain(mock_hook)
+
+ mock_proc = mock.MagicMock()
+ mock_proc.returncode = 0
+ mock_proc.stdout.readline = lambda: b""
+ mock_proc.wait.return_value = None
+
+ mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+ mock_subprocess.PIPE = "pipe"
+ mock_subprocess.STDOUT = "stdout"
+
+ op = GCSTimeSpanFileTransformOperator(
+ task_id="test",
+ source_bucket="bucket",
+ source_prefix="prefix",
+ source_gcp_conn_id="",
+ destination_bucket="dest",
+ destination_prefix="dest",
+ destination_gcp_conn_id="",
+ transform_script="script.py",
+ max_download_workers=workers,
+ )
+
+ op.execute(context=context)
+
+ assert mock_blob.download_to_filename.call_count == 2
+
+ @pytest.mark.parametrize("continue_on_fail", [False, True])
+
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+ def test_parallel_download_failure_behavior(
+ self, mock_hook, mock_subprocess, mock_tempdir, continue_on_fail
+ ):
+ timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+ timespan_end = timespan_start + timedelta(hours=1)
+
+ context = {
+ "logical_date": timespan_start,
+ "data_interval_start": timespan_start,
+ "data_interval_end": timespan_end,
+ "ti": mock.Mock(),
+ "task": mock.MagicMock(),
+ }
+
+ mock_tempdir.return_value.__enter__.side_effect = ["source",
"destination"]
+ mock_hook.return_value.list_by_timespan.return_value = ["file1"]
+
+ mock_client, mock_bucket, mock_blob =
self._setup_gcs_client_chain(mock_hook)
+
+ mock_blob.download_to_filename.side_effect = GoogleCloudError("fail")
+
+ mock_proc = mock.MagicMock()
+ mock_proc.returncode = 0
+ mock_proc.stdout.readline = lambda: b""
+ mock_proc.wait.return_value = None
+
+ mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+ mock_subprocess.PIPE = "pipe"
+ mock_subprocess.STDOUT = "stdout"
+
+ op = GCSTimeSpanFileTransformOperator(
+ task_id="test",
+ source_bucket="bucket",
+ source_prefix="prefix",
+ source_gcp_conn_id="",
+ destination_bucket="dest",
+ destination_prefix="dest",
+ destination_gcp_conn_id="",
+ transform_script="script.py",
+ max_download_workers=2,
+ download_continue_on_fail=continue_on_fail,
+ )
+
+ if continue_on_fail:
+ op.execute(context=context)
+ else:
+ with pytest.raises(GoogleCloudError):
+ op.execute(context=context)
+
+ @pytest.mark.parametrize(
+ ("workers", "should_raise"),
+ [
+ (0, True),
+ (1, False),
+ (2, False),
+ ],
+ )
+
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+ def test_parallel_upload_worker_behavior(
+ self, mock_hook, mock_subprocess, mock_tempdir, workers, should_raise
+ ):
+ timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+ timespan_end = timespan_start + timedelta(hours=1)
+
+ context = {
+ "logical_date": timespan_start,
+ "data_interval_start": timespan_start,
+ "data_interval_end": timespan_end,
+ "ti": mock.Mock(),
+ "task": mock.MagicMock(),
+ }
+
+ if should_raise:
+ with pytest.raises(ValueError, match="max_upload_workers must be
>= 1"):
+ GCSTimeSpanFileTransformOperator(
+ task_id="test",
+ source_bucket="bucket",
+ source_prefix="prefix",
+ source_gcp_conn_id="",
+ destination_bucket="dest",
+ destination_prefix="dest",
+ destination_gcp_conn_id="",
+ transform_script="script.py",
+ max_upload_workers=workers,
+ )
+ return
+
+ mock_tempdir.return_value.__enter__.side_effect = ["source",
"destination"]
+ mock_hook.return_value.list_by_timespan.return_value = []
+
+ mock_client, mock_bucket, mock_blob =
self._setup_gcs_client_chain(mock_hook)
+
+ mock_proc = mock.MagicMock()
+ mock_proc.returncode = 0
+ mock_proc.stdout.readline = lambda: b""
+ mock_proc.wait.return_value = None
+
+ mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+ mock_subprocess.PIPE = "pipe"
+ mock_subprocess.STDOUT = "stdout"
+
+ op = GCSTimeSpanFileTransformOperator(
+ task_id="test",
+ source_bucket="bucket",
+ source_prefix="prefix",
+ source_gcp_conn_id="",
+ destination_bucket="dest",
+ destination_prefix="dest",
+ destination_gcp_conn_id="",
+ transform_script="script.py",
+ max_upload_workers=workers,
+ )
+
+ with (
+ mock.patch.object(Path, "glob") as path_glob,
+ mock.patch.object(Path, "is_file", return_value=True),
+ ):
+ path_glob.return_value.__iter__.return_value = [
+ Path("destination/a.txt"),
+ Path("destination/b.txt"),
+ ]
+
+ op.execute(context=context)
+
+ assert mock_blob.upload_from_filename.call_count == 2
+
+ @pytest.mark.parametrize("continue_on_fail", [False, True])
+
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
+ @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
+ def test_parallel_upload_failure_behavior(
+ self, mock_hook, mock_subprocess, mock_tempdir, continue_on_fail
+ ):
+ timespan_start = datetime(2015, 2, 1, tzinfo=timezone.utc)
+ timespan_end = timespan_start + timedelta(hours=1)
+
+ context = {
+ "logical_date": timespan_start,
+ "data_interval_start": timespan_start,
+ "data_interval_end": timespan_end,
+ "ti": mock.Mock(),
+ "task": mock.MagicMock(),
+ }
+
+ mock_tempdir.return_value.__enter__.side_effect = ["source",
"destination"]
+ mock_hook.return_value.list_by_timespan.return_value = []
+
+ mock_proc = mock.MagicMock()
+ mock_proc.returncode = 0
+ mock_proc.stdout.readline = lambda: b""
+ mock_proc.wait.return_value = None
+
+ mock_subprocess.Popen.return_value.__enter__.return_value = mock_proc
+ mock_subprocess.PIPE = "pipe"
+ mock_subprocess.STDOUT = "stdout"
+
+ mock_client, mock_bucket, mock_blob =
self._setup_gcs_client_chain(mock_hook)
+
+ mock_blob.upload_from_filename.side_effect = GoogleCloudError("fail")
+
+ op = GCSTimeSpanFileTransformOperator(
+ task_id="test",
+ source_bucket="bucket",
+ source_prefix="prefix",
+ source_gcp_conn_id="",
+ destination_bucket="dest",
+ destination_prefix="dest",
+ destination_gcp_conn_id="",
+ transform_script="script.py",
+ max_upload_workers=2,
+ upload_continue_on_fail=continue_on_fail,
+ )
+
+ with (
+ mock.patch.object(Path, "glob") as path_glob,
+ mock.patch.object(Path, "is_file", return_value=True),
+ ):
+ path_glob.return_value.__iter__.return_value =
[Path("destination/a.txt")]
+
+ if continue_on_fail:
+ op.execute(context=context)
+ else:
+ with pytest.raises(GoogleCloudError):
+ op.execute(context=context)
+
class TestGCSDeleteBucketOperator:
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")