This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new a7cf6cf24 fix: Fall back to Spark for MakeDecimal with unsupported 
input type (#2815)
a7cf6cf24 is described below

commit a7cf6cf24920d3f6ff516708a703cb9dff312861
Author: Andy Grove <[email protected]>
AuthorDate: Wed Dec 10 16:57:00 2025 -0700

    fix: Fall back to Spark for MakeDecimal with unsupported input type (#2815)
---
 .../src/math_funcs/internal/make_decimal.rs        | 25 ++++++++------
 .../apache/comet/serde/decimalExpressions.scala    |  8 +++++
 .../org/apache/comet/CometExpressionSuite.scala    | 40 ++++++++++++++++++++++
 .../org/apache/spark/sql/ShimCometTestBase.scala   |  6 +++-
 .../org/apache/spark/sql/ShimCometTestBase.scala   |  6 +++-
 .../org/apache/spark/sql/ShimCometTestBase.scala   |  6 +++-
 6 files changed, 77 insertions(+), 14 deletions(-)

diff --git a/native/spark-expr/src/math_funcs/internal/make_decimal.rs 
b/native/spark-expr/src/math_funcs/internal/make_decimal.rs
index 8feba54f5..338317520 100644
--- a/native/spark-expr/src/math_funcs/internal/make_decimal.rs
+++ b/native/spark-expr/src/math_funcs/internal/make_decimal.rs
@@ -40,18 +40,21 @@ pub fn spark_make_decimal(
             ))),
             sv => internal_err!("Expected Int64 but found {sv:?}"),
         },
-        ColumnarValue::Array(a) => {
-            let arr = a.as_primitive::<Int64Type>();
-            let mut result = Decimal128Builder::new();
-            for v in arr.into_iter() {
-                result.append_option(long_to_decimal(&v, precision))
-            }
-            let result_type = DataType::Decimal128(precision, scale);
+        ColumnarValue::Array(a) => match a.data_type() {
+            DataType::Int64 => {
+                let arr = a.as_primitive::<Int64Type>();
+                let mut result = Decimal128Builder::new();
+                for v in arr.into_iter() {
+                    result.append_option(long_to_decimal(&v, precision))
+                }
+                let result_type = DataType::Decimal128(precision, scale);
 
-            Ok(ColumnarValue::Array(Arc::new(
-                result.finish().with_data_type(result_type),
-            )))
-        }
+                Ok(ColumnarValue::Array(Arc::new(
+                    result.finish().with_data_type(result_type),
+                )))
+            }
+            av => internal_err!("Expected Int64 but found {av:?}"),
+        },
     }
 }
 
diff --git 
a/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala 
b/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala
index c606d1ac5..880f01742 100644
--- a/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala
@@ -38,6 +38,14 @@ object CometUnscaledValue extends 
CometExpressionSerde[UnscaledValue] {
 }
 
 object CometMakeDecimal extends CometExpressionSerde[MakeDecimal] {
+
+  override def getSupportLevel(expr: MakeDecimal): SupportLevel = {
+    expr.child.dataType match {
+      case LongType => Compatible()
+      case other => Unsupported(Some(s"Unsupported input data type: $other"))
+    }
+  }
+
   override def convert(
       expr: MakeDecimal,
       inputs: Seq[Attribute],
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 4fb4b02fa..c6d505691 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -3111,4 +3111,44 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         CometConcat.unsupportedReason)
     }
   }
+
+  // https://github.com/apache/datafusion-comet/issues/2813
+  test("make decimal using DataFrame API - integer") {
+    withTable("t1") {
+      sql("create table t1 using parquet as select 123456 as c1 from range(1)")
+
+      withSQLConf(
+        SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
+        SQLConf.ANSI_ENABLED.key -> "false",
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+        SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> 
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
+
+        val df = sql("select * from t1")
+        val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3, 
0)
+        val df1 = df.withColumn("result", makeDecimalColumn)
+
+        checkSparkAnswerAndFallbackReason(df1, "Unsupported input data type: 
IntegerType")
+      }
+    }
+  }
+
+  test("make decimal using DataFrame API - long") {
+    withTable("t1") {
+      sql("create table t1 using parquet as select cast(123456 as long) as c1 
from range(1)")
+
+      withSQLConf(
+        SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
+        SQLConf.ANSI_ENABLED.key -> "false",
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+        SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> 
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
+
+        val df = sql("select * from t1")
+        val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3, 
0)
+        val df1 = df.withColumn("result", makeDecimalColumn)
+
+        checkSparkAnswerAndOperator(df1)
+      }
+    }
+  }
+
 }
diff --git 
a/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala 
b/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala
index b8ecfacb3..a7dfb4264 100644
--- a/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala
+++ b/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala
@@ -20,7 +20,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 
 trait ShimCometTestBase {
@@ -46,4 +46,8 @@ trait ShimCometTestBase {
   def extractLogicalPlan(df: DataFrame): LogicalPlan = {
     df.logicalPlan
   }
+
+  def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): 
Column = {
+    new Column(MakeDecimal(child, precision, scale))
+  }
 }
diff --git 
a/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala 
b/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala
index f2b419556..7f22494ad 100644
--- a/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala
+++ b/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala
@@ -20,7 +20,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 
 trait ShimCometTestBase {
@@ -47,4 +47,8 @@ trait ShimCometTestBase {
     df.logicalPlan
   }
 
+  def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): 
Column = {
+    new Column(MakeDecimal(child, precision, scale))
+  }
+
 }
diff --git 
a/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala 
b/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala
index 8fb2e6970..5ad454322 100644
--- a/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala
+++ b/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala
@@ -20,7 +20,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode, 
SparkSession}
 
@@ -47,4 +47,8 @@ trait ShimCometTestBase {
   def extractLogicalPlan(df: DataFrame): LogicalPlan = {
     df.queryExecution.analyzed
   }
+
+  def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): 
Column = {
+    new Column(ExpressionColumnNode.apply(MakeDecimal(child, precision, scale, 
true)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to