This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d26e48404a0 [SPARK-40926][CONNECT] Refactor server side tests to only use DataFrame API d26e48404a0 is described below commit d26e48404a07528029f952517b39c037337f0134 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Oct 28 13:15:07 2022 +0800 [SPARK-40926][CONNECT] Refactor server side tests to only use DataFrame API ### What changes were proposed in this pull request? This PR migrates all existing proto tests to be DataFrame API based. ### Why are the changes needed? 1. The goal for proto tests is to test the capability of representing DataFrames by the Connect proto. So comparing with DataFrame API is more accurate. 2. There are some Connect plan execution requiring SparkSession anyway. We can unify all tests into one suite by only using DataFrame API (e.g. We can merge `SparkConnectDeduplicateSuite.scala` into `SparkConnectProtoSuite.scala`. 3. This also enables the possibility that we can also test result (not only plan) in the future. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing UT. Closes #38406 from amaliujia/refactor_server_tests. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/connect/planner/SparkConnectPlanner.scala | 13 +- .../planner/SparkConnectDeduplicateSuite.scala | 68 ------- .../connect/planner/SparkConnectPlannerSuite.scala | 8 +- .../connect/planner/SparkConnectProtoSuite.scala | 219 ++++++++++----------- 4 files changed, 112 insertions(+), 196 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 53abf2e7709..ebdb5a447b1 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -64,7 +64,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) case proto.Relation.RelTypeCase.LOCAL_RELATION => - transformLocalRelation(rel.getLocalRelation) + transformLocalRelation(rel.getLocalRelation, common) case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") @@ -122,9 +122,16 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } - private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { + private def transformLocalRelation( + rel: proto.LocalRelation, + common: Option[proto.RelationCommon]): LogicalPlan = { val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq - new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) + val relation = new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) + if (common.nonEmpty && common.get.getAlias.nonEmpty) { + logical.SubqueryAlias(identifier = common.get.getAlias, child = relation) + } else { + relation + } } private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala deleted file mode 100644 index 88af60581ba..00000000000 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connect.planner - -import org.apache.spark.sql.{Dataset, Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} - -/** - * [[SparkConnectPlanTestWithSparkSession]] contains a SparkSession for the connect planner. - * - * It is not recommended to use Catalyst DSL along with this trait because `SharedSparkSession` - * has also defined implicits over Catalyst LogicalPlan which will cause ambiguity with the - * implicits defined in Catalyst DSL. - */ -trait SparkConnectPlanTestWithSparkSession extends SharedSparkSession with SparkConnectPlanTest { - override def getSession(): SparkSession = spark -} - -class SparkConnectDeduplicateSuite extends SparkConnectPlanTestWithSparkSession { - lazy val connectTestRelation = createLocalRelationProto( - Seq( - AttributeReference("id", IntegerType)(), - AttributeReference("key", StringType)(), - AttributeReference("value", StringType)())) - - lazy val sparkTestRelation = { - spark.createDataFrame( - new java.util.ArrayList[Row](), - StructType( - Seq( - StructField("id", IntegerType), - StructField("key", StringType), - StructField("value", StringType)))) - } - - test("Test basic deduplicate") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - Dataset.ofRows(spark, transform(connectTestRelation.distinct())) - } - - val sparkPlan = sparkTestRelation.distinct() - comparePlans(connectPlan.queryExecution.analyzed, sparkPlan.queryExecution.analyzed, false) - - val connectPlan2 = { - import org.apache.spark.sql.connect.dsl.plans._ - Dataset.ofRows(spark, transform(connectTestRelation.deduplicate(Seq("key", "value")))) - } - val sparkPlan2 = sparkTestRelation.dropDuplicates(Seq("key", "value")) - comparePlans(connectPlan2.queryExecution.analyzed, sparkPlan2.queryExecution.analyzed, false) - } -} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 6fc47e07c59..49072982c00 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -22,20 +22,18 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.UnresolvedStar -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSparkSession /** * Testing trait for SparkConnect tests with some helper methods to make it easier to create new * test cases. */ -trait SparkConnectPlanTest { - - def getSession(): SparkSession = None.orNull +trait SparkConnectPlanTest extends SharedSparkSession { def transform(rel: proto.Relation): LogicalPlan = { - new SparkConnectPlanner(rel, getSession()).transform() + new SparkConnectPlanner(rel, spark).transform() } def readRel: proto.Relation = diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 0325b6573bd..a38b1951eb2 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -18,10 +18,16 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter, UsingJoin} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.dsl.expressions._ +import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} /** * This suite is based on connect DSL and test that given same dataframe operations, whether @@ -30,81 +36,61 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation */ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { - lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string)) + lazy val connectTestRelation = + createLocalRelationProto( + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) - lazy val connectTestRelation2 = createLocalRelationProto( - Seq($"key".int, $"value".int, $"name".string)) + lazy val connectTestRelation2 = + createLocalRelationProto( + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) - lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string) + lazy val sparkTestRelation: DataFrame = + spark.createDataFrame( + new java.util.ArrayList[Row](), + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType)))) - lazy val sparkTestRelation2: LocalRelation = - LocalRelation($"key".int, $"value".int, $"name".string) + lazy val sparkTestRelation2: DataFrame = + spark.createDataFrame( + new java.util.ArrayList[Row](), + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType)))) test("Basic select") { - val connectPlan = { - // TODO: Scala only allows one implicit per scope so we keep proto implicit imports in - // this scope. Need to find a better way to make two implicits work in the same scope. - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.select("id".protoAttr)) - } - val sparkPlan = sparkTestRelation.select($"id") - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.select("id".protoAttr) + val sparkPlan = sparkTestRelation.select("id") + comparePlans(connectPlan, sparkPlan) } test("UnresolvedFunction resolution.") { - { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - assertThrows[IllegalArgumentException] { - transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr)))) - } + assertThrows[IllegalArgumentException] { + transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr)))) } - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform( - connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr)))) - } + val connectPlan = + connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr))) assertThrows[UnsupportedOperationException] { - connectPlan.analyze + analyzePlan(transform(connectPlan)) } - val validPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr)))) - } - assert(validPlan.analyze != null) + val validPlan = connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr))) + assert(analyzePlan(transform(validPlan)) != null) } test("Basic filter") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.where("id".protoAttr < 0)) - } - - val sparkPlan = sparkTestRelation.where($"id" < 0).analyze - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.where("id".protoAttr < 0) + val sparkPlan = sparkTestRelation.where(Column("id") < 0) + comparePlans(connectPlan, sparkPlan) } test("Basic joins with different join types") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.join(connectTestRelation2)) - } + val connectPlan = connectTestRelation.join(connectTestRelation2) val sparkPlan = sparkTestRelation.join(sparkTestRelation2) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + comparePlans(connectPlan, sparkPlan) + + val connectPlan2 = connectTestRelation.join(connectTestRelation2) + val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2) + comparePlans(connectPlan2, sparkPlan2) - val connectPlan2 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.join(connectTestRelation2, condition = None)) - } - val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None) - comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) for ((t, y) <- Seq( (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), @@ -112,99 +98,79 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), (JoinType.JOIN_TYPE_INNER, Inner))) { - val connectPlan3 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.join(connectTestRelation2, t)) - } - val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y) - comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) - } - val connectPlan4 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform( - connectTestRelation.join(connectTestRelation2, JoinType.JOIN_TYPE_INNER, Seq("name"))) + val connectPlan3 = connectTestRelation.join(connectTestRelation2, t, Seq("id")) + val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, Seq("id"), y.toString) + comparePlans(connectPlan3, sparkPlan3) } - val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, UsingJoin(Inner, Seq("name"))) - comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false) + + val connectPlan4 = + connectTestRelation.join(connectTestRelation2, JoinType.JOIN_TYPE_INNER, Seq("name")) + val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, Seq("name"), Inner.toString) + comparePlans(connectPlan4, sparkPlan4) } test("Test sample") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.sample(0, 0.2, false, 1)) - } - val sparkPlan = sparkTestRelation.sample(0, 0.2, false, 1) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.sample(0, 0.2, false, 1) + val sparkPlan = sparkTestRelation.sample(false, 0.2 - 0, 1) + comparePlans(connectPlan, sparkPlan) } test("column alias") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.select("id".protoAttr.as("id2"))) - } - val sparkPlan = sparkTestRelation.select($"id".as("id2")) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = connectTestRelation.select("id".protoAttr.as("id2")) + val sparkPlan = sparkTestRelation.select(Column("id").alias("id2")) + comparePlans(connectPlan, sparkPlan) } test("Aggregate with more than 1 grouping expressions") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)()) + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val connectPlan = + connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)() + val sparkPlan = + sparkTestRelation.groupBy(Column("id"), Column("name")).agg(Map.empty[String, String]) + comparePlans(connectPlan, sparkPlan) } - val sparkPlan = sparkTestRelation.groupBy($"id", $"name")() - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } test("Test as(alias: String)") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.as("target_table")) - } - + val connectPlan = connectTestRelation.as("target_table") val sparkPlan = sparkTestRelation.as("target_table") - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + comparePlans(connectPlan, sparkPlan) } test("Test StructType in LocalRelation") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.expressions._ - transform(createLocalRelationProtoByQualifiedAttributes(Seq("a".struct("id".int)))) - } - val sparkPlan = LocalRelation($"a".struct($"id".int)) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + val connectPlan = createLocalRelationProtoByQualifiedAttributes(Seq("a".struct("id".int))) + val sparkPlan = + LocalRelation(AttributeReference("a", StructType(Seq(StructField("id", IntegerType))))()) + comparePlans(connectPlan, sparkPlan) } test("Test limit offset") { - val connectPlan = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.limit(10)) - } + val connectPlan = connectTestRelation.limit(10) val sparkPlan = sparkTestRelation.limit(10) - comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + comparePlans(connectPlan, sparkPlan) - val connectPlan2 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.offset(2)) - } + val connectPlan2 = connectTestRelation.offset(2) val sparkPlan2 = sparkTestRelation.offset(2) - comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) + comparePlans(connectPlan2, sparkPlan2) - val connectPlan3 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.limit(10).offset(2)) - } + val connectPlan3 = connectTestRelation.limit(10).offset(2) val sparkPlan3 = sparkTestRelation.limit(10).offset(2) - comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) + comparePlans(connectPlan3, sparkPlan3) - val connectPlan4 = { - import org.apache.spark.sql.connect.dsl.plans._ - transform(connectTestRelation.offset(2).limit(10)) - } + val connectPlan4 = connectTestRelation.offset(2).limit(10) val sparkPlan4 = sparkTestRelation.offset(2).limit(10) - comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false) + comparePlans(connectPlan4, sparkPlan4) + } + + test("Test basic deduplicate") { + val connectPlan = connectTestRelation.distinct() + val sparkPlan = sparkTestRelation.distinct() + comparePlans(connectPlan, sparkPlan) + + val connectPlan2 = connectTestRelation.deduplicate(Seq("id", "name")) + val sparkPlan2 = sparkTestRelation.dropDuplicates(Seq("id", "name")) + comparePlans(connectPlan2, sparkPlan2) } private def createLocalRelationProtoByQualifiedAttributes( @@ -215,4 +181,17 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } + + // This is a function for testing only. This is used when the plan is ready and it only waits + // analyzer to analyze attribute references within the plan. + private def analyzePlan(plan: LogicalPlan): LogicalPlan = { + val connectAnalyzed = analysis.SimpleAnalyzer.execute(plan) + analysis.SimpleAnalyzer.checkAnalysis(connectAnalyzed) + connectAnalyzed + } + + private def comparePlans(connectPlan: proto.Relation, sparkPlan: DataFrame): Unit = { + val connectAnalyzed = analyzePlan(transform(connectPlan)) + comparePlans(connectAnalyzed, sparkPlan.queryExecution.analyzed, false) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org