This is an automated email from the ASF dual-hosted git repository. xiangfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 47eff886cd Add more array scalar functions (#11555) 47eff886cd is described below commit 47eff886cd0a130cbff9566175f7ec36ea7792e0 Author: Xuanyi Li <xuany...@uber.com> AuthorDate: Tue Sep 19 17:04:28 2023 -0700 Add more array scalar functions (#11555) * scalar func * fix unit test * fix silly bug in intersectIndices * add indexOfAll for long, float and double, including unit test --- .../common/function/scalar/ArrayFunctions.java | 81 ++++++++++++++ .../function/BaseTransformFunctionTest.java | 37 +++++++ .../ScalarTransformFunctionWrapperTest.java | 116 +++++++++++++++++++++ 3 files changed, 234 insertions(+) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index a15cc931b4..a9a6d39e72 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet; import it.unimi.dsi.fastutil.objects.ObjectSet; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -74,6 +76,85 @@ public class ArrayFunctions { return ArrayUtils.indexOf(values, valueToFind); } + @ScalarFunction + public static int[] arrayIndexOfAllInt(int[] value, int valueToFind) { + List<Integer> indices = new ArrayList<>(); + for (int i = 0; i < value.length; i++) { + if (value[i] == valueToFind) { + indices.add(i); + } + } + return indices.stream().mapToInt(Integer::intValue).toArray(); + } + + @ScalarFunction + public static int[] arrayIndexOfAllLong(long[] value, long valueToFind) { + List<Integer> indices = new ArrayList<>(); + for (int i = 0; i < value.length; i++) { + if (value[i] == valueToFind) { + indices.add(i); + } + } + return indices.stream().mapToInt(Integer::intValue).toArray(); + } + + @ScalarFunction + public static int[] arrayIndexOfAllFloat(float[] value, float valueToFind) { + List<Integer> indices = new ArrayList<>(); + for (int i = 0; i < value.length; i++) { + if (value[i] == valueToFind) { + indices.add(i); + } + } + return indices.stream().mapToInt(Integer::intValue).toArray(); + } + + @ScalarFunction + public static int[] arrayIndexOfAllDouble(double[] value, double valueToFind) { + List<Integer> indices = new ArrayList<>(); + for (int i = 0; i < value.length; i++) { + if (value[i] == valueToFind) { + indices.add(i); + } + } + return indices.stream().mapToInt(Integer::intValue).toArray(); + } + + @ScalarFunction + public static int[] arrayIndexOfAllString(String[] value, String valueToFind) { + List<Integer> indices = new ArrayList<>(); + for (int i = 0; i < value.length; i++) { + if (valueToFind.equals(value[i])) { + indices.add(i); + } + } + return indices.stream().mapToInt(Integer::intValue).toArray(); + } + + /** + * Assume values1, and values2 are monotonous increasing indices of MV cols. + * Here is the common usage: + * col1: ["a", "b", "a", "b"] + * col2: ["c", "d", "d", "c"] + * The user want to get the first index called idx, s.t. col1[idx] == "b" && col2[idx] == "d" + * arrayElementAtInt(0, intersectIndices(arrayIndexOfAllString(col1, "b"), arrayIndexOfAllString(col2, "d"))) + */ + @ScalarFunction + public static int[] intersectIndices(int[] values1, int[] values2) { + // TODO: if values1.length << values2.length. Use binary search can speed up the query + int i = 0; + int j = 0; + List<Integer> indices = new ArrayList<>(); + while (i < values1.length && j < values2.length) { + if (values1[i] == values2[j]) { + indices.add(values1[i]); + j++; + } + i++; + } + return indices.stream().mapToInt(Integer::intValue).toArray(); + } + @ScalarFunction public static boolean arrayContainsInt(int[] values, int valueToFind) { return ArrayUtils.contains(values, valueToFind); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java index ed2a5b4b7b..129c67ad73 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java @@ -102,9 +102,16 @@ public abstract class BaseTransformFunctionTest { protected static final String STRING_MV_COLUMN = "stringMV"; protected static final String STRING_ALPHANUM_MV_COLUMN = "stringAlphaNumMV"; protected static final String STRING_LONG_MV_COLUMN = "stringLongMV"; + // deterministic MV is useful for testing IndexOf and IndexOfAll + protected static final String STRING_ALPHANUM_MV_COLUMN_2 = "stringAlphaNumMV2"; protected static final String TIME_COLUMN = "timeColumn"; protected static final String TIMESTAMP_COLUMN = "timestampColumn"; protected static final String TIMESTAMP_COLUMN_NULL = "timestampColumnNull"; + protected static final String INT_MONO_INCREASING_MV_1 = "intMonoIncreasingMV1"; + protected static final String INT_MONO_INCREASING_MV_2 = "intMonoIncreasingMV2"; + protected static final String LONG_MV_COLUMN_2 = "longMV2"; + protected static final String FLOAT_MV_COLUMN_2 = "floatMV2"; + protected static final String DOUBLE_MV_COLUMN_2 = "doubleMV2"; protected static final String JSON_COLUMN = "json"; protected static final String DEFAULT_JSON_COLUMN = "defaultJson"; @@ -122,11 +129,17 @@ public abstract class BaseTransformFunctionTest { protected final double[][] _doubleMVValues = new double[NUM_ROWS][]; protected final String[][] _stringMVValues = new String[NUM_ROWS][]; protected final String[][] _stringAlphaNumericMVValues = new String[NUM_ROWS][]; + protected final String[][] _stringAlphaNumericMV2Values = new String[NUM_ROWS][]; protected final String[][] _stringLongFormatMVValues = new String[NUM_ROWS][]; protected final long[] _timeValues = new long[NUM_ROWS]; protected final String[] _jsonValues = new String[NUM_ROWS]; protected final float[][] _vector1Values = new float[NUM_ROWS][]; protected final float[][] _vector2Values = new float[NUM_ROWS][]; + protected final int[][] _intMonoIncreasingMV1Values = new int[NUM_ROWS][]; + protected final int[][] _intMonoIncreasingMV2Values = new int[NUM_ROWS][]; + protected final long[][] _longMV2Values = new long[NUM_ROWS][]; + protected final float[][] _floatMV2Values = new float[NUM_ROWS][]; + protected final double[][] _doubleMV2Values = new double[NUM_ROWS][]; protected Map<String, DataSource> _dataSourceMap; protected ProjectionBlock _projectionBlock; @@ -155,9 +168,15 @@ public abstract class BaseTransformFunctionTest { _doubleMVValues[i] = new double[numValues]; _stringMVValues[i] = new String[numValues]; _stringAlphaNumericMVValues[i] = new String[numValues]; + _stringAlphaNumericMV2Values[i] = new String[numValues]; _stringLongFormatMVValues[i] = new String[numValues]; _vector1Values[i] = new float[VECTOR_DIM_SIZE]; _vector2Values[i] = new float[VECTOR_DIM_SIZE]; + _intMonoIncreasingMV1Values[i] = new int[numValues]; + _intMonoIncreasingMV2Values[i] = new int[numValues]; + _longMV2Values[i] = new long[numValues]; + _floatMV2Values[i] = new float[numValues]; + _doubleMV2Values[i] = new double[numValues]; for (int j = 0; j < numValues; j++) { _intMVValues[i][j] = 1 + RANDOM.nextInt(MAX_MULTI_VALUE); @@ -166,7 +185,13 @@ public abstract class BaseTransformFunctionTest { _doubleMVValues[i][j] = 1 + RANDOM.nextDouble(); _stringMVValues[i][j] = df.format(_intSVValues[i] * RANDOM.nextDouble()); _stringAlphaNumericMVValues[i][j] = RandomStringUtils.randomAlphanumeric(26); + _stringAlphaNumericMV2Values[i][j] = "a"; _stringLongFormatMVValues[i][j] = df.format(_intSVValues[i] * RANDOM.nextLong()); + _intMonoIncreasingMV1Values[i][j] = j; + _intMonoIncreasingMV2Values[i][j] = j + 1; + _longMV2Values[i][j] = 1L; + _floatMV2Values[i][j] = 1.0f; + _doubleMV2Values[i][j] = 1.0; } for (int j = 0; j < VECTOR_DIM_SIZE; j++) { @@ -219,6 +244,7 @@ public abstract class BaseTransformFunctionTest { map.put(DOUBLE_MV_COLUMN, ArrayUtils.toObject(_doubleMVValues[i])); map.put(STRING_MV_COLUMN, _stringMVValues[i]); map.put(STRING_ALPHANUM_MV_COLUMN, _stringAlphaNumericMVValues[i]); + map.put(STRING_ALPHANUM_MV_COLUMN_2, _stringAlphaNumericMV2Values[i]); map.put(STRING_LONG_MV_COLUMN, _stringLongFormatMVValues[i]); map.put(TIMESTAMP_COLUMN, _timeValues[i]); if (isNullRow(i)) { @@ -229,6 +255,11 @@ public abstract class BaseTransformFunctionTest { map.put(TIME_COLUMN, _timeValues[i]); _jsonValues[i] = JsonUtils.objectToJsonNode(map).toString(); map.put(JSON_COLUMN, _jsonValues[i]); + map.put(INT_MONO_INCREASING_MV_1, ArrayUtils.toObject(_intMonoIncreasingMV1Values[i])); + map.put(INT_MONO_INCREASING_MV_2, ArrayUtils.toObject(_intMonoIncreasingMV2Values[i])); + map.put(LONG_MV_COLUMN_2, ArrayUtils.toObject(_longMV2Values[i])); + map.put(FLOAT_MV_COLUMN_2, ArrayUtils.toObject(_floatMV2Values[i])); + map.put(DOUBLE_MV_COLUMN_2, ArrayUtils.toObject(_doubleMV2Values[i])); GenericRow row = new GenericRow(); row.init(map); rows.add(row); @@ -254,10 +285,16 @@ public abstract class BaseTransformFunctionTest { .addMultiValueDimension(DOUBLE_MV_COLUMN, FieldSpec.DataType.DOUBLE) .addMultiValueDimension(STRING_MV_COLUMN, FieldSpec.DataType.STRING) .addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN, FieldSpec.DataType.STRING) + .addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN_2, FieldSpec.DataType.STRING) .addMultiValueDimension(STRING_LONG_MV_COLUMN, FieldSpec.DataType.STRING) .addMultiValueDimension(VECTOR_1_COLUMN, FieldSpec.DataType.FLOAT) .addMultiValueDimension(VECTOR_2_COLUMN, FieldSpec.DataType.FLOAT) .addMultiValueDimension(ZERO_VECTOR_COLUMN, FieldSpec.DataType.FLOAT) + .addMultiValueDimension(INT_MONO_INCREASING_MV_1, FieldSpec.DataType.INT) + .addMultiValueDimension(INT_MONO_INCREASING_MV_2, FieldSpec.DataType.INT) + .addMultiValueDimension(LONG_MV_COLUMN_2, FieldSpec.DataType.LONG) + .addMultiValueDimension(FLOAT_MV_COLUMN_2, FieldSpec.DataType.FLOAT) + .addMultiValueDimension(DOUBLE_MV_COLUMN_2, FieldSpec.DataType.DOUBLE) .addDateTime(TIMESTAMP_COLUMN, FieldSpec.DataType.TIMESTAMP, "1:MILLISECONDS:EPOCH", "1:MILLISECONDS") .addDateTime(TIMESTAMP_COLUMN_NULL, FieldSpec.DataType.TIMESTAMP, "1:MILLISECONDS:EPOCH", "1:MILLISECONDS") .addTime(new TimeGranularitySpec(FieldSpec.DataType.LONG, TimeUnit.MILLISECONDS, TIME_COLUMN), null).build(); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java index c16f0e9c23..5befeccc07 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java @@ -953,6 +953,122 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes testTransformFunction(transformFunction, expectedValues); } + @Test + public void testArrayIndexOfAllInt() { + ExpressionContext expression = RequestContextUtils.getExpression( + String.format("array_index_of_all_int(%s, 0)", INT_MONO_INCREASING_MV_1)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + int[] expectedValue = {0}; + expectedValues[i] = expectedValue; + } + testTransformFunctionMV(transformFunction, expectedValues); + } + + @Test + public void testArrayIndexOfAllLong() { + ExpressionContext expression = RequestContextUtils.getExpression( + String.format("array_index_of_all_long(%s, 1)", LONG_MV_COLUMN_2)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + int len = _longMV2Values[i].length; + int[] expectedValue = new int[len]; + for (int j = 0; j < len; j++) { + expectedValue[j] = j; + } + expectedValues[i] = expectedValue; + } + testTransformFunctionMV(transformFunction, expectedValues); + } + + @Test + public void testArrayIndexOfAllFloat() { + ExpressionContext expression = RequestContextUtils.getExpression( + String.format("array_index_of_all_float(%s, 1.0)", FLOAT_MV_COLUMN_2)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + int len = _floatMV2Values[i].length; + int[] expectedValue = new int[len]; + for (int j = 0; j < len; j++) { + expectedValue[j] = j; + } + expectedValues[i] = expectedValue; + } + testTransformFunctionMV(transformFunction, expectedValues); + } + + @Test + public void testArrayIndexOfAllDouble() { + ExpressionContext expression = RequestContextUtils.getExpression( + String.format("array_index_of_all_double(%s, 1.0)", DOUBLE_MV_COLUMN_2)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + int len = _doubleMV2Values[i].length; + int[] expectedValue = new int[len]; + for (int j = 0; j < len; j++) { + expectedValue[j] = j; + } + expectedValues[i] = expectedValue; + } + testTransformFunctionMV(transformFunction, expectedValues); + } + + @Test + public void testArrayIndexOfAllString() { + ExpressionContext expression = RequestContextUtils.getExpression( + String.format("array_index_of_all_string(%s, 'a')", STRING_ALPHANUM_MV_COLUMN_2)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + int len = _stringAlphaNumericMV2Values[i].length; + int[] expectedValue = new int[len]; + for (int j = 0; j < len; j++) { + expectedValue[j] = j; + } + expectedValues[i] = expectedValue; + } + testTransformFunctionMV(transformFunction, expectedValues); + } + + @Test + public void testIntersectIndices() { + ExpressionContext expression = RequestContextUtils.getExpression( + String.format("intersect_indices(%s, %s)", INT_MONO_INCREASING_MV_1, INT_MONO_INCREASING_MV_2)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + int len = _intMonoIncreasingMV1Values[i].length; + int[] expectedValue = new int[len - 1]; + for (int j = 0; j < expectedValue.length; j++) { + expectedValue[j] = j + 1; + } + expectedValues[i] = expectedValue; + } + testTransformFunctionMV(transformFunction, expectedValues); + } + @Test public void testBase64TransformFunction() { ExpressionContext expression = RequestContextUtils.getExpression(String.format("toBase64(%s)", BYTES_SV_COLUMN)); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org