This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ac105ccebf5 [SPARK-42286][SQL] Fallback to previous codegen code path for complex expr with CAST ac105ccebf5 is described below commit ac105ccebf5f144f2d506cbe102c362a195afa9a Author: Runyao Chen <runyao.c...@databricks.com> AuthorDate: Thu Feb 2 21:06:20 2023 -0800 [SPARK-42286][SQL] Fallback to previous codegen code path for complex expr with CAST ### What changes were proposed in this pull request? This PR fixes the internal error `Child is not Cast or ExpressionProxy of Cast` for valid `CaseWhen` expr with `Cast` expr in its branches. Specifically, after SPARK-39865, an improved error msg for overflow exception during table insert was introduced. The improvement covers `Cast` expr and `ExpressionProxy` expr, but `CaseWhen` and other complex ones are not covered. An example below hits an internal error today. ``` create table t1 as select x FROM values (1), (2), (3) as tab(x); create table t2 (x Decimal(9, 0)); insert into t2 select 0 - (case when x = 1 then 1 else x end) from t1 where x = 1; ``` To fix the query failure, we decide to fall back to the previous handling if the expr is not a simple `Cast` expr or `ExpressionProxy` expr. ### Why are the changes needed? To fix the query regression introduced in SPARK-39865. ### Does this PR introduce _any_ user-facing change? No. We just fall back to the previous error msg if the expression involving `Cast` is not a simple one. ### How was this patch tested? - Added Unit test. - Removed one test case for the `Child is not Cast or ExpressionProxy of Cast` internal error, as now we do not check if the child has a `Cast` expression and fall back to the previous error message. Closes #39855 from RunyaoChen/fallback_cast_codegen. Authored-by: Runyao Chen <runyao.c...@databricks.com> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../spark/sql/catalyst/expressions/Cast.scala | 41 +++++++++++----------- .../sql/catalyst/expressions/CastSuiteBase.scala | 28 +-------------- .../sql/errors/QueryExecutionAnsiErrorsSuite.scala | 37 ++++++++++++++++++- .../org/apache/spark/sql/sources/InsertSuite.scala | 10 ++++++ 4 files changed, 68 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 6900aa873bb..ae0dc3dbf03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.time.{ZoneId, ZoneOffset} import java.util.Locale import java.util.concurrent.TimeUnit._ -import org.apache.spark.{SparkArithmeticException, SparkException} +import org.apache.spark.SparkArithmeticException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -2511,40 +2511,41 @@ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: S */ case class CheckOverflowInTableInsert(child: Expression, columnName: String) extends UnaryExpression { - checkChild(child) - - private def checkChild(child: Expression): Unit = child match { - case _: Cast => - case ExpressionProxy(c, _, _) if c.isInstanceOf[Cast] => - case _ => - throw SparkException.internalError("Child is not Cast or ExpressionProxy of Cast") - } override protected def withNewChildInternal(newChild: Expression): Expression = { - checkChild(newChild) copy(child = newChild) } - private def getCast: Cast = child match { + private def getCast: Option[Cast] = child match { case c: Cast => - c - case ExpressionProxy(c, _, _) => - c.asInstanceOf[Cast] + Some(c) + case ExpressionProxy(c: Cast, _, _) => + Some(c) + case _ => None } override def eval(input: InternalRow): Any = try { child.eval(input) } catch { case e: SparkArithmeticException => - val cast = getCast - throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert( - cast.child.dataType, - cast.dataType, - columnName) + getCast match { + case Some(cast) => + throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert( + cast.child.dataType, + cast.dataType, + columnName) + case None => throw e + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val child = getCast + getCast match { + case Some(child) => doGenCodeWithBetterErrorMsg(ctx, ev, child) + case None => child.genCode(ctx) + } + } + + def doGenCodeWithBetterErrorMsg(ctx: CodegenContext, ev: ExprCode, child: Cast): ExprCode = { val childGen = child.genCode(ctx) val exceptionClass = classOf[SparkArithmeticException].getCanonicalName val fromDt = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index ca9f43adc1f..bad85ca4176 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, Locale, TimeZone} import scala.collection.parallel.immutable.ParVector -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -1391,30 +1391,4 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(expr.sql == cast.sql) assert(expr.toString == cast.toString) } - - test("SPARK-41991: CheckOverflowInTableInsert child must be Cast or ExpressionProxy of Cast") { - val runtime = new SubExprEvaluationRuntime(1) - val cast = Cast(Literal(1.0), IntegerType) - val expr = CheckOverflowInTableInsert(cast, "column_1") - val proxy = ExpressionProxy(Literal(1.0), 0, runtime) - checkError( - exception = intercept[SparkException] { - expr.withNewChildrenInternal(IndexedSeq(proxy)) - }, - errorClass = "INTERNAL_ERROR", - parameters = Map( - "message" -> "Child is not Cast or ExpressionProxy of Cast" - ) - ) - - checkError( - exception = intercept[SparkException] { - expr.withNewChildrenInternal(IndexedSeq(Literal(1))) - }, - errorClass = "INTERNAL_ERROR", - parameters = Map( - "message" -> "Child is not Cast or ExpressionProxy of Cast" - ) - ) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index c85fbf84baa..45c7898dfa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.errors import org.apache.spark._ import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.{Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.ByteType @@ -179,6 +179,41 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest } } + test("SPARK-42286: CheckOverflowInTableInsert with CaseWhen should throw an exception") { + val caseWhen = CaseWhen( + Seq((Literal(true), Cast(Literal.apply(12345678901234567890D), ByteType))), None) + checkError( + exception = intercept[SparkArithmeticException] { + CheckOverflowInTableInsert(caseWhen, "col").eval(null) + }.asInstanceOf[SparkThrowable], + errorClass = "CAST_OVERFLOW", + parameters = Map("value" -> "1.2345678901234567E19D", + "sourceType" -> "\"DOUBLE\"", + "targetType" -> ("\"TINYINT\""), + "ansiConfig" -> ansiConf) + ) + } + + test("SPARK-42286: End-to-end query with Case When throwing CAST_OVERFLOW exception") { + withTable("t1", "t2") { + sql("CREATE TABLE t1 (x double) USING parquet") + sql("insert into t1 values (1.2345678901234567E19D)") + sql("CREATE TABLE t2 (x tinyint) USING parquet") + val insertCmd = "insert into t2 select 0 - (case when x = 1.2345678901234567E19D " + + "then 1.2345678901234567E19D else x end) from t1 where x = 1.2345678901234567E19D;" + checkError( + exception = intercept[SparkException] { + sql(insertCmd).collect() + }.getCause.getCause.asInstanceOf[SparkThrowable], + errorClass = "CAST_OVERFLOW", + parameters = Map("value" -> "-1.2345678901234567E19D", + "sourceType" -> "\"DOUBLE\"", + "targetType" -> "\"TINYINT\"", + "ansiConfig" -> ansiConf), + sqlState = "22003") + } + } + test("SPARK-39981: interpreted CheckOverflowInTableInsert should throw an exception") { checkError( exception = intercept[SparkArithmeticException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index c34925ef1bf..cc1d4ab3fcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -2315,6 +2315,16 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } } + + test("SPARK-42286: Insert into a table select from case when with cast, positive test") { + withTable("t1", "t2") { + sql("create table t1 (x int) using parquet") + sql("insert into t1 values (1), (2)") + sql("create table t2 (x Decimal(9, 0)) using parquet") + sql("insert into t2 select 0 - (case when x = 1 then 1 else x end) from t1 where x = 1") + checkAnswer(spark.table("t2"), Row(-1)) + } + } } class FileExistingTestFileSystem extends RawLocalFileSystem { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org