This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cd672b09ac6 [SPARK-45162][SQL] Support maps and array parameters constructed via `call_function` cd672b09ac6 is described below commit cd672b09ac69724cd99dc12c9bb49dd117025be1 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Thu Sep 14 11:31:56 2023 +0300 [SPARK-45162][SQL] Support maps and array parameters constructed via `call_function` ### What changes were proposed in this pull request? In the PR, I propose to move the `BindParameters` rules from the `Substitution` to the `Resolution` batch, and change types of the `args` parameter of `NameParameterizedQuery` and `PosParameterizedQuery` to an `Iterable` to resolve argument expressions. ### Why are the changes needed? After the PR, the parameterized `sql()` allows map/array/struct constructed by functions like `map()`, `array()`, and `struct()`, but the same functions invoked via `call_function` are not supported: ```scala scala> sql("SELECT element_at(:mapParam, 'a')", Map("mapParam" -> call_function("map", lit("a"), lit(1)))) org.apache.spark.sql.catalyst.ExtendedAnalysisException: [UNBOUND_SQL_PARAMETER] Found the unbound parameter: mapParam. Please, fix `args` and provide a mapping of the parameter to a SQL literal.; line 1 pos 18; ``` ### Does this PR introduce _any_ user-facing change? No, should not since it fixes an issue. Only if user code depends on the error message. After the changes: ```scala scala> sql("SELECT element_at(:mapParam, 'a')", Map("mapParam" -> call_function("map", lit("a"), lit(1)))).show(false) +------------------------+ |element_at(map(a, 1), a)| +------------------------+ |1 | +------------------------+ ``` ### How was this patch tested? By running new tests: ``` $ build/sbt "test:testOnly *ParametersSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42894 from MaxGekk/fix-parameterized-sql-unresolved. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../sql/connect/planner/SparkConnectPlanner.scala | 2 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/analysis/parameters.scala | 28 +++++++++++++++++----- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 ++-- .../org/apache/spark/sql/ParametersSuite.scala | 19 ++++++++++++--- 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 24dee006f0b..74a8ff290eb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -269,7 +269,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { if (!args.isEmpty) { NameParameterizedQuery(parsedPlan, args.asScala.mapValues(transformLiteral).toMap) } else if (!posArgs.isEmpty) { - PosParameterizedQuery(parsedPlan, posArgs.asScala.map(transformLiteral).toArray) + PosParameterizedQuery(parsedPlan, posArgs.asScala.map(transformLiteral).toSeq) } else { parsedPlan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e15b9730111..6491a4eea95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -260,7 +260,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // at the beginning of analysis. OptimizeUpdateFields, CTESubstitution, - BindParameters, WindowsSubstitution, EliminateUnions, SubstituteUnresolvedOrdinals), @@ -322,6 +321,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor RewriteDeleteFromTable :: RewriteUpdateTable :: RewriteMergeIntoTable :: + BindParameters :: typeCoercionRules ++ Seq( ResolveWithCTE, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index 13404797490..a6072dcdd2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -68,22 +68,33 @@ abstract class ParameterizedQuery(child: LogicalPlan) extends UnresolvedUnaryNod * The logical plan representing a parameterized query with named parameters. * * @param child The parameterized logical plan. - * @param args The map of parameter names to its literal values. + * @param argNames Argument names. + * @param argValues A sequence of argument values matched to argument names `argNames`. */ -case class NameParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) +case class NameParameterizedQuery( + child: LogicalPlan, + argNames: Seq[String], + argValues: Seq[Expression]) extends ParameterizedQuery(child) { - assert(args.nonEmpty) + assert(argNames.nonEmpty && argValues.nonEmpty) override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(child = newChild) } +object NameParameterizedQuery { + def apply(child: LogicalPlan, args: Map[String, Expression]): NameParameterizedQuery = { + val argsSeq = args.toSeq + new NameParameterizedQuery(child, argsSeq.map(_._1), argsSeq.map(_._2)) + } +} + /** * The logical plan representing a parameterized query with positional parameters. * * @param child The parameterized logical plan. * @param args The literal values of positional parameters. */ -case class PosParameterizedQuery(child: LogicalPlan, args: Array[Expression]) +case class PosParameterizedQuery(child: LogicalPlan, args: Seq[Expression]) extends ParameterizedQuery(child) { assert(args.nonEmpty) override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = @@ -124,8 +135,13 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) { // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE // relations are not children of `UnresolvedWith`. - case NameParameterizedQuery(child, args) - if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_._2.resolved) => + case NameParameterizedQuery(child, argNames, argValues) + if !child.containsPattern(UNRESOLVED_WITH) && argValues.forall(_.resolved) => + if (argNames.length != argValues.length) { + throw SparkException.internalError(s"The number of argument names ${argNames.length} " + + s"must be equal to the number of argument values ${argValues.length}.") + } + val args = argNames.zip(argValues).toMap checkArgs(args) bind(child) { case NamedParameter(name) if args.contains(name) => args(name) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f8fedc0500c..97ba471dc21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1432,7 +1432,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CTERelationDef.curId.set(0) val actual1 = PosParameterizedQuery( child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT ?"), - args = Array(Literal(10))).analyze + args = Seq(Literal(10))).analyze CTERelationDef.curId.set(0) val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze comparePlans(actual1, expected1) @@ -1440,7 +1440,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CTERelationDef.curId.set(0) val actual2 = PosParameterizedQuery( child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < ?"), - args = Array(Literal(20), Literal(10))).analyze + args = Seq(Literal(20), Literal(10))).analyze CTERelationDef.curId.set(0) val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze comparePlans(actual2, expected2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 6e361e70bd9..2a24f0cc399 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -21,7 +21,7 @@ import java.time.{Instant, LocalDate, LocalDateTime, ZoneId} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.functions.{array, lit, map, map_from_arrays, map_from_entries, str_to_map, struct} +import org.apache.spark.sql.functions.{array, call_function, lit, map, map_from_arrays, map_from_entries, str_to_map, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -535,17 +535,29 @@ class ParametersSuite extends QueryTest with SharedSparkSession { def fromArr(keys: Array[_], values: Array[_]): Column = { map_from_arrays(Column(Literal(keys)), Column(Literal(values))) } + def callFromArr(keys: Array[_], values: Array[_]): Column = { + call_function("map_from_arrays", Column(Literal(keys)), Column(Literal(values))) + } def createMap(keys: Array[_], values: Array[_]): Column = { val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => Column(Literal(v)))) map(zipped.map { case (k, v) => Seq(k, v) }.flatten: _*) } + def callMap(keys: Array[_], values: Array[_]): Column = { + val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => Column(Literal(v)))) + call_function("map", zipped.map { case (k, v) => Seq(k, v) }.flatten: _*) + } def fromEntries(keys: Array[_], values: Array[_]): Column = { val structures = keys.zip(values) .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))} map_from_entries(array(structures: _*)) } + def callFromEntries(keys: Array[_], values: Array[_]): Column = { + val structures = keys.zip(values) + .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))} + call_function("map_from_entries", call_function("array", structures: _*)) + } - Seq(fromArr(_, _), createMap(_, _)).foreach { f => + Seq(fromArr(_, _), createMap(_, _), callFromArr(_, _), callMap(_, _)).foreach { f => checkAnswer( spark.sql("SELECT map_contains_key(:mapParam, 0)", Map("mapParam" -> f(Array.empty[Int], Array.empty[String]))), @@ -555,7 +567,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Array(f(Array.empty[String], Array.empty[Double]))), Row(false)) } - Seq(fromArr(_, _), createMap(_, _), fromEntries(_, _)).foreach { f => + Seq(fromArr(_, _), createMap(_, _), fromEntries(_, _), + callFromArr(_, _), callMap(_, _), callFromEntries(_, _)).foreach { f => checkAnswer( spark.sql("SELECT element_at(:mapParam, 'a')", Map("mapParam" -> f(Array("a"), Array(0)))), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org