This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d3cf9310cc9 [SPARK-40222][SQL] Numeric try_add/try_divide/try_subtract/try_multiply should throw error from their children d3cf9310cc9 is described below commit d3cf9310cc93315ca25f52bdc47fde91909dad99 Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Fri Aug 26 14:19:41 2022 -0700 [SPARK-40222][SQL] Numeric try_add/try_divide/try_subtract/try_multiply should throw error from their children ### What changes were proposed in this pull request? Similar to https://issues.apache.org/jira/browse/SPARK-40054, we should refactor the try_add/try_subtract/try_multiply/try_divide functions so that the errors from their children will be shown instead of ignored. Spark SQL allows arithmetic operations between Number/Date/Timestamp/CalendarInterval/AnsiInterval (see the rule [ResolveBinaryArithmetic](https://github.com/databricks/runtime/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L501) for details). Some of these combinations can throw exceptions too: * Date + CalendarInterval - Date + CalendarInterval - Date + AnsiInterval - Timestamp + AnsiInterval - Date - CalendarInterval - Date - AnsiInterval - Timestamp - AnsiInterval - Number * CalendarInterval - Number * AnsiInterval - CalendarInterval / Number - AnsiInterval / Number This Jira is for the cases when both input data types are numbers. I will open jira tickets for DateTime types arithmetic operations if this one got merged. ### Why are the changes needed? Fix the semantics of try_add/try_divide/try_subtract/try_multiply. ### Does this PR introduce _any_ user-facing change? Yes, after changes, the error from the children of try_add/try_divide/try_subtract/try_multiply functions will be shown instead of ignored. ### How was this patch tested? Existing UT + new UT Closes #37663 from gengliangwang/newTryArithmetic. Authored-by: Gengliang Wang <gengli...@apache.org> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 33 +-- .../spark/sql/catalyst/expressions/TryEval.scala | 38 ++- .../catalyst/expressions/aggregate/Average.scala | 4 +- .../sql/catalyst/expressions/aggregate/Sum.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 95 +++++-- .../catalyst/expressions/bitwiseExpressions.scala | 6 +- .../expressions/ArithmeticExpressionSuite.scala | 24 +- .../sql/catalyst/expressions/TryCastSuite.scala | 2 +- .../sql/catalyst/expressions/TryEvalSuite.scala | 24 +- .../sql/catalyst/util/V2ExpressionBuilder.scala | 10 +- .../resources/sql-tests/inputs/try_arithmetic.sql | 12 + .../sql-tests/results/ansi/try_arithmetic.sql.out | 280 +++++++++++++++++++++ .../sql-tests/results/try_arithmetic.sql.out | 96 +++++++ .../spark/sql/SparkSessionExtensionSuite.scala | 2 +- .../connector/functions/V2FunctionBenchmark.scala | 4 +- .../sql/expressions/ExpressionInfoSuite.scala | 7 +- 16 files changed, 562 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 669857b6a11..820202ef9c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -377,7 +377,7 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsPattern(BINARY_ARITHMETIC), ruleId) { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(BINARY_ARITHMETIC), ruleId) { - case a @ Add(l, r, f) if a.childrenResolved => (l.dataType, r.dataType) match { + case a @ Add(l, r, mode) if a.childrenResolved => (l.dataType, r.dataType) match { case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r)) case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r) case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l)) @@ -394,23 +394,25 @@ class Analyzer(override val catalogManager: CatalogManager) a.copy(left = Cast(a.left, a.right.dataType)) case (_: AnsiIntervalType, _: NullType) => a.copy(right = Cast(a.right, a.left.dataType)) - case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f) + case (DateType, CalendarIntervalType) => + DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI) case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType) - case (CalendarIntervalType, DateType) => DateAddInterval(r, l, ansiEnabled = f) + case (CalendarIntervalType, DateType) => + DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI) case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType) case (DateType, dt) if dt != StringType => DateAdd(l, r) case (dt, DateType) if dt != StringType => DateAdd(r, l) case _ => a } - case s @ Subtract(l, r, f) if s.childrenResolved => (l.dataType, r.dataType) match { + case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, r.dataType) match { case (DateType, DayTimeIntervalType(DAY, DAY)) => - DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), f)) + DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI)) case (DateType, _: DayTimeIntervalType) => - DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, f))) + DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI))) case (DateType, _: YearMonthIntervalType) => - DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f))) + DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => - DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f))) + DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) case (CalendarIntervalType, CalendarIntervalType) | (_: DayTimeIntervalType, _: DayTimeIntervalType) => s case (_: NullType, _: AnsiIntervalType) => @@ -418,26 +420,27 @@ class Analyzer(override val catalogManager: CatalogManager) case (_: AnsiIntervalType, _: NullType) => s.copy(right = Cast(s.right, s.left.dataType)) case (DateType, CalendarIntervalType) => - DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f)) + DatetimeSub(l, r, DateAddInterval(l, + UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == EvalMode.ANSI)) case (_, CalendarIntervalType | _: DayTimeIntervalType) => - Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, f))), l.dataType) + Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType) case _ if AnyTimestampType.unapply(l) || AnyTimestampType.unapply(r) => SubtractTimestamps(l, r) case (_, DateType) => SubtractDates(l, r) case (DateType, dt) if dt != StringType => DateSub(l, r) case _ => s } - case m @ Multiply(l, r, f) if m.childrenResolved => (l.dataType, r.dataType) match { - case (CalendarIntervalType, _) => MultiplyInterval(l, r, f) - case (_, CalendarIntervalType) => MultiplyInterval(r, l, f) + case m @ Multiply(l, r, mode) if m.childrenResolved => (l.dataType, r.dataType) match { + case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI) + case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI) case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r) case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l) case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r) case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l) case _ => m } - case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, r.dataType) match { - case (CalendarIntervalType, _) => DivideInterval(l, r, f) + case d @ Divide(l, r, mode) if d.childrenResolved => (l.dataType, r.dataType) match { + case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI) case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r) case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r) case _ => d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala index c179c83befb..a23f4f61943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, NumericType} case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -77,8 +77,13 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran // scalastyle:on line.size.limit case class TryAdd(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Add(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Add(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Add(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_add" @@ -110,8 +115,13 @@ case class TryAdd(left: Expression, right: Expression, replacement: Expression) // scalastyle:on line.size.limit case class TryDivide(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Divide(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Divide(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Divide(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_divide" @@ -144,8 +154,13 @@ case class TryDivide(left: Expression, right: Expression, replacement: Expressio group = "math_funcs") case class TrySubtract(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Subtract(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Subtract(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Subtract(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_subtract" @@ -171,8 +186,13 @@ case class TrySubtract(left: Expression, right: Expression, replacement: Express group = "math_funcs") case class TryMultiply(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Multiply(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Multiply(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Multiply(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_multiply" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 36ffcd8f764..9bc2891ae5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -69,7 +69,7 @@ abstract class AverageBase protected def add(left: Expression, right: Expression): Expression = left.dataType match { case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType) - case _ => Add(left, right, useAnsiAdd) + case _ => Add(left, right, EvalMode.fromBoolean(useAnsiAdd)) } override lazy val aggBufferAttributes = sum :: count :: Nil @@ -103,7 +103,7 @@ abstract class AverageBase If(EqualTo(count, Literal(0L)), Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) case _ => - Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) + Divide(sum.cast(resultType), count.cast(resultType), EvalMode.LEGACY) } protected def getUpdateExpressions: Seq[Expression] = Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 869a27c6161..db8bec7c931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -65,7 +65,7 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate private def add(left: Expression, right: Expression): Expression = left.dataType match { case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType) - case _ => Add(left, right, useAnsiAdd) + case _ => Add(left, right, EvalMode.fromBoolean(useAnsiAdd)) } override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 24ac685eace..45e0ec876d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -214,7 +214,14 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant with SupportQueryContext { - protected val failOnError: Boolean + protected val evalMode: EvalMode.Value + + protected def failOnError: Boolean = evalMode match { + // The TRY mode executes as if it would fail on errors, except that it would capture the errors + // and return null results. + case EvalMode.ANSI | EvalMode.TRY => true + case _ => false + } override def checkInputDataTypes(): TypeCheckResult = (left.dataType, right.dataType) match { case (l: DecimalType, r: DecimalType) if inputType.acceptsType(l) && inputType.acceptsType(r) => @@ -240,11 +247,11 @@ abstract class BinaryArithmetic extends BinaryOperator s"${getClass.getSimpleName} must override `resultDecimalType`.") } - override def nullable: Boolean = super.nullable || { + override def nullable: Boolean = super.nullable || evalMode == EvalMode.TRY || { if (left.dataType.isInstanceOf[DecimalType]) { // For decimal arithmetic, we may return null even if both inputs are not null, if overflow // happens and this `failOnError` flag is false. - !failOnError + evalMode != EvalMode.ANSI } else { // For non-decimal arithmetic, the calculation always return non-null result when inputs are // not null. If overflow happens, we return either the overflowed value or fail. @@ -349,6 +356,49 @@ abstract class BinaryArithmetic extends BinaryOperator """.stripMargin }) } + + override def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: (String, String) => String): ExprCode = { + if (evalMode == EvalMode.TRY) { + val tryBlock: (String, String) => String = (eval1, eval2) => { + s""" + |try { + | ${f(eval1, eval2)} + |} catch (Exception e) { + | ${ev.isNull} = true; + |} + |""".stripMargin + } + super.nullSafeCodeGen(ctx, ev, tryBlock) + } else { + super.nullSafeCodeGen(ctx, ev, f) + } + } + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + if (evalMode == EvalMode.TRY) { + try { + nullSafeEval(value1, value2) + } catch { + case _: Exception => + null + } + } else { + nullSafeEval(value1, value2) + } + } + } + } } object BinaryArithmetic { @@ -367,9 +417,10 @@ object BinaryArithmetic { case class Add( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -436,9 +487,10 @@ case class Add( case class Subtract( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -511,9 +563,10 @@ case class Subtract( case class Multiply( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = NumericType @@ -698,9 +751,14 @@ trait DivModLike extends BinaryArithmetic { case class Divide( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { + + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + // `try_divide` has exactly the same behavior as the legacy divide, so here it only executes + // the error code path when `evalMode` is `ANSI`. + protected override def failOnError: Boolean = evalMode == EvalMode.ANSI override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) @@ -762,9 +820,10 @@ case class Divide( case class IntegralDivide( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = this(left, right, + EvalMode.fromSQLConf(SQLConf.get)) override def checkDivideOverflow: Boolean = left.dataType match { case LongType if failOnError => true @@ -835,9 +894,10 @@ case class IntegralDivide( case class Remainder( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = NumericType @@ -912,9 +972,10 @@ case class Remainder( case class Pmod( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 57ab9e2773e..a178500fba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types._ group = "bitwise_funcs") case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - protected override val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -77,7 +77,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme group = "bitwise_funcs") case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - protected override val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -116,7 +116,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet group = "bitwise_funcs") case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - protected override val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 2bfa072a13a..63862ee3553 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -95,7 +95,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Add(maxValue, maxValue, failOnError = true) + val expr = Add(maxValue, maxValue, EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -180,7 +180,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Subtract(minValue, maxValue, failOnError = true) + val expr = Subtract(minValue, maxValue, EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -219,7 +219,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Multiply(maxValue, maxValue, failOnError = true) + val expr = Multiply(maxValue, maxValue, EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -264,7 +264,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), failOnError = true) + val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -320,7 +320,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper withOrigin(o) { val expr = IntegralDivide( - Literal(Long.MinValue, LongType), Literal(right, LongType), failOnError = true) + Literal(Long.MinValue, LongType), Literal(right, LongType), EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -367,7 +367,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expression = exprBuilder(Literal(1L, LongType), Literal(0L, LongType), true) + val expression = exprBuilder(Literal(1L, LongType), Literal(0L, LongType), EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expression, EmptyRow, query) } } @@ -760,24 +760,24 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("SPARK-34677: exact add and subtract of day-time and year-month intervals") { - Seq(true, false).foreach { failOnError => + Seq(EvalMode.ANSI, EvalMode.LEGACY).foreach { evalMode => checkExceptionInExpression[ArithmeticException]( UnaryMinus( Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType()), - failOnError), + evalMode == EvalMode.ANSI), "overflow") checkExceptionInExpression[ArithmeticException]( Subtract( Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType()), Literal.create(Period.ofMonths(10), YearMonthIntervalType()), - failOnError + evalMode ), "overflow") checkExceptionInExpression[ArithmeticException]( Add( Literal.create(Period.ofMonths(Int.MaxValue), YearMonthIntervalType()), Literal.create(Period.ofMonths(10), YearMonthIntervalType()), - failOnError + evalMode ), "overflow") @@ -785,14 +785,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Subtract( Literal.create(Duration.ofDays(-106751991), DayTimeIntervalType()), Literal.create(Duration.ofDays(10), DayTimeIntervalType()), - failOnError + evalMode ), "overflow") checkExceptionInExpression[ArithmeticException]( Add( Literal.create(Duration.ofDays(106751991), DayTimeIntervalType()), Literal.create(Duration.ofDays(10), DayTimeIntervalType()), - failOnError + evalMode ), "overflow") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala index 4dc7f87d19d..9ead0756635 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala @@ -104,7 +104,7 @@ class TryCastThrowExceptionSuite extends SparkFunSuite with ExpressionEvalHelper // The method checkExceptionInExpression is overridden in TryCastSuite, so here we have a // new test suite for testing exceptions from the child of `try_cast()`. test("TryCast should not catch the exception from it's child") { - val child = Divide(Literal(1.0), Literal(0.0), failOnError = true) + val child = Divide(Literal(1.0), Literal(0.0), EvalMode.ANSI) checkExceptionInExpression[Exception]( Cast(child, StringType, None, EvalMode.TRY), "Division by zero") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala index 1eccd46d960..780a2692e87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala @@ -28,7 +28,7 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Add(left, right, failOnError = true)) + val input = Add(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } @@ -41,7 +41,7 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Divide(left, right, failOnError = true)) + val input = Divide(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } @@ -54,7 +54,7 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Subtract(left, right, failOnError = true)) + val input = Subtract(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } @@ -67,8 +67,24 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Multiply(left, right, failOnError = true)) + val input = Multiply(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } + + test("Throw exceptions from children") { + val failingChild = Divide(Literal(1.0), Literal(0.0), EvalMode.ANSI) + Seq( + Add(failingChild, Literal(1.0), EvalMode.TRY), + Add(Literal(1.0), failingChild, EvalMode.TRY), + Subtract(failingChild, Literal(1.0), EvalMode.TRY), + Subtract(Literal(1.0), failingChild, EvalMode.TRY), + Multiply(failingChild, Literal(1.0), EvalMode.TRY), + Multiply(Literal(1.0), failingChild, EvalMode.TRY), + Divide(failingChild, Literal(1.0), EvalMode.TRY), + Divide(Literal(1.0), failingChild, EvalMode.TRY) + ).foreach { expr => + checkExceptionInExpression[ArithmeticException](expr, "DIVIDE_BY_ZERO") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 749c8791da9..947a5e9f383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -35,11 +35,11 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { private def canTranslate(b: BinaryOperator) = b match { case _: BinaryComparison => true case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true - case add: Add => add.failOnError - case sub: Subtract => sub.failOnError - case mul: Multiply => mul.failOnError - case div: Divide => div.failOnError - case r: Remainder => r.failOnError + case add: Add => add.evalMode == EvalMode.ANSI + case sub: Subtract => sub.evalMode == EvalMode.ANSI + case mul: Multiply => mul.evalMode == EvalMode.ANSI + case div: Divide => div.evalMode == EvalMode.ANSI + case r: Remainder => r.evalMode == EvalMode.ANSI case _ => false } diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql index 586680f5507..55907b6701e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql @@ -4,6 +4,9 @@ SELECT try_add(2147483647, 1); SELECT try_add(-2147483648, -1); SELECT try_add(9223372036854775807L, 1); SELECT try_add(-9223372036854775808L, -1); +SELECT try_add(1, (2147483647 + 1)); +SELECT try_add(1L, (9223372036854775807L + 1L)); +SELECT try_add(1, 1.0 / 0.0); -- Date + Integer SELECT try_add(date'2021-01-01', 1); @@ -32,6 +35,9 @@ SELECT try_add(interval 106751991 day, interval 3 day); SELECT try_divide(1, 0.5); SELECT try_divide(1, 0); SELECT try_divide(0, 0); +SELECT try_divide(1, (2147483647 + 1)); +SELECT try_divide(1L, (9223372036854775807L + 1L)); +SELECT try_divide(1, 1.0 / 0.0); -- Interval / Numeric SELECT try_divide(interval 2 year, 2); @@ -47,6 +53,9 @@ SELECT try_subtract(2147483647, -1); SELECT try_subtract(-2147483648, 1); SELECT try_subtract(9223372036854775807L, -1); SELECT try_subtract(-9223372036854775808L, 1); +SELECT try_subtract(1, (2147483647 + 1)); +SELECT try_subtract(1L, (9223372036854775807L + 1L)); +SELECT try_subtract(1, 1.0 / 0.0); -- Interval - Interval SELECT try_subtract(interval 2 year, interval 3 year); @@ -60,6 +69,9 @@ SELECT try_multiply(2147483647, -2); SELECT try_multiply(-2147483648, 2); SELECT try_multiply(9223372036854775807L, 2); SELECT try_multiply(-9223372036854775808L, -2); +SELECT try_multiply(1, (2147483647 + 1)); +SELECT try_multiply(1L, (9223372036854775807L + 1L)); +SELECT try_multiply(1, 1.0 / 0.0); -- Interval * Numeric SELECT try_multiply(interval 2 year, 2); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out index 8622b97a205..914ee064c51 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out @@ -39,6 +39,76 @@ struct<try_add(-9223372036854775808, -1):bigint> NULL +-- !query +SELECT try_add(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 20, + "stopIndex" : 33, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_add(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 45, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_add(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 19, + "stopIndex" : 27, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_add(date'2021-01-01', 1) -- !query schema @@ -184,6 +254,76 @@ struct<try_divide(0, 0):double> NULL +-- !query +SELECT try_divide(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 23, + "stopIndex" : 36, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_divide(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 24, + "stopIndex" : 48, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_divide(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 22, + "stopIndex" : 30, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema @@ -272,6 +412,76 @@ struct<try_subtract(-9223372036854775808, 1):bigint> NULL +-- !query +SELECT try_subtract(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 38, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_subtract(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 26, + "stopIndex" : 50, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_subtract(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 24, + "stopIndex" : 32, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_subtract(interval 2 year, interval 3 year) -- !query schema @@ -344,6 +554,76 @@ struct<try_multiply(-9223372036854775808, -2):bigint> NULL +-- !query +SELECT try_multiply(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 38, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_multiply(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 26, + "stopIndex" : 50, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_multiply(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 24, + "stopIndex" : 32, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_multiply(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out index 8622b97a205..50bbafedd08 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out @@ -39,6 +39,30 @@ struct<try_add(-9223372036854775808, -1):bigint> NULL +-- !query +SELECT try_add(1, (2147483647 + 1)) +-- !query schema +struct<try_add(1, (2147483647 + 1)):int> +-- !query output +-2147483647 + + +-- !query +SELECT try_add(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<try_add(1, (9223372036854775807 + 1)):bigint> +-- !query output +-9223372036854775807 + + +-- !query +SELECT try_add(1, 1.0 / 0.0) +-- !query schema +struct<try_add(1, (1.0 / 0.0)):decimal(9,6)> +-- !query output +NULL + + -- !query SELECT try_add(date'2021-01-01', 1) -- !query schema @@ -184,6 +208,30 @@ struct<try_divide(0, 0):double> NULL +-- !query +SELECT try_divide(1, (2147483647 + 1)) +-- !query schema +struct<try_divide(1, (2147483647 + 1)):double> +-- !query output +-4.6566128730773926E-10 + + +-- !query +SELECT try_divide(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<try_divide(1, (9223372036854775807 + 1)):double> +-- !query output +-1.0842021724855044E-19 + + +-- !query +SELECT try_divide(1, 1.0 / 0.0) +-- !query schema +struct<try_divide(1, (1.0 / 0.0)):decimal(16,9)> +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema @@ -272,6 +320,30 @@ struct<try_subtract(-9223372036854775808, 1):bigint> NULL +-- !query +SELECT try_subtract(1, (2147483647 + 1)) +-- !query schema +struct<try_subtract(1, (2147483647 + 1)):int> +-- !query output +NULL + + +-- !query +SELECT try_subtract(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<try_subtract(1, (9223372036854775807 + 1)):bigint> +-- !query output +NULL + + +-- !query +SELECT try_subtract(1, 1.0 / 0.0) +-- !query schema +struct<try_subtract(1, (1.0 / 0.0)):decimal(9,6)> +-- !query output +NULL + + -- !query SELECT try_subtract(interval 2 year, interval 3 year) -- !query schema @@ -344,6 +416,30 @@ struct<try_multiply(-9223372036854775808, -2):bigint> NULL +-- !query +SELECT try_multiply(1, (2147483647 + 1)) +-- !query schema +struct<try_multiply(1, (2147483647 + 1)):int> +-- !query output +-2147483648 + + +-- !query +SELECT try_multiply(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<try_multiply(1, (9223372036854775807 + 1)):bigint> +-- !query output +-9223372036854775808 + + +-- !query +SELECT try_multiply(1, 1.0 / 0.0) +-- !query schema +struct<try_multiply(1, (1.0 / 0.0)):decimal(10,6)> +-- !query output +NULL + + -- !query SELECT try_multiply(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 102c971d6fd..bcdb66bab33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -735,7 +735,7 @@ class BrokenColumnarAdd( left: ColumnarExpression, right: ColumnarExpression, failOnError: Boolean = false) - extends Add(left, right, failOnError) with ColumnarExpression { + extends Add(left, right, EvalMode.fromBoolean(failOnError)) with ColumnarExpression { override def supportsColumnar(): Boolean = left.supportsColumnar && right.supportsColumnar diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala index 38f016c2b63..d9c3848d3b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression} +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, Expression} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog} @@ -104,7 +104,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark { left: Expression, right: Expression, override val nullable: Boolean) extends BinaryArithmetic { - override protected val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = NumericType override def symbol: String = "+" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 101315ccb77..106802a54c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -228,11 +228,6 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // Do not check these expressions, because these expressions override the eval method val ignoreSet = Set( - // Extend NullIntolerant and avoid evaluating input1 if input2 is 0 - classOf[IntegralDivide], - classOf[Divide], - classOf[Remainder], - classOf[Pmod], // Throws an exception, even if input is null classOf[RaiseError] ) @@ -242,6 +237,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { .filterNot(c => ignoreSet.exists(_.getName.equals(c))) .map(name => Utils.classForName(name)) .filterNot(classOf[NonSQLExpression].isAssignableFrom) + // BinaryArithmetic overrides the eval method + .filterNot(classOf[BinaryArithmetic].isAssignableFrom) exprTypesToCheck.foreach { superClass => candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org