This is an automated email from the ASF dual-hosted git repository. taragolis pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 13209fc354 Allow specification of buffer length for GCS to Samba (#38373) 13209fc354 is described below commit 13209fc35496bc40c283d440802333fd7bf17d84 Author: Collin McNulty <collin.mcnu...@gmail.com> AuthorDate: Fri Apr 5 05:07:26 2024 -0500 Allow specification of buffer length for GCS to Samba (#38373) * Allow specification of buffer length for GCS to Samba Co-Authored-By: jslepicka-apex <110119914+jslepicka-a...@users.noreply.github.com> * Pass buffer_size more cleanly * Allow specification of buffer length for GCS to Samba Co-Authored-By: jslepicka-apex <110119914+jslepicka-a...@users.noreply.github.com> * Pass buffer_size more cleanly * Cleanup tests to account for new parameter * Update test_gcs_to_samba.py * Add test and fix non-wildcard path --------- Co-authored-by: jslepicka-apex <110119914+jslepicka-a...@users.noreply.github.com> --- airflow/providers/samba/hooks/samba.py | 15 ++++-- airflow/providers/samba/transfers/gcs_to_samba.py | 16 +++++-- .../providers/samba/transfers/test_gcs_to_samba.py | 53 ++++++++++++++++++++-- 3 files changed, 74 insertions(+), 10 deletions(-) diff --git a/airflow/providers/samba/hooks/samba.py b/airflow/providers/samba/hooks/samba.py index 535ec267cc..895c885d92 100644 --- a/airflow/providers/samba/hooks/samba.py +++ b/airflow/providers/samba/hooks/samba.py @@ -245,10 +245,19 @@ class SambaHook(BaseHook): **self._conn_kwargs, ) - def push_from_local(self, destination_filepath: str, local_filepath: str): - """Push local file to samba server.""" + def push_from_local(self, destination_filepath: str, local_filepath: str, buffer_size: int | None = None): + """ + Push local file to samba server. + + :param destination_filepath: the samba location to push to + :param local_filepath: the file to push + :param buffer_size: + size in bytes of the individual chunks of file to send. Larger values may + speed up large file transfers + """ + extra_args = (buffer_size,) if buffer_size else () with open(local_filepath, "rb") as f, self.open_file(destination_filepath, mode="wb") as g: - copyfileobj(f, g) + copyfileobj(f, g, *extra_args) @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: diff --git a/airflow/providers/samba/transfers/gcs_to_samba.py b/airflow/providers/samba/transfers/gcs_to_samba.py index fb1cb6ad98..bddc038b73 100644 --- a/airflow/providers/samba/transfers/gcs_to_samba.py +++ b/airflow/providers/samba/transfers/gcs_to_samba.py @@ -93,6 +93,9 @@ class GCSToSambaOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param buffer_size: Optional specification of the size in bytes of the chunks sent to + Samba. Larger buffer lengths may decrease the time to upload large files. The default + length is determined by shutil, which is 64 KB. """ template_fields: Sequence[str] = ( @@ -114,6 +117,7 @@ class GCSToSambaOperator(BaseOperator): gcp_conn_id: str = "google_cloud_default", samba_conn_id: str = "samba_default", impersonation_chain: str | Sequence[str] | None = None, + buffer_size: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -127,6 +131,7 @@ class GCSToSambaOperator(BaseOperator): self.samba_conn_id = samba_conn_id self.impersonation_chain = impersonation_chain self.sftp_dirs = None + self.buffer_size = buffer_size def execute(self, context: Context): gcs_hook = GCSHook( @@ -154,12 +159,16 @@ class GCSToSambaOperator(BaseOperator): for source_object in objects: destination_path = self._resolve_destination_path(source_object, prefix=prefix_dirname) - self._copy_single_object(gcs_hook, samba_hook, source_object, destination_path) + self._copy_single_object( + gcs_hook, samba_hook, source_object, destination_path, self.buffer_size + ) self.log.info("Done. Uploaded '%d' files to %s", len(objects), self.destination_path) else: destination_path = self._resolve_destination_path(self.source_object) - self._copy_single_object(gcs_hook, samba_hook, self.source_object, destination_path) + self._copy_single_object( + gcs_hook, samba_hook, self.source_object, destination_path, self.buffer_size + ) self.log.info("Done. Uploaded '%s' file to %s", self.source_object, destination_path) def _resolve_destination_path(self, source_object: str, prefix: str | None = None) -> str: @@ -176,6 +185,7 @@ class GCSToSambaOperator(BaseOperator): samba_hook: SambaHook, source_object: str, destination_path: str, + buffer_size: int | None = None, ) -> None: """Copy single object.""" self.log.info( @@ -194,7 +204,7 @@ class GCSToSambaOperator(BaseOperator): object_name=source_object, filename=tmp.name, ) - samba_hook.push_from_local(destination_path, tmp.name) + samba_hook.push_from_local(destination_path, tmp.name, buffer_size=buffer_size) if self.move_object: self.log.info("Executing delete of gs://%s/%s", self.source_bucket, source_object) diff --git a/tests/providers/samba/transfers/test_gcs_to_samba.py b/tests/providers/samba/transfers/test_gcs_to_samba.py index 100fde5f7d..f335d78423 100644 --- a/tests/providers/samba/transfers/test_gcs_to_samba.py +++ b/tests/providers/samba/transfers/test_gcs_to_samba.py @@ -70,7 +70,7 @@ class TestGoogleCloudStorageToSambaOperator: bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY ) samba_hook_mock.return_value.push_from_local.assert_called_with( - os.path.join(DESTINATION_SMB, target_object), mock.ANY + os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None ) gcs_hook_mock.return_value.delete.assert_not_called() @@ -114,7 +114,52 @@ class TestGoogleCloudStorageToSambaOperator: bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY ) samba_hook_mock.return_value.push_from_local.assert_called_with( - os.path.join(DESTINATION_SMB, target_object), mock.ANY + os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None + ) + gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, source_object) + + @pytest.mark.parametrize( + "source_object, target_object, keep_directory_structure", + [ + ("folder/test_object.txt", "folder/test_object.txt", True), + ("folder/subfolder/test_object.txt", "folder/subfolder/test_object.txt", True), + ("folder/test_object.txt", "test_object.txt", False), + ("folder/subfolder/test_object.txt", "test_object.txt", False), + ], + ) + @mock.patch("airflow.providers.samba.transfers.gcs_to_samba.GCSHook") + @mock.patch("airflow.providers.samba.transfers.gcs_to_samba.SambaHook") + def test_execute_adjust_buffer_size( + self, + samba_hook_mock, + gcs_hook_mock, + source_object, + target_object, + keep_directory_structure, + ): + operator = GCSToSambaOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_object=source_object, + destination_path=DESTINATION_SMB, + keep_directory_structure=keep_directory_structure, + move_object=True, + gcp_conn_id=GCP_CONN_ID, + samba_conn_id=SAMBA_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + buffer_size=128000, + ) + operator.execute(None) + gcs_hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + samba_hook_mock.assert_called_once_with(samba_conn_id=SAMBA_CONN_ID) + gcs_hook_mock.return_value.download.assert_called_with( + bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY + ) + samba_hook_mock.return_value.push_from_local.assert_called_with( + os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=128000 ) gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, source_object) @@ -201,7 +246,7 @@ class TestGoogleCloudStorageToSambaOperator: ) samba_hook_mock.return_value.push_from_local.assert_has_calls( [ - mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY) + mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None) for target_object in target_objects ] ) @@ -290,7 +335,7 @@ class TestGoogleCloudStorageToSambaOperator: ) samba_hook_mock.return_value.push_from_local.assert_has_calls( [ - mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY) + mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None) for target_object in target_objects ] )