alamb commented on code in PR #10226:
URL: https://github.com/apache/datafusion/pull/10226#discussion_r1580921553


##########
datafusion/physical-expr/src/aggregate/median.rs:
##########
@@ -196,6 +184,192 @@ impl<T: ArrowNumericType> Accumulator for 
MedianAccumulator<T> {
     }
 }
 
+/// MEDIAN(DISTINCT) aggregate expression. Similar to MEDIAN but computes 
after taking
+/// all unique values. This may use a lot of memory if the cardinality is high.
+#[derive(Debug)]
+pub struct DistinctMedian {
+    name: String,
+    expr: Arc<dyn PhysicalExpr>,
+    data_type: DataType,
+}
+
+impl DistinctMedian {
+    /// Create a new MEDIAN(DISTINCT) aggregate function
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+    ) -> Self {
+        Self {
+            name: name.into(),
+            expr,
+            data_type,
+        }
+    }
+}
+
+impl AggregateExpr for DistinctMedian {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, self.data_type.clone(), true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        use arrow_array::types::*;
+        macro_rules! helper {
+            ($t:ty, $dt:expr) => {
+                Ok(Box::new(DistinctMedianAccumulator::<$t> {
+                    data_type: $dt.clone(),
+                    distinct_values: Default::default(),
+                }))
+            };
+        }
+        let dt = &self.data_type;
+        downcast_integer! {
+            dt => (helper, dt),
+            DataType::Float16 => helper!(Float16Type, dt),
+            DataType::Float32 => helper!(Float32Type, dt),
+            DataType::Float64 => helper!(Float64Type, dt),
+            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
+            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "DistinctMedianAccumulator not supported for {} with {}",
+                self.name(),
+                self.data_type
+            ))),
+        }
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        // Intermediate state is a list of the unique elements we have
+        // collected so far
+        let field = Field::new("item", self.data_type.clone(), true);
+        let data_type = DataType::List(Arc::new(field));
+
+        Ok(vec![Field::new(
+            format_state_name(&self.name, "distinct_median"),
+            data_type,
+            true,
+        )])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+impl PartialEq<dyn Any> for DistinctMedian {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| {
+                self.name == x.name
+                    && self.data_type == x.data_type
+                    && self.expr.eq(&x.expr)
+            })
+            .unwrap_or(false)
+    }
+}
+
+/// The distinct median accumulator accumulates the raw input values
+/// as `ScalarValue`s
+///
+/// The intermediate state is represented as a List of scalar values updated by
+/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
+/// in the final evaluation step so that we avoid expensive conversions and
+/// allocations during `update_batch`.
+struct DistinctMedianAccumulator<T: ArrowNumericType> {
+    data_type: DataType,
+    distinct_values: HashSet<Hashable<T::Native>>,
+}
+
+impl<T: ArrowNumericType> std::fmt::Debug for DistinctMedianAccumulator<T> {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(f, "DistinctMedianAccumulator({})", self.data_type)
+    }
+}
+
+impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        let all_values = self
+            .distinct_values
+            .iter()
+            .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), 
&self.data_type))
+            .collect::<Result<Vec<_>>>()?;
+
+        let arr = ScalarValue::new_list(&all_values, &self.data_type);
+        Ok(vec![ScalarValue::List(arr)])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+
+        let array = values[0].as_primitive::<T>();
+        match array.nulls().filter(|x| x.null_count() > 0) {

Review Comment:
   My point in 
https://github.com/apache/datafusion/pull/10226#discussion_r1579978178 was that 
the way this is checking for no-nulls seems overly obscure to me
   
   I think the code could look like
   
   ```rust
           if array.null_count() > 0 {
             for val in array.iter() {
                if let Some(value) = val {
                       self.distinct_values.insert(Hashable(value))
             }
           } else {
              array.values().iter().for_each(|x| {
                   self.distinct_values.insert(Hashable(*x));
               }),
          }
   ```
   
   But I also don't think it is a big deal 
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to