This is an automated email from the ASF dual-hosted git repository. tingchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 17ee2024dd FUNNEL_COUNT Aggregation Function (#10867) 17ee2024dd is described below commit 17ee2024dd7d49ce2f5ebf00ef9d71b6ca589a00 Author: dario-liberman <130933669+dario-liber...@users.noreply.github.com> AuthorDate: Wed Jun 21 19:26:57 2023 +0200 FUNNEL_COUNT Aggregation Function (#10867) * New Funnel Aggregation Function * Funnel analytics support - FUNNEL_COUNT aggregation function * Delete FunnelAggregationFunction.java * Simplify Tests * Simplify Tests * Simplify Tests * within -> across Fix javadoc * Update FunnelCountAggregationFunction.java Address comments --------- Co-authored-by: Dario Liberman <dario.liber...@uber.com> --- .../org/apache/pinot/common/utils/DataSchema.java | 3 + .../blocks/results/AggregationResultsBlock.java | 4 + .../function/AggregationFunctionFactory.java | 3 + .../function/FunnelCountAggregationFunction.java | 511 +++++++++++++++++++++ .../pinot/queries/BaseFunnelCountQueriesTest.java | 252 ++++++++++ .../queries/FunnelCountQueriesNonSortedTest.java | 57 +++ .../queries/FunnelCountQueriesSortedTest.java | 65 +++ .../org/apache/pinot/queries/QueriesTestUtils.java | 3 +- .../pinot/segment/spi/AggregationFunctionType.java | 5 +- 9 files changed, 901 insertions(+), 2 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java index d8fedfc3d7..4020aff70a 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyOrder; import com.google.common.collect.Ordering; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.longs.LongArrayList; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -469,6 +470,8 @@ public class DataSchema { private static long[] toLongArray(Object value) { if (value instanceof long[]) { return (long[]) value; + } else if (value instanceof LongArrayList) { + return ((LongArrayList) value).elements(); } else { int[] intValues = (int[]) value; int length = intValues.length; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java index b10ebdd3b9..2d816bffca 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java @@ -19,6 +19,7 @@ package org.apache.pinot.core.operator.blocks.results; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.longs.LongArrayList; import java.io.IOException; import java.math.BigDecimal; import java.util.Collections; @@ -182,6 +183,9 @@ public class AggregationResultsBlock extends BaseResultsBlock { case DOUBLE_ARRAY: dataTableBuilder.setColumn(index, ((DoubleArrayList) result).elements()); break; + case LONG_ARRAY: + dataTableBuilder.setColumn(index, ((LongArrayList) result).elements()); + break; default: throw new IllegalStateException("Illegal column data type in final result: " + columnDataType); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java index 06fbb1db66..7f96072c9e 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java @@ -358,6 +358,9 @@ public class AggregationFunctionFactory { case ARGMIN: throw new IllegalArgumentException( "Aggregation function: " + function + " is only supported in selection without alias."); + case FUNNELCOUNT: + return new FunnelCountAggregationFunction(arguments); + default: throw new IllegalArgumentException(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java new file mode 100644 index 0000000000..4eecad1002 --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java @@ -0,0 +1,511 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import com.google.common.base.Preconditions; +import it.unimi.dsi.fastutil.longs.LongArrayList; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder; +import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; +import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.segment.spi.index.reader.Dictionary; +import org.roaringbitmap.RoaringBitmap; + + +/** + * The {@code FunnelCountAggregationFunction} calculates the number of step conversions for a given partition column and + * a list of boolean expressions. + * <p>IMPORTANT: This function relies on the partition column being partitioned for each segment, where there are no + * common values across different segments. + * <p>This function calculates the exact number of step matches per partition key within the segment, then sums up the + * results from different segments. + * + * Example: + * SELECT + * dateTrunc('day', timestamp) AS ts, + * FUNNEL_COUNT( + * STEPS(url = '/addToCart', url = '/checkout', url = '/orderConfirmation'), + * CORRELATED_BY(user) + * ) as step_counts + * FROM user_log + * WHERE url in ('/addToCart', '/checkout', '/orderConfirmation') + * GROUP BY 1 + */ +public class FunnelCountAggregationFunction implements AggregationFunction<List<Long>, LongArrayList> { + final List<ExpressionContext> _expressions; + final List<ExpressionContext> _stepExpressions; + final List<ExpressionContext> _correlateByExpressions; + final ExpressionContext _primaryCorrelationCol; + final int _numSteps; + + final SegmentAggregationStrategy<?, List<Long>> _sortedAggregationStrategy; + final SegmentAggregationStrategy<?, List<Long>> _bitmapAggregationStrategy; + + public FunnelCountAggregationFunction(List<ExpressionContext> expressions) { + _expressions = expressions; + _correlateByExpressions = Option.CORRELATE_BY.getInputExpressions(expressions); + _primaryCorrelationCol = Option.CORRELATE_BY.getFirstInputExpression(expressions); + _stepExpressions = Option.STEPS.getInputExpressions(expressions); + _numSteps = _stepExpressions.size(); + _sortedAggregationStrategy = new SortedAggregationStrategy(); + _bitmapAggregationStrategy = new BitmapAggregationStrategy(); + } + + @Override + public String getResultColumnName() { + return getType().getName().toLowerCase() + "(" + _expressions.stream().map(ExpressionContext::toString) + .collect(Collectors.joining(",")) + ")"; + } + + @Override + public List<ExpressionContext> getInputExpressions() { + final List<ExpressionContext> inputs = new ArrayList<>(); + inputs.addAll(_correlateByExpressions); + inputs.addAll(_stepExpressions); + return inputs; + } + + @Override + public AggregationFunctionType getType() { + return AggregationFunctionType.FUNNELCOUNT; + } + + @Override + public AggregationResultHolder createAggregationResultHolder() { + return new ObjectAggregationResultHolder(); + } + + @Override + public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) { + return new ObjectGroupByResultHolder(initialCapacity, maxCapacity); + } + + @Override + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregate(length, aggregationResultHolder, blockValSetMap); + } + + @Override + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregateGroupBySV(length, groupKeyArray, + groupByResultHolder, blockValSetMap); + } + + @Override + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregateGroupByMV(length, groupKeysArray, + groupByResultHolder, blockValSetMap); + } + + @Override + public List<Long> extractAggregationResult(AggregationResultHolder aggregationResultHolder) { + return getAggregationStrategyByAggregationResult(aggregationResultHolder.getResult()).extractAggregationResult( + aggregationResultHolder); + } + + @Override + public List<Long> extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { + return getAggregationStrategyByAggregationResult(groupByResultHolder.getResult(groupKey)).extractGroupByResult( + groupByResultHolder, groupKey); + } + + @Override + public List<Long> merge(List<Long> a, List<Long> b) { + int length = a.size(); + Preconditions.checkState(length == b.size(), "The two operand arrays are not of the same size! provided %s, %s", + length, b.size()); + + LongArrayList result = toLongArrayList(a); + long[] elements = result.elements(); + for (int i = 0; i < length; i++) { + elements[i] += b.get(i); + } + return result; + } + + @Override + public ColumnDataType getIntermediateResultColumnType() { + return ColumnDataType.OBJECT; + } + + @Override + public ColumnDataType getFinalResultColumnType() { + return ColumnDataType.LONG_ARRAY; + } + + @Override + public LongArrayList extractFinalResult(List<Long> result) { + return toLongArrayList(result); + } + + @Override + public String toExplainString() { + StringBuilder stringBuilder = new StringBuilder(getType().getName()).append('('); + int numArguments = getInputExpressions().size(); + if (numArguments > 0) { + stringBuilder.append(getInputExpressions().get(0).toString()); + for (int i = 1; i < numArguments; i++) { + stringBuilder.append(", ").append(getInputExpressions().get(i).toString()); + } + } + return stringBuilder.append(')').toString(); + } + + private static LongArrayList toLongArrayList(List<Long> longList) { + return longList instanceof LongArrayList ? ((LongArrayList) longList).clone() : new LongArrayList(longList); + } + + private int[] getCorrelationIds(Map<ExpressionContext, BlockValSet> blockValSetMap) { + return blockValSetMap.get(_primaryCorrelationCol).getDictionaryIdsSV(); + } + + private int[][] getSteps(Map<ExpressionContext, BlockValSet> blockValSetMap) { + final int[][] steps = new int[_numSteps][]; + for (int n = 0; n < _numSteps; n++) { + final BlockValSet stepBlockValSet = blockValSetMap.get(_stepExpressions.get(n)); + steps[n] = stepBlockValSet.getIntValuesSV(); + } + return steps; + } + + private boolean isSorted(Map<ExpressionContext, BlockValSet> blockValSetMap) { + final Dictionary primaryCorrelationDictionary = blockValSetMap.get(_primaryCorrelationCol).getDictionary(); + if (primaryCorrelationDictionary == null) { + throw new IllegalArgumentException( + "CORRELATE_BY column in FUNNELCOUNT aggregation function not supported, please use a dictionary encoded " + + "column."); + } + return primaryCorrelationDictionary.isSorted(); + } + + private SegmentAggregationStrategy<?, List<Long>> getAggregationStrategyByBlockValSetMap( + Map<ExpressionContext, BlockValSet> blockValSetMap) { + return isSorted(blockValSetMap) ? _sortedAggregationStrategy : _bitmapAggregationStrategy; + } + + private SegmentAggregationStrategy<?, List<Long>> getAggregationStrategyByAggregationResult(Object aggResult) { + return aggResult instanceof SortedAggregationResult ? _sortedAggregationStrategy : _bitmapAggregationStrategy; + } + + enum Option { + STEPS("steps"), + CORRELATE_BY("correlateby"); + + final String _name; + + Option(String name) { + _name = name; + } + + boolean matches(ExpressionContext expression) { + if (expression.getType() != ExpressionContext.Type.FUNCTION) { + return false; + } + return _name.equals(expression.getFunction().getFunctionName()); + } + + Optional<ExpressionContext> find(List<ExpressionContext> expressions) { + return expressions.stream().filter(this::matches).findFirst(); + } + + public List<ExpressionContext> getInputExpressions(List<ExpressionContext> expressions) { + return this.find(expressions).map(exp -> exp.getFunction().getArguments()) + .orElseThrow(() -> new IllegalStateException("FUNNELCOUNT requires " + _name)); + } + + public ExpressionContext getFirstInputExpression(List<ExpressionContext> expressions) { + return this.find(expressions) + .flatMap(exp -> exp.getFunction().getArguments().stream().findFirst()) + .orElseThrow(() -> new IllegalStateException("FUNNELCOUNT: " + _name + " requires an argument.")); + } + } + + /** + * Interface for segment aggregation strategy. + * + * <p>The implementation should be stateless, and can be shared among multiple segments in multiple threads. The + * result for each segment should be stored and passed in via the result holder. + * There should be no assumptions beyond segment boundaries, different aggregation strategies may be utilized + * across different segments for a given query. + * + * @param <A> Aggregation result accumulated across blocks within segment, kept by result holder. + * @param <I> Intermediate result at segment level (extracted from aforementioned aggregation result). + */ + @ThreadSafe + static abstract class SegmentAggregationStrategy<A, I> { + + /** + * Returns an aggregation result for this aggregation strategy to be kept in a result holder (aggregation only). + */ + abstract A createAggregationResult(); + + public A getAggregationResultGroupBy(GroupByResultHolder groupByResultHolder, int groupKey) { + A aggResult = groupByResultHolder.getResult(groupKey); + if (aggResult == null) { + aggResult = createAggregationResult(); + groupByResultHolder.setValueForKey(groupKey, aggResult); + } + return aggResult; + } + + public A getAggregationResult(AggregationResultHolder aggregationResultHolder) { + A aggResult = aggregationResultHolder.getResult(); + if (aggResult == null) { + aggResult = createAggregationResult(); + aggregationResultHolder.setValue(aggResult); + } + return aggResult; + } + + /** + * Performs aggregation on the given block value sets (aggregation only). + */ + abstract void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap); + + /** + * Performs aggregation on the given group key array and block value sets (aggregation group-by on single-value + * columns). + */ + abstract void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap); + + /** + * Performs aggregation on the given group keys array and block value sets (aggregation group-by on multi-value + * columns). + */ + abstract void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap); + + /** + * Extracts the intermediate result from the aggregation result holder (aggregation only). + */ + public I extractAggregationResult(AggregationResultHolder aggregationResultHolder) { + return extractIntermediateResult(aggregationResultHolder.getResult()); + } + + /** + * Extracts the intermediate result from the group-by result holder for the given group key (aggregation group-by). + */ + public I extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { + return extractIntermediateResult(groupByResultHolder.getResult(groupKey)); + } + + abstract I extractIntermediateResult(A aggregationResult); + } + + /** + * Aggregation strategy leveraging roaring bitmap algebra (unions/intersections). + */ + class BitmapAggregationStrategy extends SegmentAggregationStrategy<RoaringBitmap[], List<Long>> { + + @Override + public RoaringBitmap[] createAggregationResult() { + final RoaringBitmap[] stepsBitmaps = new RoaringBitmap[_numSteps]; + for (int n = 0; n < _numSteps; n++) { + stepsBitmaps[n] = new RoaringBitmap(); + } + return stepsBitmaps; + } + + @Override + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + final int[] correlationIds = getCorrelationIds(blockValSetMap); + final int[][] steps = getSteps(blockValSetMap); + + final RoaringBitmap[] stepsBitmaps = getAggregationResult(aggregationResultHolder); + + for (int n = 0; n < _numSteps; n++) { + for (int i = 0; i < length; i++) { + if (steps[n][i] > 0) { + stepsBitmaps[n].add(correlationIds[i]); + } + } + } + } + + @Override + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + final int[] correlationIds = getCorrelationIds(blockValSetMap); + final int[][] steps = getSteps(blockValSetMap); + + for (int n = 0; n < _numSteps; n++) { + for (int i = 0; i < length; i++) { + final int groupKey = groupKeyArray[i]; + if (steps[n][i] > 0) { + getAggregationResultGroupBy(groupByResultHolder, groupKey)[n].add(correlationIds[i]); + } + } + } + } + + @Override + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + final int[] correlationIds = getCorrelationIds(blockValSetMap); + final int[][] steps = getSteps(blockValSetMap); + + for (int n = 0; n < _numSteps; n++) { + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + if (steps[n][i] > 0) { + getAggregationResultGroupBy(groupByResultHolder, groupKey)[n].add(correlationIds[i]); + } + } + } + } + } + + @Override + public List<Long> extractIntermediateResult(RoaringBitmap[] stepsBitmaps) { + if (stepsBitmaps == null) { + return new LongArrayList(_numSteps); + } + + long[] result = new long[_numSteps]; + result[0] = stepsBitmaps[0].getCardinality(); + for (int i = 1; i < _numSteps; i++) { + // intersect this step with previous step + stepsBitmaps[i].and(stepsBitmaps[i - 1]); + result[i] = stepsBitmaps[i].getCardinality(); + } + return LongArrayList.wrap(result); + } + } + + /** + * Aggregation strategy for segments sorted by the main correlation column. + */ + class SortedAggregationStrategy extends SegmentAggregationStrategy<SortedAggregationResult, List<Long>> { + + @Override + public SortedAggregationResult createAggregationResult() { + return new SortedAggregationResult(); + } + + @Override + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + final int[] correlationIds = getCorrelationIds(blockValSetMap); + final int[][] steps = getSteps(blockValSetMap); + + final SortedAggregationResult agg = getAggregationResult(aggregationResultHolder); + + for (int i = 0; i < length; i++) { + agg.sortedCount(steps, i, correlationIds[i]); + } + } + + @Override + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + final int[] correlationIds = getCorrelationIds(blockValSetMap); + final int[][] steps = getSteps(blockValSetMap); + + for (int i = 0; i < length; i++) { + final int groupKey = groupKeyArray[i]; + final SortedAggregationResult agg = getAggregationResultGroupBy(groupByResultHolder, groupKey); + + agg.sortedCount(steps, i, correlationIds[i]); + } + } + + @Override + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + final int[] correlationIds = getCorrelationIds(blockValSetMap); + final int[][] steps = getSteps(blockValSetMap); + + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + final SortedAggregationResult agg = getAggregationResultGroupBy(groupByResultHolder, groupKey); + + agg.sortedCount(steps, i, correlationIds[i]); + } + } + } + + @Override + public List<Long> extractIntermediateResult(SortedAggregationResult agg) { + if (agg == null) { + return new LongArrayList(_numSteps); + } + + return LongArrayList.wrap(agg.extractResult()); + } + } + + /** + * Aggregation result data structure leveraged by sorted aggregation strategy. + */ + class SortedAggregationResult { + public long[] _stepCounters = new long[_numSteps]; + public int _lastCorrelationId = Integer.MIN_VALUE; + public boolean[] _correlatedSteps = new boolean[_numSteps]; + + public void sortedCount(int[][] steps, int i, int correlationId) { + if (correlationId == _lastCorrelationId) { + // same correlation as before, keep accumulating. + for (int n = 0; n < _numSteps; n++) { + _correlatedSteps[n] |= steps[n][i] > 0; + } + } else { + // End of correlation group, calculate funnel conversion counts + incrStepCounters(); + + // initialize next correlation group + for (int n = 0; n < _numSteps; n++) { + _correlatedSteps[n] = steps[n][i] > 0; + } + _lastCorrelationId = correlationId; + } + } + + void incrStepCounters() { + for (int n = 0; n < _numSteps; n++) { + if (!_correlatedSteps[n]) { + break; + } + _stepCounters[n]++; + } + } + + public long[] extractResult() { + // count last correlation id left open + incrStepCounters(); + + return _stepCounters; + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java new file mode 100644 index 0000000000..ef5c7d596f --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java @@ -0,0 +1,252 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.queries; + +import it.unimi.dsi.fastutil.longs.LongArrayList; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.commons.io.FileUtils; +import org.apache.pinot.common.utils.HashUtil; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock; +import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock; +import org.apache.pinot.core.operator.query.AggregationOperator; +import org.apache.pinot.core.operator.query.GroupByOperator; +import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; +import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + + +/** + * Base queries test for FUNNEL_COUNT queries. + * Each strategy gets its own test. + */ +@SuppressWarnings("rawtypes") +abstract public class BaseFunnelCountQueriesTest extends BaseQueriesTest { + protected static final File INDEX_DIR = + new File(FileUtils.getTempDirectory(), "FunnelCountQueriesTest"); + protected static final String RAW_TABLE_NAME = "testTable"; + protected static final String SEGMENT_NAME = "testSegment"; + protected static final Random RANDOM = new Random(); + + protected static final int NUM_RECORDS = 2000; + protected static final int MAX_VALUE = 1000; + protected static final int NUM_GROUPS = 100; + protected static final int FILTER_LIMIT = 50; + protected static final String ID_COLUMN = "idColumn"; + protected static final String STEP_COLUMN = "stepColumn"; + protected static final String[] STEPS = {"A", "B"}; + protected static final Schema SCHEMA = new Schema.SchemaBuilder() + .addSingleValueDimension(ID_COLUMN, DataType.INT) + .addSingleValueDimension(STEP_COLUMN, DataType.STRING) + .build(); + protected static final TableConfigBuilder TABLE_CONFIG_BUILDER = + new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME); + + private Set<Integer>[] _values = new Set[2]; + private List<Integer> _all = new ArrayList<>(); + private IndexSegment _indexSegment; + private List<IndexSegment> _indexSegments; + + protected abstract int getExpectedNumEntriesScannedInFilter(); + protected abstract TableConfig getTableConfig(); + protected abstract IndexSegment buildSegment(List<GenericRow> records) throws Exception; + + @Override + protected String getFilter() { + return String.format(" WHERE idColumn >= %s", FILTER_LIMIT); + } + + @Override + protected IndexSegment getIndexSegment() { + return _indexSegment; + } + + @Override + protected List<IndexSegment> getIndexSegments() { + return _indexSegments; + } + + @BeforeClass + public void setUp() + throws Exception { + FileUtils.deleteDirectory(INDEX_DIR); + + List<GenericRow> records = genereateRows(); + _indexSegment = buildSegment(records); + _indexSegments = Arrays.asList(_indexSegment, _indexSegment); + } + + private List<GenericRow> genereateRows() { + List<GenericRow> records = new ArrayList<>(NUM_RECORDS); + int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE); + _values[0] = new HashSet<>(hashMapCapacity); + _values[1] = new HashSet<>(hashMapCapacity); + for (int i = 0; i < NUM_RECORDS; i++) { + int value = RANDOM.nextInt(MAX_VALUE); + GenericRow record = new GenericRow(); + record.putValue(ID_COLUMN, value); + record.putValue(STEP_COLUMN, STEPS[i % 2]); + records.add(record); + _all.add(Integer.hashCode(value)); + _values[i % 2].add(Integer.hashCode(value)); + } + return records; + } + + @Test + public void testAggregationOnly() { + String query = String.format("SELECT " + + "FUNNEL_COUNT(" + + " STEPS(stepColumn = 'A', stepColumn = 'B')," + + " CORRELATE_BY(idColumn)" + + ") FROM testTable"); + + // Inner segment + Predicate<Integer> filter = id -> id >= FILTER_LIMIT; + long expectedFilteredNumDocs = _all.stream().filter(filter).count(); + Set<Integer> filteredA = _values[0].stream().filter(filter).collect(Collectors.toSet()); + Set<Integer> filteredB = _values[1].stream().filter(filter).collect(Collectors.toSet()); + Set<Integer> intersection = new HashSet<>(filteredA); + intersection.retainAll(filteredB); + long[] expectedResult = { filteredA.size(), intersection.size() }; + + Operator operator = getOperatorWithFilter(query); + assertTrue(operator instanceof AggregationOperator); + AggregationResultsBlock resultsBlock = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), + expectedFilteredNumDocs, getExpectedNumEntriesScannedInFilter(), 2 * expectedFilteredNumDocs, NUM_RECORDS); + List<Object> aggregationResult = resultsBlock.getResults(); + assertNotNull(aggregationResult); + assertEquals(aggregationResult.size(), 1); + for (int i = 0; i < 2; i++) { + assertEquals(((LongArrayList) aggregationResult.get(0)).getLong(i), expectedResult[i]); + } + + // Inter segments (expect 4 * inner segment result) + for (int i = 0; i < 2; i++) { + expectedResult[i] = 4 * expectedResult[i]; + } + Object[] expectedResults = { expectedResult }; + + QueriesTestUtils.testInterSegmentsResult(getBrokerResponseWithFilter(query), + 4 * expectedFilteredNumDocs, 4 * getExpectedNumEntriesScannedInFilter(), 4 * 2 * expectedFilteredNumDocs, + 4 * NUM_RECORDS, expectedResults); + } + + @Test + public void testAggregationGroupBy() { + String query = String.format("SELECT " + + "MOD(idColumn, %s), " + + "FUNNEL_COUNT(" + + " STEPS(stepColumn = 'A', stepColumn = 'B')," + + " CORRELATE_BY(idColumn)" + + ") FROM testTable " + + "WHERE idColumn >= %s " + + "GROUP BY 1 ORDER BY 1 LIMIT %s", NUM_GROUPS, FILTER_LIMIT, NUM_GROUPS); + + // Inner segment + Set<Integer>[] filteredA = new Set[NUM_GROUPS]; + Set<Integer>[] filteredB = new Set[NUM_GROUPS]; + Set<Integer>[] intersection = new Set[NUM_GROUPS]; + long[][] expectedResult = new long[NUM_GROUPS][]; + + long expectedFilteredNumDocs = _all.stream().filter(id -> id >= FILTER_LIMIT).count(); + + int expectedNumGroups = 0; + for (int i = 0; i < NUM_GROUPS; i++) { + final int group = i; + Predicate<Integer> filter = id -> id >= FILTER_LIMIT && id % NUM_GROUPS == group; + filteredA[group] = _values[0].stream().filter(filter).collect(Collectors.toSet()); + filteredB[group] = _values[1].stream().filter(filter).collect(Collectors.toSet()); + intersection[group] = new HashSet<>(filteredA[group]); + intersection[group].retainAll(filteredB[group]); + if (!filteredA[i].isEmpty() || !filteredB[i].isEmpty()) { + expectedNumGroups++; + expectedResult[group] = new long[] { filteredA[group].size(), intersection[group].size() }; + } + } + + // Inner segment + GroupByOperator groupByOperator = getOperator(query); + GroupByResultsBlock resultsBlock = groupByOperator.nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(groupByOperator.getExecutionStatistics(), + expectedFilteredNumDocs, getExpectedNumEntriesScannedInFilter(), 2 * expectedFilteredNumDocs, NUM_RECORDS); + + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + int key = ((Double) groupKey._keys[0]).intValue(); + assertEquals(aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + new LongArrayList(expectedResult[key])); + } + assertEquals(numGroups, expectedNumGroups); + + // Inter segments (expect 4 * inner segment result) + List<Object[]> expectedRows = new ArrayList<>(); + for (int i = 0; i < NUM_GROUPS; i++) { + if (expectedResult[i] == null) { + continue; + } + for (int step = 0; step < 2; step++) { + expectedResult[i][step] = 4 * expectedResult[i][step]; + } + Object[] expectedRow = { Double.valueOf(i), expectedResult[i] }; + expectedRows.add(expectedRow); + } + + // Inter segments (expect 4 * inner segment result) + QueriesTestUtils.testInterSegmentsResult(getBrokerResponse(query), + 4 * expectedFilteredNumDocs, 4 * getExpectedNumEntriesScannedInFilter(), 4 * 2 * expectedFilteredNumDocs, + 4 * NUM_RECORDS, expectedRows); + } + + @AfterClass + public void tearDown() + throws IOException { + _indexSegment.destroy(); + FileUtils.deleteDirectory(INDEX_DIR); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java new file mode 100644 index 0000000000..c89a5d74c9 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.queries; + +import java.util.Collections; +import java.util.List; +import org.apache.pinot.segment.local.indexsegment.mutable.MutableSegmentImplTestUtils; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.MutableSegment; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.data.readers.GenericRow; + + +/** + * Queries test for FUNNEL_COUNT queries. + */ +@SuppressWarnings("rawtypes") +public class FunnelCountQueriesNonSortedTest extends BaseFunnelCountQueriesTest { + + @Override + protected int getExpectedNumEntriesScannedInFilter() { + return NUM_RECORDS; + } + + @Override + protected TableConfig getTableConfig() { + return TABLE_CONFIG_BUILDER.build(); + } + + @Override + protected IndexSegment buildSegment(List<GenericRow> records) + throws Exception { + MutableSegment mutableSegment = MutableSegmentImplTestUtils + .createMutableSegmentImpl(SCHEMA, Collections.emptySet(), Collections.emptySet(), Collections.emptySet(), + false); + for (GenericRow record : records) { + mutableSegment.index(record, null); + } + return mutableSegment; + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java new file mode 100644 index 0000000000..f06fe26637 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.queries; + +import java.io.File; +import java.util.Comparator; +import java.util.List; +import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader; +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl; +import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.utils.ReadMode; + + +/** + * Queries test for FUNNEL_COUNT queries using sorted strategy. + */ +@SuppressWarnings("rawtypes") +public class FunnelCountQueriesSortedTest extends BaseFunnelCountQueriesTest { + + @Override + protected int getExpectedNumEntriesScannedInFilter() { + return 0; + } + + @Override + protected TableConfig getTableConfig() { + return TABLE_CONFIG_BUILDER.setSortedColumn(ID_COLUMN).build(); + } + + @Override + protected IndexSegment buildSegment(List<GenericRow> records) + throws Exception { + // Simulate PinotSegmentSorter + records.sort(Comparator.comparingInt(rec -> (Integer) rec.getValue(ID_COLUMN))); + + SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(getTableConfig(), SCHEMA); + segmentGeneratorConfig.setTableName(RAW_TABLE_NAME); + segmentGeneratorConfig.setSegmentName(SEGMENT_NAME); + segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath()); + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records)); + driver.build(); + return ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java index c5ba9cdb8b..c5c9eea9a3 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java @@ -133,7 +133,8 @@ public class QueriesTestUtils { private static void validateRows(List<Object[]> actual, List<Object[]> expected) { assertEquals(actual.size(), expected.size()); for (int i = 0; i < actual.size(); i++) { - assertEquals(actual.get(i), expected.get(i)); + // Generic assertEquals delegates to assertArrayEquals, which can test for equality of array values in rows. + assertEquals((Object) actual.get(i), (Object) expected.get(i)); } } diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index 7b2a02d666..5201fdcd2b 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -116,7 +116,10 @@ public enum AggregationFunctionType { PARENTARGMIN(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + ARGMIN.getName()), PARENTARGMAX(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + ARGMAX.getName()), CHILDARGMIN(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + ARGMIN.getName()), - CHILDARGMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + ARGMAX.getName()); + CHILDARGMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + ARGMAX.getName()), + + // funnel aggregate functions + FUNNELCOUNT("funnelCount"); private static final Set<String> NAMES = Arrays.stream(values()).flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toLowerCase())).collect(Collectors.toSet()); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org