This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 35fa5e6716e [SPARK-41271][SQL] Support parameterized SQL queries by `sql()` 35fa5e6716e is described below commit 35fa5e6716e59b004851b61f7fbfbdace15f46b7 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Thu Dec 15 09:14:46 2022 +0300 [SPARK-41271][SQL] Support parameterized SQL queries by `sql()` ### What changes were proposed in this pull request? In the PR, I propose to extend SparkSession API and override the `sql` method by: ```scala def sql(sqlText: String, args: Map[String, String]): DataFrame ``` which accepts a map with: - keys are parameters names, - values are SQL literal values. And the first argument `sqlText` might have named parameters in the positions of constants like literal values. For example: ```scala spark.sql( sqlText = "SELECT * FROM tbl WHERE date > :startDate LIMIT :maxRows", args = Map( "startDate" -> "DATE'2022-12-01'", "maxRows" -> "100")) ``` The new `sql()` method parses the input SQL statement and provided parameter values, and replaces the named parameters by the literal values. And then it eagerly runs DDL/DML commands, but not for SELECT queries. Closes #38712 ### Why are the changes needed? 1. To improve user experience with Spark SQL via - Using Spark as remote service (microservice). - Write SQL code that will power reports, dashboards, charts and other data presentation solutions that need to account for criteria modifiable by users through an interface. - Build a generic integration layer based on the SQL API. The goal is to expose managed data to a wide application ecosystem with a microservice architecture. It is only natural in such a setup to ask for modular and reusable SQL code, that can be executed repeatedly with different parameter values. 2. To achieve feature parity with other systems that support named parameters: - Redshift: https://docs.aws.amazon.com/redshift/latest/mgmt/data-api.html#data-api-calling - BigQuery: https://cloud.google.com/bigquery/docs/parameterized-queries#api - MS DBSQL: https://learn.microsoft.com/en-us/azure/databricks/sql/user/queries/query-parameters ### Does this PR introduce _any_ user-facing change? No, this is an extension of the existing APIs. ### How was this patch tested? By running new tests: ``` $ build/sbt "core/testOnly *SparkThrowableSuite" $ build/sbt "test:testOnly *PlanParserSuite" $ build/sbt "test:testOnly *AnalysisSuite" $ build/sbt "test:testOnly *ParametersSuite" ``` Closes #38864 from MaxGekk/parameterized-sql-2. Lead-authored-by: Max Gekk <max.g...@gmail.com> Co-authored-by: Maxim Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 10 +++ .../spark/sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../sql/catalyst/analysis/CheckAnalysis.scala | 5 ++ .../sql/catalyst/expressions/parameters.scala | 64 ++++++++++++++++++ .../spark/sql/catalyst/parser/AstBuilder.scala | 7 ++ .../spark/sql/catalyst/trees/TreePatterns.scala | 1 + .../sql/catalyst/analysis/AnalysisSuite.scala | 14 ++++ .../sql/catalyst/parser/PlanParserSuite.scala | 26 ++++++++ .../scala/org/apache/spark/sql/SparkSession.scala | 40 +++++++++-- .../org/apache/spark/sql/ParametersSuite.scala | 78 ++++++++++++++++++++++ .../org/apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../benchmark/InsertIntoHiveTableBenchmark.scala | 4 +- .../ObjectHashAggregateExecBenchmark.scala | 4 +- 13 files changed, 246 insertions(+), 10 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index f66d6998e26..b7bf07a0e48 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -806,6 +806,11 @@ } } }, + "INVALID_SQL_ARG" : { + "message" : [ + "The argument <name> of `sql()` is invalid. Consider to replace it by a SQL literal statement." + ] + }, "INVALID_SQL_SYNTAX" : { "message" : [ "Invalid SQL syntax: <inputString>" @@ -1147,6 +1152,11 @@ "Unable to convert SQL type <toType> to Protobuf type <protobufType>." ] }, + "UNBOUND_SQL_PARAMETER" : { + "message" : [ + "Found the unbound parameter: <name>. Please, fix `args` and provide a mapping of the parameter to a SQL literal statement." + ] + }, "UNCLOSED_BRACKETED_COMMENT" : { "message" : [ "Found an unclosed bracketed comment. Please, append */ at the end of the comment." diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 21747a0a021..078a9939116 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -930,6 +930,7 @@ primaryExpression constant : NULL #nullLiteral + | COLON identifier #parameterLiteral | interval #intervalLiteral | identifier stringLit #typeConstructor | number #numericLiteral diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5303364710c..11b2d6671c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -325,6 +325,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "_LEGACY_ERROR_TEMP_2413", messageParameters = Map("argName" -> e.prettyName)) + case p: Parameter => + p.failAnalysis( + errorClass = "UNBOUND_SQL_PARAMETER", + messageParameters = Map("name" -> toSQLId(p.name))) + case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala new file mode 100644 index 00000000000..fae2b9a1a9f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.analysis.AnalysisErrorAt +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern} +import org.apache.spark.sql.errors.QueryErrorsBase +import org.apache.spark.sql.types.DataType + +/** + * The expression represents a named parameter that should be replaced by a literal. + * + * @param name The identifier of the parameter without the marker. + */ +case class Parameter(name: String) extends LeafExpression with Unevaluable { + override lazy val resolved: Boolean = false + + private def unboundError(methodName: String): Nothing = { + throw SparkException.internalError( + s"Cannot call `$methodName()` of the unbound parameter `$name`.") + } + override def dataType: DataType = unboundError("dataType") + override def nullable: Boolean = unboundError("nullable") + + final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER) +} + + +/** + * Finds all named parameters in the given plan and substitutes them by literals of `args` values. + */ +object Parameter extends QueryErrorsBase { + def bind(plan: LogicalPlan, args: Map[String, Expression]): LogicalPlan = { + if (!args.isEmpty) { + args.filter(!_._2.isInstanceOf[Literal]).headOption.foreach { case (name, expr) => + expr.failAnalysis( + errorClass = "INVALID_SQL_ARG", + messageParameters = Map("name" -> toSQLId(name))) + } + plan.transformAllExpressionsWithPruning(_.containsPattern(PARAMETER)) { + case Parameter(name) if args.contains(name) => args(name) + } + } else { + plan + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index da25702e1f2..545d5d97d88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4837,4 +4837,11 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit override def visitTimestampdiff(ctx: TimestampdiffContext): Expression = withOrigin(ctx) { TimestampDiff(ctx.unit.getText, expression(ctx.startTimestamp), expression(ctx.endTimestamp)) } + + /** + * Create a named parameter which represents a literal with a non-bound value and unknown type. + * */ + override def visitParameterLiteral(ctx: ParameterLiteralContext): Expression = withOrigin(ctx) { + Parameter(ctx.identifier().getText) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 1a8ad7c7d62..ab3f8726815 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -71,6 +71,7 @@ object TreePattern extends Enumeration { val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value val OUTER_REFERENCE: Value = Value + val PARAMETER: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 8b303ec3bb1..b7cb7fa59ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1295,4 +1295,18 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertAnalysisSuccess(finalPlan) } + + test("SPARK-41271: bind named parameters to literals") { + comparePlans( + Parameter.bind( + plan = parsePlan("SELECT * FROM a LIMIT :limitA"), + args = Map("limitA" -> Literal(10))), + parsePlan("SELECT * FROM a LIMIT 10")) + // Ignore unused arguments + comparePlans( + Parameter.bind( + plan = parsePlan("SELECT c FROM a WHERE c < :param2"), + args = Map("param1" -> Literal(10), "param2" -> Literal(20))), + parsePlan("SELECT c FROM a WHERE c < 20")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 11590e465c2..035e6231178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -1568,4 +1568,30 @@ class PlanParserSuite extends AnalysisTest { .toAggregateExpression(false, Some(GreaterThan(UnresolvedAttribute("id"), Literal(10)))) ) } + + test("SPARK-41271: parsing of named parameters") { + comparePlans( + parsePlan("SELECT :param_1"), + Project(UnresolvedAlias(Parameter("param_1"), None) :: Nil, OneRowRelation())) + comparePlans( + parsePlan("SELECT abs(:1Abc)"), + Project(UnresolvedAlias( + UnresolvedFunction( + "abs" :: Nil, + Parameter("1Abc") :: Nil, + isDistinct = false), None) :: Nil, + OneRowRelation())) + comparePlans( + parsePlan("SELECT * FROM a LIMIT :limitA"), + table("a").select(star()).limit(Parameter("limitA"))) + // Invalid empty name and invalid symbol in a name + checkError( + exception = parseException(s"SELECT :-"), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'-'", "hint" -> "")) + checkError( + exception = parseException(s"SELECT :"), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "end of input", "hint" -> "")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3d9f1679957..adbe593ac56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Parameter} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.ExternalCommandRunner @@ -609,19 +609,49 @@ class SparkSession private( * ----------------- */ /** - * Executes a SQL query using Spark, returning the result as a `DataFrame`. + * Executes a SQL query substituting named parameters by the given arguments, + * returning the result as a `DataFrame`. * This API eagerly runs DDL/DML commands, but not for SELECT queries. * - * @since 2.0.0 + * @param sqlText A SQL statement with named parameters to execute. + * @param args A map of parameter names to literal values. + * + * @since 3.4.0 */ - def sql(sqlText: String): DataFrame = withActive { + @Experimental + def sql(sqlText: String, args: Map[String, String]): DataFrame = withActive { val tracker = new QueryPlanningTracker val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { - sessionState.sqlParser.parsePlan(sqlText) + val parser = sessionState.sqlParser + val parsedArgs = args.mapValues(parser.parseExpression).toMap + Parameter.bind(parser.parsePlan(sqlText), parsedArgs) } Dataset.ofRows(self, plan, tracker) } + /** + * Executes a SQL query substituting named parameters by the given arguments, + * returning the result as a `DataFrame`. + * This API eagerly runs DDL/DML commands, but not for SELECT queries. + * + * @param sqlText A SQL statement with named parameters to execute. + * @param args A map of parameter names to literal values. + * + * @since 3.4.0 + */ + @Experimental + def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = { + sql(sqlText, args.asScala.toMap) + } + + /** + * Executes a SQL query using Spark, returning the result as a `DataFrame`. + * This API eagerly runs DDL/DML commands, but not for SELECT queries. + * + * @since 2.0.0 + */ + def sql(sqlText: String): DataFrame = sql(sqlText, Map.empty[String, String]) + /** * Execute an arbitrary string command inside an external execution engine rather than Spark. * This could be useful when user wants to execute some commands out of Spark. For diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala new file mode 100644 index 00000000000..668a1e4ad7d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSparkSession + +class ParametersSuite extends QueryTest with SharedSparkSession { + + test("bind parameters") { + val sqlText = + """ + |SELECT id, id % :div as c0 + |FROM VALUES (0), (1), (2), (3), (4), (5), (6), (7), (8), (9) AS t(id) + |WHERE id < :constA + |""".stripMargin + val args = Map("div" -> "3", "constA" -> "4L") + checkAnswer( + spark.sql(sqlText, args), + Row(0, 0) :: Row(1, 1) :: Row(2, 2) :: Row(3, 0) :: Nil) + + checkAnswer( + spark.sql("""SELECT contains('Spark \'SQL\'', :subStr)""", Map("subStr" -> "'SQL'")), + Row(true)) + } + + test("non-substituted parameters") { + checkError( + exception = intercept[AnalysisException] { + spark.sql("select :abc, :def", Map("abc" -> "1")) + }, + errorClass = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "`def`"), + context = ExpectedContext( + fragment = ":def", + start = 13, + stop = 16)) + checkError( + exception = intercept[AnalysisException] { + sql("select :abc").collect() + }, + errorClass = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "`abc`"), + context = ExpectedContext( + fragment = ":abc", + start = 7, + stop = 10)) + } + + test("non-literal argument of `sql()`") { + Seq("col1 + 1", "CAST('100' AS INT)", "map('a', 1, 'b', 2)", "array(1)").foreach { arg => + checkError( + exception = intercept[AnalysisException] { + spark.sql("SELECT :param1 FROM VALUES (1) AS t(col1)", Map("param1" -> arg)) + }, + errorClass = "INVALID_SQL_ARG", + parameters = Map("name" -> "`param1`"), + context = ExpectedContext( + fragment = arg, + start = 0, + stop = arg.length - 1)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ae425419c54..dd55fcfe42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -229,7 +229,7 @@ private[sql] trait SQLTestUtilsBase protected def sparkContext = spark.sparkContext // Shorthand for running a query using our SQLContext - protected lazy val sql = spark.sql _ + protected lazy val sql: String => DataFrame = spark.sql _ /** * A helper object for importing SQL implicits. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/InsertIntoHiveTableBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/InsertIntoHiveTableBenchmark.scala index 76345985698..b64b7823acd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/InsertIntoHiveTableBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/InsertIntoHiveTableBenchmark.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.hive.test.TestHive /** @@ -40,7 +40,7 @@ object InsertIntoHiveTableBenchmark extends SqlBasedBenchmark { val tempView = "temp" val numRows = 1024 * 10 - val sql = spark.sql _ + val sql: String => DataFrame = spark.sql _ // scalastyle:off hadoopconfiguration private val hadoopConf = spark.sparkContext.hadoopConfiguration diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 5d0a5ce0957..1a4700e7445 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.{Column, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.hive.execution.TestingTypedCount @@ -46,7 +46,7 @@ object ObjectHashAggregateExecBenchmark extends SqlBasedBenchmark { override def getSparkSession: SparkSession = TestHive.sparkSession - private val sql = spark.sql _ + private val sql: String => DataFrame = spark.sql _ import spark.implicits._ private def hiveUDAFvsSparkAF(N: Int): Unit = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org