This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1b4048bf62d [SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()` 1b4048bf62d is described below commit 1b4048bf62dddae7d324c4b12aa409a1bd456dc5 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Thu Jun 22 09:40:30 2023 +0300 [SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()` ### What changes were proposed in this pull request? In the PR, I propose to extend SparkSession API and override the `sql` method by: ```scala def sql(sqlText: String, args: Array[_]): DataFrame ``` which accepts an array of Java/Scala objects that can be converted to SQL literal expressions. And the first argument `sqlText` might have named parameters in the positions of constants like literal values. A value can be also a `Column` of literal expression, in that case it is taken as is. For example: ```scala spark.sql( sqlText = "SELECT * FROM tbl WHERE date > ? LIMIT ?", args = Array(LocalDate.of(2023, 6, 15), 100)) ``` The new `sql()` method parses the input SQL statement and replaces the positional parameters by the literal values. ### Why are the changes needed? 1. To conform the SQL standard and JDBC/ODBC protocol. 2. To improve user experience with Spark SQL via - Using Spark as remote service (microservice). - Write SQL code that will power reports, dashboards, charts and other data presentation solutions that need to account for criteria modifiable by users through an interface. - Build a generic integration layer based on the SQL API. The goal is to expose managed data to a wide application ecosystem with a microservice architecture. It is only natural in such a setup to ask for modular and reusable SQL code, that can be executed repeatedly with different parameter values. 3. To achieve feature parity with other systems that support positional parameters. ### Does this PR introduce _any_ user-facing change? No, the changes extend the existing API. ### How was this patch tested? By running new tests: ``` $ build/sbt "test:testOnly *AnalysisSuite" $ build/sbt "test:testOnly *PlanParserSuite" $ build/sbt "test:testOnly *ParametersSuite" ``` and the affected test suites: ``` $ build/sbt "sql/testOnly *QueryExecutionErrorsSuite" ``` Closes #41568 from MaxGekk/parametrized-query-pos-param. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../CheckConnectJvmClientCompatibility.scala | 2 + .../sql/connect/planner/SparkConnectPlanner.scala | 4 +- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 + .../spark/sql/catalyst/parser/SqlBaseParser.g4 | 5 +- .../spark/sql/catalyst/analysis/parameters.scala | 95 ++++++-- .../spark/sql/catalyst/parser/AstBuilder.scala | 14 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 22 +- .../sql/catalyst/parser/PlanParserSuite.scala | 25 +- .../scala/org/apache/spark/sql/SparkSession.scala | 34 ++- .../apache/spark/sql/JavaSparkSessionSuite.java | 28 +++ .../org/apache/spark/sql/ParametersSuite.scala | 265 +++++++++++++++++++-- .../sql/errors/QueryExecutionErrorsSuite.scala | 10 +- 12 files changed, 448 insertions(+), 57 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6b648fd152b..acc469672b4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -227,6 +227,8 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"), + // TODO(SPARK-44068): Support positional parameters in Scala connect client + ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"), // RuntimeConfig ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this"), diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6ee252d1a58..856d0f06ba4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -41,7 +41,7 @@ import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, RelationalGroupedDataset, Row, SparkSession} import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} @@ -253,7 +253,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val parser = session.sessionState.sqlParser val parsedPlan = parser.parsePlan(sql.getQuery) if (!args.isEmpty) { - ParameterizedQuery(parsedPlan, args.asScala.mapValues(transformLiteral).toMap) + NameParameterizedQuery(parsedPlan, args.asScala.mapValues(transformLiteral).toMap) } else { parsedPlan } diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index ecd5f5912fd..6c9b3a71266 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -445,6 +445,7 @@ COLON: ':'; ARROW: '->'; HENT_START: '/*+'; HENT_END: '*/'; +QUESTION: '?'; STRING_LITERAL : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 240310a426d..d1e672e9472 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -952,9 +952,10 @@ literalType constant : NULL #nullLiteral - | COLON identifier #parameterLiteral + | QUESTION #posParameterLiteral + | COLON identifier #namedParameterLiteral | interval #intervalLiteral - | literalType stringLit #typeConstructor + | literalType stringLit #typeConstructor | number #numericLiteral | booleanValue #booleanLiteral | stringLit+ #stringLiteral diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index a00f9cec92c..2e3cabce24a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -25,12 +25,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types.DataType -/** - * The expression represents a named parameter that should be replaced by a literal. - * - * @param name The identifier of the parameter without the marker. - */ -case class Parameter(name: String) extends LeafExpression with Unevaluable { +sealed trait Parameter extends LeafExpression with Unevaluable { override lazy val resolved: Boolean = false private def unboundError(methodName: String): Nothing = { @@ -41,17 +36,56 @@ case class Parameter(name: String) extends LeafExpression with Unevaluable { override def nullable: Boolean = unboundError("nullable") final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER) + + def name: String +} + +/** + * The expression represents a named parameter that should be replaced by a literal. + * + * @param name The identifier of the parameter without the marker. + */ +case class NamedParameter(name: String) extends Parameter + +/** + * The expression represents a positional parameter that should be replaced by a literal. + * + * @param pos An unique position of the parameter in a SQL query text. + */ +case class PosParameter(pos: Int) extends Parameter { + override def name: String = s"_$pos" } /** * The logical plan representing a parameterized query. It will be removed during analysis after * the parameters are bind. */ -case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) - extends UnresolvedUnaryNode { +abstract class ParameterizedQuery(child: LogicalPlan) extends UnresolvedUnaryNode { + final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY) +} +/** + * The logical plan representing a parameterized query with named parameters. + * + * @param child The parameterized logical plan. + * @param args The map of parameter names to its literal values. + */ +case class NameParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) + extends ParameterizedQuery(child) { + assert(args.nonEmpty) + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(child = newChild) +} + +/** + * The logical plan representing a parameterized query with positional parameters. + * + * @param child The parameterized logical plan. + * @param args The literal values of positional parameters. + */ +case class PosParameterizedQuery(child: LogicalPlan, args: Array[Expression]) + extends ParameterizedQuery(child) { assert(args.nonEmpty) - final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY) override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(child = newChild) } @@ -61,6 +95,20 @@ case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) * user-specified arguments. */ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { + private def checkArgs(args: Iterable[(String, Expression)]): Unit = { + args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) => + expr.failAnalysis( + errorClass = "INVALID_SQL_ARG", + messageParameters = Map("name" -> name)) + } + } + + private def bind(p: LogicalPlan)(f: PartialFunction[Expression, Expression]): LogicalPlan = { + p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) (f orElse { + case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)(f)) + }) + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (plan.containsPattern(PARAMETERIZED_QUERY)) { // One unresolved plan can have at most one ParameterizedQuery. @@ -71,23 +119,22 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) { // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE // relations are not children of `UnresolvedWith`. - case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) => - args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) => - expr.failAnalysis( - errorClass = "INVALID_SQL_ARG", - messageParameters = Map("name" -> name)) - } + case p @ NameParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) => + checkArgs(args) + bind(child) { case NamedParameter(name) if args.contains(name) => args(name) } + + case p @ PosParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) => + val indexedArgs = args.zipWithIndex + checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1))) + + val positions = scala.collection.mutable.Set.empty[Int] + bind(child) { case p @ PosParameter(pos) => positions.add(pos); p } + val posToIndex = positions.toSeq.sorted.zipWithIndex.toMap - def bind(p: LogicalPlan): LogicalPlan = { - p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) { - case Parameter(name) if args.contains(name) => - args(name) - case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)) - } + bind(child) { + case PosParameter(pos) if posToIndex.contains(pos) && args.size > posToIndex(pos) => + args(posToIndex(pos)) } - val res = bind(child) - res.copyTagsFrom(p) - res case _ => plan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 07721424a86..ca62de12e7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5113,7 +5113,17 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit /** * Create a named parameter which represents a literal with a non-bound value and unknown type. * */ - override def visitParameterLiteral(ctx: ParameterLiteralContext): Expression = withOrigin(ctx) { - Parameter(ctx.identifier().getText) + override def visitNamedParameterLiteral( + ctx: NamedParameterLiteralContext): Expression = withOrigin(ctx) { + NamedParameter(ctx.identifier().getText) + } + + /** + * Create a positional parameter which represents a literal + * with a non-bound value and unknown type. + * */ + override def visitPosParameterLiteral( + ctx: PosParameterLiteralContext): Expression = withOrigin(ctx) { + PosParameter(ctx.QUESTION().getSymbol.getStartIndex) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1e844e22bec..dae42453f0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1370,7 +1370,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-41271: bind named parameters to literals") { CTERelationDef.curId.set(0) - val actual1 = ParameterizedQuery( + val actual1 = NameParameterizedQuery( child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"), args = Map("limitA" -> Literal(10))).analyze CTERelationDef.curId.set(0) @@ -1378,7 +1378,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { comparePlans(actual1, expected1) // Ignore unused arguments CTERelationDef.curId.set(0) - val actual2 = ParameterizedQuery( + val actual2 = NameParameterizedQuery( child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"), args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze CTERelationDef.curId.set(0) @@ -1386,6 +1386,24 @@ class AnalysisSuite extends AnalysisTest with Matchers { comparePlans(actual2, expected2) } + test("SPARK-44066: bind positional parameters to literals") { + CTERelationDef.curId.set(0) + val actual1 = PosParameterizedQuery( + child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT ?"), + args = Array(Literal(10))).analyze + CTERelationDef.curId.set(0) + val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze + comparePlans(actual1, expected1) + // Ignore unused arguments + CTERelationDef.curId.set(0) + val actual2 = PosParameterizedQuery( + child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < ?"), + args = Array(Literal(20), Literal(10))).analyze + CTERelationDef.curId.set(0) + val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze + comparePlans(actual2, expected2) + } + test("SPARK-41489: type of filter expression should be a bool") { assertAnalysisErrorClass(parsePlan( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5a28ef847dc..ded8aaf7430 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedParameter, PosParameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc} import org.apache.spark.sql.catalyst.plans._ @@ -1630,18 +1630,18 @@ class PlanParserSuite extends AnalysisTest { test("SPARK-41271: parsing of named parameters") { comparePlans( parsePlan("SELECT :param_1"), - Project(UnresolvedAlias(Parameter("param_1"), None) :: Nil, OneRowRelation())) + Project(UnresolvedAlias(NamedParameter("param_1"), None) :: Nil, OneRowRelation())) comparePlans( parsePlan("SELECT abs(:1Abc)"), Project(UnresolvedAlias( UnresolvedFunction( "abs" :: Nil, - Parameter("1Abc") :: Nil, + NamedParameter("1Abc") :: Nil, isDistinct = false), None) :: Nil, OneRowRelation())) comparePlans( parsePlan("SELECT * FROM a LIMIT :limitA"), - table("a").select(star()).limit(Parameter("limitA"))) + table("a").select(star()).limit(NamedParameter("limitA"))) // Invalid empty name and invalid symbol in a name checkError( exception = parseException(s"SELECT :-"), @@ -1661,4 +1661,21 @@ class PlanParserSuite extends AnalysisTest { Seq(Literal("abc")) :: Nil).as("tbl").select($"interval") ) } + + test("SPARK-44066: parsing of positional parameters") { + comparePlans( + parsePlan("SELECT ?"), + Project(UnresolvedAlias(PosParameter(7), None) :: Nil, OneRowRelation())) + comparePlans( + parsePlan("SELECT abs(?)"), + Project(UnresolvedAlias( + UnresolvedFunction( + "abs" :: Nil, + PosParameter(11) :: Nil, + isDistinct = false), None) :: Nil, + OneRowRelation())) + comparePlans( + parsePlan("SELECT * FROM a LIMIT ?"), + table("a").select(star()).limit(PosParameter(22))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 642006fb8dc..2a1c2474bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -35,7 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} @@ -609,6 +609,36 @@ class SparkSession private( | Everything else | * ----------------- */ + /** + * Executes a SQL query substituting positional parameters by the given arguments, + * returning the result as a `DataFrame`. + * This API eagerly runs DDL/DML commands, but not for SELECT queries. + * + * @param sqlText A SQL statement with positional parameters to execute. + * @param args An array of Java/Scala objects that can be converted to + * SQL literal expressions. See + * <a href="https://spark.apache.org/docs/latest/sql-ref-datatypes.html"> + * Supported Data Types</a> for supported value types in Scala/Java. + * For example, 1, "Steven", LocalDate.of(2023, 4, 2). + * A value can be also a `Column` of literal expression, in that case + * it is taken as is. + * + * @since 3.5.0 + */ + @Experimental + def sql(sqlText: String, args: Array[_]): DataFrame = withActive { + val tracker = new QueryPlanningTracker + val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { + val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) + if (args.nonEmpty) { + PosParameterizedQuery(parsedPlan, args.map(lit(_).expr)) + } else { + parsedPlan + } + } + Dataset.ofRows(self, plan, tracker) + } + /** * Executes a SQL query substituting named parameters by the given arguments, * returning the result as a `DataFrame`. @@ -632,7 +662,7 @@ class SparkSession private( val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) if (args.nonEmpty) { - ParameterizedQuery(parsedPlan, args.mapValues(lit(_).expr).toMap) + NameParameterizedQuery(parsedPlan, args.mapValues(lit(_).expr).toMap) } else { parsedPlan } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java index b1df377936d..0d6d773d930 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java @@ -18,11 +18,13 @@ package test.org.apache.spark.sql; import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSparkSession; import org.junit.After; import org.junit.Assert; import org.junit.Test; import java.util.HashMap; +import java.util.List; import java.util.Map; public class JavaSparkSessionSuite { @@ -54,4 +56,30 @@ public class JavaSparkSessionSuite { Assert.assertEquals(spark.conf().get(e.getKey()), e.getValue().toString()); } } + + @Test + public void testPositionalParameters() { + spark = new TestSparkSession(); + + int[] emptyArgs = {}; + List<Row> collected1 = spark.sql("select 'abc'", emptyArgs).collectAsList(); + Assert.assertEquals("abc", collected1.get(0).getString(0)); + + Object[] singleArg = new String[] { "abc" }; + List<Row> collected2 = spark.sql("select ?", singleArg).collectAsList(); + Assert.assertEquals("abc", collected2.get(0).getString(0)); + + int[] args = new int[] { 1, 2, 3 }; + List<Row> collected3 = spark.sql("select ?, ?, ?", args).collectAsList(); + Row r0 = collected3.get(0); + Assert.assertEquals(1, r0.getInt(0)); + Assert.assertEquals(2, r0.getInt(1)); + Assert.assertEquals(3, r0.getInt(2)); + + Object[] mixedArgs = new Object[] { 1, "abc" }; + List<Row> collected4 = spark.sql("select ?, ?", mixedArgs).collectAsList(); + Row r1 = collected4.get(0); + Assert.assertEquals(1, r1.getInt(0)); + Assert.assertEquals("abc", r1.getString(1)); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 985d0373c4f..725956e259b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.test.SharedSparkSession class ParametersSuite extends QueryTest with SharedSparkSession { - test("bind parameters") { + test("bind named parameters") { val sqlText = """ |SELECT id, id % :div as c0 @@ -42,6 +42,23 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(true)) } + test("bind positional parameters") { + val sqlText = + """ + |SELECT id, id % ? as c0 + |FROM VALUES (0), (1), (2), (3), (4), (5), (6), (7), (8), (9) AS t(id) + |WHERE id < ? + |""".stripMargin + val args = Array(3, 4L) + checkAnswer( + spark.sql(sqlText, args), + Row(0, 0) :: Row(1, 1) :: Row(2, 2) :: Row(3, 0) :: Nil) + + checkAnswer( + spark.sql("""SELECT contains('Spark \'SQL\'', ?)""", Array("SQL")), + Row(true)) + } + test("parameter binding is case sensitive") { checkAnswer( spark.sql("SELECT :p, :P", Map("p" -> 1, "P" -> 2)), @@ -60,7 +77,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession { stop = 8)) } - test("parameters in CTE") { + test("named parameters in CTE") { val sqlText = """ |WITH w1 AS (SELECT :p1 AS p) @@ -72,7 +89,19 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(3)) } - test("parameters in nested CTE") { + test("positional parameters in CTE") { + val sqlText = + """ + |WITH w1 AS (SELECT ? AS p) + |SELECT p + ? FROM w1 + |""".stripMargin + val args = Array(1, 2) + checkAnswer( + spark.sql(sqlText, args), + Row(3)) + } + + test("named parameters in nested CTE") { val sqlText = """ |WITH w1 AS @@ -85,7 +114,20 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(6)) } - test("parameters in subquery expression") { + test("positional parameters in nested CTE") { + val sqlText = + """ + |WITH w1 AS + | (WITH w2 AS (SELECT ? AS p) SELECT p + ? AS p2 FROM w2) + |SELECT p2 + ? FROM w1 + |""".stripMargin + val args = Array(1, 2, 3) + checkAnswer( + spark.sql(sqlText, args), + Row(6)) + } + + test("named parameters in subquery expression") { val sqlText = "SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2" val args = Map("p1" -> 1, "p2" -> 2) checkAnswer( @@ -93,7 +135,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(12)) } - test("parameters in nested subquery expression") { + test("positional parameters in subquery expression") { + val sqlText = "SELECT (SELECT max(id) + ? FROM range(10)) + ?" + val args = Array(1, 2) + checkAnswer( + spark.sql(sqlText, args), + Row(12)) + } + + test("named parameters in nested subquery expression") { val sqlText = "SELECT (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2) + :p3" val args = Map("p1" -> 1, "p2" -> 2, "p3" -> 3) checkAnswer( @@ -101,7 +151,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(15)) } - test("parameters in subquery expression inside CTE") { + test("positional parameters in nested subquery expression") { + val sqlText = "SELECT (SELECT (SELECT max(id) + ? FROM range(10)) + ?) + ?" + val args = Array(1, 2, 3) + checkAnswer( + spark.sql(sqlText, args), + Row(15)) + } + + test("named parameters in subquery expression inside CTE") { val sqlText = """ |WITH w1 AS (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2 AS p) @@ -113,7 +171,19 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(15)) } - test("parameter in identifier clause") { + test("positional parameters in subquery expression inside CTE") { + val sqlText = + """ + |WITH w1 AS (SELECT (SELECT max(id) + ? FROM range(10)) + ? AS p) + |SELECT p + ? FROM w1 + |""".stripMargin + val args = Array(1, 2, 3) + checkAnswer( + spark.sql(sqlText, args), + Row(15)) + } + + test("named parameter in identifier clause") { val sqlText = "SELECT IDENTIFIER('T.' || :p1 || '1') FROM VALUES(1) T(c1)" val args = Map("p1" -> "c") @@ -122,7 +192,16 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(1)) } - test("parameter in identifier clause in DDL and utility commands") { + test("positional parameter in identifier clause") { + val sqlText = + "SELECT IDENTIFIER('T.' || ? || '1') FROM VALUES(1) T(c1)" + val args = Array("c") + checkAnswer( + spark.sql(sqlText, args), + Row(1)) + } + + test("named parameter in identifier clause in DDL and utility commands") { spark.sql("CREATE VIEW IDENTIFIER(:p1)(c1) AS SELECT 1", args = Map("p1" -> "v")) spark.sql("ALTER VIEW IDENTIFIER(:p1) AS SELECT 2 AS c1", args = Map("p1" -> "v")) checkAnswer( @@ -131,7 +210,16 @@ class ParametersSuite extends QueryTest with SharedSparkSession { spark.sql("DROP VIEW IDENTIFIER(:p1)", args = Map("p1" -> "v")) } - test("parameters in INSERT") { + test("positional parameter in identifier clause in DDL and utility commands") { + spark.sql("CREATE VIEW IDENTIFIER(?)(c1) AS SELECT 1", args = Array("v")) + spark.sql("ALTER VIEW IDENTIFIER(?) AS SELECT 2 AS c1", args = Array("v")) + checkAnswer( + spark.sql("SHOW COLUMNS FROM IDENTIFIER(?)", args = Array("v")), + Row("c1")) + spark.sql("DROP VIEW IDENTIFIER(?)", args = Array("v")) + } + + test("named parameters in INSERT") { withTable("t") { sql("CREATE TABLE t (col INT) USING json") spark.sql("INSERT INTO t SELECT :p", Map("p" -> 1)) @@ -139,7 +227,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession { } } - test("parameters not allowed in view body ") { + test("positional parameters in INSERT") { + withTable("t") { + sql("CREATE TABLE t (col INT) USING json") + spark.sql("INSERT INTO t SELECT ?", Array(1)) + checkAnswer(spark.table("t"), Row(1)) + } + } + + test("named parameters not allowed in view body ") { val sqlText = "CREATE VIEW v AS SELECT :p AS p" val args = Map("p" -> 1) checkError( @@ -154,7 +250,22 @@ class ParametersSuite extends QueryTest with SharedSparkSession { stop = sqlText.length - 1)) } - test("parameters not allowed in view body - WITH and scalar subquery") { + test("positional parameters not allowed in view body ") { + val sqlText = "CREATE VIEW v AS SELECT ? AS p" + val args = Array(1) + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlText, args) + }, + errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "CREATE VIEW body"), + context = ExpectedContext( + fragment = sqlText, + start = 0, + stop = sqlText.length - 1)) + } + + test("named parameters not allowed in view body - WITH and scalar subquery") { val sqlText = "CREATE VIEW v AS WITH cte(a) AS (SELECT (SELECT :p) AS a) SELECT a FROM cte" val args = Map("p" -> 1) checkError( @@ -169,7 +280,22 @@ class ParametersSuite extends QueryTest with SharedSparkSession { stop = sqlText.length - 1)) } - test("parameters not allowed in view body - nested WITH and EXIST") { + test("positional parameters not allowed in view body - WITH and scalar subquery") { + val sqlText = "CREATE VIEW v AS WITH cte(a) AS (SELECT (SELECT ?) AS a) SELECT a FROM cte" + val args = Array(1) + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlText, args) + }, + errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "CREATE VIEW body"), + context = ExpectedContext( + fragment = sqlText, + start = 0, + stop = sqlText.length - 1)) + } + + test("named parameters not allowed in view body - nested WITH and EXIST") { val sqlText = """CREATE VIEW v AS |SELECT a as a @@ -188,7 +314,26 @@ class ParametersSuite extends QueryTest with SharedSparkSession { stop = sqlText.length - 1)) } - test("non-substituted parameters") { + test("positional parameters not allowed in view body - nested WITH and EXIST") { + val sqlText = + """CREATE VIEW v AS + |SELECT a as a + |FROM (WITH cte(a) AS (SELECT CASE WHEN EXISTS(SELECT ?) THEN 1 END AS a) + |SELECT a FROM cte)""".stripMargin + val args = Array(1) + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlText, args) + }, + errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "CREATE VIEW body"), + context = ExpectedContext( + fragment = sqlText, + start = 0, + stop = sqlText.length - 1)) + } + + test("non-substituted named parameters") { checkError( exception = intercept[AnalysisException] { spark.sql("select :abc, :def", Map("abc" -> 1)) @@ -211,7 +356,30 @@ class ParametersSuite extends QueryTest with SharedSparkSession { stop = 10)) } - test("literal argument of `sql()`") { + test("non-substituted positional parameters") { + checkError( + exception = intercept[AnalysisException] { + spark.sql("select ?, ?", Array(1)) + }, + errorClass = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "_10"), + context = ExpectedContext( + fragment = "?", + start = 10, + stop = 10)) + checkError( + exception = intercept[AnalysisException] { + sql("select ?").collect() + }, + errorClass = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "_7"), + context = ExpectedContext( + fragment = "?", + start = 7, + stop = 7)) + } + + test("literal argument of named parameter in `sql()`") { val sqlText = """SELECT s FROM VALUES ('Jeff /*__*/ Green'), ('E\'Twaun Moore'), ('Vander Blue') AS t(s) |WHERE s = :player_name""".stripMargin @@ -249,4 +417,73 @@ class ParametersSuite extends QueryTest with SharedSparkSession { .toInstant) :: Nil) } } + + test("literal argument of positional parameter in `sql()`") { + val sqlText = + """SELECT s FROM VALUES ('Jeff /*__*/ Green'), ('E\'Twaun Moore'), ('Vander Blue') AS t(s) + |WHERE s = ?""".stripMargin + checkAnswer( + spark.sql(sqlText, args = Array(lit("E'Twaun Moore"))), + Row("E'Twaun Moore") :: Nil) + checkAnswer( + spark.sql(sqlText, args = Array(lit("Vander Blue--comment"))), + Nil) + checkAnswer( + spark.sql(sqlText, args = Array(lit("Jeff /*__*/ Green"))), + Row("Jeff /*__*/ Green") :: Nil) + + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { + checkAnswer( + spark.sql( + sqlText = """ + |SELECT d + |FROM VALUES (DATE'1970-01-01'), (DATE'2023-12-31') AS t(d) + |WHERE d < ? + |""".stripMargin, + args = Array(lit(LocalDate.of(2023, 4, 1)))), + Row(LocalDate.of(1970, 1, 1)) :: Nil) + checkAnswer( + spark.sql( + sqlText = """ + |SELECT d + |FROM VALUES (TIMESTAMP_LTZ'1970-01-01 01:02:03 Europe/Amsterdam'), + | (TIMESTAMP_LTZ'2023-12-31 04:05:06 America/Los_Angeles') AS t(d) + |WHERE d < ? + |""".stripMargin, + args = Array(lit(Instant.parse("2023-04-01T00:00:00Z")))), + Row(LocalDateTime.of(1970, 1, 1, 1, 2, 3) + .atZone(ZoneId.of("Europe/Amsterdam")) + .toInstant) :: Nil) + } + } + + test("unused positional arguments") { + checkAnswer( + spark.sql("SELECT ?, ?", Array(1, "abc", 3.14f)), + Row(1, "abc")) + } + + test("mixing of positional and named parameters") { + checkError( + exception = intercept[AnalysisException] { + spark.sql("select :param1, ?", Map("param1" -> 1)) + }, + errorClass = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "_16"), + context = ExpectedContext( + fragment = "?", + start = 16, + stop = 16)) + + checkError( + exception = intercept[AnalysisException] { + spark.sql("select :param1, ?", Array(1)) + }, + errorClass = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "param1"), + context = ExpectedContext( + fragment = ":param1", + start = 7, + stop = 13)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 61b3610e64e..8f47b06d855 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -29,7 +29,7 @@ import org.mockito.Mockito.{mock, spy, when} import org.apache.spark._ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.{Parameter, UnresolvedGenerator} +import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator} import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -866,26 +866,26 @@ class QueryExecutionErrorsSuite test("INTERNAL_ERROR: Calling eval on Unevaluable expression") { val e = intercept[SparkException] { - Parameter("foo").eval() + NamedParameter("foo").eval() } checkError( exception = e, errorClass = "INTERNAL_ERROR", - parameters = Map("message" -> "Cannot evaluate expression: parameter(foo)"), + parameters = Map("message" -> "Cannot evaluate expression: namedparameter(foo)"), sqlState = "XX000") } test("INTERNAL_ERROR: Calling doGenCode on unresolved") { val e = intercept[SparkException] { val ctx = new CodegenContext - Grouping(Parameter("foo")).genCode(ctx) + Grouping(NamedParameter("foo")).genCode(ctx) } checkError( exception = e, errorClass = "INTERNAL_ERROR", parameters = Map( "message" -> ("Cannot generate code for expression: " + - "grouping(parameter(foo))")), + "grouping(namedparameter(foo))")), sqlState = "XX000") } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org