This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 5fc6327b feat: Implement Spark-compatible CAST float/double to string 
(#346)
5fc6327b is described below

commit 5fc6327bc74249767487fb727aec1e6ef99ba7a3
Author: RickestCode <[email protected]>
AuthorDate: Fri May 3 21:22:32 2024 +0200

    feat: Implement Spark-compatible CAST float/double to string (#346)
---
 core/src/execution/datafusion/expressions/cast.rs  | 105 ++++++++++++++++++++-
 .../scala/org/apache/comet/CometCastSuite.scala    |  30 +++++-
 2 files changed, 129 insertions(+), 6 deletions(-)

diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index 7560e0c2..45859c5f 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -17,7 +17,7 @@
 
 use std::{
     any::Any,
-    fmt::{Display, Formatter},
+    fmt::{Debug, Display, Formatter},
     hash::{Hash, Hasher},
     sync::Arc,
 };
@@ -31,7 +31,8 @@ use arrow::{
 };
 use arrow_array::{
     types::{Int16Type, Int32Type, Int64Type, Int8Type},
-    Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, 
PrimitiveArray,
+    Array, ArrayRef, BooleanArray, Float32Array, Float64Array, 
GenericStringArray, OffsetSizeTrait,
+    PrimitiveArray,
 };
 use arrow_schema::{DataType, Schema};
 use chrono::{TimeZone, Timelike};
@@ -107,6 +108,74 @@ macro_rules! cast_utf8_to_timestamp {
     }};
 }
 
+macro_rules! cast_float_to_string {
+    ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) 
=> {{
+
+        fn cast<OffsetSize>(
+            from: &dyn Array,
+            _eval_mode: EvalMode,
+        ) -> CometResult<ArrayRef>
+        where
+            OffsetSize: OffsetSizeTrait, {
+                let array = 
from.as_any().downcast_ref::<$output_type>().unwrap();
+
+                // If the absolute number is less than 10,000,000 and greater 
or equal than 0.001, the
+                // result is expressed without scientific notation with at 
least one digit on either side of
+                // the decimal point. Otherwise, Spark uses a mantissa 
followed by E and an
+                // exponent. The mantissa has an optional leading minus sign 
followed by one digit to the
+                // left of the decimal point, and the minimal number of digits 
greater than zero to the
+                // right. The exponent has and optional leading minus sign.
+                // source: 
https://docs.databricks.com/en/sql/language-manual/functions/cast.html
+
+                const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
+                const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
+
+                let output_array = array
+                    .iter()
+                    .map(|value| match value {
+                        Some(value) if value == <$type>::INFINITY => 
Ok(Some("Infinity".to_string())),
+                        Some(value) if value == <$type>::NEG_INFINITY => 
Ok(Some("-Infinity".to_string())),
+                        Some(value)
+                            if (value.abs() < UPPER_SCIENTIFIC_BOUND
+                                && value.abs() >= LOWER_SCIENTIFIC_BOUND)
+                                || value.abs() == 0.0 =>
+                        {
+                            let trailing_zero = if value.fract() == 0.0 { ".0" 
} else { "" };
+
+                            Ok(Some(format!("{value}{trailing_zero}")))
+                        }
+                        Some(value)
+                            if value.abs() >= UPPER_SCIENTIFIC_BOUND
+                                || value.abs() < LOWER_SCIENTIFIC_BOUND =>
+                        {
+                            let formatted = format!("{value:E}");
+
+                            if formatted.contains(".") {
+                                Ok(Some(formatted))
+                            } else {
+                                // `formatted` is already in scientific 
notation and can be split up by E
+                                // in order to add the missing trailing 0 
which gets removed for numbers with a fraction of 0.0
+                                let prepare_number: Vec<&str> = 
formatted.split("E").collect();
+
+                                let coefficient = prepare_number[0];
+
+                                let exponent = prepare_number[1];
+
+                                Ok(Some(format!("{coefficient}.0E{exponent}")))
+                            }
+                        }
+                        Some(value) => Ok(Some(value.to_string())),
+                        _ => Ok(None),
+                    })
+                    .collect::<Result<GenericStringArray<OffsetSize>, 
CometError>>()?;
+
+                Ok(Arc::new(output_array))
+            }
+
+        cast::<$offset_type>($from, $eval_mode)
+    }};
+}
+
 impl Cast {
     pub fn new(
         child: Arc<dyn PhysicalExpr>,
@@ -185,6 +254,18 @@ impl Cast {
                     ),
                 }
             }
+            (DataType::Float64, DataType::Utf8) => {
+                Self::spark_cast_float64_to_utf8::<i32>(&array, 
self.eval_mode)?
+            }
+            (DataType::Float64, DataType::LargeUtf8) => {
+                Self::spark_cast_float64_to_utf8::<i64>(&array, 
self.eval_mode)?
+            }
+            (DataType::Float32, DataType::Utf8) => {
+                Self::spark_cast_float32_to_utf8::<i32>(&array, 
self.eval_mode)?
+            }
+            (DataType::Float32, DataType::LargeUtf8) => {
+                Self::spark_cast_float32_to_utf8::<i64>(&array, 
self.eval_mode)?
+            }
             _ => {
                 // when we have no Spark-specific casting we delegate to 
DataFusion
                 cast_with_options(&array, to_type, &CAST_OPTIONS)?
@@ -248,6 +329,26 @@ impl Cast {
         Ok(cast_array)
     }
 
+    fn spark_cast_float64_to_utf8<OffsetSize>(
+        from: &dyn Array,
+        _eval_mode: EvalMode,
+    ) -> CometResult<ArrayRef>
+    where
+        OffsetSize: OffsetSizeTrait,
+    {
+        cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
+    }
+
+    fn spark_cast_float32_to_utf8<OffsetSize>(
+        from: &dyn Array,
+        _eval_mode: EvalMode,
+    ) -> CometResult<ArrayRef>
+    where
+        OffsetSize: OffsetSizeTrait,
+    {
+        cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
+    }
+
     fn spark_cast_utf8_to_boolean<OffsetSize>(
         from: &dyn Array,
         eval_mode: EvalMode,
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index a31f4e68..3be7dcb6 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -329,9 +329,22 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
   }
 
-  ignore("cast FloatType to StringType") {
+  test("cast FloatType to StringType") {
     // https://github.com/apache/datafusion-comet/issues/312
-    castTest(generateFloats(), DataTypes.StringType)
+    val r = new Random(0)
+    val values = Seq(
+      Float.MaxValue,
+      Float.MinValue,
+      Float.NaN,
+      Float.PositiveInfinity,
+      Float.NegativeInfinity,
+      1.0f,
+      -1.0f,
+      Short.MinValue.toFloat,
+      Short.MaxValue.toFloat,
+      0.0f) ++
+      Range(0, dataSize).map(_ => r.nextFloat())
+    withNulls(values).toDF("a")
   }
 
   ignore("cast FloatType to TimestampType") {
@@ -374,9 +387,18 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
   }
 
-  ignore("cast DoubleType to StringType") {
+  test("cast DoubleType to StringType") {
     // https://github.com/apache/datafusion-comet/issues/312
-    castTest(generateDoubles(), DataTypes.StringType)
+    val r = new Random(0)
+    val values = Seq(
+      Double.MaxValue,
+      Double.MinValue,
+      Double.NaN,
+      Double.PositiveInfinity,
+      Double.NegativeInfinity,
+      0.0d) ++
+      Range(0, dataSize).map(_ => r.nextDouble())
+    withNulls(values).toDF("a")
   }
 
   ignore("cast DoubleType to TimestampType") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to