This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 34ac7de89711 [SPARK-48536][PYTHON][CONNECT] Cache user specified schema in applyInPandas and applyInArrow 34ac7de89711 is described below commit 34ac7de897115caada7330aed32f03aca4796299 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Jun 5 20:42:00 2024 +0800 [SPARK-48536][PYTHON][CONNECT] Cache user specified schema in applyInPandas and applyInArrow ### What changes were proposed in this pull request? Cache user specified schema in applyInPandas and applyInArrow ### Why are the changes needed? to avoid extra RPCs ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46877 from zhengruifeng/cache_schema_apply_in_x. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/group.py | 20 ++- .../connect/test_connect_dataframe_property.py | 145 ++++++++++++++++++++- 2 files changed, 160 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 2a5bb5939a3f..85806b1a265b 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -301,7 +301,7 @@ class GroupedData: evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, ) - return DataFrame( + res = DataFrame( plan.GroupMap( child=self._df._plan, grouping_cols=self._grouping_cols, @@ -310,6 +310,9 @@ class GroupedData: ), session=self._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__ @@ -370,7 +373,7 @@ class GroupedData: evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, ) - return DataFrame( + res = DataFrame( plan.GroupMap( child=self._df._plan, grouping_cols=self._grouping_cols, @@ -379,6 +382,9 @@ class GroupedData: ), session=self._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__ @@ -410,7 +416,7 @@ class PandasCogroupedOps: evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, ) - return DataFrame( + res = DataFrame( plan.CoGroupMap( input=self._gd1._df._plan, input_grouping_cols=self._gd1._grouping_cols, @@ -420,6 +426,9 @@ class PandasCogroupedOps: ), session=self._gd1._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__ @@ -436,7 +445,7 @@ class PandasCogroupedOps: evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, ) - return DataFrame( + res = DataFrame( plan.CoGroupMap( input=self._gd1._df._plan, input_grouping_cols=self._gd1._grouping_cols, @@ -446,6 +455,9 @@ class PandasCogroupedOps: ), session=self._gd1._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index 6abf6303b7ca..f80f4509a7ce 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -17,7 +17,7 @@ import unittest -from pyspark.sql.types import StructType, StructField, StringType, IntegerType +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType from pyspark.sql.utils import is_remote from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase @@ -30,6 +30,7 @@ from pyspark.testing.sqlutils import ( if have_pyarrow: import pyarrow as pa + import pyarrow.compute as pc if have_pandas: import pandas as pd @@ -127,6 +128,148 @@ class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase): self.assertEqual(cdf1.schema, sdf1.schema) self.assertEqual(cdf1.collect(), sdf1.collect()) + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_group_apply_in_pandas(self): + data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)] + cdf = self.connect.createDataFrame(data, ("id", "v")) + sdf = self.spark.createDataFrame(data, ("id", "v")) + + def normalize(pdf): + v = pdf.v + return pdf.assign(v=(v - v.mean()) / v.std()) + + schema = StructType( + [ + StructField("id", LongType(), True), + StructField("v", DoubleType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf1 = cdf.groupby("id").applyInPandas(normalize, schema) + self.assertEqual(cdf1._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf1 = sdf.groupby("id").applyInPandas(normalize, schema) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertEqual(cdf1.collect(), sdf1.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_group_apply_in_arrow(self): + data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)] + cdf = self.connect.createDataFrame(data, ("id", "v")) + sdf = self.spark.createDataFrame(data, ("id", "v")) + + def normalize(table): + v = table.column("v") + norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1)) + return table.set_column(1, "v", norm) + + schema = StructType( + [ + StructField("id", LongType(), True), + StructField("v", DoubleType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf1 = cdf.groupby("id").applyInArrow(normalize, schema) + self.assertEqual(cdf1._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf1 = sdf.groupby("id").applyInArrow(normalize, schema) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertEqual(cdf1.collect(), sdf1.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_cogroup_apply_in_pandas(self): + data1 = [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)] + data2 = [(20000101, 1, "x"), (20000101, 2, "y")] + + cdf1 = self.connect.createDataFrame(data1, ("time", "id", "v1")) + sdf1 = self.spark.createDataFrame(data1, ("time", "id", "v1")) + cdf2 = self.connect.createDataFrame(data2, ("time", "id", "v2")) + sdf2 = self.spark.createDataFrame(data2, ("time", "id", "v2")) + + def asof_join(left, right): + return pd.merge_asof(left, right, on="time", by="id") + + schema = StructType( + [ + StructField("time", IntegerType(), True), + StructField("id", IntegerType(), True), + StructField("v1", DoubleType(), True), + StructField("v2", StringType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInPandas(asof_join, schema) + self.assertEqual(cdf3._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInPandas(asof_join, schema) + + self.assertEqual(cdf3.schema, sdf3.schema) + self.assertEqual(cdf3.collect(), sdf3.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_cogroup_apply_in_arrow(self): + data1 = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)] + data2 = [(1, "x"), (2, "y")] + + cdf1 = self.connect.createDataFrame(data1, ("id", "v1")) + sdf1 = self.spark.createDataFrame(data1, ("id", "v1")) + cdf2 = self.connect.createDataFrame(data2, ("id", "v2")) + sdf2 = self.spark.createDataFrame(data2, ("id", "v2")) + + def summarize(left, right): + return pa.Table.from_pydict( + { + "left": [left.num_rows], + "right": [right.num_rows], + } + ) + + schema = StructType( + [ + StructField("left", LongType(), True), + StructField("right", LongType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInArrow(summarize, schema) + self.assertEqual(cdf3._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInArrow(summarize, schema) + + self.assertEqual(cdf3.schema, sdf3.schema) + self.assertEqual(cdf3.collect(), sdf3.collect()) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org