kaxil closed pull request #4270: [AIRFLOW-3434] Allows creating intermediate 
folders in SFTPOperator
URL: https://github.com/apache/incubator-airflow/pull/4270
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/sftp_operator.py 
b/airflow/contrib/operators/sftp_operator.py
index 620d875f89..117bc55a8c 100644
--- a/airflow/contrib/operators/sftp_operator.py
+++ b/airflow/contrib/operators/sftp_operator.py
@@ -16,6 +16,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import os
+
 from airflow.contrib.hooks.ssh_hook import SSHHook
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
@@ -48,9 +50,28 @@ class SFTPOperator(BaseOperator):
     :param remote_filepath: remote file path to get or put. (templated)
     :type remote_filepath: str
     :param operation: specify operation 'get' or 'put', defaults to put
-    :type get: bool
+    :type operation: str
     :param confirm: specify if the SFTP operation should be confirmed, 
defaults to True
     :type confirm: bool
+    :param create_intermediate_dirs: create missing intermediate directories 
when
+        copying from remote to local and vice-versa. Default is False.
+
+        Example: The following task would copy ``file.txt`` to the remote host
+        at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if 
they
+        don't exist. If the parameter is not passed it would error as the 
directory
+        does not exist. ::
+
+            put_file = SFTPOperator(
+                task_id="test_sftp",
+                ssh_conn="ssh_default",
+                local_filepath="/tmp/file.txt",
+                remote_filepath="/tmp/tmp1/tmp2/file.txt",
+                operation="put",
+                create_intermediate_dirs=True,
+                dag=dag
+            )
+
+    :type create_intermediate_dirs: bool
     """
     template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
 
@@ -63,6 +84,7 @@ def __init__(self,
                  remote_filepath=None,
                  operation=SFTPOperation.PUT,
                  confirm=True,
+                 create_intermediate_dirs=False,
                  *args,
                  **kwargs):
         super(SFTPOperator, self).__init__(*args, **kwargs)
@@ -73,6 +95,7 @@ def __init__(self,
         self.remote_filepath = remote_filepath
         self.operation = operation
         self.confirm = confirm
+        self.create_intermediate_dirs = create_intermediate_dirs
         if not (self.operation.lower() == SFTPOperation.GET or
                 self.operation.lower() == SFTPOperation.PUT):
             raise TypeError("unsupported operation value {0}, expected {1} or 
{2}"
@@ -101,11 +124,25 @@ def execute(self, context):
             with self.ssh_hook.get_conn() as ssh_client:
                 sftp_client = ssh_client.open_sftp()
                 if self.operation.lower() == SFTPOperation.GET:
+                    local_folder = os.path.dirname(self.local_filepath)
+                    if self.create_intermediate_dirs:
+                        # Create Intermediate Directories if it doesn't exist
+                        try:
+                            os.makedirs(local_folder)
+                        except OSError:
+                            if not os.path.isdir(local_folder):
+                                raise
                     file_msg = "from {0} to {1}".format(self.remote_filepath,
                                                         self.local_filepath)
                     self.log.debug("Starting to transfer %s", file_msg)
                     sftp_client.get(self.remote_filepath, self.local_filepath)
                 else:
+                    remote_folder = os.path.dirname(self.remote_filepath)
+                    if self.create_intermediate_dirs:
+                        _make_intermediate_dirs(
+                            sftp_client=sftp_client,
+                            remote_directory=remote_folder,
+                        )
                     file_msg = "from {0} to {1}".format(self.local_filepath,
                                                         self.remote_filepath)
                     self.log.debug("Starting to transfer file %s", file_msg)
@@ -118,3 +155,26 @@ def execute(self, context):
                                    .format(file_msg, str(e)))
 
         return None
+
+
+def _make_intermediate_dirs(sftp_client, remote_directory):
+    """
+    Create all the intermediate directories in a remote host
+
+    :param sftp_client: A Paramiko SFTP client.
+    :param remote_directory: Absolute Path of the directory containing the file
+    :return:
+    """
+    if remote_directory == '/':
+        sftp_client.chdir('/')
+        return
+    if remote_directory == '':
+        return
+    try:
+        sftp_client.chdir(remote_directory)
+    except IOError:
+        dirname, basename = os.path.split(remote_directory.rstrip('/'))
+        _make_intermediate_dirs(sftp_client, dirname)
+        sftp_client.mkdir(basename)
+        sftp_client.chdir(basename)
+        return
diff --git a/tests/contrib/operators/test_sftp_operator.py 
b/tests/contrib/operators/test_sftp_operator.py
index 7a450c0844..66d793991d 100644
--- a/tests/contrib/operators/test_sftp_operator.py
+++ b/tests/contrib/operators/test_sftp_operator.py
@@ -62,12 +62,20 @@ def setUp(self):
         self.hook = hook
         self.dag = dag
         self.test_dir = "/tmp"
+        self.test_local_dir = "/tmp/tmp2"
+        self.test_remote_dir = "/tmp/tmp1"
         self.test_local_filename = 'test_local_file'
         self.test_remote_filename = 'test_remote_file'
         self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
                                                     self.test_local_filename)
+        # Local Filepath with Intermediate Directory
+        self.test_local_filepath_int_dir = 
'{0}/{1}'.format(self.test_local_dir,
+                                                            
self.test_local_filename)
         self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
                                                      self.test_remote_filename)
+        # Remote Filepath with Intermediate Directory
+        self.test_remote_filepath_int_dir = 
'{0}/{1}'.format(self.test_remote_dir,
+                                                             
self.test_remote_filename)
 
     def test_pickle_file_transfer_put(self):
         configuration.conf.set("core", "enable_xcom_pickling", "True")
@@ -85,6 +93,7 @@ def test_pickle_file_transfer_put(self):
             local_filepath=self.test_local_filepath,
             remote_filepath=self.test_remote_filepath,
             operation=SFTPOperation.PUT,
+            create_intermediate_dirs=True,
             dag=self.dag
         )
         self.assertIsNotNone(put_test_task)
@@ -106,6 +115,71 @@ def test_pickle_file_transfer_put(self):
             ti3.xcom_pull(task_ids='test_check_file', 
key='return_value').strip(),
             test_local_file_content)
 
+    def test_file_transfer_no_intermediate_dir_error_put(self):
+        configuration.conf.set("core", "enable_xcom_pickling", "True")
+        test_local_file_content = \
+            b"This is local file content \n which is multiline " \
+            b"continuing....with other character\nanother line here \n this is 
last line"
+        # create a test file locally
+        with open(self.test_local_filepath, 'wb') as f:
+            f.write(test_local_file_content)
+
+        # Try to put test file to remote
+        # This should raise an error with "No such file" as the directory
+        # does not exist
+        with self.assertRaises(Exception) as error:
+            put_test_task = SFTPOperator(
+                task_id="test_sftp",
+                ssh_hook=self.hook,
+                local_filepath=self.test_local_filepath,
+                remote_filepath=self.test_remote_filepath_int_dir,
+                operation=SFTPOperation.PUT,
+                create_intermediate_dirs=False,
+                dag=self.dag
+            )
+            self.assertIsNotNone(put_test_task)
+            ti2 = TaskInstance(task=put_test_task, 
execution_date=timezone.utcnow())
+            ti2.run()
+        self.assertIn('No such file', str(error.exception))
+
+    def test_file_transfer_with_intermediate_dir_put(self):
+        configuration.conf.set("core", "enable_xcom_pickling", "True")
+        test_local_file_content = \
+            b"This is local file content \n which is multiline " \
+            b"continuing....with other character\nanother line here \n this is 
last line"
+        # create a test file locally
+        with open(self.test_local_filepath, 'wb') as f:
+            f.write(test_local_file_content)
+
+        # put test file to remote
+        put_test_task = SFTPOperator(
+            task_id="test_sftp",
+            ssh_hook=self.hook,
+            local_filepath=self.test_local_filepath,
+            remote_filepath=self.test_remote_filepath_int_dir,
+            operation=SFTPOperation.PUT,
+            create_intermediate_dirs=True,
+            dag=self.dag
+        )
+        self.assertIsNotNone(put_test_task)
+        ti2 = TaskInstance(task=put_test_task, 
execution_date=timezone.utcnow())
+        ti2.run()
+
+        # check the remote file content
+        check_file_task = SSHOperator(
+            task_id="test_check_file",
+            ssh_hook=self.hook,
+            command="cat {0}".format(self.test_remote_filepath_int_dir),
+            do_xcom_push=True,
+            dag=self.dag
+        )
+        self.assertIsNotNone(check_file_task)
+        ti3 = TaskInstance(task=check_file_task, 
execution_date=timezone.utcnow())
+        ti3.run()
+        self.assertEqual(
+            ti3.xcom_pull(task_ids='test_check_file', 
key='return_value').strip(),
+            test_local_file_content)
+
     def test_json_file_transfer_put(self):
         configuration.conf.set("core", "enable_xcom_pickling", "False")
         test_local_file_content = \
@@ -220,6 +294,81 @@ def test_json_file_transfer_get(self):
         self.assertEqual(content_received.strip(),
                          
test_remote_file_content.encode('utf-8').decode('utf-8'))
 
+    def test_file_transfer_no_intermediate_dir_error_get(self):
+        configuration.conf.set("core", "enable_xcom_pickling", "True")
+        test_remote_file_content = \
+            "This is remote file content \n which is also multiline " \
+            "another line here \n this is last line. EOF"
+
+        # create a test file remotely
+        create_file_task = SSHOperator(
+            task_id="test_create_file",
+            ssh_hook=self.hook,
+            command="echo '{0}' > {1}".format(test_remote_file_content,
+                                              self.test_remote_filepath),
+            do_xcom_push=True,
+            dag=self.dag
+        )
+        self.assertIsNotNone(create_file_task)
+        ti1 = TaskInstance(task=create_file_task, 
execution_date=timezone.utcnow())
+        ti1.run()
+
+        # Try to GET test file from remote
+        # This should raise an error with "No such file" as the directory
+        # does not exist
+        with self.assertRaises(Exception) as error:
+            get_test_task = SFTPOperator(
+                task_id="test_sftp",
+                ssh_hook=self.hook,
+                local_filepath=self.test_local_filepath_int_dir,
+                remote_filepath=self.test_remote_filepath,
+                operation=SFTPOperation.GET,
+                dag=self.dag
+            )
+            self.assertIsNotNone(get_test_task)
+            ti2 = TaskInstance(task=get_test_task, 
execution_date=timezone.utcnow())
+            ti2.run()
+        self.assertIn('No such file', str(error.exception))
+
+    def test_file_transfer_with_intermediate_dir_error_get(self):
+        configuration.conf.set("core", "enable_xcom_pickling", "True")
+        test_remote_file_content = \
+            "This is remote file content \n which is also multiline " \
+            "another line here \n this is last line. EOF"
+
+        # create a test file remotely
+        create_file_task = SSHOperator(
+            task_id="test_create_file",
+            ssh_hook=self.hook,
+            command="echo '{0}' > {1}".format(test_remote_file_content,
+                                              self.test_remote_filepath),
+            do_xcom_push=True,
+            dag=self.dag
+        )
+        self.assertIsNotNone(create_file_task)
+        ti1 = TaskInstance(task=create_file_task, 
execution_date=timezone.utcnow())
+        ti1.run()
+
+        # get remote file to local
+        get_test_task = SFTPOperator(
+            task_id="test_sftp",
+            ssh_hook=self.hook,
+            local_filepath=self.test_local_filepath_int_dir,
+            remote_filepath=self.test_remote_filepath,
+            operation=SFTPOperation.GET,
+            create_intermediate_dirs=True,
+            dag=self.dag
+        )
+        self.assertIsNotNone(get_test_task)
+        ti2 = TaskInstance(task=get_test_task, 
execution_date=timezone.utcnow())
+        ti2.run()
+
+        # test the received content
+        content_received = None
+        with open(self.test_local_filepath_int_dir, 'r') as f:
+            content_received = f.read()
+        self.assertEqual(content_received.strip(), test_remote_file_content)
+
     def test_arg_checking(self):
         from airflow.exceptions import AirflowException
         conn_id = "conn_id_for_testing"
@@ -288,22 +437,32 @@ def test_arg_checking(self):
     def delete_local_resource(self):
         if os.path.exists(self.test_local_filepath):
             os.remove(self.test_local_filepath)
+        if os.path.exists(self.test_local_filepath_int_dir):
+            os.remove(self.test_local_filepath_int_dir)
+        if os.path.exists(self.test_local_dir):
+            os.rmdir(self.test_local_dir)
 
     def delete_remote_resource(self):
-        # check the remote file content
-        remove_file_task = SSHOperator(
-            task_id="test_check_file",
-            ssh_hook=self.hook,
-            command="rm {0}".format(self.test_remote_filepath),
-            do_xcom_push=True,
-            dag=self.dag
-        )
-        self.assertIsNotNone(remove_file_task)
-        ti3 = TaskInstance(task=remove_file_task, 
execution_date=timezone.utcnow())
-        ti3.run()
+        if os.path.exists(self.test_remote_filepath):
+            # check the remote file content
+            remove_file_task = SSHOperator(
+                task_id="test_check_file",
+                ssh_hook=self.hook,
+                command="rm {0}".format(self.test_remote_filepath),
+                do_xcom_push=True,
+                dag=self.dag
+            )
+            self.assertIsNotNone(remove_file_task)
+            ti3 = TaskInstance(task=remove_file_task, 
execution_date=timezone.utcnow())
+            ti3.run()
+        if os.path.exists(self.test_remote_filepath_int_dir):
+            os.remove(self.test_remote_filepath_int_dir)
+        if os.path.exists(self.test_remote_dir):
+            os.rmdir(self.test_remote_dir)
 
     def tearDown(self):
-        self.delete_local_resource() and self.delete_remote_resource()
+        self.delete_local_resource()
+        self.delete_remote_resource()
 
 
 if __name__ == '__main__':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to