This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new 5bc06fd [SPARK-36130][SQL] UnwrapCastInBinaryComparison should skip In expression when in.list contains an expression that is not literal 5bc06fd is described below commit 5bc06fd7d9e756a7d40011e3cda57494859f692a Author: Fu Chen <cfmcgr...@gmail.com> AuthorDate: Wed Jul 14 15:57:10 2021 +0800 [SPARK-36130][SQL] UnwrapCastInBinaryComparison should skip In expression when in.list contains an expression that is not literal ### What changes were proposed in this pull request? Fix [comment](https://github.com/apache/spark/pull/32488#issuecomment-879315179) This PR fix rule `UnwrapCastInBinaryComparison` bug. Rule UnwrapCastInBinaryComparison should skip In expression when in.list contains an expression that is not literal. - In Before this pr, the following example will throw an exception. ```scala withTable("tbl") { sql("CREATE TABLE tbl (d decimal(33, 27)) USING PARQUET") sql("SELECT d FROM tbl WHERE d NOT IN (d + 1)") } ``` - InSet As the analyzer guarantee that all the elements in the `inSet.hset` are literal, so this is not an issue for `InSet`. https://github.com/apache/spark/blob/fbf53dee37129a493a4e5d5a007625b35f44fbda/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala#L264-L279 ### Does this PR introduce _any_ user-facing change? No, only bug fix. ### How was this patch tested? New test. Closes #33335 from cfmcgrady/SPARK-36130. Authored-by: Fu Chen <cfmcgr...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 103d16e868e3caaa08401e0398c20b4a4574c6b7) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/UnwrapCastInBinaryComparison.scala | 3 ++- .../optimizer/UnwrapCastInBinaryComparisonSuite.scala | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index d5ff0fc..08c4cbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -141,8 +141,9 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { // values. // 2. this rule only handles the case when both `fromExp` and value in `in.list` are of numeric // type. + // 3. this rule doesn't optimize In when `in.list` contains an expression that is not literal. case in @ In(Cast(fromExp, toType: NumericType, _, _), list @ Seq(firstLit, _*)) - if canImplicitlyCast(fromExp, toType, firstLit.dataType) => + if canImplicitlyCast(fromExp, toType, firstLit.dataType) && in.inSetConvertible => // There are 3 kinds of literals in the list: // 1. null literals diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index e5df1ab..31f62cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -283,6 +283,25 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp ) } + test("SPARK-36130: unwrap In should skip when in.list contains an expression that " + + "is not literal") { + val add = Cast(f2, DoubleType) + 1.0d + val doubleLit = Literal.create(null, DoubleType) + assertEquivalent(In(Cast(f2, DoubleType), Seq(add)), In(Cast(f2, DoubleType), Seq(add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(doubleLit, add)), + In(Cast(f2, DoubleType), Seq(doubleLit, add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(doubleLit, 1.0d, add)), + In(Cast(f2, DoubleType), Seq(doubleLit, 1.0d, add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(1.0d, add)), + In(Cast(f2, DoubleType), Seq(1.0d, add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(0.0d, 1.0d, add)), + In(Cast(f2, DoubleType), Seq(0.0d, 1.0d, add))) + } + private def castInt(e: Expression): Expression = Cast(e, IntegerType) private def castDouble(e: Expression): Expression = Cast(e, DoubleType) private def castDecimal2(e: Expression): Expression = Cast(e, DecimalType(10, 4)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org