This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new c51f644  [SPARK-37392][SQL] Fix the performance bug when inferring 
constraints for Generate
c51f644 is described below

commit c51f6449d38d30d0bff22df895dca515898a520b
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Wed Dec 8 13:04:40 2021 +0800

    [SPARK-37392][SQL] Fix the performance bug when inferring constraints for 
Generate
    
    This is a performance regression since Spark 3.1, caused by 
https://issues.apache.org/jira/browse/SPARK-32295
    
    If you run the query in the JIRA ticket
    ```
    Seq(
      (1, "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", 
"x", "x", "x", "x", "x", "x")
    ).toDF()
      .checkpoint() // or save and reload to truncate lineage
      .createOrReplaceTempView("sub")
    
    session.sql("""
      SELECT
        *
      FROM
      (
        SELECT
          EXPLODE( ARRAY( * ) ) result
        FROM
        (
          SELECT
            _1 a, _2 b, _3 c, _4 d, _5 e, _6 f, _7 g, _8 h, _9 i, _10 j, _11 k, 
_12 l, _13 m, _14 n, _15 o, _16 p, _17 q, _18 r, _19 s, _20 t, _21 u
          FROM
            sub
        )
      )
      WHERE
        result != ''
      """).show()
    ```
    You will hit OOM. The reason is that:
    1. We infer additional predicates with `Generate`. In this case, it's 
`size(array(cast(_1#21 as string), _2#22, _3#23, ...) > 0`
    2. Because of the cast, the `ConstantFolding` rule can't optimize this 
`size(array(...))`.
    3. We end up with a plan containing this part
    ```
       +- Project [_1#21 AS a#106, _2#22 AS b#107, _3#23 AS c#108, _4#24 AS 
d#109, _5#25 AS e#110, _6#26 AS f#111, _7#27 AS g#112, _8#28 AS h#113, _9#29 AS 
i#114, _10#30 AS j#115, _11#31 AS k#116, _12#32 AS l#117, _13#33 AS m#118, 
_14#34 AS n#119, _15#35 AS o#120, _16#36 AS p#121, _17#37 AS q#122, _18#38 AS 
r#123, _19#39 AS s#124, _20#40 AS t#125, _21#41 AS u#126]
          +- Filter (size(array(cast(_1#21 as string), _2#22, _3#23, _4#24, 
_5#25, _6#26, _7#27, _8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34, 
_15#35, _16#36, _17#37, _18#38, _19#39, _20#40, _21#41), true) > 0)
             +- LogicalRDD [_1#21, _2#22, _3#23, _4#24, _5#25, _6#26, _7#27, 
_8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34, _15#35, _16#36, _17#37, 
_18#38, _19#39, _20#40, _21#41]
    ```
    When calculating the constraints of the `Project`, we generate around 2^20 
expressions, due to this code
    ```
    var allConstraints = child.constraints
    projectList.foreach {
      case a  Alias(l: Literal, _) =>
        allConstraints += EqualNullSafe(a.toAttribute, l)
      case a  Alias(e, _) =>
        // For every alias in `projectList`, replace the reference in 
constraints by its attribute.
        allConstraints ++= allConstraints.map(_ transform {
          case expr: Expression if expr.semanticEquals(e) =>
            a.toAttribute
        })
        allConstraints += EqualNullSafe(e, a.toAttribute)
      case _ => // Don't change.
    }
    ```
    
    There are 3 issues here:
    1. We may infer complicated predicates from `Generate`
    2. `ConstanFolding` rule is too conservative. At least `Cast` has no side 
effect with ANSI-off.
    3. When calculating constraints, we should have a upper bound to avoid 
generating too many expressions.
    
    This fixes the first 2 issues, and leaves the third one for the future.
    
    fix a performance issue
    
    no
    
    new tests, and run the query in JIRA ticket locally.
    
    Closes #34823 from cloud-fan/perf.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 1fac7a9d9992b7c120f325cdfa6a935b52c7f3bc)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 41 +++++----
 .../spark/sql/catalyst/optimizer/expressions.scala |  1 +
 .../optimizer/InferFiltersFromGenerateSuite.scala  | 98 ++++++++++------------
 3 files changed, 67 insertions(+), 73 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 99b5240..e39fa23 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -893,25 +893,30 @@ object TransposeWindow extends Rule[LogicalPlan] {
  * by this [[Generate]] can be removed earlier - before joins and in data 
sources.
  */
 object InferFiltersFromGenerate extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    // This rule does not infer filters from foldable expressions to avoid 
constant filters
-    // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints 
and
-    // then the idempotence will break.
-    case generate @ Generate(e, _, _, _, _, _)
-      if !e.deterministic || e.children.forall(_.foldable) ||
-        e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate
-
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
     case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
-      // Exclude child's constraints to guarantee idempotency
-      val inferredFilters = ExpressionSet(
-        Seq(
-          GreaterThan(Size(g.children.head), Literal(0)),
-          IsNotNull(g.children.head)
-        )
-      ) -- generate.child.constraints
-
-      if (inferredFilters.nonEmpty) {
-        generate.copy(child = Filter(inferredFilters.reduce(And), 
generate.child))
+      assert(g.children.length == 1)
+      val input = g.children.head
+      // Generating extra predicates here has overheads/risks:
+      //   - We may evaluate expensive input expressions multiple times.
+      //   - We may infer too many constraints later.
+      //   - The input expression may fail to be evaluated under ANSI mode. If 
we reorder the
+      //     predicates and evaluate the input expression first, we may fail 
the query unexpectedly.
+      // To be safe, here we only generate extra predicates if the input is an 
attribute.
+      // Note that, foldable input is also excluded here, to avoid constant 
filters like
+      // 'size([1, 2, 3]) > 0'. These do not show up in child's constraints 
and then the
+      // idempotence will break.
+      if (input.isInstanceOf[Attribute]) {
+        // Exclude child's constraints to guarantee idempotency
+        val inferredFilters = ExpressionSet(
+          Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input))
+        ) -- generate.child.constraints
+
+        if (inferredFilters.nonEmpty) {
+          generate.copy(child = Filter(inferredFilters.reduce(And), 
generate.child))
+        } else {
+          generate
+        }
       } else {
         generate
       }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index d989753..78098c4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -46,6 +46,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
   private def hasNoSideEffect(e: Expression): Boolean = e match {
     case _: Attribute => true
     case _: Literal => true
+    case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child)
     case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect)
     case _ => false
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
index 800d37e..61ab4f0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
@@ -18,10 +18,8 @@
 package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -36,7 +34,7 @@ class InferFiltersFromGenerateSuite extends PlanTest {
   val testRelation = LocalRelation('a.array(StructType(Seq(
     StructField("x", IntegerType),
     StructField("y", IntegerType)
-  ))), 'c1.string, 'c2.string)
+  ))), 'c1.string, 'c2.string, 'c3.int)
 
   Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
     val generator = f('a)
@@ -74,63 +72,53 @@ class InferFiltersFromGenerateSuite extends PlanTest {
       val optimized = Optimize.execute(originalQuery)
       comparePlans(optimized, originalQuery)
     }
-  }
 
-  // setup rules to test inferFilters with ConstantFolding to make sure
-  // the Filter rule added in inferFilters is removed again when doing
-  // explode with CreateArray/CreateMap
-  object OptimizeInferAndConstantFold extends RuleExecutor[LogicalPlan] {
-    val batches =
-      Batch("AnalysisNodes", Once,
-        EliminateSubqueryAliases) ::
-      Batch("Infer Filters", Once, InferFiltersFromGenerate) ::
-      Batch("ConstantFolding after", FixedPoint(4),
-        ConstantFolding,
-        NullPropagation,
-        PruneFilters) :: Nil
+    val generatorWithFromJson = f(JsonToStructs(
+      ArrayType(new StructType().add("s", "string")),
+      Map.empty,
+      'c1))
+    test("SPARK-37392: Don't infer filters from " + generatorWithFromJson) {
+      val originalQuery = testRelation.generate(generatorWithFromJson).analyze
+      val optimized = Optimize.execute(originalQuery)
+      comparePlans(optimized, originalQuery)
+    }
+
+    val returnSchema = ArrayType(StructType(Seq(
+      StructField("x", IntegerType),
+      StructField("y", StringType)
+    )))
+    val fakeUDF = ScalaUDF(
+      (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
+      returnSchema, 'c3 :: Nil, Nil)
+    val generatorWithUDF = f(fakeUDF)
+    test("SPARK-36715: Don't infer filters from " + generatorWithUDF) {
+      val originalQuery = testRelation.generate(generatorWithUDF).analyze
+      val optimized = Optimize.execute(originalQuery)
+      comparePlans(optimized, originalQuery)
+    }
   }
 
   Seq(Explode(_), PosExplode(_)).foreach { f =>
-     val createArrayExplode = f(CreateArray(Seq('c1)))
-     test("SPARK-33544: Don't infer filters from CreateArray " + 
createArrayExplode) {
-       val originalQuery = testRelation.generate(createArrayExplode).analyze
-       val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
-       comparePlans(optimized, originalQuery)
-     }
-     val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
-     test("SPARK-33544: Don't infer filters from CreateMap " + 
createMapExplode) {
-       val originalQuery = testRelation.generate(createMapExplode).analyze
-       val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
-       comparePlans(optimized, originalQuery)
-     }
-   }
-
-   Seq(Inline(_)).foreach { f =>
-     val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
-     test("SPARK-33544: Don't infer filters from CreateArray " + 
createArrayStructExplode) {
-       val originalQuery = 
testRelation.generate(createArrayStructExplode).analyze
-       val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
-       comparePlans(optimized, originalQuery)
-     }
-   }
+    val createArrayExplode = f(CreateArray(Seq('c1)))
+    test("SPARK-33544: Don't infer filters from " + createArrayExplode) {
+      val originalQuery = testRelation.generate(createArrayExplode).analyze
+      val optimized = Optimize.execute(originalQuery)
+      comparePlans(optimized, originalQuery)
+    }
+    val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
+    test("SPARK-33544: Don't infer filters from " + createMapExplode) {
+      val originalQuery = testRelation.generate(createMapExplode).analyze
+      val optimized = Optimize.execute(originalQuery)
+      comparePlans(optimized, originalQuery)
+    }
+  }
 
-  test("SPARK-36715: Don't infer filters from udf") {
-    Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
-      val returnSchema = ArrayType(StructType(Seq(
-        StructField("x", IntegerType),
-        StructField("y", StringType)
-      )))
-      val fakeUDF = ScalaUDF(
-        (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
-        returnSchema, Literal(8) :: Nil,
-        Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil)
-      val generator = f(fakeUDF)
-      val originalQuery = OneRowRelation().generate(generator).analyze
-      val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
-      val correctAnswer = OneRowRelation()
-        .generate(generator)
-        .analyze
-      comparePlans(optimized, correctAnswer)
+  Seq(Inline(_)).foreach { f =>
+    val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
+    test("SPARK-33544: Don't infer filters from " + createArrayStructExplode) {
+      val originalQuery = 
testRelation.generate(createArrayStructExplode).analyze
+      val optimized = Optimize.execute(originalQuery)
+      comparePlans(optimized, originalQuery)
     }
   }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to