This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 13209fc354 Allow specification of buffer length for GCS to Samba 
(#38373)
13209fc354 is described below

commit 13209fc35496bc40c283d440802333fd7bf17d84
Author: Collin McNulty <collin.mcnu...@gmail.com>
AuthorDate: Fri Apr 5 05:07:26 2024 -0500

    Allow specification of buffer length for GCS to Samba (#38373)
    
    * Allow specification of buffer length for GCS to Samba
    
    Co-Authored-By: jslepicka-apex 
<110119914+jslepicka-a...@users.noreply.github.com>
    
    * Pass buffer_size more cleanly
    
    * Allow specification of buffer length for GCS to Samba
    
    Co-Authored-By: jslepicka-apex 
<110119914+jslepicka-a...@users.noreply.github.com>
    
    * Pass buffer_size more cleanly
    
    * Cleanup tests to account for new parameter
    
    * Update test_gcs_to_samba.py
    
    * Add test and fix non-wildcard path
    
    ---------
    
    Co-authored-by: jslepicka-apex 
<110119914+jslepicka-a...@users.noreply.github.com>
---
 airflow/providers/samba/hooks/samba.py             | 15 ++++--
 airflow/providers/samba/transfers/gcs_to_samba.py  | 16 +++++--
 .../providers/samba/transfers/test_gcs_to_samba.py | 53 ++++++++++++++++++++--
 3 files changed, 74 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/samba/hooks/samba.py 
b/airflow/providers/samba/hooks/samba.py
index 535ec267cc..895c885d92 100644
--- a/airflow/providers/samba/hooks/samba.py
+++ b/airflow/providers/samba/hooks/samba.py
@@ -245,10 +245,19 @@ class SambaHook(BaseHook):
             **self._conn_kwargs,
         )
 
-    def push_from_local(self, destination_filepath: str, local_filepath: str):
-        """Push local file to samba server."""
+    def push_from_local(self, destination_filepath: str, local_filepath: str, 
buffer_size: int | None = None):
+        """
+        Push local file to samba server.
+
+        :param destination_filepath: the samba location to push to
+        :param local_filepath: the file to push
+        :param buffer_size:
+            size in bytes of the individual chunks of file to send. Larger 
values may
+            speed up large file transfers
+        """
+        extra_args = (buffer_size,) if buffer_size else ()
         with open(local_filepath, "rb") as f, 
self.open_file(destination_filepath, mode="wb") as g:
-            copyfileobj(f, g)
+            copyfileobj(f, g, *extra_args)
 
     @classmethod
     def get_ui_field_behaviour(cls) -> dict[str, Any]:
diff --git a/airflow/providers/samba/transfers/gcs_to_samba.py 
b/airflow/providers/samba/transfers/gcs_to_samba.py
index fb1cb6ad98..bddc038b73 100644
--- a/airflow/providers/samba/transfers/gcs_to_samba.py
+++ b/airflow/providers/samba/transfers/gcs_to_samba.py
@@ -93,6 +93,9 @@ class GCSToSambaOperator(BaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param buffer_size: Optional specification of the size in bytes of the 
chunks sent to
+        Samba. Larger buffer lengths may decrease the time to upload large 
files. The default
+        length is determined by shutil, which is 64 KB.
     """
 
     template_fields: Sequence[str] = (
@@ -114,6 +117,7 @@ class GCSToSambaOperator(BaseOperator):
         gcp_conn_id: str = "google_cloud_default",
         samba_conn_id: str = "samba_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        buffer_size: int | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -127,6 +131,7 @@ class GCSToSambaOperator(BaseOperator):
         self.samba_conn_id = samba_conn_id
         self.impersonation_chain = impersonation_chain
         self.sftp_dirs = None
+        self.buffer_size = buffer_size
 
     def execute(self, context: Context):
         gcs_hook = GCSHook(
@@ -154,12 +159,16 @@ class GCSToSambaOperator(BaseOperator):
 
             for source_object in objects:
                 destination_path = 
self._resolve_destination_path(source_object, prefix=prefix_dirname)
-                self._copy_single_object(gcs_hook, samba_hook, source_object, 
destination_path)
+                self._copy_single_object(
+                    gcs_hook, samba_hook, source_object, destination_path, 
self.buffer_size
+                )
 
             self.log.info("Done. Uploaded '%d' files to %s", len(objects), 
self.destination_path)
         else:
             destination_path = 
self._resolve_destination_path(self.source_object)
-            self._copy_single_object(gcs_hook, samba_hook, self.source_object, 
destination_path)
+            self._copy_single_object(
+                gcs_hook, samba_hook, self.source_object, destination_path, 
self.buffer_size
+            )
             self.log.info("Done. Uploaded '%s' file to %s", 
self.source_object, destination_path)
 
     def _resolve_destination_path(self, source_object: str, prefix: str | None 
= None) -> str:
@@ -176,6 +185,7 @@ class GCSToSambaOperator(BaseOperator):
         samba_hook: SambaHook,
         source_object: str,
         destination_path: str,
+        buffer_size: int | None = None,
     ) -> None:
         """Copy single object."""
         self.log.info(
@@ -194,7 +204,7 @@ class GCSToSambaOperator(BaseOperator):
                 object_name=source_object,
                 filename=tmp.name,
             )
-            samba_hook.push_from_local(destination_path, tmp.name)
+            samba_hook.push_from_local(destination_path, tmp.name, 
buffer_size=buffer_size)
 
         if self.move_object:
             self.log.info("Executing delete of gs://%s/%s", 
self.source_bucket, source_object)
diff --git a/tests/providers/samba/transfers/test_gcs_to_samba.py 
b/tests/providers/samba/transfers/test_gcs_to_samba.py
index 100fde5f7d..f335d78423 100644
--- a/tests/providers/samba/transfers/test_gcs_to_samba.py
+++ b/tests/providers/samba/transfers/test_gcs_to_samba.py
@@ -70,7 +70,7 @@ class TestGoogleCloudStorageToSambaOperator:
             bucket_name=TEST_BUCKET, object_name=source_object, 
filename=mock.ANY
         )
         samba_hook_mock.return_value.push_from_local.assert_called_with(
-            os.path.join(DESTINATION_SMB, target_object), mock.ANY
+            os.path.join(DESTINATION_SMB, target_object), mock.ANY, 
buffer_size=None
         )
         gcs_hook_mock.return_value.delete.assert_not_called()
 
@@ -114,7 +114,52 @@ class TestGoogleCloudStorageToSambaOperator:
             bucket_name=TEST_BUCKET, object_name=source_object, 
filename=mock.ANY
         )
         samba_hook_mock.return_value.push_from_local.assert_called_with(
-            os.path.join(DESTINATION_SMB, target_object), mock.ANY
+            os.path.join(DESTINATION_SMB, target_object), mock.ANY, 
buffer_size=None
+        )
+        gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, 
source_object)
+
+    @pytest.mark.parametrize(
+        "source_object, target_object, keep_directory_structure",
+        [
+            ("folder/test_object.txt", "folder/test_object.txt", True),
+            ("folder/subfolder/test_object.txt", 
"folder/subfolder/test_object.txt", True),
+            ("folder/test_object.txt", "test_object.txt", False),
+            ("folder/subfolder/test_object.txt", "test_object.txt", False),
+        ],
+    )
+    @mock.patch("airflow.providers.samba.transfers.gcs_to_samba.GCSHook")
+    @mock.patch("airflow.providers.samba.transfers.gcs_to_samba.SambaHook")
+    def test_execute_adjust_buffer_size(
+        self,
+        samba_hook_mock,
+        gcs_hook_mock,
+        source_object,
+        target_object,
+        keep_directory_structure,
+    ):
+        operator = GCSToSambaOperator(
+            task_id=TASK_ID,
+            source_bucket=TEST_BUCKET,
+            source_object=source_object,
+            destination_path=DESTINATION_SMB,
+            keep_directory_structure=keep_directory_structure,
+            move_object=True,
+            gcp_conn_id=GCP_CONN_ID,
+            samba_conn_id=SAMBA_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            buffer_size=128000,
+        )
+        operator.execute(None)
+        gcs_hook_mock.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        samba_hook_mock.assert_called_once_with(samba_conn_id=SAMBA_CONN_ID)
+        gcs_hook_mock.return_value.download.assert_called_with(
+            bucket_name=TEST_BUCKET, object_name=source_object, 
filename=mock.ANY
+        )
+        samba_hook_mock.return_value.push_from_local.assert_called_with(
+            os.path.join(DESTINATION_SMB, target_object), mock.ANY, 
buffer_size=128000
         )
         gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, 
source_object)
 
@@ -201,7 +246,7 @@ class TestGoogleCloudStorageToSambaOperator:
         )
         samba_hook_mock.return_value.push_from_local.assert_has_calls(
             [
-                mock.call(os.path.join(DESTINATION_SMB, target_object), 
mock.ANY)
+                mock.call(os.path.join(DESTINATION_SMB, target_object), 
mock.ANY, buffer_size=None)
                 for target_object in target_objects
             ]
         )
@@ -290,7 +335,7 @@ class TestGoogleCloudStorageToSambaOperator:
         )
         samba_hook_mock.return_value.push_from_local.assert_has_calls(
             [
-                mock.call(os.path.join(DESTINATION_SMB, target_object), 
mock.ANY)
+                mock.call(os.path.join(DESTINATION_SMB, target_object), 
mock.ANY, buffer_size=None)
                 for target_object in target_objects
             ]
         )

Reply via email to