This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f6b43f0384f [SPARK-39040][SQL] Respect NaNvl in EquivalentExpressions for expression elimination f6b43f0384f is described below commit f6b43f0384f9681b963f52a53759c521f6ac11d5 Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Fri Apr 29 12:35:44 2022 +0800 [SPARK-39040][SQL] Respect NaNvl in EquivalentExpressions for expression elimination ### What changes were proposed in this pull request? Respect NaNvl in EquivalentExpressions for expression elimination. ### Why are the changes needed? For example the query will fail: ```sql set spark.sql.ansi.enabled=true; set spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConstantFolding; SELECT nanvl(1, 1/0 + 1/0); ``` ```sql org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 4.0 failed 1 times, most recent failure: Lost task 0.0 in stage 4.0 (TID 4) (10.221.98.68 executor driver): org.apache.spark.SparkArithmeticException: divide by zero. To return NULL instead, use 'try_divide'. If necessary set spark.sql.ansi.enabled to false (except for ANSI interval type) to bypass this error. == SQL(line 1, position 17) == select nanvl(1 , 1/0 + 1/0) ^^^ at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:151) ``` We should respect the ordering of conditional expression that always evaluate the predicate branch first, so the query above should not fail. ### Does this PR introduce _any_ user-facing change? yes, bug fix ### How was this patch tested? add test Closes #36376 from ulysses-you/SPARK-39040. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/EquivalentExpressions.scala | 4 ++++ .../SubexpressionEliminationSuite.scala | 14 +++++++++++ .../inputs/ansi/conditional-functions.sql | 6 +++++ .../results/ansi/conditional-functions.sql.out | 27 ++++++++++++++++++++++ 4 files changed, 51 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 472b6e871e7..e826de75fb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -135,11 +135,14 @@ class EquivalentExpressions { // will always get accessed. // 4. Coalesce: it's also a conditional expression, we should only recurse into the first // children, because others may not get accessed. + // 5. NaNvl: it's a conditional expression, we can only guarantee the left child can be always + // accessed. And if we hit the left child, the right will not be accessed. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil case i: If => i.predicate :: Nil case c: CaseWhen => c.children.head :: Nil case c: Coalesce => c.children.head :: Nil + case n: NaNvl => n.left :: Nil case other => other.children } @@ -173,6 +176,7 @@ class EquivalentExpressions { // If there is only one child, the first child is already covered by // `childrenToRecurse` and we should exclude it here. case c: Coalesce if c.children.length > 1 => Seq(c.children) + case n: NaNvl => Seq(n.children) case _ => Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 4ad5c92f47c..2375d3ed35f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -447,6 +447,20 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel // So if `p` is replaced by subexpression, the literal will be reused. assert(code.value.toString == "((Decimal) references[0] /* literal */)") } + + test("SPARK-39040: Respect NaNvl in EquivalentExpressions for expression elimination") { + val add = Add(Literal(1), Literal(0)) + val n1 = NaNvl(Literal(1.0d), Add(add, add)) + val e1 = new EquivalentExpressions + e1.addExprTree(n1) + assert(e1.getCommonSubexpressions.isEmpty) + + val n2 = NaNvl(add, add) + val e2 = new EquivalentExpressions + e2.addExprTree(n2) + assert(e2.getCommonSubexpressions.size == 1) + assert(e2.getCommonSubexpressions.head == add) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/conditional-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/conditional-functions.sql new file mode 100644 index 00000000000..5c548b1e9c4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/conditional-functions.sql @@ -0,0 +1,6 @@ +-- Tests for conditional functions +CREATE TABLE t USING PARQUET AS SELECT c1, c2 FROM VALUES(1, 0),(2, 1) AS t(c1, c2); + +SELECT nanvl(c1, c1/c2 + c1/c2) FROM t; + +DROP TABLE IF EXISTS t; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out new file mode 100644 index 00000000000..d3af659fc48 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out @@ -0,0 +1,27 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query +CREATE TABLE t USING PARQUET AS SELECT c1, c2 FROM VALUES(1, 0),(2, 1) AS t(c1, c2) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT nanvl(c1, c1/c2 + c1/c2) FROM t +-- !query schema +struct<nanvl(c1, ((c1 / c2) + (c1 / c2))):double> +-- !query output +1.0 +2.0 + + +-- !query +DROP TABLE IF EXISTS t +-- !query schema +struct<> +-- !query output + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org