This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d8d604bc07b [SPARK-40599][SQL] Add multiTransform methods to TreeNode to generate alternatives d8d604bc07b is described below commit d8d604bc07bc3b8c98f73c4b10f93cb4eb7113be Author: Peter Toth <peter.t...@gmail.com> AuthorDate: Tue Jan 17 20:58:37 2023 +0800 [SPARK-40599][SQL] Add multiTransform methods to TreeNode to generate alternatives ### What changes were proposed in this pull request? This PR introduce `TreeNode.multiTransform()` methods to be able to recursively transform a `TreeNode` (and so a tree) into multiple alternatives. These functions are particularly useful if we want to transform an expression with a projection in which subexpressions can be aliased with multiple different attributes. E.g. if we have a partitioning expression `HashPartitioning(a + b)` and we have a `Project` node that aliases `a` as `a1` and `a2` and `b` as `b1` and `b2` we can easily generate a stream of alternative transformations of the original partitioning: ``` // This is a simplified test, some arguments are missing to make it conciese val partitioning = HashPartitioning(Add(a, b)) val aliases: Map[Expression, Seq[Attribute]] = ... // collect the alias map from project val s = partitioning.multiTransform { case e: Expression if aliases.contains(e.canonicalized) => aliases(e.canonicalized) } s // Stream(HashPartitioning(Add(a1, b1)), HashPartitioning(Add(a1, b2)), HashPartitioning(Add(a2, b2)), HashPartitioning(Add(a2, b2))) ``` The result of `multiTransform` is a lazy stream to be able to limit the number of alternatives generated at the caller side as needed. ### Why are the changes needed? `TreeNode.multiTransform()` is a useful helper method. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UTs are added. Closes #38034 from peter-toth/SPARK-40599-multitransform. Authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 128 +++++++++++++++++++++ .../spark/sql/catalyst/trees/TreeNodeSuite.scala | 104 +++++++++++++++++ 2 files changed, 232 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9510aa4d9e7..dc64e5e2560 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -618,6 +618,134 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre } } + /** + * Returns alternative copies of this node where `rule` has been recursively applied to it and all + * of its children (pre-order). + * + * @param rule a function used to generate alternatives for a node + * @return the stream of alternatives + */ + def multiTransformDown( + rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Returns alternative copies of this node where `rule` has been recursively applied to it and all + * of its children (pre-order). + * + * As it is very easy to generate enormous number of alternatives when the input tree is huge or + * when the rule returns many alternatives for many nodes, this function returns the alternatives + * as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side + * as needed. + * + * The rule should not apply or can return a one element stream of original node to indicate that + * the original node without any transformation is a valid alternative. + * + * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this + * case `multiTransform()` returns an empty `Stream`. + * + * Please consider the following examples of `input.multiTransformDown(rule)`: + * + * We have an input expression: + * `Add(a, b)` + * + * 1. + * We have a simple rule: + * `a` => `Stream(1, 2)` + * `b` => `Stream(10, 20)` + * `Add(a, b)` => `Stream(11, 12, 21, 22)` + * + * The output is: + * `Stream(11, 12, 21, 22)` + * + * 2. + * In the previous example if we want to generate alternatives of `a` and `b` too then we need to + * explicitly add the original `Add(a, b)` expression to the rule: + * `a` => `Stream(1, 2)` + * `b` => `Stream(10, 20)` + * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))` + * + * The output is: + * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` + * + * @param rule a function used to generate alternatives for a node + * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false + * on a TreeNode T, skips processing T and its subtree; otherwise, processes + * T and its subtree recursively. + * @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is + * UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`) + * has been marked as in effective on a TreeNode T, skips processing T and its + * subtree. Do not pass it if the rule is not purely functional and reads a + * varying initial state for different invocations. + * @return the stream of alternatives + */ + def multiTransformDownWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId + )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + if (!cond.apply(this) || isRuleIneffective(ruleId)) { + return Stream(this) + } + + // We could return `Stream(this)` if the `rule` doesn't apply and handle both + // - the doesn't apply + // - and the rule returns a one element `Stream(originalNode)` + // cases together. But, unfortunately it doesn't seem like there is a way to match on a one + // element stream without eagerly computing the tail head. So this contradicts with the purpose + // of only taking the necessary elements from the alternatives. I.e. the + // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. + // Please note that this behaviour has a downside as well that we can only mark the rule on the + // original node ineffective if the rule didn't match. + var ruleApplied = true + val afterRules = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, (_: BaseType) => { + ruleApplied = false + Stream.empty + }) + } + + val afterRulesStream = if (afterRules.isEmpty) { + if (ruleApplied) { + // If the rule returned with empty alternatives then prune + Stream.empty + } else { + // If the rule was not applied then keep the original node + this.markRuleAsIneffective(ruleId) + Stream(this) + } + } else { + // If the rule was applied then use the returned alternatives + afterRules.map { afterRule => + if (this fastEquals afterRule) { + this + } else { + afterRule.copyTagsFrom(this) + afterRule + } + } + } + + afterRulesStream.flatMap { afterRule => + if (afterRule.containsChild.nonEmpty) { + generateChildrenSeq( + afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule))) + .map(afterRule.withNewChildren) + } else { + Stream(afterRule) + } + } + } + + private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = { + childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) => + for { + childrenSeq <- childrenSeqStream + child <- childrenStream + } yield child +: childrenSeq + ) + } + /** * Returns a copy of this node where `f` has been applied to all the nodes in `children`. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 286d3dddae6..ac28917675e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -977,4 +977,108 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(origin.context.summary.isEmpty) } } + + private def newErrorAfterStream(es: Expression*) = { + es.toStream.append( + throw new NoSuchElementException("Stream should not return more elements") + ) + } + + test("multiTransformDown generates all alternatives") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => + Stream(Literal(100), Literal(200), Literal(300)) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200), Literal(300)) + b <- Seq(Literal(10), Literal(20), Literal(30)) + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, b), cd) + assert(transformed === expected) + } + + test("multiTransformDown is lazy") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => newErrorAfterStream(Literal(10)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + } + val expected = for { + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, Literal(10)), Literal(100)) + // We don't access alternatives for `b` after 10 and for `c` after 100 + assert(transformed.take(3) == expected) + intercept[NoSuchElementException] { + transformed.take(3 + 1).toList + } + + val transformed2 = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + } + val expected2 = for { + b <- Seq(Literal(10), Literal(20), Literal(30)) + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, b), Literal(100)) + // We don't access alternatives for `c` after 100 + assert(transformed2.take(3 * 3) === expected2) + intercept[NoSuchElementException] { + transformed.take(3 * 3 + 1).toList + } + } + + test("multiTransformDown rule return this") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s) + case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s) + case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => + Stream(Literal(100), Literal(200), a) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d"))) + b <- Seq(Literal(10), Literal(20), Literal("b")) + a <- Seq(Literal(1), Literal(2), Literal("a")) + } yield Add(Add(a, b), cd) + assert(transformed == expected) + } + + test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " + + "transformed and itself is in the alternatives") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case a @ Add(StringLiteral("a"), StringLiteral("b"), _) => + Stream(Literal(11), Literal(12), Literal(21), Literal(22), a) + case StringLiteral("a") => Stream(Literal(1), Literal(2)) + case StringLiteral("b") => Stream(Literal(10), Literal(20)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200)) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200)) + ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++ + (for { + b <- Seq(Literal(10), Literal(20)) + a <- Seq(Literal(1), Literal(2)) + } yield Add(a, b)) + } yield Add(ab, cd) + assert(transformed == expected) + } + + test("multiTransformDown can prune") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Stream.empty + } + assert(transformed.isEmpty) + + val transformed2 = e.multiTransformDown { + case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty + } + assert(transformed2.isEmpty) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org