This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b3d9544c93 [SPARK-40836][CONNECT] AnalyzeResult should use struct for 
schema
0b3d9544c93 is described below

commit 0b3d9544c934c0c21609cd2c1a08687333c7e0ca
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Tue Oct 25 15:31:17 2022 +0800

    [SPARK-40836][CONNECT] AnalyzeResult should use struct for schema
    
    ### What changes were proposed in this pull request?
    
    This PR replace column names and columns type with a schema (which is a 
struct).
    
    ### Why are the changes needed?
    
    Before this PR, AnalyzeResult separates column names and column types. 
However these two can be combined to form a schema which is a struct.  This PR 
will simplify that proto message.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38301 from amaliujia/return_schema_use_struct.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../src/main/protobuf/spark/connect/base.proto     |  6 +--
 .../connect/planner/DataTypeProtoConverter.scala   | 19 ++++++-
 .../sql/connect/service/SparkConnectService.scala  | 47 +++++++++---------
 .../connect/planner/SparkConnectServiceSuite.scala | 58 ++++++++++++++++++++++
 python/pyspark/sql/connect/client.py               | 47 ++++++++++++++++--
 python/pyspark/sql/connect/dataframe.py            | 26 +++++++++-
 python/pyspark/sql/connect/proto/base_pb2.py       | 51 +++++++++----------
 python/pyspark/sql/connect/proto/base_pb2.pyi      | 27 +++-------
 .../sql/tests/connect/test_connect_basic.py        | 10 ++++
 9 files changed, 212 insertions(+), 79 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/src/main/protobuf/spark/connect/base.proto
index dff1734335e..b376515bf1a 100644
--- a/connector/connect/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/base.proto
@@ -22,6 +22,7 @@ package spark.connect;
 import "google/protobuf/any.proto";
 import "spark/connect/commands.proto";
 import "spark/connect/relations.proto";
+import "spark/connect/types.proto";
 
 option java_multiple_files = true;
 option java_package = "org.apache.spark.connect.proto";
@@ -116,11 +117,10 @@ message Response {
 // reason about the performance.
 message AnalyzeResponse {
   string client_id = 1;
-  repeated string column_names = 2;
-  repeated string column_types = 3;
+  DataType schema = 2;
 
   // The extended explain string as produced by Spark.
-  string explain_string = 4;
+  string explain_string = 3;
 }
 
 // Main interface for the SparkConnect service.
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
index da3adce43ba..0ee90b5e8fb 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.types.{DataType, IntegerType, StringType, 
StructField, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType, 
StringType, StructField, StructType}
 
 /**
  * This object offers methods to convert to/from connect proto to catalyst 
types.
@@ -50,11 +50,28 @@ object DataTypeProtoConverter {
         
proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
       case StringType =>
         
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
+      case LongType =>
+        
proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
+      case struct: StructType =>
+        toConnectProtoStructType(struct)
       case _ =>
         throw InvalidPlanInput(s"Does not support convert ${t.typeName} to 
connect proto types.")
     }
   }
 
+  def toConnectProtoStructType(schema: StructType): proto.DataType = {
+    val struct = proto.DataType.Struct.newBuilder()
+    for (structField <- schema.fields) {
+      struct.addFields(
+        proto.DataType.StructField
+          .newBuilder()
+          .setName(structField.name)
+          .setType(toConnectProtoType(structField.dataType))
+          .setNullable(structField.nullable))
+    }
+    proto.DataType.newBuilder().setStruct(struct).build()
+  }
+
   def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = {
     mode match {
       case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 20776a29eda..5841017e5bb 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.connect.service
 
 import java.util.concurrent.TimeUnit
 
-import scala.collection.JavaConverters._
-
 import com.google.common.base.Ticker
 import com.google.common.cache.CacheBuilder
 import io.grpc.{Server, Status}
@@ -35,7 +33,7 @@ import org.apache.spark.connect.proto.{AnalyzeResponse, 
Request, Response, Spark
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
-import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, 
SparkConnectPlanner}
 import org.apache.spark.sql.execution.ExtendedMode
 
 /**
@@ -89,29 +87,16 @@ class SparkConnectService(debug: Boolean)
       request: Request,
       responseObserver: StreamObserver[AnalyzeResponse]): Unit = {
     try {
+      if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) {
+        responseObserver.onError(
+          new UnsupportedOperationException(
+            s"${request.getPlan.getOpTypeCase} not supported for analysis."))
+      }
       val session =
         
SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session
-
-      val logicalPlan = request.getPlan.getOpTypeCase match {
-        case proto.Plan.OpTypeCase.ROOT =>
-          new SparkConnectPlanner(request.getPlan.getRoot, session).transform()
-        case _ =>
-          responseObserver.onError(
-            new UnsupportedOperationException(
-              s"${request.getPlan.getOpTypeCase} not supported for analysis."))
-          return
-      }
-      val ds = Dataset.ofRows(session, logicalPlan)
-      val explainString = ds.queryExecution.explainString(ExtendedMode)
-
-      val resp = proto.AnalyzeResponse
-        .newBuilder()
-        .setExplainString(explainString)
-        .setClientId(request.getClientId)
-
-      resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava)
-      resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava)
-      responseObserver.onNext(resp.build())
+      val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session)
+      response.setClientId(request.getClientId)
+      responseObserver.onNext(response.build())
       responseObserver.onCompleted()
     } catch {
       case e: Throwable =>
@@ -120,6 +105,20 @@ class SparkConnectService(debug: Boolean)
           
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
     }
   }
+
+  def handleAnalyzePlanRequest(
+      relation: proto.Relation,
+      session: SparkSession): proto.AnalyzeResponse.Builder = {
+    val logicalPlan = new SparkConnectPlanner(relation, session).transform()
+
+    val ds = Dataset.ofRows(session, logicalPlan)
+    val explainString = ds.queryExecution.explainString(ExtendedMode)
+
+    val response = proto.AnalyzeResponse
+      .newBuilder()
+      .setExplainString(explainString)
+    response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema))
+  }
 }
 
 /**
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
new file mode 100644
index 00000000000..4be8d1705b9
--- /dev/null
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.planner
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.service.SparkConnectService
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Testing Connect Service implementation.
+ */
+class SparkConnectServiceSuite extends SharedSparkSession {
+
+  test("Test schema in analyze response") {
+    withTable("test") {
+      spark.sql("""
+          | CREATE TABLE test (col1 INT, col2 STRING)
+          | USING parquet
+          |""".stripMargin)
+
+      val instance = new SparkConnectService(false)
+      val relation = proto.Relation
+        .newBuilder()
+        .setRead(
+          proto.Read
+            .newBuilder()
+            
.setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("test").build())
+            .build())
+        .build()
+
+      val response = instance.handleAnalyzePlanRequest(relation, spark)
+
+      assert(response.getSchema.hasStruct)
+      val schema = response.getSchema.getStruct
+      assert(schema.getFieldsCount == 2)
+      assert(
+        schema.getFields(0).getName == "col1"
+          && schema.getFields(0).getType.getKindCase == 
proto.DataType.KindCase.I32)
+      assert(
+        schema.getFields(1).getName == "col2"
+          && schema.getFields(1).getType.getKindCase == 
proto.DataType.KindCase.STRING)
+    }
+  }
+}
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 0ae075521c6..f4b6d2ec302 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -33,6 +33,7 @@ from pyspark import cloudpickle
 from pyspark.sql.connect.dataframe import DataFrame
 from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.plan import SQL
+from pyspark.sql.types import DataType, StructType, StructField, LongType, 
StringType
 
 from typing import Optional, Any, Union
 
@@ -91,14 +92,13 @@ class PlanMetrics:
 
 
 class AnalyzeResult:
-    def __init__(self, cols: typing.List[str], types: typing.List[str], 
explain: str):
-        self.cols = cols
-        self.types = types
+    def __init__(self, schema: pb2.DataType, explain: str):
+        self.schema = schema
         self.explain_string = explain
 
     @classmethod
     def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
-        return AnalyzeResult(pb.column_names, pb.column_types, 
pb.explain_string)
+        return AnalyzeResult(pb.schema, pb.explain_string)
 
 
 class RemoteSparkSession(object):
@@ -151,7 +151,44 @@ class RemoteSparkSession(object):
         req.plan.CopyFrom(plan)
         return self._execute_and_fetch(req)
 
-    def analyze(self, plan: pb2.Plan) -> AnalyzeResult:
+    def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> 
DataType:
+        if schema.HasField("struct"):
+            structFields = []
+            for proto_field in schema.struct.fields:
+                structFields.append(
+                    StructField(
+                        proto_field.name,
+                        self._proto_schema_to_pyspark_schema(proto_field.type),
+                        proto_field.nullable,
+                    )
+                )
+            return StructType(structFields)
+        elif schema.HasField("i64"):
+            return LongType()
+        elif schema.HasField("string"):
+            return StringType()
+        else:
+            raise Exception("Only support long, string, struct conversion")
+
+    def schema(self, plan: pb2.Plan) -> StructType:
+        proto_schema = self._analyze(plan).schema
+        # Server side should populate the struct field which is the schema.
+        assert proto_schema.HasField("struct")
+        structFields = []
+        for proto_field in proto_schema.struct.fields:
+            structFields.append(
+                StructField(
+                    proto_field.name,
+                    self._proto_schema_to_pyspark_schema(proto_field.type),
+                    proto_field.nullable,
+                )
+            )
+        return StructType(structFields)
+
+    def explain_string(self, plan: pb2.Plan) -> str:
+        return self._analyze(plan).explain_string
+
+    def _analyze(self, plan: pb2.Plan) -> AnalyzeResult:
         req = pb2.Request()
         req.user_context.user_id = self._user_id
         req.plan.CopyFrom(plan)
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 2b7e3d52039..bf9ed83615b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -34,6 +34,7 @@ from pyspark.sql.connect.column import (
     Expression,
     LiteralExpression,
 )
+from pyspark.sql.types import StructType
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
@@ -96,7 +97,7 @@ class DataFrame(object):
     of the DataFrame with the changes applied.
     """
 
-    def __init__(self, data: Optional[List[Any]] = None, schema: 
Optional[List[str]] = None):
+    def __init__(self, data: Optional[List[Any]] = None, schema: 
Optional[StructType] = None):
         """Creates a new data frame"""
         self._schema = schema
         self._plan: Optional[plan.LogicalPlan] = None
@@ -315,11 +316,32 @@ class DataFrame(object):
         query = self._plan.to_proto(self._session)
         return self._session._to_pandas(query)
 
+    def schema(self) -> StructType:
+        """Returns the schema of this :class:`DataFrame` as a 
:class:`pyspark.sql.types.StructType`.
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        :class:`StructType`
+        """
+        if self._schema is None:
+            if self._plan is not None:
+                query = self._plan.to_proto(self._session)
+                if self._session is None:
+                    raise Exception("Cannot analyze without 
RemoteSparkSession.")
+                self._schema = self._session.schema(query)
+                return self._schema
+            else:
+                raise Exception("Empty plan.")
+        else:
+            return self._schema
+
     def explain(self) -> str:
         if self._plan is not None:
             query = self._plan.to_proto(self._session)
             if self._session is None:
                 raise Exception("Cannot analyze without RemoteSparkSession.")
-            return self._session.analyze(query).explain_string
+            return self._session.explain_string(query)
         else:
             return ""
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 408872dbb66..eb9ecc9157f 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -31,10 +31,11 @@ _sym_db = _symbol_database.Default()
 from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
 from pyspark.sql.connect.proto import commands_pb2 as 
spark_dot_connect_dot_commands__pb2
 from pyspark.sql.connect.proto import relations_pb2 as 
spark_dot_connect_dot_relations__pb2
+from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01
 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 
\x01(\x0b\x32".spark.co [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01
 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_contex [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,28 +46,28 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
     _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = 
b"8\001"
-    _PLAN._serialized_start = 131
-    _PLAN._serialized_end = 247
-    _REQUEST._serialized_start = 250
-    _REQUEST._serialized_end = 524
-    _REQUEST_USERCONTEXT._serialized_start = 402
-    _REQUEST_USERCONTEXT._serialized_end = 524
-    _RESPONSE._serialized_start = 527
-    _RESPONSE._serialized_end = 1495
-    _RESPONSE_ARROWBATCH._serialized_start = 756
-    _RESPONSE_ARROWBATCH._serialized_end = 931
-    _RESPONSE_JSONBATCH._serialized_start = 933
-    _RESPONSE_JSONBATCH._serialized_end = 993
-    _RESPONSE_METRICS._serialized_start = 996
-    _RESPONSE_METRICS._serialized_end = 1480
-    _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1080
-    _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1390
-    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 
1278
-    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1390
-    _RESPONSE_METRICS_METRICVALUE._serialized_start = 1392
-    _RESPONSE_METRICS_METRICVALUE._serialized_end = 1480
-    _ANALYZERESPONSE._serialized_start = 1498
-    _ANALYZERESPONSE._serialized_end = 1653
-    _SPARKCONNECTSERVICE._serialized_start = 1656
-    _SPARKCONNECTSERVICE._serialized_end = 1818
+    _PLAN._serialized_start = 158
+    _PLAN._serialized_end = 274
+    _REQUEST._serialized_start = 277
+    _REQUEST._serialized_end = 551
+    _REQUEST_USERCONTEXT._serialized_start = 429
+    _REQUEST_USERCONTEXT._serialized_end = 551
+    _RESPONSE._serialized_start = 554
+    _RESPONSE._serialized_end = 1522
+    _RESPONSE_ARROWBATCH._serialized_start = 783
+    _RESPONSE_ARROWBATCH._serialized_end = 958
+    _RESPONSE_JSONBATCH._serialized_start = 960
+    _RESPONSE_JSONBATCH._serialized_end = 1020
+    _RESPONSE_METRICS._serialized_start = 1023
+    _RESPONSE_METRICS._serialized_end = 1507
+    _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107
+    _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417
+    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 
1305
+    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417
+    _RESPONSE_METRICS_METRICVALUE._serialized_start = 1419
+    _RESPONSE_METRICS_METRICVALUE._serialized_end = 1507
+    _ANALYZERESPONSE._serialized_start = 1525
+    _ANALYZERESPONSE._serialized_end = 1659
+    _SPARKCONNECTSERVICE._serialized_start = 1662
+    _SPARKCONNECTSERVICE._serialized_end = 1824
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index bb3a6578cf7..5ffd7701b44 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -41,6 +41,7 @@ import google.protobuf.internal.containers
 import google.protobuf.message
 import pyspark.sql.connect.proto.commands_pb2
 import pyspark.sql.connect.proto.relations_pb2
+import pyspark.sql.connect.proto.types_pb2
 import sys
 
 if sys.version_info >= (3, 8):
@@ -401,39 +402,27 @@ class AnalyzeResponse(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
     CLIENT_ID_FIELD_NUMBER: builtins.int
-    COLUMN_NAMES_FIELD_NUMBER: builtins.int
-    COLUMN_TYPES_FIELD_NUMBER: builtins.int
+    SCHEMA_FIELD_NUMBER: builtins.int
     EXPLAIN_STRING_FIELD_NUMBER: builtins.int
     client_id: builtins.str
     @property
-    def column_names(
-        self,
-    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 
...
-    @property
-    def column_types(
-        self,
-    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 
...
+    def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
     explain_string: builtins.str
     """The extended explain string as produced by Spark."""
     def __init__(
         self,
         *,
         client_id: builtins.str = ...,
-        column_names: collections.abc.Iterable[builtins.str] | None = ...,
-        column_types: collections.abc.Iterable[builtins.str] | None = ...,
+        schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
         explain_string: builtins.str = ...,
     ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["schema", b"schema"]
+    ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "client_id",
-            b"client_id",
-            "column_names",
-            b"column_names",
-            "column_types",
-            b"column_types",
-            "explain_string",
-            b"explain_string",
+            "client_id", b"client_id", "explain_string", b"explain_string", 
"schema", b"schema"
         ],
     ) -> None: ...
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f6988a1d120..459b05cc37a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -22,6 +22,7 @@ import tempfile
 import pandas
 
 from pyspark.sql import SparkSession, Row
+from pyspark.sql.types import StructType, StructField, LongType, StringType
 from pyspark.sql.connect.client import RemoteSparkSession
 from pyspark.sql.connect.function_builder import udf
 from pyspark.sql.connect.functions import lit
@@ -97,6 +98,15 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         result = df.explain()
         self.assertGreater(len(result), 0)
 
+    def test_schema(self):
+        schema = self.connect.read.table(self.tbl_name).schema()
+        self.assertEqual(
+            StructType(
+                [StructField("id", LongType(), True), StructField("name", 
StringType(), True)]
+            ),
+            schema,
+        )
+
     def test_simple_binary_expressions(self):
         """Test complex expression"""
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to