This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 908adca2b22 [SPARK-41005][COLLECT][FOLLOWUP] Remove JSON code path and use `RDD.collect` in Arrow code path 908adca2b22 is described below commit 908adca2b229b05b2ae0dd31cbaaa1fdcde16290 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Nov 22 09:43:44 2022 +0900 [SPARK-41005][COLLECT][FOLLOWUP] Remove JSON code path and use `RDD.collect` in Arrow code path ### What changes were proposed in this pull request? 1, Remove JSON code path; 2, use RDD.collect in Arrow code path, since existing tests were already broken in Arrow code path; 3, reenable `test_fill_na` ### Why are the changes needed? existing Arrow code path is still problematic and it fails and fallback to JSON code path, which change the output datatypes of `test_fill_na` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? reenabled test and added UT Closes #38706 from zhengruifeng/collect_disable_json. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/base.proto | 14 +- .../service/SparkConnectStreamHandler.scala | 156 ++------------------- python/pyspark/sql/connect/client.py | 5 - python/pyspark/sql/connect/proto/base_pb2.py | 41 ++---- python/pyspark/sql/connect/proto/base_pb2.pyi | 51 +------ .../sql/tests/connect/test_connect_basic.py | 55 +++++++- 6 files changed, 82 insertions(+), 240 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 66e27187153..277da6b2431 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -139,11 +139,7 @@ message ExecutePlanRequest { message ExecutePlanResponse { string client_id = 1; - // Result type - oneof result_type { - ArrowBatch arrow_batch = 2; - JSONBatch json_batch = 3; - } + ArrowBatch arrow_batch = 2; // Metrics for the query execution. Typically, this field is only present in the last // batch of results and then represent the overall state of the query execution. @@ -155,14 +151,6 @@ message ExecutePlanResponse { bytes data = 2; } - // Message type when the result is returned as JSON. This is essentially a bulk wrapper - // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format - // of `{col -> row}`. - message JSONBatch { - int64 row_count = 1; - bytes data = 2; - } - message Metrics { repeated MetricObject metrics = 1; diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 50ff08f997c..092bdd00dc1 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver -import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging @@ -34,7 +32,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ThreadUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) extends Logging { @@ -57,75 +54,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(session) val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) - try { - processAsArrowBatches(request.getClientId, dataframe) - } catch { - case e: Exception => - logWarning(e.getMessage) - processAsJsonBatches(request.getClientId, dataframe) - } - } - - def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { - // Only process up to 10MB of data. - val sb = new StringBuilder - var rowCount = 0 - dataframe.toJSON - .collect() - .foreach(row => { - - // There are a few cases to cover here. - // 1. The aggregated buffer size is larger than the MAX_BATCH_SIZE - // -> send the current batch and reset. - // 2. The aggregated buffer size is smaller than the MAX_BATCH_SIZE - // -> append the row to the buffer. - // 3. The row in question is larger than the MAX_BATCH_SIZE - // -> fail the query. - - // Case 3. - Fail - if (row.size > MAX_BATCH_SIZE) { - throw SparkException.internalError( - s"Serialized row is larger than MAX_BATCH_SIZE: ${row.size} > ${MAX_BATCH_SIZE}") - } - - // Case 1 - FLush and send. - if (sb.size + row.size > MAX_BATCH_SIZE) { - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.JSONBatch - .newBuilder() - .setData(ByteString.copyFromUtf8(sb.toString())) - .setRowCount(rowCount) - .build() - response.setJsonBatch(batch) - responseObserver.onNext(response.build()) - sb.clear() - sb.append(row) - rowCount = 1 - } else { - // Case 2 - Append. - // Make sure to put the newline delimiters only between items and not at the end. - if (rowCount > 0) { - sb.append("\n") - } - sb.append(row) - rowCount += 1 - } - }) - - // If the last batch is not empty, send out the data to the client. - if (sb.size > 0) { - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.JSONBatch - .newBuilder() - .setData(ByteString.copyFromUtf8(sb.toString())) - .setRowCount(rowCount) - .build() - response.setJsonBatch(batch) - responseObserver.onNext(response.build()) - } - - responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) - responseObserver.onCompleted() + processAsArrowBatches(request.getClientId, dataframe) } def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { @@ -142,83 +71,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp var numSent = 0 if (numPartitions > 0) { - type Batch = (Array[Byte], Long) - val batches = rows.mapPartitionsInternal( SparkConnectStreamHandler .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)) - val signal = new Object - val partitions = collection.mutable.Map.empty[Int, Array[Batch]] - var error: Throwable = null - - val processPartition = (iter: Iterator[Batch]) => iter.toArray - - // This callback is executed by the DAGScheduler thread. - // After fetching a partition, it inserts the partition into the Map, and then - // wakes up the main thread. - val resultHandler = (partitionId: Int, partition: Array[Batch]) => { - signal.synchronized { - partitions(partitionId) = partition - signal.notify() - } - () - } - - val future = spark.sparkContext.submitJob( - rdd = batches, - processPartition = processPartition, - partitions = Seq.range(0, numPartitions), - resultHandler = resultHandler, - resultFunc = () => ()) - - // Collect errors and propagate them to the main thread. - future.onComplete { result => - result.failed.foreach { throwable => - signal.synchronized { - error = throwable - signal.notify() - } - } - }(ThreadUtils.sameThread) - - // The main thread will wait until 0-th partition is available, - // then send it to client and wait for the next partition. - // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends - // the arrow batches in main thread to avoid DAGScheduler thread been blocked for - // tasks not related to scheduling. This is particularly important if there are - // multiple users or clients running code at the same time. - var currentPartitionId = 0 - while (currentPartitionId < numPartitions) { - val partition = signal.synchronized { - var result = partitions.remove(currentPartitionId) - while (result.isEmpty && error == null) { - signal.wait() - result = partitions.remove(currentPartitionId) - } - error match { - case NonFatal(e) => - responseObserver.onError(error) - logError("Error while processing query.", e) - return - case fatal: Throwable => throw fatal - case null => result.get - } - } - - partition.foreach { case (bytes, count) => - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.ArrowBatch - .newBuilder() - .setRowCount(count) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) - numSent += 1 - } - - currentPartitionId += 1 + batches.collect().foreach { case (bytes, count) => + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.ArrowBatch + .newBuilder() + .setRowCount(count) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + numSent += 1 } } diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 5bdf01afc99..fdcf34b7a47 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -16,7 +16,6 @@ # -import io import logging import os import typing @@ -446,13 +445,9 @@ class RemoteSparkSession(object): return AnalyzeResult.fromProto(resp) def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]: - import pandas as pd - if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: with pa.ipc.open_stream(b.arrow_batch.data) as rd: return rd.read_pandas() - elif b.json_batch is not None and len(b.json_batch.data) > 0: - return pd.read_json(io.BytesIO(b.json_batch.data), lines=True) return None def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]: diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 7d9f98b243e..daa1c25cc8f 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...] + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...] ) @@ -48,7 +48,6 @@ _ANALYZEPLANRESPONSE = DESCRIPTOR.message_types_by_name["AnalyzePlanResponse"] _EXECUTEPLANREQUEST = DESCRIPTOR.message_types_by_name["ExecutePlanRequest"] _EXECUTEPLANRESPONSE = DESCRIPTOR.message_types_by_name["ExecutePlanResponse"] _EXECUTEPLANRESPONSE_ARROWBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["ArrowBatch"] -_EXECUTEPLANRESPONSE_JSONBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["JSONBatch"] _EXECUTEPLANRESPONSE_METRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["Metrics"] _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[ "MetricObject" @@ -139,15 +138,6 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.ArrowBatch) }, ), - "JSONBatch": _reflection.GeneratedProtocolMessageType( - "JSONBatch", - (_message.Message,), - { - "DESCRIPTOR": _EXECUTEPLANRESPONSE_JSONBATCH, - "__module__": "spark.connect.base_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.JSONBatch) - }, - ), "Metrics": _reflection.GeneratedProtocolMessageType( "Metrics", (_message.Message,), @@ -191,7 +181,6 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(ExecutePlanResponse) _sym_db.RegisterMessage(ExecutePlanResponse.ArrowBatch) -_sym_db.RegisterMessage(ExecutePlanResponse.JSONBatch) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry) @@ -219,19 +208,17 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXECUTEPLANREQUEST._serialized_start = 986 _EXECUTEPLANREQUEST._serialized_end = 1193 _EXECUTEPLANRESPONSE._serialized_start = 1196 - _EXECUTEPLANRESPONSE._serialized_end = 2137 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1479 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1540 - _EXECUTEPLANRESPONSE_JSONBATCH._serialized_start = 1542 - _EXECUTEPLANRESPONSE_JSONBATCH._serialized_end = 1602 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1605 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2122 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1700 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2032 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1909 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2032 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2034 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2122 - _SPARKCONNECTSERVICE._serialized_start = 2140 - _SPARKCONNECTSERVICE._serialized_end = 2339 + _EXECUTEPLANRESPONSE._serialized_end = 1979 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1398 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1459 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1462 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 1979 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1557 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 1889 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1766 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1889 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 1891 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 1979 + _SPARKCONNECTSERVICE._serialized_start = 1982 + _SPARKCONNECTSERVICE._serialized_end = 2181 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 18b70de57a3..64bb51d4c0b 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -401,28 +401,6 @@ class ExecutePlanResponse(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] ) -> None: ... - class JSONBatch(google.protobuf.message.Message): - """Message type when the result is returned as JSON. This is essentially a bulk wrapper - for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format - of `{col -> row}`. - """ - - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ROW_COUNT_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - row_count: builtins.int - data: builtins.bytes - def __init__( - self, - *, - row_count: builtins.int = ..., - data: builtins.bytes = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] - ) -> None: ... - class Metrics(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -530,14 +508,11 @@ class ExecutePlanResponse(google.protobuf.message.Message): CLIENT_ID_FIELD_NUMBER: builtins.int ARROW_BATCH_FIELD_NUMBER: builtins.int - JSON_BATCH_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int client_id: builtins.str @property def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ... @property - def json_batch(self) -> global___ExecutePlanResponse.JSONBatch: ... - @property def metrics(self) -> global___ExecutePlanResponse.Metrics: """Metrics for the query execution. Typically, this field is only present in the last batch of results and then represent the overall state of the query execution. @@ -547,39 +522,17 @@ class ExecutePlanResponse(google.protobuf.message.Message): *, client_id: builtins.str = ..., arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., - json_batch: global___ExecutePlanResponse.JSONBatch | None = ..., metrics: global___ExecutePlanResponse.Metrics | None = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal[ - "arrow_batch", - b"arrow_batch", - "json_batch", - b"json_batch", - "metrics", - b"metrics", - "result_type", - b"result_type", - ], + field_name: typing_extensions.Literal["arrow_batch", b"arrow_batch", "metrics", b"metrics"], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "arrow_batch", - b"arrow_batch", - "client_id", - b"client_id", - "json_batch", - b"json_batch", - "metrics", - b"metrics", - "result_type", - b"result_type", + "arrow_batch", b"arrow_batch", "client_id", b"client_id", "metrics", b"metrics" ], ) -> None: ... - def WhichOneof( - self, oneof_group: typing_extensions.Literal["result_type", b"result_type"] - ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ... global___ExecutePlanResponse = ExecutePlanResponse diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index d3de94a379f..9e7a5f2f4a5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -221,7 +221,60 @@ class SparkConnectTests(SparkConnectSQLTestCase): with self.assertRaises(_MultiThreadedRendezvous): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") - @unittest.skip("test_fill_na is flaky") + def test_to_pandas(self): + # SPARK-41005: Test to pandas + query = """ + SELECT * FROM VALUES + (false, 1, NULL), + (false, NULL, float(2.0)), + (NULL, 3, float(3.0)) + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (1, 1, NULL), + (2, NULL, float(2.0)), + (3, 3, float(3.0)) + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (double(1.0), 1, "1"), + (NULL, NULL, NULL), + (double(2.0), 3, "3") + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (float(1.0), double(1.0), 1, "1"), + (float(2.0), double(2.0), 2, "2"), + (float(3.0), double(3.0), 3, "3") + AS tab(a, b, c, d) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + def test_fill_na(self): # SPARK-41128: Test fill na query = """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org