This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cd672b09ac6 [SPARK-45162][SQL] Support maps and array parameters 
constructed via `call_function`
cd672b09ac6 is described below

commit cd672b09ac69724cd99dc12c9bb49dd117025be1
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Thu Sep 14 11:31:56 2023 +0300

    [SPARK-45162][SQL] Support maps and array parameters constructed via 
`call_function`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to move the `BindParameters` rules from the 
`Substitution` to the `Resolution` batch, and change types of the `args` 
parameter of `NameParameterizedQuery` and `PosParameterizedQuery` to an 
`Iterable` to resolve argument expressions.
    
    ### Why are the changes needed?
    After the PR, the parameterized `sql()` allows map/array/struct constructed 
by functions like `map()`, `array()`, and `struct()`, but the same functions 
invoked via `call_function` are not supported:
    ```scala
    scala> sql("SELECT element_at(:mapParam, 'a')", Map("mapParam" -> 
call_function("map", lit("a"), lit(1))))
    org.apache.spark.sql.catalyst.ExtendedAnalysisException: 
[UNBOUND_SQL_PARAMETER] Found the unbound parameter: mapParam. Please, fix 
`args` and provide a mapping of the parameter to a SQL literal.; line 1 pos 18;
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No, should not since it fixes an issue. Only if user code depends on the 
error message.
    
    After the changes:
    ```scala
    scala> sql("SELECT element_at(:mapParam, 'a')", Map("mapParam" -> 
call_function("map", lit("a"), lit(1)))).show(false)
    +------------------------+
    |element_at(map(a, 1), a)|
    +------------------------+
    |1                       |
    +------------------------+
    ```
    
    ### How was this patch tested?
    By running new tests:
    ```
    $ build/sbt "test:testOnly *ParametersSuite"
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #42894 from MaxGekk/fix-parameterized-sql-unresolved.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  2 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  2 +-
 .../spark/sql/catalyst/analysis/parameters.scala   | 28 +++++++++++++++++-----
 .../sql/catalyst/analysis/AnalysisSuite.scala      |  4 ++--
 .../org/apache/spark/sql/ParametersSuite.scala     | 19 ++++++++++++---
 5 files changed, 42 insertions(+), 13 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 24dee006f0b..74a8ff290eb 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -269,7 +269,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
     if (!args.isEmpty) {
       NameParameterizedQuery(parsedPlan, 
args.asScala.mapValues(transformLiteral).toMap)
     } else if (!posArgs.isEmpty) {
-      PosParameterizedQuery(parsedPlan, 
posArgs.asScala.map(transformLiteral).toArray)
+      PosParameterizedQuery(parsedPlan, 
posArgs.asScala.map(transformLiteral).toSeq)
     } else {
       parsedPlan
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e15b9730111..6491a4eea95 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -260,7 +260,6 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       // at the beginning of analysis.
       OptimizeUpdateFields,
       CTESubstitution,
-      BindParameters,
       WindowsSubstitution,
       EliminateUnions,
       SubstituteUnresolvedOrdinals),
@@ -322,6 +321,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       RewriteDeleteFromTable ::
       RewriteUpdateTable ::
       RewriteMergeIntoTable ::
+      BindParameters ::
       typeCoercionRules ++
       Seq(
         ResolveWithCTE,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
index 13404797490..a6072dcdd2c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -68,22 +68,33 @@ abstract class ParameterizedQuery(child: LogicalPlan) 
extends UnresolvedUnaryNod
  * The logical plan representing a parameterized query with named parameters.
  *
  * @param child The parameterized logical plan.
- * @param args The map of parameter names to its literal values.
+ * @param argNames Argument names.
+ * @param argValues A sequence of argument values matched to argument names 
`argNames`.
  */
-case class NameParameterizedQuery(child: LogicalPlan, args: Map[String, 
Expression])
+case class NameParameterizedQuery(
+    child: LogicalPlan,
+    argNames: Seq[String],
+    argValues: Seq[Expression])
   extends ParameterizedQuery(child) {
-  assert(args.nonEmpty)
+  assert(argNames.nonEmpty && argValues.nonEmpty)
   override protected def withNewChildInternal(newChild: LogicalPlan): 
LogicalPlan =
     copy(child = newChild)
 }
 
+object NameParameterizedQuery {
+  def apply(child: LogicalPlan, args: Map[String, Expression]): 
NameParameterizedQuery = {
+    val argsSeq = args.toSeq
+    new NameParameterizedQuery(child, argsSeq.map(_._1), argsSeq.map(_._2))
+  }
+}
+
 /**
  * The logical plan representing a parameterized query with positional 
parameters.
  *
  * @param child The parameterized logical plan.
  * @param args The literal values of positional parameters.
  */
-case class PosParameterizedQuery(child: LogicalPlan, args: Array[Expression])
+case class PosParameterizedQuery(child: LogicalPlan, args: Seq[Expression])
   extends ParameterizedQuery(child) {
   assert(args.nonEmpty)
   override protected def withNewChildInternal(newChild: LogicalPlan): 
LogicalPlan =
@@ -124,8 +135,13 @@ object BindParameters extends Rule[LogicalPlan] with 
QueryErrorsBase {
     plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
       // We should wait for `CTESubstitution` to resolve CTE before binding 
parameters, as CTE
       // relations are not children of `UnresolvedWith`.
-      case NameParameterizedQuery(child, args)
-        if !child.containsPattern(UNRESOLVED_WITH) && 
args.forall(_._2.resolved) =>
+      case NameParameterizedQuery(child, argNames, argValues)
+        if !child.containsPattern(UNRESOLVED_WITH) && 
argValues.forall(_.resolved) =>
+        if (argNames.length != argValues.length) {
+          throw SparkException.internalError(s"The number of argument names 
${argNames.length} " +
+            s"must be equal to the number of argument values 
${argValues.length}.")
+        }
+        val args = argNames.zip(argValues).toMap
         checkArgs(args)
         bind(child) { case NamedParameter(name) if args.contains(name) => 
args(name) }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index f8fedc0500c..97ba471dc21 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1432,7 +1432,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     CTERelationDef.curId.set(0)
     val actual1 = PosParameterizedQuery(
       child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT ?"),
-      args = Array(Literal(10))).analyze
+      args = Seq(Literal(10))).analyze
     CTERelationDef.curId.set(0)
     val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 
10").analyze
     comparePlans(actual1, expected1)
@@ -1440,7 +1440,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     CTERelationDef.curId.set(0)
     val actual2 = PosParameterizedQuery(
       child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < ?"),
-      args = Array(Literal(20), Literal(10))).analyze
+      args = Seq(Literal(20), Literal(10))).analyze
     CTERelationDef.curId.set(0)
     val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c 
< 20").analyze
     comparePlans(actual2, expected2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
index 6e361e70bd9..2a24f0cc399 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
@@ -21,7 +21,7 @@ import java.time.{Instant, LocalDate, LocalDateTime, ZoneId}
 
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.functions.{array, lit, map, map_from_arrays, 
map_from_entries, str_to_map, struct}
+import org.apache.spark.sql.functions.{array, call_function, lit, map, 
map_from_arrays, map_from_entries, str_to_map, struct}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 
@@ -535,17 +535,29 @@ class ParametersSuite extends QueryTest with 
SharedSparkSession {
     def fromArr(keys: Array[_], values: Array[_]): Column = {
       map_from_arrays(Column(Literal(keys)), Column(Literal(values)))
     }
+    def callFromArr(keys: Array[_], values: Array[_]): Column = {
+      call_function("map_from_arrays", Column(Literal(keys)), 
Column(Literal(values)))
+    }
     def createMap(keys: Array[_], values: Array[_]): Column = {
       val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => 
Column(Literal(v))))
       map(zipped.map { case (k, v) => Seq(k, v) }.flatten: _*)
     }
+    def callMap(keys: Array[_], values: Array[_]): Column = {
+      val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => 
Column(Literal(v))))
+      call_function("map", zipped.map { case (k, v) => Seq(k, v) }.flatten: _*)
+    }
     def fromEntries(keys: Array[_], values: Array[_]): Column = {
       val structures = keys.zip(values)
         .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))}
       map_from_entries(array(structures: _*))
     }
+    def callFromEntries(keys: Array[_], values: Array[_]): Column = {
+      val structures = keys.zip(values)
+        .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))}
+      call_function("map_from_entries", call_function("array", structures: _*))
+    }
 
-    Seq(fromArr(_, _), createMap(_, _)).foreach { f =>
+    Seq(fromArr(_, _), createMap(_, _), callFromArr(_, _), callMap(_, 
_)).foreach { f =>
       checkAnswer(
         spark.sql("SELECT map_contains_key(:mapParam, 0)",
           Map("mapParam" -> f(Array.empty[Int], Array.empty[String]))),
@@ -555,7 +567,8 @@ class ParametersSuite extends QueryTest with 
SharedSparkSession {
           Array(f(Array.empty[String], Array.empty[Double]))),
         Row(false))
     }
-    Seq(fromArr(_, _), createMap(_, _), fromEntries(_, _)).foreach { f =>
+    Seq(fromArr(_, _), createMap(_, _), fromEntries(_, _),
+      callFromArr(_, _), callMap(_, _), callFromEntries(_, _)).foreach { f =>
       checkAnswer(
         spark.sql("SELECT element_at(:mapParam, 'a')",
           Map("mapParam" -> f(Array("a"), Array(0)))),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to