This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0b3d9544c93 [SPARK-40836][CONNECT] AnalyzeResult should use struct for schema 0b3d9544c93 is described below commit 0b3d9544c934c0c21609cd2c1a08687333c7e0ca Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Tue Oct 25 15:31:17 2022 +0800 [SPARK-40836][CONNECT] AnalyzeResult should use struct for schema ### What changes were proposed in this pull request? This PR replace column names and columns type with a schema (which is a struct). ### Why are the changes needed? Before this PR, AnalyzeResult separates column names and column types. However these two can be combined to form a schema which is a struct. This PR will simplify that proto message. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38301 from amaliujia/return_schema_use_struct. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../src/main/protobuf/spark/connect/base.proto | 6 +-- .../connect/planner/DataTypeProtoConverter.scala | 19 ++++++- .../sql/connect/service/SparkConnectService.scala | 47 +++++++++--------- .../connect/planner/SparkConnectServiceSuite.scala | 58 ++++++++++++++++++++++ python/pyspark/sql/connect/client.py | 47 ++++++++++++++++-- python/pyspark/sql/connect/dataframe.py | 26 +++++++++- python/pyspark/sql/connect/proto/base_pb2.py | 51 +++++++++---------- python/pyspark/sql/connect/proto/base_pb2.pyi | 27 +++------- .../sql/tests/connect/test_connect_basic.py | 10 ++++ 9 files changed, 212 insertions(+), 79 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index dff1734335e..b376515bf1a 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -22,6 +22,7 @@ package spark.connect; import "google/protobuf/any.proto"; import "spark/connect/commands.proto"; import "spark/connect/relations.proto"; +import "spark/connect/types.proto"; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; @@ -116,11 +117,10 @@ message Response { // reason about the performance. message AnalyzeResponse { string client_id = 1; - repeated string column_names = 2; - repeated string column_types = 3; + DataType schema = 2; // The extended explain string as produced by Spark. - string explain_string = 4; + string explain_string = 3; } // Main interface for the SparkConnect service. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index da3adce43ba..0ee90b5e8fb 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._ import org.apache.spark.connect.proto import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType} /** * This object offers methods to convert to/from connect proto to catalyst types. @@ -50,11 +50,28 @@ object DataTypeProtoConverter { proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build() case StringType => proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build() + case LongType => + proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build() + case struct: StructType => + toConnectProtoStructType(struct) case _ => throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") } } + def toConnectProtoStructType(schema: StructType): proto.DataType = { + val struct = proto.DataType.Struct.newBuilder() + for (structField <- schema.fields) { + struct.addFields( + proto.DataType.StructField + .newBuilder() + .setName(structField.name) + .setType(toConnectProtoType(structField.dataType)) + .setNullable(structField.nullable)) + } + proto.DataType.newBuilder().setStruct(struct).build() + } + def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = { mode match { case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 20776a29eda..5841017e5bb 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ - import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder import io.grpc.{Server, Status} @@ -35,7 +33,7 @@ import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, Spark import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT -import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} import org.apache.spark.sql.execution.ExtendedMode /** @@ -89,29 +87,16 @@ class SparkConnectService(debug: Boolean) request: Request, responseObserver: StreamObserver[AnalyzeResponse]): Unit = { try { + if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) { + responseObserver.onError( + new UnsupportedOperationException( + s"${request.getPlan.getOpTypeCase} not supported for analysis.")) + } val session = SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session - - val logicalPlan = request.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.ROOT => - new SparkConnectPlanner(request.getPlan.getRoot, session).transform() - case _ => - responseObserver.onError( - new UnsupportedOperationException( - s"${request.getPlan.getOpTypeCase} not supported for analysis.")) - return - } - val ds = Dataset.ofRows(session, logicalPlan) - val explainString = ds.queryExecution.explainString(ExtendedMode) - - val resp = proto.AnalyzeResponse - .newBuilder() - .setExplainString(explainString) - .setClientId(request.getClientId) - - resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava) - resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava) - responseObserver.onNext(resp.build()) + val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session) + response.setClientId(request.getClientId) + responseObserver.onNext(response.build()) responseObserver.onCompleted() } catch { case e: Throwable => @@ -120,6 +105,20 @@ class SparkConnectService(debug: Boolean) Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) } } + + def handleAnalyzePlanRequest( + relation: proto.Relation, + session: SparkSession): proto.AnalyzeResponse.Builder = { + val logicalPlan = new SparkConnectPlanner(relation, session).transform() + + val ds = Dataset.ofRows(session, logicalPlan) + val explainString = ds.queryExecution.explainString(ExtendedMode) + + val response = proto.AnalyzeResponse + .newBuilder() + .setExplainString(explainString) + response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema)) + } } /** diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala new file mode 100644 index 00000000000..4be8d1705b9 --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.planner + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Testing Connect Service implementation. + */ +class SparkConnectServiceSuite extends SharedSparkSession { + + test("Test schema in analyze response") { + withTable("test") { + spark.sql(""" + | CREATE TABLE test (col1 INT, col2 STRING) + | USING parquet + |""".stripMargin) + + val instance = new SparkConnectService(false) + val relation = proto.Relation + .newBuilder() + .setRead( + proto.Read + .newBuilder() + .setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("test").build()) + .build()) + .build() + + val response = instance.handleAnalyzePlanRequest(relation, spark) + + assert(response.getSchema.hasStruct) + val schema = response.getSchema.getStruct + assert(schema.getFieldsCount == 2) + assert( + schema.getFields(0).getName == "col1" + && schema.getFields(0).getType.getKindCase == proto.DataType.KindCase.I32) + assert( + schema.getFields(1).getName == "col2" + && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING) + } + } +} diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 0ae075521c6..f4b6d2ec302 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -33,6 +33,7 @@ from pyspark import cloudpickle from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.plan import SQL +from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType from typing import Optional, Any, Union @@ -91,14 +92,13 @@ class PlanMetrics: class AnalyzeResult: - def __init__(self, cols: typing.List[str], types: typing.List[str], explain: str): - self.cols = cols - self.types = types + def __init__(self, schema: pb2.DataType, explain: str): + self.schema = schema self.explain_string = explain @classmethod def fromProto(cls, pb: typing.Any) -> "AnalyzeResult": - return AnalyzeResult(pb.column_names, pb.column_types, pb.explain_string) + return AnalyzeResult(pb.schema, pb.explain_string) class RemoteSparkSession(object): @@ -151,7 +151,44 @@ class RemoteSparkSession(object): req.plan.CopyFrom(plan) return self._execute_and_fetch(req) - def analyze(self, plan: pb2.Plan) -> AnalyzeResult: + def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: + if schema.HasField("struct"): + structFields = [] + for proto_field in schema.struct.fields: + structFields.append( + StructField( + proto_field.name, + self._proto_schema_to_pyspark_schema(proto_field.type), + proto_field.nullable, + ) + ) + return StructType(structFields) + elif schema.HasField("i64"): + return LongType() + elif schema.HasField("string"): + return StringType() + else: + raise Exception("Only support long, string, struct conversion") + + def schema(self, plan: pb2.Plan) -> StructType: + proto_schema = self._analyze(plan).schema + # Server side should populate the struct field which is the schema. + assert proto_schema.HasField("struct") + structFields = [] + for proto_field in proto_schema.struct.fields: + structFields.append( + StructField( + proto_field.name, + self._proto_schema_to_pyspark_schema(proto_field.type), + proto_field.nullable, + ) + ) + return StructType(structFields) + + def explain_string(self, plan: pb2.Plan) -> str: + return self._analyze(plan).explain_string + + def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: req = pb2.Request() req.user_context.user_id = self._user_id req.plan.CopyFrom(plan) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 2b7e3d52039..bf9ed83615b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -34,6 +34,7 @@ from pyspark.sql.connect.column import ( Expression, LiteralExpression, ) +from pyspark.sql.types import StructType if TYPE_CHECKING: from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString @@ -96,7 +97,7 @@ class DataFrame(object): of the DataFrame with the changes applied. """ - def __init__(self, data: Optional[List[Any]] = None, schema: Optional[List[str]] = None): + def __init__(self, data: Optional[List[Any]] = None, schema: Optional[StructType] = None): """Creates a new data frame""" self._schema = schema self._plan: Optional[plan.LogicalPlan] = None @@ -315,11 +316,32 @@ class DataFrame(object): query = self._plan.to_proto(self._session) return self._session._to_pandas(query) + def schema(self) -> StructType: + """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`StructType` + """ + if self._schema is None: + if self._plan is not None: + query = self._plan.to_proto(self._session) + if self._session is None: + raise Exception("Cannot analyze without RemoteSparkSession.") + self._schema = self._session.schema(query) + return self._schema + else: + raise Exception("Empty plan.") + else: + return self._schema + def explain(self) -> str: if self._plan is not None: query = self._plan.to_proto(self._session) if self._session is None: raise Exception("Cannot analyze without RemoteSparkSession.") - return self._session.analyze(query).explain_string + return self._session.explain_string(query) else: return "" diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 408872dbb66..eb9ecc9157f 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -31,10 +31,11 @@ _sym_db = _symbol_database.Default() from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 from pyspark.sql.connect.proto import commands_pb2 as spark_dot_connect_dot_commands__pb2 from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2 +from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.co [...] + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_contex [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -45,28 +46,28 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" - _PLAN._serialized_start = 131 - _PLAN._serialized_end = 247 - _REQUEST._serialized_start = 250 - _REQUEST._serialized_end = 524 - _REQUEST_USERCONTEXT._serialized_start = 402 - _REQUEST_USERCONTEXT._serialized_end = 524 - _RESPONSE._serialized_start = 527 - _RESPONSE._serialized_end = 1495 - _RESPONSE_ARROWBATCH._serialized_start = 756 - _RESPONSE_ARROWBATCH._serialized_end = 931 - _RESPONSE_JSONBATCH._serialized_start = 933 - _RESPONSE_JSONBATCH._serialized_end = 993 - _RESPONSE_METRICS._serialized_start = 996 - _RESPONSE_METRICS._serialized_end = 1480 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1080 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1390 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1278 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1390 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1392 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1480 - _ANALYZERESPONSE._serialized_start = 1498 - _ANALYZERESPONSE._serialized_end = 1653 - _SPARKCONNECTSERVICE._serialized_start = 1656 - _SPARKCONNECTSERVICE._serialized_end = 1818 + _PLAN._serialized_start = 158 + _PLAN._serialized_end = 274 + _REQUEST._serialized_start = 277 + _REQUEST._serialized_end = 551 + _REQUEST_USERCONTEXT._serialized_start = 429 + _REQUEST_USERCONTEXT._serialized_end = 551 + _RESPONSE._serialized_start = 554 + _RESPONSE._serialized_end = 1522 + _RESPONSE_ARROWBATCH._serialized_start = 783 + _RESPONSE_ARROWBATCH._serialized_end = 958 + _RESPONSE_JSONBATCH._serialized_start = 960 + _RESPONSE_JSONBATCH._serialized_end = 1020 + _RESPONSE_METRICS._serialized_start = 1023 + _RESPONSE_METRICS._serialized_end = 1507 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1305 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1419 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1507 + _ANALYZERESPONSE._serialized_start = 1525 + _ANALYZERESPONSE._serialized_end = 1659 + _SPARKCONNECTSERVICE._serialized_start = 1662 + _SPARKCONNECTSERVICE._serialized_end = 1824 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index bb3a6578cf7..5ffd7701b44 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -41,6 +41,7 @@ import google.protobuf.internal.containers import google.protobuf.message import pyspark.sql.connect.proto.commands_pb2 import pyspark.sql.connect.proto.relations_pb2 +import pyspark.sql.connect.proto.types_pb2 import sys if sys.version_info >= (3, 8): @@ -401,39 +402,27 @@ class AnalyzeResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor CLIENT_ID_FIELD_NUMBER: builtins.int - COLUMN_NAMES_FIELD_NUMBER: builtins.int - COLUMN_TYPES_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int EXPLAIN_STRING_FIELD_NUMBER: builtins.int client_id: builtins.str @property - def column_names( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - @property - def column_types( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... explain_string: builtins.str """The extended explain string as produced by Spark.""" def __init__( self, *, client_id: builtins.str = ..., - column_names: collections.abc.Iterable[builtins.str] | None = ..., - column_types: collections.abc.Iterable[builtins.str] | None = ..., + schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., explain_string: builtins.str = ..., ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["schema", b"schema"] + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "client_id", - b"client_id", - "column_names", - b"column_names", - "column_types", - b"column_types", - "explain_string", - b"explain_string", + "client_id", b"client_id", "explain_string", b"explain_string", "schema", b"schema" ], ) -> None: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f6988a1d120..459b05cc37a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -22,6 +22,7 @@ import tempfile import pandas from pyspark.sql import SparkSession, Row +from pyspark.sql.types import StructType, StructField, LongType, StringType from pyspark.sql.connect.client import RemoteSparkSession from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit @@ -97,6 +98,15 @@ class SparkConnectTests(SparkConnectSQLTestCase): result = df.explain() self.assertGreater(len(result), 0) + def test_schema(self): + schema = self.connect.read.table(self.tbl_name).schema() + self.assertEqual( + StructType( + [StructField("id", LongType(), True), StructField("name", StringType(), True)] + ), + schema, + ) + def test_simple_binary_expressions(self): """Test complex expression""" df = self.connect.read.table(self.tbl_name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org