This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d378b35f54b [SPARK-45868][CONNECT] Make sure `spark.table` use the 
same parser with vanilla spark
d378b35f54b is described below

commit d378b35f54b853d91e13e5def8a5bf2c7c06ff32
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Nov 15 13:43:55 2023 -0800

    [SPARK-45868][CONNECT] Make sure `spark.table` use the same parser with 
vanilla spark
    
    ### What changes were proposed in this pull request?
    Make sure spark.table use the same parser with vanilla spark
    
    ### Why are the changes needed?
    to be consistent with the vanilla spark:
    
    
https://github.com/apache/spark/blob/9d93b7112a31965447a34301889f90d14578e628/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala#L714-L720
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43741 from zhengruifeng/connect_read_table_parser.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  5 +-
 .../SparkConnectWithSessionExtensionSuite.scala    | 82 ++++++++++++++++++++++
 2 files changed, 85 insertions(+), 2 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 637ed09798a..d4e5e34c61a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -50,7 +50,7 @@ import 
org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncode
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
-import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParseException, ParserUtils}
+import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, 
CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, 
DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, 
Unpivot, UnresolvedHint}
@@ -1227,7 +1227,8 @@ class SparkConnectPlanner(
     rel.getReadTypeCase match {
       case proto.Read.ReadTypeCase.NAMED_TABLE =>
         val multipartIdentifier =
-          
CatalystSqlParser.parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier)
+          session.sessionState.sqlParser
+            .parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier)
         UnresolvedRelation(
           multipartIdentifier,
           new CaseInsensitiveStringMap(rel.getNamedTable.getOptionsMap),
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala
new file mode 100644
index 00000000000..37c7fe25097
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.connect.proto
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.sql.types.{DataType, StructType}
+
+class SparkConnectWithSessionExtensionSuite extends SparkFunSuite {
+
+  case class MyParser(spark: SparkSession, delegate: ParserInterface) extends 
ParserInterface {
+    override def parsePlan(sqlText: String): LogicalPlan =
+      delegate.parsePlan(sqlText)
+
+    override def parseExpression(sqlText: String): Expression =
+      delegate.parseExpression(sqlText)
+
+    override def parseTableIdentifier(sqlText: String): TableIdentifier =
+      delegate.parseTableIdentifier(sqlText)
+
+    override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
+      delegate.parseFunctionIdentifier(sqlText)
+
+    override def parseMultipartIdentifier(sqlText: String): Seq[String] =
+      delegate.parseMultipartIdentifier(sqlText) :+ "FROM_MY_PARSER"
+
+    override def parseTableSchema(sqlText: String): StructType =
+      delegate.parseTableSchema(sqlText)
+
+    override def parseDataType(sqlText: String): DataType =
+      delegate.parseDataType(sqlText)
+
+    override def parseQuery(sqlText: String): LogicalPlan =
+      delegate.parseQuery(sqlText)
+  }
+
+  test("Parse table name with test parser") {
+    val spark = SparkSession
+      .builder()
+      .master("local[1]")
+      .withExtensions(extension => extension.injectParser(MyParser))
+      .getOrCreate()
+
+    val read = proto.Read.newBuilder().build()
+    val readWithTable = read.toBuilder
+      
.setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("name").build())
+      .build()
+    val rel = proto.Relation.newBuilder.setRead(readWithTable).build()
+
+    val res = new 
SparkConnectPlanner(SessionHolder.forTesting(spark)).transformRelation(rel)
+
+    assert(res !== null)
+    assert(res.nodeName === "UnresolvedRelation")
+    assert(
+      res.asInstanceOf[UnresolvedRelation].multipartIdentifier ===
+        Seq("name", "FROM_MY_PARSER"))
+
+    spark.stop()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to