This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a7cf6cf24 fix: Fall back to Spark for MakeDecimal with unsupported
input type (#2815)
a7cf6cf24 is described below
commit a7cf6cf24920d3f6ff516708a703cb9dff312861
Author: Andy Grove <[email protected]>
AuthorDate: Wed Dec 10 16:57:00 2025 -0700
fix: Fall back to Spark for MakeDecimal with unsupported input type (#2815)
---
.../src/math_funcs/internal/make_decimal.rs | 25 ++++++++------
.../apache/comet/serde/decimalExpressions.scala | 8 +++++
.../org/apache/comet/CometExpressionSuite.scala | 40 ++++++++++++++++++++++
.../org/apache/spark/sql/ShimCometTestBase.scala | 6 +++-
.../org/apache/spark/sql/ShimCometTestBase.scala | 6 +++-
.../org/apache/spark/sql/ShimCometTestBase.scala | 6 +++-
6 files changed, 77 insertions(+), 14 deletions(-)
diff --git a/native/spark-expr/src/math_funcs/internal/make_decimal.rs
b/native/spark-expr/src/math_funcs/internal/make_decimal.rs
index 8feba54f5..338317520 100644
--- a/native/spark-expr/src/math_funcs/internal/make_decimal.rs
+++ b/native/spark-expr/src/math_funcs/internal/make_decimal.rs
@@ -40,18 +40,21 @@ pub fn spark_make_decimal(
))),
sv => internal_err!("Expected Int64 but found {sv:?}"),
},
- ColumnarValue::Array(a) => {
- let arr = a.as_primitive::<Int64Type>();
- let mut result = Decimal128Builder::new();
- for v in arr.into_iter() {
- result.append_option(long_to_decimal(&v, precision))
- }
- let result_type = DataType::Decimal128(precision, scale);
+ ColumnarValue::Array(a) => match a.data_type() {
+ DataType::Int64 => {
+ let arr = a.as_primitive::<Int64Type>();
+ let mut result = Decimal128Builder::new();
+ for v in arr.into_iter() {
+ result.append_option(long_to_decimal(&v, precision))
+ }
+ let result_type = DataType::Decimal128(precision, scale);
- Ok(ColumnarValue::Array(Arc::new(
- result.finish().with_data_type(result_type),
- )))
- }
+ Ok(ColumnarValue::Array(Arc::new(
+ result.finish().with_data_type(result_type),
+ )))
+ }
+ av => internal_err!("Expected Int64 but found {av:?}"),
+ },
}
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala
b/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala
index c606d1ac5..880f01742 100644
--- a/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala
@@ -38,6 +38,14 @@ object CometUnscaledValue extends
CometExpressionSerde[UnscaledValue] {
}
object CometMakeDecimal extends CometExpressionSerde[MakeDecimal] {
+
+ override def getSupportLevel(expr: MakeDecimal): SupportLevel = {
+ expr.child.dataType match {
+ case LongType => Compatible()
+ case other => Unsupported(Some(s"Unsupported input data type: $other"))
+ }
+ }
+
override def convert(
expr: MakeDecimal,
inputs: Seq[Attribute],
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 4fb4b02fa..c6d505691 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -3111,4 +3111,44 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
CometConcat.unsupportedReason)
}
}
+
+ // https://github.com/apache/datafusion-comet/issues/2813
+ test("make decimal using DataFrame API - integer") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select 123456 as c1 from range(1)")
+
+ withSQLConf(
+ SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
+ SQLConf.ANSI_ENABLED.key -> "false",
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
+
+ val df = sql("select * from t1")
+ val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3,
0)
+ val df1 = df.withColumn("result", makeDecimalColumn)
+
+ checkSparkAnswerAndFallbackReason(df1, "Unsupported input data type:
IntegerType")
+ }
+ }
+ }
+
+ test("make decimal using DataFrame API - long") {
+ withTable("t1") {
+ sql("create table t1 using parquet as select cast(123456 as long) as c1
from range(1)")
+
+ withSQLConf(
+ SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
+ SQLConf.ANSI_ENABLED.key -> "false",
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
+
+ val df = sql("select * from t1")
+ val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3,
0)
+ val df1 = df.withColumn("result", makeDecimalColumn)
+
+ checkSparkAnswerAndOperator(df1)
+ }
+ }
+ }
+
}
diff --git
a/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala
b/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala
index b8ecfacb3..a7dfb4264 100644
--- a/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala
+++ b/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala
@@ -20,7 +20,7 @@
package org.apache.spark.sql
import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
trait ShimCometTestBase {
@@ -46,4 +46,8 @@ trait ShimCometTestBase {
def extractLogicalPlan(df: DataFrame): LogicalPlan = {
df.logicalPlan
}
+
+ def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int):
Column = {
+ new Column(MakeDecimal(child, precision, scale))
+ }
}
diff --git
a/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala
b/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala
index f2b419556..7f22494ad 100644
--- a/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala
+++ b/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala
@@ -20,7 +20,7 @@
package org.apache.spark.sql
import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
trait ShimCometTestBase {
@@ -47,4 +47,8 @@ trait ShimCometTestBase {
df.logicalPlan
}
+ def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int):
Column = {
+ new Column(MakeDecimal(child, precision, scale))
+ }
+
}
diff --git
a/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala
b/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala
index 8fb2e6970..5ad454322 100644
--- a/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala
+++ b/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala
@@ -20,7 +20,7 @@
package org.apache.spark.sql
import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode,
SparkSession}
@@ -47,4 +47,8 @@ trait ShimCometTestBase {
def extractLogicalPlan(df: DataFrame): LogicalPlan = {
df.queryExecution.analyzed
}
+
+ def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int):
Column = {
+ new Column(ExpressionColumnNode.apply(MakeDecimal(child, precision, scale,
true)))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]