This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new de3673e02b6 [SPARK-39105][SQL] Add ConditionalExpression trait
de3673e02b6 is described below

commit de3673e02b620c00c741ae313ab5f40e56603515
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Fri May 6 15:41:53 2022 +0800

    [SPARK-39105][SQL] Add ConditionalExpression trait
    
    ### What changes were proposed in this pull request?
    
    Add `ConditionalExpression` trait.
    
    ### Why are the changes needed?
    
    For developers, if a custom conditional like expression contains common sub 
expression then the evaluation order may be changed since Spark will pull out 
and eval the common sub expressions first during execution.
    
    Add ConditionalExpression trait is friendly for developers.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no, add a new trait
    
    ### How was this patch tested?
    
    Pass existed test
    
    Closes #36455 from ulysses-you/SPARK-39105.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit fa86a078bb7d57d7dbd48095fb06059a9bdd6c2e)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/EquivalentExpressions.scala        | 44 ++--------------------
 .../sql/catalyst/expressions/Expression.scala      | 17 +++++++++
 .../expressions/conditionalExpressions.scala       | 40 +++++++++++++++++++-
 .../sql/catalyst/expressions/nullExpressions.scala | 26 ++++++++++++-
 4 files changed, 82 insertions(+), 45 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index e826de75fb5..2bbde304c28 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -127,22 +127,10 @@ class EquivalentExpressions {
 
   // There are some special expressions that we should not recurse into all of 
its children.
   //   1. CodegenFallback: it's children will not be used to generate code 
(call eval() instead)
-  //   2. If: common subexpressions will always be evaluated at the beginning, 
but the true and
-  //          false expressions in `If` may not get accessed, according to the 
predicate
-  //          expression. We should only recurse into the predicate expression.
-  //   3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in 
a certain
-  //                condition. We should only recurse into the first condition 
expression as it
-  //                will always get accessed.
-  //   4. Coalesce: it's also a conditional expression, we should only recurse 
into the first
-  //                children, because others may not get accessed.
-  //   5. NaNvl: it's a conditional expression, we can only guarantee the left 
child can be always
-  //             accessed. And if we hit the left child, the right will not be 
accessed.
+  //   2. ConditionalExpression: use its children that will always be 
evaluated.
   private def childrenToRecurse(expr: Expression): Seq[Expression] = expr 
match {
     case _: CodegenFallback => Nil
-    case i: If => i.predicate :: Nil
-    case c: CaseWhen => c.children.head :: Nil
-    case c: Coalesce => c.children.head :: Nil
-    case n: NaNvl => n.left :: Nil
+    case c: ConditionalExpression => c.alwaysEvaluatedInputs
     case other => other.children
   }
 
@@ -150,33 +138,7 @@ class EquivalentExpressions {
   // recursively add the common expressions shared between all of its children.
   private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] 
= expr match {
     case _: CodegenFallback => Nil
-    case i: If => Seq(Seq(i.trueValue, i.falseValue))
-    case c: CaseWhen =>
-      // We look at subexpressions in conditions and values of `CaseWhen` 
separately. It is
-      // because a subexpression in conditions will be run no matter which 
condition is matched
-      // if it is shared among conditions, but it doesn't need to be shared in 
values. Similarly,
-      // a subexpression among values doesn't need to be in conditions because 
no matter which
-      // condition is true, it will be evaluated.
-      val conditions = if (c.branches.length > 1) {
-        c.branches.map(_._1)
-      } else {
-        // If there is only one branch, the first condition is already covered 
by
-        // `childrenToRecurse` and we should exclude it here.
-        Nil
-      }
-      // For an expression to be in all branch values of a CaseWhen statement, 
it must also be in
-      // the elseValue.
-      val values = if (c.elseValue.nonEmpty) {
-        c.branches.map(_._2) ++ c.elseValue
-      } else {
-        Nil
-      }
-
-      Seq(conditions, values)
-    // If there is only one child, the first child is already covered by
-    // `childrenToRecurse` and we should exclude it here.
-    case c: Coalesce if c.children.length > 1 => Seq(c.children)
-    case n: NaNvl => Seq(n.children)
+    case c: ConditionalExpression => c.branchGroups
     case _ => Nil
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index b5695e8c872..30b6773ce1c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -454,6 +454,23 @@ trait Nondeterministic extends Expression {
   protected def evalInternal(input: InternalRow): Any
 }
 
+/**
+ * An expression that contains conditional expression branches, so not all 
branches will be hit.
+ * All optimization should be careful with the evaluation order.
+ */
+trait ConditionalExpression extends Expression {
+  /**
+   * Return the children expressions which can always be hit at runtime.
+   */
+  def alwaysEvaluatedInputs: Seq[Expression]
+
+  /**
+   * Return groups of branches. For each group, at least one branch will be 
hit at runtime,
+   * so that we can eagerly evaluate the common expressions of a group.
+   */
+  def branchGroups: Seq[Seq[Expression]]
+}
+
 /**
  * An expression that contains mutable state. A stateful expression is always 
non-deterministic
  * because the results it produces during evaluation are not only dependent on 
the given input
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 3e356f1e8a2..5dacabd646d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types._
   group = "conditional_funcs")
 // scalastyle:on line.size.limit
 case class If(predicate: Expression, trueValue: Expression, falseValue: 
Expression)
-  extends ComplexTypeMergingExpression with TernaryLike[Expression] {
+  extends ComplexTypeMergingExpression with ConditionalExpression with 
TernaryLike[Expression] {
 
   @transient
   override lazy val inputTypesForMerging: Seq[DataType] = {
@@ -48,6 +48,12 @@ case class If(predicate: Expression, trueValue: Expression, 
falseValue: Expressi
   override def second: Expression = trueValue
   override def third: Expression = falseValue
   override def nullable: Boolean = trueValue.nullable || falseValue.nullable
+  /**
+   * Only the condition expression will always be evaluated.
+   */
+  override def alwaysEvaluatedInputs: Seq[Expression] = predicate :: Nil
+
+  override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, 
falseValue))
 
   final override val nodePatterns : Seq[TreePattern] = Seq(IF)
 
@@ -138,7 +144,7 @@ case class If(predicate: Expression, trueValue: Expression, 
falseValue: Expressi
 case class CaseWhen(
     branches: Seq[(Expression, Expression)],
     elseValue: Option[Expression] = None)
-  extends ComplexTypeMergingExpression with Serializable {
+  extends ComplexTypeMergingExpression with ConditionalExpression {
 
   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 
:: Nil) ++ elseValue
 
@@ -179,6 +185,36 @@ case class CaseWhen(
     }
   }
 
+  /**
+   * Like `If`, the children of `CaseWhen` only get accessed in a certain 
condition.
+   * We should only return the first condition expression as it will always 
get accessed.
+   */
+  override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil
+
+  override def branchGroups: Seq[Seq[Expression]] = {
+    // We look at subexpressions in conditions and values of `CaseWhen` 
separately. It is
+    // because a subexpression in conditions will be run no matter which 
condition is matched
+    // if it is shared among conditions, but it doesn't need to be shared in 
values. Similarly,
+    // a subexpression among values doesn't need to be in conditions because 
no matter which
+    // condition is true, it will be evaluated.
+    val conditions = if (branches.length > 1) {
+      branches.map(_._1)
+    } else {
+      // If there is only one branch, the first condition is already covered by
+      // `alwaysEvaluatedInputs` and we should exclude it here.
+      Nil
+    }
+    // For an expression to be in all branch values of a CaseWhen statement, 
it must also be in
+    // the elseValue.
+    val values = if (elseValue.nonEmpty) {
+      branches.map(_._2) ++ elseValue
+    } else {
+      Nil
+    }
+
+    Seq(conditions, values)
+  }
+
   override def eval(input: InternalRow): Any = {
     var i = 0
     val size = branches.size
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 3c6a9b8e780..8f59ab5b249 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -47,7 +47,8 @@ import org.apache.spark.sql.types._
   since = "1.0.0",
   group = "conditional_funcs")
 // scalastyle:on line.size.limit
-case class Coalesce(children: Seq[Expression]) extends 
ComplexTypeMergingExpression {
+case class Coalesce(children: Seq[Expression])
+  extends ComplexTypeMergingExpression with ConditionalExpression {
 
   /** Coalesce is nullable if all of its children are nullable, or if it has 
no children. */
   override def nullable: Boolean = children.forall(_.nullable)
@@ -66,6 +67,19 @@ case class Coalesce(children: Seq[Expression]) extends 
ComplexTypeMergingExpress
     }
   }
 
+  /**
+   * We should only return the first child, because others may not get 
accessed.
+   */
+  override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil
+
+  override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
+    // If there is only one child, the first child is already covered by
+    // `alwaysEvaluatedInputs` and we should exclude it here.
+    Seq(children)
+  } else {
+    Nil
+  }
+
   override def eval(input: InternalRow): Any = {
     var result: Any = null
     val childIterator = children.iterator
@@ -261,13 +275,21 @@ case class IsNaN(child: Expression) extends 
UnaryExpression
   since = "1.5.0",
   group = "conditional_funcs")
 case class NaNvl(left: Expression, right: Expression)
-    extends BinaryExpression with ImplicitCastInputTypes {
+    extends BinaryExpression with ConditionalExpression with 
ImplicitCastInputTypes {
 
   override def dataType: DataType = left.dataType
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, 
FloatType))
 
+  /**
+   * We can only guarantee the left child can be always accessed. If we hit 
the left child,
+   * the right child will not be accessed.
+   */
+  override def alwaysEvaluatedInputs: Seq[Expression] = left :: Nil
+
+  override def branchGroups: Seq[Seq[Expression]] = Seq(children)
+
   override def eval(input: InternalRow): Any = {
     val value = left.eval(input)
     if (value == null) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to