martin-g commented on code in PR #18649:
URL: https://github.com/apache/datafusion/pull/18649#discussion_r2519923367
##########
datafusion/spark/src/function/bitwise/bit_shift.rs:
##########
@@ -149,589 +136,162 @@ where
Ok(result)
}
-trait BitShiftUDF: ScalarUDFImpl {
- fn shift<T: ArrowPrimitiveType>(
- &self,
- value: &PrimitiveArray<T>,
- shift: &PrimitiveArray<Int32Type>,
- ) -> Result<PrimitiveArray<T>>
- where
- T::Native: ArrowNativeType
- + std::ops::Shl<i32, Output = T::Native>
- + std::ops::Shr<i32, Output = T::Native>
- + UShr<i32>;
-
- fn spark_shift(&self, arrays: &[ArrayRef]) -> Result<ArrayRef> {
- let value_array = arrays[0].as_ref();
- let shift_array = arrays[1].as_ref();
-
- // Ensure shift array is Int32
- let shift_array = if shift_array.data_type() != &DataType::Int32 {
- return plan_err!("{} shift amount must be Int32", self.name());
- } else {
- shift_array.as_primitive::<Int32Type>()
- };
-
- match value_array.data_type() {
- DataType::Int32 => {
- let value_array = value_array.as_primitive::<Int32Type>();
- Ok(Arc::new(self.shift(value_array, shift_array)?))
- }
- DataType::Int64 => {
- let value_array = value_array.as_primitive::<Int64Type>();
- Ok(Arc::new(self.shift(value_array, shift_array)?))
- }
- DataType::UInt32 => {
- let value_array = value_array.as_primitive::<UInt32Type>();
- Ok(Arc::new(self.shift(value_array, shift_array)?))
- }
- DataType::UInt64 => {
- let value_array = value_array.as_primitive::<UInt64Type>();
- Ok(Arc::new(self.shift(value_array, shift_array)?))
- }
- _ => {
- plan_err!(
- "{} function does not support data type: {}",
- self.name(),
- value_array.data_type()
- )
- }
- }
- }
-}
+fn shift_inner(
+ arrays: &[ArrayRef],
+ name: &str,
+ bit_shift_type: BitShiftType,
+) -> Result<ArrayRef> {
+ let [value_array, shift_array] = take_function_args(name, arrays)?;
-fn bit_shift_coerce_types(arg_types: &[DataType], func: &str) ->
Result<Vec<DataType>> {
- if arg_types.len() != 2 {
- return Err(invalid_arg_count_exec_err(func, (2, 2), arg_types.len()));
- }
- if !arg_types[0].is_integer() && !arg_types[0].is_null() {
- return Err(unsupported_data_type_exec_err(
- func,
- "Integer Type",
- &arg_types[0],
- ));
- }
- if !arg_types[1].is_integer() && !arg_types[1].is_null() {
- return Err(unsupported_data_type_exec_err(
- func,
- "Integer Type",
- &arg_types[1],
- ));
+ if value_array.data_type().is_null() || shift_array.data_type().is_null() {
+ return Ok(Arc::new(Int32Array::new_null(value_array.len())));
Review Comment:
Why always Int32Array ?
If `shift_array.data_type().is_null()` then I think you need to use the type
returned by `value_array.data_type()`, which could be Int64 for example.
If `value_array.data_type().is_null()` then fallback to Int32Array
##########
datafusion/spark/src/function/bitwise/bit_shift.rs:
##########
@@ -18,36 +18,33 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
+use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array,
PrimitiveArray};
use arrow::compute;
use arrow::datatypes::{
ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type,
Review Comment:
```suggestion
DataType, Int32Type, Int64Type, UInt32Type, UInt64Type,
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]