This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 3e9f8503a chore: Cast module refactor boolean module (#3491)
3e9f8503a is described below
commit 3e9f8503a3d80b856025604dcbf7323258299cb7
Author: B Vadlamani <[email protected]>
AuthorDate: Thu Feb 19 13:57:01 2026 -0800
chore: Cast module refactor boolean module (#3491)
---
native/spark-expr/Cargo.toml | 4 +
native/spark-expr/benches/cast_from_boolean.rs | 89 ++++++++++
native/spark-expr/src/conversion_funcs/boolean.rs | 196 ++++++++++++++++++++++
native/spark-expr/src/conversion_funcs/cast.rs | 138 ++-------------
native/spark-expr/src/conversion_funcs/mod.rs | 2 +
native/spark-expr/src/conversion_funcs/utils.rs | 128 ++++++++++++++
6 files changed, 430 insertions(+), 127 deletions(-)
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 63e1c0476..e7c238f7e 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -99,3 +99,7 @@ harness = false
[[test]]
name = "test_udf_registration"
path = "tests/spark_expr_reg.rs"
+
+[[bench]]
+name = "cast_from_boolean"
+harness = false
diff --git a/native/spark-expr/benches/cast_from_boolean.rs
b/native/spark-expr/benches/cast_from_boolean.rs
new file mode 100644
index 000000000..dbb986df9
--- /dev/null
+++ b/native/spark-expr/benches/cast_from_boolean.rs
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{BooleanBuilder, RecordBatch};
+use arrow::datatypes::{DataType, Field, Schema};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::expressions::Column;
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
+use std::sync::Arc;
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let expr = Arc::new(Column::new("a", 0));
+ let boolean_batch = create_boolean_batch();
+ let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC",
false);
+ let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8,
spark_cast_options.clone());
+ let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16,
spark_cast_options.clone());
+ let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32,
spark_cast_options.clone());
+ let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64,
spark_cast_options.clone());
+ let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32,
spark_cast_options.clone());
+ let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64,
spark_cast_options.clone());
+ let cast_to_str = Cast::new(expr.clone(), DataType::Utf8,
spark_cast_options.clone());
+ let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4),
spark_cast_options);
+
+ let mut group = c.benchmark_group("cast_bool".to_string());
+ group.bench_function("i8", |b| {
+ b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap());
+ });
+ group.bench_function("i16", |b| {
+ b.iter(|| cast_to_i16.evaluate(&boolean_batch).unwrap());
+ });
+ group.bench_function("i32", |b| {
+ b.iter(|| cast_to_i32.evaluate(&boolean_batch).unwrap());
+ });
+ group.bench_function("i64", |b| {
+ b.iter(|| cast_to_i64.evaluate(&boolean_batch).unwrap());
+ });
+ group.bench_function("f32", |b| {
+ b.iter(|| cast_to_f32.evaluate(&boolean_batch).unwrap());
+ });
+ group.bench_function("f64", |b| {
+ b.iter(|| cast_to_f64.evaluate(&boolean_batch).unwrap());
+ });
+ group.bench_function("str", |b| {
+ b.iter(|| cast_to_str.evaluate(&boolean_batch).unwrap());
+ });
+ group.bench_function("decimal", |b| {
+ b.iter(|| cast_to_decimal.evaluate(&boolean_batch).unwrap());
+ });
+}
+
+fn create_boolean_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean,
true)]));
+ let mut b = BooleanBuilder::with_capacity(1000);
+ for i in 0..1000 {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<bool>());
+ }
+ }
+ let array = b.finish();
+ RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
+}
+
+fn config() -> Criterion {
+ Criterion::default()
+}
+
+criterion_group! {
+ name = benches;
+ config = config();
+ targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs
b/native/spark-expr/src/conversion_funcs/boolean.rs
new file mode 100644
index 000000000..db288fa32
--- /dev/null
+++ b/native/spark-expr/src/conversion_funcs/boolean.rs
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::SparkResult;
+use arrow::array::{ArrayRef, AsArray, Decimal128Array};
+use arrow::datatypes::DataType;
+use std::sync::Arc;
+
+pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool {
+ use DataType::*;
+ matches!(
+ to_type,
+ Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8
+ )
+}
+
+// only DF incompatible boolean cast
+pub fn cast_boolean_to_decimal(
+ array: &ArrayRef,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let bool_array = array.as_boolean();
+ let scaled_val = 10_i128.pow(scale as u32);
+ let result: Decimal128Array = bool_array
+ .iter()
+ .map(|v| v.map(|b| if b { scaled_val } else { 0 }))
+ .collect();
+ Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::cast::cast_array;
+ use crate::{EvalMode, SparkCastOptions};
+ use arrow::array::{
+ Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
Int32Array,
+ Int64Array, Int8Array, StringArray,
+ };
+ use arrow::datatypes::DataType::Decimal128;
+ use std::sync::Arc;
+
+ fn test_input_bool_array() -> ArrayRef {
+ Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]))
+ }
+
+ fn test_input_spark_opts() -> SparkCastOptions {
+ SparkCastOptions::new(EvalMode::Legacy, "Asia/Kolkata", false)
+ }
+
+ #[test]
+ fn test_is_df_cast_from_bool_spark_compatible() {
+ assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Boolean));
+ assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int8));
+ assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int16));
+ assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int32));
+ assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int64));
+ assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float32));
+ assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float64));
+ assert!(is_df_cast_from_bool_spark_compatible(&DataType::Utf8));
+ assert!(!is_df_cast_from_bool_spark_compatible(
+ &DataType::Decimal128(10, 4)
+ ));
+ assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Null));
+ }
+
+ #[test]
+ fn test_bool_to_int8_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &DataType::Int8,
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<Int8Array>().unwrap();
+ assert_eq!(arr.value(0), 1);
+ assert_eq!(arr.value(1), 0);
+ assert!(arr.is_null(2));
+ }
+
+ #[test]
+ fn test_bool_to_int16_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &DataType::Int16,
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<Int16Array>().unwrap();
+ assert_eq!(arr.value(0), 1);
+ assert_eq!(arr.value(1), 0);
+ assert!(arr.is_null(2));
+ }
+
+ #[test]
+ fn test_bool_to_int32_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &DataType::Int32,
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert_eq!(arr.value(0), 1);
+ assert_eq!(arr.value(1), 0);
+ assert!(arr.is_null(2));
+ }
+
+ #[test]
+ fn test_bool_to_int64_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &DataType::Int64,
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<Int64Array>().unwrap();
+ assert_eq!(arr.value(0), 1);
+ assert_eq!(arr.value(1), 0);
+ assert!(arr.is_null(2));
+ }
+
+ #[test]
+ fn test_bool_to_float32_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &DataType::Float32,
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<Float32Array>().unwrap();
+ assert_eq!(arr.value(0), 1.0);
+ assert_eq!(arr.value(1), 0.0);
+ assert!(arr.is_null(2));
+ }
+
+ #[test]
+ fn test_bool_to_float64_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &DataType::Float64,
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<Float64Array>().unwrap();
+ assert_eq!(arr.value(0), 1.0);
+ assert_eq!(arr.value(1), 0.0);
+ assert!(arr.is_null(2));
+ }
+
+ #[test]
+ fn test_bool_to_string_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &DataType::Utf8,
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
+ assert_eq!(arr.value(0), "true");
+ assert_eq!(arr.value(1), "false");
+ assert!(arr.is_null(2));
+ }
+
+ #[test]
+ fn test_bool_to_decimal_cast() {
+ let result = cast_array(
+ test_input_bool_array(),
+ &Decimal128(10, 4),
+ &test_input_spark_opts(),
+ )
+ .unwrap();
+ let expected_arr = Decimal128Array::from(vec![10000_i128, 0_i128])
+ .with_precision_and_scale(10, 4)
+ .unwrap();
+ let arr = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ assert_eq!(arr.value(0), expected_arr.value(0));
+ assert_eq!(arr.value(1), expected_arr.value(1));
+ assert!(arr.is_null(2));
+ }
+}
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index f5ab83b8a..004668b8f 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -15,6 +15,11 @@
// specific language governing permissions and limitations
// under the License.
+use crate::conversion_funcs::boolean::{
+ cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible,
+};
+use crate::conversion_funcs::utils::spark_cast_postprocess;
+use crate::conversion_funcs::utils::{cast_overflow, invalid_value};
use crate::utils::array_with_timezone;
use crate::EvalMode::Legacy;
use crate::{timezone, BinaryOutputStyle};
@@ -37,7 +42,7 @@ use arrow::{
GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array,
OffsetSizeTrait,
PrimitiveArray,
},
- compute::{cast_with_options, take, unary, CastOptions},
+ compute::{cast_with_options, take, CastOptions},
datatypes::{
is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type,
Float32Type,
Float64Type, Int64Type, TimestampMicrosecondType,
@@ -48,16 +53,10 @@ use arrow::{
};
use base64::prelude::*;
use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
-use datafusion::common::{
- cast::as_generic_string_array, internal_err, DataFusionError, Result as
DataFusionResult,
- ScalarValue,
-};
+use datafusion::common::{internal_err, DataFusionError, Result as
DataFusionResult, ScalarValue};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::ColumnarValue;
-use num::{
- cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub,
Integer, ToPrimitive,
- Zero,
-};
+use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer,
ToPrimitive, Zero};
use regex::Regex;
use std::str::FromStr;
use std::{
@@ -70,7 +69,7 @@ use std::{
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
-const MICROS_PER_SECOND: i64 = 1000000;
+pub(crate) const MICROS_PER_SECOND: i64 = 1000000;
static CAST_OPTIONS: CastOptions = CastOptions {
safe: true,
@@ -776,7 +775,7 @@ fn dict_from_values<K: ArrowDictionaryKeyType>(
Ok(Arc::new(dict_array))
}
-fn cast_array(
+pub(crate) fn cast_array(
array: ArrayRef,
to_type: &DataType,
cast_options: &SparkCastOptions,
@@ -1018,16 +1017,6 @@ fn cast_date_to_timestamp(
))
}
-fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) ->
SparkResult<ArrayRef> {
- let bool_array = array.as_boolean();
- let scaled_val = 10_i128.pow(scale as u32);
- let result: Decimal128Array = bool_array
- .iter()
- .map(|v| v.map(|b| if b { scaled_val } else { 0 }))
- .collect();
- Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
-}
-
fn cast_string_to_float(
array: &ArrayRef,
to_type: &DataType,
@@ -1186,16 +1175,7 @@ fn is_datafusion_spark_compatible(from_type: &DataType,
to_type: &DataType) -> b
DataType::Null => {
matches!(to_type, DataType::List(_))
}
- DataType::Boolean => matches!(
- to_type,
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- | DataType::Utf8
- ),
+ DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
=> {
matches!(
to_type,
@@ -2437,24 +2417,6 @@ fn parse_decimal_str(
Ok((final_mantissa, final_scale))
}
-#[inline]
-fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
- SparkError::CastInvalidValue {
- value: value.to_string(),
- from_type: from_type.to_string(),
- to_type: to_type.to_string(),
- }
-}
-
-#[inline]
-fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
- SparkError::CastOverFlow {
- value: value.to_string(),
- from_type: from_type.to_string(),
- to_type: to_type.to_string(),
- }
-}
-
impl Display for Cast {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
@@ -2852,84 +2814,6 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i32>>
}
}
-/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
-/// This function runs as a post process of the DataFusion cast(). By the time
it arrives here,
-/// Dictionary arrays are already unpacked by the DataFusion cast() since
Spark cannot specify
-/// Dictionary as to_type. The from_type is taken before the DataFusion cast()
runs in
-/// expressions/cast.rs, so it can be still Dictionary.
-fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type:
&DataType) -> ArrayRef {
- match (from_type, to_type) {
- (DataType::Timestamp(_, _), DataType::Int64) => {
- // See Spark's `Cast` expression
- unary_dyn::<_, Int64Type>(&array, |v| div_floor(v,
MICROS_PER_SECOND)).unwrap()
- }
- (DataType::Dictionary(_, value_type), DataType::Int64)
- if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
- {
- // See Spark's `Cast` expression
- unary_dyn::<_, Int64Type>(&array, |v| div_floor(v,
MICROS_PER_SECOND)).unwrap()
- }
- (DataType::Timestamp(_, _), DataType::Utf8) =>
remove_trailing_zeroes(array),
- (DataType::Dictionary(_, value_type), DataType::Utf8)
- if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
- {
- remove_trailing_zeroes(array)
- }
- _ => array,
- }
-}
-
-/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
-fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
-where
- T: ArrowPrimitiveType,
- F: Fn(T::Native) -> T::Native,
-{
- if let Some(d) = array.as_any_dictionary_opt() {
- let new_values = unary_dyn::<F, T>(d.values(), op)?;
- return Ok(Arc::new(d.with_values(Arc::new(new_values))));
- }
-
- match array.as_primitive_opt::<T>() {
- Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
- Ok(Arc::new(unary::<T, F, T>(
- array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
- op,
- )))
- }
- _ => Err(ArrowError::NotYetImplemented(format!(
- "Cannot perform unary operation of type {} on array of type {}",
- T::DATA_TYPE,
- array.data_type()
- ))),
- }
-}
-
-/// Remove any trailing zeroes in the string if they occur after in the
fractional seconds,
-/// to match Spark behavior
-/// example:
-/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
-/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
-/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
-/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00"
-/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
-fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
- let string_array = as_generic_string_array::<i32>(&array).unwrap();
- let result = string_array
- .iter()
- .map(|s| s.map(trim_end))
- .collect::<GenericStringArray<i32>>();
- Arc::new(result) as ArrayRef
-}
-
-fn trim_end(s: &str) -> &str {
- if s.rfind('.').is_some() {
- s.trim_end_matches('0')
- } else {
- s
- }
-}
-
#[cfg(test)]
mod tests {
use arrow::array::StringArray;
diff --git a/native/spark-expr/src/conversion_funcs/mod.rs
b/native/spark-expr/src/conversion_funcs/mod.rs
index f2c6f7ca3..190c11520 100644
--- a/native/spark-expr/src/conversion_funcs/mod.rs
+++ b/native/spark-expr/src/conversion_funcs/mod.rs
@@ -15,4 +15,6 @@
// specific language governing permissions and limitations
// under the License.
+mod boolean;
pub mod cast;
+mod utils;
diff --git a/native/spark-expr/src/conversion_funcs/utils.rs
b/native/spark-expr/src/conversion_funcs/utils.rs
new file mode 100644
index 000000000..8b8d974ff
--- /dev/null
+++ b/native/spark-expr/src/conversion_funcs/utils.rs
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::cast::MICROS_PER_SECOND;
+use crate::SparkError;
+use arrow::array::{
+ Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray,
PrimitiveArray,
+};
+use arrow::compute::unary;
+use arrow::datatypes::{DataType, Int64Type};
+use arrow::error::ArrowError;
+use datafusion::common::cast::as_generic_string_array;
+use num::integer::div_floor;
+use std::sync::Arc;
+
+/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
+pub fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native) -> T::Native,
+{
+ if let Some(d) = array.as_any_dictionary_opt() {
+ let new_values = unary_dyn::<F, T>(d.values(), op)?;
+ return Ok(Arc::new(d.with_values(Arc::new(new_values))));
+ }
+
+ match array.as_primitive_opt::<T>() {
+ Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
+ Ok(Arc::new(unary::<T, F, T>(
+ array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
+ op,
+ )))
+ }
+ _ => Err(ArrowError::NotYetImplemented(format!(
+ "Cannot perform unary operation of type {} on array of type {}",
+ T::DATA_TYPE,
+ array.data_type()
+ ))),
+ }
+}
+
+/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
+/// This function runs as a post process of the DataFusion cast(). By the time
it arrives here,
+/// Dictionary arrays are already unpacked by the DataFusion cast() since
Spark cannot specify
+/// Dictionary as to_type. The from_type is taken before the DataFusion cast()
runs in
+/// expressions/cast.rs, so it can be still Dictionary.
+pub fn spark_cast_postprocess(
+ array: ArrayRef,
+ from_type: &DataType,
+ to_type: &DataType,
+) -> ArrayRef {
+ match (from_type, to_type) {
+ (DataType::Timestamp(_, _), DataType::Int64) => {
+ // See Spark's `Cast` expression
+ unary_dyn::<_, Int64Type>(&array, |v| div_floor(v,
MICROS_PER_SECOND)).unwrap()
+ }
+ (DataType::Dictionary(_, value_type), DataType::Int64)
+ if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
+ {
+ // See Spark's `Cast` expression
+ unary_dyn::<_, Int64Type>(&array, |v| div_floor(v,
MICROS_PER_SECOND)).unwrap()
+ }
+ (DataType::Timestamp(_, _), DataType::Utf8) =>
remove_trailing_zeroes(array),
+ (DataType::Dictionary(_, value_type), DataType::Utf8)
+ if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
+ {
+ remove_trailing_zeroes(array)
+ }
+ _ => array,
+ }
+}
+
+/// Remove any trailing zeroes in the string if they occur after in the
fractional seconds,
+/// to match Spark behavior
+/// example:
+/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
+/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
+/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
+/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00"
+/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
+fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
+ let string_array = as_generic_string_array::<i32>(&array).unwrap();
+ let result = string_array
+ .iter()
+ .map(|s| s.map(trim_end))
+ .collect::<GenericStringArray<i32>>();
+ Arc::new(result) as ArrayRef
+}
+
+fn trim_end(s: &str) -> &str {
+ if s.rfind('.').is_some() {
+ s.trim_end_matches('0')
+ } else {
+ s
+ }
+}
+
+#[inline]
+pub fn cast_overflow(value: &str, from_type: &str, to_type: &str) ->
SparkError {
+ SparkError::CastOverFlow {
+ value: value.to_string(),
+ from_type: from_type.to_string(),
+ to_type: to_type.to_string(),
+ }
+}
+
+#[inline]
+pub fn invalid_value(value: &str, from_type: &str, to_type: &str) ->
SparkError {
+ SparkError::CastInvalidValue {
+ value: value.to_string(),
+ from_type: from_type.to_string(),
+ to_type: to_type.to_string(),
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]