alamb commented on code in PR #6103:
URL: https://github.com/apache/arrow-datafusion/pull/6103#discussion_r1176844965
##########
datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs:
##########
@@ -506,31 +513,159 @@ pub(crate) fn subtract_dyn_decimal(
decimal_array_with_precision_scale(array, precision, scale)
}
-pub(crate) fn multiply_dyn_decimal(
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn math_op_dict<K, T, F>(
+ left: &DictionaryArray<K>,
+ right: &DictionaryArray<K>,
+ op: F,
+) -> Result<PrimitiveArray<T>>
+where
+ K: ArrowDictionaryKeyType + ArrowNumericType,
+ T: ArrowNumericType,
+ F: Fn(T::Native, T::Native) -> T::Native,
+{
+ if left.len() != right.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Cannot perform operation on arrays of different length ({}, {})",
+ left.len(),
+ right.len()
+ )));
+ }
+
+ // Safety justification: Since the inputs are valid Arrow arrays, all
values are
+ // valid indexes into the dictionary (which is verified during
construction)
+
+ let left_iter = unsafe {
+ left.values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .unwrap()
+ .take_iter_unchecked(left.keys_iter())
+ };
+
+ let right_iter = unsafe {
+ right
+ .values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .unwrap()
+ .take_iter_unchecked(right.keys_iter())
+ };
Review Comment:
```suggestion
let left_iter = unsafe {
left.values()
.as_primitive::<T>()
.take_iter_unchecked(left.keys_iter())
};
let right_iter = unsafe {
right
.values()
.as_primitive::<T>()
.take_iter_unchecked(right.keys_iter())
};
```
##########
datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs:
##########
@@ -391,31 +398,159 @@ pub(crate) fn subtract_dyn_decimal(
decimal_array_with_precision_scale(array, precision, scale)
}
-pub(crate) fn multiply_dyn_decimal(
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn math_op_dict<K, T, F>(
Review Comment:
I couldn't find any reference to `multiply_fixed_point_dyn` in arrow-rs --
is there an existing ticket?
https://github.com/search?q=repo%3Aapache%2Farrow-rs%20multiply_fixed_point_dyn&type=code
##########
datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs:
##########
@@ -506,31 +513,159 @@ pub(crate) fn subtract_dyn_decimal(
decimal_array_with_precision_scale(array, precision, scale)
}
-pub(crate) fn multiply_dyn_decimal(
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn math_op_dict<K, T, F>(
+ left: &DictionaryArray<K>,
+ right: &DictionaryArray<K>,
+ op: F,
+) -> Result<PrimitiveArray<T>>
+where
+ K: ArrowDictionaryKeyType + ArrowNumericType,
+ T: ArrowNumericType,
+ F: Fn(T::Native, T::Native) -> T::Native,
+{
+ if left.len() != right.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Cannot perform operation on arrays of different length ({}, {})",
+ left.len(),
+ right.len()
+ )));
+ }
+
+ // Safety justification: Since the inputs are valid Arrow arrays, all
values are
+ // valid indexes into the dictionary (which is verified during
construction)
+
+ let left_iter = unsafe {
+ left.values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .unwrap()
+ .take_iter_unchecked(left.keys_iter())
+ };
+
+ let right_iter = unsafe {
+ right
+ .values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .unwrap()
+ .take_iter_unchecked(right.keys_iter())
+ };
+
+ let result = left_iter
+ .zip(right_iter)
+ .map(|(left_value, right_value)| {
+ if let (Some(left), Some(right)) = (left_value, right_value) {
+ Some(op(left, right))
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ Ok(result)
+}
+
+/// Divide a decimal native value by given divisor and round the result.
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
+where
+ I: DecimalType,
+ I::Native: ArrowNativeTypeOp,
+{
+ let d = input.div_wrapping(div);
+ let r = input.mod_wrapping(div);
+
+ let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+ let half_neg = half.neg_wrapping();
+ // Round result
+ match input >= I::Native::ZERO {
+ true if r >= half => d.add_wrapping(I::Native::ONE),
+ false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
+ _ => d,
+ }
+}
+
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn multiply_fixed_point_dyn(
left: &dyn Array,
right: &dyn Array,
- result_type: &DataType,
+ required_scale: i8,
) -> Result<ArrayRef> {
- let (precision, scale) = get_precision_scale(result_type)?;
+ match (left.data_type(), right.data_type()) {
+ (
+ DataType::Dictionary(_, lhs_value_type),
+ DataType::Dictionary(_, rhs_value_type),
+ ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
+ && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _))
=>
+ {
+ downcast_dictionary_array!(
+ left => match left.values().data_type() {
+ DataType::Decimal128(_, _) => {
+ let lhs_precision_scale =
get_precision_scale(lhs_value_type.as_ref())?;
+ let rhs_precision_scale =
get_precision_scale(rhs_value_type.as_ref())?;
- let op_type = decimal_op_mathematics_type(
- &Operator::Multiply,
- left.data_type(),
- left.data_type(),
- )
- .unwrap();
- let (_, op_scale) = get_precision_scale(&op_type)?;
+ let product_scale = lhs_precision_scale.1 +
rhs_precision_scale.1;
+ let precision = min(lhs_precision_scale.0 +
rhs_precision_scale.0 + 1, DECIMAL128_MAX_PRECISION);
- let array = multiply_dyn(left, right)?;
- if op_scale > scale {
- let div = 10_i128.pow((op_scale - scale) as u32);
- let array = divide_scalar_dyn::<Decimal128Type>(&array, div)?;
- decimal_array_with_precision_scale(array, precision, scale)
- } else {
- decimal_array_with_precision_scale(array, precision, scale)
+ if required_scale == product_scale {
+ return Ok(multiply_dyn(left,
right)?.as_primitive::<Decimal128Type>().clone()
+ .with_precision_and_scale(precision,
required_scale).map(|a| Arc::new(a) as ArrayRef)?);
+ }
+
+ if required_scale > product_scale {
+ return Err(DataFusionError::Internal(format!(
+ "Required scale {} is greater than product
scale {}",
+ required_scale, product_scale
+ )));
+ }
+
+ let divisor =
+ i256::from_i128(10).pow_wrapping((product_scale -
required_scale) as u32);
+
+ let right = as_dictionary_array::<_>(right);
+
+ let array = math_op_dict::<_, Decimal128Type, _>(left,
right, |a, b| {
+ let a = i256::from_i128(a);
+ let b = i256::from_i128(b);
+
+ let mut mul = a.wrapping_mul(b);
+ mul = divide_and_round::<Decimal256Type>(mul,
divisor);
+ mul.as_i128()
+ }).map(|a| a.with_precision_and_scale(precision,
required_scale).unwrap())?;
+
+ Ok(Arc::new(array))
+ }
+ t => unreachable!("Unsupported dictionary value type {}",
t),
+ },
+ t => unreachable!("Unsupported data type {}", t),
+ )
+ }
+ (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
+ let left =
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ let right =
right.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+ Ok(multiply_fixed_point(left, right, required_scale)
+ .map(|a| Arc::new(a) as ArrayRef)?)
+ }
+ (_, _) => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {}, {}",
+ left.data_type(),
+ right.data_type()
+ ))),
}
}
+pub(crate) fn multiply_dyn_decimal(
Review Comment:
It would be great to eventually get ride of the `decimal` variants of the
`dyn` kernels and simply use the `mul_dyn`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]