This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 93e0acbf7d9 [SPARK-43971][CONNECT][PYTHON] Support Python's createDataFrame in streaming manner 93e0acbf7d9 is described below commit 93e0acbf7d9fcf3422860b2a5d39379bebf7bc43 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Sat Jun 10 01:25:04 2023 +0300 [SPARK-43971][CONNECT][PYTHON] Support Python's createDataFrame in streaming manner ### What changes were proposed in this pull request? In the PR, I propose to transfer a local relation from **the Python connect client** to the server in streaming way when it exceeds some size which is defined by the SQL config `spark.sql.session.localRelationCacheThreshold`. The implementation is similar to https://github.com/apache/spark/pull/40827. In particular: 1. The client applies the `sha256` function over **the proto form** of the local relation; 2. It checks presents of the relation at the server side by sending the relation hash to the server; 3. If the server doesn't have the local relation, the client transfers the local relation as an artefact with the name `cache/<sha256>`; 4. As soon as the relation has presented at the server already, or transferred recently, the client transform the logical plan by replacing the `LocalRelation` node by `CachedLocalRelation` with the hash. 5. On another hand, the server converts `CachedLocalRelation` back to `LocalRelation` by retrieving the relation body from the local cache. ### Why are the changes needed? To fix the issues of creating a large dataframe from a local collection: ```python pyspark.errors.exceptions.connect.SparkConnectGrpcException: <_MultiThreadedRendezvous of RPC that terminated with: status = StatusCode.RESOURCE_EXHAUSTED details = "Sent message larger than max (134218508 vs. 134217728)" debug_error_string = "UNKNOWN:Error received from peer localhost:50982 {grpc_message:"Sent message larger than max (134218508 vs. 134217728)", grpc_status:8, created_time:"2023-06-09T15:34:08.362797+03:00"} ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running new test: ``` $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_streaming_local_relation' ``` Closes #41537 from MaxGekk/streaming-createDataFrame-python-4. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- python/pyspark/sql/connect/client/core.py | 3 ++ python/pyspark/sql/connect/plan.py | 34 ++++++++++++++++++++++ python/pyspark/sql/connect/session.py | 26 +++++++++++++++-- .../sql/tests/connect/test_connect_basic.py | 19 ++++++++++++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 25e395356d5..7368521259a 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1257,6 +1257,9 @@ class SparkConnectClient(object): def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None: self._artifact_manager._add_forward_to_fs_artifacts(local_path, dest_path) + def cache_artifact(self, blob: bytes) -> str: + return self._artifact_manager.cache_artifact(blob) + class RetryState: """ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index fc8b37b102c..406f65080d1 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -363,6 +363,10 @@ class LocalRelation(LogicalPlan): plan.local_relation.schema = self._schema return plan + def serialize(self, session: "SparkConnectClient") -> bytes: + p = self.plan(session) + return bytes(p.local_relation.SerializeToString()) + def print(self, indent: int = 0) -> str: return f"{' ' * indent}<LocalRelation>\n" @@ -374,6 +378,36 @@ class LocalRelation(LogicalPlan): """ +class CachedLocalRelation(LogicalPlan): + """Creates a CachedLocalRelation plan object based on a hash of a LocalRelation.""" + + def __init__(self, hash: str) -> None: + super().__init__(None) + + self._hash = hash + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + plan = self._create_proto_relation() + clr = plan.cached_local_relation + + if session._user_id: + clr.userId = session._user_id + clr.sessionId = session._session_id + clr.hash = self._hash + + return plan + + def print(self, indent: int = 0) -> str: + return f"{' ' * indent}<CachedLocalRelation>\n" + + def _repr_html_(self) -> str: + return """ + <ul> + <li><b>CachedLocalRelation</b></li> + </ul> + """ + + class ShowString(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], num_rows: int, truncate: int, vertical: bool diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 5877181963c..2b35ca3d7ea 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -51,7 +51,14 @@ from pyspark import SparkContext, SparkConf, __version__ from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder from pyspark.sql.connect.conf import RuntimeConf from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation +from pyspark.sql.connect.plan import ( + SQL, + Range, + LocalRelation, + LogicalPlan, + CachedLocalRelation, + CachedRelation, +) from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming import DataStreamReader, StreamingQueryManager from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer @@ -466,10 +473,16 @@ class SparkSession: ) if _schema is not None: - df = DataFrame.withPlan(LocalRelation(_table, schema=_schema.json()), self) + local_relation = LocalRelation(_table, schema=_schema.json()) else: - df = DataFrame.withPlan(LocalRelation(_table), self) + local_relation = LocalRelation(_table) + + cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold") + plan: LogicalPlan = local_relation + if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes: + plan = CachedLocalRelation(self._cache_local_relation(local_relation)) + df = DataFrame.withPlan(plan, self) if _cols is not None and len(_cols) > 0: df = df.toDF(*_cols) return df @@ -643,6 +656,13 @@ class SparkSession: addArtifact = addArtifacts + def _cache_local_relation(self, local_relation: LocalRelation) -> str: + """ + Cache the local relation at the server side if it has not been cached yet. + """ + serialized = local_relation.serialize(self._client) + return self._client.cache_artifact(serialized) + def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: """ Copy file from local to cloud storage file system. diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a5139603919..18a7d8f19b4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -19,7 +19,9 @@ import array import datetime import os import unittest +import random import shutil +import string import tempfile from collections import defaultdict @@ -649,6 +651,23 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(sdf.schema, cdf.schema) self.assert_eq(sdf.toPandas(), cdf.toPandas()) + def test_streaming_local_relation(self): + threshold_conf = "spark.sql.session.localRelationCacheThreshold" + old_threshold = self.connect.conf.get(threshold_conf) + threshold = 1024 * 1024 + self.connect.conf.set(threshold_conf, threshold) + try: + suffix = "abcdef" + letters = string.ascii_lowercase + str = "".join(random.choice(letters) for i in range(threshold)) + suffix + data = [[0, str], [1, str]] + for i in range(0, 2): + cdf = self.connect.createDataFrame(data, ["a", "b"]) + self.assert_eq(cdf.count(), len(data)) + self.assert_eq(cdf.filter(f"endsWith(b, '{suffix}')").isEmpty(), False) + finally: + self.connect.conf.set(threshold_conf, old_threshold) + def test_with_atom_type(self): for data in [[(1), (2), (3)], [1, 2, 3]]: for schema in ["long", "int", "short"]: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org