This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new ceb1acb060 Fix the handling of filtered agg in MSQE (#15214)
ceb1acb060 is described below

commit ceb1acb060c823d13f75406e7c707f21a78eab76
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Thu Mar 6 23:11:00 2025 -0700

    Fix the handling of filtered agg in MSQE (#15214)
---
 .../tests/MultiStageEngineIntegrationTest.java     |  44 ++++----
 .../operator/MultistageGroupByExecutor.java        | 124 ++++++++++-----------
 2 files changed, 79 insertions(+), 89 deletions(-)

diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
index 49d15bc322..360d0095e0 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
@@ -1257,30 +1257,26 @@ public class MultiStageEngineIntegrationTest extends 
BaseClusterIntegrationTestS
       throws Exception {
     // Query written this way with a CTE and limit will be planned such that 
the multi-stage group by executor will be
     // used for both leaf and final aggregation
-    String sqlQuery = "SET mseMaxInitialResultHolderCapacity = 1;\n"
-        + "WITH tmp AS ("
-        + "  SELECT *"
-        + "  FROM mytable"
-        + "  WHERE AirlineID > 20000"
-        + "  LIMIT 10000"
-        + ") "
-        + "SELECT AirlineID,"
-        + "  COUNT(*) FILTER ("
-        + "    WHERE Origin = 'garbage'"
-        + ") "
-        + "FROM tmp "
-        + "GROUP BY AirlineID";
-    JsonNode result = postQuery(sqlQuery);
-    assertNoError(result);
-    // Ensure that result set is not empty
-    assertTrue(result.get("numRowsResultSet").asInt() > 0);
-
-    // Ensure that the count is 0 for all groups (because the aggregation 
filter does not match any rows)
-    JsonNode rows = result.get("resultTable").get("rows");
-    for (int i = 0; i < rows.size(); i++) {
-      assertEquals(rows.get(i).get(1).asInt(), 0);
-      // Ensure that the main filter was applied
-      assertTrue(rows.get(i).get(0).asInt() > 20000);
+    String aggregates1 = "COUNT(*) FILTER (WHERE Origin = 'garbage')";
+    String aggregates2 = aggregates1 + ", COUNT(*)";
+    String queryTemplate = "SET mseMaxInitialResultHolderCapacity = 1;\n"
+        + "WITH tmp AS (SELECT * FROM mytable WHERE AirlineID > 20000 LIMIT 
10000)\n"
+        + "SELECT AirlineID, %s FROM tmp GROUP BY AirlineID";
+    String query1 = String.format(queryTemplate, aggregates1);
+    String query2 = String.format(queryTemplate, aggregates2);
+    for (String query : new String[]{query1, query2}) {
+      JsonNode result = postQuery(query);
+      assertNoError(result);
+      // Ensure that result set is not empty
+      assertTrue(result.get("numRowsResultSet").asInt() > 0);
+
+      // Ensure that the count is 0 for all groups (because the aggregation 
filter does not match any rows)
+      JsonNode rows = result.get("resultTable").get("rows");
+      for (int i = 0; i < rows.size(); i++) {
+        assertEquals(rows.get(i).get(1).asInt(), 0);
+        // Ensure that the main filter was applied
+        assertTrue(rows.get(i).get(0).asInt() > 20000);
+      }
     }
   }
 
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
index fc8dd924b6..d1d5029496 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
@@ -71,16 +71,9 @@ public class MultistageGroupByExecutor {
   // because they use the zero based integer indexes to store results.
   private final GroupIdGenerator _groupIdGenerator;
 
-  public MultistageGroupByExecutor(
-      int[] groupKeyIds,
-      AggregationFunction[] aggFunctions,
-      int[] filterArgIds,
-      int maxFilterArgId,
-      AggType aggType,
-      boolean leafReturnFinalResult,
-      DataSchema resultSchema,
-      Map<String, String> opChainMetadata,
-      @Nullable PlanNode.NodeHint nodeHint) {
+  public MultistageGroupByExecutor(int[] groupKeyIds, AggregationFunction[] 
aggFunctions, int[] filterArgIds,
+      int maxFilterArgId, AggType aggType, boolean leafReturnFinalResult, 
DataSchema resultSchema,
+      Map<String, String> opChainMetadata, @Nullable PlanNode.NodeHint 
nodeHint) {
     _groupKeyIds = groupKeyIds;
     _aggFunctions = aggFunctions;
     _filterArgIds = filterArgIds;
@@ -248,10 +241,7 @@ public class MultistageGroupByExecutor {
     return rows;
   }
 
-  private Object[] getRow(
-      Iterator<GroupIdGenerator.GroupKey> groupKeyIterator,
-      int numKeys,
-      int numFunctions,
+  private Object[] getRow(Iterator<GroupIdGenerator.GroupKey> 
groupKeyIterator, int numKeys, int numFunctions,
       ColumnDataType[] resultStoredTypes) {
     GroupIdGenerator.GroupKey groupKey = groupKeyIterator.next();
     int groupId = groupKey._groupId;
@@ -293,62 +283,66 @@ public class MultistageGroupByExecutor {
 
   private void processAggregate(TransferableBlock block) {
     if (_maxFilterArgId < 0) {
-      // No filter for any aggregation function
-      int[] intKeys = generateGroupByKeys(block);
-      for (int i = 0; i < _aggFunctions.length; i++) {
-        AggregationFunction aggFunction = _aggFunctions[i];
-        Map<ExpressionContext, BlockValSet> blockValSetMap = 
AggregateOperator.getBlockValSetMap(aggFunction, block);
-        GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
-        groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups());
-        aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys, 
groupByResultHolder, blockValSetMap);
-      }
+      processAggregateWithoutFilter(block);
     } else {
-      // Some aggregation functions have filter, cache the matching rows
-      int[] intKeys = null;
-      RoaringBitmap[] matchedBitmaps = new RoaringBitmap[_maxFilterArgId + 1];
-      int[] numMatchedRowsArray = new int[_maxFilterArgId + 1];
-      int[][] filteredIntKeysArray = new int[_maxFilterArgId + 1][];
-      for (int i = 0; i < _aggFunctions.length; i++) {
-        AggregationFunction aggFunction = _aggFunctions[i];
-        int filterArgId = _filterArgIds[i];
-        if (filterArgId < 0) {
-          // No filter for this aggregation function
-          if (intKeys == null) {
-            intKeys = generateGroupByKeys(block);
-          }
-          Map<ExpressionContext, BlockValSet> blockValSetMap = 
AggregateOperator.getBlockValSetMap(aggFunction, block);
-          GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
-          groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups());
-          aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys, 
groupByResultHolder, blockValSetMap);
-        } else {
-          // Need to filter the block before aggregation
-          RoaringBitmap matchedBitmap = matchedBitmaps[filterArgId];
-          if (matchedBitmap == null) {
-            matchedBitmap = AggregateOperator.getMatchedBitmap(block, 
filterArgId);
-            matchedBitmaps[filterArgId] = matchedBitmap;
-            int numMatchedRows = matchedBitmap.getCardinality();
-            numMatchedRowsArray[filterArgId] = numMatchedRows;
-            filteredIntKeysArray[filterArgId] = generateGroupByKeys(block, 
numMatchedRows, matchedBitmap);
-          }
-          int numMatchedRows = numMatchedRowsArray[filterArgId];
-          int[] filteredIntKeys = filteredIntKeysArray[filterArgId];
-          Map<ExpressionContext, BlockValSet> blockValSetMap =
-              AggregateOperator.getFilteredBlockValSetMap(aggFunction, block, 
numMatchedRows, matchedBitmap);
-          GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
-          groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups());
-          aggFunction.aggregateGroupBySV(numMatchedRows, filteredIntKeys, 
groupByResultHolder, blockValSetMap);
+      processAggregateWithFilter(block);
+    }
+  }
+
+  private void processAggregateWithoutFilter(TransferableBlock block) {
+    int[] intKeys = generateGroupByKeys(block);
+    int numGroups = _groupIdGenerator.getNumGroups();
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      AggregationFunction aggFunction = _aggFunctions[i];
+      Map<ExpressionContext, BlockValSet> blockValSetMap = 
AggregateOperator.getBlockValSetMap(aggFunction, block);
+      GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+      groupByResultHolder.ensureCapacity(numGroups);
+      aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys, 
groupByResultHolder, blockValSetMap);
+    }
+  }
+
+  private void processAggregateWithFilter(TransferableBlock block) {
+    // In the first loop, generate all the group keys, cache the matching rows
+    int[] intKeys = _filteredAggregationsSkipEmptyGroups ? null : 
generateGroupByKeys(block);
+    RoaringBitmap[] matchedBitmaps = new RoaringBitmap[_maxFilterArgId + 1];
+    int[] numMatchedRowsArray = new int[_maxFilterArgId + 1];
+    int[][] filteredIntKeysArray = new int[_maxFilterArgId + 1][];
+    for (int filterArgId : _filterArgIds) {
+      if (filterArgId < 0) {
+        // No filter for this aggregation function
+        if (intKeys == null) {
+          intKeys = generateGroupByKeys(block);
         }
-      }
-      if (intKeys == null && !_filteredAggregationsSkipEmptyGroups) {
-        // _groupIdGenerator should still have all the groups even if there 
are only filtered aggregates for SQL
-        // compliant results. However, if the query option to skip empty 
groups is set, we avoid this step for
-        // improved performance.
-        generateGroupByKeys(block);
-        for (int i = 0; i < _aggFunctions.length; i++) {
-          
_aggregateResultHolders[i].ensureCapacity(_groupIdGenerator.getNumGroups());
+      } else {
+        // Need to filter the block before aggregation
+        if (matchedBitmaps[filterArgId] == null) {
+          RoaringBitmap matchedBitmap = 
AggregateOperator.getMatchedBitmap(block, filterArgId);
+          matchedBitmaps[filterArgId] = matchedBitmap;
+          int numMatchedRows = matchedBitmap.getCardinality();
+          numMatchedRowsArray[filterArgId] = numMatchedRows;
+          filteredIntKeysArray[filterArgId] = generateGroupByKeys(block, 
numMatchedRows, matchedBitmap);
         }
       }
     }
+
+    // In the second loop, aggregate the values
+    int numGroups = _groupIdGenerator.getNumGroups();
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      AggregationFunction aggFunction = _aggFunctions[i];
+      GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+      groupByResultHolder.ensureCapacity(numGroups);
+      int filterArgId = _filterArgIds[i];
+      if (filterArgId < 0) {
+        Map<ExpressionContext, BlockValSet> blockValSetMap = 
AggregateOperator.getBlockValSetMap(aggFunction, block);
+        aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys, 
groupByResultHolder, blockValSetMap);
+      } else {
+        Map<ExpressionContext, BlockValSet> blockValSetMap =
+            AggregateOperator.getFilteredBlockValSetMap(aggFunction, block, 
numMatchedRowsArray[filterArgId],
+                matchedBitmaps[filterArgId]);
+        aggFunction.aggregateGroupBySV(numMatchedRowsArray[filterArgId], 
filteredIntKeysArray[filterArgId],
+            groupByResultHolder, blockValSetMap);
+      }
+    }
   }
 
   private void processMerge(TransferableBlock block) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to