kaxil closed pull request #4270: [AIRFLOW-3434] Allows creating intermediate folders in SFTPOperator URL: https://github.com/apache/incubator-airflow/pull/4270
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py index 620d875f89..117bc55a8c 100644 --- a/airflow/contrib/operators/sftp_operator.py +++ b/airflow/contrib/operators/sftp_operator.py @@ -16,6 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os + from airflow.contrib.hooks.ssh_hook import SSHHook from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -48,9 +50,28 @@ class SFTPOperator(BaseOperator): :param remote_filepath: remote file path to get or put. (templated) :type remote_filepath: str :param operation: specify operation 'get' or 'put', defaults to put - :type get: bool + :type operation: str :param confirm: specify if the SFTP operation should be confirmed, defaults to True :type confirm: bool + :param create_intermediate_dirs: create missing intermediate directories when + copying from remote to local and vice-versa. Default is False. + + Example: The following task would copy ``file.txt`` to the remote host + at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they + don't exist. If the parameter is not passed it would error as the directory + does not exist. :: + + put_file = SFTPOperator( + task_id="test_sftp", + ssh_conn="ssh_default", + local_filepath="/tmp/file.txt", + remote_filepath="/tmp/tmp1/tmp2/file.txt", + operation="put", + create_intermediate_dirs=True, + dag=dag + ) + + :type create_intermediate_dirs: bool """ template_fields = ('local_filepath', 'remote_filepath', 'remote_host') @@ -63,6 +84,7 @@ def __init__(self, remote_filepath=None, operation=SFTPOperation.PUT, confirm=True, + create_intermediate_dirs=False, *args, **kwargs): super(SFTPOperator, self).__init__(*args, **kwargs) @@ -73,6 +95,7 @@ def __init__(self, self.remote_filepath = remote_filepath self.operation = operation self.confirm = confirm + self.create_intermediate_dirs = create_intermediate_dirs if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT): raise TypeError("unsupported operation value {0}, expected {1} or {2}" @@ -101,11 +124,25 @@ def execute(self, context): with self.ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() if self.operation.lower() == SFTPOperation.GET: + local_folder = os.path.dirname(self.local_filepath) + if self.create_intermediate_dirs: + # Create Intermediate Directories if it doesn't exist + try: + os.makedirs(local_folder) + except OSError: + if not os.path.isdir(local_folder): + raise file_msg = "from {0} to {1}".format(self.remote_filepath, self.local_filepath) self.log.debug("Starting to transfer %s", file_msg) sftp_client.get(self.remote_filepath, self.local_filepath) else: + remote_folder = os.path.dirname(self.remote_filepath) + if self.create_intermediate_dirs: + _make_intermediate_dirs( + sftp_client=sftp_client, + remote_directory=remote_folder, + ) file_msg = "from {0} to {1}".format(self.local_filepath, self.remote_filepath) self.log.debug("Starting to transfer file %s", file_msg) @@ -118,3 +155,26 @@ def execute(self, context): .format(file_msg, str(e))) return None + + +def _make_intermediate_dirs(sftp_client, remote_directory): + """ + Create all the intermediate directories in a remote host + + :param sftp_client: A Paramiko SFTP client. + :param remote_directory: Absolute Path of the directory containing the file + :return: + """ + if remote_directory == '/': + sftp_client.chdir('/') + return + if remote_directory == '': + return + try: + sftp_client.chdir(remote_directory) + except IOError: + dirname, basename = os.path.split(remote_directory.rstrip('/')) + _make_intermediate_dirs(sftp_client, dirname) + sftp_client.mkdir(basename) + sftp_client.chdir(basename) + return diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 7a450c0844..66d793991d 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -62,12 +62,20 @@ def setUp(self): self.hook = hook self.dag = dag self.test_dir = "/tmp" + self.test_local_dir = "/tmp/tmp2" + self.test_remote_dir = "/tmp/tmp1" self.test_local_filename = 'test_local_file' self.test_remote_filename = 'test_remote_file' self.test_local_filepath = '{0}/{1}'.format(self.test_dir, self.test_local_filename) + # Local Filepath with Intermediate Directory + self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir, + self.test_local_filename) self.test_remote_filepath = '{0}/{1}'.format(self.test_dir, self.test_remote_filename) + # Remote Filepath with Intermediate Directory + self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir, + self.test_remote_filename) def test_pickle_file_transfer_put(self): configuration.conf.set("core", "enable_xcom_pickling", "True") @@ -85,6 +93,7 @@ def test_pickle_file_transfer_put(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, + create_intermediate_dirs=True, dag=self.dag ) self.assertIsNotNone(put_test_task) @@ -106,6 +115,71 @@ def test_pickle_file_transfer_put(self): ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), test_local_file_content) + def test_file_transfer_no_intermediate_dir_error_put(self): + configuration.conf.set("core", "enable_xcom_pickling", "True") + test_local_file_content = \ + b"This is local file content \n which is multiline " \ + b"continuing....with other character\nanother line here \n this is last line" + # create a test file locally + with open(self.test_local_filepath, 'wb') as f: + f.write(test_local_file_content) + + # Try to put test file to remote + # This should raise an error with "No such file" as the directory + # does not exist + with self.assertRaises(Exception) as error: + put_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath_int_dir, + operation=SFTPOperation.PUT, + create_intermediate_dirs=False, + dag=self.dag + ) + self.assertIsNotNone(put_test_task) + ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) + ti2.run() + self.assertIn('No such file', str(error.exception)) + + def test_file_transfer_with_intermediate_dir_put(self): + configuration.conf.set("core", "enable_xcom_pickling", "True") + test_local_file_content = \ + b"This is local file content \n which is multiline " \ + b"continuing....with other character\nanother line here \n this is last line" + # create a test file locally + with open(self.test_local_filepath, 'wb') as f: + f.write(test_local_file_content) + + # put test file to remote + put_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath_int_dir, + operation=SFTPOperation.PUT, + create_intermediate_dirs=True, + dag=self.dag + ) + self.assertIsNotNone(put_test_task) + ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # check the remote file content + check_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="cat {0}".format(self.test_remote_filepath_int_dir), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(check_file_task) + ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) + ti3.run() + self.assertEqual( + ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), + test_local_file_content) + def test_json_file_transfer_put(self): configuration.conf.set("core", "enable_xcom_pickling", "False") test_local_file_content = \ @@ -220,6 +294,81 @@ def test_json_file_transfer_get(self): self.assertEqual(content_received.strip(), test_remote_file_content.encode('utf-8').decode('utf-8')) + def test_file_transfer_no_intermediate_dir_error_get(self): + configuration.conf.set("core", "enable_xcom_pickling", "True") + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) + ti1.run() + + # Try to GET test file from remote + # This should raise an error with "No such file" as the directory + # does not exist + with self.assertRaises(Exception) as error: + get_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath_int_dir, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + dag=self.dag + ) + self.assertIsNotNone(get_test_task) + ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) + ti2.run() + self.assertIn('No such file', str(error.exception)) + + def test_file_transfer_with_intermediate_dir_error_get(self): + configuration.conf.set("core", "enable_xcom_pickling", "True") + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) + ti1.run() + + # get remote file to local + get_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath_int_dir, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + create_intermediate_dirs=True, + dag=self.dag + ) + self.assertIsNotNone(get_test_task) + ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # test the received content + content_received = None + with open(self.test_local_filepath_int_dir, 'r') as f: + content_received = f.read() + self.assertEqual(content_received.strip(), test_remote_file_content) + def test_arg_checking(self): from airflow.exceptions import AirflowException conn_id = "conn_id_for_testing" @@ -288,22 +437,32 @@ def test_arg_checking(self): def delete_local_resource(self): if os.path.exists(self.test_local_filepath): os.remove(self.test_local_filepath) + if os.path.exists(self.test_local_filepath_int_dir): + os.remove(self.test_local_filepath_int_dir) + if os.path.exists(self.test_local_dir): + os.rmdir(self.test_local_dir) def delete_remote_resource(self): - # check the remote file content - remove_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="rm {0}".format(self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(remove_file_task) - ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) - ti3.run() + if os.path.exists(self.test_remote_filepath): + # check the remote file content + remove_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="rm {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(remove_file_task) + ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) + ti3.run() + if os.path.exists(self.test_remote_filepath_int_dir): + os.remove(self.test_remote_filepath_int_dir) + if os.path.exists(self.test_remote_dir): + os.rmdir(self.test_remote_dir) def tearDown(self): - self.delete_local_resource() and self.delete_remote_resource() + self.delete_local_resource() + self.delete_remote_resource() if __name__ == '__main__': ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services