This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3dda58a  [SPARK-26370][SQL] Fix resolution of higher-order function 
for the same identifier.
3dda58a is described below

commit 3dda58af2b7f42beab736d856bf17b4d35c8866c
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Sat Dec 15 00:23:28 2018 +0800

    [SPARK-26370][SQL] Fix resolution of higher-order function for the same 
identifier.
    
    ## What changes were proposed in this pull request?
    
    When using a higher-order function with the same variable name as the 
existing columns in `Filter` or something which uses 
`Analyzer.resolveExpressionBottomUp` during the resolution, e.g.,:
    
    ```scala
    val df = Seq(
      (Seq(1, 9, 8, 7), 1, 2),
      (Seq(5, 9, 7), 2, 2),
      (Seq.empty, 3, 2),
      (null, 4, 2)
    ).toDF("i", "x", "d")
    
    checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
      Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
    checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
      Seq(Row(1)))
    ```
    
    the following exception happens:
    
    ```
    java.lang.ClassCastException: 
org.apache.spark.sql.catalyst.expressions.BoundReference cannot be cast to 
org.apache.spark.sql.catalyst.expressions.NamedExpression
      at 
scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
      at 
scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
      at 
scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
      at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
      at scala.collection.TraversableLike.map(TraversableLike.scala:237)
      at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
      at scala.collection.AbstractTraversable.map(Traversable.scala:108)
      at 
org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.$anonfun$functionsForEval$1(higherOrderFunctions.scala:147)
      at 
scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
      at scala.collection.immutable.List.foreach(List.scala:392)
      at scala.collection.TraversableLike.map(TraversableLike.scala:237)
      at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
      at scala.collection.immutable.List.map(List.scala:298)
      at 
org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval(higherOrderFunctions.scala:145)
      at 
org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval$(higherOrderFunctions.scala:145)
      at 
org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval$lzycompute(higherOrderFunctions.scala:369)
      at 
org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval(higherOrderFunctions.scala:369)
      at 
org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval(higherOrderFunctions.scala:176)
      at 
org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval$(higherOrderFunctions.scala:176)
      at 
org.apache.spark.sql.catalyst.expressions.ArrayExists.functionForEval(higherOrderFunctions.scala:369)
      at 
org.apache.spark.sql.catalyst.expressions.ArrayExists.nullSafeEval(higherOrderFunctions.scala:387)
      at 
org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval(higherOrderFunctions.scala:190)
      at 
org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval$(higherOrderFunctions.scala:185)
      at 
org.apache.spark.sql.catalyst.expressions.ArrayExists.eval(higherOrderFunctions.scala:369)
      at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown
 Source)
      at 
org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3(basicPhysicalOperators.scala:216)
      at 
org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3$adapted(basicPhysicalOperators.scala:215)
    
    ...
    ```
    
    because the `UnresolvedAttribute`s in `LambdaFunction` are unexpectedly 
resolved by the rule.
    
    This pr modified to use a placeholder `UnresolvedNamedLambdaVariable` to 
prevent unexpected resolution.
    
    ## How was this patch tested?
    
    Added a test and modified some tests.
    
    Closes #23320 from ueshin/issues/SPARK-26370/hof_resolution.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/analysis/higherOrderFunctions.scala   |  5 +++--
 .../expressions/higherOrderFunctions.scala         | 26 ++++++++++++++++++++--
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  7 ++++--
 .../analysis/ResolveLambdaVariablesSuite.scala     | 10 +++++----
 .../ReplaceNullWithFalseInPredicateSuite.scala     | 14 +++++++-----
 .../catalyst/parser/ExpressionParserSuite.scala    |  6 +++--
 .../results/typeCoercion/native/mapZipWith.sql.out |  4 ++--
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++++++++++
 8 files changed, 72 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index a8a7bbd..1cd7f41 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends 
Rule[LogicalPlan] {
       val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
       l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))
 
-    case u @ UnresolvedAttribute(name +: nestedFields) =>
+    case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
       parentLambdaMap.get(canonicalizer(name)) match {
         case Some(lambda) =>
           nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
             ExtractValue(expr, Literal(fieldName), conf.resolver)
           }
-        case None => u
+        case None =>
+          UnresolvedAttribute(u.nameParts)
       }
 
     case _ =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index a8639d2..7141b6e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -22,13 +22,35 @@ import java.util.concurrent.atomic.AtomicReference
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute, UnresolvedException}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.array.ByteArrayMethods
 
 /**
+ * A placeholder of lambda variables to prevent unexpected resolution of 
[[LambdaFunction]].
+ */
+case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
+  extends LeafExpression with NamedExpression with Unevaluable {
+
+  override def name: String =
+    nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
+
+  override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
+  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
+  override def qualifier: Seq[String] = throw new UnresolvedException(this, 
"qualifier")
+  override def toAttribute: Attribute = throw new UnresolvedException(this, 
"toAttribute")
+  override def newInstance(): NamedExpression = throw new 
UnresolvedException(this, "newInstance")
+  override lazy val resolved = false
+
+  override def toString: String = s"lambda '$name"
+
+  override def sql: String = name
+}
+
+/**
  * A named lambda variable.
  */
 case class NamedLambdaVariable(
@@ -79,7 +101,7 @@ case class LambdaFunction(
 
 object LambdaFunction {
   val identity: LambdaFunction = {
-    val id = UnresolvedAttribute.quoted("id")
+    val id = UnresolvedNamedLambdaVariable(Seq("id"))
     LambdaFunction(id, Seq(id))
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 672bffc..8959f78 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
    */
   override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
     val arguments = ctx.IDENTIFIER().asScala.map { name =>
-      UnresolvedAttribute.quoted(name.getText)
+      
UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
     }
-    LambdaFunction(expression(ctx.expression), arguments)
+    val function = expression(ctx.expression).transformUp {
+      case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+    }
+    LambdaFunction(function, arguments)
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
index c4171c7..a5847ba 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
@@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
     comparePlans(Analyzer.execute(plan(e1)), plan(e2))
   }
 
+  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
   test("resolution - no op") {
     checkExpression(key, key)
   }
 
   test("resolution - simple") {
-    val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: 
Nil))
+    val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
     val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
     checkExpression(in, out)
   }
 
   test("resolution - nested") {
     val in = ArrayTransform(values2, LambdaFunction(
-      ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 
'x.attr :: Nil))
+      ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), 
lv('x) :: Nil))
     val out = ArrayTransform(values2, LambdaFunction(
       ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), 
lvArray :: Nil))
     checkExpression(in, out)
@@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {
 
   test("fail - name collisions") {
     val p = plan(ArrayTransform(values1,
-      LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
+      LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
     val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
     assert(msg.contains("arguments should not have names that are semantically 
the same"))
   }
 
   test("fail - lambda arguments") {
     val p = plan(ArrayTransform(values1,
-      LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 
'z.attr :: Nil)))
+      LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: 
Nil)))
     val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
     assert(msg.contains("does not match the number of arguments expected"))
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
index ee0d04d..748075b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, 
ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, 
LambdaFunction, Literal, MapFilter, NamedExpression, Or}
+import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, 
ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, 
LambdaFunction, Literal, MapFilter, NamedExpression, Or, 
UnresolvedNamedLambdaVariable}
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
@@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends 
PlanTest {
     testProjection(originalExpr = column, expectedExpr = column)
   }
 
+  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
   test("replace nulls in lambda function of ArrayFilter") {
-    testHigherOrderFunc('a, ArrayFilter, Seq('e))
+    testHigherOrderFunc('a, ArrayFilter, Seq(lv('e)))
   }
 
   test("replace nulls in lambda function of ArrayExists") {
-    testHigherOrderFunc('a, ArrayExists, Seq('e))
+    testHigherOrderFunc('a, ArrayExists, Seq(lv('e)))
   }
 
   test("replace nulls in lambda function of MapFilter") {
-    testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
+    testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v)))
   }
 
   test("inability to replace nulls in arbitrary higher-order function") {
     val lambdaFunc = LambdaFunction(
-      function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
-      arguments = Seq[NamedExpression]('e))
+      function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral),
+      arguments = Seq[NamedExpression](lv('e)))
     val column = ArrayTransform('a, lambdaFunc)
     testProjection(originalExpr = column, expectedExpr = column)
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index b4df22c..8bcc69d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
     intercept("foo(a x)", "extraneous input 'x'")
   }
 
+  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
   test("lambda functions") {
-    assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
-    assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 
'y.attr)))
+    assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
+    assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), 
lv('y))))
   }
 
   test("window function expressions") {
diff --git 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
index 3574009..86a578c 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
@@ -85,7 +85,7 @@ FROM various_maps
 struct<>
 -- !query 5 output
 org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map1`, 
various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), 
`k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due 
to argument data type mismatch: The input to function map_zip_with should have 
been two maps with compatible key types, but the key types are [decimal(36,0), 
decimal(36,35)].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map1`, 
various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, 
NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument 
data type mismatch: The input to function map_zip_with should have been two 
maps with compatible key types, but the key types are [decimal(36,0), 
decimal(36,35)].; line 1 pos 7
 
 
 -- !query 6
@@ -113,7 +113,7 @@ FROM various_maps
 struct<>
 -- !query 8 output
 org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map2`, 
various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, 
NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to 
argument data type mismatch: The input to function map_zip_with should have 
been two maps with compatible key types, but the key types are [decimal(36,35), 
int].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map2`, 
various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, 
NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument 
data type mismatch: The input to function map_zip_with should have been two 
maps with compatible key types, but the key types are [decimal(36,35), int].; 
line 1 pos 7
 
 
 -- !query 9
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index e6d1a03..b7fc957 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     }
     assert(ex.getMessage.contains("Cannot use null as map key"))
   }
+
+  test("SPARK-26370: Fix resolution of higher-order function for the same 
identifier") {
+    val df = Seq(
+      (Seq(1, 9, 8, 7), 1, 2),
+      (Seq(5, 9, 7), 2, 2),
+      (Seq.empty, 3, 2),
+      (null, 4, 2)
+    ).toDF("i", "x", "d")
+
+    checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
+      Seq(
+        Row(1, true),
+        Row(2, false),
+        Row(3, false),
+        Row(4, null)))
+    checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
+      Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
+    checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
+      Seq(Row(1)))
+  }
 }
 
 object DataFrameFunctionsSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to