This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new ceb1acb060 Fix the handling of filtered agg in MSQE (#15214)
ceb1acb060 is described below
commit ceb1acb060c823d13f75406e7c707f21a78eab76
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Thu Mar 6 23:11:00 2025 -0700
Fix the handling of filtered agg in MSQE (#15214)
---
.../tests/MultiStageEngineIntegrationTest.java | 44 ++++----
.../operator/MultistageGroupByExecutor.java | 124 ++++++++++-----------
2 files changed, 79 insertions(+), 89 deletions(-)
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
index 49d15bc322..360d0095e0 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
@@ -1257,30 +1257,26 @@ public class MultiStageEngineIntegrationTest extends
BaseClusterIntegrationTestS
throws Exception {
// Query written this way with a CTE and limit will be planned such that
the multi-stage group by executor will be
// used for both leaf and final aggregation
- String sqlQuery = "SET mseMaxInitialResultHolderCapacity = 1;\n"
- + "WITH tmp AS ("
- + " SELECT *"
- + " FROM mytable"
- + " WHERE AirlineID > 20000"
- + " LIMIT 10000"
- + ") "
- + "SELECT AirlineID,"
- + " COUNT(*) FILTER ("
- + " WHERE Origin = 'garbage'"
- + ") "
- + "FROM tmp "
- + "GROUP BY AirlineID";
- JsonNode result = postQuery(sqlQuery);
- assertNoError(result);
- // Ensure that result set is not empty
- assertTrue(result.get("numRowsResultSet").asInt() > 0);
-
- // Ensure that the count is 0 for all groups (because the aggregation
filter does not match any rows)
- JsonNode rows = result.get("resultTable").get("rows");
- for (int i = 0; i < rows.size(); i++) {
- assertEquals(rows.get(i).get(1).asInt(), 0);
- // Ensure that the main filter was applied
- assertTrue(rows.get(i).get(0).asInt() > 20000);
+ String aggregates1 = "COUNT(*) FILTER (WHERE Origin = 'garbage')";
+ String aggregates2 = aggregates1 + ", COUNT(*)";
+ String queryTemplate = "SET mseMaxInitialResultHolderCapacity = 1;\n"
+ + "WITH tmp AS (SELECT * FROM mytable WHERE AirlineID > 20000 LIMIT
10000)\n"
+ + "SELECT AirlineID, %s FROM tmp GROUP BY AirlineID";
+ String query1 = String.format(queryTemplate, aggregates1);
+ String query2 = String.format(queryTemplate, aggregates2);
+ for (String query : new String[]{query1, query2}) {
+ JsonNode result = postQuery(query);
+ assertNoError(result);
+ // Ensure that result set is not empty
+ assertTrue(result.get("numRowsResultSet").asInt() > 0);
+
+ // Ensure that the count is 0 for all groups (because the aggregation
filter does not match any rows)
+ JsonNode rows = result.get("resultTable").get("rows");
+ for (int i = 0; i < rows.size(); i++) {
+ assertEquals(rows.get(i).get(1).asInt(), 0);
+ // Ensure that the main filter was applied
+ assertTrue(rows.get(i).get(0).asInt() > 20000);
+ }
}
}
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
index fc8dd924b6..d1d5029496 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
@@ -71,16 +71,9 @@ public class MultistageGroupByExecutor {
// because they use the zero based integer indexes to store results.
private final GroupIdGenerator _groupIdGenerator;
- public MultistageGroupByExecutor(
- int[] groupKeyIds,
- AggregationFunction[] aggFunctions,
- int[] filterArgIds,
- int maxFilterArgId,
- AggType aggType,
- boolean leafReturnFinalResult,
- DataSchema resultSchema,
- Map<String, String> opChainMetadata,
- @Nullable PlanNode.NodeHint nodeHint) {
+ public MultistageGroupByExecutor(int[] groupKeyIds, AggregationFunction[]
aggFunctions, int[] filterArgIds,
+ int maxFilterArgId, AggType aggType, boolean leafReturnFinalResult,
DataSchema resultSchema,
+ Map<String, String> opChainMetadata, @Nullable PlanNode.NodeHint
nodeHint) {
_groupKeyIds = groupKeyIds;
_aggFunctions = aggFunctions;
_filterArgIds = filterArgIds;
@@ -248,10 +241,7 @@ public class MultistageGroupByExecutor {
return rows;
}
- private Object[] getRow(
- Iterator<GroupIdGenerator.GroupKey> groupKeyIterator,
- int numKeys,
- int numFunctions,
+ private Object[] getRow(Iterator<GroupIdGenerator.GroupKey>
groupKeyIterator, int numKeys, int numFunctions,
ColumnDataType[] resultStoredTypes) {
GroupIdGenerator.GroupKey groupKey = groupKeyIterator.next();
int groupId = groupKey._groupId;
@@ -293,62 +283,66 @@ public class MultistageGroupByExecutor {
private void processAggregate(TransferableBlock block) {
if (_maxFilterArgId < 0) {
- // No filter for any aggregation function
- int[] intKeys = generateGroupByKeys(block);
- for (int i = 0; i < _aggFunctions.length; i++) {
- AggregationFunction aggFunction = _aggFunctions[i];
- Map<ExpressionContext, BlockValSet> blockValSetMap =
AggregateOperator.getBlockValSetMap(aggFunction, block);
- GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
- groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups());
- aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys,
groupByResultHolder, blockValSetMap);
- }
+ processAggregateWithoutFilter(block);
} else {
- // Some aggregation functions have filter, cache the matching rows
- int[] intKeys = null;
- RoaringBitmap[] matchedBitmaps = new RoaringBitmap[_maxFilterArgId + 1];
- int[] numMatchedRowsArray = new int[_maxFilterArgId + 1];
- int[][] filteredIntKeysArray = new int[_maxFilterArgId + 1][];
- for (int i = 0; i < _aggFunctions.length; i++) {
- AggregationFunction aggFunction = _aggFunctions[i];
- int filterArgId = _filterArgIds[i];
- if (filterArgId < 0) {
- // No filter for this aggregation function
- if (intKeys == null) {
- intKeys = generateGroupByKeys(block);
- }
- Map<ExpressionContext, BlockValSet> blockValSetMap =
AggregateOperator.getBlockValSetMap(aggFunction, block);
- GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
- groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups());
- aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys,
groupByResultHolder, blockValSetMap);
- } else {
- // Need to filter the block before aggregation
- RoaringBitmap matchedBitmap = matchedBitmaps[filterArgId];
- if (matchedBitmap == null) {
- matchedBitmap = AggregateOperator.getMatchedBitmap(block,
filterArgId);
- matchedBitmaps[filterArgId] = matchedBitmap;
- int numMatchedRows = matchedBitmap.getCardinality();
- numMatchedRowsArray[filterArgId] = numMatchedRows;
- filteredIntKeysArray[filterArgId] = generateGroupByKeys(block,
numMatchedRows, matchedBitmap);
- }
- int numMatchedRows = numMatchedRowsArray[filterArgId];
- int[] filteredIntKeys = filteredIntKeysArray[filterArgId];
- Map<ExpressionContext, BlockValSet> blockValSetMap =
- AggregateOperator.getFilteredBlockValSetMap(aggFunction, block,
numMatchedRows, matchedBitmap);
- GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
- groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups());
- aggFunction.aggregateGroupBySV(numMatchedRows, filteredIntKeys,
groupByResultHolder, blockValSetMap);
+ processAggregateWithFilter(block);
+ }
+ }
+
+ private void processAggregateWithoutFilter(TransferableBlock block) {
+ int[] intKeys = generateGroupByKeys(block);
+ int numGroups = _groupIdGenerator.getNumGroups();
+ for (int i = 0; i < _aggFunctions.length; i++) {
+ AggregationFunction aggFunction = _aggFunctions[i];
+ Map<ExpressionContext, BlockValSet> blockValSetMap =
AggregateOperator.getBlockValSetMap(aggFunction, block);
+ GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+ groupByResultHolder.ensureCapacity(numGroups);
+ aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys,
groupByResultHolder, blockValSetMap);
+ }
+ }
+
+ private void processAggregateWithFilter(TransferableBlock block) {
+ // In the first loop, generate all the group keys, cache the matching rows
+ int[] intKeys = _filteredAggregationsSkipEmptyGroups ? null :
generateGroupByKeys(block);
+ RoaringBitmap[] matchedBitmaps = new RoaringBitmap[_maxFilterArgId + 1];
+ int[] numMatchedRowsArray = new int[_maxFilterArgId + 1];
+ int[][] filteredIntKeysArray = new int[_maxFilterArgId + 1][];
+ for (int filterArgId : _filterArgIds) {
+ if (filterArgId < 0) {
+ // No filter for this aggregation function
+ if (intKeys == null) {
+ intKeys = generateGroupByKeys(block);
}
- }
- if (intKeys == null && !_filteredAggregationsSkipEmptyGroups) {
- // _groupIdGenerator should still have all the groups even if there
are only filtered aggregates for SQL
- // compliant results. However, if the query option to skip empty
groups is set, we avoid this step for
- // improved performance.
- generateGroupByKeys(block);
- for (int i = 0; i < _aggFunctions.length; i++) {
-
_aggregateResultHolders[i].ensureCapacity(_groupIdGenerator.getNumGroups());
+ } else {
+ // Need to filter the block before aggregation
+ if (matchedBitmaps[filterArgId] == null) {
+ RoaringBitmap matchedBitmap =
AggregateOperator.getMatchedBitmap(block, filterArgId);
+ matchedBitmaps[filterArgId] = matchedBitmap;
+ int numMatchedRows = matchedBitmap.getCardinality();
+ numMatchedRowsArray[filterArgId] = numMatchedRows;
+ filteredIntKeysArray[filterArgId] = generateGroupByKeys(block,
numMatchedRows, matchedBitmap);
}
}
}
+
+ // In the second loop, aggregate the values
+ int numGroups = _groupIdGenerator.getNumGroups();
+ for (int i = 0; i < _aggFunctions.length; i++) {
+ AggregationFunction aggFunction = _aggFunctions[i];
+ GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+ groupByResultHolder.ensureCapacity(numGroups);
+ int filterArgId = _filterArgIds[i];
+ if (filterArgId < 0) {
+ Map<ExpressionContext, BlockValSet> blockValSetMap =
AggregateOperator.getBlockValSetMap(aggFunction, block);
+ aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys,
groupByResultHolder, blockValSetMap);
+ } else {
+ Map<ExpressionContext, BlockValSet> blockValSetMap =
+ AggregateOperator.getFilteredBlockValSetMap(aggFunction, block,
numMatchedRowsArray[filterArgId],
+ matchedBitmaps[filterArgId]);
+ aggFunction.aggregateGroupBySV(numMatchedRowsArray[filterArgId],
filteredIntKeysArray[filterArgId],
+ groupByResultHolder, blockValSetMap);
+ }
+ }
}
private void processMerge(TransferableBlock block) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]