Rich-T-kid commented on code in PR #22983: URL: https://github.com/apache/datafusion/pull/22983#discussion_r3428439590
########## datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs: ########## @@ -0,0 +1,894 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, DictionaryArray, Int8Array, Int16Array, Int32Array, + Int64Array, ListBuilder, NullArray, StringBuilder, UInt8Array, UInt16Array, + UInt32Array, UInt64Array, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, SchemaRef, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use arrow::downcast_dictionary_array; +use datafusion_common::hash_utils::{RandomState, combine_hashes, create_hashes}; +use datafusion_common::{Result, internal_datafusion_err}; +use datafusion_execution::memory_pool::proxy::HashTableAllocExt; +use datafusion_expr::EmitTo; +use hashbrown::hash_table::HashTable; + +use crate::aggregates::group_values::GroupValues; + +/// Caches the hashes for one dictionary column's values array. +/// Rebuilt only when the `Arc` pointer changes (i.e. a new values array arrives). +struct ColumnCache { + /// Keeps the values `Arc` alive and is compared with `Arc::ptr_eq` to detect staleness. + values: ArrayRef, + /// `value_hashes[k]` = hash of the value at dictionary index `k`. + value_hashes: Vec<u64>, +} + +impl ColumnCache { + fn empty() -> Self { + Self { + values: Arc::new(NullArray::new(0)), + value_hashes: vec![], + } + } + + fn update(&mut self, new_values: ArrayRef, random_state: &RandomState) -> Result<()> { + if Arc::ptr_eq(&new_values, &self.values) { + return Ok(()); + } + let num_values = new_values.len(); + // Reuse the allocation; only grows capacity when a larger values array arrives. + self.value_hashes.clear(); + self.value_hashes.resize(num_values, 0u64); + create_hashes(&[new_values.clone()], random_state, &mut self.value_hashes)?; + self.values = new_values; + Ok(()) + } + + fn size(&self) -> usize { + self.value_hashes.len() * size_of::<u64>() + } + + fn clear_shrink(&mut self, shrink_to: usize) { + self.values = Arc::new(NullArray::new(0)); + self.value_hashes.clear(); + self.value_hashes.shrink_to(shrink_to); + } +} + +/// [`GroupValues`] for GROUP BY over **two or more** dictionary-typed columns. +pub struct GroupDictionaryColumn { + schema: SchemaRef, + col_caches: Vec<ColumnCache>, + /// `(row_hash, group_id)`. Multiple entries may share the same hash value; + /// byte-level comparison is used to resolve collisions. + map: HashTable<(u64, usize)>, + /// Tracked allocation size of `map` in bytes, updated on every insert and shrink. + map_size: usize, + /// All group rows packed back-to-back into a single contiguous buffer. + /// + /// CSR-style layout: `row_offsets[g]` is the start of group `g` and + /// `row_offsets[g+1]` is its end. The last group has no `g+1` entry; its + /// end is `row_buffer.len()`. + row_buffer: Vec<u8>, + /// `row_offsets[g]` = start byte of group `g` inside `row_buffer`. + row_offsets: Vec<usize>, + /// Reused scratch buffer for encoding the current row. + row_scratch: Vec<u8>, + row_decoder: RowSetDecoder, + random_state: RandomState, +} + +/// Returns `true` when every field in `schema` is `DataType::Dictionary`. +pub fn all_dictionary_schema(schema: &arrow::datatypes::Schema) -> bool { + schema + .fields() + .iter() + .all(|field| matches!(field.data_type(), DataType::Dictionary(_, _))) +} + +fn is_supported_value_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Utf8) + || matches!(data_type, DataType::List(f) if f.data_type() == &DataType::Utf8) +} + +impl GroupDictionaryColumn { + pub fn new(schema: SchemaRef) -> Result<Self> { + if schema.fields().len() < 2 { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires at least 2 columns, got {}", + schema.fields().len() + )); + } + for field in schema.fields() { + match field.data_type() { + DataType::Dictionary(_, value_type) => { + if !is_supported_value_type(value_type) { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn: unsupported dictionary value type \ + '{}' in column '{}'", + value_type, + field.name() + )); + } + } + _ => { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires all columns to be Dictionary, \ + but '{}' has type {}", + field.name(), + field.data_type() + )); + } + } + } + let n_cols = schema.fields().len(); + let row_decoder = RowSetDecoder::new(Arc::clone(&schema)); + Ok(Self { + schema, + col_caches: (0..n_cols).map(|_| ColumnCache::empty()).collect(), + map: HashTable::with_capacity(128), + map_size: 0, + row_buffer: Vec::new(), + row_offsets: Vec::new(), + row_scratch: Vec::new(), + row_decoder, + random_state: crate::aggregates::AGGREGATION_HASH_SEED, + }) + } +} + +fn dict_values_array(col: &dyn Array) -> ArrayRef { + downcast_dictionary_array!( + col => col.values().clone(), + _ => unreachable!("schema validated in GroupDictionaryColumn::new") + ) +} + +// Box is required: different key widths (Int8/Int16/Int32/Int64) produce different concrete iterator types. +fn fill_keys(col: &dyn Array) -> Box<dyn Iterator<Item = Option<usize>> + '_> { + downcast_dictionary_array!( + col => { + let keys = col.keys(); + Box::new((0..keys.len()).map(move |row_idx| { + if keys.is_valid(row_idx) { + Some(keys.value(row_idx).as_usize()) + } else { + None Review Comment: this just looks a bit weird, the `.map()` is called on the iterator that `0..keys.len()` produces -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
