alamb commented on code in PR #2516: URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r875922067
########## datafusion/physical-expr/src/aggregate/sum.rs: ########## @@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale( } } +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + +macro_rules! union_arrays { + ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{ + let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?; + let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?; + let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE); + let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE); + + let chained = lhs_prim_array + .iter() + .chain(rhs_prim_array.iter()) + .collect::<$ARR_DTYPE>(); + + Arc::new(chained) + }}; +} + pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> { - Ok(match (lhs, rhs) { - (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { + let result = match (lhs.get_datatype(), rhs.get_datatype()) { + (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => { let max_precision = p1.max(p2); - if s1.eq(s2) { - // s1 = s2 - sum_decimal(v1, v2, max_precision, s1) - } else if s1.gt(s2) { - // s1 > s2 - sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2) - } else { - // s1 < s2 - sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1) + + match (lhs, rhs) { + ( + ScalarValue::Decimal128(v1, _, _), + ScalarValue::Decimal128(v2, _, _), + ) => { + Ok(if s1.eq(&s2) { + // s1 = s2 + sum_decimal(v1, v2, &max_precision, &s1) + } else if s1.gt(&s2) { + // s1 > s2 + sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2) + } else { + // s1 < s2 + sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1) + }) + } + _ => Err(DataFusionError::Internal( + "Internal state error on sum decimals ".to_string(), + )), } } - // float64 coerces everything to f64 - (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) + (DataType::Float64, _) | (_, DataType::Float64) => { + let data: ArrayRef = + union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64"); + sum_batch(&data, &arrow::datatypes::DataType::Float64) } - (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - // float32 has no cast - (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_sum!(lhs, rhs, Float32, f32) - } - // u64 coerces u* to u64 - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - typed_sum!(lhs, rhs, UInt64, u64) + (DataType::Float32, _) | (_, DataType::Float32) => { + let data: ArrayRef = + union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32"); Review Comment: > implementing the add function(in scalar.rs or any other place) leads to boilerplate code, because there is no concise way how to get ScalarValue underlying value. I think this is the best we can do -- I don't (really) mind boiler plate code, but there seems to be too many copies of almost the same boiler plate code. I agree using `arrow::compute::sum` is likely not the right thing to do -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org