This is an automated email from the ASF dual-hosted git repository.

parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ba228e1aa fix: handle cast to dictionary vector introduced by case 
when (#2044)
ba228e1aa is described below

commit ba228e1aa82e3c0058543e2412265623c285cd36
Author: Parth Chandra <par...@apache.org>
AuthorDate: Tue Sep 2 09:12:26 2025 -0700

    fix: handle cast to dictionary vector introduced by case when (#2044)
---
 native/spark-expr/src/conversion_funcs/cast.rs     | 70 ++++++++++++++++++----
 .../scala/org/apache/comet/CometCastSuite.scala    | 10 ++++
 2 files changed, 70 insertions(+), 10 deletions(-)

diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index 1cf061ab2..8f33bf912 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult};
 use arrow::array::builder::StringBuilder;
 use arrow::array::{DictionaryArray, StringArray, StructArray};
 use arrow::compute::can_cast_types;
-use arrow::datatypes::{DataType, Schema};
+use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, 
Schema};
 use arrow::{
     array::{
         cast::AsArray,
@@ -41,7 +41,8 @@ use arrow::{
 };
 use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
 use datafusion::common::{
-    cast::as_generic_string_array, internal_err, Result as DataFusionResult, 
ScalarValue,
+    cast::as_generic_string_array, internal_err, DataFusionError, Result as 
DataFusionResult,
+    ScalarValue,
 };
 use datafusion::physical_expr::PhysicalExpr;
 use datafusion::physical_plan::ColumnarValue;
@@ -867,6 +868,40 @@ pub fn spark_cast(
     }
 }
 
+// copied from datafusion common scalar/mod.rs
+fn dict_from_values<K: ArrowDictionaryKeyType>(
+    values_array: ArrayRef,
+) -> datafusion::common::Result<ArrayRef> {
+    // Create a key array with `size` elements of 0..array_len for all
+    // non-null value elements
+    let key_array: PrimitiveArray<K> = (0..values_array.len())
+        .map(|index| {
+            if values_array.is_valid(index) {
+                let native_index = K::Native::from_usize(index).ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "Can not create index of type {} from value {}",
+                        K::DATA_TYPE,
+                        index
+                    ))
+                })?;
+                Ok(Some(native_index))
+            } else {
+                Ok(None)
+            }
+        })
+        .collect::<datafusion::common::Result<Vec<_>>>()?
+        .into_iter()
+        .collect();
+
+    // create a new DictionaryArray
+    //
+    // Note: this path could be made faster by using the ArrayData
+    // APIs and skipping validation, if it every comes up in
+    // performance traces.
+    let dict_array = DictionaryArray::<K>::try_new(key_array, values_array)?;
+    Ok(Arc::new(dict_array))
+}
+
 fn cast_array(
     array: ArrayRef,
     to_type: &DataType,
@@ -896,18 +931,33 @@ fn cast_array(
                 .downcast_ref::<DictionaryArray<Int32Type>>()
                 .expect("Expected a dictionary array");
 
-            let casted_dictionary = DictionaryArray::<Int32Type>::new(
-                dict_array.keys().clone(),
-                cast_array(Arc::clone(dict_array.values()), to_type, 
cast_options)?,
-            );
-
             let casted_result = match to_type {
-                Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
-                _ => take(casted_dictionary.values().as_ref(), 
dict_array.keys(), None)?,
+                Dictionary(_, to_value_type) => {
+                    let casted_dictionary = DictionaryArray::<Int32Type>::new(
+                        dict_array.keys().clone(),
+                        cast_array(Arc::clone(dict_array.values()), 
to_value_type, cast_options)?,
+                    );
+                    Arc::new(casted_dictionary.clone())
+                }
+                _ => {
+                    let casted_dictionary = DictionaryArray::<Int32Type>::new(
+                        dict_array.keys().clone(),
+                        cast_array(Arc::clone(dict_array.values()), to_type, 
cast_options)?,
+                    );
+                    take(casted_dictionary.values().as_ref(), 
dict_array.keys(), None)?
+                }
             };
             return Ok(spark_cast_postprocess(casted_result, &from_type, 
to_type));
         }
-        _ => array,
+        _ => {
+            if let Dictionary(_, _) = to_type {
+                let dict_array = dict_from_values::<Int32Type>(array)?;
+                let casted_result = cast_array(dict_array, to_type, 
cast_options)?;
+                return Ok(spark_cast_postprocess(casted_result, &from_type, 
to_type));
+            } else {
+                array
+            }
+        }
     };
     let from_type = array.data_type();
     let eval_mode = cast_options.eval_mode;
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 070d6d781..33aadcf15 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -1063,6 +1063,16 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     withNulls(gen.generateLongs(dataSize)).toDF("a")
   }
 
+  // https://github.com/apache/datafusion-comet/issues/2038
+  test("test implicit cast to dictionary with case when and dictionary type") {
+    withSQLConf("parquet.enable.dictionary" -> "true") {
+      withParquetTable((0 until 10000).map(i => (i < 5000, "one")), "tbl") {
+        val df = spark.sql("select case when (_1 = true) then _2 else '' end 
as aaa from tbl")
+        checkSparkAnswerAndOperator(df)
+      }
+    }
+  }
+
   private def generateDecimalsPrecision10Scale2(): DataFrame = {
     val values = Seq(
       BigDecimal("-99999999.999"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to