alamb commented on code in PR #4488:
URL: https://github.com/apache/arrow-datafusion/pull/4488#discussion_r1044345607


##########
datafusion/core/tests/sqllogictests/test_files/aggregate.slt:
##########
@@ -216,6 +216,70 @@ SELECT approx_median(a) FROM median_f64_nan
 ----
 NaN
 
+# median_multi

Review Comment:
   I ported the tests to sqllogictest as much of the rest of the aggregate 
tests had been ported too



##########
datafusion/physical-expr/src/aggregate/median.rs:
##########
@@ -91,157 +91,124 @@ impl AggregateExpr for Median {
 }
 
 #[derive(Debug)]
+/// The median accumulator accumulates the raw input values
+/// as `ScalarValue`s
+///
+/// The intermediate state is represented as a List of those scalars
 struct MedianAccumulator {
     data_type: DataType,
-    all_values: Vec<ArrayRef>,
-}
-
-macro_rules! median {
-    ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{
-        let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?;
-        if combined.is_empty() {
-            return Ok(ScalarValue::Null);
-        }
-        let sorted = sort(&combined, None)?;
-        let array = as_primitive_array::<$TY>(&sorted)?;
-        let len = sorted.len();
-        let mid = len / 2;
-        if len % 2 == 0 {
-            Ok(ScalarValue::$SCALAR_TY(Some(
-                (array.value(mid - 1) + array.value(mid)) / $TWO,
-            )))
-        } else {
-            Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid))))
-        }
-    }};
+    all_values: Vec<ScalarValue>,
 }
 
 impl Accumulator for MedianAccumulator {
     fn state(&self) -> Result<Vec<AggregateState>> {
-        let mut vec: Vec<AggregateState> = self
-            .all_values
-            .iter()
-            .map(|v| AggregateState::Array(v.clone()))
-            .collect();
-        if vec.is_empty() {
-            match self.data_type {
-                DataType::UInt8 => vec.push(empty_array::<UInt8Type>()),
-                DataType::UInt16 => vec.push(empty_array::<UInt16Type>()),
-                DataType::UInt32 => vec.push(empty_array::<UInt32Type>()),
-                DataType::UInt64 => vec.push(empty_array::<UInt64Type>()),
-                DataType::Int8 => vec.push(empty_array::<Int8Type>()),
-                DataType::Int16 => vec.push(empty_array::<Int16Type>()),
-                DataType::Int32 => vec.push(empty_array::<Int32Type>()),
-                DataType::Int64 => vec.push(empty_array::<Int64Type>()),
-                DataType::Float32 => vec.push(empty_array::<Float32Type>()),
-                DataType::Float64 => vec.push(empty_array::<Float64Type>()),
-                _ => {
-                    return Err(DataFusionError::Execution(
-                        "unsupported data type for median".to_string(),
-                    ))
-                }
-            }
-        }
-        Ok(vec)
+        let state =
+            ScalarValue::new_list(Some(self.all_values.clone()), 
self.data_type.clone());
+        Ok(vec![AggregateState::Scalar(state)])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let x = values[0].clone();
-        self.all_values.extend_from_slice(&[x]);
-        Ok(())
-    }
+        assert_eq!(values.len(), 1);
+        let array = &values[0];
 
-    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        for array in states {
-            self.all_values.extend_from_slice(&[array.clone()]);
+        self.all_values.reserve(self.all_values.len() + array.len());

Review Comment:
   For what it is worth, this assert fails -- the correct assertion is 
   
   ```rust
           assert_eq!(array.data_type(), &self.data_type);
   ```
   
   Which I have fixed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to