Jefffrey commented on code in PR #18649:
URL: https://github.com/apache/datafusion/pull/18649#discussion_r2518699965


##########
datafusion/spark/src/function/bitwise/bit_shift.rs:
##########
@@ -18,36 +18,33 @@
 use std::any::Any;
 use std::sync::Arc;
 
-use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
+use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, 
PrimitiveArray};
 use arrow::compute;
 use arrow::datatypes::{
     ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type,
 };
-use datafusion_common::{plan_err, Result};
+use datafusion_common::types::{
+    logical_int16, logical_int32, logical_int64, logical_int8, logical_uint16,
+    logical_uint32, logical_uint64, logical_uint8, NativeType,
+};
+use datafusion_common::utils::take_function_args;
+use datafusion_common::{internal_err, Result};
 use datafusion_expr::{
-    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, 
TypeSignature,
+    TypeSignatureClass, Volatility,
 };
 use datafusion_functions::utils::make_scalar_function;
 
-use crate::function::error_utils::{
-    invalid_arg_count_exec_err, unsupported_data_type_exec_err,
-};
-
-/// Performs a bitwise left shift on each element of the `value` array by the 
corresponding amount in the `shift` array.
-/// The shift amount is normalized to the bit width of the type, matching 
Spark/Java semantics for negative and large shifts.
-///
-/// # Arguments
-/// * `value` - The array of values to shift.
-/// * `shift` - The array of shift amounts (must be Int32).
-///
-/// # Returns
-/// A new array with the shifted values.
-fn shift_left<T: ArrowPrimitiveType>(
+/// Bitwise left shift on elements in `value` by corresponding `shift` amount.
+/// The shift amount is normalized to the bit width of the type, matching 
Spark/Java
+/// semantics for negative and large shifts.
+fn shift_left<T>(
     value: &PrimitiveArray<T>,
-    shift: &PrimitiveArray<Int32Type>,
+    shift: &Int32Array,
 ) -> Result<PrimitiveArray<T>>
 where
-    T::Native: ArrowNativeType + std::ops::Shl<i32, Output = T::Native>,
+    T: ArrowPrimitiveType,
+    T::Native: std::ops::Shl<i32, Output = T::Native>,

Review Comment:
   Rewrote some comments to be more succinct, and also cleanup some function 
signatures (remove some unused bounds, use type aliases, etc.)



##########
datafusion/spark/src/function/bitwise/bit_shift.rs:
##########
@@ -149,589 +136,162 @@ where
     Ok(result)
 }
 
-trait BitShiftUDF: ScalarUDFImpl {
-    fn shift<T: ArrowPrimitiveType>(
-        &self,
-        value: &PrimitiveArray<T>,
-        shift: &PrimitiveArray<Int32Type>,
-    ) -> Result<PrimitiveArray<T>>
-    where
-        T::Native: ArrowNativeType
-            + std::ops::Shl<i32, Output = T::Native>
-            + std::ops::Shr<i32, Output = T::Native>
-            + UShr<i32>;
-
-    fn spark_shift(&self, arrays: &[ArrayRef]) -> Result<ArrayRef> {
-        let value_array = arrays[0].as_ref();
-        let shift_array = arrays[1].as_ref();
-
-        // Ensure shift array is Int32
-        let shift_array = if shift_array.data_type() != &DataType::Int32 {
-            return plan_err!("{} shift amount must be Int32", self.name());
-        } else {
-            shift_array.as_primitive::<Int32Type>()
-        };
-
-        match value_array.data_type() {
-            DataType::Int32 => {
-                let value_array = value_array.as_primitive::<Int32Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::Int64 => {
-                let value_array = value_array.as_primitive::<Int64Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::UInt32 => {
-                let value_array = value_array.as_primitive::<UInt32Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::UInt64 => {
-                let value_array = value_array.as_primitive::<UInt64Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            _ => {
-                plan_err!(
-                    "{} function does not support data type: {}",
-                    self.name(),
-                    value_array.data_type()
-                )
-            }
-        }
-    }
-}
+fn shift_inner(
+    arrays: &[ArrayRef],
+    name: &str,
+    bit_shift_type: BitShiftType,
+) -> Result<ArrayRef> {
+    let [value_array, shift_array] = take_function_args(name, arrays)?;
 
-fn bit_shift_coerce_types(arg_types: &[DataType], func: &str) -> 
Result<Vec<DataType>> {
-    if arg_types.len() != 2 {
-        return Err(invalid_arg_count_exec_err(func, (2, 2), arg_types.len()));
-    }
-    if !arg_types[0].is_integer() && !arg_types[0].is_null() {
-        return Err(unsupported_data_type_exec_err(
-            func,
-            "Integer Type",
-            &arg_types[0],
-        ));
-    }
-    if !arg_types[1].is_integer() && !arg_types[1].is_null() {
-        return Err(unsupported_data_type_exec_err(
-            func,
-            "Integer Type",
-            &arg_types[1],
-        ));
+    if value_array.data_type().is_null() || shift_array.data_type().is_null() {
+        return Ok(Arc::new(Int32Array::new_null(value_array.len())));
     }
 
-    // Coerce smaller integer types to Int32
-    let coerced_first = match &arg_types[0] {
-        DataType::Int8 | DataType::Int16 | DataType::Null => DataType::Int32,
-        DataType::UInt8 | DataType::UInt16 => DataType::UInt32,
-        _ => arg_types[0].clone(),
-    };
+    let shift_array = shift_array.as_primitive::<Int32Type>();
 
-    Ok(vec![coerced_first, DataType::Int32])
-}
-
-#[derive(Debug, Hash, Eq, PartialEq)]
-pub struct SparkShiftLeft {
-    signature: Signature,
-}
-
-impl Default for SparkShiftLeft {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-impl SparkShiftLeft {
-    pub fn new() -> Self {
-        Self {
-            signature: Signature::user_defined(Volatility::Immutable),
-        }
-    }
-}
-
-impl BitShiftUDF for SparkShiftLeft {
-    fn shift<T: ArrowPrimitiveType>(
-        &self,
+    fn shift<T>(
         value: &PrimitiveArray<T>,
-        shift: &PrimitiveArray<Int32Type>,
+        shift: &Int32Array,
+        bit_shift_type: BitShiftType,
     ) -> Result<PrimitiveArray<T>>
     where
-        T::Native: ArrowNativeType
-            + std::ops::Shl<i32, Output = T::Native>
+        T: ArrowPrimitiveType,
+        T::Native: std::ops::Shl<i32, Output = T::Native>
             + std::ops::Shr<i32, Output = T::Native>
-            + UShr<i32>,
+            + UShr,
     {
-        shift_left(value, shift)
-    }
-}
-
-impl ScalarUDFImpl for SparkShiftLeft {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn name(&self) -> &str {
-        "shiftleft"
-    }
-
-    fn signature(&self) -> &Signature {
-        &self.signature
-    }
-
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        bit_shift_coerce_types(arg_types, "shiftleft")
-    }
-
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        if arg_types.len() != 2 {
-            return plan_err!("shiftleft expects exactly 2 arguments");
+        match bit_shift_type {
+            BitShiftType::Left => shift_left(value, shift),
+            BitShiftType::Right => shift_right(value, shift),
+            BitShiftType::RightUnsigned => shift_right_unsigned(value, shift),
         }
-        // Return type is the same as the first argument (the value to shift)
-        Ok(arg_types[0].clone())
     }
 
-    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        if args.args.len() != 2 {
-            return plan_err!("shiftleft expects exactly 2 arguments");
+    match value_array.data_type() {
+        DataType::Int32 => {
+            let value_array = value_array.as_primitive::<Int32Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        DataType::Int64 => {
+            let value_array = value_array.as_primitive::<Int64Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        DataType::UInt32 => {
+            let value_array = value_array.as_primitive::<UInt32Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        DataType::UInt64 => {
+            let value_array = value_array.as_primitive::<UInt64Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        dt => {
+            internal_err!("{name} function does not support data type: {dt}")
         }
-        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> { 
self.spark_shift(arr) };
-        make_scalar_function(inner, vec![])(&args.args)
     }
 }
 
-#[derive(Debug, Hash, Eq, PartialEq)]
-pub struct SparkShiftRightUnsigned {
-    signature: Signature,
+#[derive(Debug, Hash, Copy, Clone, Eq, PartialEq)]
+enum BitShiftType {
+    Left,
+    Right,
+    RightUnsigned,
 }
 
-impl Default for SparkShiftRightUnsigned {
-    fn default() -> Self {
-        Self::new()
-    }
+#[derive(Debug, Hash, Eq, PartialEq)]
+pub struct SparkBitShift {
+    signature: Signature,
+    name: &'static str,
+    bit_shift_type: BitShiftType,
 }
 
-impl SparkShiftRightUnsigned {
-    pub fn new() -> Self {
+impl SparkBitShift {
+    fn new(name: &'static str, bit_shift_type: BitShiftType) -> Self {
+        let shift_amount = Coercion::new_implicit(
+            TypeSignatureClass::Native(logical_int32()),
+            vec![TypeSignatureClass::Integer],
+            NativeType::Int32,
+        );
         Self {
-            signature: Signature::user_defined(Volatility::Immutable),
+            signature: Signature::one_of(

Review Comment:
   Signature here



##########
datafusion/spark/src/function/bitwise/bit_shift.rs:
##########
@@ -149,589 +136,162 @@ where
     Ok(result)
 }
 
-trait BitShiftUDF: ScalarUDFImpl {
-    fn shift<T: ArrowPrimitiveType>(
-        &self,
-        value: &PrimitiveArray<T>,
-        shift: &PrimitiveArray<Int32Type>,
-    ) -> Result<PrimitiveArray<T>>
-    where
-        T::Native: ArrowNativeType
-            + std::ops::Shl<i32, Output = T::Native>
-            + std::ops::Shr<i32, Output = T::Native>
-            + UShr<i32>;
-
-    fn spark_shift(&self, arrays: &[ArrayRef]) -> Result<ArrayRef> {
-        let value_array = arrays[0].as_ref();
-        let shift_array = arrays[1].as_ref();
-
-        // Ensure shift array is Int32
-        let shift_array = if shift_array.data_type() != &DataType::Int32 {
-            return plan_err!("{} shift amount must be Int32", self.name());
-        } else {
-            shift_array.as_primitive::<Int32Type>()
-        };
-
-        match value_array.data_type() {
-            DataType::Int32 => {
-                let value_array = value_array.as_primitive::<Int32Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::Int64 => {
-                let value_array = value_array.as_primitive::<Int64Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::UInt32 => {
-                let value_array = value_array.as_primitive::<UInt32Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::UInt64 => {
-                let value_array = value_array.as_primitive::<UInt64Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            _ => {
-                plan_err!(
-                    "{} function does not support data type: {}",
-                    self.name(),
-                    value_array.data_type()
-                )
-            }
-        }
-    }
-}
+fn shift_inner(
+    arrays: &[ArrayRef],
+    name: &str,
+    bit_shift_type: BitShiftType,
+) -> Result<ArrayRef> {
+    let [value_array, shift_array] = take_function_args(name, arrays)?;
 
-fn bit_shift_coerce_types(arg_types: &[DataType], func: &str) -> 
Result<Vec<DataType>> {
-    if arg_types.len() != 2 {
-        return Err(invalid_arg_count_exec_err(func, (2, 2), arg_types.len()));
-    }
-    if !arg_types[0].is_integer() && !arg_types[0].is_null() {
-        return Err(unsupported_data_type_exec_err(
-            func,
-            "Integer Type",
-            &arg_types[0],
-        ));
-    }
-    if !arg_types[1].is_integer() && !arg_types[1].is_null() {
-        return Err(unsupported_data_type_exec_err(
-            func,
-            "Integer Type",
-            &arg_types[1],
-        ));
+    if value_array.data_type().is_null() || shift_array.data_type().is_null() {
+        return Ok(Arc::new(Int32Array::new_null(value_array.len())));
     }
 
-    // Coerce smaller integer types to Int32
-    let coerced_first = match &arg_types[0] {
-        DataType::Int8 | DataType::Int16 | DataType::Null => DataType::Int32,
-        DataType::UInt8 | DataType::UInt16 => DataType::UInt32,
-        _ => arg_types[0].clone(),
-    };
+    let shift_array = shift_array.as_primitive::<Int32Type>();
 
-    Ok(vec![coerced_first, DataType::Int32])
-}
-
-#[derive(Debug, Hash, Eq, PartialEq)]
-pub struct SparkShiftLeft {
-    signature: Signature,
-}
-
-impl Default for SparkShiftLeft {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-impl SparkShiftLeft {
-    pub fn new() -> Self {
-        Self {
-            signature: Signature::user_defined(Volatility::Immutable),
-        }
-    }
-}
-
-impl BitShiftUDF for SparkShiftLeft {
-    fn shift<T: ArrowPrimitiveType>(
-        &self,
+    fn shift<T>(
         value: &PrimitiveArray<T>,
-        shift: &PrimitiveArray<Int32Type>,
+        shift: &Int32Array,
+        bit_shift_type: BitShiftType,
     ) -> Result<PrimitiveArray<T>>
     where
-        T::Native: ArrowNativeType
-            + std::ops::Shl<i32, Output = T::Native>
+        T: ArrowPrimitiveType,
+        T::Native: std::ops::Shl<i32, Output = T::Native>
             + std::ops::Shr<i32, Output = T::Native>
-            + UShr<i32>,
+            + UShr,
     {
-        shift_left(value, shift)
-    }
-}
-
-impl ScalarUDFImpl for SparkShiftLeft {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn name(&self) -> &str {
-        "shiftleft"
-    }
-
-    fn signature(&self) -> &Signature {
-        &self.signature
-    }
-
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        bit_shift_coerce_types(arg_types, "shiftleft")
-    }
-
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        if arg_types.len() != 2 {
-            return plan_err!("shiftleft expects exactly 2 arguments");
+        match bit_shift_type {
+            BitShiftType::Left => shift_left(value, shift),
+            BitShiftType::Right => shift_right(value, shift),
+            BitShiftType::RightUnsigned => shift_right_unsigned(value, shift),
         }
-        // Return type is the same as the first argument (the value to shift)
-        Ok(arg_types[0].clone())
     }
 
-    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        if args.args.len() != 2 {
-            return plan_err!("shiftleft expects exactly 2 arguments");
+    match value_array.data_type() {
+        DataType::Int32 => {
+            let value_array = value_array.as_primitive::<Int32Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        DataType::Int64 => {
+            let value_array = value_array.as_primitive::<Int64Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        DataType::UInt32 => {
+            let value_array = value_array.as_primitive::<UInt32Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        DataType::UInt64 => {
+            let value_array = value_array.as_primitive::<UInt64Type>();
+            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
+        }
+        dt => {
+            internal_err!("{name} function does not support data type: {dt}")
         }
-        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> { 
self.spark_shift(arr) };
-        make_scalar_function(inner, vec![])(&args.args)
     }
 }
 
-#[derive(Debug, Hash, Eq, PartialEq)]
-pub struct SparkShiftRightUnsigned {
-    signature: Signature,
+#[derive(Debug, Hash, Copy, Clone, Eq, PartialEq)]
+enum BitShiftType {
+    Left,
+    Right,
+    RightUnsigned,
 }
 
-impl Default for SparkShiftRightUnsigned {
-    fn default() -> Self {
-        Self::new()
-    }
+#[derive(Debug, Hash, Eq, PartialEq)]
+pub struct SparkBitShift {
+    signature: Signature,
+    name: &'static str,
+    bit_shift_type: BitShiftType,
 }
 
-impl SparkShiftRightUnsigned {
-    pub fn new() -> Self {
+impl SparkBitShift {
+    fn new(name: &'static str, bit_shift_type: BitShiftType) -> Self {
+        let shift_amount = Coercion::new_implicit(
+            TypeSignatureClass::Native(logical_int32()),
+            vec![TypeSignatureClass::Integer],
+            NativeType::Int32,
+        );
         Self {
-            signature: Signature::user_defined(Volatility::Immutable),
+            signature: Signature::one_of(
+                vec![
+                    // Upcast small ints to 32bit
+                    TypeSignature::Coercible(vec![
+                        Coercion::new_implicit(
+                            TypeSignatureClass::Native(logical_int32()),
+                            vec![
+                                TypeSignatureClass::Native(logical_int8()),
+                                TypeSignatureClass::Native(logical_int16()),
+                            ],
+                            NativeType::Int32,
+                        ),
+                        shift_amount.clone(),
+                    ]),
+                    TypeSignature::Coercible(vec![
+                        Coercion::new_implicit(
+                            TypeSignatureClass::Native(logical_uint32()),
+                            vec![
+                                TypeSignatureClass::Native(logical_uint8()),
+                                TypeSignatureClass::Native(logical_uint16()),
+                            ],
+                            NativeType::UInt32,
+                        ),
+                        shift_amount.clone(),
+                    ]),
+                    // Otherwise accept direct 64 bit integers
+                    TypeSignature::Coercible(vec![
+                        
Coercion::new_exact(TypeSignatureClass::Native(logical_int64())),
+                        shift_amount.clone(),
+                    ]),
+                    TypeSignature::Coercible(vec![
+                        
Coercion::new_exact(TypeSignatureClass::Native(logical_uint64())),
+                        shift_amount.clone(),
+                    ]),
+                ],
+                Volatility::Immutable,
+            ),
+            name,
+            bit_shift_type,
         }
     }
-}
 
-impl BitShiftUDF for SparkShiftRightUnsigned {
-    fn shift<T: ArrowPrimitiveType>(
-        &self,
-        value: &PrimitiveArray<T>,
-        shift: &PrimitiveArray<Int32Type>,
-    ) -> Result<PrimitiveArray<T>>
-    where
-        T::Native: ArrowNativeType
-            + std::ops::Shl<i32, Output = T::Native>
-            + std::ops::Shr<i32, Output = T::Native>
-            + UShr<i32>,
-    {
-        shift_right_unsigned(value, shift)
+    pub fn left() -> Self {
+        Self::new("shiftleft", BitShiftType::Left)
     }
-}
 
-impl ScalarUDFImpl for SparkShiftRightUnsigned {
-    fn as_any(&self) -> &dyn Any {
-        self
+    pub fn right() -> Self {
+        Self::new("shiftright", BitShiftType::Right)
     }
 
-    fn name(&self) -> &str {
-        "shiftrightunsigned"
+    pub fn right_unsigned() -> Self {
+        Self::new("shiftrightunsigned", BitShiftType::RightUnsigned)
     }
-
-    fn signature(&self) -> &Signature {
-        &self.signature
-    }
-
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        bit_shift_coerce_types(arg_types, "shiftrightunsigned")
-    }
-
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        if arg_types.len() != 2 {
-            return plan_err!("shiftrightunsigned expects exactly 2 arguments");
-        }
-        // Return type is the same as the first argument (the value to shift)
-        Ok(arg_types[0].clone())
-    }
-
-    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        if args.args.len() != 2 {
-            return plan_err!("shiftrightunsigned expects exactly 2 arguments");
-        }
-        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> { 
self.spark_shift(arr) };
-        make_scalar_function(inner, vec![])(&args.args)
-    }
-}
-
-#[derive(Debug, Hash, Eq, PartialEq)]
-pub struct SparkShiftRight {
-    signature: Signature,
 }
 
-impl Default for SparkShiftRight {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-impl SparkShiftRight {
-    pub fn new() -> Self {
-        Self {
-            signature: Signature::user_defined(Volatility::Immutable),
-        }
-    }
-}
-
-impl BitShiftUDF for SparkShiftRight {
-    fn shift<T: ArrowPrimitiveType>(
-        &self,
-        value: &PrimitiveArray<T>,
-        shift: &PrimitiveArray<Int32Type>,
-    ) -> Result<PrimitiveArray<T>>
-    where
-        T::Native: ArrowNativeType
-            + std::ops::Shl<i32, Output = T::Native>
-            + std::ops::Shr<i32, Output = T::Native>
-            + UShr<i32>,
-    {
-        shift_right(value, shift)
-    }
-}
-
-impl ScalarUDFImpl for SparkShiftRight {
+impl ScalarUDFImpl for SparkBitShift {
     fn as_any(&self) -> &dyn Any {
         self
     }
 
     fn name(&self) -> &str {
-        "shiftright"
+        self.name
     }
 
     fn signature(&self) -> &Signature {
         &self.signature
     }
 
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        bit_shift_coerce_types(arg_types, "shiftright")
-    }
-
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        if arg_types.len() != 2 {
-            return plan_err!("shiftright expects exactly 2 arguments");
+        if arg_types[0].is_null() {
+            Ok(DataType::Int32)
+        } else {
+            Ok(arg_types[0].clone())
         }
-        // Return type is the same as the first argument (the value to shift)
-        Ok(arg_types[0].clone())
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        if args.args.len() != 2 {
-            return plan_err!("shiftright expects exactly 2 arguments");
-        }
-        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> { 
self.spark_shift(arr) };
+        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> {
+            shift_inner(arr, self.name(), self.bit_shift_type)
+        };
         make_scalar_function(inner, vec![])(&args.args)
     }
 }
-
-#[cfg(test)]
-mod tests {

Review Comment:
   Moved these all to SLTs



##########
datafusion/spark/src/function/bitwise/bit_shift.rs:
##########
@@ -149,589 +136,162 @@ where
     Ok(result)
 }
 
-trait BitShiftUDF: ScalarUDFImpl {
-    fn shift<T: ArrowPrimitiveType>(
-        &self,
-        value: &PrimitiveArray<T>,
-        shift: &PrimitiveArray<Int32Type>,
-    ) -> Result<PrimitiveArray<T>>
-    where
-        T::Native: ArrowNativeType
-            + std::ops::Shl<i32, Output = T::Native>
-            + std::ops::Shr<i32, Output = T::Native>
-            + UShr<i32>;
-
-    fn spark_shift(&self, arrays: &[ArrayRef]) -> Result<ArrayRef> {
-        let value_array = arrays[0].as_ref();
-        let shift_array = arrays[1].as_ref();
-
-        // Ensure shift array is Int32
-        let shift_array = if shift_array.data_type() != &DataType::Int32 {
-            return plan_err!("{} shift amount must be Int32", self.name());
-        } else {
-            shift_array.as_primitive::<Int32Type>()
-        };
-
-        match value_array.data_type() {
-            DataType::Int32 => {
-                let value_array = value_array.as_primitive::<Int32Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::Int64 => {
-                let value_array = value_array.as_primitive::<Int64Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::UInt32 => {
-                let value_array = value_array.as_primitive::<UInt32Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            DataType::UInt64 => {
-                let value_array = value_array.as_primitive::<UInt64Type>();
-                Ok(Arc::new(self.shift(value_array, shift_array)?))
-            }
-            _ => {
-                plan_err!(
-                    "{} function does not support data type: {}",
-                    self.name(),
-                    value_array.data_type()
-                )
-            }
-        }
-    }
-}
+fn shift_inner(
+    arrays: &[ArrayRef],
+    name: &str,
+    bit_shift_type: BitShiftType,
+) -> Result<ArrayRef> {
+    let [value_array, shift_array] = take_function_args(name, arrays)?;
 
-fn bit_shift_coerce_types(arg_types: &[DataType], func: &str) -> 
Result<Vec<DataType>> {
-    if arg_types.len() != 2 {
-        return Err(invalid_arg_count_exec_err(func, (2, 2), arg_types.len()));
-    }
-    if !arg_types[0].is_integer() && !arg_types[0].is_null() {
-        return Err(unsupported_data_type_exec_err(
-            func,
-            "Integer Type",
-            &arg_types[0],
-        ));
-    }
-    if !arg_types[1].is_integer() && !arg_types[1].is_null() {
-        return Err(unsupported_data_type_exec_err(
-            func,
-            "Integer Type",
-            &arg_types[1],
-        ));
+    if value_array.data_type().is_null() || shift_array.data_type().is_null() {
+        return Ok(Arc::new(Int32Array::new_null(value_array.len())));
     }
 
-    // Coerce smaller integer types to Int32
-    let coerced_first = match &arg_types[0] {
-        DataType::Int8 | DataType::Int16 | DataType::Null => DataType::Int32,
-        DataType::UInt8 | DataType::UInt16 => DataType::UInt32,
-        _ => arg_types[0].clone(),
-    };
+    let shift_array = shift_array.as_primitive::<Int32Type>();
 
-    Ok(vec![coerced_first, DataType::Int32])
-}
-
-#[derive(Debug, Hash, Eq, PartialEq)]
-pub struct SparkShiftLeft {

Review Comment:
   Folded these structs into a single common `SparkBitShift` struct



##########
datafusion/functions/src/macros.rs:
##########
@@ -97,6 +98,21 @@ macro_rules! make_udf_function {
             std::sync::Arc::clone(&INSTANCE)
         }
     };
+    ($UDF:ty, $NAME:ident, $CTOR:path) => {

Review Comment:
   This is to accommodate being able to use a single struct (e.g. 
SparkBitShift) for multiple different functions; similar to how allow for 
window functions:
   
   
https://github.com/apache/datafusion/blob/e661b33ee3b96d045fa8cd2533c2a54c07ac7488/datafusion/functions-window/src/macros.rs#L96-L116
   
   I had a bit of difficulty trying to reduce it to something simpler like:
   
   ```rust
   macro_rules! make_udf_function {
       ($UDF:ty, $NAME:ident) => {
           make_udf_function!($UDF, $NAME, $UDF::new); // Error on :: token
       };
       ($UDF:ty, $NAME:ident, $CTOR:path) => {
           #[allow(rustdoc::redundant_explicit_links)]
           #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) 
implementation of ", stringify!($NAME))]
           pub fn $NAME() -> std::sync::Arc<datafusion_expr::ScalarUDF> {
               // Singleton instance of the function
               static INSTANCE: std::sync::LazyLock<
                   std::sync::Arc<datafusion_expr::ScalarUDF>,
               > = std::sync::LazyLock::new(|| {
                   
std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl(
                       $CTOR(),
                   ))
               });
               std::sync::Arc::clone(&INSTANCE)
           }
       };
   }
   ```
   
   To reduce duplication. For now just kept this minor duplication of the arms.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to