This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f6b43f0384f [SPARK-39040][SQL] Respect NaNvl in EquivalentExpressions 
for expression elimination
f6b43f0384f is described below

commit f6b43f0384f9681b963f52a53759c521f6ac11d5
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Fri Apr 29 12:35:44 2022 +0800

    [SPARK-39040][SQL] Respect NaNvl in EquivalentExpressions for expression 
elimination
    
    ### What changes were proposed in this pull request?
    
    Respect NaNvl in EquivalentExpressions for expression elimination.
    
    ### Why are the changes needed?
    
    For example the query will fail:
    ```sql
    set spark.sql.ansi.enabled=true;
    set 
spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConstantFolding;
    SELECT nanvl(1, 1/0 + 1/0);
    ```
    ```sql
    org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 
in stage 4.0 failed 1 times, most recent failure: Lost task 0.0 in stage 4.0 
(TID 4) (10.221.98.68 executor driver): 
org.apache.spark.SparkArithmeticException: divide by zero. To return NULL 
instead, use 'try_divide'. If necessary set spark.sql.ansi.enabled to false 
(except for ANSI interval type) to bypass this error.
    == SQL(line 1, position 17) ==
    select nanvl(1 , 1/0 + 1/0)
                     ^^^    at 
org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:151)
     ```
    We should respect the ordering of conditional expression that always 
evaluate the predicate branch first, so the query above should not fail.
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, bug fix
    
    ### How was this patch tested?
    
    add test
    
    Closes #36376 from ulysses-you/SPARK-39040.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/EquivalentExpressions.scala        |  4 ++++
 .../SubexpressionEliminationSuite.scala            | 14 +++++++++++
 .../inputs/ansi/conditional-functions.sql          |  6 +++++
 .../results/ansi/conditional-functions.sql.out     | 27 ++++++++++++++++++++++
 4 files changed, 51 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 472b6e871e7..e826de75fb5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -135,11 +135,14 @@ class EquivalentExpressions {
   //                will always get accessed.
   //   4. Coalesce: it's also a conditional expression, we should only recurse 
into the first
   //                children, because others may not get accessed.
+  //   5. NaNvl: it's a conditional expression, we can only guarantee the left 
child can be always
+  //             accessed. And if we hit the left child, the right will not be 
accessed.
   private def childrenToRecurse(expr: Expression): Seq[Expression] = expr 
match {
     case _: CodegenFallback => Nil
     case i: If => i.predicate :: Nil
     case c: CaseWhen => c.children.head :: Nil
     case c: Coalesce => c.children.head :: Nil
+    case n: NaNvl => n.left :: Nil
     case other => other.children
   }
 
@@ -173,6 +176,7 @@ class EquivalentExpressions {
     // If there is only one child, the first child is already covered by
     // `childrenToRecurse` and we should exclude it here.
     case c: Coalesce if c.children.length > 1 => Seq(c.children)
+    case n: NaNvl => Seq(n.children)
     case _ => Nil
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 4ad5c92f47c..2375d3ed35f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -447,6 +447,20 @@ class SubexpressionEliminationSuite extends SparkFunSuite 
with ExpressionEvalHel
     // So if `p` is replaced by subexpression, the literal will be reused.
     assert(code.value.toString == "((Decimal) references[0] /* literal */)")
   }
+
+  test("SPARK-39040: Respect NaNvl in EquivalentExpressions for expression 
elimination") {
+    val add = Add(Literal(1), Literal(0))
+    val n1 = NaNvl(Literal(1.0d), Add(add, add))
+    val e1 = new EquivalentExpressions
+    e1.addExprTree(n1)
+    assert(e1.getCommonSubexpressions.isEmpty)
+
+    val n2 = NaNvl(add, add)
+    val e2 = new EquivalentExpressions
+    e2.addExprTree(n2)
+    assert(e2.getCommonSubexpressions.size == 1)
+    assert(e2.getCommonSubexpressions.head == add)
+  }
 }
 
 case class CodegenFallbackExpression(child: Expression)
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/ansi/conditional-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/ansi/conditional-functions.sql
new file mode 100644
index 00000000000..5c548b1e9c4
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/ansi/conditional-functions.sql
@@ -0,0 +1,6 @@
+-- Tests for conditional functions
+CREATE TABLE t USING PARQUET AS SELECT c1, c2 FROM VALUES(1, 0),(2, 1) AS 
t(c1, c2);
+
+SELECT nanvl(c1, c1/c2 + c1/c2) FROM t;
+
+DROP TABLE IF EXISTS t;
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out
new file mode 100644
index 00000000000..d3af659fc48
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out
@@ -0,0 +1,27 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 3
+
+
+-- !query
+CREATE TABLE t USING PARQUET AS SELECT c1, c2 FROM VALUES(1, 0),(2, 1) AS 
t(c1, c2)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT nanvl(c1, c1/c2 + c1/c2) FROM t
+-- !query schema
+struct<nanvl(c1, ((c1 / c2) + (c1 / c2))):double>
+-- !query output
+1.0
+2.0
+
+
+-- !query
+DROP TABLE IF EXISTS t
+-- !query schema
+struct<>
+-- !query output
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to