This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 36307cb501 Fixes filtered agg result column naming and filtered agg 
order-by compat (#10092)
36307cb501 is described below

commit 36307cb5019836c755a2422fa63ef380839eba58
Author: Evan Galpin <[email protected]>
AuthorDate: Fri Jan 13 12:07:15 2023 -0800

    Fixes filtered agg result column naming and filtered agg order-by compat 
(#10092)
---
 .../apache/pinot/core/data/table/TableResizer.java |  21 ++++
 .../operator/blocks/results/ResultsBlockUtils.java |  20 ++-
 .../operator/query/FilteredGroupByOperator.java    |  13 +-
 .../apache/pinot/core/plan/GroupByPlanNode.java    |   3 +-
 .../function/AggregationFunctionUtils.java         |   8 ++
 .../query/reduce/AggregationDataTableReducer.java  |  17 ++-
 .../pinot/queries/FilteredAggregationsTest.java    | 134 +++++++++++++++------
 ...terSegmentAggregationMultiValueQueriesTest.java |   3 +-
 ...SegmentAggregationMultiValueRawQueriesTest.java |   4 +-
 9 files changed, 173 insertions(+), 50 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
index 7f6704fd7a..cbbe6abdce 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
@@ -28,9 +28,12 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.request.context.FunctionContext;
 import org.apache.pinot.common.request.context.OrderByExpressionContext;
+import org.apache.pinot.common.request.context.RequestContextUtils;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
@@ -51,6 +54,8 @@ public class TableResizer {
   private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
   private final AggregationFunction[] _aggregationFunctions;
   private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
+  private final Map<Pair<FunctionContext, FilterContext>, Integer> 
_filteredAggregationIndexMap;
+  private final List<Pair<AggregationFunction, FilterContext>> 
_filteredAggregationFunctions;
   private final int _numOrderByExpressions;
   private final OrderByValueExtractor[] _orderByValueExtractors;
   private final Comparator<IntermediateRecord> _intermediateRecordComparator;
@@ -73,6 +78,8 @@ public class TableResizer {
     assert _aggregationFunctions != null;
     _aggregationFunctionIndexMap = 
queryContext.getAggregationFunctionIndexMap();
     assert _aggregationFunctionIndexMap != null;
+    _filteredAggregationIndexMap = 
queryContext.getFilteredAggregationsIndexMap();
+    _filteredAggregationFunctions = 
queryContext.getFilteredAggregationFunctions();
 
     List<OrderByExpressionContext> orderByExpressions = 
queryContext.getOrderByExpressions();
     assert orderByExpressions != null;
@@ -137,6 +144,15 @@ public class TableResizer {
     if (function.getType() == FunctionContext.Type.AGGREGATION) {
       // Aggregation function
       return new 
AggregationFunctionExtractor(_aggregationFunctionIndexMap.get(function));
+    } else if (function.getType() == FunctionContext.Type.TRANSFORM
+        && "FILTER".equalsIgnoreCase(function.getFunctionName())) {
+      FunctionContext aggregation = 
function.getArguments().get(0).getFunction();
+      ExpressionContext filterExpression = function.getArguments().get(1);
+      FilterContext filter = RequestContextUtils.getFilter(filterExpression);
+
+      int functionIndex = 
_filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
+      AggregationFunction aggregationFunction = 
_filteredAggregationFunctions.get(functionIndex).getLeft();
+      return new AggregationFunctionExtractor(functionIndex, 
aggregationFunction);
     } else {
       // Post-aggregation function
       return new PostAggregationFunctionExtractor(function);
@@ -414,6 +430,11 @@ public class TableResizer {
       _aggregationFunction = _aggregationFunctions[aggregationFunctionIndex];
     }
 
+    AggregationFunctionExtractor(int aggregationFunctionIndex, 
AggregationFunction aggregationFunction) {
+      _index = aggregationFunctionIndex + _numGroupByExpressions;
+      _aggregationFunction = aggregationFunction;
+    }
+
     @Override
     public ColumnDataType getValueType() {
       return _aggregationFunction.getFinalResultColumnType();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
index 5f5e7d0769..6fe8346a8b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
@@ -22,10 +22,13 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import 
org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import 
org.apache.pinot.core.query.aggregation.function.DistinctAggregationFunction;
 import org.apache.pinot.core.query.distinct.DistinctTable;
 import org.apache.pinot.core.query.request.context.QueryContext;
@@ -68,6 +71,8 @@ public class ResultsBlockUtils {
 
   private static AggregationResultsBlock 
buildEmptyAggregationQueryResults(QueryContext queryContext) {
     AggregationFunction[] aggregationFunctions = 
queryContext.getAggregationFunctions();
+    List<Pair<AggregationFunction, FilterContext>> 
filteredAggregationFunctions =
+        queryContext.getFilteredAggregationFunctions();
     assert aggregationFunctions != null;
     int numAggregations = aggregationFunctions.length;
     List<Object> results = new ArrayList<>(numAggregations);
@@ -78,12 +83,12 @@ public class ResultsBlockUtils {
   }
 
   private static GroupByResultsBlock 
buildEmptyGroupByQueryResults(QueryContext queryContext) {
-    AggregationFunction[] aggregationFunctions = 
queryContext.getAggregationFunctions();
-    assert aggregationFunctions != null;
-    int numAggregations = aggregationFunctions.length;
+    List<Pair<AggregationFunction, FilterContext>> 
filteredAggregationFunctions =
+        queryContext.getFilteredAggregationFunctions();
+
     List<ExpressionContext> groupByExpressions = 
queryContext.getGroupByExpressions();
     assert groupByExpressions != null;
-    int numColumns = groupByExpressions.size() + numAggregations;
+    int numColumns = groupByExpressions.size() + 
filteredAggregationFunctions.size();
     String[] columnNames = new String[numColumns];
     ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
     int index = 0;
@@ -93,9 +98,12 @@ public class ResultsBlockUtils {
       columnDataTypes[index] = ColumnDataType.STRING;
       index++;
     }
-    for (AggregationFunction aggregationFunction : aggregationFunctions) {
+    for (Pair<AggregationFunction, FilterContext> aggFilterPair : 
filteredAggregationFunctions) {
       // NOTE: Use AggregationFunction.getResultColumnName() for SQL format 
response
-      columnNames[index] = aggregationFunction.getResultColumnName();
+      AggregationFunction aggregationFunction = aggFilterPair.getLeft();
+      String columnName =
+          AggregationFunctionUtils.getResultColumnName(aggregationFunction, 
aggFilterPair.getRight());
+      columnNames[index] = columnName;
       columnDataTypes[index] = 
aggregationFunction.getIntermediateResultColumnType();
       index++;
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
index e895d817dd..872a999f54 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
@@ -24,6 +24,7 @@ import java.util.List;
 import java.util.stream.Collectors;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.data.table.IntermediateRecord;
@@ -34,6 +35,7 @@ import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
 import org.apache.pinot.core.operator.transform.TransformOperator;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import 
org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import 
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
 import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -62,6 +64,7 @@ public class FilteredGroupByOperator extends 
BaseOperator<GroupByResultsBlock> {
   private final QueryContext _queryContext;
 
   public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions,
+      List<Pair<AggregationFunction, FilterContext>> 
filteredAggregationFunctions,
       List<Pair<AggregationFunction[], TransformOperator>> 
aggFunctionsWithTransformOperator,
       ExpressionContext[] groupByExpressions, long numTotalDocs, QueryContext 
queryContext) {
     _aggregationFunctions = aggregationFunctions;
@@ -87,9 +90,12 @@ public class FilteredGroupByOperator extends 
BaseOperator<GroupByResultsBlock> {
 
     // Extract column names and data types for aggregation functions
     for (int i = 0; i < numAggregationFunctions; i++) {
-      AggregationFunction aggregationFunction = aggregationFunctions[i];
       int index = numGroupByExpressions + i;
-      columnNames[index] = aggregationFunction.getResultColumnName();
+      Pair<AggregationFunction, FilterContext> filteredAggPair = 
filteredAggregationFunctions.get(i);
+      AggregationFunction aggregationFunction = filteredAggPair.getLeft();
+      String columnName =
+          AggregationFunctionUtils.getResultColumnName(aggregationFunction, 
filteredAggPair.getRight());
+      columnNames[index] = columnName;
       columnDataTypes[index] = 
aggregationFunction.getIntermediateResultColumnType();
     }
 
@@ -102,7 +108,8 @@ public class FilteredGroupByOperator extends 
BaseOperator<GroupByResultsBlock> {
     int numAggregations = _aggregationFunctions.length;
 
     GroupByResultHolder[] groupByResultHolders = new 
GroupByResultHolder[numAggregations];
-    IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap = new 
IdentityHashMap<>(numAggregations);
+    IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap =
+        new IdentityHashMap<>(_aggregationFunctions.length);
     for (int i = 0; i < numAggregations; i++) {
       resultHolderIndexMap.put(_aggregationFunctions[i], i);
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java 
b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
index 99fdec9746..ccb51143e6 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
@@ -77,7 +77,8 @@ public class GroupByPlanNode implements PlanNode {
     List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
         AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, 
_queryContext,
             filterOperatorPair.getRight(), transformOperator, 
groupByExpressions);
-    return new 
FilteredGroupByOperator(_queryContext.getAggregationFunctions(), 
aggToTransformOpList,
+    return new FilteredGroupByOperator(_queryContext.getAggregationFunctions(),
+        _queryContext.getFilteredAggregationFunctions(), aggToTransformOpList,
         _queryContext.getGroupByExpressions().toArray(new 
ExpressionContext[0]), numTotalDocs, _queryContext);
   }
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 0dcecb046d..6b1dd21e3c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -259,4 +259,12 @@ public class AggregationFunctionUtils {
 
     return aggToTransformOpList;
   }
+
+  public static String getResultColumnName(AggregationFunction 
aggregationFunction, @Nullable FilterContext filter) {
+      String columnName = aggregationFunction.getResultColumnName();
+      if (filter != null) {
+        columnName += " FILTER(WHERE " + filter + ")";
+      }
+      return columnName;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index b727df9c30..739c1f691e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -21,9 +21,12 @@ package org.apache.pinot.core.query.reduce;
 import com.google.common.base.Preconditions;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.metrics.BrokerMetrics;
+import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.response.broker.BrokerResponseNative;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
@@ -42,10 +45,12 @@ import org.roaringbitmap.RoaringBitmap;
 public class AggregationDataTableReducer implements DataTableReducer {
   private final QueryContext _queryContext;
   private final AggregationFunction[] _aggregationFunctions;
+  private final List<Pair<AggregationFunction, FilterContext>> 
_filteredAggregationFunctions;
 
   AggregationDataTableReducer(QueryContext queryContext) {
     _queryContext = queryContext;
     _aggregationFunctions = queryContext.getAggregationFunctions();
+    _filteredAggregationFunctions = 
queryContext.getFilteredAggregationFunctions();
   }
 
   /**
@@ -150,11 +155,17 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
     int numAggregationFunctions = _aggregationFunctions.length;
     String[] columnNames = new String[numAggregationFunctions];
     ColumnDataType[] columnDataTypes = new 
ColumnDataType[numAggregationFunctions];
-    for (int i = 0; i < numAggregationFunctions; i++) {
-      AggregationFunction aggregationFunction = _aggregationFunctions[i];
-      columnNames[i] = aggregationFunction.getResultColumnName();
+
+    int i = 0;
+    for (Pair<AggregationFunction, FilterContext> aggFilterPair : 
_filteredAggregationFunctions) {
+      AggregationFunction aggregationFunction = aggFilterPair.getLeft();
+      String columnName =
+          AggregationFunctionUtils.getResultColumnName(aggregationFunction, 
aggFilterPair.getRight());
+      columnNames[i] = columnName;
       columnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
+      i++;
     }
+
     return new DataSchema(columnNames, columnDataTypes);
   }
 }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
index 9d772abc3f..2ea664ec67 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
@@ -161,51 +161,84 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
 
   @Test
   public void testSimpleQueries() {
-    String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) 
FROM MyTable WHERE INT_COL < 1000000";
-    String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 
9999 AND INT_COL < 1000000";
+    String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) 
sum1 FROM MyTable WHERE INT_COL < 1000000";
+    String nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE 
INT_COL > 9999 AND INT_COL < 1000000";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL < 3) FROM MyTable 
WHERE INT_COL > 1";
-    nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 1 AND 
INT_COL < 3";
+    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL < 3) sum1 FROM 
MyTable WHERE INT_COL > 1";
+    nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE INT_COL > 1 
AND INT_COL < 3";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT COUNT(*) FILTER(WHERE INT_COL = 4) FROM MyTable";
-    nonFilterQuery = "SELECT COUNT(*) FROM MyTable WHERE INT_COL = 4";
+    filterQuery = "SELECT COUNT(*) FILTER(WHERE INT_COL = 4) count1 FROM 
MyTable";
+    nonFilterQuery = "SELECT COUNT(*) count1 FROM MyTable WHERE INT_COL = 4";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 8000) FROM 
MyTable ";
-    nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 8000";
+    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 8000) sum1 FROM 
MyTable ";
+    nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE INT_COL > 
8000";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE NO_INDEX_COL <= 1) FROM 
MyTable WHERE INT_COL > 1";
-    nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE NO_INDEX_COL <= 1 
AND INT_COL > 1";
+    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE NO_INDEX_COL <= 1) sum1 
FROM MyTable WHERE INT_COL > 1";
+    nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE NO_INDEX_COL 
<= 1 AND INT_COL > 1";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT AVG(NO_INDEX_COL) FROM MyTable WHERE NO_INDEX_COL > 
-1";
-    nonFilterQuery = "SELECT AVG(NO_INDEX_COL) FROM MyTable";
+    filterQuery = "SELECT AVG(NO_INDEX_COL) avg1 FROM MyTable WHERE 
NO_INDEX_COL > -1";
+    nonFilterQuery = "SELECT AVG(NO_INDEX_COL) avg1 FROM MyTable";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT AVG(INT_COL) FILTER(WHERE NO_INDEX_COL > -1) FROM 
MyTable";
-    nonFilterQuery = "SELECT AVG(INT_COL) FROM MyTable";
+    filterQuery = "SELECT AVG(INT_COL) FILTER(WHERE NO_INDEX_COL > -1) avg1 
FROM MyTable";
+    nonFilterQuery = "SELECT AVG(INT_COL) avg1 FROM MyTable";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990), 
MAX(INT_COL) FILTER(WHERE INT_COL > 29990) "
-        + "FROM MyTable";
-    nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable WHERE 
INT_COL > 29990";
+    filterQuery =
+        "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) min1, 
MAX(INT_COL) FILTER(WHERE INT_COL > 29990) max1"
+            + " FROM MyTable";
+    nonFilterQuery = "SELECT MIN(INT_COL) min1, MAX(INT_COL) max1 FROM MyTable 
WHERE INT_COL > 29990";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL) FROM MyTable";
-    nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true";
+    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL) sum1 FROM 
MyTable";
+    nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE 
BOOLEAN_COL=true";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND 
STARTSWITH(STRING_COL, 'abc')) FROM MyTable";
-    nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true 
AND STARTSWITH(STRING_COL, 'abc')";
+    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND 
STARTSWITH(STRING_COL, 'abc')) sum1 FROM MyTable";
+    nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE 
BOOLEAN_COL=true AND STARTSWITH(STRING_COL, 'abc')";
     testQuery(filterQuery, nonFilterQuery);
 
     filterQuery =
-        "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND 
STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM " + "MyTable";
+        "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND 
STARTSWITH(REVERSE(STRING_COL), 'abc')) sum1 FROM MyTable";
+    nonFilterQuery =
+        "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE BOOLEAN_COL=true AND 
STARTSWITH(REVERSE(STRING_COL), " + "'abc')";
+    testQuery(filterQuery, nonFilterQuery);
+  }
+
+  @Test
+  public void testFilterResultColumnNameGroupBy() {
+    String filterQuery =
+        "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE 
INT_COL < 1000000 GROUP BY BOOLEAN_COL";
+    String nonFilterQuery =
+        "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE INT_COL > '9999')\" 
FROM MyTable WHERE INT_COL > 9999 AND "
+            + "INT_COL < 1000000 GROUP BY BOOLEAN_COL";
+    testQuery(filterQuery, nonFilterQuery);
+
+    filterQuery =
+        "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL < 
1000000) FROM MyTable GROUP BY BOOLEAN_COL";
+    nonFilterQuery =
+        "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE (INT_COL > '9999' AND 
INT_COL < '1000000'))\" FROM MyTable "
+            + "WHERE INT_COL > 9999 AND INT_COL < 1000000 GROUP BY 
BOOLEAN_COL";
+    testQuery(filterQuery, nonFilterQuery);
+  }
+
+  @Test
+  public void testFilterResultColumnNameNonGroupBy() {
+    String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) 
FROM MyTable WHERE INT_COL < 1000000";
+    String nonFilterQuery =
+        "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE INT_COL > '9999')\" 
FROM MyTable WHERE INT_COL > 9999 AND "
+            + "INT_COL < 1000000";
+    testQuery(filterQuery, nonFilterQuery);
+
+    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL 
< 1000000) FROM MyTable";
     nonFilterQuery =
-        "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND 
STARTSWITH(REVERSE(STRING_COL), " + "'abc')";
+        "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE (INT_COL > '9999' AND 
INT_COL < '1000000'))\" FROM MyTable "
+            + "WHERE INT_COL > 9999 AND INT_COL < 1000000";
     testQuery(filterQuery, nonFilterQuery);
   }
 
@@ -305,9 +338,9 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
 
   @Test
   public void testMultipleAggregationsOnSameFilter() {
-    String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 
29990), "
-        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable";
-    String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable 
WHERE INT_COL > 29990";
+    String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 
29990) testMin, "
+        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable";
+    String nonFilterQuery = "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax 
FROM MyTable WHERE INT_COL > 29990";
     testQuery(filterQuery, nonFilterQuery);
 
     filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS 
total_min, "
@@ -321,6 +354,26 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
     testQuery(filterQuery, nonFilterQuery);
   }
 
+  @Test
+  public void testMultipleAggregationsOnSameFilterOrderByFiltered() {
+    String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 
29990) testMin, "
+        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable 
ORDER BY testMax";
+    String nonFilterQuery =
+        "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax FROM MyTable WHERE 
INT_COL > 29990 ORDER BY testMax";
+    testQuery(filterQuery, nonFilterQuery);
+
+    filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS 
total_min, "
+        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) AS total_max, "
+        + "SUM(INT_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_sum, "
+        + "MAX(NO_INDEX_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_max2 
FROM MyTable ORDER BY total_sum";
+    nonFilterQuery = "SELECT MIN(CASE WHEN (NO_INDEX_COL > 29990) THEN INT_COL 
ELSE 99999 END) AS total_min, "
+        + "MAX(CASE WHEN (INT_COL > 29990) THEN INT_COL ELSE 0 END) AS 
total_max, "
+        + "SUM(CASE WHEN (NO_INDEX_COL < 5000) THEN INT_COL ELSE 0 END) AS 
total_sum, "
+        + "MAX(CASE WHEN (NO_INDEX_COL < 5000) THEN NO_INDEX_COL ELSE 0 END) 
AS total_max2 FROM MyTable ORDER BY "
+        + "total_sum";
+    testQuery(filterQuery, nonFilterQuery);
+  }
+
   @Test
   public void testMixedAggregationsOfSameType() {
     String filterQuery = "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE 
INT_COL > 25000) AS total_sum FROM MyTable";
@@ -337,8 +390,8 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
 
   @Test
   public void testGroupBy() {
-    String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000) 
FROM MyTable GROUP BY BOOLEAN_COL";
-    String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 
25000 GROUP BY BOOLEAN_COL";
+    String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000) 
testSum FROM MyTable GROUP BY BOOLEAN_COL";
+    String nonFilterQuery = "SELECT SUM(INT_COL) testSum FROM MyTable WHERE 
INT_COL > 25000 GROUP BY BOOLEAN_COL";
     testQuery(filterQuery, nonFilterQuery);
   }
 
@@ -356,17 +409,19 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
   @Test
   public void testGroupBySameFilter() {
     String filterQuery =
-        "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000), SUM(INT_COL) 
FILTER(WHERE INT_COL > 25000) FROM MyTable "
-            + "GROUP BY BOOLEAN_COL";
-    String nonFilterQuery = "SELECT AVG(INT_COL), SUM(INT_COL) FROM MyTable 
WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL";
+        "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000) testAvg, 
SUM(INT_COL) FILTER(WHERE INT_COL > 25000) "
+            + "testSum FROM MyTable GROUP BY BOOLEAN_COL";
+    String nonFilterQuery =
+        "SELECT AVG(INT_COL) testAvg, SUM(INT_COL) testSum FROM MyTable WHERE 
INT_COL > 25000 GROUP BY BOOLEAN_COL";
     testQuery(filterQuery, nonFilterQuery);
   }
 
   @Test
   public void testMultipleAggregationsOnSameFilterGroupBy() {
-    String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 
29990), "
-        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable GROUP BY 
BOOLEAN_COL";
-    String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable 
WHERE INT_COL > 29990 GROUP BY BOOLEAN_COL";
+    String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 
29990) testMin, "
+        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable 
GROUP BY BOOLEAN_COL";
+    String nonFilterQuery =
+        "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax FROM MyTable WHERE 
INT_COL > 29990 GROUP BY BOOLEAN_COL";
     testQuery(filterQuery, nonFilterQuery);
 
     filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS 
total_min, "
@@ -380,4 +435,15 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
         + "BOOLEAN_COL";
     testQuery(filterQuery, nonFilterQuery);
   }
+
+  @Test
+  public void testGroupBySameFilterOrderByFiltered() {
+    String filterQuery =
+        "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000) testAvg, 
SUM(INT_COL) FILTER(WHERE INT_COL > 25000) "
+            + "testSum FROM MyTable GROUP BY BOOLEAN_COL ORDER BY testAvg";
+    String nonFilterQuery =
+        "SELECT AVG(INT_COL) testAvg, SUM(INT_COL) testSum FROM MyTable WHERE 
INT_COL > 25000 GROUP BY BOOLEAN_COL "
+            + "ORDER BY testAvg";
+    testQuery(filterQuery, nonFilterQuery);
+  }
 }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
index 668aecd0ba..760c1c78c1 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
@@ -519,7 +519,8 @@ public class InterSegmentAggregationMultiValueQueriesTest 
extends BaseMultiValue
   public void testFilteredAggregations() {
     String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable 
WHERE column3 > 0";
     BrokerResponseNative brokerResponse = getBrokerResponse(query);
-    DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)"}, 
new ColumnDataType[]{ColumnDataType.LONG});
+    DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) 
FILTER(WHERE column1 > '5')"},
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG});
     ResultTable expectedResultTable =
         new ResultTable(expectedDataSchema, Collections.singletonList(new 
Object[]{370236L}));
     QueriesTestUtils.testInterSegmentsResult(brokerResponse, 740472L, 400000L, 
0L, 400000L, expectedResultTable);
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
index 7b4325df6d..06d89e6573 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
@@ -530,8 +530,8 @@ public class 
InterSegmentAggregationMultiValueRawQueriesTest extends BaseMultiVa
   public void testFilteredAggregations() {
     String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable 
WHERE column3 > 0";
     BrokerResponseNative brokerResponse = getBrokerResponse(query);
-    DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)"}, 
new DataSchema.ColumnDataType[]
-        {DataSchema.ColumnDataType.LONG});
+    DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) 
FILTER(WHERE column1 > '5')"},
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG});
     ResultTable expectedResultTable =
         new ResultTable(expectedDataSchema, Collections.singletonList(new 
Object[]{370236L}));
     QueriesTestUtils.testInterSegmentsResult(brokerResponse, 740472L, 400000L, 
0L, 400000L, expectedResultTable);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to