This is an automated email from the ASF dual-hosted git repository.

suneet pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new c68388ebcd Vectorized version of string last aggregator (#12493)
c68388ebcd is described below

commit c68388ebcd6cc2a77d2f6c41320906b32a5f6028
Author: somu-imply <93540295+somu-im...@users.noreply.github.com>
AuthorDate: Mon May 9 17:02:38 2022 -0700

    Vectorized version of string last aggregator (#12493)
    
    * Vectorized version of string last aggregator
    
    * Updating string last and adding testcases
    
    * Updating code and adding testcases for serializable pairs
    
    * Addressing review comments
---
 .../aggregation/first/StringFirstLastUtils.java    |  43 +++++
 .../last/StringLastAggregatorFactory.java          |  28 +++
 .../last/StringLastVectorAggregator.java           | 190 +++++++++++++++++++++
 .../last/StringLastVectorAggregatorTest.java       | 167 ++++++++++++++++++
 .../apache/druid/sql/calcite/CalciteQueryTest.java |  39 ++++-
 5 files changed, 462 insertions(+), 5 deletions(-)

diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
index 323ce413e6..6b93be7d70 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java
@@ -27,6 +27,8 @@ import org.apache.druid.segment.DimensionHandlerUtils;
 import org.apache.druid.segment.NilColumnValueSelector;
 import org.apache.druid.segment.column.ColumnCapabilities;
 import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorObjectSelector;
 
 import javax.annotation.Nullable;
 import java.nio.ByteBuffer;
@@ -59,6 +61,47 @@ public class StringFirstLastUtils
            || SerializablePairLongString.class.isAssignableFrom(clazz);
   }
 
+  /**
+   * Returns whether an object *might* contain SerializablePairLongString 
objects.
+   */
+  public static boolean objectNeedsFoldCheck(Object obj)
+  {
+    if (obj == null) {
+      return false;
+    }
+    final Class<?> clazz = obj.getClass();
+    return clazz.isAssignableFrom(SerializablePairLongString.class)
+           || SerializablePairLongString.class.isAssignableFrom(clazz);
+  }
+
+  /**
+   * Return the object at a particular index from the vector selectors.
+   * index of bounds issues is the responsibility of the caller
+   */
+  public static SerializablePairLongString readPairFromVectorSelectorsAtIndex(
+      BaseLongVectorValueSelector timeSelector,
+      VectorObjectSelector valueSelector,
+      int index
+  )
+  {
+    final long time;
+    final String string;
+    final Object object = valueSelector.getObjectVector()[index];
+    if (object instanceof SerializablePairLongString) {
+      final SerializablePairLongString pair = (SerializablePairLongString) 
object;
+      time = pair.lhs;
+      string = pair.rhs;
+    } else if (object != null) {
+      time = timeSelector.getLongVector()[index];
+      string = DimensionHandlerUtils.convertObjectToString(object);
+    } else {
+      // Don't aggregate nulls.
+      return null;
+    }
+
+    return new SerializablePairLongString(time, string);
+  }
+
   @Nullable
   public static SerializablePairLongString readPairFromSelectors(
       final BaseLongColumnValueSelector timeSelector,
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
index 71bf66e608..39c5b29647 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java
@@ -31,14 +31,20 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.AggregatorUtil;
 import org.apache.druid.query.aggregation.BufferAggregator;
 import org.apache.druid.query.aggregation.SerializablePairLongString;
+import org.apache.druid.query.aggregation.VectorAggregator;
 import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory;
 import org.apache.druid.query.aggregation.first.StringFirstLastUtils;
 import org.apache.druid.query.cache.CacheKeyBuilder;
 import org.apache.druid.segment.BaseObjectColumnValueSelector;
+import org.apache.druid.segment.ColumnInspector;
 import org.apache.druid.segment.ColumnSelectorFactory;
 import org.apache.druid.segment.NilColumnValueSelector;
+import org.apache.druid.segment.column.ColumnCapabilities;
 import org.apache.druid.segment.column.ColumnHolder;
 import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.segment.vector.VectorObjectSelector;
 
 import javax.annotation.Nullable;
 import java.nio.ByteBuffer;
@@ -141,6 +147,28 @@ public class StringLastAggregatorFactory extends 
AggregatorFactory
     }
   }
 
+  @Override
+  public boolean canVectorize(ColumnInspector columnInspector)
+  {
+    return true;
+  }
+
+  @Override
+  public VectorAggregator factorizeVector(VectorColumnSelectorFactory 
selectorFactory)
+  {
+
+    ColumnCapabilities capabilities = 
selectorFactory.getColumnCapabilities(fieldName);
+    VectorObjectSelector vSelector = 
selectorFactory.makeObjectSelector(fieldName);
+    BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) 
selectorFactory.makeValueSelector(
+        timeColumn);
+    if (capabilities != null) {
+      return new StringLastVectorAggregator(timeSelector, vSelector, 
maxStringBytes);
+    } else {
+      return new StringLastVectorAggregator(null, vSelector, maxStringBytes);
+    }
+
+  }
+
   @Override
   public Comparator getComparator()
   {
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java
new file mode 100644
index 0000000000..045360ba61
--- /dev/null
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.last;
+
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.query.aggregation.SerializablePairLongString;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.query.aggregation.first.StringFirstLastUtils;
+import org.apache.druid.segment.DimensionHandlerUtils;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorObjectSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+public class StringLastVectorAggregator implements VectorAggregator
+{
+  private static final SerializablePairLongString INIT = new 
SerializablePairLongString(
+      DateTimes.MIN.getMillis(),
+      null
+  );
+  private final BaseLongVectorValueSelector timeSelector;
+  private final VectorObjectSelector valueSelector;
+  private final int maxStringBytes;
+  protected long lastTime;
+
+  public StringLastVectorAggregator(
+      final BaseLongVectorValueSelector timeSelector,
+      final VectorObjectSelector valueSelector,
+      final int maxStringBytes
+  )
+  {
+    this.timeSelector = timeSelector;
+    this.valueSelector = valueSelector;
+    this.maxStringBytes = maxStringBytes;
+  }
+
+  @Override
+  public void init(ByteBuffer buf, int position)
+  {
+    StringFirstLastUtils.writePair(buf, position, INIT, maxStringBytes);
+  }
+
+  @Override
+  public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+  {
+    if (timeSelector == null) {
+      return;
+    }
+    long[] times = timeSelector.getLongVector();
+    Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
+
+    lastTime = buf.getLong(position);
+    int index;
+    for (int i = endRow - 1; i >= startRow; i--) {
+      if (objectsWhichMightBeStrings[i] == null) {
+        continue;
+      }
+      if (times[i] < lastTime) {
+        break;
+      }
+      index = i;
+      final boolean foldNeeded = 
StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]);
+      if (foldNeeded) {
+        // Less efficient code path when folding is a possibility (we must 
read the value selector first just in case
+        // it's a foldable object).
+        final SerializablePairLongString inPair = 
StringFirstLastUtils.readPairFromVectorSelectorsAtIndex(
+            timeSelector,
+            valueSelector,
+            index
+        );
+        if (inPair != null) {
+          final long lastTime = buf.getLong(position);
+          if (inPair.lhs >= lastTime) {
+            StringFirstLastUtils.writePair(
+                buf,
+                position,
+                new SerializablePairLongString(inPair.lhs, inPair.rhs),
+                maxStringBytes
+            );
+          }
+        }
+      } else {
+        final long time = times[index];
+
+        if (time >= lastTime) {
+          final String value = 
DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[index]);
+          lastTime = time;
+          StringFirstLastUtils.writePair(
+              buf,
+              position,
+              new SerializablePairLongString(time, value),
+              maxStringBytes
+          );
+        }
+      }
+    }
+
+  }
+
+  @Override
+  public void aggregate(
+      ByteBuffer buf,
+      int numRows,
+      int[] positions,
+      @Nullable int[] rows,
+      int positionOffset
+  )
+  {
+    long[] timeVector = timeSelector.getLongVector();
+    Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
+
+    // iterate once over the object vector to find first non null element and
+    // determine if the type is Pair or not
+    boolean foldNeeded = false;
+    for (Object obj : objectsWhichMightBeStrings) {
+      if (obj == null) {
+        continue;
+      } else {
+        foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj);
+        break;
+      }
+    }
+
+    for (int i = 0; i < numRows; i++) {
+      int position = positions[i] + positionOffset;
+      int row = rows == null ? i : rows[i];
+      long lastTime = buf.getLong(position);
+      if (timeVector[row] >= lastTime) {
+        if (foldNeeded) {
+          final SerializablePairLongString inPair = 
StringFirstLastUtils.readPairFromVectorSelectorsAtIndex(
+              timeSelector,
+              valueSelector,
+              row
+          );
+          if (inPair != null) {
+            if (inPair.lhs >= lastTime) {
+              StringFirstLastUtils.writePair(
+                  buf,
+                  position,
+                  new SerializablePairLongString(inPair.lhs, inPair.rhs),
+                  maxStringBytes
+              );
+            }
+          }
+        } else {
+          final String value = 
DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[row]);
+          lastTime = timeVector[row];
+          StringFirstLastUtils.writePair(
+              buf,
+              position,
+              new SerializablePairLongString(lastTime, value),
+              maxStringBytes
+          );
+        }
+      }
+    }
+  }
+
+  @Nullable
+  @Override
+  public Object get(ByteBuffer buf, int position)
+  {
+    return StringFirstLastUtils.readPair(buf, position);
+  }
+
+  @Override
+  public void close()
+  {
+    // nothing to close
+  }
+}
+
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java
new file mode 100644
index 0000000000..428ff3e374
--- /dev/null
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.last;
+
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.query.aggregation.SerializablePairLongString;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.segment.vector.VectorObjectSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+
+@RunWith(MockitoJUnitRunner.class)
+public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
+{
+  private static final double EPSILON = 1e-5;
+  private static final String[] VALUES = new String[]{"a", "b", null, "c"};
+  private static final boolean[] NULLS = new boolean[]{false, false, true, 
false};
+  private static final String NAME = "NAME";
+  private static final String FIELD_NAME = "FIELD_NAME";
+  private static final String TIME_COL = "__time";
+  private long[] times = {2436, 6879, 7888, 8224};
+  private long[] timesSame = {2436, 2436};
+  private SerializablePairLongString[] pairs = {
+      new SerializablePairLongString(2345100L, "last"),
+      new SerializablePairLongString(2345001L, "notLast")
+  };
+
+  @Mock
+  private VectorObjectSelector selector;
+  @Mock
+  private VectorObjectSelector selectorForPairs;
+  @Mock
+  private BaseLongVectorValueSelector timeSelector;
+  @Mock
+  private BaseLongVectorValueSelector timeSelectorForPairs;
+  private ByteBuffer buf;
+  private StringLastVectorAggregator target;
+  private StringLastVectorAggregator targetWithPairs;
+
+  private StringLastAggregatorFactory stringLastAggregatorFactory;
+  @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+  private VectorColumnSelectorFactory selectorFactory;
+
+  @Before
+  public void setup()
+  {
+    byte[] randomBytes = new byte[1024];
+    ThreadLocalRandom.current().nextBytes(randomBytes);
+    buf = ByteBuffer.wrap(randomBytes);
+    Mockito.doReturn(VALUES).when(selector).getObjectVector();
+    Mockito.doReturn(times).when(timeSelector).getLongVector();
+    Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector();
+    Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector();
+    target = new StringLastVectorAggregator(timeSelector, selector, 10);
+    targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, 
selectorForPairs, 10);
+    clearBufferForPositions(0, 0);
+
+
+    
Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME);
+    
Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL);
+    stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME, 
FIELD_NAME, TIME_COL, 10);
+
+  }
+
+  @Test
+  public void testAggregateWithPairs()
+  {
+    targetWithPairs.aggregate(buf, 0, 0, pairs.length);
+    Pair<Long, String> result = (Pair<Long, String>) targetWithPairs.get(buf, 
0);
+    //Should come 0 as the last value as the left of the pair is greater
+    Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue());
+    Assert.assertEquals(pairs[0].rhs, result.rhs);
+  }
+
+  @Test
+  public void testFactory()
+  {
+    
Assert.assertTrue(stringLastAggregatorFactory.canVectorize(selectorFactory));
+    VectorAggregator vectorAggregator = 
stringLastAggregatorFactory.factorizeVector(selectorFactory);
+    Assert.assertNotNull(vectorAggregator);
+    Assert.assertEquals(StringLastVectorAggregator.class, 
vectorAggregator.getClass());
+  }
+
+  @Test
+  public void initValueShouldBeMinDate()
+  {
+    target.init(buf, 0);
+    long initVal = buf.getLong(0);
+    Assert.assertEquals(DateTimes.MIN.getMillis(), initVal);
+  }
+
+  @Test
+  public void aggregate()
+  {
+    target.aggregate(buf, 0, 0, VALUES.length);
+    Pair<Long, String> result = (Pair<Long, String>) target.get(buf, 0);
+    Assert.assertEquals(times[3], result.lhs.longValue());
+    Assert.assertEquals(VALUES[3], result.rhs);
+  }
+
+  @Test
+  public void aggregateBatchWithoutRows()
+  {
+    int[] positions = new int[]{0, 43, 70};
+    int positionOffset = 2;
+    clearBufferForPositions(positionOffset, positions);
+    target.aggregate(buf, 3, positions, null, positionOffset);
+    for (int i = 0; i < positions.length; i++) {
+      Pair<Long, String> result = (Pair<Long, String>) target.get(buf, 
positions[i] + positionOffset);
+      Assert.assertEquals(times[i], result.lhs.longValue());
+      Assert.assertEquals(VALUES[i], result.rhs);
+    }
+  }
+
+  @Test
+  public void aggregateBatchWithRows()
+  {
+    int[] positions = new int[]{0, 43, 70};
+    int[] rows = new int[]{3, 2, 0};
+    int positionOffset = 2;
+    clearBufferForPositions(positionOffset, positions);
+    target.aggregate(buf, 3, positions, rows, positionOffset);
+    for (int i = 0; i < positions.length; i++) {
+      Pair<Long, String> result = (Pair<Long, String>) target.get(buf, 
positions[i] + positionOffset);
+      Assert.assertEquals(times[rows[i]], result.lhs.longValue());
+      Assert.assertEquals(VALUES[rows[i]], result.rhs);
+    }
+  }
+
+  private void clearBufferForPositions(int offset, int... positions)
+  {
+    for (int position : positions) {
+      target.init(buf, offset + position);
+    }
+  }
+}
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 9feb822679..1af59214f8 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -683,8 +683,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
   @Test
   public void testLatestAggregators() throws Exception
   {
-    // Cannot vectorize until StringLast is vectorized
-    skipVectorize();
+
     testQuery(
         "SELECT "
         + "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), "
@@ -944,6 +943,39 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
     );
   }
 
+  @Test
+  public void testStringLatestGroupBy() throws Exception
+  {
+    testQuery(
+        "SELECT dim2, LATEST(dim4,10) AS val1 FROM druid.numfoo GROUP BY dim2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE3)
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(dimensions(new 
DefaultDimensionSpec("dim2", "_d0")))
+                        .setAggregatorSpecs(aggregators(
+                                                new 
StringLastAggregatorFactory("a0", "dim4", null, 10)
+                                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        NullHandling.sqlCompatible()
+        ? ImmutableList.of(
+            new Object[]{null, "b"},
+            new Object[]{"", "a"},
+            new Object[]{"a", "b"},
+            new Object[]{"abc", "b"}
+        )
+        : ImmutableList.of(
+            new Object[]{"", "b"},
+            new Object[]{"a", "b"},
+            new Object[]{"abc", "b"}
+        )
+    );
+  }
+
   // This test the off-heap (buffer) version of the EarliestAggregator 
(Double/Float/Long)
   @Test
   public void testPrimitiveEarliestInSubquery() throws Exception
@@ -999,9 +1031,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
   @Test
   public void testStringLatestInSubquery() throws Exception
   {
-    // Cannot vectorize LATEST aggregator for Strings
-    skipVectorize();
-
     testQuery(
         "SELECT SUM(val) FROM (SELECT dim2, LATEST(dim1, 10) AS val FROM foo 
GROUP BY dim2)",
         ImmutableList.of(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@druid.apache.org
For additional commands, e-mail: commits-h...@druid.apache.org

Reply via email to