alamb commented on code in PR #8721:
URL: https://github.com/apache/arrow-datafusion/pull/8721#discussion_r1441692046


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -192,6 +262,164 @@ impl Accumulator for DistinctCountAccumulator {
     }
 }
 
+#[derive(Debug)]
+struct NativeDistinctCountAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+    T::Native: Eq + Hash,
+{
+    values: HashSet<T::Native, RandomState>,
+}
+
+impl<T> NativeDistinctCountAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+    T::Native: Eq + Hash,
+{
+    fn new() -> Self {
+        Self {
+            values: HashSet::default(),
+        }
+    }
+}
+
+impl<T> Accumulator for NativeDistinctCountAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send + Debug,
+    T::Native: Eq + Hash,
+{
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
+            self.values.iter().cloned(),
+        )) as ArrayRef;
+        let list = Arc::new(array_into_list_array(arr)) as ArrayRef;

Review Comment:
   👏  for @jayzhan211 for switching the native implementation of 
`ScalarValue::List` to use an `ArrayRef`



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctCountAccumulator {
-            values: HashSet::default(),
-            state_data_type: self.state_data_type.clone(),
-        }))
+        use DataType::*;
+        use TimeUnit::*;
+
+        match &self.state_data_type {
+            Int8 => native_distinct_count_accumulator!(Int8Type),
+            Int16 => native_distinct_count_accumulator!(Int16Type),
+            Int32 => native_distinct_count_accumulator!(Int32Type),
+            Int64 => native_distinct_count_accumulator!(Int64Type),
+            UInt8 => native_distinct_count_accumulator!(UInt8Type),
+            UInt16 => native_distinct_count_accumulator!(UInt16Type),
+            UInt32 => native_distinct_count_accumulator!(UInt32Type),
+            UInt64 => native_distinct_count_accumulator!(UInt64Type),
+            Decimal128(_, _) => 
native_distinct_count_accumulator!(Decimal128Type),
+            Decimal256(_, _) => 
native_distinct_count_accumulator!(Decimal256Type),
+
+            Date32 => native_distinct_count_accumulator!(Date32Type),
+            Date64 => native_distinct_count_accumulator!(Date64Type),
+            Time32(Millisecond) => {
+                native_distinct_count_accumulator!(Time32MillisecondType)
+            }
+            Time32(Second) => {
+                native_distinct_count_accumulator!(Time32SecondType)
+            }
+            Time64(Microsecond) => {
+                native_distinct_count_accumulator!(Time64MicrosecondType)
+            }
+            Time64(Nanosecond) => {
+                native_distinct_count_accumulator!(Time64NanosecondType)
+            }
+            Timestamp(Microsecond, _) => {
+                native_distinct_count_accumulator!(TimestampMicrosecondType)
+            }
+            Timestamp(Millisecond, _) => {
+                native_distinct_count_accumulator!(TimestampMillisecondType)
+            }
+            Timestamp(Nanosecond, _) => {
+                native_distinct_count_accumulator!(TimestampNanosecondType)
+            }
+            Timestamp(Second, _) => {
+                native_distinct_count_accumulator!(TimestampSecondType)
+            }
+
+            Float16 => float_distinct_count_accumulator!(Float16Type),
+            Float32 => float_distinct_count_accumulator!(Float32Type),
+            Float64 => float_distinct_count_accumulator!(Float64Type),

Review Comment:
   This is similar to the idea in 
https://github.com/apache/arrow-datafusion/issues/7064
   
   Maybe we can eventually use the same data structure (specialized for storing 
string values not using a `String`)



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -192,6 +262,164 @@ impl Accumulator for DistinctCountAccumulator {
     }
 }
 
+#[derive(Debug)]
+struct NativeDistinctCountAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+    T::Native: Eq + Hash,
+{
+    values: HashSet<T::Native, RandomState>,
+}
+
+impl<T> NativeDistinctCountAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+    T::Native: Eq + Hash,
+{
+    fn new() -> Self {
+        Self {
+            values: HashSet::default(),
+        }
+    }
+}
+
+impl<T> Accumulator for NativeDistinctCountAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send + Debug,
+    T::Native: Eq + Hash,
+{
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
+            self.values.iter().cloned(),
+        )) as ArrayRef;
+        let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
+        Ok(vec![ScalarValue::List(list)])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+
+        let arr = as_primitive_array::<T>(&values[0])?;
+        arr.iter().for_each(|value| {
+            if let Some(value) = value {
+                self.values.insert(value);
+            }
+        });
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        if states.is_empty() {
+            return Ok(());
+        }
+        assert_eq!(
+            states.len(),
+            1,
+            "count_distinct states must be single array"
+        );
+
+        let arr = as_list_array(&states[0])?;
+        arr.iter().try_for_each(|maybe_list| {
+            if let Some(list) = maybe_list {
+                let list = as_primitive_array::<T>(&list)?;
+                self.values.extend(list.values())
+            };
+            Ok(())
+        })
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+            + std::mem::size_of_val(&self.values)
+            + (std::mem::size_of::<T::Native>() * self.values.capacity())

Review Comment:
   It seems like maybe we just need to add overhead of the size of a hash for 
each entry? I poked around in the HashSet / HashTable implementation and it is 
not immeditately clear how much additional overhead there is
   
   Maybe something like
   ```rust
               + ((std::mem::size_of::<T::Native>() + std::mem::size_of<u64>)) 
* self.values.capacity())
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to