This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 7a27ea7  [SPARK-36715][SQL] InferFiltersFromGenerate should not infer 
filter for udf
7a27ea7 is described below

commit 7a27ea7382e3460b43dcc36dab6f31b2a0a87565
Author: Fu Chen <cfmcgr...@gmail.com>
AuthorDate: Tue Sep 14 09:26:11 2021 +0900

    [SPARK-36715][SQL] InferFiltersFromGenerate should not infer filter for udf
    
    ### What changes were proposed in this pull request?
    
    Fix InferFiltersFromGenerate bug, InferFiltersFromGenerate should not infer 
filter for generate when the children contain an expression which is instance 
of `org.apache.spark.sql.catalyst.expressions.UserDefinedExpression`.
    Before this pr, the following case will throw an exception.
    
    ```scala
    spark.udf.register("vec", (i: Int) => (0 until i).toArray)
    sql("select explode(vec(8)) as c1").show
    ```
    
    ```
    Once strategy's idempotence is broken for batch Infer Filters
     GlobalLimit 21                                                        
GlobalLimit 21
     +- LocalLimit 21                                                      +- 
LocalLimit 21
        +- Project [cast(c1#3 as string) AS c1#12]                            
+- Project [cast(c1#3 as string) AS c1#12]
           +- Generate explode(vec(8)), false, [c1#3]                           
 +- Generate explode(vec(8)), false, [c1#3]
              +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))        
    +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
    !            +- OneRowRelation                                              
       +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
    !                                                                           
          +- OneRowRelation
    
    java.lang.RuntimeException:
    Once strategy's idempotence is broken for batch Infer Filters
     GlobalLimit 21                                                        
GlobalLimit 21
     +- LocalLimit 21                                                      +- 
LocalLimit 21
        +- Project [cast(c1#3 as string) AS c1#12]                            
+- Project [cast(c1#3 as string) AS c1#12]
           +- Generate explode(vec(8)), false, [c1#3]                           
 +- Generate explode(vec(8)), false, [c1#3]
              +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))        
    +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
    !            +- OneRowRelation                                              
       +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
    !                                                                           
          +- OneRowRelation
    
        at 
org.apache.spark.sql.errors.QueryExecutionErrors$.onceStrategyIdempotenceIsBrokenForBatchError(QueryExecutionErrors.scala:1200)
        at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.checkBatchIdempotence(RuleExecutor.scala:168)
        at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:254)
        at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:200)
        at scala.collection.immutable.List.foreach(List.scala:431)
        at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:200)
        at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:179)
        at 
org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:88)
        at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:179)
        at 
org.apache.spark.sql.execution.QueryExecution.$anonfun$optimizedPlan$1(QueryExecution.scala:138)
        at 
org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
        at 
org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:196)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
        at 
org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:196)
        at 
org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:134)
        at 
org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:130)
        at 
org.apache.spark.sql.execution.QueryExecution.assertOptimized(QueryExecution.scala:148)
        at 
org.apache.spark.sql.execution.QueryExecution.$anonfun$executedPlan$1(QueryExecution.scala:166)
        at 
org.apache.spark.sql.execution.QueryExecution.withCteMap(QueryExecution.scala:73)
        at 
org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:163)
        at 
org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:163)
        at 
org.apache.spark.sql.execution.QueryExecution.simpleString(QueryExecution.scala:214)
        at 
org.apache.spark.sql.execution.QueryExecution.org$apache$spark$sql$execution$QueryExecution$$explainString(QueryExecution.scala:259)
        at 
org.apache.spark.sql.execution.QueryExecution.explainString(QueryExecution.scala:228)
        at 
org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:98)
        at 
org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
        at 
org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
        at 
org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
        at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3731)
        at org.apache.spark.sql.Dataset.head(Dataset.scala:2755)
        at org.apache.spark.sql.Dataset.take(Dataset.scala:2962)
        at org.apache.spark.sql.Dataset.getRows(Dataset.scala:288)
        at org.apache.spark.sql.Dataset.showString(Dataset.scala:327)
        at org.apache.spark.sql.Dataset.show(Dataset.scala:807)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, only bug fix.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #33956 from cfmcgrady/SPARK-36715.
    
    Authored-by: Fu Chen <cfmcgr...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 52c5ff20ca132653f505040a4dff522b136d2626)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  3 ++-
 .../optimizer/InferFiltersFromGenerateSuite.scala  | 24 +++++++++++++++++++++-
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 6584401..99b5240 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -898,7 +898,8 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] {
     // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints 
and
     // then the idempotence will break.
     case generate @ Generate(e, _, _, _, _, _)
-      if !e.deterministic || e.children.forall(_.foldable) => generate
+      if !e.deterministic || e.children.forall(_.foldable) ||
+        e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate
 
     case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
       // Exclude child's constraints to guarantee idempotency
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
index 93a1d41..800d37e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
@@ -17,14 +17,16 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, 
StructField, StructType}
 
 class InferFiltersFromGenerateSuite extends PlanTest {
   object Optimize extends RuleExecutor[LogicalPlan] {
@@ -111,4 +113,24 @@ class InferFiltersFromGenerateSuite extends PlanTest {
        comparePlans(optimized, originalQuery)
      }
    }
+
+  test("SPARK-36715: Don't infer filters from udf") {
+    Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
+      val returnSchema = ArrayType(StructType(Seq(
+        StructField("x", IntegerType),
+        StructField("y", StringType)
+      )))
+      val fakeUDF = ScalaUDF(
+        (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
+        returnSchema, Literal(8) :: Nil,
+        Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil)
+      val generator = f(fakeUDF)
+      val originalQuery = OneRowRelation().generate(generator).analyze
+      val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
+      val correctAnswer = OneRowRelation()
+        .generate(generator)
+        .analyze
+      comparePlans(optimized, correctAnswer)
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to