This is an automated email from the ASF dual-hosted git repository.
shauryachats pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 065d50c3f77 Extend FUNNEL_COUNT to support multiple CORRELATE_BY
columns (#18760)
065d50c3f77 is described below
commit 065d50c3f77a4b1a5b2571ea92a1aa5c71036ba3
Author: tarun11Mavani <[email protected]>
AuthorDate: Tue Jun 30 04:00:44 2026 +0530
Extend FUNNEL_COUNT to support multiple CORRELATE_BY columns (#18760)
* Extend FUNNEL_COUNT to support multiple CORRELATE_BY columns
Enable funnel analysis that tracks users through steps within a composite
key (e.g., per user per device category) by accepting multiple columns in
CORRELATE_BY(col1, col2, ...).
The single-key path is preserved as a zero-overhead fast path with separate
addSingleKey/addMultiKey abstract methods and dedicated aggregation loops,
ensuring no regression for existing single-column queries.
Multi-key composite ID mapping uses stride-based arithmetic when the product
of dictionary sizes fits in int, with a HashMap fallback for large key
spaces.
Co-authored-by: Cursor <[email protected]>
* Remove benchmark file from PR
Benchmark was used for local validation only; not needed in the PR.
Co-authored-by: Cursor <[email protected]>
* Preserve original add() signature for backward compatibility
Keep the original `add(Dictionary, A, int, int)` abstract method unchanged.
The new multi-key method is added as `addMultiKey(A, int, Dictionary[],
int[])`.
Co-authored-by: Cursor <[email protected]>
* Add tests for DictIdsWrapper HashMap fallback path and fix
SortedAggregationResult double-count
- Add DictIdsWrapperTest covering the HashMap fallback path
(large-cardinality
composite keys where product of dict sizes exceeds Integer.MAX_VALUE):
path selection, sequential ID assignment, same-key idempotency,
key-order sensitivity, and round-trip for 2- and 3-column keys.
Also covers stride-path reverseCompositeId round-trip.
Add isHashMapPath() predicate to DictIdsWrapper for test introspection
(avoids widening _strides visibility).
- Add SortedAggregationResultTest with multi-key extraction scenarios.
- Fix SortedAggregationResult.extractResult(): clear _secondaryKeySteps
after
flushMultiKeyGroup() so a second call (defensive) returns zeros rather
than
double-counting the last open primary group.
* Clarify hash approximation in BitmapResultExtractionStrategy Javadoc
Add method-level doc on convertCompositeToValueBitmap linking the
multi-key .hashCode() usage to the existing single-key non-INT
approximation in convertToValueBitmap.
* refactor(funnel): reduce allocations in sorted multi-key path and bitmap
extraction
SortedAggregationResult: replace HashMap<IntArrayList, boolean[]> with
pre-allocated flat arrays and linear scan. Zero allocations in the hot
loop for typical workloads (1-5 secondary key combos per primary group).
BitmapResultExtractionStrategy: replace toCompositeString().hashCode()
with direct type-aware hash combining, avoiding StringBuilder/String
allocation per composite ID during extraction.
---------
Co-authored-by: Cursor <[email protected]>
---
.../function/funnel/AggregationStrategy.java | 181 ++++++++++++++++++---
.../function/funnel/BitmapAggregationStrategy.java | 10 ++
.../funnel/BitmapResultExtractionStrategy.java | 58 ++++++-
.../function/funnel/DictIdsWrapper.java | 136 +++++++++++++++-
.../FunnelCountAggregationFunctionFactory.java | 3 +
.../FunnelCountSortedAggregationFunction.java | 16 +-
.../funnel/SetResultExtractionStrategy.java | 30 +++-
.../function/funnel/SortedAggregationResult.java | 108 +++++++++++-
.../function/funnel/SortedAggregationStrategy.java | 12 ++
.../funnel/ThetaSketchAggregationStrategy.java | 19 ++-
.../function/funnel/DictIdsWrapperTest.java | 128 +++++++++++++++
.../funnel/SortedAggregationResultTest.java | 57 +++++++
.../integration/tests/custom/FunnelCountTest.java | 103 +++++++++++-
13 files changed, 811 insertions(+), 50 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
index 99006c102ab..1448ea0b243 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/AggregationStrategy.java
@@ -37,12 +37,16 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
* There should be no assumptions beyond segment boundaries, different
aggregation strategies may be utilized
* across different segments for a given query.
*
+ * <p>Supports both single-key and multi-key CORRELATE_BY. The single-key path
is kept as a zero-overhead fast path
+ * (structurally identical to the original single-column implementation) to
avoid any regression for existing queries.
+ *
* @param <A> Aggregation result accumulated across blocks within segment,
kept by result holder.
*/
@ThreadSafe
public abstract class AggregationStrategy<A> {
protected final int _numSteps;
+ protected final int _numCorrelateByKeys;
private final List<ExpressionContext> _stepExpressions;
private final List<ExpressionContext> _correlateByExpressions;
private final ExpressionContext _primaryCorrelationCol;
@@ -52,13 +56,38 @@ public abstract class AggregationStrategy<A> {
_correlateByExpressions = correlateByExpressions;
_primaryCorrelationCol = _correlateByExpressions.get(0);
_numSteps = _stepExpressions.size();
+ _numCorrelateByKeys = _correlateByExpressions.size();
}
/**
- * Returns an aggregation result for this aggregation strategy to be kept in
a result holder (aggregation only).
+ * Creates an aggregation result for single-key correlation.
*/
abstract A createAggregationResult(Dictionary dictionary);
+ /**
+ * Creates an aggregation result for multi-key correlation.
+ */
+ abstract A createAggregationResultMultiKey(Dictionary[] dictionaries);
+
+ public A getAggregationResult(Dictionary dictionary, AggregationResultHolder
aggregationResultHolder) {
+ A aggResult = aggregationResultHolder.getResult();
+ if (aggResult == null) {
+ aggResult = createAggregationResult(dictionary);
+ aggregationResultHolder.setValue(aggResult);
+ }
+ return aggResult;
+ }
+
+ public A getAggregationResultMultiKey(Dictionary[] dictionaries,
+ AggregationResultHolder aggregationResultHolder) {
+ A aggResult = aggregationResultHolder.getResult();
+ if (aggResult == null) {
+ aggResult = createAggregationResultMultiKey(dictionaries);
+ aggregationResultHolder.setValue(aggResult);
+ }
+ return aggResult;
+ }
+
public A getAggregationResultGroupBy(Dictionary dictionary,
GroupByResultHolder groupByResultHolder, int groupKey) {
A aggResult = groupByResultHolder.getResult(groupKey);
if (aggResult == null) {
@@ -68,11 +97,12 @@ public abstract class AggregationStrategy<A> {
return aggResult;
}
- public A getAggregationResult(Dictionary dictionary, AggregationResultHolder
aggregationResultHolder) {
- A aggResult = aggregationResultHolder.getResult();
+ public A getAggregationResultGroupByMultiKey(Dictionary[] dictionaries,
GroupByResultHolder groupByResultHolder,
+ int groupKey) {
+ A aggResult = groupByResultHolder.getResult(groupKey);
if (aggResult == null) {
- aggResult = createAggregationResult(dictionary);
- aggregationResultHolder.setValue(aggResult);
+ aggResult = createAggregationResultMultiKey(dictionaries);
+ groupByResultHolder.setValueForKey(groupKey, aggResult);
}
return aggResult;
}
@@ -82,10 +112,18 @@ public abstract class AggregationStrategy<A> {
*/
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- final Dictionary dictionary = getDictionary(blockValSetMap);
- final int[] correlationIds = getCorrelationIds(blockValSetMap);
final int[][] steps = getSteps(blockValSetMap);
+ if (_numCorrelateByKeys == 1) {
+ aggregateSingleKey(length, aggregationResultHolder, blockValSetMap,
steps);
+ } else {
+ aggregateMultiKey(length, aggregationResultHolder, blockValSetMap,
steps);
+ }
+ }
+ private void aggregateSingleKey(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+ final Dictionary dictionary = getPrimaryDictionary(blockValSetMap);
+ final int[] correlationIds = getPrimaryCorrelationIds(blockValSetMap);
final A aggResult = getAggregationResult(dictionary,
aggregationResultHolder);
for (int i = 0; i < length; i++) {
for (int n = 0; n < _numSteps; n++) {
@@ -96,20 +134,46 @@ public abstract class AggregationStrategy<A> {
}
}
+ private void aggregateMultiKey(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+ final Dictionary[] dictionaries = getAllDictionaries(blockValSetMap);
+ final int[][] allCorrelationIds = getAllCorrelationDictIds(blockValSetMap);
+ final A aggResult = getAggregationResultMultiKey(dictionaries,
aggregationResultHolder);
+ final int[] rowDictIds = new int[_numCorrelateByKeys];
+ for (int i = 0; i < length; i++) {
+ for (int k = 0; k < _numCorrelateByKeys; k++) {
+ rowDictIds[k] = allCorrelationIds[k][i];
+ }
+ for (int n = 0; n < _numSteps; n++) {
+ if (steps[n][i] > 0) {
+ addMultiKey(aggResult, n, dictionaries, rowDictIds);
+ }
+ }
+ }
+ }
+
/**
* Performs aggregation on the given group key array and block value sets
(aggregation group-by on single-value
* columns).
*/
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- final Dictionary dictionary = getDictionary(blockValSetMap);
- final int[] correlationIds = getCorrelationIds(blockValSetMap);
final int[][] steps = getSteps(blockValSetMap);
+ if (_numCorrelateByKeys == 1) {
+ aggregateGroupBySVSingleKey(length, groupKeyArray, groupByResultHolder,
blockValSetMap, steps);
+ } else {
+ aggregateGroupBySVMultiKey(length, groupKeyArray, groupByResultHolder,
blockValSetMap, steps);
+ }
+ }
+ private void aggregateGroupBySVSingleKey(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+ final Dictionary dictionary = getPrimaryDictionary(blockValSetMap);
+ final int[] correlationIds = getPrimaryCorrelationIds(blockValSetMap);
for (int i = 0; i < length; i++) {
+ final int groupKey = groupKeyArray[i];
+ final A aggResult = getAggregationResultGroupBy(dictionary,
groupByResultHolder, groupKey);
for (int n = 0; n < _numSteps; n++) {
- final int groupKey = groupKeyArray[i];
- final A aggResult = getAggregationResultGroupBy(dictionary,
groupByResultHolder, groupKey);
if (steps[n][i] > 0) {
add(dictionary, aggResult, n, correlationIds[i]);
}
@@ -117,20 +181,47 @@ public abstract class AggregationStrategy<A> {
}
}
+ private void aggregateGroupBySVMultiKey(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap, int[][] steps) {
+ final Dictionary[] dictionaries = getAllDictionaries(blockValSetMap);
+ final int[][] allCorrelationIds = getAllCorrelationDictIds(blockValSetMap);
+ final int[] rowDictIds = new int[_numCorrelateByKeys];
+ for (int i = 0; i < length; i++) {
+ for (int k = 0; k < _numCorrelateByKeys; k++) {
+ rowDictIds[k] = allCorrelationIds[k][i];
+ }
+ final int groupKey = groupKeyArray[i];
+ final A aggResult = getAggregationResultGroupByMultiKey(dictionaries,
groupByResultHolder, groupKey);
+ for (int n = 0; n < _numSteps; n++) {
+ if (steps[n][i] > 0) {
+ addMultiKey(aggResult, n, dictionaries, rowDictIds);
+ }
+ }
+ }
+ }
+
/**
* Performs aggregation on the given group keys array and block value sets
(aggregation group-by on multi-value
* columns).
*/
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- final Dictionary dictionary = getDictionary(blockValSetMap);
- final int[] correlationIds = getCorrelationIds(blockValSetMap);
final int[][] steps = getSteps(blockValSetMap);
+ if (_numCorrelateByKeys == 1) {
+ aggregateGroupByMVSingleKey(length, groupKeysArray, groupByResultHolder,
blockValSetMap, steps);
+ } else {
+ aggregateGroupByMVMultiKey(length, groupKeysArray, groupByResultHolder,
blockValSetMap, steps);
+ }
+ }
+ private void aggregateGroupByMVSingleKey(int length, int[][] groupKeysArray,
+ GroupByResultHolder groupByResultHolder, Map<ExpressionContext,
BlockValSet> blockValSetMap, int[][] steps) {
+ final Dictionary dictionary = getPrimaryDictionary(blockValSetMap);
+ final int[] correlationIds = getPrimaryCorrelationIds(blockValSetMap);
for (int i = 0; i < length; i++) {
- for (int n = 0; n < _numSteps; n++) {
- for (int groupKey : groupKeysArray[i]) {
- final A aggResult = getAggregationResultGroupBy(dictionary,
groupByResultHolder, groupKey);
+ for (int groupKey : groupKeysArray[i]) {
+ final A aggResult = getAggregationResultGroupBy(dictionary,
groupByResultHolder, groupKey);
+ for (int n = 0; n < _numSteps; n++) {
if (steps[n][i] > 0) {
add(dictionary, aggResult, n, correlationIds[i]);
}
@@ -139,26 +230,74 @@ public abstract class AggregationStrategy<A> {
}
}
+ private void aggregateGroupByMVMultiKey(int length, int[][] groupKeysArray,
+ GroupByResultHolder groupByResultHolder, Map<ExpressionContext,
BlockValSet> blockValSetMap, int[][] steps) {
+ final Dictionary[] dictionaries = getAllDictionaries(blockValSetMap);
+ final int[][] allCorrelationIds = getAllCorrelationDictIds(blockValSetMap);
+ final int[] rowDictIds = new int[_numCorrelateByKeys];
+ for (int i = 0; i < length; i++) {
+ for (int k = 0; k < _numCorrelateByKeys; k++) {
+ rowDictIds[k] = allCorrelationIds[k][i];
+ }
+ for (int groupKey : groupKeysArray[i]) {
+ final A aggResult = getAggregationResultGroupByMultiKey(dictionaries,
groupByResultHolder, groupKey);
+ for (int n = 0; n < _numSteps; n++) {
+ if (steps[n][i] > 0) {
+ addMultiKey(aggResult, n, dictionaries, rowDictIds);
+ }
+ }
+ }
+ }
+ }
+
/**
* Adds a correlation id to the aggregation counter for a given step in the
funnel.
*/
abstract void add(Dictionary dictionary, A aggResult, int step, int
correlationId);
- private Dictionary getDictionary(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
+ /**
+ * Adds a row's composite correlation identity to the aggregation counter
for a given step (multi-key path).
+ *
+ * @param aggResult the aggregation result to update
+ * @param step the funnel step index
+ * @param dictionaries one dictionary per correlate-by column
+ * @param correlationDictIds one dictionary ID per correlate-by column for
the current row
+ * (this array is reused across rows;
implementations must not hold a reference)
+ */
+ abstract void addMultiKey(A aggResult, int step, Dictionary[] dictionaries,
int[] correlationDictIds);
+
+ Dictionary getPrimaryDictionary(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
final BlockValSet primaryCorrelationValSet =
blockValSetMap.get(_primaryCorrelationCol);
- // FUNNELCOUNT requires dict-id reads from the forward index; a column
with EncodingType.RAW + dictionaryIndex
- // exposes a Dictionary but BlockValSet#getDictionaryIdsSV throws on the
RAW forward index. Gate on the
- // explicit forward-index encoding flag rather than dictionary nullness
alone.
Preconditions.checkArgument(primaryCorrelationValSet.isDictionaryEncoded(),
"CORRELATE_BY column in FUNNELCOUNT aggregation function not
supported, please use a dictionary encoded "
+ "column.");
return primaryCorrelationValSet.getDictionary();
}
- private int[] getCorrelationIds(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
+ private Dictionary[] getAllDictionaries(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
+ Dictionary[] dictionaries = new Dictionary[_numCorrelateByKeys];
+ for (int k = 0; k < _numCorrelateByKeys; k++) {
+ BlockValSet valSet = blockValSetMap.get(_correlateByExpressions.get(k));
+ Preconditions.checkArgument(valSet.isDictionaryEncoded(),
+ "CORRELATE_BY column in FUNNELCOUNT aggregation function not
supported, please use a dictionary encoded "
+ + "column.");
+ dictionaries[k] = valSet.getDictionary();
+ }
+ return dictionaries;
+ }
+
+ private int[] getPrimaryCorrelationIds(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
return blockValSetMap.get(_primaryCorrelationCol).getDictionaryIdsSV();
}
+ private int[][] getAllCorrelationDictIds(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
+ int[][] allIds = new int[_numCorrelateByKeys][];
+ for (int k = 0; k < _numCorrelateByKeys; k++) {
+ allIds[k] =
blockValSetMap.get(_correlateByExpressions.get(k)).getDictionaryIdsSV();
+ }
+ return allIds;
+ }
+
private int[][] getSteps(Map<ExpressionContext, BlockValSet> blockValSetMap)
{
final int[][] steps = new int[_numSteps][];
for (int n = 0; n < _numSteps; n++) {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
index f726d936205..c0f3019fa1c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapAggregationStrategy.java
@@ -37,8 +37,18 @@ class BitmapAggregationStrategy extends
AggregationStrategy<DictIdsWrapper> {
return new DictIdsWrapper(_numSteps, dictionary);
}
+ @Override
+ public DictIdsWrapper createAggregationResultMultiKey(Dictionary[]
dictionaries) {
+ return new DictIdsWrapper(_numSteps, dictionaries);
+ }
+
@Override
protected void add(Dictionary dictionary, DictIdsWrapper dictIdsWrapper, int
step, int correlationId) {
dictIdsWrapper._stepsBitmaps[step].add(correlationId);
}
+
+ @Override
+ void addMultiKey(DictIdsWrapper dictIdsWrapper, int step, Dictionary[]
dictionaries, int[] correlationDictIds) {
+
dictIdsWrapper._stepsBitmaps[step].add(dictIdsWrapper.getCompositeCorrelationId(correlationDictIds));
+ }
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
index 1611b8dae8f..4b0373123df 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/BitmapResultExtractionStrategy.java
@@ -26,6 +26,13 @@ import org.roaringbitmap.PeekableIntIterator;
import org.roaringbitmap.RoaringBitmap;
+/**
+ * Extracts intermediate bitmap results for cross-segment merging.
+ *
+ * <p>The bitmap strategy stores entities as 32-bit hash codes in a {@link
RoaringBitmap}. For single-key INT
+ * columns, the actual int values are stored directly (exact). For other
single-key types and all multi-key
+ * composites, hash codes are used (approximate — hash collisions can cause
under-counting).
+ */
class BitmapResultExtractionStrategy implements
ResultExtractionStrategy<DictIdsWrapper, List<RoaringBitmap>> {
protected final int _numSteps;
@@ -42,14 +49,59 @@ class BitmapResultExtractionStrategy implements
ResultExtractionStrategy<DictIds
}
return result;
}
- Dictionary dictionary = dictIdsWrapper._dictionary;
List<RoaringBitmap> result = new ArrayList<>(_numSteps);
- for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
- result.add(convertToValueBitmap(dictionary, dictIdBitmap));
+ if (dictIdsWrapper.isMultiKey()) {
+ for (RoaringBitmap compositeIdBitmap : dictIdsWrapper._stepsBitmaps) {
+ result.add(convertCompositeToValueBitmap(dictIdsWrapper,
compositeIdBitmap));
+ }
+ } else {
+ Dictionary dictionary = dictIdsWrapper._dictionaries[0];
+ for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
+ result.add(convertToValueBitmap(dictionary, dictIdBitmap));
+ }
}
return result;
}
+ /// Converts segment-local composite dictionary IDs to hash-coded value
bitmaps for cross-segment merging.
+ /// Combines per-column value hashes directly — no string allocation. Same
approximation as the
+ /// single-key non-INT path in {@link #convertToValueBitmap}: hash
collisions may cause under-counting.
+ private RoaringBitmap convertCompositeToValueBitmap(DictIdsWrapper wrapper,
RoaringBitmap compositeIdBitmap) {
+ RoaringBitmap valueBitmap = new RoaringBitmap();
+ PeekableIntIterator iterator = compositeIdBitmap.getIntIterator();
+ int numKeys = wrapper._dictionaries.length;
+ int[] dictIds = new int[numKeys];
+ while (iterator.hasNext()) {
+ wrapper.reverseCompositeId(iterator.next(), dictIds);
+ int hash = 1;
+ for (int k = 0; k < numKeys; k++) {
+ hash = 31 * hash + valueHashCode(wrapper._dictionaries[k], dictIds[k]);
+ }
+ valueBitmap.add(hash);
+ }
+ return valueBitmap;
+ }
+
+ /// Returns the hash code of a dictionary value using its native type,
avoiding string conversion
+ /// for numeric types.
+ private static int valueHashCode(Dictionary dictionary, int dictId) {
+ switch (dictionary.getValueType()) {
+ case INT:
+ return Integer.hashCode(dictionary.getIntValue(dictId));
+ case LONG:
+ return Long.hashCode(dictionary.getLongValue(dictId));
+ case FLOAT:
+ return Float.hashCode(dictionary.getFloatValue(dictId));
+ case DOUBLE:
+ return Double.hashCode(dictionary.getDoubleValue(dictId));
+ case STRING:
+ return dictionary.getStringValue(dictId).hashCode();
+ default:
+ throw new IllegalArgumentException("Illegal data type for FUNNEL_COUNT
aggregation function: "
+ + dictionary.getValueType());
+ }
+ }
+
/**
* Helper method to read dictionary and convert dictionary ids to hash code
of the values for dictionary-encoded
* expression.
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
index c09d0128f29..a778d48319b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapper.java
@@ -18,19 +18,151 @@
*/
package org.apache.pinot.core.query.aggregation.function.funnel;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.roaringbitmap.RoaringBitmap;
+/**
+ * Holds per-step RoaringBitmaps keyed by correlation dictionary IDs.
+ *
+ * <p>For single-key CORRELATE_BY, stores raw dictionary IDs directly in the
bitmaps (compact, fits in one
+ * RoaringBitmap container for typical segment sizes).
+ *
+ * <p>For multi-key CORRELATE_BY, composite IDs are assigned via stride-based
arithmetic (when the combined key
+ * space fits in int) or a HashMap fallback for large key spaces.
+ */
final class DictIdsWrapper {
- final Dictionary _dictionary;
+ final Dictionary[] _dictionaries;
final RoaringBitmap[] _stepsBitmaps;
+ // Stride-based composite mapping (non-null only for multi-key when product
of dict sizes fits in int)
+ private final int[] _strides;
+
+ // HashMap-based composite mapping (non-null only for multi-key when stride
overflows int)
+ private final Map<IntArrayList, Integer> _compositeKeyMap;
+ private final List<int[]> _compositeKeyReverse;
+ private final IntArrayList _lookupKey;
+
DictIdsWrapper(int numSteps, Dictionary dictionary) {
- _dictionary = dictionary;
+ _dictionaries = new Dictionary[]{dictionary};
_stepsBitmaps = new RoaringBitmap[numSteps];
for (int n = 0; n < numSteps; n++) {
_stepsBitmaps[n] = new RoaringBitmap();
}
+ _strides = null;
+ _compositeKeyMap = null;
+ _compositeKeyReverse = null;
+ _lookupKey = null;
+ }
+
+ DictIdsWrapper(int numSteps, Dictionary[] dictionaries) {
+ _dictionaries = dictionaries;
+ _stepsBitmaps = new RoaringBitmap[numSteps];
+ for (int n = 0; n < numSteps; n++) {
+ _stepsBitmaps[n] = new RoaringBitmap();
+ }
+
+ if (dictionaries.length > 1) {
+ long totalSpace = 1;
+ boolean fitsInInt = true;
+ for (Dictionary d : dictionaries) {
+ totalSpace *= d.length();
+ if (totalSpace > Integer.MAX_VALUE) {
+ fitsInInt = false;
+ break;
+ }
+ }
+
+ if (fitsInInt) {
+ _strides = new int[dictionaries.length];
+ _strides[dictionaries.length - 1] = 1;
+ for (int k = dictionaries.length - 2; k >= 0; k--) {
+ _strides[k] = _strides[k + 1] * dictionaries[k + 1].length();
+ }
+ _compositeKeyMap = null;
+ _compositeKeyReverse = null;
+ _lookupKey = null;
+ } else {
+ _strides = null;
+ _compositeKeyMap = new HashMap<>();
+ _compositeKeyReverse = new ArrayList<>();
+ _lookupKey = new IntArrayList(dictionaries.length);
+ }
+ } else {
+ _strides = null;
+ _compositeKeyMap = null;
+ _compositeKeyReverse = null;
+ _lookupKey = null;
+ }
+ }
+
+ boolean isMultiKey() {
+ return _dictionaries.length > 1;
+ }
+
+ boolean isHashMapPath() {
+ return _compositeKeyMap != null;
+ }
+
+ /**
+ * Maps a tuple of per-column dictionary IDs to a single composite int
suitable for RoaringBitmap.
+ * Only used for multi-key; for single-key, callers should add the dictId
directly.
+ */
+ int getCompositeCorrelationId(int[] dictIds) {
+ if (_strides != null) {
+ int id = 0;
+ for (int k = 0; k < dictIds.length; k++) {
+ id += dictIds[k] * _strides[k];
+ }
+ return id;
+ }
+ _lookupKey.clear();
+ for (int dictId : dictIds) {
+ _lookupKey.add(dictId);
+ }
+ Integer existingId = _compositeKeyMap.get(_lookupKey);
+ if (existingId != null) {
+ return existingId;
+ }
+ IntArrayList insertKey = new IntArrayList(dictIds);
+ int id = _compositeKeyReverse.size();
+ _compositeKeyMap.put(insertKey, id);
+ _compositeKeyReverse.add(dictIds.clone());
+ return id;
+ }
+
+ /**
+ * Builds a collision-free composite string from dictionary values using
length-prefix encoding.
+ * Each component is encoded as {@code length:value}, e.g. ("alice", "home")
becomes "5:alice4:home".
+ */
+ static String toCompositeString(Dictionary[] dictionaries, int[] dictIds) {
+ StringBuilder sb = new StringBuilder();
+ for (int k = 0; k < dictionaries.length; k++) {
+ String val = dictionaries[k].getStringValue(dictIds[k]);
+ sb.append(val.length()).append(':').append(val);
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Reverse-maps a composite ID back to per-column dictionary IDs.
+ */
+ void reverseCompositeId(int compositeId, int[] outDictIds) {
+ if (_strides != null) {
+ int remaining = compositeId;
+ for (int k = 0; k < outDictIds.length - 1; k++) {
+ outDictIds[k] = remaining / _strides[k];
+ remaining %= _strides[k];
+ }
+ outDictIds[outDictIds.length - 1] = remaining;
+ return;
+ }
+ int[] stored = _compositeKeyReverse.get(compositeId);
+ System.arraycopy(stored, 0, outDictIds, 0, outDictIds.length);
}
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
index 5d0fb1eeb85..91e9bfca51a 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunctionFactory.java
@@ -168,6 +168,9 @@ public class FunnelCountAggregationFunctionFactory
implements Supplier<Aggregati
ResultExtractionStrategy<DictIdsWrapper, List<Long>>
bitmapPartitionedResultExtractionStrategy() {
final MergeStrategy<List<RoaringBitmap>> bitmapMergeStrategy =
bitmapMergeStrategy();
+ // For partitioned mode, each segment is self-contained: every row for a
given correlation key
+ // appears in exactly one segment. Therefore we can count bitmap
cardinality directly without
+ // converting segment-local composite IDs to global values — they will
never be merged across segments.
return dictIdsWrapper -> {
if (dictIdsWrapper == null) {
return Collections.nCopies(_numSteps, 0L);
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
index ac39461cef5..a86c0e22e85 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountSortedAggregationFunction.java
@@ -34,6 +34,8 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
* It leverages a more efficient counting strategy for segments sorted by
correlate_by column, falls back to a regular
* counting strategy for unsorted segments (e.g. uncommitted segments).
*
+ * <p>For multi-key correlate-by, the sorted/partitioned optimization applies
to the first (primary) column only.
+ *
* Example:
* SELECT
* dateTrunc('day', timestamp) AS ts,
@@ -59,14 +61,14 @@ public class FunnelCountSortedAggregationFunction<A>
extends FunnelCountAggregat
super(expressions, stepExpressions, correlateByExpressions,
aggregationStrategy, resultExtractionStrategy,
mergeStrategy);
_sortedAggregationStrategy = new
SortedAggregationStrategy(stepExpressions, correlateByExpressions);
- _sortedResultExtractionStrategy = SortedAggregationResult::extractResult;;
+ _sortedResultExtractionStrategy = SortedAggregationResult::extractResult;
_primaryCorrelationCol = correlateByExpressions.get(0);
}
@Override
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- if (isSortedDictionary(blockValSetMap)) {
+ if (isPrimarySortedDictionary(blockValSetMap)) {
_sortedAggregationStrategy.aggregate(length, aggregationResultHolder,
blockValSetMap);
} else {
super.aggregate(length, aggregationResultHolder, blockValSetMap);
@@ -76,7 +78,7 @@ public class FunnelCountSortedAggregationFunction<A> extends
FunnelCountAggregat
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- if (isSortedDictionary(blockValSetMap)) {
+ if (isPrimarySortedDictionary(blockValSetMap)) {
_sortedAggregationStrategy.aggregateGroupBySV(length, groupKeyArray,
groupByResultHolder, blockValSetMap);
} else {
super.aggregateGroupBySV(length, groupKeyArray, groupByResultHolder,
blockValSetMap);
@@ -86,7 +88,7 @@ public class FunnelCountSortedAggregationFunction<A> extends
FunnelCountAggregat
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- if (isSortedDictionary(blockValSetMap)) {
+ if (isPrimarySortedDictionary(blockValSetMap)) {
_sortedAggregationStrategy.aggregateGroupByMV(length, groupKeysArray,
groupByResultHolder, blockValSetMap);
} else {
super.aggregateGroupByMV(length, groupKeysArray, groupByResultHolder,
blockValSetMap);
@@ -111,15 +113,15 @@ public class FunnelCountSortedAggregationFunction<A>
extends FunnelCountAggregat
}
}
- private boolean isSortedDictionary(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
- return getDictionary(blockValSetMap).isSorted();
+ private boolean isPrimarySortedDictionary(Map<ExpressionContext,
BlockValSet> blockValSetMap) {
+ return getPrimaryDictionary(blockValSetMap).isSorted();
}
private boolean isSortedAggResult(Object aggResult) {
return aggResult instanceof SortedAggregationResult;
}
- private Dictionary getDictionary(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
+ private Dictionary getPrimaryDictionary(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
final Dictionary primaryCorrelationDictionary =
blockValSetMap.get(_primaryCorrelationCol).getDictionary();
Preconditions.checkArgument(primaryCorrelationDictionary != null,
"CORRELATE_BY column in FUNNELCOUNT aggregation function not supported
for sorted setting, "
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
index fad2bbf033a..675288b43c5 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SetResultExtractionStrategy.java
@@ -33,7 +33,10 @@ import org.roaringbitmap.RoaringBitmap;
/**
- * Aggregation strategy leveraging set algebra (unions/intersections).
+ * Extracts intermediate set results for cross-segment merging.
+ *
+ * <p>For single-key, converts dictionary IDs to typed value sets. For
multi-key, converts composite IDs
+ * to length-prefix-encoded composite strings, producing a {@code Set<String>}
per step.
*/
class SetResultExtractionStrategy implements
ResultExtractionStrategy<DictIdsWrapper, List<Set>> {
protected final int _numSteps;
@@ -51,14 +54,33 @@ class SetResultExtractionStrategy implements
ResultExtractionStrategy<DictIdsWra
}
return result;
}
- Dictionary dictionary = dictIdsWrapper._dictionary;
List<Set> result = new ArrayList<>(_numSteps);
- for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
- result.add(convertToValueSet(dictionary, dictIdBitmap));
+ if (dictIdsWrapper.isMultiKey()) {
+ for (RoaringBitmap compositeIdBitmap : dictIdsWrapper._stepsBitmaps) {
+ result.add(convertCompositeToValueSet(dictIdsWrapper,
compositeIdBitmap));
+ }
+ } else {
+ Dictionary dictionary = dictIdsWrapper._dictionaries[0];
+ for (RoaringBitmap dictIdBitmap : dictIdsWrapper._stepsBitmaps) {
+ result.add(convertToValueSet(dictionary, dictIdBitmap));
+ }
}
return result;
}
+ private Set<String> convertCompositeToValueSet(DictIdsWrapper wrapper,
RoaringBitmap compositeIdBitmap) {
+ int numValues = compositeIdBitmap.getCardinality();
+ int numKeys = wrapper._dictionaries.length;
+ int[] dictIds = new int[numKeys];
+ ObjectOpenHashSet<String> stringSet = new ObjectOpenHashSet<>(numValues);
+ PeekableIntIterator iterator = compositeIdBitmap.getIntIterator();
+ while (iterator.hasNext()) {
+ wrapper.reverseCompositeId(iterator.next(), dictIds);
+ stringSet.add(DictIdsWrapper.toCompositeString(wrapper._dictionaries,
dictIds));
+ }
+ return stringSet;
+ }
+
private Set convertToValueSet(Dictionary dictionary, RoaringBitmap
dictIdBitmap) {
int numValues = dictIdBitmap.getCardinality();
PeekableIntIterator iterator = dictIdBitmap.getIntIterator();
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
index eb773eac7ed..cf4bb2aa05a 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResult.java
@@ -19,29 +19,49 @@
package org.apache.pinot.core.query.aggregation.function.funnel;
import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.util.Arrays;
/**
* Aggregation result data structure leveraged by sorted aggregation strategy.
+ *
+ * <p>For single-key, uses simple last-ID tracking since data is sorted by the
correlation column.
+ * For multi-key, data is sorted by the primary (first) correlation column
only; secondary keys
+ * are tracked via pre-allocated flat arrays within each primary-key group.
*/
class SortedAggregationResult {
+ private static final int INITIAL_CAPACITY = 8;
+
final int _numSteps;
final long[] _stepCounters;
+ private final int _numKeys;
+
+ // Single-key tracking
final boolean[] _correlatedSteps;
int _lastCorrelationId = Integer.MIN_VALUE;
+ // Multi-key tracking — flat arrays, pre-allocated once and reused across
groups
+ private int _lastPrimaryId = Integer.MIN_VALUE;
+ private int[][] _entryKeys;
+ private boolean[][] _entrySteps;
+ private int _entryCount;
+
SortedAggregationResult(int numSteps) {
+ this(numSteps, 1);
+ }
+
+ SortedAggregationResult(int numSteps, int numKeys) {
_numSteps = numSteps;
- _stepCounters = new long[_numSteps];
- _correlatedSteps = new boolean[_numSteps];
+ _numKeys = numKeys;
+ _stepCounters = new long[numSteps];
+ _correlatedSteps = numKeys == 1 ? new boolean[numSteps] : null;
+ _entryKeys = numKeys > 1 ? new int[INITIAL_CAPACITY][numKeys] : null;
+ _entrySteps = numKeys > 1 ? new boolean[INITIAL_CAPACITY][numSteps] : null;
}
public void add(int step, int correlationId) {
if (correlationId != _lastCorrelationId) {
- // End of correlation group, calculate funnel conversion counts
incrStepCounters();
-
- // initialize next correlation group
for (int n = 0; n < _numSteps; n++) {
_correlatedSteps[n] = false;
}
@@ -50,7 +70,74 @@ class SortedAggregationResult {
_correlatedSteps[step] = true;
}
+ /**
+ * Multi-key add. Data must be sorted by correlationIds[0] (primary key).
+ * Secondary keys are tracked via linear scan over pre-allocated flat arrays.
+ *
+ * <p>The full correlationIds array (including the primary key at index 0)
is used as the
+ * lookup key. The primary key is the same for every entry within a group,
so including it
+ * is redundant but harmless — it avoids the cost of copying a sub-array.
+ */
+ public void addMultiKey(int step, int[] correlationIds) {
+ int primaryId = correlationIds[0];
+ if (primaryId != _lastPrimaryId) {
+ flushMultiKeyGroup();
+ _lastPrimaryId = primaryId;
+ _entryCount = 0;
+ }
+
+ for (int i = 0; i < _entryCount; i++) {
+ if (keysMatch(_entryKeys[i], correlationIds)) {
+ _entrySteps[i][step] = true;
+ return;
+ }
+ }
+
+ ensureCapacity();
+ System.arraycopy(correlationIds, 0, _entryKeys[_entryCount], 0, _numKeys);
+ Arrays.fill(_entrySteps[_entryCount], false);
+ _entrySteps[_entryCount][step] = true;
+ _entryCount++;
+ }
+
+ private boolean keysMatch(int[] stored, int[] incoming) {
+ for (int i = 0; i < _numKeys; i++) {
+ if (stored[i] != incoming[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private void ensureCapacity() {
+ if (_entryCount < _entryKeys.length) {
+ return;
+ }
+ int oldCap = _entryKeys.length;
+ int newCap = oldCap * 2;
+ _entryKeys = Arrays.copyOf(_entryKeys, newCap);
+ _entrySteps = Arrays.copyOf(_entrySteps, newCap);
+ for (int i = oldCap; i < newCap; i++) {
+ _entryKeys[i] = new int[_numKeys];
+ _entrySteps[i] = new boolean[_numSteps];
+ }
+ }
+
+ private void flushMultiKeyGroup() {
+ for (int i = 0; i < _entryCount; i++) {
+ for (int n = 0; n < _numSteps; n++) {
+ if (!_entrySteps[i][n]) {
+ break;
+ }
+ _stepCounters[n]++;
+ }
+ }
+ }
+
void incrStepCounters() {
+ if (_correlatedSteps == null) {
+ return;
+ }
for (int n = 0; n < _numSteps; n++) {
if (!_correlatedSteps[n]) {
break;
@@ -59,9 +146,16 @@ class SortedAggregationResult {
}
}
+ /**
+ * Extracts the final funnel result. Must be called exactly once.
+ */
public LongArrayList extractResult() {
- // count last correlation id left open
- incrStepCounters();
+ if (_numKeys > 1) {
+ flushMultiKeyGroup();
+ _entryCount = 0;
+ } else {
+ incrStepCounters();
+ }
return LongArrayList.wrap(_stepCounters);
}
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
index 533d8723a74..7668e7a72bb 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationStrategy.java
@@ -25,6 +25,8 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
/**
* Aggregation strategy for segments partitioned and sorted by the main
correlation column.
+ * For multi-key correlate-by, data must be sorted by the first (primary)
column; secondary
+ * keys are handled within each primary-key group by {@link
SortedAggregationResult}.
*/
class SortedAggregationStrategy extends
AggregationStrategy<SortedAggregationResult> {
public SortedAggregationStrategy(List<ExpressionContext> stepExpressions,
@@ -37,8 +39,18 @@ class SortedAggregationStrategy extends
AggregationStrategy<SortedAggregationRes
return new SortedAggregationResult(_numSteps);
}
+ @Override
+ public SortedAggregationResult createAggregationResultMultiKey(Dictionary[]
dictionaries) {
+ return new SortedAggregationResult(_numSteps, dictionaries.length);
+ }
+
@Override
void add(Dictionary dictionary, SortedAggregationResult aggResult, int step,
int correlationId) {
aggResult.add(step, correlationId);
}
+
+ @Override
+ void addMultiKey(SortedAggregationResult aggResult, int step, Dictionary[]
dictionaries, int[] correlationDictIds) {
+ aggResult.addMultiKey(step, correlationDictIds);
+ }
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
index a2ac25f8677..da16056705b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/ThetaSketchAggregationStrategy.java
@@ -46,6 +46,15 @@ class ThetaSketchAggregationStrategy extends
AggregationStrategy<UpdateSketch[]>
return stepsSketches;
}
+ @Override
+ public UpdateSketch[] createAggregationResultMultiKey(Dictionary[]
dictionaries) {
+ final UpdateSketch[] stepsSketches = new UpdateSketch[_numSteps];
+ for (int n = 0; n < _numSteps; n++) {
+ stepsSketches[n] = _updateSketchBuilder.build();
+ }
+ return stepsSketches;
+ }
+
@Override
void add(Dictionary dictionary, UpdateSketch[] stepsSketches, int step, int
correlationId) {
final UpdateSketch sketch = stepsSketches[step];
@@ -66,8 +75,14 @@ class ThetaSketchAggregationStrategy extends
AggregationStrategy<UpdateSketch[]>
sketch.update(dictionary.getStringValue(correlationId));
break;
default:
- throw new IllegalStateException("Illegal CORRELATED_BY column data
type for FUNNEL_COUNT aggregation function: "
- + dictionary.getValueType());
+ throw new IllegalStateException(
+ "Illegal CORRELATED_BY column data type for FUNNEL_COUNT
aggregation function: "
+ + dictionary.getValueType());
}
}
+
+ @Override
+ void addMultiKey(UpdateSketch[] stepsSketches, int step, Dictionary[]
dictionaries, int[] correlationDictIds) {
+ stepsSketches[step].update(DictIdsWrapper.toCompositeString(dictionaries,
correlationDictIds));
+ }
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapperTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapperTest.java
new file mode 100644
index 00000000000..aba3ab13fd7
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/DictIdsWrapperTest.java
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function.funnel;
+
+import java.util.Arrays;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+
+public class DictIdsWrapperTest {
+
+ // Two dicts of 100_000 each → product 10^10 > Integer.MAX_VALUE → HashMap
path
+ private static final int LARGE_DICT_SIZE = 100_000;
+
+ private static Dictionary mockDict(int size) {
+ Dictionary d = mock(Dictionary.class);
+ when(d.length()).thenReturn(size);
+ return d;
+ }
+
+ private static Dictionary[] largeDicts(int count) {
+ Dictionary[] dicts = new Dictionary[count];
+ for (int i = 0; i < count; i++) {
+ dicts[i] = mockDict(LARGE_DICT_SIZE);
+ }
+ return dicts;
+ }
+
+ // ── Single-key constructor ───────────────────────────────────────────────
+
+ @Test
+ public void testSingleKeyNotMultiKey() {
+ DictIdsWrapper wrapper = new DictIdsWrapper(2, mockDict(100));
+ Assert.assertFalse(wrapper.isMultiKey());
+ Assert.assertFalse(wrapper.isHashMapPath());
+ }
+
+ // ── HashMap fallback path ────────────────────────────────────────────────
+
+ @Test
+ public void testHashMapPathSelectedWhenProductOverflows() {
+ DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+ Assert.assertTrue(wrapper.isHashMapPath(), "should select HashMap path for
large key space");
+ Assert.assertTrue(wrapper.isMultiKey());
+ }
+
+ @Test
+ public void testHashMapPathNewKeyGetsSequentialId() {
+ DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+ Assert.assertEquals(wrapper.getCompositeCorrelationId(new int[]{0, 0}), 0);
+ Assert.assertEquals(wrapper.getCompositeCorrelationId(new int[]{0, 1}), 1);
+ Assert.assertEquals(wrapper.getCompositeCorrelationId(new int[]{1, 0}), 2);
+ }
+
+ @Test
+ public void testHashMapPathSameKeyReturnsSameId() {
+ DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+ int first = wrapper.getCompositeCorrelationId(new int[]{5, 7});
+ int second = wrapper.getCompositeCorrelationId(new int[]{5, 7});
+ Assert.assertEquals(first, second);
+ }
+
+ @Test
+ public void testHashMapPathKeyOrderSensitive() {
+ DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+ int id01 = wrapper.getCompositeCorrelationId(new int[]{0, 1});
+ int id10 = wrapper.getCompositeCorrelationId(new int[]{1, 0});
+ Assert.assertNotEquals(id01, id10, "[0,1] and [1,0] must map to different
IDs");
+ }
+
+ @Test
+ public void testHashMapPathReverseRoundTrip() {
+ DictIdsWrapper wrapper = new DictIdsWrapper(2, largeDicts(2));
+ int[][] keys = {{0, 0}, {0, 1}, {1, 0}, {99999, 99999}, {42, 7}};
+ for (int[] key : keys) {
+ int id = wrapper.getCompositeCorrelationId(key);
+ int[] out = new int[2];
+ wrapper.reverseCompositeId(id, out);
+ Assert.assertEquals(out, key, "reverseCompositeId must round-trip for
key " + Arrays.toString(key));
+ }
+ }
+
+ @Test
+ public void testHashMapPathThreeColumns() {
+ DictIdsWrapper wrapper = new DictIdsWrapper(3, largeDicts(3));
+ int id = wrapper.getCompositeCorrelationId(new int[]{1, 2, 3});
+ int[] out = new int[3];
+ wrapper.reverseCompositeId(id, out);
+ Assert.assertEquals(out, new int[]{1, 2, 3});
+ }
+
+ // ── Stride path reverseCompositeId ──────────────────────────────────────
+
+ @Test
+ public void testStridePathReverseRoundTrip() {
+ Dictionary[] dicts = {mockDict(10), mockDict(20), mockDict(5)};
+ DictIdsWrapper wrapper = new DictIdsWrapper(3, dicts);
+ Assert.assertFalse(wrapper.isHashMapPath(), "should select stride path for
small key space");
+
+ int[][] keys = {{0, 0, 0}, {9, 19, 4}, {3, 7, 2}, {0, 1, 0}};
+ for (int[] key : keys) {
+ int id = wrapper.getCompositeCorrelationId(key);
+ int[] out = new int[3];
+ wrapper.reverseCompositeId(id, out);
+ Assert.assertEquals(out, key, "stride reverseCompositeId must round-trip
for key " + Arrays.toString(key));
+ }
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResultTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResultTest.java
new file mode 100644
index 00000000000..7e265376c6f
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/funnel/SortedAggregationResultTest.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function.funnel;
+
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+
+public class SortedAggregationResultTest {
+
+ @Test
+ public void testMultiKeyExtractResultDoesNotDoubleCount() {
+ // Two entities (primary key 0 and 1), each completing both steps.
+ // Expected: stepCounters = [2, 2] (one completion per entity per step).
+ SortedAggregationResult result = new SortedAggregationResult(2, 2);
+ result.addMultiKey(0, new int[]{0, 10});
+ result.addMultiKey(1, new int[]{0, 10});
+ result.addMultiKey(0, new int[]{1, 20});
+ result.addMultiKey(1, new int[]{1, 20});
+
+ LongArrayList counts = result.extractResult();
+ Assert.assertEquals(counts.getLong(0), 2L, "step 0 count");
+ Assert.assertEquals(counts.getLong(1), 2L, "step 1 count");
+ }
+
+ @Test
+ public void testMultiKeySecondaryKeysWithinPrimaryGroup() {
+ // Primary key 0 with two secondary keys: (0,10) and (0,20).
+ // (0,10) completes both steps; (0,20) completes only step 0.
+ // Expected: stepCounters = [2, 1].
+ SortedAggregationResult result = new SortedAggregationResult(2, 2);
+ result.addMultiKey(0, new int[]{0, 10});
+ result.addMultiKey(1, new int[]{0, 10});
+ result.addMultiKey(0, new int[]{0, 20});
+
+ LongArrayList counts = result.extractResult();
+ Assert.assertEquals(counts.getLong(0), 2L, "step 0 count");
+ Assert.assertEquals(counts.getLong(1), 1L, "step 1 count");
+ }
+}
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
index c18674fa9b2..ae13706243b 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/FunnelCountTest.java
@@ -93,10 +93,17 @@ import static org.testng.Assert.assertNotNull;
*
* <h3>Expected funnel counts</h3>
* <pre>
- * Overall: [12, 10, 6, 3]
- * clothing: [ 4, 4, 2, 2]
- * electronics: [ 5, 4, 2, 1]
- * home: [ 3, 2, 0, 0]
+ * Single-key CORRELATE_BY(user_id):
+ * Overall: [12, 10, 6, 3]
+ * clothing: [ 4, 4, 2, 2]
+ * electronics: [ 5, 4, 2, 1]
+ * home: [ 3, 2, 0, 0]
+ *
+ * Multi-key CORRELATE_BY(user_id, category):
+ * Overall: [12, 10, 4, 3] (step 3 drops from 6→4: users 3,9
cross-category)
+ * clothing: [ 4, 4, 2, 2] (same — grouping already separates by
category)
+ * electronics: [ 5, 4, 2, 1] (same)
+ * home: [ 3, 2, 0, 0] (same)
* </pre>
*/
@Test(suiteName = "CustomClusterIntegrationTest")
@@ -118,6 +125,11 @@ public class FunnelCountTest extends
CustomDataQueryClusterIntegrationTest {
private static final long[] EXPECTED_CLOTHING = {4, 4, 2, 2};
private static final long[] EXPECTED_HOME = {3, 2, 0, 0};
+ // Multi-key: CORRELATE_BY(user_id, category)
+ // Cross-category users 3 and 9 no longer complete checkout within a single
(user, category) pair.
+ private static final long[] EXPECTED_MULTI_KEY_OVERALL = {12, 10, 4, 3};
+ private static final long[] EXPECTED_MULTI_KEY_FILTERED = {7, 6, 3, 3};
+
@Override
protected long getCountStarResult() {
return COUNT_STAR;
@@ -231,6 +243,23 @@ public class FunnelCountTest extends
CustomDataQueryClusterIntegrationTest {
CATEGORY_COL, funnelCountAggregation(settings), TABLE_NAME,
CATEGORY_COL, CATEGORY_COL);
}
+ private String funnelCountMultiKeyAggregation(String settings) {
+ String settingsClause = (settings == null) ? "" : ", SETTINGS(" + settings
+ ")";
+ return String.format("FUNNEL_COUNT("
+ + "STEPS(%1$s = '%2$s', %1$s = '%3$s', %1$s = '%4$s', %1$s = '%5$s'), "
+ + "CORRELATE_BY(%6$s, %7$s)"
+ + "%8$s)", ACTION_COL, VIEW, CART, CHECKOUT, PURCHASE, USER_ID_COL,
CATEGORY_COL, settingsClause);
+ }
+
+ private String overallMultiKeyQuery(String settings) {
+ return String.format("SELECT %s FROM %s",
funnelCountMultiKeyAggregation(settings), TABLE_NAME);
+ }
+
+ private String groupByMultiKeyQuery(String settings) {
+ return String.format("SELECT %s, %s FROM %s GROUP BY %s ORDER BY %s",
+ CATEGORY_COL, funnelCountMultiKeyAggregation(settings), TABLE_NAME,
CATEGORY_COL, CATEGORY_COL);
+ }
+
// ---------- assertion helpers ----------
private JsonNode getRows(JsonNode response) {
@@ -394,4 +423,70 @@ public class FunnelCountTest extends
CustomDataQueryClusterIntegrationTest {
JsonNode rows = getRows(postQuery(emptyResultGroupByQuery(settings)));
assertEquals(rows.size(), 0, "Expected zero groups when all rows are
filtered");
}
+
+ // ===================== Multi-key CORRELATE_BY tests =====================
+
+ @Test(dataProvider = "allStrategies")
+ public void testMultiKeyOverall(String settings)
+ throws Exception {
+ setUseMultiStageQueryEngine(false);
+ JsonNode rows = getRows(postQuery(overallMultiKeyQuery(settings)));
+ assertOverallResult(rows, EXPECTED_MULTI_KEY_OVERALL);
+ }
+
+ @Test(dataProvider = "allStrategies")
+ public void testMultiKeyGroupBy(String settings)
+ throws Exception {
+ setUseMultiStageQueryEngine(false);
+ JsonNode rows = getRows(postQuery(groupByMultiKeyQuery(settings)));
+ // Group-by category with CORRELATE_BY(user_id, category) produces the
same results
+ // as single-key because grouping already separates rows by category.
+ assertGroupByResult(rows);
+ }
+
+ private String filteredMultiKeyQuery(String settings) {
+ return overallMultiKeyQuery(settings) + " WHERE " + USER_ID_COL + " <= 7";
+ }
+
+ @Test(dataProvider = "allStrategies")
+ public void testMultiKeyWithFilter(String settings)
+ throws Exception {
+ setUseMultiStageQueryEngine(false);
+ JsonNode rows = getRows(postQuery(filteredMultiKeyQuery(settings)));
+ assertEquals(rows.size(), 1);
+ assertStepCounts(rows.get(0).get(0), EXPECTED_MULTI_KEY_FILTERED);
+ }
+
+ // Multi-key: WHERE filter eliminates one segment entirely (users 7-12 only)
+ // user 9 crosses categories, so (user=9,home) only does view+cart,
(user=9,electronics) only does checkout
+ // Expected: view=6, cart=5, checkout=2, purchase=1
+ private static final long[] EXPECTED_MULTI_KEY_ONE_SEGMENT = {6, 5, 2, 1};
+
+ @Test(dataProvider = "allStrategies")
+ public void testMultiKeyFilterEliminatesOneSegment(String settings)
+ throws Exception {
+ setUseMultiStageQueryEngine(false);
+ String query = overallMultiKeyQuery(settings) + " WHERE " + USER_ID_COL +
" >= 7";
+ JsonNode rows = getRows(postQuery(query));
+ assertOverallResult(rows, EXPECTED_MULTI_KEY_ONE_SEGMENT);
+ }
+
+ @Test(dataProvider = "allStrategies")
+ public void testMultiKeyEmptyResultOverall(String settings)
+ throws Exception {
+ setUseMultiStageQueryEngine(false);
+ String query = overallMultiKeyQuery(settings) + " WHERE " + USER_ID_COL +
" > 100";
+ JsonNode rows = getRows(postQuery(query));
+ assertOverallResult(rows, EXPECTED_ALL_FILTERED);
+ }
+
+ @Test(dataProvider = "allStrategies")
+ public void testMultiKeyEmptyResultGroupBy(String settings)
+ throws Exception {
+ setUseMultiStageQueryEngine(false);
+ String query = String.format("SELECT %s, %s FROM %s WHERE %s > 100 GROUP
BY %s ORDER BY %s",
+ CATEGORY_COL, funnelCountMultiKeyAggregation(settings), TABLE_NAME,
USER_ID_COL, CATEGORY_COL, CATEGORY_COL);
+ JsonNode rows = getRows(postQuery(query));
+ assertEquals(rows.size(), 0, "Expected zero groups when all rows are
filtered");
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]