This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 34ac7de89711 [SPARK-48536][PYTHON][CONNECT] Cache user specified 
schema in applyInPandas and applyInArrow
34ac7de89711 is described below

commit 34ac7de897115caada7330aed32f03aca4796299
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Jun 5 20:42:00 2024 +0800

    [SPARK-48536][PYTHON][CONNECT] Cache user specified schema in applyInPandas 
and applyInArrow
    
    ### What changes were proposed in this pull request?
    Cache user specified schema in applyInPandas and applyInArrow
    
    ### Why are the changes needed?
    to avoid extra RPCs
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46877 from zhengruifeng/cache_schema_apply_in_x.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/group.py                |  20 ++-
 .../connect/test_connect_dataframe_property.py     | 145 ++++++++++++++++++++-
 2 files changed, 160 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 2a5bb5939a3f..85806b1a265b 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -301,7 +301,7 @@ class GroupedData:
             evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
         )
 
-        return DataFrame(
+        res = DataFrame(
             plan.GroupMap(
                 child=self._df._plan,
                 grouping_cols=self._grouping_cols,
@@ -310,6 +310,9 @@ class GroupedData:
             ),
             session=self._df._session,
         )
+        if isinstance(schema, StructType):
+            res._cached_schema = schema
+        return res
 
     applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__
 
@@ -370,7 +373,7 @@ class GroupedData:
             evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
         )
 
-        return DataFrame(
+        res = DataFrame(
             plan.GroupMap(
                 child=self._df._plan,
                 grouping_cols=self._grouping_cols,
@@ -379,6 +382,9 @@ class GroupedData:
             ),
             session=self._df._session,
         )
+        if isinstance(schema, StructType):
+            res._cached_schema = schema
+        return res
 
     applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__
 
@@ -410,7 +416,7 @@ class PandasCogroupedOps:
             evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         )
 
-        return DataFrame(
+        res = DataFrame(
             plan.CoGroupMap(
                 input=self._gd1._df._plan,
                 input_grouping_cols=self._gd1._grouping_cols,
@@ -420,6 +426,9 @@ class PandasCogroupedOps:
             ),
             session=self._gd1._df._session,
         )
+        if isinstance(schema, StructType):
+            res._cached_schema = schema
+        return res
 
     applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__
 
@@ -436,7 +445,7 @@ class PandasCogroupedOps:
             evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
         )
 
-        return DataFrame(
+        res = DataFrame(
             plan.CoGroupMap(
                 input=self._gd1._df._plan,
                 input_grouping_cols=self._gd1._grouping_cols,
@@ -446,6 +455,9 @@ class PandasCogroupedOps:
             ),
             session=self._gd1._df._session,
         )
+        if isinstance(schema, StructType):
+            res._cached_schema = schema
+        return res
 
     applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__
 
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py 
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index 6abf6303b7ca..f80f4509a7ce 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -17,7 +17,7 @@
 
 import unittest
 
-from pyspark.sql.types import StructType, StructField, StringType, IntegerType
+from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType, LongType, DoubleType
 from pyspark.sql.utils import is_remote
 
 from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
@@ -30,6 +30,7 @@ from pyspark.testing.sqlutils import (
 
 if have_pyarrow:
     import pyarrow as pa
+    import pyarrow.compute as pc
 
 if have_pandas:
     import pandas as pd
@@ -127,6 +128,148 @@ class 
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
         self.assertEqual(cdf1.schema, sdf1.schema)
         self.assertEqual(cdf1.collect(), sdf1.collect())
 
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_cached_schema_group_apply_in_pandas(self):
+        data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)]
+        cdf = self.connect.createDataFrame(data, ("id", "v"))
+        sdf = self.spark.createDataFrame(data, ("id", "v"))
+
+        def normalize(pdf):
+            v = pdf.v
+            return pdf.assign(v=(v - v.mean()) / v.std())
+
+        schema = StructType(
+            [
+                StructField("id", LongType(), True),
+                StructField("v", DoubleType(), True),
+            ]
+        )
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+            self.assertTrue(is_remote())
+            cdf1 = cdf.groupby("id").applyInPandas(normalize, schema)
+            self.assertEqual(cdf1._cached_schema, schema)
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+            self.assertFalse(is_remote())
+            sdf1 = sdf.groupby("id").applyInPandas(normalize, schema)
+
+        self.assertEqual(cdf1.schema, sdf1.schema)
+        self.assertEqual(cdf1.collect(), sdf1.collect())
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_cached_schema_group_apply_in_arrow(self):
+        data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)]
+        cdf = self.connect.createDataFrame(data, ("id", "v"))
+        sdf = self.spark.createDataFrame(data, ("id", "v"))
+
+        def normalize(table):
+            v = table.column("v")
+            norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1))
+            return table.set_column(1, "v", norm)
+
+        schema = StructType(
+            [
+                StructField("id", LongType(), True),
+                StructField("v", DoubleType(), True),
+            ]
+        )
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+            self.assertTrue(is_remote())
+            cdf1 = cdf.groupby("id").applyInArrow(normalize, schema)
+            self.assertEqual(cdf1._cached_schema, schema)
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+            self.assertFalse(is_remote())
+            sdf1 = sdf.groupby("id").applyInArrow(normalize, schema)
+
+        self.assertEqual(cdf1.schema, sdf1.schema)
+        self.assertEqual(cdf1.collect(), sdf1.collect())
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_cached_schema_cogroup_apply_in_pandas(self):
+        data1 = [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), 
(20000102, 2, 4.0)]
+        data2 = [(20000101, 1, "x"), (20000101, 2, "y")]
+
+        cdf1 = self.connect.createDataFrame(data1, ("time", "id", "v1"))
+        sdf1 = self.spark.createDataFrame(data1, ("time", "id", "v1"))
+        cdf2 = self.connect.createDataFrame(data2, ("time", "id", "v2"))
+        sdf2 = self.spark.createDataFrame(data2, ("time", "id", "v2"))
+
+        def asof_join(left, right):
+            return pd.merge_asof(left, right, on="time", by="id")
+
+        schema = StructType(
+            [
+                StructField("time", IntegerType(), True),
+                StructField("id", IntegerType(), True),
+                StructField("v1", DoubleType(), True),
+                StructField("v2", StringType(), True),
+            ]
+        )
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+            self.assertTrue(is_remote())
+            cdf3 = 
cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInPandas(asof_join, schema)
+            self.assertEqual(cdf3._cached_schema, schema)
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+            self.assertFalse(is_remote())
+            sdf3 = 
sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInPandas(asof_join, schema)
+
+        self.assertEqual(cdf3.schema, sdf3.schema)
+        self.assertEqual(cdf3.collect(), sdf3.collect())
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_cached_schema_cogroup_apply_in_arrow(self):
+        data1 = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)]
+        data2 = [(1, "x"), (2, "y")]
+
+        cdf1 = self.connect.createDataFrame(data1, ("id", "v1"))
+        sdf1 = self.spark.createDataFrame(data1, ("id", "v1"))
+        cdf2 = self.connect.createDataFrame(data2, ("id", "v2"))
+        sdf2 = self.spark.createDataFrame(data2, ("id", "v2"))
+
+        def summarize(left, right):
+            return pa.Table.from_pydict(
+                {
+                    "left": [left.num_rows],
+                    "right": [right.num_rows],
+                }
+            )
+
+        schema = StructType(
+            [
+                StructField("left", LongType(), True),
+                StructField("right", LongType(), True),
+            ]
+        )
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+            self.assertTrue(is_remote())
+            cdf3 = 
cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInArrow(summarize, schema)
+            self.assertEqual(cdf3._cached_schema, schema)
+
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+            self.assertFalse(is_remote())
+            sdf3 = 
sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInArrow(summarize, schema)
+
+        self.assertEqual(cdf3.schema, sdf3.schema)
+        self.assertEqual(cdf3.collect(), sdf3.collect())
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_dataframe_property import *  # 
noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to