This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 47eff886cd Add more array scalar functions (#11555)
47eff886cd is described below

commit 47eff886cd0a130cbff9566175f7ec36ea7792e0
Author: Xuanyi Li <xuany...@uber.com>
AuthorDate: Tue Sep 19 17:04:28 2023 -0700

    Add more array scalar functions (#11555)
    
    * scalar func
    
    * fix unit test
    
    * fix silly bug in intersectIndices
    
    * add indexOfAll for long, float and double, including unit test
---
 .../common/function/scalar/ArrayFunctions.java     |  81 ++++++++++++++
 .../function/BaseTransformFunctionTest.java        |  37 +++++++
 .../ScalarTransformFunctionWrapperTest.java        | 116 +++++++++++++++++++++
 3 files changed, 234 insertions(+)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
index a15cc931b4..a9a6d39e72 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
@@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet;
 import it.unimi.dsi.fastutil.objects.ObjectSet;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.pinot.spi.annotations.ScalarFunction;
 import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder;
@@ -74,6 +76,85 @@ public class ArrayFunctions {
     return ArrayUtils.indexOf(values, valueToFind);
   }
 
+  @ScalarFunction
+  public static int[] arrayIndexOfAllInt(int[] value, int valueToFind) {
+    List<Integer> indices = new ArrayList<>();
+    for (int i = 0; i < value.length; i++) {
+      if (value[i] == valueToFind) {
+        indices.add(i);
+      }
+    }
+    return indices.stream().mapToInt(Integer::intValue).toArray();
+  }
+
+  @ScalarFunction
+  public static int[] arrayIndexOfAllLong(long[] value, long valueToFind) {
+    List<Integer> indices = new ArrayList<>();
+    for (int i = 0; i < value.length; i++) {
+      if (value[i] == valueToFind) {
+        indices.add(i);
+      }
+    }
+    return indices.stream().mapToInt(Integer::intValue).toArray();
+  }
+
+  @ScalarFunction
+  public static int[] arrayIndexOfAllFloat(float[] value, float valueToFind) {
+    List<Integer> indices = new ArrayList<>();
+    for (int i = 0; i < value.length; i++) {
+      if (value[i] == valueToFind) {
+        indices.add(i);
+      }
+    }
+    return indices.stream().mapToInt(Integer::intValue).toArray();
+  }
+
+  @ScalarFunction
+  public static int[] arrayIndexOfAllDouble(double[] value, double 
valueToFind) {
+    List<Integer> indices = new ArrayList<>();
+    for (int i = 0; i < value.length; i++) {
+      if (value[i] == valueToFind) {
+        indices.add(i);
+      }
+    }
+    return indices.stream().mapToInt(Integer::intValue).toArray();
+  }
+
+  @ScalarFunction
+  public static int[] arrayIndexOfAllString(String[] value, String 
valueToFind) {
+    List<Integer> indices = new ArrayList<>();
+    for (int i = 0; i < value.length; i++) {
+      if (valueToFind.equals(value[i])) {
+        indices.add(i);
+      }
+    }
+    return indices.stream().mapToInt(Integer::intValue).toArray();
+  }
+
+  /**
+   * Assume values1, and values2 are monotonous increasing indices of MV cols.
+   * Here is the common usage:
+   * col1: ["a", "b", "a", "b"]
+   * col2: ["c", "d", "d", "c"]
+   * The user want to get the first index called idx, s.t. col1[idx] == "b" && 
col2[idx] == "d"
+   * arrayElementAtInt(0, intersectIndices(arrayIndexOfAllString(col1, "b"), 
arrayIndexOfAllString(col2, "d")))
+   */
+  @ScalarFunction
+  public static int[] intersectIndices(int[] values1, int[] values2) {
+    // TODO: if values1.length << values2.length. Use binary search can speed 
up the query
+    int i = 0;
+    int j = 0;
+    List<Integer> indices = new ArrayList<>();
+    while (i < values1.length && j < values2.length) {
+      if (values1[i] == values2[j]) {
+        indices.add(values1[i]);
+        j++;
+      }
+      i++;
+    }
+    return indices.stream().mapToInt(Integer::intValue).toArray();
+  }
+
   @ScalarFunction
   public static boolean arrayContainsInt(int[] values, int valueToFind) {
     return ArrayUtils.contains(values, valueToFind);
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
index ed2a5b4b7b..129c67ad73 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
@@ -102,9 +102,16 @@ public abstract class BaseTransformFunctionTest {
   protected static final String STRING_MV_COLUMN = "stringMV";
   protected static final String STRING_ALPHANUM_MV_COLUMN = "stringAlphaNumMV";
   protected static final String STRING_LONG_MV_COLUMN = "stringLongMV";
+  // deterministic MV is useful for testing IndexOf and IndexOfAll
+  protected static final String STRING_ALPHANUM_MV_COLUMN_2 = 
"stringAlphaNumMV2";
   protected static final String TIME_COLUMN = "timeColumn";
   protected static final String TIMESTAMP_COLUMN = "timestampColumn";
   protected static final String TIMESTAMP_COLUMN_NULL = "timestampColumnNull";
+  protected static final String INT_MONO_INCREASING_MV_1 = 
"intMonoIncreasingMV1";
+  protected static final String INT_MONO_INCREASING_MV_2 = 
"intMonoIncreasingMV2";
+  protected static final String LONG_MV_COLUMN_2 = "longMV2";
+  protected static final String FLOAT_MV_COLUMN_2 = "floatMV2";
+  protected static final String DOUBLE_MV_COLUMN_2 = "doubleMV2";
 
   protected static final String JSON_COLUMN = "json";
   protected static final String DEFAULT_JSON_COLUMN = "defaultJson";
@@ -122,11 +129,17 @@ public abstract class BaseTransformFunctionTest {
   protected final double[][] _doubleMVValues = new double[NUM_ROWS][];
   protected final String[][] _stringMVValues = new String[NUM_ROWS][];
   protected final String[][] _stringAlphaNumericMVValues = new 
String[NUM_ROWS][];
+  protected final String[][] _stringAlphaNumericMV2Values = new 
String[NUM_ROWS][];
   protected final String[][] _stringLongFormatMVValues = new 
String[NUM_ROWS][];
   protected final long[] _timeValues = new long[NUM_ROWS];
   protected final String[] _jsonValues = new String[NUM_ROWS];
   protected final float[][] _vector1Values = new float[NUM_ROWS][];
   protected final float[][] _vector2Values = new float[NUM_ROWS][];
+  protected final int[][] _intMonoIncreasingMV1Values = new int[NUM_ROWS][];
+  protected final int[][] _intMonoIncreasingMV2Values = new int[NUM_ROWS][];
+  protected final long[][] _longMV2Values = new long[NUM_ROWS][];
+  protected final float[][] _floatMV2Values = new float[NUM_ROWS][];
+  protected final double[][] _doubleMV2Values = new double[NUM_ROWS][];
 
   protected Map<String, DataSource> _dataSourceMap;
   protected ProjectionBlock _projectionBlock;
@@ -155,9 +168,15 @@ public abstract class BaseTransformFunctionTest {
       _doubleMVValues[i] = new double[numValues];
       _stringMVValues[i] = new String[numValues];
       _stringAlphaNumericMVValues[i] = new String[numValues];
+      _stringAlphaNumericMV2Values[i] = new String[numValues];
       _stringLongFormatMVValues[i] = new String[numValues];
       _vector1Values[i] = new float[VECTOR_DIM_SIZE];
       _vector2Values[i] = new float[VECTOR_DIM_SIZE];
+      _intMonoIncreasingMV1Values[i] = new int[numValues];
+      _intMonoIncreasingMV2Values[i] = new int[numValues];
+      _longMV2Values[i] = new long[numValues];
+      _floatMV2Values[i] = new float[numValues];
+      _doubleMV2Values[i] = new double[numValues];
 
       for (int j = 0; j < numValues; j++) {
         _intMVValues[i][j] = 1 + RANDOM.nextInt(MAX_MULTI_VALUE);
@@ -166,7 +185,13 @@ public abstract class BaseTransformFunctionTest {
         _doubleMVValues[i][j] = 1 + RANDOM.nextDouble();
         _stringMVValues[i][j] = df.format(_intSVValues[i] * 
RANDOM.nextDouble());
         _stringAlphaNumericMVValues[i][j] = 
RandomStringUtils.randomAlphanumeric(26);
+        _stringAlphaNumericMV2Values[i][j] = "a";
         _stringLongFormatMVValues[i][j] = df.format(_intSVValues[i] * 
RANDOM.nextLong());
+        _intMonoIncreasingMV1Values[i][j] = j;
+        _intMonoIncreasingMV2Values[i][j] = j + 1;
+        _longMV2Values[i][j] = 1L;
+        _floatMV2Values[i][j] = 1.0f;
+        _doubleMV2Values[i][j] = 1.0;
       }
 
       for (int j = 0; j < VECTOR_DIM_SIZE; j++) {
@@ -219,6 +244,7 @@ public abstract class BaseTransformFunctionTest {
       map.put(DOUBLE_MV_COLUMN, ArrayUtils.toObject(_doubleMVValues[i]));
       map.put(STRING_MV_COLUMN, _stringMVValues[i]);
       map.put(STRING_ALPHANUM_MV_COLUMN, _stringAlphaNumericMVValues[i]);
+      map.put(STRING_ALPHANUM_MV_COLUMN_2, _stringAlphaNumericMV2Values[i]);
       map.put(STRING_LONG_MV_COLUMN, _stringLongFormatMVValues[i]);
       map.put(TIMESTAMP_COLUMN, _timeValues[i]);
       if (isNullRow(i)) {
@@ -229,6 +255,11 @@ public abstract class BaseTransformFunctionTest {
       map.put(TIME_COLUMN, _timeValues[i]);
       _jsonValues[i] = JsonUtils.objectToJsonNode(map).toString();
       map.put(JSON_COLUMN, _jsonValues[i]);
+      map.put(INT_MONO_INCREASING_MV_1, 
ArrayUtils.toObject(_intMonoIncreasingMV1Values[i]));
+      map.put(INT_MONO_INCREASING_MV_2, 
ArrayUtils.toObject(_intMonoIncreasingMV2Values[i]));
+      map.put(LONG_MV_COLUMN_2, ArrayUtils.toObject(_longMV2Values[i]));
+      map.put(FLOAT_MV_COLUMN_2, ArrayUtils.toObject(_floatMV2Values[i]));
+      map.put(DOUBLE_MV_COLUMN_2, ArrayUtils.toObject(_doubleMV2Values[i]));
       GenericRow row = new GenericRow();
       row.init(map);
       rows.add(row);
@@ -254,10 +285,16 @@ public abstract class BaseTransformFunctionTest {
         .addMultiValueDimension(DOUBLE_MV_COLUMN, FieldSpec.DataType.DOUBLE)
         .addMultiValueDimension(STRING_MV_COLUMN, FieldSpec.DataType.STRING)
         .addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN, 
FieldSpec.DataType.STRING)
+        .addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN_2, 
FieldSpec.DataType.STRING)
         .addMultiValueDimension(STRING_LONG_MV_COLUMN, 
FieldSpec.DataType.STRING)
         .addMultiValueDimension(VECTOR_1_COLUMN, FieldSpec.DataType.FLOAT)
         .addMultiValueDimension(VECTOR_2_COLUMN, FieldSpec.DataType.FLOAT)
         .addMultiValueDimension(ZERO_VECTOR_COLUMN, FieldSpec.DataType.FLOAT)
+        .addMultiValueDimension(INT_MONO_INCREASING_MV_1, 
FieldSpec.DataType.INT)
+        .addMultiValueDimension(INT_MONO_INCREASING_MV_2, 
FieldSpec.DataType.INT)
+        .addMultiValueDimension(LONG_MV_COLUMN_2, FieldSpec.DataType.LONG)
+        .addMultiValueDimension(FLOAT_MV_COLUMN_2, FieldSpec.DataType.FLOAT)
+        .addMultiValueDimension(DOUBLE_MV_COLUMN_2, FieldSpec.DataType.DOUBLE)
         .addDateTime(TIMESTAMP_COLUMN, FieldSpec.DataType.TIMESTAMP, 
"1:MILLISECONDS:EPOCH", "1:MILLISECONDS")
         .addDateTime(TIMESTAMP_COLUMN_NULL, FieldSpec.DataType.TIMESTAMP, 
"1:MILLISECONDS:EPOCH", "1:MILLISECONDS")
         .addTime(new TimeGranularitySpec(FieldSpec.DataType.LONG, 
TimeUnit.MILLISECONDS, TIME_COLUMN), null).build();
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
index c16f0e9c23..5befeccc07 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
@@ -953,6 +953,122 @@ public class ScalarTransformFunctionWrapperTest extends 
BaseTransformFunctionTes
     testTransformFunction(transformFunction, expectedValues);
   }
 
+  @Test
+  public void testArrayIndexOfAllInt() {
+    ExpressionContext expression = RequestContextUtils.getExpression(
+        String.format("array_index_of_all_int(%s, 0)", 
INT_MONO_INCREASING_MV_1));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
+    assertFalse(transformFunction.getResultMetadata().isSingleValue());
+    int[][] expectedValues = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      int[] expectedValue = {0};
+      expectedValues[i] = expectedValue;
+    }
+    testTransformFunctionMV(transformFunction, expectedValues);
+  }
+
+  @Test
+  public void testArrayIndexOfAllLong() {
+    ExpressionContext expression = RequestContextUtils.getExpression(
+        String.format("array_index_of_all_long(%s, 1)", LONG_MV_COLUMN_2));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
+    assertFalse(transformFunction.getResultMetadata().isSingleValue());
+    int[][] expectedValues = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      int len = _longMV2Values[i].length;
+      int[] expectedValue = new int[len];
+      for (int j = 0; j < len; j++) {
+        expectedValue[j] = j;
+      }
+      expectedValues[i] = expectedValue;
+    }
+    testTransformFunctionMV(transformFunction, expectedValues);
+  }
+
+  @Test
+  public void testArrayIndexOfAllFloat() {
+    ExpressionContext expression = RequestContextUtils.getExpression(
+        String.format("array_index_of_all_float(%s, 1.0)", FLOAT_MV_COLUMN_2));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
+    assertFalse(transformFunction.getResultMetadata().isSingleValue());
+    int[][] expectedValues = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      int len = _floatMV2Values[i].length;
+      int[] expectedValue = new int[len];
+      for (int j = 0; j < len; j++) {
+        expectedValue[j] = j;
+      }
+      expectedValues[i] = expectedValue;
+    }
+    testTransformFunctionMV(transformFunction, expectedValues);
+  }
+
+  @Test
+  public void testArrayIndexOfAllDouble() {
+    ExpressionContext expression = RequestContextUtils.getExpression(
+        String.format("array_index_of_all_double(%s, 1.0)", 
DOUBLE_MV_COLUMN_2));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
+    assertFalse(transformFunction.getResultMetadata().isSingleValue());
+    int[][] expectedValues = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      int len = _doubleMV2Values[i].length;
+      int[] expectedValue = new int[len];
+      for (int j = 0; j < len; j++) {
+        expectedValue[j] = j;
+      }
+      expectedValues[i] = expectedValue;
+    }
+    testTransformFunctionMV(transformFunction, expectedValues);
+  }
+
+  @Test
+  public void testArrayIndexOfAllString() {
+    ExpressionContext expression = RequestContextUtils.getExpression(
+        String.format("array_index_of_all_string(%s, 'a')", 
STRING_ALPHANUM_MV_COLUMN_2));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
+    assertFalse(transformFunction.getResultMetadata().isSingleValue());
+    int[][] expectedValues = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      int len = _stringAlphaNumericMV2Values[i].length;
+      int[] expectedValue = new int[len];
+      for (int j = 0; j < len; j++) {
+        expectedValue[j] = j;
+      }
+      expectedValues[i] = expectedValue;
+    }
+    testTransformFunctionMV(transformFunction, expectedValues);
+  }
+
+  @Test
+  public void testIntersectIndices() {
+    ExpressionContext expression = RequestContextUtils.getExpression(
+        String.format("intersect_indices(%s, %s)", INT_MONO_INCREASING_MV_1, 
INT_MONO_INCREASING_MV_2));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
+    assertFalse(transformFunction.getResultMetadata().isSingleValue());
+    int[][] expectedValues = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      int len = _intMonoIncreasingMV1Values[i].length;
+      int[] expectedValue = new int[len - 1];
+      for (int j = 0; j < expectedValue.length; j++) {
+        expectedValue[j] = j + 1;
+      }
+      expectedValues[i] = expectedValue;
+    }
+    testTransformFunctionMV(transformFunction, expectedValues);
+  }
+
   @Test
   public void testBase64TransformFunction() {
     ExpressionContext expression = 
RequestContextUtils.getExpression(String.format("toBase64(%s)", 
BYTES_SV_COLUMN));


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to