This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new aa5f318 add mode aggregation function (#7318)
aa5f318 is described below
commit aa5f318d0708b9b0a3e570706c4236df94d29141
Author: Yash Agarwal <[email protected]>
AuthorDate: Thu Aug 19 23:15:29 2021 +0530
add mode aggregation function (#7318)
Add support for Mode Function.
Mode accepts an additional parameter to reduce multiple modes to a single
value: MIN/MAX/AVG
---
.../function/AggregationFunctionTypeTest.java | 1 +
.../apache/pinot/core/common/ObjectSerDeUtils.java | 152 +++-
.../function/AggregationFunctionFactory.java | 2 +
.../function/ModeAggregationFunction.java | 691 +++++++++++++++
.../pinot/core/common/ObjectSerDeUtilsTest.java | 68 ++
.../function/AggregationFunctionFactoryTest.java | 7 +
.../org/apache/pinot/queries/ModeQueriesTest.java | 949 +++++++++++++++++++++
.../pinot/segment/spi/AggregationFunctionType.java | 1 +
8 files changed, 1869 insertions(+), 2 deletions(-)
diff --git
a/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java
b/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java
index 70a25b5..e325b0a 100644
---
a/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java
+++
b/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java
@@ -32,6 +32,7 @@ public class AggregationFunctionTypeTest {
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MaX"),
AggregationFunctionType.MAX);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("SuM"),
AggregationFunctionType.SUM);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("AvG"),
AggregationFunctionType.AVG);
+
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MoDe"),
AggregationFunctionType.MODE);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MiNmAxRaNgE"),
AggregationFunctionType.MINMAXRANGE);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("DiStInCtCoUnT"),
AggregationFunctionType.DISTINCTCOUNT);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("DiStInCtCoUnThLl"),
AggregationFunctionType.DISTINCTCOUNTHLL);
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
index 7417daa..123b03b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
@@ -22,16 +22,24 @@ import
com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.google.common.primitives.Longs;
import com.tdunning.math.stats.MergingDigest;
import com.tdunning.math.stats.TDigest;
+import it.unimi.dsi.fastutil.doubles.Double2LongMap;
+import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleIterator;
import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
import it.unimi.dsi.fastutil.doubles.DoubleSet;
+import it.unimi.dsi.fastutil.floats.Float2LongMap;
+import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap;
import it.unimi.dsi.fastutil.floats.FloatIterator;
import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
import it.unimi.dsi.fastutil.floats.FloatSet;
+import it.unimi.dsi.fastutil.ints.Int2LongMap;
+import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
+import it.unimi.dsi.fastutil.longs.Long2LongMap;
+import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
@@ -97,7 +105,11 @@ public class ObjectSerDeUtils {
BytesSet(19),
IdSet(20),
List(21),
- BigDecimal(22);
+ BigDecimal(22),
+ Int2LongMap(23),
+ Long2LongMap(24),
+ Float2LongMap(25),
+ Double2LongMap(26);
private final int _value;
ObjectType(int value) {
@@ -127,6 +139,14 @@ public class ObjectSerDeUtils {
return ObjectType.HyperLogLog;
} else if (value instanceof QuantileDigest) {
return ObjectType.QuantileDigest;
+ } else if (value instanceof Int2LongMap) {
+ return ObjectType.Int2LongMap;
+ } else if (value instanceof Long2LongMap) {
+ return ObjectType.Long2LongMap;
+ } else if (value instanceof Float2LongMap) {
+ return ObjectType.Float2LongMap;
+ } else if (value instanceof Double2LongMap) {
+ return ObjectType.Double2LongMap;
} else if (value instanceof Map) {
return ObjectType.Map;
} else if (value instanceof IntSet) {
@@ -874,6 +894,130 @@ public class ObjectSerDeUtils {
}
};
+ public static final ObjectSerDe<Int2LongMap> INT_2_LONG_MAP_SER_DE = new
ObjectSerDe<Int2LongMap>() {
+
+ @Override
+ public byte[] serialize(Int2LongMap map) {
+ int size = map.size();
+ byte[] bytes = new byte[Integer.BYTES + size * (Integer.BYTES +
Long.BYTES)];
+ ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+ byteBuffer.putInt(size);
+ for (Int2LongMap.Entry entry : map.int2LongEntrySet()) {
+ byteBuffer.putInt(entry.getIntKey());
+ byteBuffer.putLong(entry.getLongValue());
+ }
+ return bytes;
+ }
+
+ @Override
+ public Int2LongOpenHashMap deserialize(byte[] bytes) {
+ return deserialize(ByteBuffer.wrap(bytes));
+ }
+
+ @Override
+ public Int2LongOpenHashMap deserialize(ByteBuffer byteBuffer) {
+ int size = byteBuffer.getInt();
+ Int2LongOpenHashMap map = new Int2LongOpenHashMap(size);
+ for (int i = 0; i < size; i++) {
+ map.put(byteBuffer.getInt(), byteBuffer.getLong());
+ }
+ return map;
+ }
+ };
+
+ public static final ObjectSerDe<Long2LongMap> LONG_2_LONG_MAP_SER_DE = new
ObjectSerDe<Long2LongMap>() {
+
+ @Override
+ public byte[] serialize(Long2LongMap map) {
+ int size = map.size();
+ byte[] bytes = new byte[Integer.BYTES + size * (Long.BYTES +
Long.BYTES)];
+ ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+ byteBuffer.putInt(size);
+ for (Long2LongMap.Entry entry : map.long2LongEntrySet()) {
+ byteBuffer.putLong(entry.getLongKey());
+ byteBuffer.putLong(entry.getLongValue());
+ }
+ return bytes;
+ }
+
+ @Override
+ public Long2LongOpenHashMap deserialize(byte[] bytes) {
+ return deserialize(ByteBuffer.wrap(bytes));
+ }
+
+ @Override
+ public Long2LongOpenHashMap deserialize(ByteBuffer byteBuffer) {
+ int size = byteBuffer.getInt();
+ Long2LongOpenHashMap map = new Long2LongOpenHashMap(size);
+ for (int i = 0; i < size; i++) {
+ map.put(byteBuffer.getLong(), byteBuffer.getLong());
+ }
+ return map;
+ }
+ };
+
+ public static final ObjectSerDe<Float2LongMap> FLOAT_2_LONG_MAP_SER_DE = new
ObjectSerDe<Float2LongMap>() {
+
+ @Override
+ public byte[] serialize(Float2LongMap map) {
+ int size = map.size();
+ byte[] bytes = new byte[Integer.BYTES + size * (Float.BYTES +
Long.BYTES)];
+ ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+ byteBuffer.putInt(size);
+ for (Float2LongMap.Entry entry : map.float2LongEntrySet()) {
+ byteBuffer.putFloat(entry.getFloatKey());
+ byteBuffer.putLong(entry.getLongValue());
+ }
+ return bytes;
+ }
+
+ @Override
+ public Float2LongOpenHashMap deserialize(byte[] bytes) {
+ return deserialize(ByteBuffer.wrap(bytes));
+ }
+
+ @Override
+ public Float2LongOpenHashMap deserialize(ByteBuffer byteBuffer) {
+ int size = byteBuffer.getInt();
+ Float2LongOpenHashMap map = new Float2LongOpenHashMap(size);
+ for (int i = 0; i < size; i++) {
+ map.put(byteBuffer.getFloat(), byteBuffer.getLong());
+ }
+ return map;
+ }
+ };
+
+ public static final ObjectSerDe<Double2LongMap> DOUBLE_2_LONG_MAP_SER_DE =
new ObjectSerDe<Double2LongMap>() {
+
+ @Override
+ public byte[] serialize(Double2LongMap map) {
+ int size = map.size();
+ byte[] bytes = new byte[Integer.BYTES + size * (Double.BYTES +
Long.BYTES)];
+ ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+ byteBuffer.putInt(size);
+ for (Double2LongMap.Entry entry : map.double2LongEntrySet()) {
+ byteBuffer.putDouble(entry.getDoubleKey());
+ byteBuffer.putLong(entry.getLongValue());
+ }
+ return bytes;
+ }
+
+ @Override
+ public Double2LongOpenHashMap deserialize(byte[] bytes) {
+ return deserialize(ByteBuffer.wrap(bytes));
+ }
+
+ @Override
+ public Double2LongOpenHashMap deserialize(ByteBuffer byteBuffer) {
+ int size = byteBuffer.getInt();
+ Double2LongOpenHashMap map = new Double2LongOpenHashMap(size);
+ for (int i = 0; i < size; i++) {
+ map.put(byteBuffer.getDouble(), byteBuffer.getLong());
+ }
+ return map;
+ }
+ };
+
// NOTE: DO NOT change the order, it has to be the same order as the
ObjectType
//@formatter:off
private static final ObjectSerDe[] SER_DES = {
@@ -899,7 +1043,11 @@ public class ObjectSerDeUtils {
BYTES_SET_SER_DE,
ID_SET_SER_DE,
LIST_SER_DE,
- BIGDECIMAL_SER_DE
+ BIGDECIMAL_SER_DE,
+ INT_2_LONG_MAP_SER_DE,
+ LONG_2_LONG_MAP_SER_DE,
+ FLOAT_2_LONG_MAP_SER_DE,
+ DOUBLE_2_LONG_MAP_SER_DE
};
//@formatter:on
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index bf7c5aa..ccd45fc 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -122,6 +122,8 @@ public class AggregationFunctionFactory {
return new SumPrecisionAggregationFunction(arguments);
case AVG:
return new AvgAggregationFunction(firstArgument);
+ case MODE:
+ return new ModeAggregationFunction(arguments);
case MINMAXRANGE:
return new MinMaxRangeAggregationFunction(firstArgument);
case DISTINCTCOUNT:
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java
new file mode 100644
index 0000000..b67152b
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java
@@ -0,0 +1,691 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import com.google.common.base.Preconditions;
+import it.unimi.dsi.fastutil.doubles.Double2LongMap;
+import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap;
+import it.unimi.dsi.fastutil.floats.Float2LongMap;
+import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap;
+import it.unimi.dsi.fastutil.ints.Int2IntMap;
+import it.unimi.dsi.fastutil.ints.Int2IntMaps;
+import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
+import it.unimi.dsi.fastutil.ints.Int2LongMap;
+import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;
+import it.unimi.dsi.fastutil.longs.Long2LongMap;
+import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap;
+import it.unimi.dsi.fastutil.objects.ObjectIterator;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+
+
+/**
+ * This function is used for Mode calculations.
+ * <p>The function can be used as MODE(expression, multiModeReducerType)
+ * <p>Following arguments are supported:
+ * <ul>
+ * <li>Expression: expression that contains the column to be calculated mode
on, can be any Numeric column</li>
+ * <li>MultiModeReducerType (optional): the reducer to use in case of
multiple modes present in data</li>
+ * </ul>
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<Map<? extends Number, Long>, Double> {
+
+ private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
+
+ private final MultiModeReducerType _multiModeReducerType;
+
+ public ModeAggregationFunction(List<ExpressionContext> arguments) {
+ super(arguments.get(0));
+
+ int numArguments = arguments.size();
+ Preconditions.checkArgument(numArguments <= 2, "Mode expects at most 2
arguments, got: %s", numArguments);
+ if (numArguments > 1) {
+ _multiModeReducerType =
MultiModeReducerType.valueOf(arguments.get(1).getLiteral());
+ } else {
+ _multiModeReducerType = MultiModeReducerType.MIN;
+ }
+ }
+
+ /**
+ * Helper method to create a value map for the given value type.
+ */
+ private static Map<? extends Number, Long> getValueMap(DataType valueType) {
+ switch (valueType) {
+ case INT:
+ return new Int2LongOpenHashMap();
+ case LONG:
+ return new Long2LongOpenHashMap();
+ case FLOAT:
+ return new Float2LongOpenHashMap();
+ case DOUBLE:
+ return new Double2LongOpenHashMap();
+ default:
+ throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + valueType);
+ }
+ }
+
+ /**
+ * Returns the value map from the result holder or creates a new one if it
does not exist.
+ */
+ private static Map<? extends Number, Long>
getValueMap(AggregationResultHolder aggregationResultHolder,
+ DataType valueType) {
+ Map<? extends Number, Long> valueMap = aggregationResultHolder.getResult();
+ if (valueMap == null) {
+ valueMap = getValueMap(valueType);
+ aggregationResultHolder.setValue(valueMap);
+ }
+ return valueMap;
+ }
+
+ /**
+ * Helper method to set INT value for the given group keys into the result
holder.
+ */
+ private static void setValueForGroupKeys(GroupByResultHolder
groupByResultHolder, int groupKey, int value) {
+ Int2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey);
+ if (valueMap == null) {
+ valueMap = new Int2LongOpenHashMap();
+ groupByResultHolder.setValueForKey(groupKey, valueMap);
+ }
+ valueMap.merge(value, 1, Long::sum);
+ }
+
+ /**
+ * Helper method to set LONG value for the given group keys into the result
holder.
+ */
+ private static void setValueForGroupKeys(GroupByResultHolder
groupByResultHolder, int groupKey, long value) {
+ Long2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey);
+ if (valueMap == null) {
+ valueMap = new Long2LongOpenHashMap();
+ groupByResultHolder.setValueForKey(groupKey, valueMap);
+ }
+ valueMap.merge(value, 1, Long::sum);
+ }
+
+ /**
+ * Helper method to set FLOAT value for the given group keys into the result
holder.
+ */
+ private static void setValueForGroupKeys(GroupByResultHolder
groupByResultHolder, int groupKey, float value) {
+ Float2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey);
+ if (valueMap == null) {
+ valueMap = new Float2LongOpenHashMap();
+ groupByResultHolder.setValueForKey(groupKey, valueMap);
+ }
+ valueMap.merge(value, 1, Long::sum);
+ }
+
+ /**
+ * Helper method to set DOUBLE value for the given group keys into the
result holder.
+ */
+ private static void setValueForGroupKeys(GroupByResultHolder
groupByResultHolder, int groupKey, double value) {
+ Double2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey);
+ if (valueMap == null) {
+ valueMap = new Double2LongOpenHashMap();
+ groupByResultHolder.setValueForKey(groupKey, valueMap);
+ }
+ valueMap.merge(value, 1, Long::sum);
+ }
+
+ /**
+ * Returns the dictionary id count map from the result holder or creates a
new one if it does not exist.
+ */
+ protected static Int2IntOpenHashMap
getDictIdCountMap(AggregationResultHolder aggregationResultHolder,
+ Dictionary dictionary) {
+ ModeAggregationFunction.DictIdsWrapper dictIdsWrapper =
aggregationResultHolder.getResult();
+ if (dictIdsWrapper == null) {
+ dictIdsWrapper = new ModeAggregationFunction.DictIdsWrapper(dictionary);
+ aggregationResultHolder.setValue(dictIdsWrapper);
+ }
+ return dictIdsWrapper._dictIdCountMap;
+ }
+
+ /**
+ * Returns the dictionary id count map for the given group key or creates a
new one if it does not exist.
+ */
+ protected static Int2IntOpenHashMap getDictIdCountMap(GroupByResultHolder
groupByResultHolder, int groupKey,
+ Dictionary dictionary) {
+ ModeAggregationFunction.DictIdsWrapper dictIdsWrapper =
groupByResultHolder.getResult(groupKey);
+ if (dictIdsWrapper == null) {
+ dictIdsWrapper = new ModeAggregationFunction.DictIdsWrapper(dictionary);
+ groupByResultHolder.setValueForKey(groupKey, dictIdsWrapper);
+ }
+ return dictIdsWrapper._dictIdCountMap;
+ }
+
+ /**
+ * Helper method to read dictionary and convert dictionary ids to values for
dictionary-encoded expression.
+ */
+ private static Map<? extends Number, Long> convertToValueMap(DictIdsWrapper
dictIdsWrapper) {
+ Dictionary dictionary = dictIdsWrapper._dictionary;
+ Int2IntOpenHashMap dictIdCountMap = dictIdsWrapper._dictIdCountMap;
+ int numValues = dictIdCountMap.size();
+ ObjectIterator<Int2IntMap.Entry> iterator =
Int2IntMaps.fastIterator(dictIdCountMap);
+ DataType storedType = dictionary.getValueType();
+ switch (storedType) {
+ case INT:
+ Int2LongOpenHashMap intValueMap = new Int2LongOpenHashMap(numValues);
+ while (iterator.hasNext()) {
+ Int2IntMap.Entry next = iterator.next();
+ intValueMap.put(dictionary.getIntValue(next.getIntKey()),
next.getIntValue());
+ }
+ return intValueMap;
+ case LONG:
+ Long2LongOpenHashMap longValueMap = new
Long2LongOpenHashMap(numValues);
+ while (iterator.hasNext()) {
+ Int2IntMap.Entry next = iterator.next();
+ longValueMap.put(dictionary.getLongValue(next.getIntKey()),
next.getIntValue());
+ }
+ return longValueMap;
+ case FLOAT:
+ Float2LongOpenHashMap floatValueMap = new
Float2LongOpenHashMap(numValues);
+ while (iterator.hasNext()) {
+ Int2IntMap.Entry next = iterator.next();
+ floatValueMap.put(dictionary.getFloatValue(next.getIntKey()),
next.getIntValue());
+ }
+ return floatValueMap;
+ case DOUBLE:
+ Double2LongOpenHashMap doubleValueMap = new
Double2LongOpenHashMap(numValues);
+ while (iterator.hasNext()) {
+ Int2IntMap.Entry next = iterator.next();
+ doubleValueMap.put(dictionary.getDoubleValue(next.getIntKey()),
next.getIntValue());
+ }
+ return doubleValueMap;
+ default:
+ throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + storedType);
+ }
+ }
+
+ /**
+ * Helper method to extract segment level intermediate result from the inner
segment result.
+ */
+ private static Map<? extends Number, Long>
extractIntermediateResult(@Nullable Object result) {
+ if (result == null) {
+ // NOTE: Return an empty Int2LongOpenHashMap for empty result.
+ return new Int2LongOpenHashMap();
+ }
+
+ if (result instanceof DictIdsWrapper) {
+ // For dictionary-encoded expression, convert dictionary ids to values
+ return convertToValueMap((DictIdsWrapper) result);
+ }
+ assert result instanceof Map;
+ // For non-dictionary-encoded expression, directly return the value set
+ return (Map) result;
+ }
+
+ @Override
+ public AggregationFunctionType getType() {
+ return AggregationFunctionType.MODE;
+ }
+
+ @Override
+ public AggregationResultHolder createAggregationResultHolder() {
+ return new ObjectAggregationResultHolder();
+ }
+
+ @Override
+ public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
+
+ @Override
+ public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+ // For dictionary-encoded expression, store dictionary ids into the dictId
map
+ Dictionary dictionary = blockValSet.getDictionary();
+ if (dictionary != null) {
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ Int2IntOpenHashMap dictIdValueMap =
getDictIdCountMap(aggregationResultHolder, dictionary);
+ for (int i = 0; i < length; i++) {
+ dictIdValueMap.merge(dictIds[i], 1, Integer::sum);
+ }
+ return;
+ }
+
+ // For non-dictionary-encoded expression, store values into the value map
+ DataType storedType = blockValSet.getValueType().getStoredType();
+ Map<? extends Number, Long> valueMap =
getValueMap(aggregationResultHolder, storedType);
+ switch (storedType) {
+ case INT:
+ Int2LongOpenHashMap intMap = (Int2LongOpenHashMap) valueMap;
+ int[] intValues = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length; i++) {
+ intMap.merge(intValues[i], 1, Long::sum);
+ }
+ break;
+ case LONG:
+ Long2LongOpenHashMap longMap = (Long2LongOpenHashMap) valueMap;
+ long[] longValues = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length; i++) {
+ longMap.merge(longValues[i], 1, Long::sum);
+ }
+ break;
+ case FLOAT:
+ Float2LongOpenHashMap floatMap = (Float2LongOpenHashMap) valueMap;
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ for (int i = 0; i < length; i++) {
+ floatMap.merge(floatValues[i], 1, Long::sum);
+ }
+ break;
+ case DOUBLE:
+ Double2LongOpenHashMap doubleMap = (Double2LongOpenHashMap) valueMap;
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ doubleMap.merge(doubleValues[i], 1, Long::sum);
+ }
+ break;
+ default:
+ throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + storedType);
+ }
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+ // For dictionary-encoded expression, store dictionary ids into the dictId
map
+ Dictionary dictionary = blockValSet.getDictionary();
+ if (dictionary != null) {
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ for (int i = 0; i < length; i++) {
+ getDictIdCountMap(groupByResultHolder, groupKeyArray[i],
dictionary).merge(dictIds[i], 1, Integer::sum);
+ }
+ return;
+ }
+
+ // For non-dictionary-encoded expression, store values into the value map
+ DataType storedType = blockValSet.getValueType().getStoredType();
+ switch (storedType) {
+ case INT:
+ int[] intValues = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
intValues[i]);
+ }
+ break;
+ case LONG:
+ long[] longValues = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
longValues[i]);
+ }
+ break;
+ case FLOAT:
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ for (int i = 0; i < length; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
floatValues[i]);
+ }
+ break;
+ case DOUBLE:
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
doubleValues[i]);
+ }
+ break;
+ default:
+ throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + storedType);
+ }
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+ // For dictionary-encoded expression, store dictionary ids into the dictId
map
+ Dictionary dictionary = blockValSet.getDictionary();
+ if (dictionary != null) {
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ getDictIdCountMap(groupByResultHolder, groupKey,
dictionary).merge(dictIds[i], 1, Integer::sum);
+ }
+ }
+ return;
+ }
+
+ // For non-dictionary-encoded expression, store values into the value map
+ DataType storedType = blockValSet.getValueType().getStoredType();
+ switch (storedType) {
+ case INT:
+ int[] intValues = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey, intValues[i]);
+ }
+ }
+ break;
+ case LONG:
+ long[] longValues = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey, longValues[i]);
+ }
+ }
+ break;
+ case FLOAT:
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey,
floatValues[i]);
+ }
+ }
+ break;
+ case DOUBLE:
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey,
doubleValues[i]);
+ }
+ }
+ break;
+ default:
+ throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + storedType);
+ }
+ }
+
+ @Override
+ public Map<? extends Number, Long>
extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
+ return extractIntermediateResult(aggregationResultHolder.getResult());
+ }
+
+ @Override
+ public Map<? extends Number, Long> extractGroupByResult(GroupByResultHolder
groupByResultHolder, int groupKey) {
+ return extractIntermediateResult(groupByResultHolder.getResult(groupKey));
+ }
+
+ @Override
+ public Map<? extends Number, Long> merge(Map<? extends Number, Long>
intermediateResult1,
+ Map<? extends Number, Long> intermediateResult2) {
+ if (intermediateResult1.isEmpty()) {
+ return intermediateResult2;
+ }
+ if (intermediateResult2.isEmpty()) {
+ return intermediateResult1;
+ }
+ if (intermediateResult1 instanceof Int2LongOpenHashMap &&
intermediateResult2 instanceof Int2LongOpenHashMap) {
+ ((Int2LongOpenHashMap)
intermediateResult2).int2LongEntrySet().fastForEach(
+ e -> ((Int2LongOpenHashMap)
intermediateResult1).merge(e.getIntKey(), e.getLongValue(), Long::sum));
+ } else if (intermediateResult1 instanceof Long2LongOpenHashMap
+ && intermediateResult2 instanceof Long2LongOpenHashMap) {
+ ((Long2LongOpenHashMap)
intermediateResult2).long2LongEntrySet().fastForEach(
+ e -> ((Long2LongOpenHashMap)
intermediateResult1).merge(e.getLongKey(), e.getLongValue(), Long::sum));
+ } else if (intermediateResult1 instanceof Float2LongOpenHashMap
+ && intermediateResult2 instanceof Float2LongOpenHashMap) {
+ ((Float2LongOpenHashMap)
intermediateResult2).float2LongEntrySet().fastForEach(
+ e -> ((Float2LongOpenHashMap)
intermediateResult1).merge(e.getFloatKey(), e.getLongValue(), Long::sum));
+ } else if (intermediateResult1 instanceof Double2LongOpenHashMap
+ && intermediateResult2 instanceof Double2LongOpenHashMap) {
+ ((Double2LongOpenHashMap)
intermediateResult2).double2LongEntrySet().fastForEach(
+ e -> ((Double2LongOpenHashMap)
intermediateResult1).merge(e.getDoubleKey(), e.getLongValue(), Long::sum));
+ } else {
+ throw new IllegalStateException(
+ "Illegal data type for Intermediate Result of MODE aggregation
function: " + intermediateResult1.getClass()
+ .getSimpleName() + ", " +
intermediateResult2.getClass().getSimpleName());
+ }
+ return intermediateResult1;
+ }
+
+ @Override
+ public boolean isIntermediateResultComparable() {
+ return false;
+ }
+
+ @Override
+ public ColumnDataType getIntermediateResultColumnType() {
+ return ColumnDataType.OBJECT;
+ }
+
+ @Override
+ public ColumnDataType getFinalResultColumnType() {
+ return ColumnDataType.DOUBLE;
+ }
+
+ @Override
+ public Double extractFinalResult(Map<? extends Number, Long>
intermediateResult) {
+ if (intermediateResult.isEmpty()) {
+ return DEFAULT_FINAL_RESULT;
+ } else if (intermediateResult instanceof Int2LongOpenHashMap) {
+ return extractFinalResult((Int2LongOpenHashMap) intermediateResult);
+ } else if (intermediateResult instanceof Long2LongOpenHashMap) {
+ return extractFinalResult((Long2LongOpenHashMap) intermediateResult);
+ } else if (intermediateResult instanceof Float2LongOpenHashMap) {
+ return extractFinalResult((Float2LongOpenHashMap) intermediateResult);
+ } else if (intermediateResult instanceof Double2LongOpenHashMap) {
+ return extractFinalResult((Double2LongOpenHashMap) intermediateResult);
+ } else {
+ throw new IllegalStateException(
+ "Illegal data type for Intermediate Result of MODE aggregation
function: " + intermediateResult.getClass()
+ .getSimpleName());
+ }
+ }
+
+ public double extractFinalResult(Int2LongOpenHashMap intermediateResult) {
+ ObjectIterator<Int2LongMap.Entry> iterator =
intermediateResult.int2LongEntrySet().fastIterator();
+ Int2LongMap.Entry first = iterator.next();
+ long maxFrequency = first.getLongValue();
+ switch (_multiModeReducerType) {
+ case MIN:
+ int min = first.getIntKey();
+ while (iterator.hasNext()) {
+ Int2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency && min > next.getIntKey())) {
+ maxFrequency = next.getLongValue();
+ min = next.getIntKey();
+ }
+ }
+ return min;
+ case MAX:
+ int max = first.getIntKey();
+ while (iterator.hasNext()) {
+ Int2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency && max < next.getIntKey())) {
+ maxFrequency = next.getLongValue();
+ max = next.getIntKey();
+ }
+ }
+ return max;
+ case AVG:
+ double sum = first.getIntKey();
+ int count = 1;
+ while (iterator.hasNext()) {
+ Int2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency)) {
+ maxFrequency = next.getLongValue();
+ sum = next.getIntKey();
+ count = 1;
+ } else if (next.getLongValue() == maxFrequency) {
+ sum += next.getIntKey();
+ count += 1;
+ }
+ }
+ return sum / count;
+ default:
+ throw new IllegalStateException("Illegal reducer type for MODE
aggregation function: " + _multiModeReducerType);
+ }
+ }
+
+ public double extractFinalResult(Long2LongOpenHashMap intermediateResult) {
+ ObjectIterator<Long2LongMap.Entry> iterator =
intermediateResult.long2LongEntrySet().fastIterator();
+ Long2LongMap.Entry first = iterator.next();
+ long maxFrequency = first.getLongValue();
+ switch (_multiModeReducerType) {
+ case MIN:
+ long min = first.getLongKey();
+ while (iterator.hasNext()) {
+ Long2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency
+ && min > next.getLongKey())) {
+ maxFrequency = next.getLongValue();
+ min = next.getLongKey();
+ }
+ }
+ return min;
+ case MAX:
+ long max = first.getLongKey();
+ while (iterator.hasNext()) {
+ Long2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency
+ && max < next.getLongKey())) {
+ maxFrequency = next.getLongValue();
+ max = next.getLongKey();
+ }
+ }
+ return max;
+ case AVG:
+ double sum = first.getLongKey();
+ int count = 1;
+ while (iterator.hasNext()) {
+ Long2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency)) {
+ maxFrequency = next.getLongValue();
+ sum = next.getLongKey();
+ count = 1;
+ } else if (next.getLongValue() == maxFrequency) {
+ sum += next.getLongKey();
+ count += 1;
+ }
+ }
+ return sum / count;
+ default:
+ throw new IllegalStateException("Illegal reducer type for MODE
aggregation function: " + _multiModeReducerType);
+ }
+ }
+
+ public double extractFinalResult(Float2LongOpenHashMap intermediateResult) {
+ ObjectIterator<Float2LongMap.Entry> iterator =
intermediateResult.float2LongEntrySet().fastIterator();
+ Float2LongMap.Entry first = iterator.next();
+ long maxFrequency = first.getLongValue();
+ switch (_multiModeReducerType) {
+ case MIN:
+ float min = first.getFloatKey();
+ while (iterator.hasNext()) {
+ Float2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency
+ && min > next.getFloatKey())) {
+ maxFrequency = next.getLongValue();
+ min = next.getFloatKey();
+ }
+ }
+ return min;
+ case MAX:
+ float max = first.getFloatKey();
+ while (iterator.hasNext()) {
+ Float2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency
+ && max < next.getFloatKey())) {
+ maxFrequency = next.getLongValue();
+ max = next.getFloatKey();
+ }
+ }
+ return max;
+ case AVG:
+ double sum = first.getFloatKey();
+ int count = 1;
+ while (iterator.hasNext()) {
+ Float2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency)) {
+ maxFrequency = next.getLongValue();
+ sum = next.getFloatKey();
+ count = 1;
+ } else if (next.getLongValue() == maxFrequency) {
+ sum += next.getFloatKey();
+ count += 1;
+ }
+ }
+ return sum / count;
+ default:
+ throw new IllegalStateException("Illegal reducer type for MODE
aggregation function: " + _multiModeReducerType);
+ }
+ }
+
+ public Double extractFinalResult(Double2LongOpenHashMap intermediateResult) {
+ ObjectIterator<Double2LongMap.Entry> iterator =
intermediateResult.double2LongEntrySet().fastIterator();
+ Double2LongMap.Entry first = iterator.next();
+ long maxFrequency = first.getLongValue();
+ switch (_multiModeReducerType) {
+ case MIN:
+ double min = first.getDoubleKey();
+ while (iterator.hasNext()) {
+ Double2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency
+ && min > next.getDoubleKey())) {
+ maxFrequency = next.getLongValue();
+ min = next.getDoubleKey();
+ }
+ }
+ return min;
+ case MAX:
+ double max = first.getDoubleKey();
+ while (iterator.hasNext()) {
+ Double2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency) || (next.getLongValue() ==
maxFrequency
+ && max < next.getDoubleKey())) {
+ maxFrequency = next.getLongValue();
+ max = next.getDoubleKey();
+ }
+ }
+ return max;
+ case AVG:
+ double sum = first.getDoubleKey();
+ int count = 1;
+ while (iterator.hasNext()) {
+ Double2LongMap.Entry next = iterator.next();
+ if ((next.getLongValue() > maxFrequency)) {
+ maxFrequency = next.getLongValue();
+ sum = next.getDoubleKey();
+ count = 1;
+ } else if (next.getLongValue() == maxFrequency) {
+ sum += next.getDoubleKey();
+ count += 1;
+ }
+ }
+ return sum / count;
+ default:
+ throw new IllegalStateException("Illegal reducer type for MODE
aggregation function: " + _multiModeReducerType);
+ }
+ }
+
+ private enum MultiModeReducerType {
+ MIN, MAX, AVG
+ }
+
+ private static final class DictIdsWrapper {
+
+ final Dictionary _dictionary;
+ final Int2IntOpenHashMap _dictIdCountMap;
+
+ private DictIdsWrapper(Dictionary dictionary) {
+ _dictionary = dictionary;
+ _dictIdCountMap = new Int2IntOpenHashMap();
+ }
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
index b019698..8e4c6df 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java
@@ -20,9 +20,13 @@ package org.apache.pinot.core.common;
import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.tdunning.math.stats.TDigest;
+import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap;
+import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
+import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
@@ -202,4 +206,68 @@ public class ObjectSerDeUtilsTest {
}
}
}
+
+ @Test
+ public void testInt2LongMap() {
+ for (int i = 0; i < NUM_ITERATIONS; i++) {
+ int size = RANDOM.nextInt(100);
+ Int2LongOpenHashMap expected = new Int2LongOpenHashMap(size);
+ for (int j = 0; j < size; j++) {
+ expected.put(RANDOM.nextInt(20), RANDOM.nextLong());
+ }
+
+ byte[] bytes = ObjectSerDeUtils.serialize(expected);
+ Int2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes,
ObjectSerDeUtils.ObjectType.Int2LongMap);
+
+ assertEquals(actual, expected, ERROR_MESSAGE);
+ }
+ }
+
+ @Test
+ public void testLong2LongMap() {
+ for (int i = 0; i < NUM_ITERATIONS; i++) {
+ int size = RANDOM.nextInt(100);
+ Long2LongOpenHashMap expected = new Long2LongOpenHashMap(size);
+ for (int j = 0; j < size; j++) {
+ expected.put(RANDOM.nextLong(), RANDOM.nextLong());
+ }
+
+ byte[] bytes = ObjectSerDeUtils.serialize(expected);
+ Long2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes,
ObjectSerDeUtils.ObjectType.Long2LongMap);
+
+ assertEquals(actual, expected, ERROR_MESSAGE);
+ }
+ }
+
+ @Test
+ public void testFloat2LongMap() {
+ for (int i = 0; i < NUM_ITERATIONS; i++) {
+ int size = RANDOM.nextInt(100);
+ Float2LongOpenHashMap expected = new Float2LongOpenHashMap(size);
+ for (int j = 0; j < size; j++) {
+ expected.put(RANDOM.nextFloat(), RANDOM.nextLong());
+ }
+
+ byte[] bytes = ObjectSerDeUtils.serialize(expected);
+ Float2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes,
ObjectSerDeUtils.ObjectType.Float2LongMap);
+
+ assertEquals(actual, expected, ERROR_MESSAGE);
+ }
+ }
+
+ @Test
+ public void testDouble2LongMap() {
+ for (int i = 0; i < NUM_ITERATIONS; i++) {
+ int size = RANDOM.nextInt(100);
+ Double2LongOpenHashMap expected = new Double2LongOpenHashMap(size);
+ for (int j = 0; j < size; j++) {
+ expected.put(RANDOM.nextDouble(), RANDOM.nextLong());
+ }
+
+ byte[] bytes = ObjectSerDeUtils.serialize(expected);
+ Double2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes,
ObjectSerDeUtils.ObjectType.Double2LongMap);
+
+ assertEquals(actual, expected, ERROR_MESSAGE);
+ }
+ }
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
index bb847b1..b855806 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
@@ -80,6 +80,13 @@ public class AggregationFunctionFactoryTest {
assertEquals(aggregationFunction.getColumnName(), "avg_column");
assertEquals(aggregationFunction.getResultColumnName(),
function.toString());
+ function = getFunction("MoDe");
+ aggregationFunction =
AggregationFunctionFactory.getAggregationFunction(function,
DUMMY_QUERY_CONTEXT);
+ assertTrue(aggregationFunction instanceof ModeAggregationFunction);
+ assertEquals(aggregationFunction.getType(), AggregationFunctionType.MODE);
+ assertEquals(aggregationFunction.getColumnName(), "mode_column");
+ assertEquals(aggregationFunction.getResultColumnName(),
function.toString());
+
function = getFunction("MiNmAxRaNgE");
aggregationFunction =
AggregationFunctionFactory.getAggregationFunction(function,
DUMMY_QUERY_CONTEXT);
assertTrue(aggregationFunction instanceof MinMaxRangeAggregationFunction);
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/ModeQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/ModeQueriesTest.java
new file mode 100644
index 0000000..381127a
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/ModeQueriesTest.java
@@ -0,0 +1,949 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap;
+import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap;
+import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;
+import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap;
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+import java.util.stream.Collectors;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.response.broker.AggregationResult;
+import org.apache.pinot.common.response.broker.BrokerResponseNative;
+import org.apache.pinot.common.response.broker.GroupByResult;
+import org.apache.pinot.common.utils.HashUtil;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationGroupByOperator;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import
org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
+import
org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
+import org.apache.pinot.segment.spi.ImmutableSegment;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.ReadMode;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.Assert;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+/**
+ * Queries test for MODE queries.
+ */
+@SuppressWarnings("rawtypes")
+public class ModeQueriesTest extends BaseQueriesTest {
+ private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(),
"ModeQueriesTest");
+ private static final String RAW_TABLE_NAME = "testTable";
+ private static final String SEGMENT_NAME = "testSegment";
+ private static final Random RANDOM = new Random();
+
+ private static final int NUM_RECORDS = 2000;
+ private static final int MAX_VALUE = 1000;
+
+ private static final String INT_COLUMN = "intColumn";
+ private static final String INT_MV_COLUMN = "intMvColumn";
+ private static final String LONG_COLUMN = "longColumn";
+ private static final String FLOAT_COLUMN = "floatColumn";
+ private static final String DOUBLE_COLUMN = "doubleColumn";
+ private static final String INT_NO_DICT_COLUMN = "intNoDictColumn";
+ private static final String LONG_NO_DICT_COLUMN = "longNoDictColumn";
+ private static final String FLOAT_NO_DICT_COLUMN = "floatNoDictColumn";
+ private static final String DOUBLE_NO_DICT_COLUMN = "doubleNoDictColumn";
+ private static final Schema SCHEMA = new
Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN, DataType.INT)
+ .addMultiValueDimension(INT_MV_COLUMN,
DataType.INT).addSingleValueDimension(INT_NO_DICT_COLUMN, DataType.INT)
+ .addSingleValueDimension(LONG_COLUMN,
DataType.LONG).addSingleValueDimension(LONG_NO_DICT_COLUMN, DataType.LONG)
+ .addSingleValueDimension(FLOAT_COLUMN, DataType.FLOAT)
+ .addSingleValueDimension(FLOAT_NO_DICT_COLUMN, DataType.FLOAT)
+ .addSingleValueDimension(DOUBLE_COLUMN, DataType.DOUBLE)
+ .addSingleValueDimension(DOUBLE_NO_DICT_COLUMN, DataType.DOUBLE).build();
+ private static final TableConfig TABLE_CONFIG = new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME)
+ .setNoDictionaryColumns(
+ Lists.newArrayList(INT_NO_DICT_COLUMN, LONG_NO_DICT_COLUMN,
FLOAT_NO_DICT_COLUMN, DOUBLE_NO_DICT_COLUMN))
+ .build();
+ private static final double DELTA = 0.00001;
+
+ private HashMap<Integer, Long> _values;
+ private Double _expectedResultMin;
+ private Double _expectedResultMax;
+ private Double _expectedResultAvg;
+ private IndexSegment _indexSegment;
+ private List<IndexSegment> _indexSegments;
+
+ @Override
+ protected String getFilter() {
+ // NOTE: Use a match all filter to switch between
DictionaryBasedAggregationOperator and AggregationOperator
+ return " WHERE intColumn >= 0";
+ }
+
+ @Override
+ protected IndexSegment getIndexSegment() {
+ return _indexSegment;
+ }
+
+ @Override
+ protected List<IndexSegment> getIndexSegments() {
+ return _indexSegments;
+ }
+
+ @BeforeClass
+ public void setUp()
+ throws Exception {
+ FileUtils.deleteDirectory(INDEX_DIR);
+
+ List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+ int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE);
+ _values = new HashMap<>(hashMapCapacity);
+ for (int i = 0; i < NUM_RECORDS; i++) {
+ int value = RANDOM.nextInt(MAX_VALUE);
+ GenericRow record = new GenericRow();
+ _values.merge(value, 1L, Long::sum);
+ record.putValue(INT_COLUMN, value);
+ record.putValue(INT_MV_COLUMN, new Integer[]{value, value});
+ record.putValue(INT_NO_DICT_COLUMN, value);
+ record.putValue(LONG_COLUMN, (long) value);
+ record.putValue(LONG_NO_DICT_COLUMN, (long) value);
+ record.putValue(FLOAT_COLUMN, (float) value);
+ record.putValue(FLOAT_NO_DICT_COLUMN, (float) value);
+ record.putValue(DOUBLE_COLUMN, (double) value);
+ record.putValue(DOUBLE_NO_DICT_COLUMN, (double) value);
+ records.add(record);
+ }
+ _expectedResultMin = _values.keySet().stream()
+ .filter(key -> Objects.equals(_values.get(key),
_values.values().stream().max(Long::compareTo).get()))
+
.mapToDouble(Integer::doubleValue).min().orElse(Double.NEGATIVE_INFINITY);
+ _expectedResultMax = _values.keySet().stream()
+ .filter(key -> Objects.equals(_values.get(key),
_values.values().stream().max(Long::compareTo).get()))
+
.mapToDouble(Integer::doubleValue).max().orElse(Double.NEGATIVE_INFINITY);
+ _expectedResultAvg = _values.keySet().stream()
+ .filter(key -> Objects.equals(_values.get(key),
_values.values().stream().max(Long::compareTo).get()))
+
.mapToDouble(Integer::doubleValue).average().orElse(Double.NEGATIVE_INFINITY);
+
+ SegmentGeneratorConfig segmentGeneratorConfig = new
SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
+ segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+ segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+ segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+
+ SegmentIndexCreationDriverImpl driver = new
SegmentIndexCreationDriverImpl();
+ driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+ driver.build();
+
+ ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new
File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+ _indexSegment = immutableSegment;
+ _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+ }
+
+ @Test
+ public void testAggregationOnly() {
+ String query = "SELECT MODE(intColumn), MODE(longColumn),
MODE(floatColumn), MODE(doubleColumn) FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultsWithoutFilter =
resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResultsWithoutFilter);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter);
+ assertTrue(Maps.difference((Int2LongOpenHashMap)
aggregationResultsWithoutFilter.get(0), _values).areEqual());
+ assertTrue(Maps.difference((Long2LongOpenHashMap)
aggregationResultsWithoutFilter.get(1),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().longValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Float2LongOpenHashMap)
aggregationResultsWithoutFilter.get(2),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().floatValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Double2LongOpenHashMap)
aggregationResultsWithoutFilter.get(3),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().doubleValue(), Map.Entry::getValue)))
+ .areEqual());
+
+ // Inter segments (expect 4 * inner segment result)
+ double[] expectedResults = new double[4];
+ for (int i = 0; i < 4; i++) {
+ expectedResults[i] = _expectedResultMin;
+ }
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ aggregationResults = brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+ }
+
+ @Test
+ public void testAggregationOnlyNoDictionary() {
+ String query =
+ "SELECT MODE(intNoDictColumn), MODE(longNoDictColumn),
MODE(floatNoDictColumn), MODE(doubleNoDictColumn) FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultsWithoutFilter =
resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResultsWithoutFilter);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter);
+ assertTrue(Maps.difference((Int2LongOpenHashMap)
aggregationResultsWithoutFilter.get(0), _values).areEqual());
+ assertTrue(Maps.difference((Long2LongOpenHashMap)
aggregationResultsWithoutFilter.get(1),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().longValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Float2LongOpenHashMap)
aggregationResultsWithoutFilter.get(2),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().floatValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Double2LongOpenHashMap)
aggregationResultsWithoutFilter.get(3),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().doubleValue(), Map.Entry::getValue)))
+ .areEqual());
+
+ // Inter segments (expect 4 * inner segment result)
+ double[] expectedResults = new double[4];
+ for (int i = 0; i < 4; i++) {
+ expectedResults[i] = _expectedResultMin;
+ }
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ aggregationResults = brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+ }
+
+ @Test
+ public void testAggregationOnlyWithMultiModeReducerOptionMIN() {
+ String query =
+ "SELECT MODE(intColumn, 'MIN'), MODE(longColumn, 'MIN'),
MODE(floatColumn, 'MIN'), MODE(doubleColumn, 'MIN') FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultsWithoutFilter =
resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResultsWithoutFilter);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter);
+ assertTrue(Maps.difference((Int2LongOpenHashMap)
aggregationResultsWithoutFilter.get(0), _values).areEqual());
+ assertTrue(Maps.difference((Long2LongOpenHashMap)
aggregationResultsWithoutFilter.get(1),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().longValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Float2LongOpenHashMap)
aggregationResultsWithoutFilter.get(2),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().floatValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Double2LongOpenHashMap)
aggregationResultsWithoutFilter.get(3),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().doubleValue(), Map.Entry::getValue)))
+ .areEqual());
+
+ // Inter segments (expect 4 * inner segment result)
+ double[] expectedResults = new double[4];
+ for (int i = 0; i < 4; i++) {
+ expectedResults[i] = _expectedResultMin;
+ }
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ aggregationResults = brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+ }
+
+ @Test
+ public void testAggregationOnlyWithMultiModeReducerOptionMAX() {
+ String query =
+ "SELECT MODE(intColumn, 'MAX'), MODE(longColumn, 'MAX'),
MODE(floatColumn, 'MAX'), MODE(doubleColumn, 'MAX') FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultsWithoutFilter =
resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResultsWithoutFilter);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter);
+ assertTrue(Maps.difference((Int2LongOpenHashMap)
aggregationResultsWithoutFilter.get(0), _values).areEqual());
+ assertTrue(Maps.difference((Long2LongOpenHashMap)
aggregationResultsWithoutFilter.get(1),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().longValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Float2LongOpenHashMap)
aggregationResultsWithoutFilter.get(2),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().floatValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Double2LongOpenHashMap)
aggregationResultsWithoutFilter.get(3),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().doubleValue(), Map.Entry::getValue)))
+ .areEqual());
+
+ // Inter segments (expect 4 * inner segment result)
+ double[] expectedResults = new double[4];
+ for (int i = 0; i < 4; i++) {
+ expectedResults[i] = _expectedResultMax;
+ }
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ aggregationResults = brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+ }
+
+ @Test
+ public void testAggregationOnlyWithMultiModeReducerOptionAVG() {
+ String query =
+ "SELECT MODE(intColumn, 'AVG'), MODE(longColumn, 'AVG'),
MODE(floatColumn, 'AVG'), MODE(doubleColumn, 'AVG') FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultsWithoutFilter =
resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResultsWithoutFilter);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter);
+ assertTrue(Maps.difference((Int2LongOpenHashMap)
aggregationResultsWithoutFilter.get(0), _values).areEqual());
+ assertTrue(Maps.difference((Long2LongOpenHashMap)
aggregationResultsWithoutFilter.get(1),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().longValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Float2LongOpenHashMap)
aggregationResultsWithoutFilter.get(2),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().floatValue(), Map.Entry::getValue)))
+ .areEqual());
+ assertTrue(Maps.difference((Double2LongOpenHashMap)
aggregationResultsWithoutFilter.get(3),
+ _values.entrySet().stream().collect(Collectors.toMap(e ->
e.getKey().doubleValue(), Map.Entry::getValue)))
+ .areEqual());
+
+ // Inter segments (expect 4 * inner segment result)
+ double[] expectedResults = new double[4];
+ for (int i = 0; i < 4; i++) {
+ expectedResults[i] = _expectedResultAvg;
+ }
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ aggregationResults = brokerResponse.getAggregationResults();
+ Assert.assertEquals(aggregationResults.size(), expectedResults.length);
+ for (int i = 0; i < expectedResults.length; i++) {
+ AggregationResult aggregationResult = aggregationResults.get(i);
+ double expectedAggregationResult = expectedResults[i];
+ Serializable value = aggregationResult.getValue();
+ Assert.assertEquals(Double.parseDouble(value.toString()),
expectedAggregationResult, DELTA);
+ }
+ }
+
+ @Test
+ public void testAggregationGroupBySv() {
+ String query =
+ "SELECT MODE(intColumn), MODE(longColumn), MODE(floatColumn),
MODE(doubleColumn) FROM testTable GROUP BY intColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ Integer key = (Integer) groupKey._keys[0];
+ assertTrue(_values.containsKey(key));
+ assertTrue(
+ Maps.difference((Int2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId),
+ Collections.singletonMap(key, _values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Long2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId),
+ Collections.singletonMap(key.longValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Float2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId),
+ Collections.singletonMap(key.floatValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Double2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId),
+ Collections.singletonMap(key.doubleValue(),
_values.get(key))).areEqual());
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 4);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.containsKey(Integer.parseInt(group.get(0))));
+ assertEquals(Double.parseDouble(groupByResult.getValue().toString()),
Double.parseDouble(group.get(0)), DELTA);
+ }
+ }
+ }
+
+ @Test
+ public void testAggregationGroupByMv() {
+ String query =
+ "SELECT MODE(intColumn), MODE(longColumn), MODE(floatColumn),
MODE(doubleColumn) FROM testTable GROUP BY intMvColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 5 * NUM_RECORDS, NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ Integer key = (Integer) groupKey._keys[0];
+ assertTrue(_values.containsKey(key));
+ assertTrue(
+ Maps.difference((Int2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId),
+ Collections.singletonMap(key, _values.get(key) * 2)).areEqual());
+ assertTrue(
+ Maps.difference((Long2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId),
+ Collections.singletonMap(key.longValue(), _values.get(key) *
2)).areEqual());
+ assertTrue(
+ Maps.difference((Float2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId),
+ Collections.singletonMap(key.floatValue(), _values.get(key) *
2)).areEqual());
+ assertTrue(
+ Maps.difference((Double2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId),
+ Collections.singletonMap(key.doubleValue(), _values.get(key) *
2)).areEqual());
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 5
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 4);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.containsKey(Integer.parseInt(group.get(0))));
+ assertEquals(Double.parseDouble(groupByResult.getValue().toString()),
Double.parseDouble(group.get(0)), DELTA);
+ }
+ }
+ }
+
+ @Test
+ public void testAggregationGroupBySvNoDictionary() {
+ String query =
+ "SELECT MODE(intNoDictColumn), MODE(longNoDictColumn),
MODE(floatNoDictColumn), MODE(doubleNoDictColumn) FROM testTable GROUP BY
intNoDictColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ Integer key = (Integer) groupKey._keys[0];
+ assertTrue(_values.containsKey(key));
+ assertTrue(
+ Maps.difference((Int2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId),
+ Collections.singletonMap(key, _values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Long2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId),
+ Collections.singletonMap(key.longValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Float2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId),
+ Collections.singletonMap(key.floatValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Double2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId),
+ Collections.singletonMap(key.doubleValue(),
_values.get(key))).areEqual());
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 4);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.containsKey(Integer.parseInt(group.get(0))));
+ assertEquals(Double.parseDouble(groupByResult.getValue().toString()),
Double.parseDouble(group.get(0)), DELTA);
+ }
+ }
+ }
+
+ @Test
+ public void testAggregationGroupByMvNoDictionary() {
+ String query =
+ "SELECT MODE(intNoDictColumn), MODE(longNoDictColumn),
MODE(floatNoDictColumn), MODE(doubleNoDictColumn) FROM testTable GROUP BY
intMvColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 5 * NUM_RECORDS, NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ Integer key = (Integer) groupKey._keys[0];
+ assertTrue(_values.containsKey(key));
+ assertTrue(
+ Maps.difference((Int2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId),
+ Collections.singletonMap(key, _values.get(key) * 2)).areEqual());
+ assertTrue(
+ Maps.difference((Long2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId),
+ Collections.singletonMap(key.longValue(), _values.get(key) *
2)).areEqual());
+ assertTrue(
+ Maps.difference((Float2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId),
+ Collections.singletonMap(key.floatValue(), _values.get(key) *
2)).areEqual());
+ assertTrue(
+ Maps.difference((Double2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId),
+ Collections.singletonMap(key.doubleValue(), _values.get(key) *
2)).areEqual());
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 5
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 4);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.containsKey(Integer.parseInt(group.get(0))));
+ assertEquals(Double.parseDouble(groupByResult.getValue().toString()),
Double.parseDouble(group.get(0)), DELTA);
+ }
+ }
+ }
+
+ @Test
+ public void testAggregationGroupBySvWithMultiModeReducerOptionMIN() {
+ String query =
+ "SELECT MODE(intColumn, 'MIN'), MODE(longColumn, 'MIN'),
MODE(floatColumn, 'MIN'), MODE(doubleColumn, 'MIN') FROM testTable GROUP BY
intColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ Integer key = (Integer) groupKey._keys[0];
+ assertTrue(_values.containsKey(key));
+ assertTrue(
+ Maps.difference((Int2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId),
+ Collections.singletonMap(key, _values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Long2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId),
+ Collections.singletonMap(key.longValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Float2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId),
+ Collections.singletonMap(key.floatValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Double2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId),
+ Collections.singletonMap(key.doubleValue(),
_values.get(key))).areEqual());
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 4);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.containsKey(Integer.parseInt(group.get(0))));
+ assertEquals(Double.parseDouble(groupByResult.getValue().toString()),
Double.parseDouble(group.get(0)), DELTA);
+ }
+ }
+ }
+
+ @Test
+ public void testAggregationGroupBySvWithMultiModeReducerOptionMAX() {
+ String query =
+ "SELECT MODE(intColumn, 'MAX'), MODE(longColumn, 'MAX'),
MODE(floatColumn, 'MAX'), MODE(doubleColumn, 'MAX') FROM testTable GROUP BY
intColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ Integer key = (Integer) groupKey._keys[0];
+ assertTrue(_values.containsKey(key));
+ assertTrue(
+ Maps.difference((Int2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId),
+ Collections.singletonMap(key, _values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Long2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId),
+ Collections.singletonMap(key.longValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Float2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId),
+ Collections.singletonMap(key.floatValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Double2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId),
+ Collections.singletonMap(key.doubleValue(),
_values.get(key))).areEqual());
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 4);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.containsKey(Integer.parseInt(group.get(0))));
+ assertEquals(Double.parseDouble(groupByResult.getValue().toString()),
Double.parseDouble(group.get(0)), DELTA);
+ }
+ }
+ }
+
+ @Test
+ public void testAggregationGroupBySvWithMultiModeReducerOptionAVG() {
+ String query =
+ "SELECT MODE(intColumn, 'AVG'), MODE(longColumn, 'AVG'),
MODE(floatColumn, 'AVG'), MODE(doubleColumn, 'AVG') FROM testTable GROUP BY
intColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0,
+ 4 * NUM_RECORDS, NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ Integer key = (Integer) groupKey._keys[0];
+ assertTrue(_values.containsKey(key));
+ assertTrue(
+ Maps.difference((Int2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId),
+ Collections.singletonMap(key, _values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Long2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId),
+ Collections.singletonMap(key.longValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Float2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId),
+ Collections.singletonMap(key.floatValue(),
_values.get(key))).areEqual());
+ assertTrue(
+ Maps.difference((Double2LongOpenHashMap)
aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId),
+ Collections.singletonMap(key.doubleValue(),
_values.get(key))).areEqual());
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 4);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.containsKey(Integer.parseInt(group.get(0))));
+ assertEquals(Double.parseDouble(groupByResult.getValue().toString()),
Double.parseDouble(group.get(0)), DELTA);
+ }
+ }
+ }
+
+ @AfterClass
+ public void tearDown()
+ throws IOException {
+ _indexSegment.destroy();
+ FileUtils.deleteDirectory(INDEX_DIR);
+ }
+}
diff --git
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 1683787..9d10e8a 100644
---
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -32,6 +32,7 @@ public enum AggregationFunctionType {
SUM("sum"),
SUMPRECISION("sumPrecision"),
AVG("avg"),
+ MODE("mode"),
MINMAXRANGE("minMaxRange"),
DISTINCTCOUNT("distinctCount"),
DISTINCTCOUNTBITMAP("distinctCountBitmap"),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]