This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ba3c82c59 fix: Refactor arithmetic serde and fix correctness issues 
with EvalMode::TRY (#2018)
ba3c82c59 is described below

commit ba3c82c597055953f89c8af5b52f1a8fdaac0cfa
Author: Andy Grove <agr...@apache.org>
AuthorDate: Mon Jul 21 16:49:18 2025 +0100

    fix: Refactor arithmetic serde and fix correctness issues with 
EvalMode::TRY (#2018)
---
 native/core/src/execution/planner.rs               | 111 +++++---
 native/proto/src/proto/expr.proto                  |  14 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 197 +-------------
 .../scala/org/apache/comet/serde/arithmetic.scala  | 282 +++++++++++++++++++++
 .../org/apache/comet/CometExpressionSuite.scala    |   9 +
 5 files changed, 376 insertions(+), 237 deletions(-)

diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index f76e10199..1b1c1ae57 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -229,45 +229,78 @@ impl PhysicalPlanner {
         input_schema: SchemaRef,
     ) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
         match spark_expr.expr_struct.as_ref().unwrap() {
-            ExprStruct::Add(expr) => self.create_binary_expr(
-                expr.left.as_ref().unwrap(),
-                expr.right.as_ref().unwrap(),
-                expr.return_type.as_ref(),
-                DataFusionOperator::Plus,
-                input_schema,
-            ),
-            ExprStruct::Subtract(expr) => self.create_binary_expr(
-                expr.left.as_ref().unwrap(),
-                expr.right.as_ref().unwrap(),
-                expr.return_type.as_ref(),
-                DataFusionOperator::Minus,
-                input_schema,
-            ),
-            ExprStruct::Multiply(expr) => self.create_binary_expr(
-                expr.left.as_ref().unwrap(),
-                expr.right.as_ref().unwrap(),
-                expr.return_type.as_ref(),
-                DataFusionOperator::Multiply,
-                input_schema,
-            ),
-            ExprStruct::Divide(expr) => self.create_binary_expr(
-                expr.left.as_ref().unwrap(),
-                expr.right.as_ref().unwrap(),
-                expr.return_type.as_ref(),
-                DataFusionOperator::Divide,
-                input_schema,
-            ),
-            ExprStruct::IntegralDivide(expr) => 
self.create_binary_expr_with_options(
-                expr.left.as_ref().unwrap(),
-                expr.right.as_ref().unwrap(),
-                expr.return_type.as_ref(),
-                DataFusionOperator::Divide,
-                input_schema,
-                BinaryExprOptions {
-                    is_integral_div: true,
-                },
-            ),
+            ExprStruct::Add(expr) => {
+                // TODO respect eval mode
+                // https://github.com/apache/datafusion-comet/issues/2021
+                // https://github.com/apache/datafusion-comet/issues/536
+                let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+                self.create_binary_expr(
+                    expr.left.as_ref().unwrap(),
+                    expr.right.as_ref().unwrap(),
+                    expr.return_type.as_ref(),
+                    DataFusionOperator::Plus,
+                    input_schema,
+                )
+            }
+            ExprStruct::Subtract(expr) => {
+                // TODO respect eval mode
+                // https://github.com/apache/datafusion-comet/issues/2021
+                // https://github.com/apache/datafusion-comet/issues/535
+                let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+                self.create_binary_expr(
+                    expr.left.as_ref().unwrap(),
+                    expr.right.as_ref().unwrap(),
+                    expr.return_type.as_ref(),
+                    DataFusionOperator::Minus,
+                    input_schema,
+                )
+            }
+            ExprStruct::Multiply(expr) => {
+                // TODO respect eval mode
+                // https://github.com/apache/datafusion-comet/issues/2021
+                // https://github.com/apache/datafusion-comet/issues/534
+                let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+                self.create_binary_expr(
+                    expr.left.as_ref().unwrap(),
+                    expr.right.as_ref().unwrap(),
+                    expr.return_type.as_ref(),
+                    DataFusionOperator::Multiply,
+                    input_schema,
+                )
+            }
+            ExprStruct::Divide(expr) => {
+                // TODO respect eval mode
+                // https://github.com/apache/datafusion-comet/issues/2021
+                // https://github.com/apache/datafusion-comet/issues/533
+                let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+                self.create_binary_expr(
+                    expr.left.as_ref().unwrap(),
+                    expr.right.as_ref().unwrap(),
+                    expr.return_type.as_ref(),
+                    DataFusionOperator::Divide,
+                    input_schema,
+                )
+            }
+            ExprStruct::IntegralDivide(expr) => {
+                // TODO respect eval mode
+                // https://github.com/apache/datafusion-comet/issues/2021
+                // https://github.com/apache/datafusion-comet/issues/533
+                let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+                self.create_binary_expr_with_options(
+                    expr.left.as_ref().unwrap(),
+                    expr.right.as_ref().unwrap(),
+                    expr.return_type.as_ref(),
+                    DataFusionOperator::Divide,
+                    input_schema,
+                    BinaryExprOptions {
+                        is_integral_div: true,
+                    },
+                )
+            }
             ExprStruct::Remainder(expr) => {
+                let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+                // TODO add support for EvalMode::TRY
+                // https://github.com/apache/datafusion-comet/issues/2021
                 let left =
                     self.create_expr(expr.left.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
                 let right =
@@ -278,7 +311,7 @@ impl PhysicalPlanner {
                     right,
                     expr.return_type.as_ref().map(to_arrow_datatype).unwrap(),
                     input_schema,
-                    expr.fail_on_error,
+                    eval_mode == EvalMode::Ansi,
                     &self.session_ctx.state(),
                 );
                 result.map_err(|e| GeneralError(e.to_string()))
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index 9f31beffd..bed0f3b9c 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -220,19 +220,19 @@ message Literal {
    bool is_null = 12;
 }
 
-message MathExpr {
-  Expr left = 1;
-  Expr right = 2;
-  bool fail_on_error = 3;
-  DataType return_type = 4;
-}
-
 enum EvalMode {
   LEGACY = 0;
   TRY = 1;
   ANSI = 2;
 }
 
+message MathExpr {
+  Expr left = 1;
+  Expr right = 2;
+  DataType return_type = 4;
+  EvalMode eval_mode = 5;
+}
+
 message Cast {
   Expr child = 1;
   DataType datatype = 2;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 4e5631ed2..f0faa5d39 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,6 @@ import java.util.Locale
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
-import scala.math.min
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
@@ -67,6 +66,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
    * Mapping of Spark expression class to Comet expression handler.
    */
   private val exprSerdeMap: Map[Class[_], CometExpressionSerde] = Map(
+    classOf[Add] -> CometAdd,
+    classOf[Subtract] -> CometSubtract,
+    classOf[Multiply] -> CometMultiply,
+    classOf[Divide] -> CometDivide,
+    classOf[IntegralDivide] -> CometIntegralDivide,
+    classOf[Remainder] -> CometRemainder,
     classOf[ArrayAppend] -> CometArrayAppend,
     classOf[ArrayContains] -> CometArrayContains,
     classOf[ArrayDistinct] -> CometArrayDistinct,
@@ -630,141 +635,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
       case c @ Cast(child, dt, timeZoneId, _) =>
         handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
 
-      case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
-        createMathExpression(
-          expr,
-          left,
-          right,
-          inputs,
-          binding,
-          add.dataType,
-          add.evalMode == EvalMode.ANSI,
-          (builder, mathExpr) => builder.setAdd(mathExpr))
-
-      case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
-        withInfo(add, s"Unsupported datatype ${left.dataType}")
-        None
-
-      case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) 
=>
-        createMathExpression(
-          expr,
-          left,
-          right,
-          inputs,
-          binding,
-          sub.dataType,
-          sub.evalMode == EvalMode.ANSI,
-          (builder, mathExpr) => builder.setSubtract(mathExpr))
-
-      case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
-        withInfo(sub, s"Unsupported datatype ${left.dataType}")
-        None
-
-      case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) 
=>
-        createMathExpression(
-          expr,
-          left,
-          right,
-          inputs,
-          binding,
-          mul.dataType,
-          mul.evalMode == EvalMode.ANSI,
-          (builder, mathExpr) => builder.setMultiply(mathExpr))
-
-      case mul @ Multiply(left, _, _) =>
-        if (!supportedDataType(left.dataType)) {
-          withInfo(mul, s"Unsupported datatype ${left.dataType}")
-        }
-        None
-
-      case div @ Divide(left, right, _) if supportedDataType(left.dataType) =>
-        // Datafusion now throws an exception for dividing by zero
-        // See https://github.com/apache/arrow-datafusion/pull/6792
-        // For now, use NullIf to swap zeros with nulls.
-        val rightExpr = nullIfWhenPrimitive(right)
-
-        createMathExpression(
-          expr,
-          left,
-          rightExpr,
-          inputs,
-          binding,
-          div.dataType,
-          div.evalMode == EvalMode.ANSI,
-          (builder, mathExpr) => builder.setDivide(mathExpr))
-
-      case div @ Divide(left, _, _) =>
-        if (!supportedDataType(left.dataType)) {
-          withInfo(div, s"Unsupported datatype ${left.dataType}")
-        }
-        None
-
-      case div @ IntegralDivide(left, right, _) if 
supportedDataType(left.dataType) =>
-        val rightExpr = nullIfWhenPrimitive(right)
-
-        val dataType = (left.dataType, right.dataType) match {
-          case (l: DecimalType, r: DecimalType) =>
-            // copy from IntegralDivide.resultDecimalType
-            val intDig = l.precision - l.scale + r.scale
-            DecimalType(min(if (intDig == 0) 1 else intDig, 
DecimalType.MAX_PRECISION), 0)
-          case _ => left.dataType
-        }
-
-        val divideExpr = createMathExpression(
-          expr,
-          left,
-          rightExpr,
-          inputs,
-          binding,
-          dataType,
-          div.evalMode == EvalMode.ANSI,
-          (builder, mathExpr) => builder.setIntegralDivide(mathExpr))
-
-        if (divideExpr.isDefined) {
-          val childExpr = if (dataType.isInstanceOf[DecimalType]) {
-            // check overflow for decimal type
-            val builder = ExprOuterClass.CheckOverflow.newBuilder()
-            builder.setChild(divideExpr.get)
-            builder.setFailOnError(div.evalMode == EvalMode.ANSI)
-            builder.setDatatype(serializeDataType(dataType).get)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setCheckOverflow(builder)
-                .build())
-          } else {
-            divideExpr
-          }
-
-          // cast result to long
-          castToProto(expr, None, LongType, childExpr.get, 
CometEvalMode.LEGACY)
-        } else {
-          None
-        }
-
-      case div @ IntegralDivide(left, _, _) =>
-        if (!supportedDataType(left.dataType)) {
-          withInfo(div, s"Unsupported datatype ${left.dataType}")
-        }
-        None
-
-      case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) 
=>
-        createMathExpression(
-          expr,
-          left,
-          right,
-          inputs,
-          binding,
-          rem.dataType,
-          rem.evalMode == EvalMode.ANSI,
-          (builder, mathExpr) => builder.setRemainder(mathExpr))
-
-      case rem @ Remainder(left, _, _) =>
-        if (!supportedDataType(left.dataType)) {
-          withInfo(rem, s"Unsupported datatype ${left.dataType}")
-        }
-        None
-
       case EqualTo(left, right) =>
         createBinaryExpr(
           expr,
@@ -1962,42 +1832,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
     }
   }
 
-  private def createMathExpression(
-      expr: Expression,
-      left: Expression,
-      right: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean,
-      dataType: DataType,
-      failOnError: Boolean,
-      f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => 
ExprOuterClass.Expr.Builder)
-      : Option[ExprOuterClass.Expr] = {
-    val leftExpr = exprToProtoInternal(left, inputs, binding)
-    val rightExpr = exprToProtoInternal(right, inputs, binding)
-
-    if (leftExpr.isDefined && rightExpr.isDefined) {
-      // create the generic MathExpr message
-      val builder = ExprOuterClass.MathExpr.newBuilder()
-      builder.setLeft(leftExpr.get)
-      builder.setRight(rightExpr.get)
-      builder.setFailOnError(failOnError)
-      serializeDataType(dataType).foreach { t =>
-        builder.setReturnType(t)
-      }
-      val inner = builder.build()
-      // call the user-supplied function to wrap MathExpr in a top-level Expr
-      // such as Expr.Add or Expr.Divide
-      Some(
-        f(
-          ExprOuterClass.Expr
-            .newBuilder(),
-          inner).build())
-    } else {
-      withInfo(expr, left, right)
-      None
-    }
-  }
-
   def in(
       expr: Expression,
       value: Expression,
@@ -2053,25 +1887,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
     Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
   }
 
-  private def isPrimitive(expression: Expression): Boolean = 
expression.dataType match {
-    case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: 
FloatType |
-        _: DoubleType | _: TimestampType | _: DateType | _: BooleanType | _: 
DecimalType =>
-      true
-    case _ => false
-  }
-
-  private def nullIfWhenPrimitive(expression: Expression): Expression =
-    if (isPrimitive(expression)) {
-      val zero = Literal.default(expression.dataType)
-      expression match {
-        case _: Literal if expression != zero => expression
-        case _ =>
-          If(EqualTo(expression, zero), Literal.create(null, 
expression.dataType), expression)
-      }
-    } else {
-      expression
-    }
-
   private def nullIfNegative(expression: Expression): Expression = {
     val zero = Literal.default(expression.dataType)
     If(LessThanOrEqual(expression, zero), Literal.create(null, 
expression.dataType), expression)
diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala 
b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
new file mode 100644
index 000000000..3a7a9f8fb
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.math.min
+
+import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Divide, 
EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, 
Remainder, Subtract}
+import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, 
DoubleType, FloatType, IntegerType, LongType, ShortType}
+
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.expressions.CometEvalMode
+import org.apache.comet.serde.QueryPlanSerde.{castToProto, evalModeToProto, 
exprToProtoInternal, serializeDataType}
+import org.apache.comet.shims.CometEvalModeUtil
+
+trait MathBase {
+  def createMathExpression(
+      expr: Expression,
+      left: Expression,
+      right: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      dataType: DataType,
+      evalMode: EvalMode.Value,
+      f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => 
ExprOuterClass.Expr.Builder)
+      : Option[ExprOuterClass.Expr] = {
+    val leftExpr = exprToProtoInternal(left, inputs, binding)
+    val rightExpr = exprToProtoInternal(right, inputs, binding)
+
+    if (leftExpr.isDefined && rightExpr.isDefined) {
+      // create the generic MathExpr message
+      val builder = ExprOuterClass.MathExpr.newBuilder()
+      builder.setLeft(leftExpr.get)
+      builder.setRight(rightExpr.get)
+      
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode)))
+      serializeDataType(dataType).foreach { t =>
+        builder.setReturnType(t)
+      }
+      val inner = builder.build()
+      // call the user-supplied function to wrap MathExpr in a top-level Expr
+      // such as Expr.Add or Expr.Divide
+      Some(
+        f(
+          ExprOuterClass.Expr
+            .newBuilder(),
+          inner).build())
+    } else {
+      withInfo(expr, left, right)
+      None
+    }
+  }
+
+  def nullIfWhenPrimitive(expression: Expression): Expression = {
+    val zero = Literal.default(expression.dataType)
+    expression match {
+      case _: Literal if expression != zero => expression
+      case _ =>
+        If(EqualTo(expression, zero), Literal.create(null, 
expression.dataType), expression)
+    }
+  }
+
+  def supportedDataType(dt: DataType): Boolean = dt match {
+    case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: 
FloatType |
+        _: DoubleType | _: DecimalType =>
+      true
+    case _ =>
+      false
+  }
+
+}
+
+object CometAdd extends CometExpressionSerde with MathBase {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val add = expr.asInstanceOf[Add]
+    if (!supportedDataType(add.left.dataType)) {
+      withInfo(add, s"Unsupported datatype ${add.left.dataType}")
+      return None
+    }
+    if (add.evalMode == EvalMode.TRY) {
+      withInfo(add, s"Eval mode ${add.evalMode} is not supported")
+      return None
+    }
+    createMathExpression(
+      expr,
+      add.left,
+      add.right,
+      inputs,
+      binding,
+      add.dataType,
+      add.evalMode,
+      (builder, mathExpr) => builder.setAdd(mathExpr))
+  }
+}
+
+object CometSubtract extends CometExpressionSerde with MathBase {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val sub = expr.asInstanceOf[Subtract]
+    if (!supportedDataType(sub.left.dataType)) {
+      withInfo(sub, s"Unsupported datatype ${sub.left.dataType}")
+      return None
+    }
+    if (sub.evalMode == EvalMode.TRY) {
+      withInfo(sub, s"Eval mode ${sub.evalMode} is not supported")
+      return None
+    }
+    createMathExpression(
+      expr,
+      sub.left,
+      sub.right,
+      inputs,
+      binding,
+      sub.dataType,
+      sub.evalMode,
+      (builder, mathExpr) => builder.setSubtract(mathExpr))
+  }
+}
+
+object CometMultiply extends CometExpressionSerde with MathBase {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val mul = expr.asInstanceOf[Multiply]
+    if (!supportedDataType(mul.left.dataType)) {
+      withInfo(mul, s"Unsupported datatype ${mul.left.dataType}")
+      return None
+    }
+    if (mul.evalMode == EvalMode.TRY) {
+      withInfo(mul, s"Eval mode ${mul.evalMode} is not supported")
+      return None
+    }
+    createMathExpression(
+      expr,
+      mul.left,
+      mul.right,
+      inputs,
+      binding,
+      mul.dataType,
+      mul.evalMode,
+      (builder, mathExpr) => builder.setMultiply(mathExpr))
+  }
+}
+
+object CometDivide extends CometExpressionSerde with MathBase {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val div = expr.asInstanceOf[Divide]
+
+    // Datafusion now throws an exception for dividing by zero
+    // See https://github.com/apache/arrow-datafusion/pull/6792
+    // For now, use NullIf to swap zeros with nulls.
+    val rightExpr = nullIfWhenPrimitive(div.right)
+
+    if (!supportedDataType(div.left.dataType)) {
+      withInfo(div, s"Unsupported datatype ${div.left.dataType}")
+      return None
+    }
+    if (div.evalMode == EvalMode.TRY) {
+      withInfo(div, s"Eval mode ${div.evalMode} is not supported")
+      return None
+    }
+    createMathExpression(
+      expr,
+      div.left,
+      rightExpr,
+      inputs,
+      binding,
+      div.dataType,
+      div.evalMode,
+      (builder, mathExpr) => builder.setDivide(mathExpr))
+  }
+}
+
+object CometIntegralDivide extends CometExpressionSerde with MathBase {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val div = expr.asInstanceOf[IntegralDivide]
+    val rightExpr = nullIfWhenPrimitive(div.right)
+
+    if (!supportedDataType(div.left.dataType)) {
+      withInfo(div, s"Unsupported datatype ${div.left.dataType}")
+      return None
+    }
+    if (div.evalMode == EvalMode.TRY) {
+      withInfo(div, s"Eval mode ${div.evalMode} is not supported")
+      return None
+    }
+
+    val dataType = (div.left.dataType, div.right.dataType) match {
+      case (l: DecimalType, r: DecimalType) =>
+        // copy from IntegralDivide.resultDecimalType
+        val intDig = l.precision - l.scale + r.scale
+        DecimalType(min(if (intDig == 0) 1 else intDig, 
DecimalType.MAX_PRECISION), 0)
+      case _ => div.left.dataType
+    }
+
+    val divideExpr = createMathExpression(
+      expr,
+      div.left,
+      rightExpr,
+      inputs,
+      binding,
+      dataType,
+      div.evalMode,
+      (builder, mathExpr) => builder.setIntegralDivide(mathExpr))
+
+    if (divideExpr.isDefined) {
+      val childExpr = if (dataType.isInstanceOf[DecimalType]) {
+        // check overflow for decimal type
+        val builder = ExprOuterClass.CheckOverflow.newBuilder()
+        builder.setChild(divideExpr.get)
+        builder.setFailOnError(div.evalMode == EvalMode.ANSI)
+        builder.setDatatype(serializeDataType(dataType).get)
+        Some(
+          ExprOuterClass.Expr
+            .newBuilder()
+            .setCheckOverflow(builder)
+            .build())
+      } else {
+        divideExpr
+      }
+
+      // cast result to long
+      castToProto(expr, None, LongType, childExpr.get, CometEvalMode.LEGACY)
+    } else {
+      None
+    }
+  }
+}
+
+object CometRemainder extends CometExpressionSerde with MathBase {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val remainder = expr.asInstanceOf[Remainder]
+    if (!supportedDataType(remainder.left.dataType)) {
+      withInfo(remainder, s"Unsupported datatype ${remainder.left.dataType}")
+      return None
+    }
+    if (remainder.evalMode == EvalMode.TRY) {
+      withInfo(remainder, s"Eval mode ${remainder.evalMode} is not supported")
+      return None
+    }
+
+    createMathExpression(
+      expr,
+      remainder.left,
+      remainder.right,
+      inputs,
+      binding,
+      remainder.dataType,
+      remainder.evalMode,
+      (builder, mathExpr) => builder.setRemainder(mathExpr))
+  }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index a09f337f8..883f1d082 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -302,6 +302,15 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("try_add") {
+    // TODO: we need to implement more comprehensive tests for all try_ 
arithmetic functions
+    // https://github.com/apache/datafusion-comet/issues/2021
+    val data = Seq((Integer.MAX_VALUE, 1))
+    withParquetTable(data, "tbl") {
+      checkSparkAnswer("SELECT try_add(_1, _2) FROM tbl")
+    }
+  }
+
   test("dictionary arithmetic") {
     // TODO: test ANSI mode
     withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", 
"parquet.enable.dictionary" -> "true") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to