alamb commented on code in PR #4488: URL: https://github.com/apache/arrow-datafusion/pull/4488#discussion_r1044345607
########## datafusion/core/tests/sqllogictests/test_files/aggregate.slt: ########## @@ -216,6 +216,70 @@ SELECT approx_median(a) FROM median_f64_nan ---- NaN +# median_multi Review Comment: I ported the tests to sqllogictest as much of the rest of the aggregate tests had been ported too ########## datafusion/physical-expr/src/aggregate/median.rs: ########## @@ -91,157 +91,124 @@ impl AggregateExpr for Median { } #[derive(Debug)] +/// The median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of those scalars struct MedianAccumulator { data_type: DataType, - all_values: Vec<ArrayRef>, -} - -macro_rules! median { - ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{ - let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?; - if combined.is_empty() { - return Ok(ScalarValue::Null); - } - let sorted = sort(&combined, None)?; - let array = as_primitive_array::<$TY>(&sorted)?; - let len = sorted.len(); - let mid = len / 2; - if len % 2 == 0 { - Ok(ScalarValue::$SCALAR_TY(Some( - (array.value(mid - 1) + array.value(mid)) / $TWO, - ))) - } else { - Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid)))) - } - }}; + all_values: Vec<ScalarValue>, } impl Accumulator for MedianAccumulator { fn state(&self) -> Result<Vec<AggregateState>> { - let mut vec: Vec<AggregateState> = self - .all_values - .iter() - .map(|v| AggregateState::Array(v.clone())) - .collect(); - if vec.is_empty() { - match self.data_type { - DataType::UInt8 => vec.push(empty_array::<UInt8Type>()), - DataType::UInt16 => vec.push(empty_array::<UInt16Type>()), - DataType::UInt32 => vec.push(empty_array::<UInt32Type>()), - DataType::UInt64 => vec.push(empty_array::<UInt64Type>()), - DataType::Int8 => vec.push(empty_array::<Int8Type>()), - DataType::Int16 => vec.push(empty_array::<Int16Type>()), - DataType::Int32 => vec.push(empty_array::<Int32Type>()), - DataType::Int64 => vec.push(empty_array::<Int64Type>()), - DataType::Float32 => vec.push(empty_array::<Float32Type>()), - DataType::Float64 => vec.push(empty_array::<Float64Type>()), - _ => { - return Err(DataFusionError::Execution( - "unsupported data type for median".to_string(), - )) - } - } - } - Ok(vec) + let state = + ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone()); + Ok(vec![AggregateState::Scalar(state)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let x = values[0].clone(); - self.all_values.extend_from_slice(&[x]); - Ok(()) - } + assert_eq!(values.len(), 1); + let array = &values[0]; - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - for array in states { - self.all_values.extend_from_slice(&[array.clone()]); + self.all_values.reserve(self.all_values.len() + array.len()); Review Comment: For what it is worth, this assert fails -- the correct assertion is ```rust assert_eq!(array.data_type(), &self.data_type); ``` Which I have fixed -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org