This is an automated email from the ASF dual-hosted git repository. parthc pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new ba228e1aa fix: handle cast to dictionary vector introduced by case when (#2044) ba228e1aa is described below commit ba228e1aa82e3c0058543e2412265623c285cd36 Author: Parth Chandra <par...@apache.org> AuthorDate: Tue Sep 2 09:12:26 2025 -0700 fix: handle cast to dictionary vector introduced by case when (#2044) --- native/spark-expr/src/conversion_funcs/cast.rs | 70 ++++++++++++++++++---- .../scala/org/apache/comet/CometCastSuite.scala | 10 ++++ 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 1cf061ab2..8f33bf912 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{DictionaryArray, StringArray, StructArray}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema}; use arrow::{ array::{ cast::AsArray, @@ -41,7 +41,8 @@ use arrow::{ }; use chrono::{DateTime, NaiveDate, TimeZone, Timelike}; use datafusion::common::{ - cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, + cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult, + ScalarValue, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; @@ -867,6 +868,40 @@ pub fn spark_cast( } } +// copied from datafusion common scalar/mod.rs +fn dict_from_values<K: ArrowDictionaryKeyType>( + values_array: ArrayRef, +) -> datafusion::common::Result<ArrayRef> { + // Create a key array with `size` elements of 0..array_len for all + // non-null value elements + let key_array: PrimitiveArray<K> = (0..values_array.len()) + .map(|index| { + if values_array.is_valid(index) { + let native_index = K::Native::from_usize(index).ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not create index of type {} from value {}", + K::DATA_TYPE, + index + )) + })?; + Ok(Some(native_index)) + } else { + Ok(None) + } + }) + .collect::<datafusion::common::Result<Vec<_>>>()? + .into_iter() + .collect(); + + // create a new DictionaryArray + // + // Note: this path could be made faster by using the ArrayData + // APIs and skipping validation, if it every comes up in + // performance traces. + let dict_array = DictionaryArray::<K>::try_new(key_array, values_array)?; + Ok(Arc::new(dict_array)) +} + fn cast_array( array: ArrayRef, to_type: &DataType, @@ -896,18 +931,33 @@ fn cast_array( .downcast_ref::<DictionaryArray<Int32Type>>() .expect("Expected a dictionary array"); - let casted_dictionary = DictionaryArray::<Int32Type>::new( - dict_array.keys().clone(), - cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?, - ); - let casted_result = match to_type { - Dictionary(_, _) => Arc::new(casted_dictionary.clone()), - _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, + Dictionary(_, to_value_type) => { + let casted_dictionary = DictionaryArray::<Int32Type>::new( + dict_array.keys().clone(), + cast_array(Arc::clone(dict_array.values()), to_value_type, cast_options)?, + ); + Arc::new(casted_dictionary.clone()) + } + _ => { + let casted_dictionary = DictionaryArray::<Int32Type>::new( + dict_array.keys().clone(), + cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?, + ); + take(casted_dictionary.values().as_ref(), dict_array.keys(), None)? + } }; return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); } - _ => array, + _ => { + if let Dictionary(_, _) = to_type { + let dict_array = dict_from_values::<Int32Type>(array)?; + let casted_result = cast_array(dict_array, to_type, cast_options)?; + return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); + } else { + array + } + } }; let from_type = array.data_type(); let eval_mode = cast_options.eval_mode; diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 070d6d781..33aadcf15 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -1063,6 +1063,16 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withNulls(gen.generateLongs(dataSize)).toDF("a") } + // https://github.com/apache/datafusion-comet/issues/2038 + test("test implicit cast to dictionary with case when and dictionary type") { + withSQLConf("parquet.enable.dictionary" -> "true") { + withParquetTable((0 until 10000).map(i => (i < 5000, "one")), "tbl") { + val df = spark.sql("select case when (_1 = true) then _2 else '' end as aaa from tbl") + checkSparkAnswerAndOperator(df) + } + } + } + private def generateDecimalsPrecision10Scale2(): DataFrame = { val values = Seq( BigDecimal("-99999999.999"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org