Rich-T-kid commented on code in PR #22983:
URL: https://github.com/apache/datafusion/pull/22983#discussion_r3429146634


##########
datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs:
##########
@@ -0,0 +1,894 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::mem::size_of;
+use std::sync::Arc;
+
+use arrow::array::{
+    Array, ArrayRef, AsArray, DictionaryArray, Int8Array, Int16Array, 
Int32Array,
+    Int64Array, ListBuilder, NullArray, StringBuilder, UInt8Array, UInt16Array,
+    UInt32Array, UInt64Array,
+};
+use arrow::datatypes::{
+    ArrowNativeType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, 
SchemaRef,
+    UInt8Type, UInt16Type, UInt32Type, UInt64Type,
+};
+use arrow::downcast_dictionary_array;
+use datafusion_common::hash_utils::{RandomState, combine_hashes, 
create_hashes};
+use datafusion_common::{Result, internal_datafusion_err};
+use datafusion_execution::memory_pool::proxy::HashTableAllocExt;
+use datafusion_expr::EmitTo;
+use hashbrown::hash_table::HashTable;
+
+use crate::aggregates::group_values::GroupValues;
+
+/// Caches the hashes for one dictionary column's values array.
+/// Rebuilt only when the `Arc` pointer changes (i.e. a new values array 
arrives).
+struct ColumnCache {
+    /// Keeps the values `Arc` alive and is compared with `Arc::ptr_eq` to 
detect staleness.
+    values: ArrayRef,
+    /// `value_hashes[k]` = hash of the value at dictionary index `k`.
+    value_hashes: Vec<u64>,
+}
+
+impl ColumnCache {
+    fn empty() -> Self {
+        Self {
+            values: Arc::new(NullArray::new(0)),
+            value_hashes: vec![],
+        }
+    }
+
+    fn update(&mut self, new_values: ArrayRef, random_state: &RandomState) -> 
Result<()> {
+        if Arc::ptr_eq(&new_values, &self.values) {
+            return Ok(());
+        }
+        let num_values = new_values.len();
+        // Reuse the allocation; only grows capacity when a larger values 
array arrives.
+        self.value_hashes.clear();
+        self.value_hashes.resize(num_values, 0u64);
+        create_hashes(&[new_values.clone()], random_state, &mut 
self.value_hashes)?;
+        self.values = new_values;
+        Ok(())
+    }
+
+    fn size(&self) -> usize {
+        self.value_hashes.len() * size_of::<u64>()
+    }
+
+    fn clear_shrink(&mut self, shrink_to: usize) {
+        self.values = Arc::new(NullArray::new(0));
+        self.value_hashes.clear();
+        self.value_hashes.shrink_to(shrink_to);
+    }
+}
+
+/// [`GroupValues`] for GROUP BY over **two or more** dictionary-typed columns.
+pub struct GroupDictionaryColumn {
+    schema: SchemaRef,
+    col_caches: Vec<ColumnCache>,
+    /// `(row_hash, group_id)`.  Multiple entries may share the same hash 
value;
+    /// byte-level comparison is used to resolve collisions.
+    map: HashTable<(u64, usize)>,
+    /// Tracked allocation size of `map` in bytes, updated on every insert and 
shrink.
+    map_size: usize,
+    /// All group rows packed back-to-back into a single contiguous buffer.
+    ///
+    /// CSR-style layout: `row_offsets[g]` is the start of group `g` and
+    /// `row_offsets[g+1]` is its end.  The last group has no `g+1` entry; its
+    /// end is `row_buffer.len()`.
+    row_buffer: Vec<u8>,
+    /// `row_offsets[g]` = start byte of group `g` inside `row_buffer`.
+    row_offsets: Vec<usize>,
+    /// Reused scratch buffer for encoding the current row.
+    row_scratch: Vec<u8>,
+    row_decoder: RowSetDecoder,
+    random_state: RandomState,
+}
+
+/// Returns `true` when every field in `schema` is `DataType::Dictionary`.
+pub fn all_dictionary_schema(schema: &arrow::datatypes::Schema) -> bool {
+    schema
+        .fields()
+        .iter()
+        .all(|field| matches!(field.data_type(), DataType::Dictionary(_, _)))
+}
+
+fn is_supported_value_type(data_type: &DataType) -> bool {
+    matches!(data_type, DataType::Utf8)
+        || matches!(data_type, DataType::List(f) if f.data_type() == 
&DataType::Utf8)
+}
+
+impl GroupDictionaryColumn {
+    pub fn new(schema: SchemaRef) -> Result<Self> {
+        if schema.fields().len() < 2 {
+            return Err(internal_datafusion_err!(
+                "GroupDictionaryColumn requires at least 2 columns, got {}",
+                schema.fields().len()
+            ));
+        }
+        for field in schema.fields() {
+            match field.data_type() {
+                DataType::Dictionary(_, value_type) => {
+                    if !is_supported_value_type(value_type) {
+                        return Err(internal_datafusion_err!(
+                            "GroupDictionaryColumn: unsupported dictionary 
value type \
+                             '{}' in column '{}'",
+                            value_type,
+                            field.name()
+                        ));
+                    }
+                }
+                _ => {
+                    return Err(internal_datafusion_err!(
+                        "GroupDictionaryColumn requires all columns to be 
Dictionary, \
+                         but '{}' has type {}",
+                        field.name(),
+                        field.data_type()
+                    ));
+                }
+            }
+        }
+        let n_cols = schema.fields().len();
+        let row_decoder = RowSetDecoder::new(Arc::clone(&schema));
+        Ok(Self {
+            schema,
+            col_caches: (0..n_cols).map(|_| ColumnCache::empty()).collect(),
+            map: HashTable::with_capacity(128),
+            map_size: 0,
+            row_buffer: Vec::new(),
+            row_offsets: Vec::new(),
+            row_scratch: Vec::new(),
+            row_decoder,
+            random_state: crate::aggregates::AGGREGATION_HASH_SEED,
+        })
+    }
+}
+
+fn dict_values_array(col: &dyn Array) -> ArrayRef {
+    downcast_dictionary_array!(
+        col => col.values().clone(),
+        _ => unreachable!("schema validated in GroupDictionaryColumn::new")
+    )
+}
+
+// Box is required: different key widths (Int8/Int16/Int32/Int64) produce 
different concrete iterator types.
+fn fill_keys(col: &dyn Array) -> Box<dyn Iterator<Item = Option<usize>> + '_> {
+    downcast_dictionary_array!(
+        col => {
+            let keys = col.keys();
+            Box::new((0..keys.len()).map(move |row_idx| {
+                if keys.is_valid(row_idx) {
+                    Some(keys.value(row_idx).as_usize())
+                } else {
+                    None
+                }
+            }))
+        },
+        _ => unreachable!("schema validated in GroupDictionaryColumn::new")
+    )
+}
+
+impl GroupValues for GroupDictionaryColumn {
+    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> 
Result<()> {
+        debug_assert_eq!(cols.len(), self.schema.fields().len());
+        groups.clear();
+
+        if cols.is_empty() || cols[0].is_empty() {
+            return Ok(());
+        }
+        let n_rows = cols[0].len();
+
+        for (col_idx, col) in cols.iter().enumerate() {
+            self.col_caches[col_idx]
+                .update(dict_values_array(col.as_ref()), &self.random_state)?;
+        }
+
+        // Downcast once per column; advance with .next() per row to avoid 
per-row downcast.
+        let mut key_iters: Vec<_> =
+            cols.iter().map(|col| fill_keys(col.as_ref())).collect();
+
+        groups.reserve(n_rows);
+
+        for _row in 0..n_rows {
+            let mut hash = 0u64;
+            self.row_scratch.clear();
+
+            for (col_idx, key_iter) in key_iters.iter_mut().enumerate() {
+                let key = key_iter.next().unwrap();
+                let cache = &self.col_caches[col_idx];
+                let value_hash = key.map_or(0, |key_idx| 
cache.value_hashes[key_idx]);
+                hash = combine_hashes(hash, value_hash);
+                encode_value(key, cache.values.as_ref(), &mut 
self.row_scratch);
+            }
+
+            let combined_hash = hash;
+            let found = {
+                let row_scratch = self.row_scratch.as_slice();
+                let row_buffer = self.row_buffer.as_slice();
+                let row_offsets = self.row_offsets.as_slice();
+                self.map
+                    .find(combined_hash, |&(stored_hash, group_id)| {
+                        stored_hash == combined_hash && {
+                            let end = row_offsets
+                                .get(group_id + 1)
+                                .copied()
+                                .unwrap_or(row_buffer.len()); // last group 
has no g+1 entry
+                            row_buffer[row_offsets[group_id]..end] == 
*row_scratch
+                        }
+                    })
+                    .map(|&(_, group_id)| group_id)
+            };
+
+            let group_id = match found {
+                Some(existing_id) => existing_id,
+                None => {
+                    let new_id = self.row_offsets.len();
+                    self.row_offsets.push(self.row_buffer.len());
+                    self.row_buffer.extend_from_slice(&self.row_scratch);
+                    self.map.insert_accounted(
+                        (combined_hash, new_id),
+                        |(stored_hash, _)| *stored_hash,
+                        &mut self.map_size,
+                    );
+                    new_id
+                }
+            };
+
+            groups.push(group_id);
+        }
+
+        Ok(())
+    }
+
+    fn size(&self) -> usize {
+        let cache_bytes: usize = self.col_caches.iter().map(|c| 
c.size()).sum();
+        self.map_size
+            + self.row_buffer.len()
+            + self.row_offsets.len() * size_of::<usize>()
+            + self.row_scratch.capacity()
+            + cache_bytes
+    }
+
+    fn is_empty(&self) -> bool {
+        self.row_offsets.is_empty()
+    }
+
+    fn len(&self) -> usize {
+        self.row_offsets.len()
+    }
+
+    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+        let n_total = self.row_offsets.len();
+        if n_total == 0 {
+            return Ok(self.row_decoder.finish());
+        }
+        let n_emit = match emit_to {
+            EmitTo::All => n_total,
+            EmitTo::First(n) => n.min(n_total),
+        };
+
+        for row_idx in 0..n_emit {
+            let start = self.row_offsets[row_idx];
+            let end = self
+                .row_offsets
+                .get(row_idx + 1)
+                .copied()
+                .unwrap_or(self.row_buffer.len());
+            self.row_decoder.decode(&self.row_buffer[start..end]);
+        }
+        let inner = self.row_decoder.finish();
+        let arrays: Vec<ArrayRef> = inner
+            .into_iter()
+            .zip(self.schema.fields())
+            .map(|(values, field)| match field.data_type() {
+                DataType::Dictionary(key_type, _) => wrap_as_dictionary(
+                    values,
+                    make_sequential_keys(n_emit, key_type),
+                    key_type,
+                ),
+                _ => unreachable!("schema validated in 
GroupDictionaryColumn::new"),
+            })
+            .collect();
+
+        if n_emit == n_total {
+            self.row_buffer.clear();
+            self.row_offsets.clear();
+            self.map.clear();
+            self.map_size = 0;
+        } else {
+            let retain_start = self.row_offsets[n_emit];
+            self.row_offsets.drain(0..n_emit);
+            for offset in &mut self.row_offsets {
+                *offset -= retain_start;
+            }
+            self.row_buffer.drain(0..retain_start);
+            // avoiding this somehow would be nice. worse case this runs once
+            // VecDeque?
+            // Shift remaining group ids in-place; retain gives &mut access so 
no rehashing occurs.
+            self.map.retain(|(_, gid)| {
+                if *gid < n_emit {
+                    return false;
+                }
+                *gid -= n_emit;
+                true
+            });
+        }
+
+        Ok(arrays)
+    }
+
+    fn clear_shrink(&mut self, num_rows: usize) {
+        self.map.clear();
+        self.map.shrink_to(num_rows, |_| 0);
+        self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
+        self.row_buffer.clear();
+        self.row_offsets.clear();
+        self.row_offsets.shrink_to(num_rows);
+        for cache in &mut self.col_caches {
+            cache.clear_shrink(num_rows);
+        }
+    }
+}
+
+// ── encoding / decoding 
───────────────────────────────────────────────────────
+
+/// Wire format per column:
+///   null:              `[0x00]`
+///   non-null scalar:   `[0x01][len: u32 LE][utf8_bytes…]`
+///   non-null list:     `[0x01][content_len: u32 LE][n: u32 LE][elem…]`
+///                      where each elem is `[0x00]` (null) or `[0x01][len: 
u32 LE][utf8_bytes…]`
+fn encode_value(key: Option<usize>, values: &dyn Array, buf: &mut Vec<u8>) {
+    let key_idx = match key {
+        None => {
+            buf.push(0);
+            return;
+        }
+        Some(k) => k,
+    };
+    if values.is_null(key_idx) {
+        buf.push(0);
+        return;
+    }
+    buf.push(1);
+    match values.data_type() {
+        DataType::Utf8 => {
+            let bytes = values.as_string::<i32>().value(key_idx).as_bytes();

Review Comment:
   This cast isnt needed. `.value().as_bytes()` should be fine similar to 
#/21765 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to