This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ac105ccebf5 [SPARK-42286][SQL] Fallback to previous codegen code path 
for complex expr with CAST
ac105ccebf5 is described below

commit ac105ccebf5f144f2d506cbe102c362a195afa9a
Author: Runyao Chen <runyao.c...@databricks.com>
AuthorDate: Thu Feb 2 21:06:20 2023 -0800

    [SPARK-42286][SQL] Fallback to previous codegen code path for complex expr 
with CAST
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the internal error `Child is not Cast or ExpressionProxy of 
Cast` for valid `CaseWhen` expr with `Cast` expr in its branches.
    
    Specifically, after SPARK-39865, an improved error msg for overflow 
exception during table insert was introduced. The improvement covers `Cast` 
expr and `ExpressionProxy` expr, but `CaseWhen` and other complex ones are not 
covered. An example below hits an internal error today.
    ```
    create table t1 as select x FROM values (1), (2), (3) as tab(x);
    create table t2 (x Decimal(9, 0));
    insert into t2 select 0 - (case when x = 1 then 1 else x end) from t1 where 
x = 1;
    ```
    To fix the query failure, we decide to fall back to the previous handling 
if the expr is not a simple `Cast` expr or `ExpressionProxy` expr.
    
    ### Why are the changes needed?
    
    To fix the query regression introduced in SPARK-39865.
    
    ### Does this PR introduce _any_ user-facing change?
    No. We just fall back to the previous error msg if the expression involving 
`Cast` is not a simple one.
    
    ### How was this patch tested?
    - Added Unit test.
    - Removed one test case for the `Child is not Cast or ExpressionProxy of 
Cast` internal error, as now we do not check if the child has a `Cast` 
expression and fall back to the previous error message.
    
    Closes #39855 from RunyaoChen/fallback_cast_codegen.
    
    Authored-by: Runyao Chen <runyao.c...@databricks.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 41 +++++++++++-----------
 .../sql/catalyst/expressions/CastSuiteBase.scala   | 28 +--------------
 .../sql/errors/QueryExecutionAnsiErrorsSuite.scala | 37 ++++++++++++++++++-
 .../org/apache/spark/sql/sources/InsertSuite.scala | 10 ++++++
 4 files changed, 68 insertions(+), 48 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 6900aa873bb..ae0dc3dbf03 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,7 @@ import java.time.{ZoneId, ZoneOffset}
 import java.util.Locale
 import java.util.concurrent.TimeUnit._
 
-import org.apache.spark.{SparkArithmeticException, SparkException}
+import org.apache.spark.SparkArithmeticException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -2511,40 +2511,41 @@ case class UpCast(child: Expression, target: 
AbstractDataType, walkedTypePath: S
  */
 case class CheckOverflowInTableInsert(child: Expression, columnName: String)
     extends UnaryExpression {
-  checkChild(child)
-
-  private def checkChild(child: Expression): Unit = child match {
-    case _: Cast =>
-    case ExpressionProxy(c, _, _) if c.isInstanceOf[Cast] =>
-    case _ =>
-      throw SparkException.internalError("Child is not Cast or ExpressionProxy 
of Cast")
-  }
 
   override protected def withNewChildInternal(newChild: Expression): 
Expression = {
-    checkChild(newChild)
     copy(child = newChild)
   }
 
-  private def getCast: Cast = child match {
+  private def getCast: Option[Cast] = child match {
     case c: Cast =>
-      c
-    case ExpressionProxy(c, _, _) =>
-      c.asInstanceOf[Cast]
+      Some(c)
+    case ExpressionProxy(c: Cast, _, _) =>
+      Some(c)
+    case _ => None
   }
 
   override def eval(input: InternalRow): Any = try {
     child.eval(input)
   } catch {
     case e: SparkArithmeticException =>
-      val cast = getCast
-      throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert(
-        cast.child.dataType,
-        cast.dataType,
-        columnName)
+      getCast match {
+        case Some(cast) =>
+          throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert(
+            cast.child.dataType,
+            cast.dataType,
+            columnName)
+        case None => throw e
+      }
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    val child = getCast
+    getCast match {
+      case Some(child) => doGenCodeWithBetterErrorMsg(ctx, ev, child)
+      case None => child.genCode(ctx)
+    }
+  }
+
+  def doGenCodeWithBetterErrorMsg(ctx: CodegenContext, ev: ExprCode, child: 
Cast): ExprCode = {
     val childGen = child.genCode(ctx)
     val exceptionClass = classOf[SparkArithmeticException].getCanonicalName
     val fromDt =
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index ca9f43adc1f..bad85ca4176 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -24,7 +24,7 @@ import java.util.{Calendar, Locale, TimeZone}
 
 import scala.collection.parallel.immutable.ParVector
 
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -1391,30 +1391,4 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(expr.sql == cast.sql)
     assert(expr.toString == cast.toString)
   }
-
-  test("SPARK-41991: CheckOverflowInTableInsert child must be Cast or 
ExpressionProxy of Cast") {
-    val runtime = new SubExprEvaluationRuntime(1)
-    val cast = Cast(Literal(1.0), IntegerType)
-    val expr = CheckOverflowInTableInsert(cast, "column_1")
-    val proxy = ExpressionProxy(Literal(1.0), 0, runtime)
-    checkError(
-      exception = intercept[SparkException] {
-        expr.withNewChildrenInternal(IndexedSeq(proxy))
-      },
-      errorClass = "INTERNAL_ERROR",
-      parameters = Map(
-        "message" -> "Child is not Cast or ExpressionProxy of Cast"
-      )
-    )
-
-    checkError(
-      exception = intercept[SparkException] {
-        expr.withNewChildrenInternal(IndexedSeq(Literal(1)))
-      },
-      errorClass = "INTERNAL_ERROR",
-      parameters = Map(
-        "message" -> "Child is not Cast or ExpressionProxy of Cast"
-      )
-    )
-  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
index c85fbf84baa..45c7898dfa2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.errors
 
 import org.apache.spark._
 import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions.{Cast, 
CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime}
+import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, 
CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.ByteType
@@ -179,6 +179,41 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest
     }
   }
 
+  test("SPARK-42286: CheckOverflowInTableInsert with CaseWhen should throw an 
exception") {
+    val caseWhen = CaseWhen(
+      Seq((Literal(true), Cast(Literal.apply(12345678901234567890D), 
ByteType))), None)
+    checkError(
+      exception = intercept[SparkArithmeticException] {
+        CheckOverflowInTableInsert(caseWhen, "col").eval(null)
+      }.asInstanceOf[SparkThrowable],
+      errorClass = "CAST_OVERFLOW",
+      parameters = Map("value" -> "1.2345678901234567E19D",
+        "sourceType" -> "\"DOUBLE\"",
+        "targetType" -> ("\"TINYINT\""),
+        "ansiConfig" -> ansiConf)
+    )
+  }
+
+  test("SPARK-42286: End-to-end query with Case When throwing CAST_OVERFLOW 
exception") {
+    withTable("t1", "t2") {
+      sql("CREATE TABLE t1 (x double) USING parquet")
+      sql("insert into t1 values (1.2345678901234567E19D)")
+      sql("CREATE TABLE t2 (x tinyint) USING parquet")
+      val insertCmd = "insert into t2 select 0 - (case when x = 
1.2345678901234567E19D " +
+        "then 1.2345678901234567E19D else x end) from t1 where x = 
1.2345678901234567E19D;"
+      checkError(
+        exception = intercept[SparkException] {
+          sql(insertCmd).collect()
+        }.getCause.getCause.asInstanceOf[SparkThrowable],
+        errorClass = "CAST_OVERFLOW",
+        parameters = Map("value" -> "-1.2345678901234567E19D",
+          "sourceType" -> "\"DOUBLE\"",
+          "targetType" -> "\"TINYINT\"",
+          "ansiConfig" -> ansiConf),
+        sqlState = "22003")
+    }
+  }
+
   test("SPARK-39981: interpreted CheckOverflowInTableInsert should throw an 
exception") {
     checkError(
       exception = intercept[SparkArithmeticException] {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index c34925ef1bf..cc1d4ab3fcd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -2315,6 +2315,16 @@ class InsertSuite extends DataSourceTest with 
SharedSparkSession {
       }
     }
   }
+
+  test("SPARK-42286: Insert into a table select from case when with cast, 
positive test") {
+    withTable("t1", "t2") {
+      sql("create table t1 (x int) using parquet")
+      sql("insert into t1 values (1), (2)")
+      sql("create table t2 (x Decimal(9, 0)) using parquet")
+      sql("insert into t2 select 0 - (case when x = 1 then 1 else x end) from 
t1 where x = 1")
+      checkAnswer(spark.table("t2"), Row(-1))
+    }
+  }
 }
 
 class FileExistingTestFileSystem extends RawLocalFileSystem {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to