This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-4.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push: new 487bb773e924 [SPARK-52134] Move execution logic to SqlScriptingExecution and enable Spark Connect path 487bb773e924 is described below commit 487bb773e924a928caa5708561ca67943a810aa1 Author: David Milicevic <david.milice...@databricks.com> AuthorDate: Fri May 16 11:06:40 2025 +0200 [SPARK-52134] Move execution logic to SqlScriptingExecution and enable Spark Connect path Move the script execution from `SparkSession#sql` to `QueryExecution#lazyAnalyzed`. This allows `QueryExecution` to receive the original parsed logical plan for scripting, which will be used to detect script execution in Spark Connect to treat them as commands. Moving the `executeSqlScript` logic from `SparkSession` to `SqlScriptingExecution's` object. SQL Scripting improvements. No. This PR enables new functionality though (execution through Spark Connect), but the results are remaining the same. Already existing tests confirm that refactor of execution logic doesn't affect anything. Test added to confirm that execution through Spark Connect is not failing. No. Closes #50895 from davidm-db/execute_sql_script_refactor. Lead-authored-by: David Milicevic <david.milice...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Co-authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 02389836a2ed425d43ea3240374e048f9a28636e) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/connect/planner/SparkConnectPlanner.scala | 7 +- .../spark/sql/connect/SparkConnectServerTest.scala | 57 +++++++++++ .../service/SparkConnectServiceE2ESuite.scala | 21 +++++ .../apache/spark/sql/classic/SparkSession.scala | 104 ++++----------------- .../spark/sql/execution/QueryExecution.scala | 19 +++- .../sql/scripting/SqlScriptingExecution.scala | 52 ++++++++++- 6 files changed, 164 insertions(+), 96 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7b9223508ec9..16b8c1afe7e2 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -54,7 +54,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateSt [...] +import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, CompoundBody, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAc [...] import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -2614,8 +2614,9 @@ class SparkConnectPlanner( s"SQL command expects either a SQL or a WithRelations, but got $other") } - // Check if commands have been executed. + // Check if command or SQL Script has been executed. val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult] + val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody] val rows = df.logicalPlan match { case lr: LocalRelation => lr.data case cr: CommandResult => cr.rows @@ -2627,7 +2628,7 @@ class SparkConnectPlanner( val result = SqlCommandResult.newBuilder() // Only filled when isCommand val metrics = ExecutePlanResponse.Metrics.newBuilder() - if (isCommand) { + if (isCommand || isSqlScript) { // Convert the results to Arrow. val schema = df.schema val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 76c88d515ec0..92f64875d337 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -16,16 +16,19 @@ */ package org.apache.spark.sql.connect +import java.io.ByteArrayInputStream import java.util.{TimeZone, UUID} import scala.reflect.runtime.universe.TypeTag import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.ipc.ArrowStreamReader import org.scalatest.concurrent.{Eventually, TimeLimits} import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState} import org.apache.spark.sql.connect.client.arrow.ArrowSerializer @@ -143,6 +146,21 @@ trait SparkConnectServerTest extends SharedSparkSession { proto.Plan.newBuilder().setRoot(dsl.sql(query)).build() } + protected def buildSqlCommandPlan(sqlCommand: String) = { + proto.Plan + .newBuilder() + .setCommand( + proto.Command + .newBuilder() + .setSqlCommand( + proto.SqlCommand + .newBuilder() + .setSql(sqlCommand) + .build()) + .build()) + .build() + } + protected def buildLocalRelation[A <: Product: TypeTag](data: Seq[A]) = { val encoder = ScalaReflection.encoderFor[A] val arrowData = @@ -305,4 +323,43 @@ trait SparkConnectServerTest extends SharedSparkSession { val plan = buildPlan(query) runQuery(plan, queryTimeout, iterSleep) } + + protected def checkSqlCommandResponse( + result: ExecutePlanResponse.SqlCommandResult, + expected: Seq[Seq[Any]]): Unit = { + // Extract the serialized Arrow data as a byte array. + val dataBytes = result.getRelation.getLocalRelation.getData.toByteArray + + // Create an ArrowStreamReader to deserialize the data. + val allocator = new RootAllocator(Long.MaxValue) + val inputStream = new ByteArrayInputStream(dataBytes) + val reader = new ArrowStreamReader(inputStream, allocator) + + try { + // Read the schema and data. + val root = reader.getVectorSchemaRoot + // Load the first batch of data. + reader.loadNextBatch() + + // Get dimensions. + val rowCount = root.getRowCount + val colCount = root.getFieldVectors.size + assert(rowCount == expected.length, "Row count mismatch") + assert(colCount == expected.head.length, "Column count mismatch") + + // Compare to expected. + for (i <- 0 until rowCount) { + for (j <- 0 until colCount) { + val col = root.getFieldVectors.get(j) + val value = col.getObject(i) + print(value) + assert(value == expected(i)(j), s"Value mismatch at ($i, $j)") + } + } + } finally { + // Clean up resources. + reader.close() + allocator.close() + } + } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index 3337bb9b8ea4..d127f2e5a4cd 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -33,6 +33,27 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest { // were all already in the buffer. val BIG_ENOUGH_QUERY = "select * from range(1000000)" + test("SQL Script over Spark Connect.") { + val sessionId = UUID.randomUUID.toString() + val userId = "ScriptUser" + val sqlScriptText = + """BEGIN + |IF 1 = 1 THEN + | SELECT 1; + |ELSE + | SELECT 2; + |END IF; + |END + """.stripMargin + withClient(sessionId = sessionId, userId = userId) { client => + // this will create the session, and then ReleaseSession at the end of withClient. + val enableSqlScripting = client.execute(buildPlan("SET spark.sql.scripting.enabled=true")) + enableSqlScripting.hasNext // trigger execution + val query = client.execute(buildSqlCommandPlan(sqlScriptText)) + checkSqlCommandResponse(query.next().getSqlCommandResult, Seq(Seq(1))) + } + } + test("Execute is sent eagerly to the server upon iterator creation") { // This behavior changed with grpc upgrade from 1.56.0 to 1.59.0. // Testing to be aware of future changes. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index 034a3359da94..0015d7ff99e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -42,10 +42,9 @@ import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, LogicalPlan, Range} -import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, Range} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.classic.SparkSession.applyAndLoadExtensions @@ -56,7 +55,6 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.scripting.SqlScriptingExecution import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager @@ -432,50 +430,6 @@ class SparkSession private( | Everything else | * ----------------- */ - /** - * Executes given script and return the result of the last statement. - * If script contains no queries, an empty `DataFrame` is returned. - * - * @param script A SQL script to execute. - * @param args A map of parameter names to SQL literal expressions. - * - * @return The result as a `DataFrame`. - */ - private def executeSqlScript( - script: CompoundBody, - args: Map[String, Expression] = Map.empty): DataFrame = { - val sse = new SqlScriptingExecution(script, this, args) - sse.withLocalVariableManager { - var result: Option[Seq[Row]] = None - - // We must execute returned df before calling sse.getNextResult again because sse.hasNext - // advances the script execution and executes all statements until the next result. We must - // collect results immediately to maintain execution order. - // This ensures we respect the contract of SqlScriptingExecution API. - var df: Option[DataFrame] = sse.getNextResult - var resultSchema: Option[StructType] = None - while (df.isDefined) { - sse.withErrorHandling { - // Collect results from the current DataFrame. - result = Some(df.get.collect().toSeq) - resultSchema = Some(df.get.schema) - } - df = sse.getNextResult - } - - if (result.isEmpty) { - emptyDataFrame - } else { - // If `result` is defined, then `resultSchema` must be defined as well. - assert(resultSchema.isDefined) - - val attributes = DataTypeUtils.toAttributes(resultSchema.get) - Dataset.ofRows( - self, LocalRelation.fromExternalRows(attributes, result.get)) - } - } - } - /** * Executes a SQL query substituting positional parameters by the given arguments, * returning the result as a `DataFrame`. @@ -495,30 +449,17 @@ class SparkSession private( withActive { val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) - parsedPlan match { - case compoundBody: CompoundBody => - if (args.nonEmpty) { - // Positional parameters are not supported for SQL scripting. - throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting() - } - compoundBody - case logicalPlan: LogicalPlan => - if (args.nonEmpty) { - PosParameterizedQuery(logicalPlan, args.map(lit(_).expr).toImmutableArraySeq) - } else { - logicalPlan - } + if (args.nonEmpty) { + if (parsedPlan.isInstanceOf[CompoundBody]) { + // Positional parameters are not supported for SQL scripting. + throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting() + } + PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq) + } else { + parsedPlan } } - - plan match { - case compoundBody: CompoundBody => - // Execute the SQL script. - executeSqlScript(compoundBody) - case logicalPlan: LogicalPlan => - // Execute the standalone SQL statement. - Dataset.ofRows(self, plan, tracker) - } + Dataset.ofRows(self, plan, tracker) } /** @inheritdoc */ @@ -549,26 +490,13 @@ class SparkSession private( withActive { val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) - parsedPlan match { - case compoundBody: CompoundBody => - compoundBody - case logicalPlan: LogicalPlan => - if (args.nonEmpty) { - NameParameterizedQuery(logicalPlan, args.transform((_, v) => lit(v).expr)) - } else { - logicalPlan - } + if (args.nonEmpty) { + NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr)) + } else { + parsedPlan } } - - plan match { - case compoundBody: CompoundBody => - // Execute the SQL script. - executeSqlScript(compoundBody, args.transform((_, v) => lit(v).expr)) - case logicalPlan: LogicalPlan => - // Execute the standalone SQL statement. - Dataset.ofRows(self, plan, tracker) - } + Dataset.ofRows(self, plan, tracker) } /** @inheritdoc */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 87cafa58d5fa..071a267fd6ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -31,10 +31,10 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{LazyExpression, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString @@ -46,6 +46,7 @@ import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.scripting.SqlScriptingExecution import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{LazyTry, Utils} import org.apache.spark.util.ArrayImplicits._ @@ -93,16 +94,26 @@ class QueryExecution( } private val lazyAnalyzed = LazyTry { + val withScriptExecuted = logical match { + // Execute the SQL script. Script doesn't need to go through the analyzer as Spark will run + // each statement as individual query. + case NameParameterizedQuery(compoundBody: CompoundBody, argNames, argValues) => + val args = argNames.zip(argValues).toMap + SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody, args) + case compoundBody: CompoundBody => + SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody) + case _ => logical + } try { val plan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. - sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) + sparkSession.sessionState.analyzer.executeAndCheck(withScriptExecuted, tracker) } tracker.setAnalyzed(plan) plan } catch { case NonFatal(e) => - tracker.setAnalysisFailed(logical) + tracker.setAnalysisFailed(withScriptExecuted) throw e } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 68a5a60079e6..ee72e6c358bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -18,10 +18,13 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkThrowable +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} +import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.classic.{DataFrame, SparkSession} +import org.apache.spark.sql.types.StructType /** * SQL scripting executor - executes script and returns result statements. @@ -178,3 +181,50 @@ class SqlScriptingExecution( } } } + +object SqlScriptingExecution { + + /** + * Executes given script and return the result of the last statement. + * If script contains no queries, an empty `DataFrame` is returned. + * + * @param script A SQL script to execute. + * @param args A map of parameter names to SQL literal expressions. + * @return The result as a `DataFrame`. + */ + def executeSqlScript( + session: SparkSession, + script: CompoundBody, + args: Map[String, Expression] = Map.empty): LogicalPlan = { + val sse = new SqlScriptingExecution(script, session, args) + sse.withLocalVariableManager { + var result: Option[Seq[Row]] = None + + // We must execute returned df before calling sse.getNextResult again because sse.hasNext + // advances the script execution and executes all statements until the next result. We must + // collect results immediately to maintain execution order. + // This ensures we respect the contract of SqlScriptingExecution API. + var df: Option[DataFrame] = sse.getNextResult + var resultSchema: Option[StructType] = None + while (df.isDefined) { + sse.withErrorHandling { + // Collect results from the current DataFrame. + result = Some(df.get.collect().toSeq) + resultSchema = Some(df.get.schema) + } + df = sse.getNextResult + } + + if (result.isEmpty) { + // Return empty LocalRelation. + LocalRelation.fromExternalRows(Seq.empty, Seq.empty) + } else { + // If `result` is defined, then `resultSchema` must be defined as well. + assert(resultSchema.isDefined) + + val attributes = DataTypeUtils.toAttributes(resultSchema.get) + LocalRelation.fromExternalRows(attributes, result.get) + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org