This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d7cef588d6 SqlToS3Operator - Add feature to partition SQL table 
(#30460)
d7cef588d6 is described below

commit d7cef588d6f6a749bd5e8fbf3153a275f4120ee8
Author: Utkarsh Sharma <utkarshar...@gmail.com>
AuthorDate: Wed Apr 19 04:49:49 2023 +0530

    SqlToS3Operator - Add feature to partition SQL table (#30460)
---
 .../providers/amazon/aws/transfers/sql_to_s3.py    | 32 +++++---
 .../transfer/sql_to_s3.rst                         | 13 +++
 .../amazon/aws/transfers/test_sql_to_s3.py         | 94 ++++++++++++++++++++++
 .../providers/amazon/aws/example_sql_to_s3.py      | 13 +++
 4 files changed, 143 insertions(+), 9 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py 
b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index a8b5a9cd1c..8cee9b6cff 100644
--- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -80,6 +80,7 @@ class SqlToS3Operator(BaseOperator):
                 CA cert bundle than the one used by botocore.
     :param file_format: the destination file format, only string 'csv', 'json' 
or 'parquet' is accepted.
     :param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, 
``.to_json()`` or ``.to_csv()``.
+    :param groupby_kwargs: argument to include in DataFrame ``groupby()``.
     """
 
     template_fields: Sequence[str] = (
@@ -107,6 +108,7 @@ class SqlToS3Operator(BaseOperator):
         verify: bool | str | None = None,
         file_format: Literal["csv", "json", "parquet"] = "csv",
         pd_kwargs: dict | None = None,
+        groupby_kwargs: dict | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -119,6 +121,7 @@ class SqlToS3Operator(BaseOperator):
         self.replace = replace
         self.pd_kwargs = pd_kwargs or {}
         self.parameters = parameters
+        self.groupby_kwargs = groupby_kwargs or {}
 
         if "path_or_buf" in self.pd_kwargs:
             raise AirflowException("The argument path_or_buf is not allowed, 
please remove it")
@@ -170,15 +173,26 @@ class SqlToS3Operator(BaseOperator):
         self._fix_dtypes(data_df, self.file_format)
         file_options = FILE_OPTIONS_MAP[self.file_format]
 
-        with NamedTemporaryFile(mode=file_options.mode, 
suffix=file_options.suffix) as tmp_file:
-
-            self.log.info("Writing data to temp file")
-            getattr(data_df, file_options.function)(tmp_file.name, 
**self.pd_kwargs)
-
-            self.log.info("Uploading data to S3")
-            s3_conn.load_file(
-                filename=tmp_file.name, key=self.s3_key, 
bucket_name=self.s3_bucket, replace=self.replace
-            )
+        for group_name, df in self._partition_dataframe(df=data_df):
+            with NamedTemporaryFile(mode=file_options.mode, 
suffix=file_options.suffix) as tmp_file:
+
+                self.log.info("Writing data to temp file")
+                getattr(df, file_options.function)(tmp_file.name, 
**self.pd_kwargs)
+
+                self.log.info("Uploading data to S3")
+                object_key = f"{self.s3_key}_{group_name}" if group_name else 
self.s3_key
+                s3_conn.load_file(
+                    filename=tmp_file.name, key=object_key, 
bucket_name=self.s3_bucket, replace=self.replace
+                )
+
+    def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, 
DataFrame]]:
+        """Partition dataframe using pandas groupby() method"""
+        if not self.groupby_kwargs:
+            yield "", df
+        else:
+            grouped_df = df.groupby(**self.groupby_kwargs)
+            for group_label in grouped_df.groups.keys():
+                yield group_label, 
grouped_df.get_group(group_label).reset_index(drop=True)
 
     def _get_hook(self) -> DbApiHook:
         self.log.debug("Get connection for %s", self.sql_conn_id)
diff --git a/docs/apache-airflow-providers-amazon/transfer/sql_to_s3.rst 
b/docs/apache-airflow-providers-amazon/transfer/sql_to_s3.rst
index 31223be0f6..e01d415942 100644
--- a/docs/apache-airflow-providers-amazon/transfer/sql_to_s3.rst
+++ b/docs/apache-airflow-providers-amazon/transfer/sql_to_s3.rst
@@ -50,6 +50,19 @@ Example usage:
     :start-after: [START howto_transfer_sql_to_s3]
     :end-before: [END howto_transfer_sql_to_s3]
 
+Grouping
+--------
+
+We can group the data in the table by passing the ``groupby_kwargs`` param. 
This param accepts a ``dict`` which will be passed to pandas `groupby() 
<https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.groupby.html#pandas.DataFrame.groupby>`_
 as kwargs.
+
+Example usage:
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_sql_to_s3.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_transfer_sql_to_s3_with_groupby_param]
+    :end-before: [END howto_transfer_sql_to_s3_with_groupby_param]
+
 Reference
 ---------
 
diff --git a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py 
b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
index 84db615eeb..a0e4e6f603 100644
--- a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
@@ -175,3 +175,97 @@ class TestSqlToS3Operator:
                 file_format="invalid_format",
                 dag=None,
             )
+
+    def test_with_groupby_kwarg(self):
+        """
+        Test operator when the groupby_kwargs is specified
+        """
+        query = "query"
+        s3_bucket = "bucket"
+        s3_key = "key"
+
+        op = SqlToS3Operator(
+            query=query,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            sql_conn_id="mysql_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            replace=True,
+            pd_kwargs={"index": False, "header": False},
+            groupby_kwargs={"by": "Team"},
+            dag=None,
+        )
+        example = {
+            "Team": ["Australia", "Australia", "India", "India"],
+            "Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
+            "Runs": [345, 490, 672, 560],
+        }
+
+        df = pd.DataFrame(example)
+        data = []
+        for group_name, df in op._partition_dataframe(df):
+            data.append((group_name, df))
+        data.sort(key=lambda d: d[0])
+        team, df = data[0]
+        assert df.equals(
+            pd.DataFrame(
+                {
+                    "Team": ["Australia", "Australia"],
+                    "Player": ["Ricky", "David Warner"],
+                    "Runs": [345, 490],
+                }
+            )
+        )
+        team, df = data[1]
+        assert df.equals(
+            pd.DataFrame(
+                {
+                    "Team": ["India", "India"],
+                    "Player": ["Virat Kohli", "Rohit Sharma"],
+                    "Runs": [672, 560],
+                }
+            )
+        )
+
+    def test_without_groupby_kwarg(self):
+        """
+        Test operator when the groupby_kwargs is not specified
+        """
+        query = "query"
+        s3_bucket = "bucket"
+        s3_key = "key"
+
+        op = SqlToS3Operator(
+            query=query,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            sql_conn_id="mysql_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            replace=True,
+            pd_kwargs={"index": False, "header": False},
+            dag=None,
+        )
+        example = {
+            "Team": ["Australia", "Australia", "India", "India"],
+            "Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
+            "Runs": [345, 490, 672, 560],
+        }
+
+        df = pd.DataFrame(example)
+        data = []
+        for group_name, df in op._partition_dataframe(df):
+            data.append((group_name, df))
+
+        assert len(data) == 1
+        team, df = data[0]
+        assert df.equals(
+            pd.DataFrame(
+                {
+                    "Team": ["Australia", "Australia", "India", "India"],
+                    "Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit 
Sharma"],
+                    "Runs": [345, 490, 672, 560],
+                }
+            )
+        )
diff --git a/tests/system/providers/amazon/aws/example_sql_to_s3.py 
b/tests/system/providers/amazon/aws/example_sql_to_s3.py
index f983881a3f..bda1e4c9df 100644
--- a/tests/system/providers/amazon/aws/example_sql_to_s3.py
+++ b/tests/system/providers/amazon/aws/example_sql_to_s3.py
@@ -173,6 +173,18 @@ with DAG(
     )
     # [END howto_transfer_sql_to_s3]
 
+    # [START howto_transfer_sql_to_s3_with_groupby_param]
+    sql_to_s3_task_with_groupby = SqlToS3Operator(
+        task_id="sql_to_s3_with_groupby_task",
+        sql_conn_id=conn_id_name,
+        query=SQL_QUERY,
+        s3_bucket=bucket_name,
+        s3_key=key,
+        replace=True,
+        groupby_kwargs={"by": "color"},
+    )
+    # [END howto_transfer_sql_to_s3_with_groupby_param]
+
     delete_bucket = S3DeleteBucketOperator(
         task_id="delete_bucket",
         bucket_name=bucket_name,
@@ -202,6 +214,7 @@ with DAG(
         insert_data,
         # TEST BODY
         sql_to_s3_task,
+        sql_to_s3_task_with_groupby,
         # TEST TEARDOWN
         delete_bucket,
         delete_cluster,

Reply via email to