coderfender commented on code in PR #20555: URL: https://github.com/apache/datafusion/pull/20555#discussion_r2907577014
########## datafusion/spark/src/function/conversion/cast.rs: ########## @@ -0,0 +1,633 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, TimestampMicrosecondBuilder}; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type, + Int64Type, TimeUnit, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + Result as DataFusionResult, ScalarValue, exec_err, internal_err, +}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +const MICROS_PER_SECOND: i64 = 1_000_000; + +/// Convert seconds to microseconds with saturating overflow behavior +#[inline] +fn secs_to_micros(secs: i64) -> i64 { + secs.saturating_mul(MICROS_PER_SECOND) +} + +/// Spark-compatible `cast` function for type conversions +/// +/// This implements Spark's CAST expression with a target type parameter +/// +/// # Usage +/// ```sql +/// SELECT spark_cast(value, 'timestamp') +/// ``` +/// +/// # Currently supported conversions +/// - Int8/Int16/Int32/Int64 -> Timestamp (target_type = 'timestamp') +/// +/// The integer value is interpreted as seconds since the Unix epoch (1970-01-01 00:00:00 UTC) +/// and converted to a timestamp with microsecond precision (matches spark's spec) +/// +/// # Overflow behavior +/// Uses saturating multiplication to handle overflow - values that would overflow +/// i64 when multiplied by 1,000,000 are clamped to i64::MAX or i64::MIN +/// +/// # References +/// - <https://spark.apache.org/docs/latest/api/sql/index.html#cast> +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCast { + signature: Signature, +} + +impl Default for SparkCast { + fn default() -> Self { + Self::new() + } +} + +impl SparkCast { + pub fn new() -> Self { + Self { + // First arg: value to cast (only ints for now with potential to add further support later) + // Second arg: target datatype as Utf8 string literal (ex : 'timestamp') + signature: Signature::one_of( + vec![TypeSignature::Any(2)], + Volatility::Immutable, + ), + } + } +} + +/// Parse target type string into a DataType +fn parse_target_type(type_str: &str) -> DataFusionResult<DataType> { + match type_str.to_lowercase().as_str() { + // further data type support in future + "timestamp" => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)), + other => exec_err!( + "Unsupported spark_cast target type '{}'. Supported types: timestamp", + other + ), + } +} + +/// Extract target type string from scalar arguments +fn get_target_type_from_scalar_args( + scalar_args: &[Option<&ScalarValue>], +) -> DataFusionResult<DataType> { + let [_, type_arg] = take_function_args("spark_cast", scalar_args)?; + + match type_arg { + Some(ScalarValue::Utf8(Some(s))) | Some(ScalarValue::LargeUtf8(Some(s))) => { + parse_target_type(s) + } + _ => exec_err!( + "spark_cast requires second argument to be a string of target data type ex: timestamp" + ), + } +} + +fn cast_int_to_timestamp<T: ArrowPrimitiveType>( + array: &ArrayRef, + timezone: Option<Arc<str>>, +) -> DataFusionResult<ArrayRef> +where + T::Native: Into<i64>, +{ + let arr = array.as_primitive::<T>(); + let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + // spark saturates to i64 min/max + let micros = secs_to_micros(arr.value(i).into()); + builder.append_value(micros); + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(timezone))) +} + +impl ScalarUDFImpl for SparkCast { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: ReturnFieldArgs, + ) -> DataFusionResult<FieldRef> { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let target_type = get_target_type_from_scalar_args(args.scalar_arguments)?; + Ok(Arc::new(Field::new(self.name(), target_type, nullable))) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> DataFusionResult<ColumnarValue> { + let target_type = args.return_field.data_type(); + // Use session timezone, fallback to UTC if not set Review Comment: Return type does not provide session info (which makes it difficult to access timezone info) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
