This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 27c03e5741a [SPARK-39175][SQL] Provide runtime error query context for 
Cast when WSCG is off
27c03e5741a is described below

commit 27c03e5741af25b7afacac727865e23f60ce61fa
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Fri May 13 17:46:33 2022 +0800

    [SPARK-39175][SQL] Provide runtime error query context for Cast when WSCG 
is off
    
    ### What changes were proposed in this pull request?
    
    Similar to https://github.com/apache/spark/pull/36525, this PR provides 
runtime error query context for the Cast expression when WSCG is off.
    
    ### Why are the changes needed?
    
    Enhance the runtime error query context of Cast expression. After changes, 
it works when the whole stage codegen is not available.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #36535 from gengliangwang/fixCastContext.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
    (cherry picked from commit cdd33e83c3919c4475e2e1ef387acb604bea81b9)
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 67 ++++++++++++----------
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 27 ++++++++-
 2 files changed, 64 insertions(+), 30 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 335a34514c2..17d571a70f2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -277,7 +277,10 @@ object Cast {
   }
 }
 
-abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression 
with NullIntolerant {
+abstract class CastBase extends UnaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with SupportQueryContext {
 
   def child: Expression
 
@@ -307,6 +310,12 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
 
   protected def ansiEnabled: Boolean
 
+  override def initQueryContext(): String = if (ansiEnabled) {
+    origin.context
+  } else {
+    ""
+  }
+
   // When this cast involves TimeZone, it's only resolved if the timeZoneId is 
set;
   // Otherwise behave like Expression.resolved.
   override lazy val resolved: Boolean =
@@ -467,7 +476,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
           false
         } else {
           if (ansiEnabled) {
-            throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, 
origin.context)
+            throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, 
queryContext)
           } else {
             null
           }
@@ -499,7 +508,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     case StringType =>
       buildCast[UTF8String](_, utfs => {
         if (ansiEnabled) {
-          DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, origin.context)
+          DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, queryContext)
         } else {
           DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull
         }
@@ -524,14 +533,14 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     // TimestampWritable.doubleToTimestamp
     case DoubleType =>
       if (ansiEnabled) {
-        buildCast[Double](_, d => doubleToTimestampAnsi(d, origin.context))
+        buildCast[Double](_, d => doubleToTimestampAnsi(d, queryContext))
       } else {
         buildCast[Double](_, d => doubleToTimestamp(d))
       }
     // TimestampWritable.floatToTimestamp
     case FloatType =>
       if (ansiEnabled) {
-        buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, 
origin.context))
+        buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, 
queryContext))
       } else {
         buildCast[Float](_, f => doubleToTimestamp(f.toDouble))
       }
@@ -541,7 +550,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     case StringType =>
       buildCast[UTF8String](_, utfs => {
         if (ansiEnabled) {
-          DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, 
origin.context)
+          DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, 
queryContext)
         } else {
           DateTimeUtils.stringToTimestampWithoutTimeZone(utfs).orNull
         }
@@ -574,7 +583,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   private[this] def castToDate(from: DataType): Any => Any = from match {
     case StringType =>
       if (ansiEnabled) {
-        buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, 
origin.context))
+        buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, 
queryContext))
       } else {
         buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull)
       }
@@ -631,7 +640,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   // LongConverter
   private[this] def castToLong(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, 
origin.context))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, 
queryContext))
     case StringType =>
       val result = new LongWrapper()
       buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else 
null)
@@ -654,7 +663,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   // IntConverter
   private[this] def castToInt(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, 
origin.context))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, 
queryContext))
     case StringType =>
       val result = new IntWrapper()
       buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else 
null)
@@ -686,7 +695,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   // ShortConverter
   private[this] def castToShort(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, 
origin.context))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, 
queryContext))
     case StringType =>
       val result = new IntWrapper()
       buildCast[UTF8String](_, s => if (s.toShort(result)) {
@@ -733,7 +742,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   // ByteConverter
   private[this] def castToByte(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, 
origin.context))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, 
queryContext))
     case StringType =>
       val result = new IntWrapper()
       buildCast[UTF8String](_, s => if (s.toByte(result)) {
@@ -793,7 +802,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
         null
       } else {
         throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
-          value, decimalType.precision, decimalType.scale, origin.context)
+          value, decimalType.precision, decimalType.scale, queryContext)
       }
     }
   }
@@ -816,7 +825,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
       })
     case StringType if ansiEnabled =>
       buildCast[UTF8String](_,
-        s => changePrecision(Decimal.fromStringANSI(s, target, 
origin.context), target))
+        s => changePrecision(Decimal.fromStringANSI(s, target, queryContext), 
target))
     case BooleanType =>
       buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else 
Decimal.ZERO, target))
     case DateType =>
@@ -846,7 +855,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
             val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
             if(ansiEnabled && d == null) {
               throw QueryExecutionErrors.invalidInputSyntaxForNumericError(
-                DoubleType, s, origin.context)
+                DoubleType, s, queryContext)
             } else {
               d
             }
@@ -872,7 +881,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
             val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
             if (ansiEnabled && f == null) {
               throw QueryExecutionErrors.invalidInputSyntaxForNumericError(
-                FloatType, s, origin.context)
+                FloatType, s, queryContext)
             } else {
               f
             }
@@ -988,7 +997,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   }
 
   def errorContextCode(codegenContext: CodegenContext): String = {
-    codegenContext.addReferenceObj("errCtx", origin.context)
+    codegenContext.addReferenceObj("errCtx", queryContext)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -1298,7 +1307,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
         val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]])
         (c, evPrim, evNull) =>
           if (ansiEnabled) {
-            val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+            val errorContext = ctx.addReferenceObj("errCtx", queryContext)
             code"""
               $evPrim = $dateTimeUtilsCls.stringToDateAnsi($c, $errorContext);
             """
@@ -1377,7 +1386,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
               }
           """
       case StringType if ansiEnabled =>
-        val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+        val errorContext = ctx.addReferenceObj("errCtx", queryContext)
         val toType = ctx.addReferenceObj("toType", target)
         (c, evPrim, evNull) =>
           code"""
@@ -1438,7 +1447,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
       val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
           code"""
             $evPrim = $dateTimeUtilsCls.stringToTimestampAnsi($c, $zid, 
$errorContext);
            """
@@ -1477,7 +1486,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     case DoubleType =>
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
           code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi($c, 
$errorContext);"
         } else {
           code"""
@@ -1491,7 +1500,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     case FloatType =>
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
           code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi((double)$c, 
$errorContext);"
         } else {
           code"""
@@ -1511,7 +1520,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
       val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
           code"""
             $evPrim = 
$dateTimeUtilsCls.stringToTimestampWithoutTimeZoneAnsi($c, $errorContext);
            """
@@ -1628,7 +1637,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
       val stringUtils = 
inline"${StringUtils.getClass.getName.stripSuffix("$")}"
       (c, evPrim, evNull) =>
         val castFailureCode = if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
           s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, 
$errorContext);"
         } else {
           s"$evNull = true;"
@@ -1763,7 +1772,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   private[this] def castToByteCode(from: DataType, ctx: CodegenContext): 
CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = 
UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toByteExact($c, 
$errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("intWrapper", 
classOf[UTF8String.IntWrapper])
@@ -1800,7 +1809,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
       ctx: CodegenContext): CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = 
UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toShortExact($c, 
$errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("intWrapper", 
classOf[UTF8String.IntWrapper])
@@ -1835,7 +1844,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   private[this] def castToIntCode(from: DataType, ctx: CodegenContext): 
CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = 
UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toIntExact($c, 
$errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("intWrapper", 
classOf[UTF8String.IntWrapper])
@@ -1870,7 +1879,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   private[this] def castToLongCode(from: DataType, ctx: CodegenContext): 
CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = 
UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toLongExact($c, 
$errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("longWrapper", 
classOf[UTF8String.LongWrapper])
@@ -1907,7 +1916,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
         val floatStr = ctx.freshVariable("floatStr", StringType)
         (c, evPrim, evNull) =>
           val handleNull = if (ansiEnabled) {
-            val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+            val errorContext = ctx.addReferenceObj("errCtx", queryContext)
             s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError(" +
               s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, 
$errorContext);"
           } else {
@@ -1945,7 +1954,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
         val doubleStr = ctx.freshVariable("doubleStr", StringType)
         (c, evPrim, evNull) =>
           val handleNull = if (ansiEnabled) {
-            val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+            val errorContext = ctx.addReferenceObj("errCtx", queryContext)
             s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError(" +
               s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, 
$errorContext);"
           } else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f099d3c015c..68db57ea364 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4359,7 +4359,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
     }
   }
 
-  test("SPARK-39166: Query context should be serialized to executors when WSCG 
is off") {
+  test("SPARK-39166: Query context of binary arithmetic should be serialized 
to executors" +
+    " when WSCG is off") {
     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
       SQLConf.ANSI_ENABLED.key -> "true") {
       withTable("t") {
@@ -4379,6 +4380,30 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
     }
   }
 
+  test("SPARK-39175: Query context of Cast should be serialized to executors" +
+    " when WSCG is off") {
+    withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+      SQLConf.ANSI_ENABLED.key -> "true") {
+      withTable("t") {
+        sql("create table t(s string) using parquet")
+        sql("insert into t values('a')")
+        Seq(
+          "select cast(s as int) from t",
+          "select cast(s as long) from t",
+          "select cast(s as double) from t",
+          "select cast(s as decimal(10, 2)) from t",
+          "select cast(s as date) from t",
+          "select cast(s as timestamp) from t",
+          "select cast(s as boolean) from t").foreach { query =>
+          val msg = intercept[SparkException] {
+            sql(query).collect()
+          }.getMessage
+          assert(msg.contains(query))
+        }
+      }
+    }
+  }
+
   test("SPARK-38589: try_avg should return null if overflow happens before 
merging") {
     val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
       .map(Period.ofMonths)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to