This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new c51f644 [SPARK-37392][SQL] Fix the performance bug when inferring constraints for Generate c51f644 is described below commit c51f6449d38d30d0bff22df895dca515898a520b Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Wed Dec 8 13:04:40 2021 +0800 [SPARK-37392][SQL] Fix the performance bug when inferring constraints for Generate This is a performance regression since Spark 3.1, caused by https://issues.apache.org/jira/browse/SPARK-32295 If you run the query in the JIRA ticket ``` Seq( (1, "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x") ).toDF() .checkpoint() // or save and reload to truncate lineage .createOrReplaceTempView("sub") session.sql(""" SELECT * FROM ( SELECT EXPLODE( ARRAY( * ) ) result FROM ( SELECT _1 a, _2 b, _3 c, _4 d, _5 e, _6 f, _7 g, _8 h, _9 i, _10 j, _11 k, _12 l, _13 m, _14 n, _15 o, _16 p, _17 q, _18 r, _19 s, _20 t, _21 u FROM sub ) ) WHERE result != '' """).show() ``` You will hit OOM. The reason is that: 1. We infer additional predicates with `Generate`. In this case, it's `size(array(cast(_1#21 as string), _2#22, _3#23, ...) > 0` 2. Because of the cast, the `ConstantFolding` rule can't optimize this `size(array(...))`. 3. We end up with a plan containing this part ``` +- Project [_1#21 AS a#106, _2#22 AS b#107, _3#23 AS c#108, _4#24 AS d#109, _5#25 AS e#110, _6#26 AS f#111, _7#27 AS g#112, _8#28 AS h#113, _9#29 AS i#114, _10#30 AS j#115, _11#31 AS k#116, _12#32 AS l#117, _13#33 AS m#118, _14#34 AS n#119, _15#35 AS o#120, _16#36 AS p#121, _17#37 AS q#122, _18#38 AS r#123, _19#39 AS s#124, _20#40 AS t#125, _21#41 AS u#126] +- Filter (size(array(cast(_1#21 as string), _2#22, _3#23, _4#24, _5#25, _6#26, _7#27, _8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34, _15#35, _16#36, _17#37, _18#38, _19#39, _20#40, _21#41), true) > 0) +- LogicalRDD [_1#21, _2#22, _3#23, _4#24, _5#25, _6#26, _7#27, _8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34, _15#35, _16#36, _17#37, _18#38, _19#39, _20#40, _21#41] ``` When calculating the constraints of the `Project`, we generate around 2^20 expressions, due to this code ``` var allConstraints = child.constraints projectList.foreach { case a Alias(l: Literal, _) => allConstraints += EqualNullSafe(a.toAttribute, l) case a Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } ``` There are 3 issues here: 1. We may infer complicated predicates from `Generate` 2. `ConstanFolding` rule is too conservative. At least `Cast` has no side effect with ANSI-off. 3. When calculating constraints, we should have a upper bound to avoid generating too many expressions. This fixes the first 2 issues, and leaves the third one for the future. fix a performance issue no new tests, and run the query in JIRA ticket locally. Closes #34823 from cloud-fan/perf. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 1fac7a9d9992b7c120f325cdfa6a935b52c7f3bc) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 41 +++++---- .../spark/sql/catalyst/optimizer/expressions.scala | 1 + .../optimizer/InferFiltersFromGenerateSuite.scala | 98 ++++++++++------------ 3 files changed, 67 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 99b5240..e39fa23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -893,25 +893,30 @@ object TransposeWindow extends Rule[LogicalPlan] { * by this [[Generate]] can be removed earlier - before joins and in data sources. */ object InferFiltersFromGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - // This rule does not infer filters from foldable expressions to avoid constant filters - // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and - // then the idempotence will break. - case generate @ Generate(e, _, _, _, _, _) - if !e.deterministic || e.children.forall(_.foldable) || - e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate - + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) => - // Exclude child's constraints to guarantee idempotency - val inferredFilters = ExpressionSet( - Seq( - GreaterThan(Size(g.children.head), Literal(0)), - IsNotNull(g.children.head) - ) - ) -- generate.child.constraints - - if (inferredFilters.nonEmpty) { - generate.copy(child = Filter(inferredFilters.reduce(And), generate.child)) + assert(g.children.length == 1) + val input = g.children.head + // Generating extra predicates here has overheads/risks: + // - We may evaluate expensive input expressions multiple times. + // - We may infer too many constraints later. + // - The input expression may fail to be evaluated under ANSI mode. If we reorder the + // predicates and evaluate the input expression first, we may fail the query unexpectedly. + // To be safe, here we only generate extra predicates if the input is an attribute. + // Note that, foldable input is also excluded here, to avoid constant filters like + // 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and then the + // idempotence will break. + if (input.isInstanceOf[Attribute]) { + // Exclude child's constraints to guarantee idempotency + val inferredFilters = ExpressionSet( + Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input)) + ) -- generate.child.constraints + + if (inferredFilters.nonEmpty) { + generate.copy(child = Filter(inferredFilters.reduce(And), generate.child)) + } else { + generate + } } else { generate } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index d989753..78098c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -46,6 +46,7 @@ object ConstantFolding extends Rule[LogicalPlan] { private def hasNoSideEffect(e: Expression): Boolean = e match { case _: Attribute => true case _: Literal => true + case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child) case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala index 800d37e..61ab4f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -36,7 +34,7 @@ class InferFiltersFromGenerateSuite extends PlanTest { val testRelation = LocalRelation('a.array(StructType(Seq( StructField("x", IntegerType), StructField("y", IntegerType) - ))), 'c1.string, 'c2.string) + ))), 'c1.string, 'c2.string, 'c3.int) Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => val generator = f('a) @@ -74,63 +72,53 @@ class InferFiltersFromGenerateSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) } - } - // setup rules to test inferFilters with ConstantFolding to make sure - // the Filter rule added in inferFilters is removed again when doing - // explode with CreateArray/CreateMap - object OptimizeInferAndConstantFold extends RuleExecutor[LogicalPlan] { - val batches = - Batch("AnalysisNodes", Once, - EliminateSubqueryAliases) :: - Batch("Infer Filters", Once, InferFiltersFromGenerate) :: - Batch("ConstantFolding after", FixedPoint(4), - ConstantFolding, - NullPropagation, - PruneFilters) :: Nil + val generatorWithFromJson = f(JsonToStructs( + ArrayType(new StructType().add("s", "string")), + Map.empty, + 'c1)) + test("SPARK-37392: Don't infer filters from " + generatorWithFromJson) { + val originalQuery = testRelation.generate(generatorWithFromJson).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + + val returnSchema = ArrayType(StructType(Seq( + StructField("x", IntegerType), + StructField("y", StringType) + ))) + val fakeUDF = ScalaUDF( + (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))), + returnSchema, 'c3 :: Nil, Nil) + val generatorWithUDF = f(fakeUDF) + test("SPARK-36715: Don't infer filters from " + generatorWithUDF) { + val originalQuery = testRelation.generate(generatorWithUDF).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } } Seq(Explode(_), PosExplode(_)).foreach { f => - val createArrayExplode = f(CreateArray(Seq('c1))) - test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) { - val originalQuery = testRelation.generate(createArrayExplode).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - comparePlans(optimized, originalQuery) - } - val createMapExplode = f(CreateMap(Seq('c1, 'c2))) - test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) { - val originalQuery = testRelation.generate(createMapExplode).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - comparePlans(optimized, originalQuery) - } - } - - Seq(Inline(_)).foreach { f => - val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1))))) - test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) { - val originalQuery = testRelation.generate(createArrayStructExplode).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - comparePlans(optimized, originalQuery) - } - } + val createArrayExplode = f(CreateArray(Seq('c1))) + test("SPARK-33544: Don't infer filters from " + createArrayExplode) { + val originalQuery = testRelation.generate(createArrayExplode).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + val createMapExplode = f(CreateMap(Seq('c1, 'c2))) + test("SPARK-33544: Don't infer filters from " + createMapExplode) { + val originalQuery = testRelation.generate(createMapExplode).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + } - test("SPARK-36715: Don't infer filters from udf") { - Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => - val returnSchema = ArrayType(StructType(Seq( - StructField("x", IntegerType), - StructField("y", StringType) - ))) - val fakeUDF = ScalaUDF( - (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))), - returnSchema, Literal(8) :: Nil, - Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil) - val generator = f(fakeUDF) - val originalQuery = OneRowRelation().generate(generator).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - val correctAnswer = OneRowRelation() - .generate(generator) - .analyze - comparePlans(optimized, correctAnswer) + Seq(Inline(_)).foreach { f => + val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1))))) + test("SPARK-33544: Don't infer filters from " + createArrayStructExplode) { + val originalQuery = testRelation.generate(createArrayStructExplode).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org