This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8914e49551 SqlToS3Operator: feat/ add max_rows_per_file parameter 
(#37055)
8914e49551 is described below

commit 8914e49551d8ae5ece7418950b011c1f338b4634
Author: Selim CHERGUI <selim.cher...@hotmail.com>
AuthorDate: Tue Jan 30 01:27:46 2024 +0100

    SqlToS3Operator: feat/ add max_rows_per_file parameter (#37055)
    
    
    ---------
    
    Co-authored-by: Selim Chergui <selim.cher...@setec.com>
    Co-authored-by: Jarek Potiuk <ja...@potiuk.com>
---
 .../providers/amazon/aws/transfers/sql_to_s3.py    | 34 ++++++++++++--
 .../amazon/aws/transfers/test_sql_to_s3.py         | 52 ++++++++++++++++++++++
 2 files changed, 83 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py 
b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index c00784ad4a..f8691fa4a2 100644
--- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -81,6 +81,9 @@ class SqlToS3Operator(BaseOperator):
                 You can specify this argument if you want to use a different
                 CA cert bundle than the one used by botocore.
     :param file_format: the destination file format, only string 'csv', 'json' 
or 'parquet' is accepted.
+    :param max_rows_per_file: (optional) argument to set destination file 
number of rows limit, if source data
+        is larger than that, it will be dispatched into multiple files.
+        Will be ignored if ``groupby_kwargs`` argument is specified.
     :param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, 
``.to_json()`` or ``.to_csv()``.
     :param groupby_kwargs: argument to include in DataFrame ``groupby()``.
     """
@@ -110,6 +113,7 @@ class SqlToS3Operator(BaseOperator):
         aws_conn_id: str = "aws_default",
         verify: bool | str | None = None,
         file_format: Literal["csv", "json", "parquet"] = "csv",
+        max_rows_per_file: int = 0,
         pd_kwargs: dict | None = None,
         groupby_kwargs: dict | None = None,
         **kwargs,
@@ -124,12 +128,19 @@ class SqlToS3Operator(BaseOperator):
         self.replace = replace
         self.pd_kwargs = pd_kwargs or {}
         self.parameters = parameters
+        self.max_rows_per_file = max_rows_per_file
         self.groupby_kwargs = groupby_kwargs or {}
         self.sql_hook_params = sql_hook_params
 
         if "path_or_buf" in self.pd_kwargs:
             raise AirflowException("The argument path_or_buf is not allowed, 
please remove it")
 
+        if self.max_rows_per_file and self.groupby_kwargs:
+            raise AirflowException(
+                "SqlToS3Operator arguments max_rows_per_file and 
groupby_kwargs "
+                "can not be both specified. Please choose one."
+            )
+
         try:
             self.file_format = FILE_FORMAT[file_format.upper()]
         except KeyError:
@@ -177,10 +188,8 @@ class SqlToS3Operator(BaseOperator):
         s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
         data_df = sql_hook.get_pandas_df(sql=self.query, 
parameters=self.parameters)
         self.log.info("Data from SQL obtained")
-
         self._fix_dtypes(data_df, self.file_format)
         file_options = FILE_OPTIONS_MAP[self.file_format]
-
         for group_name, df in self._partition_dataframe(df=data_df):
             with NamedTemporaryFile(mode=file_options.mode, 
suffix=file_options.suffix) as tmp_file:
                 self.log.info("Writing data to temp file")
@@ -194,13 +203,32 @@ class SqlToS3Operator(BaseOperator):
 
     def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, 
pd.DataFrame]]:
         """Partition dataframe using pandas groupby() method."""
+        try:
+            import secrets
+            import string
+
+            import numpy as np
+        except ImportError:
+            pass
+        # if max_rows_per_file argument is specified, a temporary column with 
a random unusual name will be
+        # added to the dataframe. This column is used to dispatch the 
dataframe into smaller ones using groupby()
+        random_column_name = ""
+        if self.max_rows_per_file and not self.groupby_kwargs:
+            random_column_name = "".join(secrets.choice(string.ascii_letters) 
for _ in range(20))
+            df[random_column_name] = np.arange(len(df)) // 
self.max_rows_per_file
+            self.groupby_kwargs = {"by": random_column_name}
         if not self.groupby_kwargs:
             yield "", df
             return
         for group_label in (grouped_df := 
df.groupby(**self.groupby_kwargs)).groups:
             yield (
                 cast(str, group_label),
-                cast("pd.DataFrame", 
grouped_df.get_group(group_label).reset_index(drop=True)),
+                cast(
+                    "pd.DataFrame",
+                    grouped_df.get_group(group_label)
+                    .drop(random_column_name, axis=1, errors="ignore")
+                    .reset_index(drop=True),
+                ),
             )
 
     def _get_hook(self) -> DbApiHook:
diff --git a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py 
b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
index cc56fd064a..feee688d46 100644
--- a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
@@ -271,6 +271,58 @@ class TestSqlToS3Operator:
             )
         )
 
+    def test_with_max_rows_per_file(self):
+        """
+        Test operator when the max_rows_per_file is specified
+        """
+        query = "query"
+        s3_bucket = "bucket"
+        s3_key = "key"
+
+        op = SqlToS3Operator(
+            query=query,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            sql_conn_id="mysql_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            replace=True,
+            pd_kwargs={"index": False, "header": False},
+            max_rows_per_file=3,
+            dag=None,
+        )
+        example = {
+            "Team": ["Australia", "Australia", "India", "India"],
+            "Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
+            "Runs": [345, 490, 672, 560],
+        }
+
+        df = pd.DataFrame(example)
+        data = []
+        for group_name, df in op._partition_dataframe(df):
+            data.append((group_name, df))
+        data.sort(key=lambda d: d[0])
+        team, df = data[0]
+        assert df.equals(
+            pd.DataFrame(
+                {
+                    "Team": ["Australia", "Australia", "India"],
+                    "Player": ["Ricky", "David Warner", "Virat Kohli"],
+                    "Runs": [345, 490, 672],
+                }
+            )
+        )
+        team, df = data[1]
+        assert df.equals(
+            pd.DataFrame(
+                {
+                    "Team": ["India"],
+                    "Player": ["Rohit Sharma"],
+                    "Runs": [560],
+                }
+            )
+        )
+
     
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
     def test_hook_params(self, mock_get_conn):
         mock_get_conn.return_value = Connection(conn_id="postgres_test", 
conn_type="postgres")

Reply via email to