This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f48e0b22b6 Impl `convert_to_state` for `GroupsAccumulatorAdapter` 
(faster median for high cardinality aggregates) (#11827)
f48e0b22b6 is described below

commit f48e0b22b6cc84f37f3c14b74008fcca02712d26
Author: kamille <[email protected]>
AuthorDate: Sun Sep 15 20:00:31 2024 +0800

    Impl `convert_to_state` for `GroupsAccumulatorAdapter` (faster median for 
high cardinality aggregates) (#11827)
    
    * make a draft for `convert_to_state` in `GroupsAccumulatorAdapter`.
    
    * tmp
    
    * use filter nulls to impl quick filter for some arrays.
    
    * add unique group by test for `median`, `approx_median`, `approx_distinct`.
    
    * add normal cases & nullable cases for `median`, `approx_median`, 
`approx_distinct`.
    
    * add filter cases for `median`, `approx_median`, `approx_distinct`.
    
    * fix clippy.
    
    * fix fmt.
    
    * add todo.
    
    * fix comments.
    
    * fallback to filter kernal for general.
    
    * remove unused imports.
    
    * remove unused Array.
---
 .../src/aggregate/groups_accumulator.rs            |  55 +++++-
 .../test_files/aggregate_skip_partial.slt          | 220 +++++++++++++++++++++
 2 files changed, 269 insertions(+), 6 deletions(-)

diff --git 
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs 
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs
index 1c97d22ec7..b5eb36c3fa 100644
--- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs
+++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs
@@ -207,8 +207,11 @@ impl GroupsAccumulatorAdapter {
             let state = &mut self.states[group_idx];
             sizes_pre += state.size();
 
-            let values_to_accumulate =
-                slice_and_maybe_filter(&values, opt_filter.as_ref(), offsets)?;
+            let values_to_accumulate = slice_and_maybe_filter(
+                &values,
+                opt_filter.as_ref().map(|f| f.as_boolean()),
+                offsets,
+            )?;
             (f)(state.accumulator.as_mut(), &values_to_accumulate)?;
 
             // clear out the state so they are empty for next
@@ -290,6 +293,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
         result
     }
 
+    // filtered_null_mask(opt_filter, &values);
     fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
         let vec_size_pre = self.states.allocated_size();
         let states = emit_to.take_needed(&mut self.states);
@@ -348,6 +352,46 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
     fn size(&self) -> usize {
         self.allocation_bytes
     }
+
+    fn convert_to_state(
+        &self,
+        values: &[ArrayRef],
+        opt_filter: Option<&BooleanArray>,
+    ) -> Result<Vec<ArrayRef>> {
+        let num_rows = values[0].len();
+
+        // Each row has its respective group
+        let mut results = vec![];
+        for row_idx in 0..num_rows {
+            // Create the empty accumulator for converting
+            let mut converted_accumulator = (self.factory)()?;
+
+            // Convert row to states
+            let values_to_accumulate =
+                slice_and_maybe_filter(values, opt_filter, &[row_idx, row_idx 
+ 1])?;
+            converted_accumulator.update_batch(&values_to_accumulate)?;
+            let states = converted_accumulator.state()?;
+
+            // Resize results to have enough columns according to the 
converted states
+            results.resize_with(states.len(), || Vec::with_capacity(num_rows));
+
+            // Add the states to results
+            for (idx, state_val) in states.into_iter().enumerate() {
+                results[idx].push(state_val);
+            }
+        }
+
+        let arrays = results
+            .into_iter()
+            .map(ScalarValue::iter_to_array)
+            .collect::<Result<Vec<_>>>()?;
+
+        Ok(arrays)
+    }
+
+    fn supports_convert_to_state(&self) -> bool {
+        true
+    }
 }
 
 /// Extension trait for [`Vec`] to account for allocations.
@@ -384,7 +428,7 @@ fn get_filter_at_indices(
 // Copied from physical-plan
 pub(crate) fn slice_and_maybe_filter(
     aggr_array: &[ArrayRef],
-    filter_opt: Option<&ArrayRef>,
+    filter_opt: Option<&BooleanArray>,
     offsets: &[usize],
 ) -> Result<Vec<ArrayRef>> {
     let (offset, length) = (offsets[0], offsets[1] - offsets[0]);
@@ -394,13 +438,12 @@ pub(crate) fn slice_and_maybe_filter(
         .collect();
 
     if let Some(f) = filter_opt {
-        let filter_array = f.slice(offset, length);
-        let filter_array = filter_array.as_boolean();
+        let filter = f.slice(offset, length);
 
         sliced_arrays
             .iter()
             .map(|array| {
-                compute::filter(array, filter_array).map_err(|e| 
arrow_datafusion_err!(e))
+                compute::filter(&array, &filter).map_err(|e| 
arrow_datafusion_err!(e))
             })
             .collect()
     } else {
diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt 
b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
index ab1c7e78f1..a2e51cffac 100644
--- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
+++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
@@ -133,6 +133,51 @@ GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
 -2117946883 d -2117946883 NULL NULL NULL
 -2098805236 c -2098805236 NULL NULL NULL
 
+query ITIIII
+SELECT c5, c1,
+       MEDIAN(c5),
+       MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
+       MEDIAN(c5) FILTER (WHERE c1 = 'b'),
+       MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b')
+FROM aggregate_test_100
+GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
+----
+-2141999138 c -2141999138 NULL NULL NULL
+-2141451704 a -2141451704 -2141451704 NULL NULL
+-2138770630 b -2138770630 NULL -2138770630 NULL
+-2117946883 d -2117946883 NULL NULL NULL
+-2098805236 c -2098805236 NULL NULL NULL
+
+query ITIIII
+SELECT c5, c1,
+       APPROX_MEDIAN(c5),
+       APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
+       APPROX_MEDIAN(c5) FILTER (WHERE c1 = 'b'),
+       APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE 
c1 = 'b')
+FROM aggregate_test_100
+GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
+----
+-2141999138 c -2141999138 NULL NULL NULL
+-2141451704 a -2141451704 -2141451704 NULL NULL
+-2138770630 b -2138770630 NULL -2138770630 NULL
+-2117946883 d -2117946883 NULL NULL NULL
+-2098805236 c -2098805236 NULL NULL NULL
+
+query ITIIII
+SELECT c5, c1,
+       APPROX_DISTINCT(c5),
+       APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
+       APPROX_DISTINCT(c5) FILTER (WHERE c1 = 'b'),
+       APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE 
c1 = 'b')
+FROM aggregate_test_100
+GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
+----
+-2141999138 c 1 0 0 0
+-2141451704 a 1 1 0 0
+-2138770630 b 1 0 1 0
+-2117946883 d 1 0 0 0
+-2098805236 c 1 0 0 0
+
 # FIXME: add bool_and(v3) column when issue fixed
 # ISSUE https://github.com/apache/datafusion/issues/11846
 query TBBB rowsort
@@ -222,6 +267,36 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP 
BY c2 ORDER BY c2;
 4 16155718643 9.531112968922
 5 6449337880 7.074412226677
 
+# Test median for int / float
+query IIR
+SELECT c2, median(c5), median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER 
BY c2;
+----
+1 23971150 0.5922606
+2 -562486880 0.43422085
+3 240273900 0.40199697
+4 762932956 0.48515016
+5 604973998 0.49842384
+
+# Test approx_median for int / float
+query IIR
+SELECT c2, approx_median(c5), approx_median(c11) FROM aggregate_test_100 GROUP 
BY c2 ORDER BY c2;
+----
+1 191655437 0.59926736
+2 -587831330 0.43230486
+3 240273900 0.40199697
+4 762932956 0.48515016
+5 593204320 0.5156586
+
+# Test approx_distinct for varchar / int
+query III
+SELECT c2, approx_distinct(c1), approx_distinct(c5) FROM aggregate_test_100 
GROUP BY c2 ORDER BY c2;
+----
+1 5 22
+2 5 22
+3 5 19
+4 5 23
+5 5 14
+
 # Test count with nullable fields
 query III
 SELECT c2, count(c3), count(c11) FROM aggregate_test_100_null GROUP BY c2 
ORDER BY c2;
@@ -252,6 +327,36 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP 
BY c2 ORDER BY c2;
 4 29 9.531112968922
 5 -194 7.074412226677
 
+# Test median with nullable fields
+query IIR
+SELECT c2, median(c3), median(c11) FROM aggregate_test_100_null GROUP BY c2 
ORDER BY c2;
+----
+1 12 0.6067944
+2 1 0.46076488
+3 14 0.40154034
+4 -17 0.48515016
+5 -35 0.5536642
+
+# Test approx_median with nullable fields
+query IIR
+SELECT c2, approx_median(c3), approx_median(c11) FROM aggregate_test_100_null 
GROUP BY c2 ORDER BY c2;
+----
+1 12 0.6067944
+2 1 0.46076488
+3 14 0.40154034
+4 -7 0.48515016
+5 -39 0.5536642
+
+# Test approx_distinct with nullable fields
+query II
+SELECT c2, approx_distinct(c3) FROM aggregate_test_100_null GROUP BY c2 ORDER 
BY c2;
+----
+1 19
+2 16
+3 13
+4 16
+5 12
+
 # Test avg for tinyint / float
 query TRR
 SELECT
@@ -338,6 +443,48 @@ FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
 4 417
 5 284
 
+# Test approx_distinct with filter
+query III
+SELECT
+  c2,
+  approx_distinct(c3) FILTER (WHERE c3 > 0),
+  approx_distinct(c3) FILTER (WHERE c11 > 10)
+FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
+----
+1 13 0
+2 12 0
+3 11 0
+4 13 0
+5 5 0
+
+# Test median with filter
+query III
+SELECT
+  c2,
+  median(c3) FILTER (WHERE c3 > 0),
+  median(c3) FILTER (WHERE c3 < 0)
+FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
+----
+1 57 -56
+2 52 -60
+3 71 -74
+4 65 -69
+5 64 -59
+
+# Test approx_median with filter
+query III
+SELECT
+  c2,
+  approx_median(c3) FILTER (WHERE c3 > 0),
+  approx_median(c3) FILTER (WHERE c3 < 0)
+FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
+----
+1 57 -56
+2 52 -60
+3 71 -76
+4 65 -64
+5 64 -59
+
 # Test count with nullable fields and filter
 query III
 SELECT c2,
@@ -421,6 +568,79 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
 4 -171 56 2.10740506649 1.939846396446
 5 -86 -76 1.8741710186 1.600569307804
 
+# Test approx_distinct with nullable fields and filter
+query II
+SELECT c2,
+       approx_distinct(c3) FILTER (WHERE c5 > 0)
+FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
+----
+1 11
+2 6
+3 6
+4 11
+5 8
+
+# Test approx_distinct with nullable fields and nullable filter
+query II
+SELECT c2,
+       approx_distinct(c3) FILTER (WHERE c11 > 0.5)
+FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
+----
+1 10
+2 6
+3 3
+4 3
+5 6
+
+# Test median with nullable fields and filter
+query IIR
+SELECT c2,
+       median(c3) FILTER (WHERE c5 > 0),
+       median(c11) FILTER (WHERE c5 < 0)
+FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
+----
+1 -5 0.6623719
+2 15 0.52930677
+3 13 0.32792538
+4 -38 0.49774808
+5 -18 0.49842384
+
+# Test min / max with nullable fields and nullable filter
+query II
+SELECT c2,
+       median(c3) FILTER (WHERE c11 > 0.5)
+FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
+----
+1 33
+2 -29
+3 22
+4 -90
+5 -22
+
+# Test approx_median with nullable fields and filter
+query IIR
+SELECT c2,
+       approx_median(c3) FILTER (WHERE c5 > 0),
+       approx_median(c11) FILTER (WHERE c5 < 0)
+FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
+----
+1 -5 0.6623719
+2 12 0.52930677
+3 13 0.32792538
+4 -38 0.49774808
+5 -21 0.47652745
+
+# Test approx_median with nullable fields and nullable filter
+query II
+SELECT c2,
+       approx_median(c3) FILTER (WHERE c11 > 0.5)
+FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
+----
+1 35
+2 -29
+3 22
+4 -90
+5 -32
 
 statement ok
 DROP TABLE aggregate_test_100_null;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to