This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5fbf7e2e22c [SPARK-41114][CONNECT][PYTHON][FOLLOW-UP] Python Client support for local data 5fbf7e2e22c is described below commit 5fbf7e2e22c92f6a506e88ef6d5b5d5fea2447ea Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Mon Nov 28 09:41:03 2022 -0400 [SPARK-41114][CONNECT][PYTHON][FOLLOW-UP] Python Client support for local data ### What changes were proposed in this pull request? Since the Spark Connect server now supports reading local data from the client. This patch implements the necessary changes in the Python client to support reading from a local Pandas Data frame. ``` import pandas pdf = pandas.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) df = spark.createDataFrame(pdf) rows = df.filter(df.a == lit(3)).collect() self.assertTrue(len(rows) == 1) self.assertEqual(rows[0][0], 3) self.assertEqual(rows[0][1], "c") ``` ### Why are the changes needed? Compatibility ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38803 from grundprinzip/SPARK-41114. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../sql/connect/planner/SparkConnectPlanner.scala | 3 ++ python/pyspark/sql/connect/plan.py | 32 +++++++++++++++++++++- python/pyspark/sql/connect/session.py | 30 ++++++++++++++++++++ .../sql/tests/connect/test_connect_basic.py | 14 ++++++++++ 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 40d4ecc7556..fa5a0068c68 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -278,6 +278,9 @@ class SparkConnectPlanner(session: SparkSession) { val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( Iterator(rel.getData.toByteArray), TaskContext.get()) + if (structType == null) { + throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.") + } val attributes = structType.toAttributes val proj = UnsafeProjection.create(attributes, attributes) new logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 7a57168fa73..805628cfe5b 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -25,7 +25,8 @@ from typing import ( TYPE_CHECKING, Mapping, ) - +import pandas +import pyarrow as pa import pyspark.sql.connect.proto as proto from pyspark.sql.connect.column import ( Column, @@ -177,6 +178,35 @@ class Read(LogicalPlan): """ +class LocalRelation(LogicalPlan): + """Creates a LocalRelation plan object based on a Pandas DataFrame.""" + + def __init__(self, pdf: "pandas.DataFrame") -> None: + super().__init__(None) + self._pdf = pdf + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + sink = pa.BufferOutputStream() + table = pa.Table.from_pandas(self._pdf) + with pa.ipc.new_stream(sink, table.schema) as writer: + for b in table.to_batches(): + writer.write_batch(b) + + plan = proto.Relation() + plan.local_relation.data = sink.getvalue().to_pybytes() + return plan + + def print(self, indent: int = 0) -> str: + return f"{' ' * indent}<LocalRelation>\n" + + def _repr_html_(self) -> str: + return """ + <ul> + <li>LocalRelation</li> + </ul> + """ + + class ShowString(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], numRows: int, truncate: int, vertical: bool diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 92f58140eac..c9b76cf47f9 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -17,6 +17,7 @@ from threading import RLock from typing import Optional, Any, Union, Dict, cast, overload +import pandas as pd import pyspark.sql.types from pyspark.sql.connect.client import SparkConnectClient @@ -24,6 +25,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import SQL, Range from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.utils import to_str +from . import plan from ._typing import OptionalPrimitiveType @@ -205,6 +207,34 @@ class SparkSession(object): # Create the reader self.read = DataFrameReader(self) + def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame": + """ + Creates a :class:`DataFrame` from a :class:`pandas.DataFrame`. + + .. versionadded:: 3.4.0 + + + Parameters + ---------- + data : :class:`pandas.DataFrame` + + Returns + ------- + :class:`DataFrame` + + Examples + -------- + >>> import pandas + >>> pdf = pandas.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + >>> self.connect.createDataFrame(pdf).collect() + [Row(a=1, b='a'), Row(a=2, b='b'), Row(a=3, b='c')] + + """ + assert data is not None + if len(data) == 0: + raise ValueError("Input data cannot be empty") + return DataFrame.withPlan(plan.LocalRelation(data), self) + @property def client(self) -> "SparkConnectClient": """ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 0b07a8328a1..97ba34d8269 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -131,6 +131,20 @@ class SparkConnectTests(SparkConnectSQLTestCase): result = df.select(u(df.id)).toPandas() self.assertIsNotNone(result) + def test_with_local_data(self): + """SPARK-41114: Test creating a dataframe using local data""" + pdf = pandas.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + df = self.connect.createDataFrame(pdf) + rows = df.filter(df.a == lit(3)).collect() + self.assertTrue(len(rows) == 1) + self.assertEqual(rows[0][0], 3) + self.assertEqual(rows[0][1], "c") + + # Check correct behavior for empty DataFrame + pdf = pandas.DataFrame({"a": []}) + with self.assertRaises(ValueError): + self.connect.createDataFrame(pdf) + def test_simple_explain_string(self): df = self.connect.read.table(self.tbl_name).limit(10) result = df._explain_string() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org