This is an automated email from the ASF dual-hosted git repository.

ankitsultana pushed a commit to branch release-1.2.0-rc
in repository https://gitbox.apache.org/repos/asf/pinot.git

commit 6943c0c692a5bc2512f96da5657d6d830ed050c8
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Fri Aug 9 10:47:39 2024 -0700

    Fix filtered aggregate with ordering (#13784)
---
 .../apache/pinot/core/data/table/TableResizer.java | 33 ++++-----
 .../core/query/request/context/QueryContext.java   | 31 +++------
 .../BrokerRequestToQueryContextConverterTest.java  | 79 +++++++---------------
 .../tests/OfflineClusterIntegrationTest.java       | 26 ++++---
 4 files changed, 61 insertions(+), 108 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
index 45ded8f1e5..452186b909 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
@@ -54,9 +54,7 @@ public class TableResizer {
   private final int _numGroupByExpressions;
   private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
   private final AggregationFunction[] _aggregationFunctions;
-  private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
   private final Map<Pair<FunctionContext, FilterContext>, Integer> 
_filteredAggregationIndexMap;
-  private final List<Pair<AggregationFunction, FilterContext>> 
_filteredAggregationFunctions;
   private final int _numOrderByExpressions;
   private final OrderByValueExtractor[] _orderByValueExtractors;
   private final Comparator<IntermediateRecord> _intermediateRecordComparator;
@@ -82,10 +80,8 @@ public class TableResizer {
 
     _aggregationFunctions = queryContext.getAggregationFunctions();
     assert _aggregationFunctions != null;
-    _aggregationFunctionIndexMap = 
queryContext.getAggregationFunctionIndexMap();
-    assert _aggregationFunctionIndexMap != null;
     _filteredAggregationIndexMap = 
queryContext.getFilteredAggregationsIndexMap();
-    _filteredAggregationFunctions = 
queryContext.getFilteredAggregationFunctions();
+    assert _filteredAggregationIndexMap != null;
 
     List<OrderByExpressionContext> orderByExpressions = 
queryContext.getOrderByExpressions();
     assert orderByExpressions != null;
@@ -148,26 +144,26 @@ public class TableResizer {
     FunctionContext function = expression.getFunction();
     Preconditions.checkState(function != null, "Failed to find ORDER-BY 
expression: %s in the GROUP-BY clause",
         expression);
+    FunctionContext aggregation;
+    FilterContext filter;
     if (function.getType() == FunctionContext.Type.AGGREGATION) {
       // Aggregation function
-      int index = _aggregationFunctionIndexMap.get(function);
-      // For final aggregate result, we can handle it the same way as group key
-      return _hasFinalInput ? new 
GroupByExpressionExtractor(_numGroupByExpressions + index)
-          : new AggregationFunctionExtractor(index);
+      aggregation = function;
+      filter = null;
     } else if (function.getType() == FunctionContext.Type.TRANSFORM && 
"FILTER".equalsIgnoreCase(
         function.getFunctionName())) {
       // Filtered aggregation
-      FunctionContext aggregation = 
function.getArguments().get(0).getFunction();
-      ExpressionContext filterExpression = function.getArguments().get(1);
-      FilterContext filter = RequestContextUtils.getFilter(filterExpression);
-      int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, 
filter));
-      // For final aggregate result, we can handle it the same way as group key
-      return _hasFinalInput ? new 
GroupByExpressionExtractor(_numGroupByExpressions + index)
-          : new AggregationFunctionExtractor(index, 
_filteredAggregationFunctions.get(index).getLeft());
+      aggregation = function.getArguments().get(0).getFunction();
+      filter = RequestContextUtils.getFilter(function.getArguments().get(1));
     } else {
       // Post-aggregation function
       return new PostAggregationFunctionExtractor(function);
     }
+
+    int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
+    // For final aggregate result, we can handle it the same way as group key
+    return _hasFinalInput ? new 
GroupByExpressionExtractor(_numGroupByExpressions + index)
+        : new AggregationFunctionExtractor(index);
   }
 
   /**
@@ -441,11 +437,6 @@ public class TableResizer {
       _aggregationFunction = _aggregationFunctions[aggregationFunctionIndex];
     }
 
-    AggregationFunctionExtractor(int aggregationFunctionIndex, 
AggregationFunction aggregationFunction) {
-      _index = aggregationFunctionIndex + _numGroupByExpressions;
-      _aggregationFunction = aggregationFunction;
-    }
-
     @Override
     public ColumnDataType getValueType() {
       return _aggregationFunction.getFinalResultColumnType();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
index 6c4a3d75c3..aee9261d94 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
@@ -91,10 +91,9 @@ public class QueryContext {
 
   // Pre-calculate the aggregation functions and columns for the query so that 
it can be shared across all the segments
   private AggregationFunction[] _aggregationFunctions;
-  private Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
-  private boolean _hasFilteredAggregations;
   private List<Pair<AggregationFunction, FilterContext>> 
_filteredAggregationFunctions;
   private Map<Pair<FunctionContext, FilterContext>, Integer> 
_filteredAggregationsIndexMap;
+  private boolean _hasFilteredAggregations;
   private Set<String> _columns;
 
   // Other properties to be shared across all the segments
@@ -272,22 +271,6 @@ public class QueryContext {
     return _filteredAggregationFunctions;
   }
 
-  /**
-   * Returns the filtered aggregation expressions for the query.
-   */
-  public boolean hasFilteredAggregations() {
-    return _hasFilteredAggregations;
-  }
-
-  /**
-   * Returns a map from the AGGREGATION FunctionContext to the index of the 
corresponding AggregationFunction in the
-   * aggregation functions array.
-   */
-  @Nullable
-  public Map<FunctionContext, Integer> getAggregationFunctionIndexMap() {
-    return _aggregationFunctionIndexMap;
-  }
-
   /**
    * Returns a map from the filtered aggregation (pair of AGGREGATION 
FunctionContext and FILTER FilterContext) to the
    * index of corresponding AggregationFunction in the aggregation functions 
array.
@@ -297,6 +280,13 @@ public class QueryContext {
     return _filteredAggregationsIndexMap;
   }
 
+  /**
+   * Returns the filtered aggregation expressions for the query.
+   */
+  public boolean hasFilteredAggregations() {
+    return _hasFilteredAggregations;
+  }
+
   /**
    * Returns the columns (IDENTIFIER expressions) in the query.
    */
@@ -619,12 +609,7 @@ public class QueryContext {
         for (int i = 0; i < numAggregations; i++) {
           aggregationFunctions[i] = 
filteredAggregationFunctions.get(i).getLeft();
         }
-        Map<FunctionContext, Integer> aggregationFunctionIndexMap = new 
HashMap<>();
-        for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : 
filteredAggregationsIndexMap.entrySet()) {
-          aggregationFunctionIndexMap.put(entry.getKey().getLeft(), 
entry.getValue());
-        }
         queryContext._aggregationFunctions = aggregationFunctions;
-        queryContext._aggregationFunctionIndexMap = 
aggregationFunctionIndexMap;
         queryContext._filteredAggregationFunctions = 
filteredAggregationFunctions;
         queryContext._filteredAggregationsIndexMap = 
filteredAggregationsIndexMap;
       }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
index 7c74e022af..ef331ebf59 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
@@ -480,21 +480,21 @@ public class BrokerRequestToQueryContextConverterTest {
       assertEquals(aggregationFunctions[3].getResultColumnName(), "sum(col4)");
       assertEquals(aggregationFunctions[4].getResultColumnName(), "max(col4)");
       assertEquals(aggregationFunctions[5].getResultColumnName(), "max(col1)");
-      Map<FunctionContext, Integer> aggregationFunctionIndexMap = 
queryContext.getAggregationFunctionIndexMap();
-      assertNotNull(aggregationFunctionIndexMap);
-      assertEquals(aggregationFunctionIndexMap.size(), 6);
-      assertEquals((int) aggregationFunctionIndexMap.get(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
-          
Collections.singletonList(ExpressionContext.forIdentifier("col1")))), 0);
-      assertEquals((int) aggregationFunctionIndexMap.get(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "max",
-          
Collections.singletonList(ExpressionContext.forIdentifier("col2")))), 1);
-      assertEquals((int) aggregationFunctionIndexMap.get(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "min",
-          
Collections.singletonList(ExpressionContext.forIdentifier("col2")))), 2);
-      assertEquals((int) aggregationFunctionIndexMap.get(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
-          
Collections.singletonList(ExpressionContext.forIdentifier("col4")))), 3);
-      assertEquals((int) aggregationFunctionIndexMap.get(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "max",
-          
Collections.singletonList(ExpressionContext.forIdentifier("col4")))), 4);
-      assertEquals((int) aggregationFunctionIndexMap.get(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "max",
-          
Collections.singletonList(ExpressionContext.forIdentifier("col1")))), 5);
+      Map<Pair<FunctionContext, FilterContext>, Integer> indexMap = 
queryContext.getFilteredAggregationsIndexMap();
+      assertNotNull(indexMap);
+      assertEquals(indexMap.size(), 6);
+      assertEquals((int) indexMap.get(Pair.of(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
+          Collections.singletonList(ExpressionContext.forIdentifier("col1"))), 
null)), 0);
+      assertEquals((int) indexMap.get(Pair.of(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "max",
+          Collections.singletonList(ExpressionContext.forIdentifier("col2"))), 
null)), 1);
+      assertEquals((int) indexMap.get(Pair.of(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "min",
+          Collections.singletonList(ExpressionContext.forIdentifier("col2"))), 
null)), 2);
+      assertEquals((int) indexMap.get(Pair.of(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
+          Collections.singletonList(ExpressionContext.forIdentifier("col4"))), 
null)), 3);
+      assertEquals((int) indexMap.get(Pair.of(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "max",
+          Collections.singletonList(ExpressionContext.forIdentifier("col4"))), 
null)), 4);
+      assertEquals((int) indexMap.get(Pair.of(new 
FunctionContext(FunctionContext.Type.AGGREGATION, "max",
+          Collections.singletonList(ExpressionContext.forIdentifier("col1"))), 
null)), 5);
     }
 
     // DistinctCountThetaSketch (string literal and escape quote)
@@ -540,21 +540,10 @@ public class BrokerRequestToQueryContextConverterTest {
       assertTrue(filteredAggregationFunctions.get(1).getLeft() instanceof 
CountAggregationFunction);
       assertEquals(filteredAggregationFunctions.get(1).getRight().toString(), 
"foo < '6'");
 
-      Map<FunctionContext, Integer> aggregationIndexMap = 
queryContext.getAggregationFunctionIndexMap();
-      assertNotNull(aggregationIndexMap);
-      assertEquals(aggregationIndexMap.size(), 1);
-      for (Map.Entry<FunctionContext, Integer> entry : 
aggregationIndexMap.entrySet()) {
-        FunctionContext aggregation = entry.getKey();
-        int index = entry.getValue();
-        assertEquals(aggregation.toString(), "count(*)");
-        assertTrue(index == 0 || index == 1);
-      }
-
-      Map<Pair<FunctionContext, FilterContext>, Integer> 
filteredAggregationsIndexMap =
-          queryContext.getFilteredAggregationsIndexMap();
-      assertNotNull(filteredAggregationsIndexMap);
-      assertEquals(filteredAggregationsIndexMap.size(), 2);
-      for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : 
filteredAggregationsIndexMap.entrySet()) {
+      Map<Pair<FunctionContext, FilterContext>, Integer> indexMap = 
queryContext.getFilteredAggregationsIndexMap();
+      assertNotNull(indexMap);
+      assertEquals(indexMap.size(), 2);
+      for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : 
indexMap.entrySet()) {
         Pair<FunctionContext, FilterContext> pair = entry.getKey();
         FunctionContext aggregation = pair.getLeft();
         FilterContext filter = pair.getRight();
@@ -600,32 +589,10 @@ public class BrokerRequestToQueryContextConverterTest {
       assertTrue(filteredAggregationFunctions.get(3).getLeft() instanceof 
MinAggregationFunction);
       assertEquals(filteredAggregationFunctions.get(3).getRight().toString(), 
"salary > '50000'");
 
-      Map<FunctionContext, Integer> aggregationIndexMap = 
queryContext.getAggregationFunctionIndexMap();
-      assertNotNull(aggregationIndexMap);
-      assertEquals(aggregationIndexMap.size(), 2);
-      for (Map.Entry<FunctionContext, Integer> entry : 
aggregationIndexMap.entrySet()) {
-        FunctionContext aggregation = entry.getKey();
-        int index = entry.getValue();
-        switch (index) {
-          case 0:
-          case 1:
-            assertEquals(aggregation.toString(), "sum(salary)");
-            break;
-          case 2:
-          case 3:
-            assertEquals(aggregation.toString(), "min(salary)");
-            break;
-          default:
-            fail();
-            break;
-        }
-      }
-
-      Map<Pair<FunctionContext, FilterContext>, Integer> 
filteredAggregationsIndexMap =
-          queryContext.getFilteredAggregationsIndexMap();
-      assertNotNull(filteredAggregationsIndexMap);
-      assertEquals(filteredAggregationsIndexMap.size(), 4);
-      for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : 
filteredAggregationsIndexMap.entrySet()) {
+      Map<Pair<FunctionContext, FilterContext>, Integer> indexMap = 
queryContext.getFilteredAggregationsIndexMap();
+      assertNotNull(indexMap);
+      assertEquals(indexMap.size(), 4);
+      for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : 
indexMap.entrySet()) {
         Pair<FunctionContext, FilterContext> pair = entry.getKey();
         FunctionContext aggregation = pair.getLeft();
         FilterContext filter = pair.getRight();
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index ca1ea0b1f9..7f93ba75ef 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -137,14 +137,6 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
       new StarTreeIndexConfig(Collections.singletonList("DestState"), null,
           
Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()),
 null, 100);
   private static final String TEST_STAR_TREE_QUERY_2 = "SELECT COUNT(*) FROM 
mytable WHERE DestState = 'CA'";
-  private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG =
-      "SELECT COUNT(*), COUNT(*) FILTER (WHERE Carrier = 'UA') FROM mytable 
WHERE DestState = 'CA'";
-  // This query contains a filtered aggregation which cannot be solved with 
startree, but the COUNT(*) still should be
-  private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG_MIXED =
-      "SELECT COUNT(*), AVG(ArrDelay) FILTER (WHERE Carrier = 'UA') FROM 
mytable WHERE DestState = 'CA'";
-  private static final StarTreeIndexConfig STAR_TREE_INDEX_CONFIG_3 =
-      new StarTreeIndexConfig(List.of("Carrier", "DestState"), null,
-          
Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()),
 null, 100);
 
   // For default columns test
   private static final String TEST_EXTRA_COLUMNS_QUERY = "SELECT COUNT(*) FROM 
mytable WHERE NewAddedIntMetric = 1";
@@ -3472,6 +3464,24 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
     testQuery("SELECT BOOL_OR(CAST(Diverted AS BOOLEAN)) FROM mytable");
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testGroupByAggregationWithLimitZero(boolean 
useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    testQuery("SELECT Origin, SUM(ArrDelay) FROM mytable GROUP BY Origin LIMIT 
0");
+  }
+
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testFilteredAggregationWithGroupByOrdering(boolean 
useMultiStageQueryEngine)
+    throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+
+    // Test the ordering is correctly applied to the correct aggregation (the 
one without FILTER clause)
+    // See https://github.com/apache/pinot/pull/13784
+    testQuery("SELECT DestCityName, COUNT(*) AS c1, COUNT(*) FILTER (WHERE 
AirTime = 0) AS c2 FROM mytable "
+        + "GROUP BY DestCityName ORDER BY c1 DESC LIMIT 10");
+  }
+
   private String buildSkipIndexesOption(String columnsAndIndexes) {
     return "SET " + SKIP_INDEXES + "='" + columnsAndIndexes + "'; ";
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to