This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9ee2c753b98 [SPARK-39923][SQL] Multiple query contexts in Spark exceptions 9ee2c753b98 is described below commit 9ee2c753b98b290fab9b2ec1f02d90c7c9441271 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Mon Aug 1 13:40:22 2022 +0500 [SPARK-39923][SQL] Multiple query contexts in Spark exceptions ### What changes were proposed in this pull request? 1. Replace `Option[QueryContext]` by `Array[QueryContext]` in Spark exceptions like in `SparkRuntimeException`. 2. Pass `SQLQueryContext` to `QueryExecutionErrors` functions instead of `Option[SQLQueryContext]`. 3. Add the methods `getContextOrNull()` and `getContextOrNullCode()` to `SupportQueryContext` to get a SQL query context or `null` (if it is missed) of an expression. ### Why are the changes needed? 1. The changes will allow to chain multiple error contexts in Spark's exception. For instance, if user's query refers a view v1, v1 refers another view v2, and v2 does a division. The error contexts will be: sql fragment of v2 that does division -> sql fragment of v1 that refers v2 -> sql fragment of your query that refers v1. 2. Passing `SQLQueryContext` to `QueryExecutionErrors` directly simplifies codegen code because it allows to avoid construction of Scala objects like `scala.None`. ### Does this PR introduce _any_ user-facing change? Yes, this PR changes user-facing exceptions. ### How was this patch tested? By running the modified test suites: ``` $ build/sbt "test:testOnly *DecimalExpressionSuite" ``` and potentially affected tests: ``` $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite" ``` Closes #37343 from MaxGekk/array-as-query-context. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../scala/org/apache/spark/SparkException.scala | 28 ++++----- .../spark/sql/catalyst/expressions/Cast.scala | 67 +++++++++++----------- .../sql/catalyst/expressions/Expression.scala | 10 ++++ .../catalyst/expressions/aggregate/Average.scala | 6 +- .../sql/catalyst/expressions/aggregate/Sum.scala | 23 ++++---- .../sql/catalyst/expressions/arithmetic.scala | 48 +++++++++------- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeExtractors.scala | 8 +-- .../catalyst/expressions/decimalExpressions.scala | 32 ++++------- .../catalyst/expressions/intervalExpressions.scala | 16 +++--- .../sql/catalyst/expressions/mathExpressions.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 5 +- .../spark/sql/catalyst/util/DateTimeUtils.scala | 10 ++-- .../spark/sql/catalyst/util/IntervalUtils.scala | 2 +- .../apache/spark/sql/catalyst/util/MathUtils.scala | 14 ++--- .../spark/sql/catalyst/util/UTF8StringUtils.scala | 10 ++-- .../apache/spark/sql/errors/QueryErrorsBase.scala | 9 ++- .../spark/sql/errors/QueryExecutionErrors.scala | 54 ++++++++--------- .../scala/org/apache/spark/sql/types/Decimal.scala | 4 +- .../expressions/DecimalExpressionSuite.scala | 2 +- 20 files changed, 182 insertions(+), 172 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index d6add48ffb1..6548a114d41 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -119,7 +119,7 @@ private[spark] class SparkArithmeticException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - context: Option[QueryContext], + context: Array[QueryContext], summary: String) extends ArithmeticException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) @@ -128,7 +128,7 @@ private[spark] class SparkArithmeticException( override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull - override def getQueryContext: Array[QueryContext] = context.toArray + override def getQueryContext: Array[QueryContext] = context } /** @@ -195,7 +195,7 @@ private[spark] class SparkDateTimeException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - context: Option[QueryContext], + context: Array[QueryContext], summary: String) extends DateTimeException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) @@ -204,7 +204,7 @@ private[spark] class SparkDateTimeException( override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull - override def getQueryContext: Array[QueryContext] = context.toArray + override def getQueryContext: Array[QueryContext] = context } /** @@ -244,7 +244,7 @@ private[spark] class SparkNumberFormatException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - context: Option[QueryContext], + context: Array[QueryContext], summary: String) extends NumberFormatException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) @@ -253,7 +253,7 @@ private[spark] class SparkNumberFormatException( override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull - override def getQueryContext: Array[QueryContext] = context.toArray + override def getQueryContext: Array[QueryContext] = context } /** @@ -323,7 +323,7 @@ private[spark] class SparkRuntimeException( errorSubClass: Option[String] = None, messageParameters: Array[String], cause: Throwable = null, - context: Option[QueryContext] = None, + context: Array[QueryContext] = Array.empty, summary: String = "") extends RuntimeException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary), @@ -334,7 +334,7 @@ private[spark] class SparkRuntimeException( errorSubClass: String, messageParameters: Array[String], cause: Throwable, - context: Option[QueryContext]) + context: Array[QueryContext]) = this(errorClass = errorClass, errorSubClass = Some(errorSubClass), messageParameters = messageParameters, @@ -348,12 +348,12 @@ private[spark] class SparkRuntimeException( errorSubClass = Some(errorSubClass), messageParameters = messageParameters, cause = null, - context = None) + context = Array.empty[QueryContext]) override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull - override def getQueryContext: Array[QueryContext] = context.toArray + override def getQueryContext: Array[QueryContext] = context } /** @@ -379,7 +379,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - context: Option[QueryContext], + context: Array[QueryContext], summary: String) extends ArrayIndexOutOfBoundsException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) @@ -388,7 +388,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException( override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull - override def getQueryContext: Array[QueryContext] = context.toArray + override def getQueryContext: Array[QueryContext] = context } /** @@ -420,7 +420,7 @@ private[spark] class SparkNoSuchElementException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - context: Option[QueryContext], + context: Array[QueryContext], summary: String) extends NoSuchElementException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) @@ -429,7 +429,7 @@ private[spark] class SparkNoSuchElementException( override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull - override def getQueryContext: Array[QueryContext] = context.toArray + override def getQueryContext: Array[QueryContext] = context } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0ba651b5650..f740ecd9dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -653,7 +653,7 @@ case class Cast( false } else { if (ansiEnabled) { - throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, queryContext) + throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull()) } else { null } @@ -685,7 +685,7 @@ case class Cast( case StringType => buildCast[UTF8String](_, utfs => { if (ansiEnabled) { - DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, queryContext) + DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, getContextOrNull()) } else { DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull } @@ -710,14 +710,14 @@ case class Cast( // TimestampWritable.doubleToTimestamp case DoubleType => if (ansiEnabled) { - buildCast[Double](_, d => doubleToTimestampAnsi(d, queryContext)) + buildCast[Double](_, d => doubleToTimestampAnsi(d, getContextOrNull())) } else { buildCast[Double](_, d => doubleToTimestamp(d)) } // TimestampWritable.floatToTimestamp case FloatType => if (ansiEnabled) { - buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, queryContext)) + buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, getContextOrNull())) } else { buildCast[Float](_, f => doubleToTimestamp(f.toDouble)) } @@ -727,7 +727,7 @@ case class Cast( case StringType => buildCast[UTF8String](_, utfs => { if (ansiEnabled) { - DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, queryContext) + DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, getContextOrNull()) } else { DateTimeUtils.stringToTimestampWithoutTimeZone(utfs).orNull } @@ -760,7 +760,7 @@ case class Cast( private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => if (ansiEnabled) { - buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, queryContext)) + buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, getContextOrNull())) } else { buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull) } @@ -817,7 +817,7 @@ case class Cast( // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType if ansiEnabled => - buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, queryContext)) + buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, getContextOrNull())) case StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) @@ -840,7 +840,7 @@ case class Cast( // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType if ansiEnabled => - buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, queryContext)) + buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, getContextOrNull())) case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) @@ -872,7 +872,7 @@ case class Cast( // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType if ansiEnabled => - buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, queryContext)) + buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, getContextOrNull())) case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toShort(result)) { @@ -919,7 +919,7 @@ case class Cast( // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType if ansiEnabled => - buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, queryContext)) + buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, getContextOrNull())) case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toByte(result)) { @@ -986,7 +986,7 @@ case class Cast( null } else { throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - value, decimalType.precision, decimalType.scale, queryContext) + value, decimalType.precision, decimalType.scale, getContextOrNull()) } } } @@ -999,7 +999,7 @@ case class Cast( private[this] def toPrecision( value: Decimal, decimalType: DecimalType, - context: Option[SQLQueryContext]): Decimal = + context: SQLQueryContext): Decimal = value.toPrecision( decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context) @@ -1012,17 +1012,17 @@ case class Cast( }) case StringType if ansiEnabled => buildCast[UTF8String](_, - s => changePrecision(Decimal.fromStringANSI(s, target, queryContext), target)) + s => changePrecision(Decimal.fromStringANSI(s, target, getContextOrNull()), target)) case BooleanType => buildCast[Boolean](_, - b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, queryContext)) + b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, getContextOrNull())) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case dt: DecimalType => - b => toPrecision(b.asInstanceOf[Decimal], target, queryContext) + b => toPrecision(b.asInstanceOf[Decimal], target, getContextOrNull()) case t: IntegralType => b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) case x: FractionalType => @@ -1055,7 +1055,7 @@ case class Cast( val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false) if(ansiEnabled && d == null) { throw QueryExecutionErrors.invalidInputInCastToNumberError( - DoubleType, s, queryContext) + DoubleType, s, getContextOrNull()) } else { d } @@ -1081,7 +1081,7 @@ case class Cast( val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) if (ansiEnabled && f == null) { throw QueryExecutionErrors.invalidInputInCastToNumberError( - FloatType, s, queryContext) + FloatType, s, getContextOrNull()) } else { f } @@ -1196,10 +1196,6 @@ case class Cast( } } - def errorContextCode(codegenContext: CodegenContext): String = { - codegenContext.addReferenceObj("errCtx", queryContext) - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) @@ -1512,7 +1508,7 @@ case class Cast( val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]]) (c, evPrim, evNull) => if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) code""" $evPrim = $dateTimeUtilsCls.stringToDateAnsi($c, $errorContext); """ @@ -1556,12 +1552,13 @@ case class Cast( |$evPrim = $d; """.stripMargin } else { + val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) val overflowCode = if (nullOnOverflow) { s"$evNull = true;" } else { s""" |throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - | $d, ${decimalType.precision}, ${decimalType.scale}, ${errorContextCode(ctx)}); + | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode); """.stripMargin } code""" @@ -1602,7 +1599,7 @@ case class Cast( } """ case StringType if ansiEnabled => - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) val toType = ctx.addReferenceObj("toType", target) (c, evPrim, evNull) => code""" @@ -1679,7 +1676,7 @@ case class Cast( val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) code""" $evPrim = $dateTimeUtilsCls.stringToTimestampAnsi($c, $zid, $errorContext); """ @@ -1718,7 +1715,7 @@ case class Cast( case DoubleType => (c, evPrim, evNull) => if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi($c, $errorContext);" } else { code""" @@ -1732,7 +1729,7 @@ case class Cast( case FloatType => (c, evPrim, evNull) => if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi((double)$c, $errorContext);" } else { code""" @@ -1752,7 +1749,7 @@ case class Cast( val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) code""" $evPrim = $dateTimeUtilsCls.stringToTimestampWithoutTimeZoneAnsi($c, $errorContext); """ @@ -1869,7 +1866,7 @@ case class Cast( val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => val castFailureCode = if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);" } else { s"$evNull = true;" @@ -2004,7 +2001,7 @@ case class Cast( private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toByteExact($c, $errorContext);" case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) @@ -2041,7 +2038,7 @@ case class Cast( ctx: CodegenContext): CastFunction = from match { case StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toShortExact($c, $errorContext);" case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) @@ -2076,7 +2073,7 @@ case class Cast( private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toIntExact($c, $errorContext);" case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) @@ -2111,7 +2108,7 @@ case class Cast( private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toLongExact($c, $errorContext);" case StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) @@ -2148,7 +2145,7 @@ case class Cast( val floatStr = ctx.freshVariable("floatStr", StringType) (c, evPrim, evNull) => val handleNull = if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) s"throw QueryExecutionErrors.invalidInputInCastToNumberError(" + s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, $errorContext);" } else { @@ -2186,7 +2183,7 @@ case class Cast( val doubleStr = ctx.freshVariable("doubleStr", StringType) (c, evPrim, evNull) => val handleNull = if (ansiEnabled) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) s"throw QueryExecutionErrors.invalidInputInCastToNumberError(" + s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, $errorContext);" } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d623357b9da..261d9a0cb63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -597,6 +597,16 @@ trait SupportQueryContext extends Expression with Serializable { def initQueryContext(): Option[SQLQueryContext] + def getContextOrNull(): SQLQueryContext = queryContext.getOrElse(null) + + def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = { + if (withErrorContext && queryContext.isDefined) { + ctx.addReferenceObj("errCtx", queryContext.get) + } else { + "null" + } + } + // Note: Even though query contexts are serialized to executors, it will be regenerated from an // empty "Origin" during rule transforms since "Origin"s are not serialized to executors // for better performance. Thus, we need to copy the original query context during diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index b749dfdaea1..36ffcd8f764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -86,7 +86,7 @@ abstract class AverageBase // If all input are nulls, count will be 0 and we will get null after the division. // We can't directly use `/` as it throws an exception under ansi mode. - protected def getEvaluateExpression(context: Option[SQLQueryContext]) = child.dataType match { + protected def getEvaluateExpression(context: SQLQueryContext = null) = child.dataType match { case _: DecimalType => If(EqualTo(count, Literal(0L)), Literal(null, resultType), @@ -141,7 +141,7 @@ case class Average( override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions - override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext) + override lazy val evaluateExpression: Expression = getEvaluateExpression(getContextOrNull()) override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) { Some(origin.context) @@ -206,7 +206,7 @@ case class TryAverage(child: Expression) extends AverageBase { } override lazy val evaluateExpression: Expression = { - addTryEvalIfNeeded(getEvaluateExpression(None)) + addTryEvalIfNeeded(getEvaluateExpression()) } override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 9230bd9bf44..e8492c0e5dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -148,14 +148,15 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate * So now, if ansi is enabled, then throw exception, if not then return null. * If sum is not null, then return the sum. */ - protected def getEvaluateExpression( - context: Option[SQLQueryContext]): Expression = resultType match { - case d: DecimalType => - val checkOverflowInSum = CheckOverflowInSum(sum, d, !useAnsiAdd, context) - If(isEmpty, Literal.create(null, resultType), checkOverflowInSum) - case _ if shouldTrackIsEmpty => - If(isEmpty, Literal.create(null, resultType), sum) - case _ => sum + protected def getEvaluateExpression(context: SQLQueryContext = null): Expression = { + resultType match { + case d: DecimalType => + val checkOverflowInSum = CheckOverflowInSum(sum, d, !useAnsiAdd, context) + If(isEmpty, Literal.create(null, resultType), checkOverflowInSum) + case _ if shouldTrackIsEmpty => + If(isEmpty, Literal.create(null, resultType), sum) + case _ => sum + } } // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods @@ -192,7 +193,7 @@ case class Sum( override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions - override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext) + override lazy val evaluateExpression: Expression = getEvaluateExpression(getContextOrNull()) override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) { Some(origin.context) @@ -255,9 +256,9 @@ case class TrySum(child: Expression) extends SumBase(child) { override lazy val evaluateExpression: Expression = if (useAnsiAdd) { - TryEval(getEvaluateExpression(None)) + TryEval(getEvaluateExpression()) } else { - getEvaluateExpression(None) + getEvaluateExpression() } override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7bbe5d15b91..86e6e6d7323 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -265,7 +265,7 @@ abstract class BinaryArithmetic extends BinaryOperator } protected def checkDecimalOverflow(value: Decimal, precision: Int, scale: Int): Decimal = { - value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, queryContext) + value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, getContextOrNull()) } /** Name of the function for this expression on a [[Decimal]] type. */ @@ -285,11 +285,7 @@ abstract class BinaryArithmetic extends BinaryOperator override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case DecimalType.Fixed(precision, scale) => - val errorContextCode = if (failOnError) { - ctx.addReferenceObj("errCtx", queryContext) - } else { - "scala.None$.MODULE$" - } + val errorContextCode = getContextOrNullCode(ctx, failOnError) val updateIsNull = if (failOnError) { "" } else { @@ -334,7 +330,7 @@ abstract class BinaryArithmetic extends BinaryOperator }) case IntegerType | LongType if failOnError && exactMathMethod.isDefined => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") s""" |${ev.value} = $mathUtils.${exactMathMethod.get}($eval1, $eval2, $errorContext); @@ -414,9 +410,9 @@ case class Add( case _: YearMonthIntervalType => MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _: IntegerType if failOnError => - MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext) + MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], getContextOrNull()) case _: LongType if failOnError => - MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext) + MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], getContextOrNull()) case _ => numeric.plus(input1, input2) } @@ -483,9 +479,15 @@ case class Subtract( case _: YearMonthIntervalType => MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _: IntegerType if failOnError => - MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext) + MathUtils.subtractExact( + input1.asInstanceOf[Int], + input2.asInstanceOf[Int], + getContextOrNull()) case _: LongType if failOnError => - MathUtils.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext) + MathUtils.subtractExact( + input1.asInstanceOf[Long], + input2.asInstanceOf[Long], + getContextOrNull()) case _ => numeric.minus(input1, input2) } @@ -539,9 +541,15 @@ case class Multiply( case DecimalType.Fixed(precision, scale) => checkDecimalOverflow(numeric.times(input1, input2).asInstanceOf[Decimal], precision, scale) case _: IntegerType if failOnError => - MathUtils.multiplyExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext) + MathUtils.multiplyExact( + input1.asInstanceOf[Int], + input2.asInstanceOf[Int], + getContextOrNull()) case _: LongType if failOnError => - MathUtils.multiplyExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext) + MathUtils.multiplyExact( + input1.asInstanceOf[Long], + input2.asInstanceOf[Long], + getContextOrNull()) case _ => numeric.times(input1, input2) } @@ -578,10 +586,10 @@ trait DivModLike extends BinaryArithmetic { } else { if (isZero(input2)) { // when we reach here, failOnError must be true. - throw QueryExecutionErrors.divideByZeroError(queryContext) + throw QueryExecutionErrors.divideByZeroError(getContextOrNull()) } if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) { - throw QueryExecutionErrors.overflowInIntegralDivideError(queryContext) + throw QueryExecutionErrors.overflowInIntegralDivideError(getContextOrNull()) } evalOperation(input1, input2) } @@ -603,11 +611,7 @@ trait DivModLike extends BinaryArithmetic { s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val errorContextCode = if (failOnError) { - ctx.addReferenceObj("errCtx", queryContext) - } else { - "scala.None$.MODULE$" - } + val errorContextCode = getContextOrNullCode(ctx, failOnError) val operation = super.dataType match { case DecimalType.Fixed(precision, scale) => val decimalValue = ctx.freshName("decimalValue") @@ -962,7 +966,7 @@ case class Pmod( } else { if (isZero(input2)) { // when we reach here, failOnError must bet true. - throw QueryExecutionErrors.divideByZeroError(queryContext) + throw QueryExecutionErrors.divideByZeroError(getContextOrNull()) } pmodFunc(input1, input2) } @@ -979,7 +983,7 @@ case class Pmod( } val remainder = ctx.freshName("remainder") val javaType = CodeGenerator.javaType(dataType) - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) val result = dataType match { case DecimalType.Fixed(precision, scale) => val decimalAdd = "$plus" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 098b3a88084..ae23775b62d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2174,7 +2174,7 @@ case class ElementAt( if (array.numElements() < math.abs(index)) { if (failOnError) { throw QueryExecutionErrors.invalidElementAtIndexError( - index, array.numElements(), queryContext) + index, array.numElements(), getContextOrNull()) } else { defaultValueOutOfBound match { case Some(value) => value.eval() @@ -2216,7 +2216,7 @@ case class ElementAt( } val indexOutOfBoundBranch = if (failOnError) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) // scalastyle:off line.size.limit s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements(), $errorContext);" // scalastyle:on line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index fedfcfb978f..b6cbb1d0005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -268,7 +268,7 @@ case class GetArrayItem( if (index >= baseValue.numElements() || index < 0) { if (failOnError) { throw QueryExecutionErrors.invalidArrayIndexError( - index, baseValue.numElements, queryContext) + index, baseValue.numElements, getContextOrNull()) } else { null } @@ -292,7 +292,7 @@ case class GetArrayItem( } val indexOutOfBoundBranch = if (failOnError) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) // scalastyle:off line.size.limit s"throw QueryExecutionErrors.invalidArrayIndexError($index, $eval1.numElements(), $errorContext);" // scalastyle:on line.size.limit @@ -380,7 +380,7 @@ trait GetMapValueUtil if (!found) { if (failOnError) { - throw QueryExecutionErrors.mapKeyNotExistError(ordinal, keyType, queryContext) + throw QueryExecutionErrors.mapKeyNotExistError(ordinal, keyType, getContextOrNull()) } else { null } @@ -413,7 +413,7 @@ trait GetMapValueUtil } val keyJavaType = CodeGenerator.javaType(keyType) - lazy val errorContext = ctx.addReferenceObj("errCtx", queryContext) + lazy val errorContext = getContextOrNullCode(ctx) val keyDt = ctx.addReferenceObj("keyType", keyType, keyType.getClass.getName) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val keyNotFoundBranch = if (failOnError) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index e672fffda19..37e3dd5ea89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -123,14 +123,10 @@ case class CheckOverflow( dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow, - queryContext) + getContextOrNull()) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errorContextCode = if (nullOnOverflow) { - "scala.None$.MODULE$" - } else { - ctx.addReferenceObj("errCtx", queryContext) - } + val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) nullSafeCodeGen(ctx, ev, eval => { // scalastyle:off line.size.limit s""" @@ -161,7 +157,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - context: Option[SQLQueryContext] = None) extends UnaryExpression { + context: SQLQueryContext) extends UnaryExpression with SupportQueryContext { override def nullable: Boolean = true @@ -182,11 +178,7 @@ case class CheckOverflowInSum( override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - val errorContextCode = if (nullOnOverflow) { - "scala.None$.MODULE$" - } else { - ctx.addReferenceObj("errCtx", context) - } + val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) val nullHandling = if (nullOnOverflow) { "" } else { @@ -216,6 +208,8 @@ case class CheckOverflowInSum( override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = copy(child = newChild) + + override def initQueryContext(): Option[SQLQueryContext] = Option(context) } /** @@ -261,12 +255,12 @@ case class DecimalDivideWithOverflowCheck( left: Expression, right: Expression, override val dataType: DecimalType, - context: Option[SQLQueryContext], + context: SQLQueryContext, nullOnOverflow: Boolean) extends BinaryExpression with ExpectsInputTypes with SupportQueryContext { override def nullable: Boolean = nullOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType) - override def initQueryContext(): Option[SQLQueryContext] = context + override def initQueryContext(): Option[SQLQueryContext] = Option(context) def decimalMethod: String = "$div" override def eval(input: InternalRow): Any = { @@ -275,22 +269,18 @@ case class DecimalDivideWithOverflowCheck( if (nullOnOverflow) { null } else { - throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext) + throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull()) } } else { val value2 = right.eval(input) dataType.fractional.asInstanceOf[Fractional[Any]].div(value1, value2).asInstanceOf[Decimal] .toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow, - queryContext) + getContextOrNull()) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errorContextCode = if (nullOnOverflow) { - "scala.None$.MODULE$" - } else { - ctx.addReferenceObj("errCtx", queryContext) - } + val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) val nullHandling = if (nullOnOverflow) { "" } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 17a2714c611..f7ec82de11b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -604,7 +604,7 @@ trait IntervalDivide { minValue: Any, num: Expression, numValue: Any, - context: Option[SQLQueryContext]): Unit = { + context: SQLQueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { throw QueryExecutionErrors.overflowInIntegralDivideError(context) @@ -615,7 +615,7 @@ trait IntervalDivide { def divideByZeroCheck( dataType: DataType, num: Any, - context: Option[SQLQueryContext]): Unit = dataType match { + context: SQLQueryContext): Unit = dataType match { case _: DecimalType => if (num.asInstanceOf[Decimal].isZero) { throw QueryExecutionErrors.intervalDividedByZeroError(context) @@ -665,13 +665,13 @@ case class DivideYMInterval( override def nullSafeEval(interval: Any, num: Any): Any = { checkDivideOverflow( - interval.asInstanceOf[Int], Int.MinValue, right, num, Some(origin.context)) - divideByZeroCheck(right.dataType, num, Some(origin.context)) + interval.asInstanceOf[Int], Int.MinValue, right, num, origin.context) + divideByZeroCheck(right.dataType, num, origin.context) evalFunc(interval.asInstanceOf[Int], num) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context)) + val errorContext = ctx.addReferenceObj("errCtx", origin.context) right.dataType match { case t: IntegralType => val math = t match { @@ -743,13 +743,13 @@ case class DivideDTInterval( override def nullSafeEval(interval: Any, num: Any): Any = { checkDivideOverflow( - interval.asInstanceOf[Long], Long.MinValue, right, num, Some(origin.context)) - divideByZeroCheck(right.dataType, num, Some(origin.context)) + interval.asInstanceOf[Long], Long.MinValue, right, num, origin.context) + divideByZeroCheck(right.dataType, num, origin.context) evalFunc(interval.asInstanceOf[Long], num) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context)) + val errorContext = ctx.addReferenceObj("errCtx", origin.context) right.dataType match { case _: IntegralType => val math = classOf[LongMath].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 55ff36e9863..dfbc041b259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1520,7 +1520,7 @@ abstract class RoundBase(child: Expression, scale: Expression, if (_scale >= 0) { s""" ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, - Decimal.$modeStr(), true, scala.None$$.MODULE$$); + Decimal.$modeStr(), true, null); ${ev.isNull} = ${ev.value} == null;""" } else { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 815eb8977b6..d4504c36e4e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -296,7 +296,8 @@ case class Elt( val index = indexObj.asInstanceOf[Int] if (index <= 0 || index > inputExprs.length) { if (failOnError) { - throw QueryExecutionErrors.invalidArrayIndexError(index, inputExprs.length, queryContext) + throw QueryExecutionErrors.invalidArrayIndexError( + index, inputExprs.length, getContextOrNull()) } else { null } @@ -348,7 +349,7 @@ case class Elt( }.mkString) val indexOutOfBoundBranch = if (failOnError) { - val errorContext = ctx.addReferenceObj("errCtx", queryContext) + val errorContext = getContextOrNullCode(ctx) // scalastyle:off line.size.limit s""" |if (!$indexMatched) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 172c2e54034..af0666a98fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -468,14 +468,14 @@ object DateTimeUtils { def stringToTimestampAnsi( s: UTF8String, timeZoneId: ZoneId, - context: Option[SQLQueryContext] = None): Long = { + context: SQLQueryContext = null): Long = { stringToTimestamp(s, timeZoneId).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( s, StringType, TimestampType, context) } } - def doubleToTimestampAnsi(d: Double, context: Option[SQLQueryContext]): Long = { + def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = { if (d.isNaN || d.isInfinite) { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( d, DoubleType, TimestampType, context) @@ -527,7 +527,7 @@ object DateTimeUtils { def stringToTimestampWithoutTimeZoneAnsi( s: UTF8String, - context: Option[SQLQueryContext]): Long = { + context: SQLQueryContext): Long = { stringToTimestampWithoutTimeZone(s, true).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( s, StringType, TimestampNTZType, context) @@ -646,7 +646,9 @@ object DateTimeUtils { } } - def stringToDateAnsi(s: UTF8String, context: Option[SQLQueryContext] = None): Int = { + def stringToDateAnsi( + s: UTF8String, + context: SQLQueryContext = null): Int = { stringToDate(s).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( s, StringType, DateType, context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index b4695062c08..f2c4236ad7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -733,7 +733,7 @@ object IntervalUtils { * @throws ArithmeticException if the result overflows any field value or divided by zero */ def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = { - if (num == 0) throw QueryExecutionErrors.intervalDividedByZeroError(None) + if (num == 0) throw QueryExecutionErrors.intervalDividedByZeroError(null) fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 6cb3616d4e7..e79e483076d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -27,37 +27,37 @@ object MathUtils { def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b)) - def addExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = { + def addExact(a: Int, b: Int, context: SQLQueryContext): Int = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b)) - def addExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = { + def addExact(a: Long, b: Long, context: SQLQueryContext): Long = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = { + def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = { + def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = { + def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = { + def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } @@ -78,7 +78,7 @@ object MathUtils { private def withOverflow[A]( f: => A, hint: String = "", - context: Option[SQLQueryContext] = None): A = { + context: SQLQueryContext = null): A = { try { f } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index 503c0e181ca..f7800469c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String */ object UTF8StringUtils { - def toLongExact(s: UTF8String, context: Option[SQLQueryContext]): Long = + def toLongExact(s: UTF8String, context: SQLQueryContext): Long = withException(s.toLongExact, context, LongType, s) - def toIntExact(s: UTF8String, context: Option[SQLQueryContext]): Int = + def toIntExact(s: UTF8String, context: SQLQueryContext): Int = withException(s.toIntExact, context, IntegerType, s) - def toShortExact(s: UTF8String, context: Option[SQLQueryContext]): Short = + def toShortExact(s: UTF8String, context: SQLQueryContext): Short = withException(s.toShortExact, context, ShortType, s) - def toByteExact(s: UTF8String, context: Option[SQLQueryContext]): Byte = + def toByteExact(s: UTF8String, context: SQLQueryContext): Byte = withException(s.toByteExact, context, ByteType, s) private def withException[A]( f: => A, - context: Option[SQLQueryContext], + context: SQLQueryContext, to: DataType, s: UTF8String): A = { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala index 9617f7d4b0f..4785073f80b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.errors import java.util.Locale +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.trees.SQLQueryContext @@ -97,7 +98,11 @@ private[sql] trait QueryErrorsBase { quoteByDefault(toPrettySQL(e)) } - def getSummary(context: Option[SQLQueryContext]): String = { - context.map(_.summary).getOrElse("") + def getSummary(sqlContext: SQLQueryContext): String = { + if (sqlContext == null) "" else sqlContext.summary + } + + def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { + if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index bad95afa139..3644e7c0df8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -89,7 +89,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(from), toSQLType(to), toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = None, + context = Array.empty, summary = "") } @@ -103,7 +103,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(from), toSQLType(to), toSQLId(columnName)), - context = None, + context = Array.empty, summary = "" ) } @@ -112,7 +112,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: Option[SQLQueryContext] = None): ArithmeticException = { + context: SQLQueryContext = null): ArithmeticException = { new SparkArithmeticException( errorClass = "CANNOT_CHANGE_DECIMAL_PRECISION", messageParameters = Array( @@ -120,7 +120,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { decimalPrecision.toString, decimalScale.toString, toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } @@ -128,7 +128,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { value: Any, from: DataType, to: DataType, - context: Option[SQLQueryContext]): Throwable = { + context: SQLQueryContext): Throwable = { new SparkDateTimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Array( @@ -136,13 +136,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(from), toSQLType(to), toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } def invalidInputSyntaxForBooleanError( s: UTF8String, - context: Option[SQLQueryContext]): SparkRuntimeException = { + context: SQLQueryContext): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Array( @@ -150,14 +150,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(StringType), toSQLType(BooleanType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: Option[SQLQueryContext]): SparkNumberFormatException = { + context: SQLQueryContext): SparkNumberFormatException = { new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", messageParameters = Array( @@ -165,7 +165,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(StringType), toSQLType(to), toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } @@ -196,40 +196,40 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Array(funcCls, inputTypes, outputType), e) } - def divideByZeroError(context: Option[SQLQueryContext]): ArithmeticException = { + def divideByZeroError(context: SQLQueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", messageParameters = Array(toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } - def intervalDividedByZeroError(context: Option[SQLQueryContext]): ArithmeticException = { + def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_DIVIDED_BY_ZERO", messageParameters = Array.empty, - context = context, + context = getQueryContext(context), summary = getSummary(context)) } def invalidArrayIndexError( index: Int, numElements: Int, - context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = { + context: SQLQueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX", messageParameters = Array( toSQLValue(index, IntegerType), toSQLValue(numElements, IntegerType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } def invalidElementAtIndexError( index: Int, numElements: Int, - context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = { + context: SQLQueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", messageParameters = @@ -237,20 +237,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLValue(index, IntegerType), toSQLValue(numElements, IntegerType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } def mapKeyNotExistError( key: Any, dataType: DataType, - context: Option[SQLQueryContext]): NoSuchElementException = { + context: SQLQueryContext): NoSuchElementException = { new SparkNoSuchElementException( errorClass = "MAP_KEY_DOES_NOT_EXIST", messageParameters = Array( toSQLValue(key, dataType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } @@ -259,7 +259,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { errorClass = "INVALID_FRACTION_OF_SECOND", errorSubClass = None, Array(toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = None, + context = Array.empty, summary = "") } @@ -268,7 +268,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { errorClass = "CANNOT_PARSE_TIMESTAMP", errorSubClass = None, Array(e.getMessage, toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = None, + context = Array.empty, summary = "") } @@ -294,11 +294,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { ansiIllegalArgumentError(e.getMessage) } - def overflowInSumOfDecimalError(context: Option[SQLQueryContext]): ArithmeticException = { + def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in sum of decimals", context = context) } - def overflowInIntegralDivideError(context: Option[SQLQueryContext]): ArithmeticException = { + def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in integral divide", "try_divide", context) } @@ -514,14 +514,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { def arithmeticOverflowError( message: String, hint: String = "", - context: Option[SQLQueryContext] = None): ArithmeticException = { + context: SQLQueryContext = null): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" new SparkArithmeticException( errorClass = "ARITHMETIC_OVERFLOW", messageParameters = Array(message, alternative, SQLConf.ANSI_ENABLED.key), - context = context, + context = getQueryContext(context), summary = getSummary(context)) } @@ -2061,7 +2061,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Array( s"add ${toSQLValue(amount, IntegerType)} $unit to " + s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}"), - context = None, + context = Array.empty, summary = "") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 00172f69fda..aa683a06a8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -367,7 +367,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { scale: Int, roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, nullOnOverflow: Boolean = true, - context: Option[SQLQueryContext] = None): Decimal = { + context: SQLQueryContext = null): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -632,7 +632,7 @@ object Decimal { def fromStringANSI( str: UTF8String, to: DecimalType = DecimalType.USER_DEFAULT, - context: Option[SQLQueryContext] = None): Decimal = { + context: SQLQueryContext = null): Decimal = { try { val bigDecimal = stringToJavaBigDecimal(str) // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index d96ca4b87f0..513a62dc7f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -91,7 +91,7 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkExceptionInExpression[ArithmeticException](expr1, query) val expr2 = CheckOverflowInSum( - Literal(d), DecimalType(4, 3), false, context = Some(origin.context)) + Literal(d), DecimalType(4, 3), false, context = origin.context) checkExceptionInExpression[ArithmeticException](expr2, query) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org