This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new de3673e02b6 [SPARK-39105][SQL] Add ConditionalExpression trait de3673e02b6 is described below commit de3673e02b620c00c741ae313ab5f40e56603515 Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Fri May 6 15:41:53 2022 +0800 [SPARK-39105][SQL] Add ConditionalExpression trait ### What changes were proposed in this pull request? Add `ConditionalExpression` trait. ### Why are the changes needed? For developers, if a custom conditional like expression contains common sub expression then the evaluation order may be changed since Spark will pull out and eval the common sub expressions first during execution. Add ConditionalExpression trait is friendly for developers. ### Does this PR introduce _any_ user-facing change? no, add a new trait ### How was this patch tested? Pass existed test Closes #36455 from ulysses-you/SPARK-39105. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit fa86a078bb7d57d7dbd48095fb06059a9bdd6c2e) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/EquivalentExpressions.scala | 44 ++-------------------- .../sql/catalyst/expressions/Expression.scala | 17 +++++++++ .../expressions/conditionalExpressions.scala | 40 +++++++++++++++++++- .../sql/catalyst/expressions/nullExpressions.scala | 26 ++++++++++++- 4 files changed, 82 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index e826de75fb5..2bbde304c28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -127,22 +127,10 @@ class EquivalentExpressions { // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. If: common subexpressions will always be evaluated at the beginning, but the true and - // false expressions in `If` may not get accessed, according to the predicate - // expression. We should only recurse into the predicate expression. - // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain - // condition. We should only recurse into the first condition expression as it - // will always get accessed. - // 4. Coalesce: it's also a conditional expression, we should only recurse into the first - // children, because others may not get accessed. - // 5. NaNvl: it's a conditional expression, we can only guarantee the left child can be always - // accessed. And if we hit the left child, the right will not be accessed. + // 2. ConditionalExpression: use its children that will always be evaluated. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil - case i: If => i.predicate :: Nil - case c: CaseWhen => c.children.head :: Nil - case c: Coalesce => c.children.head :: Nil - case n: NaNvl => n.left :: Nil + case c: ConditionalExpression => c.alwaysEvaluatedInputs case other => other.children } @@ -150,33 +138,7 @@ class EquivalentExpressions { // recursively add the common expressions shared between all of its children. private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { case _: CodegenFallback => Nil - case i: If => Seq(Seq(i.trueValue, i.falseValue)) - case c: CaseWhen => - // We look at subexpressions in conditions and values of `CaseWhen` separately. It is - // because a subexpression in conditions will be run no matter which condition is matched - // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, - // a subexpression among values doesn't need to be in conditions because no matter which - // condition is true, it will be evaluated. - val conditions = if (c.branches.length > 1) { - c.branches.map(_._1) - } else { - // If there is only one branch, the first condition is already covered by - // `childrenToRecurse` and we should exclude it here. - Nil - } - // For an expression to be in all branch values of a CaseWhen statement, it must also be in - // the elseValue. - val values = if (c.elseValue.nonEmpty) { - c.branches.map(_._2) ++ c.elseValue - } else { - Nil - } - - Seq(conditions, values) - // If there is only one child, the first child is already covered by - // `childrenToRecurse` and we should exclude it here. - case c: Coalesce if c.children.length > 1 => Seq(c.children) - case n: NaNvl => Seq(n.children) + case c: ConditionalExpression => c.branchGroups case _ => Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b5695e8c872..30b6773ce1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -454,6 +454,23 @@ trait Nondeterministic extends Expression { protected def evalInternal(input: InternalRow): Any } +/** + * An expression that contains conditional expression branches, so not all branches will be hit. + * All optimization should be careful with the evaluation order. + */ +trait ConditionalExpression extends Expression { + /** + * Return the children expressions which can always be hit at runtime. + */ + def alwaysEvaluatedInputs: Seq[Expression] + + /** + * Return groups of branches. For each group, at least one branch will be hit at runtime, + * so that we can eagerly evaluate the common expressions of a group. + */ + def branchGroups: Seq[Seq[Expression]] +} + /** * An expression that contains mutable state. A stateful expression is always non-deterministic * because the results it produces during evaluation are not only dependent on the given input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 3e356f1e8a2..5dacabd646d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ group = "conditional_funcs") // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends ComplexTypeMergingExpression with TernaryLike[Expression] { + extends ComplexTypeMergingExpression with ConditionalExpression with TernaryLike[Expression] { @transient override lazy val inputTypesForMerging: Seq[DataType] = { @@ -48,6 +48,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def second: Expression = trueValue override def third: Expression = falseValue override def nullable: Boolean = trueValue.nullable || falseValue.nullable + /** + * Only the condition expression will always be evaluated. + */ + override def alwaysEvaluatedInputs: Seq[Expression] = predicate :: Nil + + override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue)) final override val nodePatterns : Seq[TreePattern] = Seq(IF) @@ -138,7 +144,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends ComplexTypeMergingExpression with Serializable { + extends ComplexTypeMergingExpression with ConditionalExpression { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue @@ -179,6 +185,36 @@ case class CaseWhen( } } + /** + * Like `If`, the children of `CaseWhen` only get accessed in a certain condition. + * We should only return the first condition expression as it will always get accessed. + */ + override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil + + override def branchGroups: Seq[Seq[Expression]] = { + // We look at subexpressions in conditions and values of `CaseWhen` separately. It is + // because a subexpression in conditions will be run no matter which condition is matched + // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, + // a subexpression among values doesn't need to be in conditions because no matter which + // condition is true, it will be evaluated. + val conditions = if (branches.length > 1) { + branches.map(_._1) + } else { + // If there is only one branch, the first condition is already covered by + // `alwaysEvaluatedInputs` and we should exclude it here. + Nil + } + // For an expression to be in all branch values of a CaseWhen statement, it must also be in + // the elseValue. + val values = if (elseValue.nonEmpty) { + branches.map(_._2) ++ elseValue + } else { + Nil + } + + Seq(conditions, values) + } + override def eval(input: InternalRow): Any = { var i = 0 val size = branches.size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 3c6a9b8e780..8f59ab5b249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -47,7 +47,8 @@ import org.apache.spark.sql.types._ since = "1.0.0", group = "conditional_funcs") // scalastyle:on line.size.limit -case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression { +case class Coalesce(children: Seq[Expression]) + extends ComplexTypeMergingExpression with ConditionalExpression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = children.forall(_.nullable) @@ -66,6 +67,19 @@ case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpress } } + /** + * We should only return the first child, because others may not get accessed. + */ + override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil + + override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) { + // If there is only one child, the first child is already covered by + // `alwaysEvaluatedInputs` and we should exclude it here. + Seq(children) + } else { + Nil + } + override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator @@ -261,13 +275,21 @@ case class IsNaN(child: Expression) extends UnaryExpression since = "1.5.0", group = "conditional_funcs") case class NaNvl(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ConditionalExpression with ImplicitCastInputTypes { override def dataType: DataType = left.dataType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, FloatType)) + /** + * We can only guarantee the left child can be always accessed. If we hit the left child, + * the right child will not be accessed. + */ + override def alwaysEvaluatedInputs: Seq[Expression] = left :: Nil + + override def branchGroups: Seq[Seq[Expression]] = Seq(children) + override def eval(input: InternalRow): Any = { val value = left.eval(input) if (value == null) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org