alamb commented on code in PR #7192:
URL: https://github.com/apache/arrow-datafusion/pull/7192#discussion_r1317578819
##########
datafusion/common/src/config.rs:
##########
@@ -380,6 +380,10 @@ config_namespace! {
/// repartitioning to increase parallelism to leverage more CPU cores
pub enable_round_robin_repartition: bool, default = true
+ /// When set to true, the optimizer will attempt to perform limit
operations
Review Comment:
👍 for including an escape hatch
##########
datafusion/core/src/physical_plan/aggregates/topk/hash_table.rs:
##########
@@ -0,0 +1,438 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A wrapper around `hashbrown::RawTable` that allows entries to be tracked
by index
+
+use crate::physical_plan::aggregates::group_values::primitive::HashValue;
+use crate::physical_plan::aggregates::topk::heap::Comparable;
+use ahash::RandomState;
+use arrow::datatypes::i256;
+use arrow_array::builder::PrimitiveBuilder;
+use arrow_array::cast::AsArray;
+use arrow_array::{
+ downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray,
StringArray,
+};
+use arrow_schema::DataType;
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use half::f16;
+use hashbrown::raw::RawTable;
+use std::fmt::Debug;
+use std::sync::Arc;
+
+/// A "type alias" for Keys which are stored in our map
+pub trait KeyType: Clone + Comparable + Debug {}
+
+impl<T> KeyType for T where T: Clone + Comparable + Debug {}
+
+/// An entry in our hash table that:
+/// 1. memoizes the hash
+/// 2. contains the key (ID)
+/// 3. contains the value (heap_idx - an index into the corresponding heap)
+pub struct HashTableItem<ID: KeyType> {
+ hash: u64,
+ pub id: ID,
+ pub heap_idx: usize,
+}
+
+/// A custom wrapper around `hashbrown::RawTable` that:
+/// 1. limits the number of entries to the top K
+/// 2. Allocates a capacity greater than top K to maintain a low-fill factor
and prevent resizing
+/// 3. Tracks indexes to allow corresponding heap to refer to entries by index
vs hash
+/// 4. Catches resize events to allow the corresponding heap to update it's
indexes
+struct TopKHashTable<ID: KeyType> {
+ map: RawTable<HashTableItem<ID>>,
+ limit: usize,
+}
+
+/// An interface to hide the generic type signature of TopKHashTable behind
arrow arrays
+pub trait ArrowHashTable {
+ fn set_batch(&mut self, ids: ArrayRef);
+ fn len(&self) -> usize;
+ // JUSTIFICATION
+ // Benefit: ~15% speedup + required to index into RawTable from binary
heap
+ // Soundness: the caller must provide valid indexes
+ unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]);
+ // JUSTIFICATION
+ // Benefit: ~15% speedup + required to index into RawTable from binary
heap
+ // Soundness: the caller must provide a valid index
+ unsafe fn heap_idx_at(&self, map_idx: usize) -> usize;
+ fn drain(&mut self) -> (ArrayRef, Vec<usize>);
+
+ // JUSTIFICATION
+ // Benefit: ~15% speedup + required to index into RawTable from binary
heap
+ // Soundness: the caller must provide valid indexes
+ unsafe fn find_or_insert(
+ &mut self,
+ row_idx: usize,
+ replace_idx: usize,
+ map: &mut Vec<(usize, usize)>,
+ ) -> (usize, bool);
+}
+
+// An implementation of ArrowHashTable for String keys
+pub struct StringHashTable {
+ owned: ArrayRef,
+ map: TopKHashTable<Option<String>>,
+ rnd: RandomState,
+}
+
+// An implementation of ArrowHashTable for any `ArrowPrimitiveType` key
+struct PrimitiveHashTable<VAL: ArrowPrimitiveType>
+where
+ Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
+{
+ owned: ArrayRef,
+ map: TopKHashTable<Option<VAL::Native>>,
+ rnd: RandomState,
+}
+
+impl StringHashTable {
+ pub fn new(limit: usize) -> Self {
+ let vals: Vec<&str> = Vec::new();
+ let owned = Arc::new(StringArray::from(vals));
+ Self {
+ owned,
+ map: TopKHashTable::new(limit, limit * 10),
+ rnd: ahash::RandomState::default(),
+ }
+ }
+}
+
+impl ArrowHashTable for StringHashTable {
+ fn set_batch(&mut self, ids: ArrayRef) {
+ self.owned = ids;
+ }
+
+ fn len(&self) -> usize {
+ self.map.len()
+ }
+
+ unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
+ self.map.update_heap_idx(mapper);
+ }
+
+ unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
+ self.map.heap_idx_at(map_idx)
+ }
+
+ fn drain(&mut self) -> (ArrayRef, Vec<usize>) {
+ let mut rows = self.map.drain();
+ rows.sort_by(|a, b| a.0.comp(&b.0));
+ let (ids, heap_idxs): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
+ let ids = Arc::new(StringArray::from(ids));
+ (ids, heap_idxs)
+ }
+
+ unsafe fn find_or_insert(
+ &mut self,
+ row_idx: usize,
+ replace_idx: usize,
+ mapper: &mut Vec<(usize, usize)>,
+ ) -> (usize, bool) {
+ let ids = self
+ .owned
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .expect("StringArray required");
+ let id = if ids.is_null(row_idx) {
+ None
+ } else {
+ Some(ids.value(row_idx))
+ };
+
+ let hash = self.rnd.hash_one(id);
+ if let Some(map_idx) = self
+ .map
+ .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str()))
+ {
+ return (map_idx, false);
+ }
+
+ // we're full and this is a better value, so remove the worst
+ let heap_idx = self.map.remove_if_full(replace_idx);
+
+ // add the new group
+ let id = id.map(|id| id.to_string());
+ let map_idx = self.map.insert(hash, id, heap_idx, mapper);
+ (map_idx, true)
+ }
+}
+
+impl<VAL: ArrowPrimitiveType> PrimitiveHashTable<VAL>
+where
+ Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
+ Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
+{
+ pub fn new(limit: usize) -> Self {
+ let owned = Arc::new(PrimitiveArray::<VAL>::builder(0).finish());
+ Self {
+ owned,
+ map: TopKHashTable::new(limit, limit * 10),
+ rnd: ahash::RandomState::default(),
+ }
+ }
+}
+
+impl<VAL: ArrowPrimitiveType> ArrowHashTable for PrimitiveHashTable<VAL>
+where
+ Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
+ Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
+{
+ fn set_batch(&mut self, ids: ArrayRef) {
+ self.owned = ids;
+ }
+
+ fn len(&self) -> usize {
+ self.map.len()
+ }
+
+ unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
+ self.map.update_heap_idx(mapper);
+ }
+
+ unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
+ self.map.heap_idx_at(map_idx)
+ }
+
+ fn drain(&mut self) -> (ArrayRef, Vec<usize>) {
+ let mut rows = self.map.drain();
+ rows.sort_by(|a, b| a.0.comp(&b.0));
+ let (ids, heap_idxs): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
+ let mut builder: PrimitiveBuilder<VAL> =
PrimitiveArray::builder(ids.len());
+ for id in ids.into_iter() {
+ match id {
+ None => builder.append_null(),
+ Some(id) => builder.append_value(id),
+ }
+ }
+ let ids = Arc::new(builder.finish());
+ (ids, heap_idxs)
+ }
+
+ unsafe fn find_or_insert(
+ &mut self,
+ row_idx: usize,
+ replace_idx: usize,
+ mapper: &mut Vec<(usize, usize)>,
+ ) -> (usize, bool) {
+ let ids = self.owned.as_primitive::<VAL>();
+ let id: Option<VAL::Native> = if ids.is_null(row_idx) {
+ None
+ } else {
+ Some(ids.value(row_idx))
+ };
+
+ let hash: u64 = id.hash(&self.rnd);
+ if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) {
+ return (map_idx, false);
+ }
+
+ // we're full and this is a better value, so remove the worst
+ let heap_idx = self.map.remove_if_full(replace_idx);
+
+ // add the new group
+ let map_idx = self.map.insert(hash, id, heap_idx, mapper);
+ (map_idx, true)
+ }
+}
+
+impl<ID: KeyType> TopKHashTable<ID> {
+ pub fn new(limit: usize, capacity: usize) -> Self {
+ Self {
+ map: RawTable::with_capacity(capacity),
+ limit,
+ }
+ }
+
+ pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) ->
Option<usize> {
+ let bucket = self.map.find(hash, |mi| eq(&mi.id))?;
+ // JUSTIFICATION
+ // Benefit: ~15% speedup + required to index into RawTable from
binary heap
+ // Soundness: getting the index of a bucket we just found
+ let idx = unsafe { self.map.bucket_index(&bucket) };
+ Some(idx)
+ }
+
+ pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
+ let bucket = unsafe { self.map.bucket(map_idx) };
+ bucket.as_ref().heap_idx
+ }
+
+ pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize {
+ if self.map.len() >= self.limit {
+ self.map.erase(self.map.bucket(replace_idx));
+ 0 // if full, always replace top node
+ } else {
+ self.map.len() // if we're not full, always append to end
+ }
+ }
+
+ unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
+ for (m, h) in mapper {
+ self.map.bucket(*m).as_mut().heap_idx = *h
+ }
+ }
+
+ pub fn insert(
+ &mut self,
+ hash: u64,
+ id: ID,
+ heap_idx: usize,
+ mapper: &mut Vec<(usize, usize)>,
+ ) -> usize {
+ let mi = HashTableItem::new(hash, id, heap_idx);
+ let bucket = self.map.try_insert_no_grow(hash, mi);
+ let bucket = match bucket {
+ Ok(bucket) => bucket,
+ Err(new_item) => {
+ let bucket = self.map.insert(hash, new_item, |mi| mi.hash);
+ // JUSTIFICATION
+ // Benefit: ~15% speedup + required to index into RawTable
from binary heap
+ // Soundness: we're getting indexes of buckets, not
dereferencing them
+ unsafe {
+ for bucket in self.map.iter() {
+ let heap_idx = bucket.as_ref().heap_idx;
+ let map_idx = self.map.bucket_index(&bucket);
+ mapper.push((heap_idx, map_idx));
+ }
+ }
+ bucket
+ }
+ };
+ // JUSTIFICATION
+ // Benefit: ~15% speedup + required to index into RawTable from
binary heap
+ // Soundness: we're getting indexes of buckets, not dereferencing them
+ unsafe { self.map.bucket_index(&bucket) }
+ }
+
+ pub fn len(&self) -> usize {
+ self.map.len()
+ }
+
+ pub fn drain(&mut self) -> Vec<(ID, usize)> {
+ self.map.drain().map(|mi| (mi.id, mi.heap_idx)).collect()
+ }
+}
+
+impl<ID: KeyType> HashTableItem<ID> {
+ pub fn new(hash: u64, id: ID, heap_idx: usize) -> Self {
+ Self { hash, id, heap_idx }
+ }
+}
+
+#[allow(dead_code)]
Review Comment:
why `dead_code`? Perhaps we can remove this prior to merge?
##########
datafusion/sqllogictest/test_files/aggregate.slt:
##########
@@ -2291,7 +2291,131 @@ false
true
NULL
+# TopK aggregation
+statement ok
+CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES
+(NULL, 0),
+('a', NULL),
+('a', 1),
+('b', 0),
+('c', 1),
+('c', 2),
+('b', 3);
+
+statement ok
+set datafusion.optimizer.enable_topk_aggregation = false;
+query TT
+explain select trace_id, MAX(timestamp) from traces group by trace_id order by
MAX(timestamp) desc limit 4;
+----
+logical_plan
+Limit: skip=0, fetch=4
+--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4
+----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]]
+------TableScan: traces projection=[trace_id, timestamp]
+physical_plan
+GlobalLimitExec: skip=0, fetch=4
+--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4
+----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC]
+------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id],
aggr=[MAX(traces.timestamp)]
+--------CoalesceBatchesExec: target_batch_size=8192
+----------RepartitionExec: partitioning=Hash([trace_id@0], 4),
input_partitions=4
+------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id],
aggr=[MAX(traces.timestamp)]
+--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+----------------MemoryExec: partitions=1, partition_sizes=[1]
+
+
+query TI
+select trace_id, MAX(timestamp) from traces group by trace_id order by
MAX(timestamp) desc limit 4;
Review Comment:
I do think it is important to have an end to end that that actually limits
the number of values coming out - as I mentioned here I think this test only
has 4 distinct groups and thus a `limit 4` doesn't actually do any limiting.
##########
datafusion/common/src/config.rs:
##########
@@ -380,6 +380,10 @@ config_namespace! {
/// repartitioning to increase parallelism to leverage more CPU cores
pub enable_round_robin_repartition: bool, default = true
+ /// When set to true, the optimizer will attempt to perform limit
operations
Review Comment:
👍 for including an escape hatch
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]