This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 5bc06fd  [SPARK-36130][SQL] UnwrapCastInBinaryComparison should skip 
In expression when in.list contains an expression that is not literal
5bc06fd is described below

commit 5bc06fd7d9e756a7d40011e3cda57494859f692a
Author: Fu Chen <cfmcgr...@gmail.com>
AuthorDate: Wed Jul 14 15:57:10 2021 +0800

    [SPARK-36130][SQL] UnwrapCastInBinaryComparison should skip In expression 
when in.list contains an expression that is not literal
    
    ### What changes were proposed in this pull request?
    
    Fix 
[comment](https://github.com/apache/spark/pull/32488#issuecomment-879315179)
    This PR fix rule `UnwrapCastInBinaryComparison` bug. Rule 
UnwrapCastInBinaryComparison should skip In expression when in.list contains an 
expression that is not literal.
    
    - In
    
    Before this pr, the following example will throw an exception.
    ```scala
      withTable("tbl") {
        sql("CREATE TABLE tbl (d decimal(33, 27)) USING PARQUET")
        sql("SELECT d FROM tbl WHERE d NOT IN (d + 1)")
      }
    ```
    - InSet
    
    As the analyzer guarantee that all the elements in the `inSet.hset` are 
literal, so this is not an issue for `InSet`.
    
    
https://github.com/apache/spark/blob/fbf53dee37129a493a4e5d5a007625b35f44fbda/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala#L264-L279
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, only bug fix.
    
    ### How was this patch tested?
    
    New test.
    
    Closes #33335 from cfmcgrady/SPARK-36130.
    
    Authored-by: Fu Chen <cfmcgr...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 103d16e868e3caaa08401e0398c20b4a4574c6b7)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../optimizer/UnwrapCastInBinaryComparison.scala      |  3 ++-
 .../optimizer/UnwrapCastInBinaryComparisonSuite.scala | 19 +++++++++++++++++++
 2 files changed, 21 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index d5ff0fc..08c4cbf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -141,8 +141,9 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
     // values.
     // 2. this rule only handles the case when both `fromExp` and value in 
`in.list` are of numeric
     // type.
+    // 3. this rule doesn't optimize In when `in.list` contains an expression 
that is not literal.
     case in @ In(Cast(fromExp, toType: NumericType, _, _), list @ 
Seq(firstLit, _*))
-      if canImplicitlyCast(fromExp, toType, firstLit.dataType) =>
+      if canImplicitlyCast(fromExp, toType, firstLit.dataType) && 
in.inSetConvertible =>
 
       // There are 3 kinds of literals in the list:
       // 1. null literals
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index e5df1ab..31f62cf 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -283,6 +283,25 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
     )
   }
 
+  test("SPARK-36130: unwrap In should skip when in.list contains an expression 
that " +
+    "is not literal") {
+    val add = Cast(f2, DoubleType) + 1.0d
+    val doubleLit = Literal.create(null, DoubleType)
+    assertEquivalent(In(Cast(f2, DoubleType), Seq(add)), In(Cast(f2, 
DoubleType), Seq(add)))
+    assertEquivalent(
+      In(Cast(f2, DoubleType), Seq(doubleLit, add)),
+      In(Cast(f2, DoubleType), Seq(doubleLit, add)))
+    assertEquivalent(
+      In(Cast(f2, DoubleType), Seq(doubleLit, 1.0d, add)),
+      In(Cast(f2, DoubleType), Seq(doubleLit, 1.0d, add)))
+    assertEquivalent(
+      In(Cast(f2, DoubleType), Seq(1.0d, add)),
+      In(Cast(f2, DoubleType), Seq(1.0d, add)))
+    assertEquivalent(
+      In(Cast(f2, DoubleType), Seq(0.0d, 1.0d, add)),
+      In(Cast(f2, DoubleType), Seq(0.0d, 1.0d, add)))
+  }
+
   private def castInt(e: Expression): Expression = Cast(e, IntegerType)
   private def castDouble(e: Expression): Expression = Cast(e, DoubleType)
   private def castDecimal2(e: Expression): Expression = Cast(e, 
DecimalType(10, 4))

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to