This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 62708db [SPARK-32635][SQL][2.4] Fix foldable propagation 62708db is described below commit 62708db4f90f652cd9bc73998ac5f1e949bd41ac Author: Peter Toth <peter.t...@gmail.com> AuthorDate: Fri Sep 18 10:28:30 2020 -0700 [SPARK-32635][SQL][2.4] Fix foldable propagation ### What changes were proposed in this pull request? This PR rewrites `FoldablePropagation` rule to replace attribute references in a node with foldables coming only from the node's children. Before this PR in the case of this example (with setting`spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation`): ```scala val a = Seq("1").toDF("col1").withColumn("col2", lit("1")) val b = Seq("2").toDF("col1").withColumn("col2", lit("2")) val aub = a.union(b) val c = aub.filter($"col1" === "2").cache() val d = Seq("2").toDF( "col4") val r = d.join(aub, $"col2" === $"col4").select("col4") val l = c.select("col2") val df = l.join(r, $"col2" === $"col4", "LeftOuter") df.show() ``` foldable propagation happens incorrectly: ``` Join LeftOuter, (col2#6 = col4#34) Join LeftOuter, (col2#6 = col4#34) !:- Project [col2#6] :- Project [1 AS col2#6] : +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas) : +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas) : +- Union : +- Union : :- *(1) Project [value#1 AS col1#4, 1 AS col2#6] : :- *(1) Project [value#1 AS col1#4, 1 AS col2#6] : : +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2)) : : +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2)) : : +- *(1) LocalTableScan [value#1] : : +- *(1) LocalTableScan [value#1] : +- *(2) Project [value#10 AS col1#13, 2 AS col2#15] : +- *(2) Project [value#10 AS col1#13, 2 AS col2#15] : +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2)) : +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2)) : +- *(2) LocalTableScan [value#10] : +- *(2) LocalTableScan [value#10] +- Project [col4#34] +- Project [col4#34] +- Join Inner, (col2#6 = col4#34) +- Join Inner, (col2#6 = col4#34) :- Project [value#31 AS col4#34] :- Project [value#31 AS col4#34] : +- LocalRelation [value#31] : +- LocalRelation [value#31] +- Project [col2#6] +- Project [col2#6] +- Union false, false +- Union false, false :- Project [1 AS col2#6] :- Project [1 AS col2#6] : +- LocalRelation [value#1] : +- LocalRelation [value#1] +- Project [2 AS col2#15] +- Project [2 AS col2#15] +- LocalRelation [value#10] +- LocalRelation [value#10] ``` and so the result is wrong: ``` +----+----+ |col2|col4| +----+----+ | 1|null| +----+----+ ``` After this PR foldable propagation will not happen incorrectly and the result is correct: ``` +----+----+ |col2|col4| +----+----+ | 2| 2| +----+----+ ``` ### Why are the changes needed? To fix a correctness issue. ### Does this PR introduce _any_ user-facing change? Yes, fixes a correctness issue. ### How was this patch tested? Existing and new UTs. Closes #29805 from peter-toth/SPARK-32635-fix-foldable-propagation-2.4. Authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../sql/catalyst/expressions/AttributeMap.scala | 2 + .../spark/sql/catalyst/optimizer/expressions.scala | 121 ++++++++++++--------- .../optimizer/FoldablePropagationSuite.scala | 12 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 12 ++ 4 files changed, 98 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 9f4a0f2..1e8f8ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -26,6 +26,8 @@ object AttributeMap { def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } + + def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty) } class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index be0e702..6c6f6313 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -595,59 +595,82 @@ object NullPropagation extends Rule[LogicalPlan] { */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - var foldableMap = AttributeMap(plan.flatMap { - case Project(projectList, _) => projectList.collect { - case a: Alias if a.child.foldable => (a.toAttribute, a) - } - case _ => Nil - }) - val replaceFoldable: PartialFunction[Expression, Expression] = { - case a: AttributeReference if foldableMap.contains(a) => foldableMap(a) + CleanupAliases(propagateFoldables(plan)._1) + } + + private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, AttributeMap[Alias]) = { + plan match { + case p: Project => + val (newChild, foldableMap) = propagateFoldables(p.child) + val newProject = + replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], foldableMap) + val newFoldableMap = AttributeMap(newProject.projectList.collect { + case a: Alias if a.child.foldable => (a.toAttribute, a) + }) + (newProject, newFoldableMap) + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case e: Expand => + val (newChild, foldableMap) = propagateFoldables(e.child) + val expandWithNewChildren = e.withNewChildren(Seq(newChild)).asInstanceOf[Expand] + val newExpand = if (foldableMap.isEmpty) { + expandWithNewChildren + } else { + val newProjections = expandWithNewChildren.projections.map(_.map(_.transform { + case a: AttributeReference if foldableMap.contains(a) => foldableMap(a) + })) + if (newProjections == expandWithNewChildren.projections) { + expandWithNewChildren + } else { + expandWithNewChildren.copy(projections = newProjections) + } + } + (newExpand, foldableMap) + + case u: UnaryNode if canPropagateFoldables(u) => + val (newChild, foldableMap) = propagateFoldables(u.child) + val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), foldableMap) + (newU, foldableMap) + + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. + // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes + // of outer join. + case j: Join => + val (newChildren, foldableMaps) = j.children.map(propagateFoldables).unzip + val foldableMap = AttributeMap( + foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ _.baseMap.values).toSeq) + val newJoin = + replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap) + val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => newJoin.right.output + case RightOuter => newJoin.left.output + case FullOuter => newJoin.left.output ++ newJoin.right.output + }) + val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + (newJoin, newFoldableMap) + + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case o => + val newOther = o.mapChildren(propagateFoldables(_)._1) + (newOther, AttributeMap.empty) } + } + private def replaceFoldable(plan: LogicalPlan, foldableMap: AttributeMap[Alias]): plan.type = { if (foldableMap.isEmpty) { plan } else { - CleanupAliases(plan.transformUp { - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => - u.transformExpressions(replaceFoldable) - - // Join derives the output attributes from its child while they are actually not the - // same attributes. For example, the output of outer join is not always picked from its - // children, but can also be null. We should exclude these miss-derived attributes when - // propagating the foldable expressions. - // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes - // of outer join. - case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => - val newJoin = j.transformExpressions(replaceFoldable) - val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { - case _: InnerLike | LeftExistence(_) => Nil - case LeftOuter => right.output - case RightOuter => left.output - case FullOuter => left.output ++ right.output - }) - foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { - case (attr, _) => missDerivedAttrsSet.contains(attr) - }.toSeq) - newJoin - - // We can not replace the attributes in `Expand.output`. If there are other non-leaf - // operators that have the `output` field, we should put them here too. - case expand: Expand if foldableMap.nonEmpty => - expand.copy(projections = expand.projections.map { projection => - projection.map(_.transform(replaceFoldable)) - }) - - // For other plans, they are not safe to apply foldable propagation, and they should not - // propagate foldable expressions from children. - case other if foldableMap.nonEmpty => - val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) - foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { - case (attr, _) => childrenOutputSet.contains(attr) - }.toSeq) - other - }) + plan transformExpressions { + case a: AttributeReference if foldableMap.contains(a) => foldableMap(a) + } } } @@ -655,7 +678,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { * Whitelist of all [[UnaryNode]]s for which allow foldable propagation. */ private def canPropagateFoldables(u: UnaryNode): Boolean = u match { - case _: Project => true + // Handling `Project` is moved to `propagateFoldables`. case _: Filter => true case _: SubqueryAlias => true case _: Aggregate => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index c288446..2c45199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -180,4 +180,16 @@ class FoldablePropagationSuite extends PlanTest { .select((Literal(1) + 3).as('res)).analyze comparePlans(optimized, correctAnswer) } + + test("SPARK-32635: Replace references with foldables coming only from the node's children") { + val leftExpression = 'a.int + val left = LocalRelation(leftExpression).select('a) + val rightExpression = Alias(Literal(2), "a")(leftExpression.exprId) + val right = LocalRelation('b.int).select('b, rightExpression).select('b) + val join = left.join(right, joinType = LeftOuter, condition = Some('b === 'a)) + + val query = join.analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e7d55ee..037cf23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2720,6 +2720,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val nestedDecArray = Array(decSpark) checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava)))) } + + test("SPARK-32635: Replace references with foldables coming only from the node's children") { + val a = Seq("1").toDF("col1").withColumn("col2", lit("1")) + val b = Seq("2").toDF("col1").withColumn("col2", lit("2")) + val aub = a.union(b) + val c = aub.filter($"col1" === "2").cache() + val d = Seq("2").toDF("col4") + val r = d.join(aub, $"col2" === $"col4").select("col4") + val l = c.select("col2") + val df = l.join(r, $"col2" === $"col4", "LeftOuter") + checkAnswer(df, Row("2", "2")) + } } case class GroupByKey(a: Int, b: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org