This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 658681f3eb5 [SPARK-40930][CONNECT] Support Collect() in Python client 658681f3eb5 is described below commit 658681f3eb5b8f3226ac8d3793e2c1a065351b6c Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Wed Nov 2 10:01:42 2022 +0900 [SPARK-40930][CONNECT] Support Collect() in Python client ### What changes were proposed in this pull request? Before this PR, the `collect()` call will throw an exception to recommend to use `toPandas()`. With this PR, we can generate a list of PySpark `Row` upon calling `collect()`. ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38409 from amaliujia/python_support_collect. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 13 ++++++++++--- python/pyspark/sql/tests/connect/test_connect_basic.py | 10 +++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index c7107a7e79f..b9ddb0db300 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -34,7 +34,10 @@ from pyspark.sql.connect.column import ( Expression, LiteralExpression, ) -from pyspark.sql.types import StructType +from pyspark.sql.types import ( + StructType, + Row, +) if TYPE_CHECKING: from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString @@ -317,8 +320,12 @@ class DataFrame(object): return self._plan.print() return "" - def collect(self) -> None: - raise NotImplementedError("Please use toPandas().") + def collect(self) -> List[Row]: + pdf = self.toPandas() + if pdf is not None: + return list(pdf.apply(lambda row: Row(**row), axis=1)) + else: + return [] def toPandas(self) -> Optional["pandas.DataFrame"]: if self._plan is None: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index e9a06f9c545..0d3fc76134e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -73,7 +73,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): # Setup Remote Spark Session cls.connect = RemoteSparkSession(user_id="test_user") df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) - # Since we might create multiple Spark sessions, we need to creata global temporary view + # Since we might create multiple Spark sessions, we need to create global temporary view # that is specifically maintained in the "global_temp" schema. df.write.saveAsTable(cls.tbl_name) @@ -89,6 +89,14 @@ class SparkConnectTests(SparkConnectSQLTestCase): # Check that the limit is applied self.assertEqual(len(data.index), 10) + def test_collect(self): + df = self.connect.read.table(self.tbl_name) + data = df.limit(10).collect() + self.assertEqual(len(data), 10) + # Check Row has schema column names. + self.assertTrue("name" in data[0]) + self.assertTrue("id" in data[0]) + def test_simple_udf(self): def conv_udf(x) -> str: return "Martin" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org