This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fd1e0d028cb [SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner fd1e0d028cb is described below commit fd1e0d028cb7e26921cd66a421c00d7260092b23 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Nov 11 15:48:21 2022 +0800 [SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner ### What changes were proposed in this pull request? In the past, Connect server side separates `Command` and `Relation` into two Planners. However, as we are adding new API, there are certainly cases that a `Command` still has an input which is a Relation. Thus when converting `Command`, it still needs to access the logic of converting `Relation`. View creation is an example of such cases. Usually DDL and DML of SQL will also follow. This PR refactors to merge the logic of dealing with `Command` and `Relation` into the same planner. ### Why are the changes needed? Refactoring. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Existing UT Closes #38604 from amaliujia/refactor-planners. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../command/SparkConnectCommandPlanner.scala | 174 --------------------- .../sql/connect/planner/SparkConnectPlanner.scala | 152 +++++++++++++++++- .../sql/connect/service/SparkConnectService.scala | 2 +- .../service/SparkConnectStreamHandler.scala | 9 +- .../planner/SparkConnectCommandPlannerSuite.scala | 160 ------------------- .../connect/planner/SparkConnectPlannerSuite.scala | 25 +-- .../connect/planner/SparkConnectProtoSuite.scala | 128 ++++++++++++++- 7 files changed, 291 insertions(+), 359 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala deleted file mode 100644 index 11090976c7f..00000000000 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connect.command - -import scala.collection.JavaConverters._ - -import com.google.common.collect.{Lists, Maps} - -import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} -import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.WriteOperation -import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView} -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.command.CreateViewCommand -import org.apache.spark.sql.execution.python.UserDefinedPythonFunction -import org.apache.spark.sql.types.StringType - -final case class InvalidCommandInput( - private val message: String = "", - private val cause: Throwable = null) - extends Exception(message, cause) - -class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) { - - lazy val pythonExec = - sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) - - def process(): Unit = { - command.getCommandTypeCase match { - case proto.Command.CommandTypeCase.CREATE_FUNCTION => - handleCreateScalarFunction(command.getCreateFunction) - case proto.Command.CommandTypeCase.WRITE_OPERATION => - handleWriteOperation(command.getWriteOperation) - case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => - handleCreateViewCommand(command.getCreateDataframeView) - case _ => throw new UnsupportedOperationException(s"$command not supported.") - } - } - - /** - * This is a helper function that registers a new Python function in the SparkSession. - * - * Right now this function is very rudimentary and bare-bones just to showcase how it is - * possible to remotely serialize a Python function and execute it on the Spark cluster. If the - * Python version on the client and server diverge, the execution of the function that is - * serialized will most likely fail. - * - * @param cf - */ - def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = { - val function = SimplePythonFunction( - cf.getSerializedFunction.toByteArray, - Maps.newHashMap(), - Lists.newArrayList(), - pythonExec, - "3.9", // TODO(SPARK-40532) This needs to be an actual Python version. - Lists.newArrayList(), - null) - - val udf = UserDefinedPythonFunction( - cf.getPartsList.asScala.head, - function, - StringType, - PythonEvalType.SQL_BATCHED_UDF, - udfDeterministic = false) - - session.udf.registerPython(cf.getPartsList.asScala.head, udf) - } - - def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = { - val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView - - val tableIdentifier = - try { - session.sessionState.sqlParser.parseTableIdentifier(createView.getName) - } catch { - case _: ParseException => - throw QueryCompilationErrors.invalidViewNameError(createView.getName) - } - - val plan = CreateViewCommand( - name = tableIdentifier, - userSpecifiedColumns = Nil, - comment = None, - properties = Map.empty, - originalText = None, - plan = new SparkConnectPlanner(createView.getInput, session).transform(), - allowExisting = false, - replace = createView.getReplace, - viewType = viewType, - isAnalyzed = true) - - Dataset.ofRows(session, plan).queryExecution.commandExecuted - } - - /** - * Transforms the write operation and executes it. - * - * The input write operation contains a reference to the input plan and transforms it to the - * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the - * parameters of the WriteOperation into the corresponding methods calls. - * - * @param writeOperation - */ - def handleWriteOperation(writeOperation: WriteOperation): Unit = { - // Transform the input plan into the logical plan. - val planner = new SparkConnectPlanner(writeOperation.getInput, session) - val plan = planner.transform() - // And create a Dataset from the plan. - val dataset = Dataset.ofRows(session, logicalPlan = plan) - - val w = dataset.write - if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) { - w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode)) - } - - if (writeOperation.getOptionsCount > 0) { - writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) } - } - - if (writeOperation.getSortColumnNamesCount > 0) { - val names = writeOperation.getSortColumnNamesList.asScala - w.sortBy(names.head, names.tail.toSeq: _*) - } - - if (writeOperation.hasBucketBy) { - val op = writeOperation.getBucketBy - val cols = op.getBucketColumnNamesList.asScala - if (op.getNumBuckets <= 0) { - throw InvalidCommandInput( - s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.") - } - w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*) - } - - if (writeOperation.getPartitioningColumnsCount > 0) { - val names = writeOperation.getPartitioningColumnsList.asScala - w.partitionBy(names.toSeq: _*) - } - - if (writeOperation.getSource != null) { - w.format(writeOperation.getSource) - } - - writeOperation.getSaveTypeCase match { - case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath) - case proto.WriteOperation.SaveTypeCase.TABLE_NAME => - w.saveAsTable(writeOperation.getTableName) - case _ => - throw new UnsupportedOperationException( - "WriteOperation:SaveTypeCase not supported " - + s"${writeOperation.getSaveTypeCase.getNumber}") - } - } - -} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b91fef58a11..f8ccc7b62e7 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -19,18 +19,25 @@ package org.apache.spark.sql.connect.planner import scala.collection.JavaConverters._ +import com.google.common.collect.{Lists, Maps} + +import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.WriteOperation import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.AliasIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression} import org.apache.spark.sql.catalyst.optimizer.CombineUnions -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LogicalPlan, Sample, SubqueryAlias, Union} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.command.CreateViewCommand +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -39,14 +46,17 @@ final case class InvalidPlanInput( private val cause: Throwable = None.orNull) extends Exception(message, cause) -class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { +final case class InvalidCommandInput( + private val message: String = "", + private val cause: Throwable = null) + extends Exception(message, cause) - def transform(): LogicalPlan = { - transformRelation(plan) - } +class SparkConnectPlanner(session: SparkSession) { + lazy val pythonExec = + sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) // The root of the query plan is a relation and we apply the transformations to it. - private def transformRelation(rel: proto.Relation): LogicalPlan = { + def transformRelation(rel: proto.Relation): LogicalPlan = { rel.getRelTypeCase match { case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject) @@ -446,4 +456,132 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } + def process(command: proto.Command): Unit = { + command.getCommandTypeCase match { + case proto.Command.CommandTypeCase.CREATE_FUNCTION => + handleCreateScalarFunction(command.getCreateFunction) + case proto.Command.CommandTypeCase.WRITE_OPERATION => + handleWriteOperation(command.getWriteOperation) + case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => + handleCreateViewCommand(command.getCreateDataframeView) + case _ => throw new UnsupportedOperationException(s"$command not supported.") + } + } + + /** + * This is a helper function that registers a new Python function in the SparkSession. + * + * Right now this function is very rudimentary and bare-bones just to showcase how it is + * possible to remotely serialize a Python function and execute it on the Spark cluster. If the + * Python version on the client and server diverge, the execution of the function that is + * serialized will most likely fail. + * + * @param cf + */ + def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = { + val function = SimplePythonFunction( + cf.getSerializedFunction.toByteArray, + Maps.newHashMap(), + Lists.newArrayList(), + pythonExec, + "3.9", // TODO(SPARK-40532) This needs to be an actual Python version. + Lists.newArrayList(), + null) + + val udf = UserDefinedPythonFunction( + cf.getPartsList.asScala.head, + function, + StringType, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = false) + + session.udf.registerPython(cf.getPartsList.asScala.head, udf) + } + + def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = { + val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView + + val tableIdentifier = + try { + session.sessionState.sqlParser.parseTableIdentifier(createView.getName) + } catch { + case _: ParseException => + throw QueryCompilationErrors.invalidViewNameError(createView.getName) + } + + val plan = CreateViewCommand( + name = tableIdentifier, + userSpecifiedColumns = Nil, + comment = None, + properties = Map.empty, + originalText = None, + plan = transformRelation(createView.getInput), + allowExisting = false, + replace = createView.getReplace, + viewType = viewType, + isAnalyzed = true) + + Dataset.ofRows(session, plan).queryExecution.commandExecuted + } + + /** + * Transforms the write operation and executes it. + * + * The input write operation contains a reference to the input plan and transforms it to the + * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the + * parameters of the WriteOperation into the corresponding methods calls. + * + * @param writeOperation + */ + def handleWriteOperation(writeOperation: WriteOperation): Unit = { + // Transform the input plan into the logical plan. + val planner = new SparkConnectPlanner(session) + val plan = planner.transformRelation(writeOperation.getInput) + // And create a Dataset from the plan. + val dataset = Dataset.ofRows(session, logicalPlan = plan) + + val w = dataset.write + if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) { + w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode)) + } + + if (writeOperation.getOptionsCount > 0) { + writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) } + } + + if (writeOperation.getSortColumnNamesCount > 0) { + val names = writeOperation.getSortColumnNamesList.asScala + w.sortBy(names.head, names.tail.toSeq: _*) + } + + if (writeOperation.hasBucketBy) { + val op = writeOperation.getBucketBy + val cols = op.getBucketColumnNamesList.asScala + if (op.getNumBuckets <= 0) { + throw InvalidCommandInput( + s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.") + } + w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*) + } + + if (writeOperation.getPartitioningColumnsCount > 0) { + val names = writeOperation.getPartitioningColumnsList.asScala + w.partitionBy(names.toSeq: _*) + } + + if (writeOperation.getSource != null) { + w.format(writeOperation.getSource) + } + + writeOperation.getSaveTypeCase match { + case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath) + case proto.WriteOperation.SaveTypeCase.TABLE_NAME => + w.saveAsTable(writeOperation.getTableName) + case _ => + throw new UnsupportedOperationException( + "WriteOperation:SaveTypeCase not supported " + + s"${writeOperation.getSaveTypeCase.getNumber}") + } + } + } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index a1e70975da5..abbad51c601 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -106,7 +106,7 @@ class SparkConnectService(debug: Boolean) def handleAnalyzePlanRequest( relation: proto.Relation, session: SparkSession): proto.AnalyzeResponse.Builder = { - val logicalPlan = new SparkConnectPlanner(relation, session).transform() + val logicalPlan = new SparkConnectPlanner(session).transformRelation(relation) val ds = Dataset.ofRows(session, logicalPlan) val explainString = ds.queryExecution.explainString(ExtendedMode) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 9652fce5425..394d6477d73 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -27,7 +27,6 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{Request, Response} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} @@ -51,8 +50,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte def handlePlan(session: SparkSession, request: Request): Unit = { // Extract the plan from the request and convert it to a logical plan - val planner = new SparkConnectPlanner(request.getPlan.getRoot, session) - val dataframe = Dataset.ofRows(session, planner.transform()) + val planner = new SparkConnectPlanner(session) + val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) try { processAsArrowBatches(request.getClientId, dataframe) } catch { @@ -216,8 +215,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte def handleCommand(session: SparkSession, request: Request): Unit = { val command = request.getPlan.getCommand - val planner = new SparkConnectCommandPlanner(session, command) - planner.process() + val planner = new SparkConnectPlanner(session) + planner.process(command) responseObserver.onCompleted() } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala deleted file mode 100644 index 8ab8e0599fc..00000000000 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connect.planner - -import java.nio.file.{Files, Paths} - -import org.apache.spark.SparkClassNotFoundException -import org.apache.spark.connect.proto -import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner} -import org.apache.spark.sql.connect.dsl.commands._ -import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} - -class SparkConnectCommandPlannerSuite - extends SQLTestUtils - with SparkConnectPlanTest - with SharedSparkSession { - - lazy val localRelation = createLocalRelationProto(Seq($"id".int)) - - def transform(cmd: proto.Command): Unit = { - new SparkConnectCommandPlanner(spark, cmd).process() - } - - test("Writes fails without path or table") { - assertThrows[UnsupportedOperationException] { - transform(localRelation.write()) - } - } - - test("Write fails with unknown table - AnalysisException") { - val cmd = readRel.write(tableName = Some("dest")) - assertThrows[AnalysisException] { - transform(cmd) - } - } - - test("Write with partitions") { - val cmd = localRelation.write( - tableName = Some("testtable"), - format = Some("parquet"), - partitionByCols = Seq("noid")) - assertThrows[AnalysisException] { - transform(cmd) - } - } - - test("Write with invalid bucketBy configuration") { - val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0)) - assertThrows[InvalidCommandInput] { - transform(cmd) - } - } - - test("Write to Path") { - withTempDir { f => - val cmd = localRelation.write( - format = Some("parquet"), - path = Some(f.getPath), - mode = Some("Overwrite")) - transform(cmd) - assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}") - } - } - - test("Write to Path with invalid input") { - // Wrong data source. - assertThrows[SparkClassNotFoundException]( - transform( - localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat")))) - - // Default data source not found. - assertThrows[SparkClassNotFoundException]( - transform(localRelation.write(path = Some("/tmp/tmppath")))) - } - - test("Write with sortBy") { - // Sort by existing column. - withTable("testtable") { - transform( - localRelation.write( - tableName = Some("testtable"), - format = Some("parquet"), - sortByColumns = Seq("id"), - bucketByCols = Seq("id"), - numBuckets = Some(10))) - } - - // Sort by non-existing column - assertThrows[AnalysisException]( - transform( - localRelation - .write( - tableName = Some("testtable"), - format = Some("parquet"), - sortByColumns = Seq("noid"), - bucketByCols = Seq("id"), - numBuckets = Some(10)))) - } - - test("Write to Table") { - withTable("testtable") { - val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable")) - transform(cmd) - // Check that we can find and drop the table. - spark.sql(s"select count(*) from testtable").collect() - } - } - - test("SaveMode conversion tests") { - assertThrows[IllegalArgumentException]( - DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED)) - - val combinations = Seq( - (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND), - (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE), - (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE), - (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS)) - combinations.foreach { a => - assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2) - assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1) - } - } - - test("Test CreateView") { - withView("view1", "view2", "view3", "view4") { - transform(localRelation.createView("view1", global = true, replace = true)) - assert(spark.catalog.tableExists("global_temp.view1")) - - transform(localRelation.createView("view2", global = false, replace = true)) - assert(spark.catalog.tableExists("view2")) - - transform(localRelation.createView("view3", global = true, replace = false)) - assertThrows[AnalysisException] { - transform(localRelation.createView("view3", global = true, replace = false)) - } - - transform(localRelation.createView("view4", global = false, replace = false)) - assertThrows[AnalysisException] { - transform(localRelation.createView("view4", global = false, replace = false)) - } - } - } -} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index d2304581c3a..9e5fc41a0c6 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -33,7 +33,11 @@ import org.apache.spark.sql.test.SharedSparkSession trait SparkConnectPlanTest extends SharedSparkSession { def transform(rel: proto.Relation): logical.LogicalPlan = { - new SparkConnectPlanner(rel, spark).transform() + new SparkConnectPlanner(spark).transformRelation(rel) + } + + def transform(cmd: proto.Command): Unit = { + new SparkConnectPlanner(spark).process(cmd) } def readRel: proto.Relation = @@ -75,24 +79,23 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("Simple Limit") { assertThrows[IndexOutOfBoundsException] { - new SparkConnectPlanner( - proto.Relation.newBuilder - .setLimit(proto.Limit.newBuilder.setLimit(10)) - .build(), - None.orNull) - .transform() + new SparkConnectPlanner(None.orNull) + .transformRelation( + proto.Relation.newBuilder + .setLimit(proto.Limit.newBuilder.setLimit(10)) + .build()) } } test("InvalidInputs") { // No Relation Set intercept[IndexOutOfBoundsException]( - new SparkConnectPlanner(proto.Relation.newBuilder().build(), None.orNull).transform()) + new SparkConnectPlanner(None.orNull).transformRelation(proto.Relation.newBuilder().build())) intercept[InvalidPlanInput]( - new SparkConnectPlanner( - proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build(), - None.orNull).transform()) + new SparkConnectPlanner(None.orNull) + .transformRelation( + proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build())) } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 5052b451047..53ea1988809 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -16,15 +16,19 @@ */ package org.apache.spark.sql.connect.planner +import java.nio.file.{Files, Paths} + +import org.apache.spark.SparkClassNotFoundException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType -import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.dsl.MockRemoteSession +import org.apache.spark.sql.connect.dsl.commands._ import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.functions._ @@ -57,6 +61,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { new java.util.ArrayList[Row](), StructType(Seq(StructField("id", IntegerType), StructField("name", StringType)))) + lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)())) + test("Basic select") { val connectPlan = connectTestRelation.select("id".protoAttr) val sparkPlan = sparkTestRelation.select("id") @@ -303,6 +309,126 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { assert(e.getMessage.contains("Found duplicate column(s)")) } + test("Writes fails without path or table") { + assertThrows[UnsupportedOperationException] { + transform(localRelation.write()) + } + } + + test("Write fails with unknown table - AnalysisException") { + val cmd = readRel.write(tableName = Some("dest")) + assertThrows[AnalysisException] { + transform(cmd) + } + } + + test("Write with partitions") { + val cmd = localRelation.write( + tableName = Some("testtable"), + format = Some("parquet"), + partitionByCols = Seq("noid")) + assertThrows[AnalysisException] { + transform(cmd) + } + } + + test("Write with invalid bucketBy configuration") { + val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0)) + assertThrows[InvalidCommandInput] { + transform(cmd) + } + } + + test("Write to Path") { + withTempDir { f => + val cmd = localRelation.write( + format = Some("parquet"), + path = Some(f.getPath), + mode = Some("Overwrite")) + transform(cmd) + assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}") + } + } + + test("Write to Path with invalid input") { + // Wrong data source. + assertThrows[SparkClassNotFoundException]( + transform( + localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat")))) + + // Default data source not found. + assertThrows[SparkClassNotFoundException]( + transform(localRelation.write(path = Some("/tmp/tmppath")))) + } + + test("Write with sortBy") { + // Sort by existing column. + withTable("testtable") { + transform( + localRelation.write( + tableName = Some("testtable"), + format = Some("parquet"), + sortByColumns = Seq("id"), + bucketByCols = Seq("id"), + numBuckets = Some(10))) + } + + // Sort by non-existing column + assertThrows[AnalysisException]( + transform( + localRelation + .write( + tableName = Some("testtable"), + format = Some("parquet"), + sortByColumns = Seq("noid"), + bucketByCols = Seq("id"), + numBuckets = Some(10)))) + } + + test("Write to Table") { + withTable("testtable") { + val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable")) + transform(cmd) + // Check that we can find and drop the table. + spark.sql(s"select count(*) from testtable").collect() + } + } + + test("SaveMode conversion tests") { + assertThrows[IllegalArgumentException]( + DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED)) + + val combinations = Seq( + (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND), + (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE), + (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE), + (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS)) + combinations.foreach { a => + assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2) + assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1) + } + } + + test("Test CreateView") { + withView("view1", "view2", "view3", "view4") { + transform(localRelation.createView("view1", global = true, replace = true)) + assert(spark.catalog.tableExists("global_temp.view1")) + + transform(localRelation.createView("view2", global = false, replace = true)) + assert(spark.catalog.tableExists("view2")) + + transform(localRelation.createView("view3", global = true, replace = false)) + assertThrows[AnalysisException] { + transform(localRelation.createView("view3", global = true, replace = false)) + } + + transform(localRelation.createView("view4", global = false, replace = false)) + assertThrows[AnalysisException] { + transform(localRelation.createView("view4", global = false, replace = false)) + } + } + } + private def createLocalRelationProtoByQualifiedAttributes( attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org