This is an automated email from the ASF dual-hosted git repository.

tingchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 17ee2024dd FUNNEL_COUNT Aggregation Function (#10867)
17ee2024dd is described below

commit 17ee2024dd7d49ce2f5ebf00ef9d71b6ca589a00
Author: dario-liberman <130933669+dario-liber...@users.noreply.github.com>
AuthorDate: Wed Jun 21 19:26:57 2023 +0200

    FUNNEL_COUNT Aggregation Function (#10867)
    
    * New Funnel Aggregation Function
    
    * Funnel analytics support - FUNNEL_COUNT aggregation function
    
    * Delete FunnelAggregationFunction.java
    
    * Simplify Tests
    
    * Simplify Tests
    
    * Simplify Tests
    
    * within -> across
    
    Fix javadoc
    
    * Update FunnelCountAggregationFunction.java
    
    Address comments
    
    ---------
    
    Co-authored-by: Dario Liberman <dario.liber...@uber.com>
---
 .../org/apache/pinot/common/utils/DataSchema.java  |   3 +
 .../blocks/results/AggregationResultsBlock.java    |   4 +
 .../function/AggregationFunctionFactory.java       |   3 +
 .../function/FunnelCountAggregationFunction.java   | 511 +++++++++++++++++++++
 .../pinot/queries/BaseFunnelCountQueriesTest.java  | 252 ++++++++++
 .../queries/FunnelCountQueriesNonSortedTest.java   |  57 +++
 .../queries/FunnelCountQueriesSortedTest.java      |  65 +++
 .../org/apache/pinot/queries/QueriesTestUtils.java |   3 +-
 .../pinot/segment/spi/AggregationFunctionType.java |   5 +-
 9 files changed, 901 insertions(+), 2 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
index d8fedfc3d7..4020aff70a 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
@@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonPropertyOrder;
 import com.google.common.collect.Ordering;
 import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
@@ -469,6 +470,8 @@ public class DataSchema {
     private static long[] toLongArray(Object value) {
       if (value instanceof long[]) {
         return (long[]) value;
+      } else if (value instanceof LongArrayList) {
+        return ((LongArrayList) value).elements();
       } else {
         int[] intValues = (int[]) value;
         int length = intValues.length;
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
index b10ebdd3b9..2d816bffca 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.core.operator.blocks.results;
 
 import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
 import java.io.IOException;
 import java.math.BigDecimal;
 import java.util.Collections;
@@ -182,6 +183,9 @@ public class AggregationResultsBlock extends 
BaseResultsBlock {
       case DOUBLE_ARRAY:
         dataTableBuilder.setColumn(index, ((DoubleArrayList) 
result).elements());
         break;
+      case LONG_ARRAY:
+        dataTableBuilder.setColumn(index, ((LongArrayList) result).elements());
+        break;
       default:
         throw new IllegalStateException("Illegal column data type in final 
result: " + columnDataType);
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 06fbb1db66..7f96072c9e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -358,6 +358,9 @@ public class AggregationFunctionFactory {
           case ARGMIN:
             throw new IllegalArgumentException(
                 "Aggregation function: " + function + " is only supported in 
selection without alias.");
+          case FUNNELCOUNT:
+            return new FunnelCountAggregationFunction(arguments);
+
           default:
             throw new IllegalArgumentException();
         }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java
new file mode 100644
index 0000000000..4eecad1002
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java
@@ -0,0 +1,511 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import com.google.common.base.Preconditions;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import 
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.roaringbitmap.RoaringBitmap;
+
+
+/**
+ * The {@code FunnelCountAggregationFunction} calculates the number of step 
conversions for a given partition column and
+ * a list of boolean expressions.
+ * <p>IMPORTANT: This function relies on the partition column being 
partitioned for each segment, where there are no
+ * common values across different segments.
+ * <p>This function calculates the exact number of step matches per partition 
key within the segment, then sums up the
+ * results from different segments.
+ *
+ * Example:
+ *   SELECT
+ *    dateTrunc('day', timestamp) AS ts,
+ *    FUNNEL_COUNT(
+ *      STEPS(url = '/addToCart', url = '/checkout', url = 
'/orderConfirmation'),
+ *      CORRELATED_BY(user)
+ *    ) as step_counts
+ *    FROM user_log
+ *    WHERE url in ('/addToCart', '/checkout', '/orderConfirmation')
+ *    GROUP BY 1
+ */
+public class FunnelCountAggregationFunction implements 
AggregationFunction<List<Long>, LongArrayList> {
+  final List<ExpressionContext> _expressions;
+  final List<ExpressionContext> _stepExpressions;
+  final List<ExpressionContext> _correlateByExpressions;
+  final ExpressionContext _primaryCorrelationCol;
+  final int _numSteps;
+
+  final SegmentAggregationStrategy<?, List<Long>> _sortedAggregationStrategy;
+  final SegmentAggregationStrategy<?, List<Long>> _bitmapAggregationStrategy;
+
+  public FunnelCountAggregationFunction(List<ExpressionContext> expressions) {
+    _expressions = expressions;
+    _correlateByExpressions = 
Option.CORRELATE_BY.getInputExpressions(expressions);
+    _primaryCorrelationCol = 
Option.CORRELATE_BY.getFirstInputExpression(expressions);
+    _stepExpressions = Option.STEPS.getInputExpressions(expressions);
+    _numSteps = _stepExpressions.size();
+    _sortedAggregationStrategy = new SortedAggregationStrategy();
+    _bitmapAggregationStrategy = new BitmapAggregationStrategy();
+  }
+
+  @Override
+  public String getResultColumnName() {
+    return getType().getName().toLowerCase() + "(" + 
_expressions.stream().map(ExpressionContext::toString)
+        .collect(Collectors.joining(",")) + ")";
+  }
+
+  @Override
+  public List<ExpressionContext> getInputExpressions() {
+    final List<ExpressionContext> inputs = new ArrayList<>();
+    inputs.addAll(_correlateByExpressions);
+    inputs.addAll(_stepExpressions);
+    return inputs;
+  }
+
+  @Override
+  public AggregationFunctionType getType() {
+    return AggregationFunctionType.FUNNELCOUNT;
+  }
+
+  @Override
+  public AggregationResultHolder createAggregationResultHolder() {
+    return new ObjectAggregationResultHolder();
+  }
+
+  @Override
+  public GroupByResultHolder createGroupByResultHolder(int initialCapacity, 
int maxCapacity) {
+    return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+  }
+
+  @Override
+  public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregate(length, 
aggregationResultHolder, blockValSetMap);
+  }
+
+  @Override
+  public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    
getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregateGroupBySV(length,
 groupKeyArray,
+        groupByResultHolder, blockValSetMap);
+  }
+
+  @Override
+  public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    
getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregateGroupByMV(length,
 groupKeysArray,
+        groupByResultHolder, blockValSetMap);
+  }
+
+  @Override
+  public List<Long> extractAggregationResult(AggregationResultHolder 
aggregationResultHolder) {
+    return 
getAggregationStrategyByAggregationResult(aggregationResultHolder.getResult()).extractAggregationResult(
+        aggregationResultHolder);
+  }
+
+  @Override
+  public List<Long> extractGroupByResult(GroupByResultHolder 
groupByResultHolder, int groupKey) {
+    return 
getAggregationStrategyByAggregationResult(groupByResultHolder.getResult(groupKey)).extractGroupByResult(
+        groupByResultHolder, groupKey);
+  }
+
+  @Override
+  public List<Long> merge(List<Long> a, List<Long> b) {
+    int length = a.size();
+    Preconditions.checkState(length == b.size(), "The two operand arrays are 
not of the same size! provided %s, %s",
+        length, b.size());
+
+    LongArrayList result = toLongArrayList(a);
+    long[] elements = result.elements();
+    for (int i = 0; i < length; i++) {
+      elements[i] += b.get(i);
+    }
+    return result;
+  }
+
+  @Override
+  public ColumnDataType getIntermediateResultColumnType() {
+    return ColumnDataType.OBJECT;
+  }
+
+  @Override
+  public ColumnDataType getFinalResultColumnType() {
+    return ColumnDataType.LONG_ARRAY;
+  }
+
+  @Override
+  public LongArrayList extractFinalResult(List<Long> result) {
+    return toLongArrayList(result);
+  }
+
+  @Override
+  public String toExplainString() {
+    StringBuilder stringBuilder = new 
StringBuilder(getType().getName()).append('(');
+    int numArguments = getInputExpressions().size();
+    if (numArguments > 0) {
+      stringBuilder.append(getInputExpressions().get(0).toString());
+      for (int i = 1; i < numArguments; i++) {
+        stringBuilder.append(", 
").append(getInputExpressions().get(i).toString());
+      }
+    }
+    return stringBuilder.append(')').toString();
+  }
+
+  private static LongArrayList toLongArrayList(List<Long> longList) {
+    return longList instanceof LongArrayList ? ((LongArrayList) 
longList).clone() : new LongArrayList(longList);
+  }
+
+  private int[] getCorrelationIds(Map<ExpressionContext, BlockValSet> 
blockValSetMap) {
+    return blockValSetMap.get(_primaryCorrelationCol).getDictionaryIdsSV();
+  }
+
+  private int[][] getSteps(Map<ExpressionContext, BlockValSet> blockValSetMap) 
{
+    final int[][] steps = new int[_numSteps][];
+    for (int n = 0; n < _numSteps; n++) {
+      final BlockValSet stepBlockValSet = 
blockValSetMap.get(_stepExpressions.get(n));
+      steps[n] = stepBlockValSet.getIntValuesSV();
+    }
+    return steps;
+  }
+
+  private boolean isSorted(Map<ExpressionContext, BlockValSet> blockValSetMap) 
{
+    final Dictionary primaryCorrelationDictionary = 
blockValSetMap.get(_primaryCorrelationCol).getDictionary();
+    if (primaryCorrelationDictionary == null) {
+      throw new IllegalArgumentException(
+          "CORRELATE_BY column in FUNNELCOUNT aggregation function not 
supported, please use a dictionary encoded "
+              + "column.");
+    }
+    return primaryCorrelationDictionary.isSorted();
+  }
+
+  private SegmentAggregationStrategy<?, List<Long>> 
getAggregationStrategyByBlockValSetMap(
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    return isSorted(blockValSetMap) ? _sortedAggregationStrategy : 
_bitmapAggregationStrategy;
+  }
+
+  private SegmentAggregationStrategy<?, List<Long>> 
getAggregationStrategyByAggregationResult(Object aggResult) {
+    return aggResult instanceof SortedAggregationResult ? 
_sortedAggregationStrategy : _bitmapAggregationStrategy;
+  }
+
+  enum Option {
+    STEPS("steps"),
+    CORRELATE_BY("correlateby");
+
+    final String _name;
+
+    Option(String name) {
+      _name = name;
+    }
+
+    boolean matches(ExpressionContext expression) {
+      if (expression.getType() != ExpressionContext.Type.FUNCTION) {
+        return false;
+      }
+      return _name.equals(expression.getFunction().getFunctionName());
+    }
+
+    Optional<ExpressionContext> find(List<ExpressionContext> expressions) {
+      return expressions.stream().filter(this::matches).findFirst();
+    }
+
+    public List<ExpressionContext> getInputExpressions(List<ExpressionContext> 
expressions) {
+      return this.find(expressions).map(exp -> 
exp.getFunction().getArguments())
+          .orElseThrow(() -> new IllegalStateException("FUNNELCOUNT requires " 
+ _name));
+    }
+
+    public ExpressionContext getFirstInputExpression(List<ExpressionContext> 
expressions) {
+      return this.find(expressions)
+          .flatMap(exp -> 
exp.getFunction().getArguments().stream().findFirst())
+          .orElseThrow(() -> new IllegalStateException("FUNNELCOUNT: " + _name 
+ " requires an argument."));
+    }
+  }
+
+  /**
+   * Interface for segment aggregation strategy.
+   *
+   * <p>The implementation should be stateless, and can be shared among 
multiple segments in multiple threads. The
+   * result for each segment should be stored and passed in via the result 
holder.
+   * There should be no assumptions beyond segment boundaries, different 
aggregation strategies may be utilized
+   * across different segments for a given query.
+   *
+   * @param <A> Aggregation result accumulated across blocks within segment, 
kept by result holder.
+   * @param <I> Intermediate result at segment level (extracted from 
aforementioned aggregation result).
+   */
+  @ThreadSafe
+  static abstract class SegmentAggregationStrategy<A, I> {
+
+    /**
+     * Returns an aggregation result for this aggregation strategy to be kept 
in a result holder (aggregation only).
+     */
+    abstract A createAggregationResult();
+
+    public A getAggregationResultGroupBy(GroupByResultHolder 
groupByResultHolder, int groupKey) {
+      A aggResult = groupByResultHolder.getResult(groupKey);
+      if (aggResult == null) {
+        aggResult = createAggregationResult();
+        groupByResultHolder.setValueForKey(groupKey, aggResult);
+      }
+      return aggResult;
+    }
+
+    public A getAggregationResult(AggregationResultHolder 
aggregationResultHolder) {
+      A aggResult = aggregationResultHolder.getResult();
+      if (aggResult == null) {
+        aggResult = createAggregationResult();
+        aggregationResultHolder.setValue(aggResult);
+      }
+      return aggResult;
+    }
+
+    /**
+     * Performs aggregation on the given block value sets (aggregation only).
+     */
+    abstract void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap);
+
+    /**
+     * Performs aggregation on the given group key array and block value sets 
(aggregation group-by on single-value
+     * columns).
+     */
+    abstract void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap);
+
+    /**
+     * Performs aggregation on the given group keys array and block value sets 
(aggregation group-by on multi-value
+     * columns).
+     */
+    abstract void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap);
+
+    /**
+     * Extracts the intermediate result from the aggregation result holder 
(aggregation only).
+     */
+    public I extractAggregationResult(AggregationResultHolder 
aggregationResultHolder) {
+      return extractIntermediateResult(aggregationResultHolder.getResult());
+    }
+
+    /**
+     * Extracts the intermediate result from the group-by result holder for 
the given group key (aggregation group-by).
+     */
+    public I extractGroupByResult(GroupByResultHolder groupByResultHolder, int 
groupKey) {
+      return 
extractIntermediateResult(groupByResultHolder.getResult(groupKey));
+    }
+
+    abstract I extractIntermediateResult(A aggregationResult);
+  }
+
+  /**
+   * Aggregation strategy leveraging roaring bitmap algebra 
(unions/intersections).
+   */
+  class BitmapAggregationStrategy extends 
SegmentAggregationStrategy<RoaringBitmap[], List<Long>> {
+
+    @Override
+    public RoaringBitmap[] createAggregationResult() {
+      final RoaringBitmap[] stepsBitmaps = new RoaringBitmap[_numSteps];
+      for (int n = 0; n < _numSteps; n++) {
+        stepsBitmaps[n] = new RoaringBitmap();
+      }
+      return stepsBitmaps;
+    }
+
+    @Override
+    public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap) {
+      final int[] correlationIds = getCorrelationIds(blockValSetMap);
+      final int[][] steps = getSteps(blockValSetMap);
+
+      final RoaringBitmap[] stepsBitmaps = 
getAggregationResult(aggregationResultHolder);
+
+      for (int n = 0; n < _numSteps; n++) {
+        for (int i = 0; i < length; i++) {
+          if (steps[n][i] > 0) {
+            stepsBitmaps[n].add(correlationIds[i]);
+          }
+        }
+      }
+    }
+
+    @Override
+    public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap) {
+      final int[] correlationIds = getCorrelationIds(blockValSetMap);
+      final int[][] steps = getSteps(blockValSetMap);
+
+      for (int n = 0; n < _numSteps; n++) {
+        for (int i = 0; i < length; i++) {
+          final int groupKey = groupKeyArray[i];
+          if (steps[n][i] > 0) {
+            getAggregationResultGroupBy(groupByResultHolder, 
groupKey)[n].add(correlationIds[i]);
+          }
+        }
+      }
+    }
+
+    @Override
+    public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap) {
+      final int[] correlationIds = getCorrelationIds(blockValSetMap);
+      final int[][] steps = getSteps(blockValSetMap);
+
+      for (int n = 0; n < _numSteps; n++) {
+        for (int i = 0; i < length; i++) {
+          for (int groupKey : groupKeysArray[i]) {
+            if (steps[n][i] > 0) {
+              getAggregationResultGroupBy(groupByResultHolder, 
groupKey)[n].add(correlationIds[i]);
+            }
+          }
+        }
+      }
+    }
+
+    @Override
+    public List<Long> extractIntermediateResult(RoaringBitmap[] stepsBitmaps) {
+      if (stepsBitmaps == null) {
+        return new LongArrayList(_numSteps);
+      }
+
+      long[] result = new long[_numSteps];
+      result[0] = stepsBitmaps[0].getCardinality();
+      for (int i = 1; i < _numSteps; i++) {
+        // intersect this step with previous step
+        stepsBitmaps[i].and(stepsBitmaps[i - 1]);
+        result[i] = stepsBitmaps[i].getCardinality();
+      }
+      return LongArrayList.wrap(result);
+    }
+  }
+
+  /**
+   * Aggregation strategy for segments sorted by the main correlation column.
+   */
+  class SortedAggregationStrategy extends 
SegmentAggregationStrategy<SortedAggregationResult, List<Long>> {
+
+    @Override
+    public SortedAggregationResult createAggregationResult() {
+      return new SortedAggregationResult();
+    }
+
+    @Override
+    public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap) {
+      final int[] correlationIds = getCorrelationIds(blockValSetMap);
+      final int[][] steps = getSteps(blockValSetMap);
+
+      final SortedAggregationResult agg = 
getAggregationResult(aggregationResultHolder);
+
+      for (int i = 0; i < length; i++) {
+        agg.sortedCount(steps, i, correlationIds[i]);
+      }
+    }
+
+    @Override
+    public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap) {
+      final int[] correlationIds = getCorrelationIds(blockValSetMap);
+      final int[][] steps = getSteps(blockValSetMap);
+
+      for (int i = 0; i < length; i++) {
+        final int groupKey = groupKeyArray[i];
+        final SortedAggregationResult agg = 
getAggregationResultGroupBy(groupByResultHolder, groupKey);
+
+        agg.sortedCount(steps, i, correlationIds[i]);
+      }
+    }
+
+    @Override
+    public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
+        Map<ExpressionContext, BlockValSet> blockValSetMap) {
+      final int[] correlationIds = getCorrelationIds(blockValSetMap);
+      final int[][] steps = getSteps(blockValSetMap);
+
+      for (int i = 0; i < length; i++) {
+        for (int groupKey : groupKeysArray[i]) {
+          final SortedAggregationResult agg = 
getAggregationResultGroupBy(groupByResultHolder, groupKey);
+
+          agg.sortedCount(steps, i, correlationIds[i]);
+        }
+      }
+    }
+
+    @Override
+    public List<Long> extractIntermediateResult(SortedAggregationResult agg) {
+      if (agg == null) {
+        return new LongArrayList(_numSteps);
+      }
+
+      return LongArrayList.wrap(agg.extractResult());
+    }
+  }
+
+  /**
+   * Aggregation result data structure leveraged by sorted aggregation 
strategy.
+   */
+  class SortedAggregationResult {
+    public long[] _stepCounters = new long[_numSteps];
+    public int _lastCorrelationId = Integer.MIN_VALUE;
+    public boolean[] _correlatedSteps = new boolean[_numSteps];
+
+    public void sortedCount(int[][] steps, int i, int correlationId) {
+      if (correlationId == _lastCorrelationId) {
+        // same correlation as before, keep accumulating.
+        for (int n = 0; n < _numSteps; n++) {
+          _correlatedSteps[n] |= steps[n][i] > 0;
+        }
+      } else {
+        // End of correlation group, calculate funnel conversion counts
+        incrStepCounters();
+
+        // initialize next correlation group
+        for (int n = 0; n < _numSteps; n++) {
+          _correlatedSteps[n] = steps[n][i] > 0;
+        }
+        _lastCorrelationId = correlationId;
+      }
+    }
+
+    void incrStepCounters() {
+      for (int n = 0; n < _numSteps; n++) {
+        if (!_correlatedSteps[n]) {
+          break;
+        }
+        _stepCounters[n]++;
+      }
+    }
+
+    public long[] extractResult() {
+      // count last correlation id left open
+      incrStepCounters();
+
+      return _stepCounters;
+    }
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java
new file mode 100644
index 0000000000..ef5c7d596f
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java
@@ -0,0 +1,252 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.utils.HashUtil;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
+import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import org.apache.pinot.core.operator.query.GroupByOperator;
+import 
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+/**
+ * Base queries test for FUNNEL_COUNT queries.
+ * Each strategy gets its own test.
+ */
+@SuppressWarnings("rawtypes")
+abstract public class BaseFunnelCountQueriesTest extends BaseQueriesTest {
+  protected static final File INDEX_DIR =
+      new File(FileUtils.getTempDirectory(), "FunnelCountQueriesTest");
+  protected static final String RAW_TABLE_NAME = "testTable";
+  protected static final String SEGMENT_NAME = "testSegment";
+  protected static final Random RANDOM = new Random();
+
+  protected static final int NUM_RECORDS = 2000;
+  protected static final int MAX_VALUE = 1000;
+  protected static final int NUM_GROUPS = 100;
+  protected static final int FILTER_LIMIT = 50;
+  protected static final String ID_COLUMN = "idColumn";
+  protected static final String STEP_COLUMN = "stepColumn";
+  protected static final String[] STEPS = {"A", "B"};
+  protected static final Schema SCHEMA = new Schema.SchemaBuilder()
+      .addSingleValueDimension(ID_COLUMN, DataType.INT)
+      .addSingleValueDimension(STEP_COLUMN, DataType.STRING)
+      .build();
+  protected static final TableConfigBuilder TABLE_CONFIG_BUILDER =
+      new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME);
+
+  private Set<Integer>[] _values = new Set[2];
+  private List<Integer> _all = new ArrayList<>();
+  private IndexSegment _indexSegment;
+  private List<IndexSegment> _indexSegments;
+
+  protected abstract int getExpectedNumEntriesScannedInFilter();
+  protected abstract TableConfig getTableConfig();
+  protected abstract IndexSegment buildSegment(List<GenericRow> records) 
throws Exception;
+
+  @Override
+  protected String getFilter() {
+    return String.format(" WHERE idColumn >= %s", FILTER_LIMIT);
+  }
+
+  @Override
+  protected IndexSegment getIndexSegment() {
+    return _indexSegment;
+  }
+
+  @Override
+  protected List<IndexSegment> getIndexSegments() {
+    return _indexSegments;
+  }
+
+  @BeforeClass
+  public void setUp()
+      throws Exception {
+    FileUtils.deleteDirectory(INDEX_DIR);
+
+    List<GenericRow> records = genereateRows();
+    _indexSegment = buildSegment(records);
+    _indexSegments = Arrays.asList(_indexSegment, _indexSegment);
+  }
+
+  private List<GenericRow> genereateRows() {
+    List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+    int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE);
+    _values[0] = new HashSet<>(hashMapCapacity);
+    _values[1] = new HashSet<>(hashMapCapacity);
+    for (int i = 0; i < NUM_RECORDS; i++) {
+      int value = RANDOM.nextInt(MAX_VALUE);
+      GenericRow record = new GenericRow();
+      record.putValue(ID_COLUMN, value);
+      record.putValue(STEP_COLUMN, STEPS[i % 2]);
+      records.add(record);
+      _all.add(Integer.hashCode(value));
+      _values[i % 2].add(Integer.hashCode(value));
+    }
+    return records;
+  }
+
+  @Test
+  public void testAggregationOnly() {
+    String query = String.format("SELECT "
+        + "FUNNEL_COUNT("
+        + " STEPS(stepColumn = 'A', stepColumn = 'B'),"
+        + " CORRELATE_BY(idColumn)"
+        + ") FROM testTable");
+
+    // Inner segment
+    Predicate<Integer> filter = id -> id >= FILTER_LIMIT;
+    long expectedFilteredNumDocs = _all.stream().filter(filter).count();
+    Set<Integer> filteredA = 
_values[0].stream().filter(filter).collect(Collectors.toSet());
+    Set<Integer> filteredB = 
_values[1].stream().filter(filter).collect(Collectors.toSet());
+    Set<Integer> intersection = new HashSet<>(filteredA);
+    intersection.retainAll(filteredB);
+    long[] expectedResult = { filteredA.size(), intersection.size() };
+
+    Operator operator = getOperatorWithFilter(query);
+    assertTrue(operator instanceof AggregationOperator);
+    AggregationResultsBlock resultsBlock = ((AggregationOperator) 
operator).nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
+        expectedFilteredNumDocs, getExpectedNumEntriesScannedInFilter(), 2 * 
expectedFilteredNumDocs, NUM_RECORDS);
+    List<Object> aggregationResult = resultsBlock.getResults();
+    assertNotNull(aggregationResult);
+    assertEquals(aggregationResult.size(), 1);
+    for (int i = 0; i < 2; i++) {
+      assertEquals(((LongArrayList) aggregationResult.get(0)).getLong(i), 
expectedResult[i]);
+    }
+
+    // Inter segments (expect 4 * inner segment result)
+    for (int i = 0; i < 2; i++) {
+      expectedResult[i] = 4 * expectedResult[i];
+    }
+    Object[] expectedResults = { expectedResult };
+
+    
QueriesTestUtils.testInterSegmentsResult(getBrokerResponseWithFilter(query),
+        4 * expectedFilteredNumDocs, 4 * 
getExpectedNumEntriesScannedInFilter(), 4 * 2 * expectedFilteredNumDocs,
+        4 * NUM_RECORDS, expectedResults);
+  }
+
+  @Test
+  public void testAggregationGroupBy() {
+    String query = String.format("SELECT "
+        + "MOD(idColumn, %s), "
+        + "FUNNEL_COUNT("
+        + " STEPS(stepColumn = 'A', stepColumn = 'B'),"
+        + " CORRELATE_BY(idColumn)"
+        + ") FROM testTable "
+        + "WHERE idColumn >= %s "
+        + "GROUP BY 1 ORDER BY 1 LIMIT %s", NUM_GROUPS, FILTER_LIMIT, 
NUM_GROUPS);
+
+    // Inner segment
+    Set<Integer>[] filteredA = new Set[NUM_GROUPS];
+    Set<Integer>[] filteredB = new Set[NUM_GROUPS];
+    Set<Integer>[] intersection = new Set[NUM_GROUPS];
+    long[][] expectedResult = new long[NUM_GROUPS][];
+
+    long expectedFilteredNumDocs = _all.stream().filter(id -> id >= 
FILTER_LIMIT).count();
+
+    int expectedNumGroups = 0;
+    for (int i = 0; i < NUM_GROUPS; i++) {
+      final int group = i;
+      Predicate<Integer> filter = id -> id >= FILTER_LIMIT && id % NUM_GROUPS 
== group;
+      filteredA[group] = 
_values[0].stream().filter(filter).collect(Collectors.toSet());
+      filteredB[group] = 
_values[1].stream().filter(filter).collect(Collectors.toSet());
+      intersection[group] = new HashSet<>(filteredA[group]);
+      intersection[group].retainAll(filteredB[group]);
+      if (!filteredA[i].isEmpty() || !filteredB[i].isEmpty()) {
+        expectedNumGroups++;
+        expectedResult[group] = new long[] { filteredA[group].size(), 
intersection[group].size() };
+      }
+    }
+
+    // Inner segment
+    GroupByOperator groupByOperator = getOperator(query);
+    GroupByResultsBlock resultsBlock = groupByOperator.nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(groupByOperator.getExecutionStatistics(),
+        expectedFilteredNumDocs, getExpectedNumEntriesScannedInFilter(), 2 * 
expectedFilteredNumDocs, NUM_RECORDS);
+
+    AggregationGroupByResult aggregationGroupByResult = 
resultsBlock.getAggregationGroupByResult();
+    assertNotNull(aggregationGroupByResult);
+    int numGroups = 0;
+    Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = 
aggregationGroupByResult.getGroupKeyIterator();
+    while (groupKeyIterator.hasNext()) {
+      numGroups++;
+      GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+      int key = ((Double) groupKey._keys[0]).intValue();
+      assertEquals(aggregationGroupByResult.getResultForGroupId(0, 
groupKey._groupId),
+          new LongArrayList(expectedResult[key]));
+    }
+    assertEquals(numGroups, expectedNumGroups);
+
+    // Inter segments (expect 4 * inner segment result)
+    List<Object[]> expectedRows = new ArrayList<>();
+    for (int i = 0; i < NUM_GROUPS; i++) {
+      if (expectedResult[i] == null) {
+        continue;
+      }
+      for (int step = 0; step < 2; step++) {
+          expectedResult[i][step] = 4 * expectedResult[i][step];
+      }
+      Object[] expectedRow = { Double.valueOf(i), expectedResult[i] };
+      expectedRows.add(expectedRow);
+    }
+
+    // Inter segments (expect 4 * inner segment result)
+    QueriesTestUtils.testInterSegmentsResult(getBrokerResponse(query),
+        4 * expectedFilteredNumDocs, 4 * 
getExpectedNumEntriesScannedInFilter(), 4 * 2 * expectedFilteredNumDocs,
+        4 * NUM_RECORDS, expectedRows);
+  }
+
+  @AfterClass
+  public void tearDown()
+      throws IOException {
+    _indexSegment.destroy();
+    FileUtils.deleteDirectory(INDEX_DIR);
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java
new file mode 100644
index 0000000000..c89a5d74c9
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.util.Collections;
+import java.util.List;
+import 
org.apache.pinot.segment.local.indexsegment.mutable.MutableSegmentImplTestUtils;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.MutableSegment;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.data.readers.GenericRow;
+
+
+/**
+ * Queries test for FUNNEL_COUNT queries.
+ */
+@SuppressWarnings("rawtypes")
+public class FunnelCountQueriesNonSortedTest extends 
BaseFunnelCountQueriesTest {
+
+  @Override
+  protected int getExpectedNumEntriesScannedInFilter() {
+    return NUM_RECORDS;
+  }
+
+  @Override
+  protected TableConfig getTableConfig() {
+    return TABLE_CONFIG_BUILDER.build();
+  }
+
+  @Override
+  protected IndexSegment buildSegment(List<GenericRow> records)
+      throws Exception {
+    MutableSegment mutableSegment = MutableSegmentImplTestUtils
+        .createMutableSegmentImpl(SCHEMA, Collections.emptySet(), 
Collections.emptySet(), Collections.emptySet(),
+            false);
+    for (GenericRow record : records) {
+      mutableSegment.index(record, null);
+    }
+    return mutableSegment;
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java
new file mode 100644
index 0000000000..f06fe26637
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.util.Comparator;
+import java.util.List;
+import 
org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
+import 
org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.ReadMode;
+
+
+/**
+ * Queries test for FUNNEL_COUNT queries using sorted strategy.
+ */
+@SuppressWarnings("rawtypes")
+public class FunnelCountQueriesSortedTest extends BaseFunnelCountQueriesTest {
+
+  @Override
+  protected int getExpectedNumEntriesScannedInFilter() {
+    return 0;
+  }
+
+  @Override
+  protected TableConfig getTableConfig() {
+    return TABLE_CONFIG_BUILDER.setSortedColumn(ID_COLUMN).build();
+  }
+
+  @Override
+  protected IndexSegment buildSegment(List<GenericRow> records)
+      throws Exception {
+    // Simulate PinotSegmentSorter
+    records.sort(Comparator.comparingInt(rec -> (Integer) 
rec.getValue(ID_COLUMN)));
+
+    SegmentGeneratorConfig segmentGeneratorConfig = new 
SegmentGeneratorConfig(getTableConfig(), SCHEMA);
+    segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+    segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+    segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+    SegmentIndexCreationDriverImpl driver = new 
SegmentIndexCreationDriverImpl();
+    driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+    driver.build();
+    return ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), 
ReadMode.mmap);
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java 
b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java
index c5ba9cdb8b..c5c9eea9a3 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java
@@ -133,7 +133,8 @@ public class QueriesTestUtils {
   private static void validateRows(List<Object[]> actual, List<Object[]> 
expected) {
     assertEquals(actual.size(), expected.size());
     for (int i = 0; i < actual.size(); i++) {
-      assertEquals(actual.get(i), expected.get(i));
+      // Generic assertEquals delegates to assertArrayEquals, which can test 
for equality of array values in rows.
+      assertEquals((Object) actual.get(i), (Object) expected.get(i));
     }
   }
 
diff --git 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 7b2a02d666..5201fdcd2b 100644
--- 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++ 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -116,7 +116,10 @@ public enum AggregationFunctionType {
   
PARENTARGMIN(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + 
ARGMIN.getName()),
   
PARENTARGMAX(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + 
ARGMAX.getName()),
   CHILDARGMIN(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX 
+ ARGMIN.getName()),
-  CHILDARGMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX 
+ ARGMAX.getName());
+  CHILDARGMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX 
+ ARGMAX.getName()),
+
+  // funnel aggregate functions
+  FUNNELCOUNT("funnelCount");
 
   private static final Set<String> NAMES = 
Arrays.stream(values()).flatMap(func -> Stream.of(func.name(),
       func.getName(), 
func.getName().toLowerCase())).collect(Collectors.toSet());


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to