This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d3cf9310cc9 [SPARK-40222][SQL] Numeric 
try_add/try_divide/try_subtract/try_multiply should throw error from their 
children
d3cf9310cc9 is described below

commit d3cf9310cc93315ca25f52bdc47fde91909dad99
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Fri Aug 26 14:19:41 2022 -0700

    [SPARK-40222][SQL] Numeric try_add/try_divide/try_subtract/try_multiply 
should throw error from their children
    
    ### What changes were proposed in this pull request?
    
    Similar to https://issues.apache.org/jira/browse/SPARK-40054, we should 
refactor the try_add/try_subtract/try_multiply/try_divide functions so that the 
errors from their children will be shown instead of ignored.
     Spark SQL allows arithmetic operations between 
Number/Date/Timestamp/CalendarInterval/AnsiInterval (see the rule 
[ResolveBinaryArithmetic](https://github.com/databricks/runtime/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L501)
 for details). Some of these combinations can throw exceptions too: * Date + 
CalendarInterval
    
    - Date + CalendarInterval
    - Date + AnsiInterval
    - Timestamp + AnsiInterval
    - Date - CalendarInterval
    - Date - AnsiInterval
    - Timestamp - AnsiInterval
    - Number * CalendarInterval
    - Number * AnsiInterval
    - CalendarInterval / Number
    - AnsiInterval / Number
    
    This Jira is for the cases when both input data types are numbers. I will 
open jira tickets for DateTime types arithmetic operations if this one got 
merged.
    
    ### Why are the changes needed?
    
    Fix the semantics of try_add/try_divide/try_subtract/try_multiply.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, after changes, the error from the children of 
try_add/try_divide/try_subtract/try_multiply functions will be shown instead of 
ignored.
    
    ### How was this patch tested?
    
    Existing UT + new UT
    
    Closes #37663 from gengliangwang/newTryArithmetic.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  33 +--
 .../spark/sql/catalyst/expressions/TryEval.scala   |  38 ++-
 .../catalyst/expressions/aggregate/Average.scala   |   4 +-
 .../sql/catalyst/expressions/aggregate/Sum.scala   |   2 +-
 .../sql/catalyst/expressions/arithmetic.scala      |  95 +++++--
 .../catalyst/expressions/bitwiseExpressions.scala  |   6 +-
 .../expressions/ArithmeticExpressionSuite.scala    |  24 +-
 .../sql/catalyst/expressions/TryCastSuite.scala    |   2 +-
 .../sql/catalyst/expressions/TryEvalSuite.scala    |  24 +-
 .../sql/catalyst/util/V2ExpressionBuilder.scala    |  10 +-
 .../resources/sql-tests/inputs/try_arithmetic.sql  |  12 +
 .../sql-tests/results/ansi/try_arithmetic.sql.out  | 280 +++++++++++++++++++++
 .../sql-tests/results/try_arithmetic.sql.out       |  96 +++++++
 .../spark/sql/SparkSessionExtensionSuite.scala     |   2 +-
 .../connector/functions/V2FunctionBenchmark.scala  |   4 +-
 .../sql/expressions/ExpressionInfoSuite.scala      |   7 +-
 16 files changed, 562 insertions(+), 77 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 669857b6a11..820202ef9c5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -377,7 +377,7 @@ class Analyzer(override val catalogManager: CatalogManager)
       _.containsPattern(BINARY_ARITHMETIC), ruleId) {
       case p: LogicalPlan => p.transformExpressionsUpWithPruning(
         _.containsPattern(BINARY_ARITHMETIC), ruleId) {
-        case a @ Add(l, r, f) if a.childrenResolved => (l.dataType, 
r.dataType) match {
+        case a @ Add(l, r, mode) if a.childrenResolved => (l.dataType, 
r.dataType) match {
           case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, 
ExtractANSIIntervalDays(r))
           case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, 
TimestampType), r)
           case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, 
ExtractANSIIntervalDays(l))
@@ -394,23 +394,25 @@ class Analyzer(override val catalogManager: 
CatalogManager)
             a.copy(left = Cast(a.left, a.right.dataType))
           case (_: AnsiIntervalType, _: NullType) =>
             a.copy(right = Cast(a.right, a.left.dataType))
-          case (DateType, CalendarIntervalType) => DateAddInterval(l, r, 
ansiEnabled = f)
+          case (DateType, CalendarIntervalType) =>
+            DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
           case (_, CalendarIntervalType | _: DayTimeIntervalType) => 
Cast(TimeAdd(l, r), l.dataType)
-          case (CalendarIntervalType, DateType) => DateAddInterval(r, l, 
ansiEnabled = f)
+          case (CalendarIntervalType, DateType) =>
+            DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
           case (CalendarIntervalType | _: DayTimeIntervalType, _) => 
Cast(TimeAdd(r, l), r.dataType)
           case (DateType, dt) if dt != StringType => DateAdd(l, r)
           case (dt, DateType) if dt != StringType => DateAdd(r, l)
           case _ => a
         }
-        case s @ Subtract(l, r, f) if s.childrenResolved => (l.dataType, 
r.dataType) match {
+        case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, 
r.dataType) match {
           case (DateType, DayTimeIntervalType(DAY, DAY)) =>
-            DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), f))
+            DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == 
EvalMode.ANSI))
           case (DateType, _: DayTimeIntervalType) =>
-            DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, 
f)))
+            DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, 
mode == EvalMode.ANSI)))
           case (DateType, _: YearMonthIntervalType) =>
-            DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f)))
+            DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == 
EvalMode.ANSI)))
           case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
-            DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f)))
+            DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == 
EvalMode.ANSI)))
           case (CalendarIntervalType, CalendarIntervalType) |
                (_: DayTimeIntervalType, _: DayTimeIntervalType) => s
           case (_: NullType, _: AnsiIntervalType) =>
@@ -418,26 +420,27 @@ class Analyzer(override val catalogManager: 
CatalogManager)
           case (_: AnsiIntervalType, _: NullType) =>
             s.copy(right = Cast(s.right, s.left.dataType))
           case (DateType, CalendarIntervalType) =>
-            DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled 
= f))
+            DatetimeSub(l, r, DateAddInterval(l,
+              UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == 
EvalMode.ANSI))
           case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
-            Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, f))), l.dataType)
+            Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == 
EvalMode.ANSI))), l.dataType)
           case _ if AnyTimestampType.unapply(l) || AnyTimestampType.unapply(r) 
=>
             SubtractTimestamps(l, r)
           case (_, DateType) => SubtractDates(l, r)
           case (DateType, dt) if dt != StringType => DateSub(l, r)
           case _ => s
         }
-        case m @ Multiply(l, r, f) if m.childrenResolved => (l.dataType, 
r.dataType) match {
-          case (CalendarIntervalType, _) => MultiplyInterval(l, r, f)
-          case (_, CalendarIntervalType) => MultiplyInterval(r, l, f)
+        case m @ Multiply(l, r, mode) if m.childrenResolved => (l.dataType, 
r.dataType) match {
+          case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == 
EvalMode.ANSI)
+          case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == 
EvalMode.ANSI)
           case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
           case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l)
           case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
           case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l)
           case _ => m
         }
-        case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, 
r.dataType) match {
-          case (CalendarIntervalType, _) => DivideInterval(l, r, f)
+        case d @ Divide(l, r, mode) if d.childrenResolved => (l.dataType, 
r.dataType) match {
+          case (CalendarIntervalType, _) => DivideInterval(l, r, mode == 
EvalMode.ANSI)
           case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r)
           case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r)
           case _ => d
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
index c179c83befb..a23f4f61943 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
CodeGenerator, ExprCode}
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{DataType, NumericType}
 
 case class TryEval(child: Expression) extends UnaryExpression with 
NullIntolerant {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -77,8 +77,13 @@ case class TryEval(child: Expression) extends 
UnaryExpression with NullIntoleran
 // scalastyle:on line.size.limit
 case class TryAdd(left: Expression, right: Expression, replacement: Expression)
     extends RuntimeReplaceable with InheritAnalysisRules {
-  def this(left: Expression, right: Expression) =
-    this(left, right, TryEval(Add(left, right, failOnError = true)))
+  def this(left: Expression, right: Expression) = this(left, right,
+    (left.dataType, right.dataType) match {
+      case (_: NumericType, _: NumericType) => Add(left, right, EvalMode.TRY)
+      // TODO: support TRY eval mode on datetime arithmetic expressions.
+      case _ => TryEval(Add(left, right, EvalMode.ANSI))
+    }
+  )
 
   override def prettyName: String = "try_add"
 
@@ -110,8 +115,13 @@ case class TryAdd(left: Expression, right: Expression, 
replacement: Expression)
 // scalastyle:on line.size.limit
 case class TryDivide(left: Expression, right: Expression, replacement: 
Expression)
   extends RuntimeReplaceable with InheritAnalysisRules {
-  def this(left: Expression, right: Expression) =
-    this(left, right, TryEval(Divide(left, right, failOnError = true)))
+  def this(left: Expression, right: Expression) = this(left, right,
+    (left.dataType, right.dataType) match {
+      case (_: NumericType, _: NumericType) => Divide(left, right, 
EvalMode.TRY)
+      // TODO: support TRY eval mode on datetime arithmetic expressions.
+      case _ => TryEval(Divide(left, right, EvalMode.ANSI))
+    }
+  )
 
   override def prettyName: String = "try_divide"
 
@@ -144,8 +154,13 @@ case class TryDivide(left: Expression, right: Expression, 
replacement: Expressio
   group = "math_funcs")
 case class TrySubtract(left: Expression, right: Expression, replacement: 
Expression)
   extends RuntimeReplaceable with InheritAnalysisRules {
-  def this(left: Expression, right: Expression) =
-    this(left, right, TryEval(Subtract(left, right, failOnError = true)))
+  def this(left: Expression, right: Expression) = this(left, right,
+    (left.dataType, right.dataType) match {
+      case (_: NumericType, _: NumericType) => Subtract(left, right, 
EvalMode.TRY)
+      // TODO: support TRY eval mode on datetime arithmetic expressions.
+      case _ => TryEval(Subtract(left, right, EvalMode.ANSI))
+    }
+  )
 
   override def prettyName: String = "try_subtract"
 
@@ -171,8 +186,13 @@ case class TrySubtract(left: Expression, right: 
Expression, replacement: Express
   group = "math_funcs")
 case class TryMultiply(left: Expression, right: Expression, replacement: 
Expression)
   extends RuntimeReplaceable with InheritAnalysisRules {
-  def this(left: Expression, right: Expression) =
-    this(left, right, TryEval(Multiply(left, right, failOnError = true)))
+  def this(left: Expression, right: Expression) = this(left, right,
+    (left.dataType, right.dataType) match {
+      case (_: NumericType, _: NumericType) => Multiply(left, right, 
EvalMode.TRY)
+      // TODO: support TRY eval mode on datetime arithmetic expressions.
+      case _ => TryEval(Multiply(left, right, EvalMode.ANSI))
+    }
+  )
 
   override def prettyName: String = "try_multiply"
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 36ffcd8f764..9bc2891ae5e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -69,7 +69,7 @@ abstract class AverageBase
 
   protected def add(left: Expression, right: Expression): Expression = 
left.dataType match {
     case _: DecimalType => DecimalAddNoOverflowCheck(left, right, 
left.dataType)
-    case _ => Add(left, right, useAnsiAdd)
+    case _ => Add(left, right, EvalMode.fromBoolean(useAnsiAdd))
   }
 
   override lazy val aggBufferAttributes = sum :: count :: Nil
@@ -103,7 +103,7 @@ abstract class AverageBase
       If(EqualTo(count, Literal(0L)),
         Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count))
     case _ =>
-      Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
+      Divide(sum.cast(resultType), count.cast(resultType), EvalMode.LEGACY)
   }
 
   protected def getUpdateExpressions: Seq[Expression] = Seq(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 869a27c6161..db8bec7c931 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -65,7 +65,7 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
 
   private def add(left: Expression, right: Expression): Expression = 
left.dataType match {
     case _: DecimalType => DecimalAddNoOverflowCheck(left, right, 
left.dataType)
-    case _ => Add(left, right, useAnsiAdd)
+    case _ => Add(left, right, EvalMode.fromBoolean(useAnsiAdd))
   }
 
   override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 24ac685eace..45e0ec876d1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -214,7 +214,14 @@ case class Abs(child: Expression, failOnError: Boolean = 
SQLConf.get.ansiEnabled
 abstract class BinaryArithmetic extends BinaryOperator
   with NullIntolerant with SupportQueryContext {
 
-  protected val failOnError: Boolean
+  protected val evalMode: EvalMode.Value
+
+  protected def failOnError: Boolean = evalMode match {
+    // The TRY mode executes as if it would fail on errors, except that it 
would capture the errors
+    // and return null results.
+    case EvalMode.ANSI | EvalMode.TRY => true
+    case _ => false
+  }
 
   override def checkInputDataTypes(): TypeCheckResult = (left.dataType, 
right.dataType) match {
     case (l: DecimalType, r: DecimalType) if inputType.acceptsType(l) && 
inputType.acceptsType(r) =>
@@ -240,11 +247,11 @@ abstract class BinaryArithmetic extends BinaryOperator
       s"${getClass.getSimpleName} must override `resultDecimalType`.")
   }
 
-  override def nullable: Boolean = super.nullable || {
+  override def nullable: Boolean = super.nullable || evalMode == EvalMode.TRY 
|| {
     if (left.dataType.isInstanceOf[DecimalType]) {
       // For decimal arithmetic, we may return null even if both inputs are 
not null, if overflow
       // happens and this `failOnError` flag is false.
-      !failOnError
+      evalMode != EvalMode.ANSI
     } else {
       // For non-decimal arithmetic, the calculation always return non-null 
result when inputs are
       // not null. If overflow happens, we return either the overflowed value 
or fail.
@@ -349,6 +356,49 @@ abstract class BinaryArithmetic extends BinaryOperator
          """.stripMargin
       })
   }
+
+  override def nullSafeCodeGen(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      f: (String, String) => String): ExprCode = {
+    if (evalMode == EvalMode.TRY) {
+      val tryBlock: (String, String) => String = (eval1, eval2) => {
+        s"""
+           |try {
+           | ${f(eval1, eval2)}
+           |} catch (Exception e) {
+           | ${ev.isNull} = true;
+           |}
+           |""".stripMargin
+      }
+      super.nullSafeCodeGen(ctx, ev, tryBlock)
+    } else {
+      super.nullSafeCodeGen(ctx, ev, f)
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    if (value1 == null) {
+      null
+    } else {
+      val value2 = right.eval(input)
+      if (value2 == null) {
+        null
+      } else {
+        if (evalMode == EvalMode.TRY) {
+          try {
+            nullSafeEval(value1, value2)
+          } catch {
+            case _: Exception =>
+              null
+          }
+        } else {
+          nullSafeEval(value1, value2)
+        }
+      }
+    }
+  }
 }
 
 object BinaryArithmetic {
@@ -367,9 +417,10 @@ object BinaryArithmetic {
 case class Add(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic {
+    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic {
 
-  def this(left: Expression, right: Expression) = this(left, right, 
SQLConf.get.ansiEnabled)
+  def this(left: Expression, right: Expression) =
+    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -436,9 +487,10 @@ case class Add(
 case class Subtract(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic {
+    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic {
 
-  def this(left: Expression, right: Expression) = this(left, right, 
SQLConf.get.ansiEnabled)
+  def this(left: Expression, right: Expression) =
+    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -511,9 +563,10 @@ case class Subtract(
 case class Multiply(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic {
+    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic {
 
-  def this(left: Expression, right: Expression) = this(left, right, 
SQLConf.get.ansiEnabled)
+  def this(left: Expression, right: Expression) =
+    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = NumericType
 
@@ -698,9 +751,14 @@ trait DivModLike extends BinaryArithmetic {
 case class Divide(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike {
+    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
DivModLike {
+
+  def this(left: Expression, right: Expression) =
+    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
 
-  def this(left: Expression, right: Expression) = this(left, right, 
SQLConf.get.ansiEnabled)
+  // `try_divide` has exactly the same behavior as the legacy divide, so here 
it only executes
+  // the error code path when `evalMode` is `ANSI`.
+  protected override def failOnError: Boolean = evalMode == EvalMode.ANSI
 
   override def inputType: AbstractDataType = TypeCollection(DoubleType, 
DecimalType)
 
@@ -762,9 +820,10 @@ case class Divide(
 case class IntegralDivide(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike {
+    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
DivModLike {
 
-  def this(left: Expression, right: Expression) = this(left, right, 
SQLConf.get.ansiEnabled)
+  def this(left: Expression, right: Expression) = this(left, right,
+    EvalMode.fromSQLConf(SQLConf.get))
 
   override def checkDivideOverflow: Boolean = left.dataType match {
     case LongType if failOnError => true
@@ -835,9 +894,10 @@ case class IntegralDivide(
 case class Remainder(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike {
+    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
DivModLike {
 
-  def this(left: Expression, right: Expression) = this(left, right, 
SQLConf.get.ansiEnabled)
+  def this(left: Expression, right: Expression) =
+    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = NumericType
 
@@ -912,9 +972,10 @@ case class Remainder(
 case class Pmod(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic {
+    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic {
 
-  def this(left: Expression, right: Expression) = this(left, right, 
SQLConf.get.ansiEnabled)
+  def this(left: Expression, right: Expression) =
+    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
 
   override def toString: String = s"pmod($left, $right)"
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 57ab9e2773e..a178500fba8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.types._
   group = "bitwise_funcs")
 case class BitwiseAnd(left: Expression, right: Expression) extends 
BinaryArithmetic {
 
-  protected override val failOnError: Boolean = false
+  protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
 
   override def inputType: AbstractDataType = IntegralType
 
@@ -77,7 +77,7 @@ case class BitwiseAnd(left: Expression, right: Expression) 
extends BinaryArithme
   group = "bitwise_funcs")
 case class BitwiseOr(left: Expression, right: Expression) extends 
BinaryArithmetic {
 
-  protected override val failOnError: Boolean = false
+  protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
 
   override def inputType: AbstractDataType = IntegralType
 
@@ -116,7 +116,7 @@ case class BitwiseOr(left: Expression, right: Expression) 
extends BinaryArithmet
   group = "bitwise_funcs")
 case class BitwiseXor(left: Expression, right: Expression) extends 
BinaryArithmetic {
 
-  protected override val failOnError: Boolean = false
+  protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
 
   override def inputType: AbstractDataType = IntegralType
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 2bfa072a13a..63862ee3553 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -95,7 +95,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
-        val expr = Add(maxValue, maxValue, failOnError = true)
+        val expr = Add(maxValue, maxValue, EvalMode.ANSI)
         checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query)
       }
     }
@@ -180,7 +180,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
-        val expr = Subtract(minValue, maxValue, failOnError = true)
+        val expr = Subtract(minValue, maxValue, EvalMode.ANSI)
         checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query)
       }
     }
@@ -219,7 +219,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
-        val expr = Multiply(maxValue, maxValue, failOnError = true)
+        val expr = Multiply(maxValue, maxValue, EvalMode.ANSI)
         checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query)
       }
     }
@@ -264,7 +264,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
       stopIndex = Some(7 + query.length -1),
       sqlText = Some(s"select $query"))
     withOrigin(o) {
-      val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), 
failOnError = true)
+      val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), 
EvalMode.ANSI)
       checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query)
     }
   }
@@ -320,7 +320,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
       withOrigin(o) {
         val expr =
           IntegralDivide(
-            Literal(Long.MinValue, LongType), Literal(right, LongType), 
failOnError = true)
+            Literal(Long.MinValue, LongType), Literal(right, LongType), 
EvalMode.ANSI)
         checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query)
       }
     }
@@ -367,7 +367,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
-        val expression = exprBuilder(Literal(1L, LongType), Literal(0L, 
LongType), true)
+        val expression = exprBuilder(Literal(1L, LongType), Literal(0L, 
LongType), EvalMode.ANSI)
         checkExceptionInExpression[ArithmeticException](expression, EmptyRow, 
query)
       }
     }
@@ -760,24 +760,24 @@ class ArithmeticExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
   }
 
   test("SPARK-34677: exact add and subtract of day-time and year-month 
intervals") {
-    Seq(true, false).foreach { failOnError =>
+    Seq(EvalMode.ANSI, EvalMode.LEGACY).foreach { evalMode =>
       checkExceptionInExpression[ArithmeticException](
         UnaryMinus(
           Literal.create(Period.ofMonths(Int.MinValue), 
YearMonthIntervalType()),
-          failOnError),
+          evalMode == EvalMode.ANSI),
         "overflow")
       checkExceptionInExpression[ArithmeticException](
         Subtract(
           Literal.create(Period.ofMonths(Int.MinValue), 
YearMonthIntervalType()),
           Literal.create(Period.ofMonths(10), YearMonthIntervalType()),
-          failOnError
+          evalMode
         ),
         "overflow")
       checkExceptionInExpression[ArithmeticException](
         Add(
           Literal.create(Period.ofMonths(Int.MaxValue), 
YearMonthIntervalType()),
           Literal.create(Period.ofMonths(10), YearMonthIntervalType()),
-          failOnError
+          evalMode
         ),
         "overflow")
 
@@ -785,14 +785,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
         Subtract(
           Literal.create(Duration.ofDays(-106751991), DayTimeIntervalType()),
           Literal.create(Duration.ofDays(10), DayTimeIntervalType()),
-          failOnError
+          evalMode
         ),
         "overflow")
       checkExceptionInExpression[ArithmeticException](
         Add(
           Literal.create(Duration.ofDays(106751991), DayTimeIntervalType()),
           Literal.create(Duration.ofDays(10), DayTimeIntervalType()),
-          failOnError
+          evalMode
         ),
         "overflow")
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
index 4dc7f87d19d..9ead0756635 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
@@ -104,7 +104,7 @@ class TryCastThrowExceptionSuite extends SparkFunSuite with 
ExpressionEvalHelper
   // The method checkExceptionInExpression is overridden in TryCastSuite, so 
here we have a
   // new test suite for testing exceptions from the child of `try_cast()`.
   test("TryCast should not catch the exception from it's child") {
-    val child = Divide(Literal(1.0), Literal(0.0), failOnError = true)
+    val child = Divide(Literal(1.0), Literal(0.0), EvalMode.ANSI)
     checkExceptionInExpression[Exception](
       Cast(child, StringType, None, EvalMode.TRY),
       "Division by zero")
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala
index 1eccd46d960..780a2692e87 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala
@@ -28,7 +28,7 @@ class TryEvalSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     ).foreach { case (a, b, expected) =>
       val left = Literal(a)
       val right = Literal(b)
-      val input = TryEval(Add(left, right, failOnError = true))
+      val input = Add(left, right, EvalMode.TRY)
       checkEvaluation(input, expected)
     }
   }
@@ -41,7 +41,7 @@ class TryEvalSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     ).foreach { case (a, b, expected) =>
       val left = Literal(a)
       val right = Literal(b)
-      val input = TryEval(Divide(left, right, failOnError = true))
+      val input = Divide(left, right, EvalMode.TRY)
       checkEvaluation(input, expected)
     }
   }
@@ -54,7 +54,7 @@ class TryEvalSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     ).foreach { case (a, b, expected) =>
       val left = Literal(a)
       val right = Literal(b)
-      val input = TryEval(Subtract(left, right, failOnError = true))
+      val input = Subtract(left, right, EvalMode.TRY)
       checkEvaluation(input, expected)
     }
   }
@@ -67,8 +67,24 @@ class TryEvalSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     ).foreach { case (a, b, expected) =>
       val left = Literal(a)
       val right = Literal(b)
-      val input = TryEval(Multiply(left, right, failOnError = true))
+      val input = Multiply(left, right, EvalMode.TRY)
       checkEvaluation(input, expected)
     }
   }
+
+  test("Throw exceptions from children") {
+    val failingChild = Divide(Literal(1.0), Literal(0.0), EvalMode.ANSI)
+    Seq(
+      Add(failingChild, Literal(1.0), EvalMode.TRY),
+      Add(Literal(1.0), failingChild, EvalMode.TRY),
+      Subtract(failingChild, Literal(1.0), EvalMode.TRY),
+      Subtract(Literal(1.0), failingChild, EvalMode.TRY),
+      Multiply(failingChild, Literal(1.0), EvalMode.TRY),
+      Multiply(Literal(1.0), failingChild, EvalMode.TRY),
+      Divide(failingChild, Literal(1.0), EvalMode.TRY),
+      Divide(Literal(1.0), failingChild, EvalMode.TRY)
+    ).foreach { expr =>
+      checkExceptionInExpression[ArithmeticException](expr, "DIVIDE_BY_ZERO")
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 749c8791da9..947a5e9f383 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -35,11 +35,11 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
   private def canTranslate(b: BinaryOperator) = b match {
     case _: BinaryComparison => true
     case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true
-    case add: Add => add.failOnError
-    case sub: Subtract => sub.failOnError
-    case mul: Multiply => mul.failOnError
-    case div: Divide => div.failOnError
-    case r: Remainder => r.failOnError
+    case add: Add => add.evalMode == EvalMode.ANSI
+    case sub: Subtract => sub.evalMode == EvalMode.ANSI
+    case mul: Multiply => mul.evalMode == EvalMode.ANSI
+    case div: Divide => div.evalMode == EvalMode.ANSI
+    case r: Remainder => r.evalMode == EvalMode.ANSI
     case _ => false
   }
 
diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql 
b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql
index 586680f5507..55907b6701e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql
@@ -4,6 +4,9 @@ SELECT try_add(2147483647, 1);
 SELECT try_add(-2147483648, -1);
 SELECT try_add(9223372036854775807L, 1);
 SELECT try_add(-9223372036854775808L, -1);
+SELECT try_add(1, (2147483647 + 1));
+SELECT try_add(1L, (9223372036854775807L + 1L));
+SELECT try_add(1, 1.0 / 0.0);
 
 -- Date + Integer
 SELECT try_add(date'2021-01-01', 1);
@@ -32,6 +35,9 @@ SELECT try_add(interval 106751991 day, interval 3 day);
 SELECT try_divide(1, 0.5);
 SELECT try_divide(1, 0);
 SELECT try_divide(0, 0);
+SELECT try_divide(1, (2147483647 + 1));
+SELECT try_divide(1L, (9223372036854775807L + 1L));
+SELECT try_divide(1, 1.0 / 0.0);
 
 -- Interval / Numeric
 SELECT try_divide(interval 2 year, 2);
@@ -47,6 +53,9 @@ SELECT try_subtract(2147483647, -1);
 SELECT try_subtract(-2147483648, 1);
 SELECT try_subtract(9223372036854775807L, -1);
 SELECT try_subtract(-9223372036854775808L, 1);
+SELECT try_subtract(1, (2147483647 + 1));
+SELECT try_subtract(1L, (9223372036854775807L + 1L));
+SELECT try_subtract(1, 1.0 / 0.0);
 
 -- Interval - Interval
 SELECT try_subtract(interval 2 year, interval 3 year);
@@ -60,6 +69,9 @@ SELECT try_multiply(2147483647, -2);
 SELECT try_multiply(-2147483648, 2);
 SELECT try_multiply(9223372036854775807L, 2);
 SELECT try_multiply(-9223372036854775808L, -2);
+SELECT try_multiply(1, (2147483647 + 1));
+SELECT try_multiply(1L, (9223372036854775807L + 1L));
+SELECT try_multiply(1, 1.0 / 0.0);
 
 -- Interval * Numeric
 SELECT try_multiply(interval 2 year, 2);
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out
index 8622b97a205..914ee064c51 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out
@@ -39,6 +39,76 @@ struct<try_add(-9223372036854775808, -1):bigint>
 NULL
 
 
+-- !query
+SELECT try_add(1, (2147483647 + 1))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "integer overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 20,
+    "stopIndex" : 33,
+    "fragment" : "2147483647 + 1"
+  } ]
+}
+
+
+-- !query
+SELECT try_add(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "long overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 21,
+    "stopIndex" : 45,
+    "fragment" : "9223372036854775807L + 1L"
+  } ]
+}
+
+
+-- !query
+SELECT try_add(1, 1.0 / 0.0)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "DIVIDE_BY_ZERO",
+  "sqlState" : "22012",
+  "messageParameters" : {
+    "config" : "\"spark.sql.ansi.enabled\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 19,
+    "stopIndex" : 27,
+    "fragment" : "1.0 / 0.0"
+  } ]
+}
+
+
 -- !query
 SELECT try_add(date'2021-01-01', 1)
 -- !query schema
@@ -184,6 +254,76 @@ struct<try_divide(0, 0):double>
 NULL
 
 
+-- !query
+SELECT try_divide(1, (2147483647 + 1))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "integer overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 23,
+    "stopIndex" : 36,
+    "fragment" : "2147483647 + 1"
+  } ]
+}
+
+
+-- !query
+SELECT try_divide(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "long overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 24,
+    "stopIndex" : 48,
+    "fragment" : "9223372036854775807L + 1L"
+  } ]
+}
+
+
+-- !query
+SELECT try_divide(1, 1.0 / 0.0)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "DIVIDE_BY_ZERO",
+  "sqlState" : "22012",
+  "messageParameters" : {
+    "config" : "\"spark.sql.ansi.enabled\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 22,
+    "stopIndex" : 30,
+    "fragment" : "1.0 / 0.0"
+  } ]
+}
+
+
 -- !query
 SELECT try_divide(interval 2 year, 2)
 -- !query schema
@@ -272,6 +412,76 @@ struct<try_subtract(-9223372036854775808, 1):bigint>
 NULL
 
 
+-- !query
+SELECT try_subtract(1, (2147483647 + 1))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "integer overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 25,
+    "stopIndex" : 38,
+    "fragment" : "2147483647 + 1"
+  } ]
+}
+
+
+-- !query
+SELECT try_subtract(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "long overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 26,
+    "stopIndex" : 50,
+    "fragment" : "9223372036854775807L + 1L"
+  } ]
+}
+
+
+-- !query
+SELECT try_subtract(1, 1.0 / 0.0)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "DIVIDE_BY_ZERO",
+  "sqlState" : "22012",
+  "messageParameters" : {
+    "config" : "\"spark.sql.ansi.enabled\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 24,
+    "stopIndex" : 32,
+    "fragment" : "1.0 / 0.0"
+  } ]
+}
+
+
 -- !query
 SELECT try_subtract(interval 2 year, interval 3 year)
 -- !query schema
@@ -344,6 +554,76 @@ struct<try_multiply(-9223372036854775808, -2):bigint>
 NULL
 
 
+-- !query
+SELECT try_multiply(1, (2147483647 + 1))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "integer overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 25,
+    "stopIndex" : 38,
+    "fragment" : "2147483647 + 1"
+  } ]
+}
+
+
+-- !query
+SELECT try_multiply(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "message" : "long overflow",
+    "alternative" : " Use 'try_add' to tolerate overflow and return NULL 
instead.",
+    "config" : "spark.sql.ansi.enabled"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 26,
+    "stopIndex" : 50,
+    "fragment" : "9223372036854775807L + 1L"
+  } ]
+}
+
+
+-- !query
+SELECT try_multiply(1, 1.0 / 0.0)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "DIVIDE_BY_ZERO",
+  "sqlState" : "22012",
+  "messageParameters" : {
+    "config" : "\"spark.sql.ansi.enabled\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 24,
+    "stopIndex" : 32,
+    "fragment" : "1.0 / 0.0"
+  } ]
+}
+
+
 -- !query
 SELECT try_multiply(interval 2 year, 2)
 -- !query schema
diff --git 
a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out 
b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out
index 8622b97a205..50bbafedd08 100644
--- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out
@@ -39,6 +39,30 @@ struct<try_add(-9223372036854775808, -1):bigint>
 NULL
 
 
+-- !query
+SELECT try_add(1, (2147483647 + 1))
+-- !query schema
+struct<try_add(1, (2147483647 + 1)):int>
+-- !query output
+-2147483647
+
+
+-- !query
+SELECT try_add(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<try_add(1, (9223372036854775807 + 1)):bigint>
+-- !query output
+-9223372036854775807
+
+
+-- !query
+SELECT try_add(1, 1.0 / 0.0)
+-- !query schema
+struct<try_add(1, (1.0 / 0.0)):decimal(9,6)>
+-- !query output
+NULL
+
+
 -- !query
 SELECT try_add(date'2021-01-01', 1)
 -- !query schema
@@ -184,6 +208,30 @@ struct<try_divide(0, 0):double>
 NULL
 
 
+-- !query
+SELECT try_divide(1, (2147483647 + 1))
+-- !query schema
+struct<try_divide(1, (2147483647 + 1)):double>
+-- !query output
+-4.6566128730773926E-10
+
+
+-- !query
+SELECT try_divide(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<try_divide(1, (9223372036854775807 + 1)):double>
+-- !query output
+-1.0842021724855044E-19
+
+
+-- !query
+SELECT try_divide(1, 1.0 / 0.0)
+-- !query schema
+struct<try_divide(1, (1.0 / 0.0)):decimal(16,9)>
+-- !query output
+NULL
+
+
 -- !query
 SELECT try_divide(interval 2 year, 2)
 -- !query schema
@@ -272,6 +320,30 @@ struct<try_subtract(-9223372036854775808, 1):bigint>
 NULL
 
 
+-- !query
+SELECT try_subtract(1, (2147483647 + 1))
+-- !query schema
+struct<try_subtract(1, (2147483647 + 1)):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_subtract(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<try_subtract(1, (9223372036854775807 + 1)):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_subtract(1, 1.0 / 0.0)
+-- !query schema
+struct<try_subtract(1, (1.0 / 0.0)):decimal(9,6)>
+-- !query output
+NULL
+
+
 -- !query
 SELECT try_subtract(interval 2 year, interval 3 year)
 -- !query schema
@@ -344,6 +416,30 @@ struct<try_multiply(-9223372036854775808, -2):bigint>
 NULL
 
 
+-- !query
+SELECT try_multiply(1, (2147483647 + 1))
+-- !query schema
+struct<try_multiply(1, (2147483647 + 1)):int>
+-- !query output
+-2147483648
+
+
+-- !query
+SELECT try_multiply(1L, (9223372036854775807L + 1L))
+-- !query schema
+struct<try_multiply(1, (9223372036854775807 + 1)):bigint>
+-- !query output
+-9223372036854775808
+
+
+-- !query
+SELECT try_multiply(1, 1.0 / 0.0)
+-- !query schema
+struct<try_multiply(1, (1.0 / 0.0)):decimal(10,6)>
+-- !query output
+NULL
+
+
 -- !query
 SELECT try_multiply(interval 2 year, 2)
 -- !query schema
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 102c971d6fd..bcdb66bab33 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -735,7 +735,7 @@ class BrokenColumnarAdd(
     left: ColumnarExpression,
     right: ColumnarExpression,
     failOnError: Boolean = false)
-  extends Add(left, right, failOnError) with ColumnarExpression {
+  extends Add(left, right, EvalMode.fromBoolean(failOnError)) with 
ColumnarExpression {
 
   override def supportsColumnar(): Boolean = left.supportsColumnar && 
right.supportsColumnar
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
index 38f016c2b63..d9c3848d3b6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
@@ -24,7 +24,7 @@ import org.apache.spark.benchmark.Benchmark
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression}
+import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, 
Expression}
 import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog}
@@ -104,7 +104,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark {
       left: Expression,
       right: Expression,
       override val nullable: Boolean) extends BinaryArithmetic {
-    override protected val failOnError: Boolean = false
+    protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
     override def inputType: AbstractDataType = NumericType
     override def symbol: String = "+"
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 101315ccb77..106802a54c9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -228,11 +228,6 @@ class ExpressionInfoSuite extends SparkFunSuite with 
SharedSparkSession {
 
     // Do not check these expressions, because these expressions override the 
eval method
     val ignoreSet = Set(
-      // Extend NullIntolerant and avoid evaluating input1 if input2 is 0
-      classOf[IntegralDivide],
-      classOf[Divide],
-      classOf[Remainder],
-      classOf[Pmod],
       // Throws an exception, even if input is null
       classOf[RaiseError]
     )
@@ -242,6 +237,8 @@ class ExpressionInfoSuite extends SparkFunSuite with 
SharedSparkSession {
       .filterNot(c => ignoreSet.exists(_.getName.equals(c)))
       .map(name => Utils.classForName(name))
       .filterNot(classOf[NonSQLExpression].isAssignableFrom)
+      // BinaryArithmetic overrides the eval method
+      .filterNot(classOf[BinaryArithmetic].isAssignableFrom)
 
     exprTypesToCheck.foreach { superClass =>
       candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { 
clazz =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to