This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 89f8925a2eb9027be4f0691895e4889e72f3903a
Author: Pindikura Ravindra <[email protected]>
AuthorDate: Thu Sep 20 15:55:49 2018 +0530

    [Gandiva] add a java perf test for filter
    
    - removed some dcheck statements which were causing a perf issue
    - modified other micro benchmarks to eval 1M records, and validate
      evaluated time.
---
 cpp/src/gandiva/selection_vector.cc                |   4 +-
 cpp/src/gandiva/selection_vector_impl.h            |  12 +-
 .../org/apache/arrow/gandiva/evaluator/Filter.java |   4 -
 .../apache/arrow/gandiva/evaluator/Projector.java  |   3 -
 .../arrow/gandiva/evaluator/BaseEvaluatorTest.java | 148 ++++++++++++++++-----
 .../gandiva/evaluator/MicroBenchmarkTest.java      |  42 +++++-
 6 files changed, 157 insertions(+), 56 deletions(-)

diff --git a/cpp/src/gandiva/selection_vector.cc 
b/cpp/src/gandiva/selection_vector.cc
index 83c1d0d..a36cba9 100644
--- a/cpp/src/gandiva/selection_vector.cc
+++ b/cpp/src/gandiva/selection_vector.cc
@@ -40,6 +40,8 @@ Status SelectionVector::PopulateFromBitMap(const uint8_t* 
bitmap, int bitmap_siz
     return Status::Invalid(ss.str());
   }
 
+  int max_slots = GetMaxSlots();
+
   // jump  8-bytes at a time, add the index corresponding to each valid bit to 
the
   // the selection vector.
   int selection_idx = 0;
@@ -57,7 +59,7 @@ Status SelectionVector::PopulateFromBitMap(const uint8_t* 
bitmap, int bitmap_siz
         break;
       }
 
-      if (selection_idx >= GetMaxSlots()) {
+      if (selection_idx >= max_slots) {
         return Status::Invalid("selection vector has no remaining slots");
       }
       SetIndex(selection_idx, pos_in_bitmap);
diff --git a/cpp/src/gandiva/selection_vector_impl.h 
b/cpp/src/gandiva/selection_vector_impl.h
index dcdd222..5e5d271 100644
--- a/cpp/src/gandiva/selection_vector_impl.h
+++ b/cpp/src/gandiva/selection_vector_impl.h
@@ -40,17 +40,9 @@ class SelectionVectorImpl : public SelectionVector {
     raw_data_ = reinterpret_cast<C_TYPE*>(buffer->mutable_data());
   }
 
-  uint GetIndex(int index) const override {
-    DCHECK_LE(index, max_slots_);
-    return raw_data_[index];
-  }
-
-  void SetIndex(int index, uint value) override {
-    DCHECK_LE(index, max_slots_);
-    DCHECK_LE(value, GetMaxSupportedValue());
+  uint GetIndex(int index) const override { return raw_data_[index]; }
 
-    raw_data_[index] = static_cast<C_TYPE>(value);
-  }
+  void SetIndex(int index, uint value) override { raw_data_[index] = value; }
 
   ArrayPtr ToArray() const override;
 
diff --git 
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java 
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
index 51bf198..de4a24e 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
@@ -103,8 +103,6 @@ public class Filter {
     if (this.closed) {
       throw new EvaluatorClosedException();
     }
-    //TODO: remove later, only for diagnostic.
-    logger.info("Evaluate called for module with id {}", moduleId);
     int numRows = recordBatch.getLength();
     if (selectionVector.getMaxRecords() < numRows) {
       logger.error("selectionVector has capacity for " + numRows
@@ -141,8 +139,6 @@ public class Filter {
    * Closes the LLVM module representing this filter.
    */
   public void close() throws GandivaException {
-    //TODO: remove later, only for diagnostic.
-    logger.info("Close called for module with id {}", moduleId);
     if (this.closed) {
       return;
     }
diff --git 
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java 
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
index a0b3b02..7213b67 100644
--- 
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
+++ 
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
@@ -110,8 +110,6 @@ public class Projector {
    */
   public void evaluate(ArrowRecordBatch recordBatch, List<ValueVector> 
outColumns)
           throws GandivaException {
-    //TODO: remove later, only for diagnostic.
-    logger.info("Evaluate called for module with id {}", moduleId);
     if (this.closed) {
       throw new EvaluatorClosedException();
     }
@@ -163,7 +161,6 @@ public class Projector {
    * Closes the LLVM module representing this evaluator.
    */
   public void close() throws GandivaException {
-    logger.info("Close called for module with id {}", moduleId);
     if (this.closed) {
       return;
     }
diff --git 
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
 
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
index 33d782a..33f7649 100644
--- 
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
+++ 
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
@@ -20,6 +20,7 @@ package org.apache.arrow.gandiva.evaluator;
 
 import io.netty.buffer.ArrowBuf;
 import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.Condition;
 import org.apache.arrow.gandiva.expression.ExpressionTree;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
@@ -37,6 +38,90 @@ import org.junit.Before;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Random;
+import java.util.concurrent.TimeUnit;
+
+interface BaseEvaluator {
+  void evaluate(ArrowRecordBatch recordBatch, BufferAllocator allocator) 
throws GandivaException;
+
+  long getElapsedMillis();
+}
+
+class ProjectEvaluator implements BaseEvaluator {
+  private Projector projector;
+  private DataAndVectorGenerator generator;
+  private int numExprs;
+  private int maxRowsInBatch;
+  private long elapsedTime = 0;
+  private List<ValueVector> outputVectors = new ArrayList<>();
+
+  public ProjectEvaluator(Projector projector,
+                          DataAndVectorGenerator generator,
+                          int numExprs,
+                          int maxRowsInBatch) {
+    this.projector = projector;
+    this.generator = generator;
+    this.numExprs = numExprs;
+    this.maxRowsInBatch = maxRowsInBatch;
+  }
+
+  @Override
+  public void evaluate(ArrowRecordBatch recordBatch,
+                       BufferAllocator allocator) throws GandivaException {
+    // set up output vectors
+    // for each expression, generate the output vector
+    for (int i = 0; i < numExprs; i++) {
+      ValueVector valueVector = generator.generateOutputVector(maxRowsInBatch);
+      outputVectors.add(valueVector);
+    }
+
+    try {
+      long start = System.nanoTime();
+      projector.evaluate(recordBatch, outputVectors);
+      long finish = System.nanoTime();
+      elapsedTime += (finish - start);
+    } finally {
+      for (ValueVector valueVector : outputVectors) {
+        valueVector.close();
+      }
+    }
+    outputVectors.clear();
+  }
+
+  @Override
+  public long getElapsedMillis() {
+    return TimeUnit.NANOSECONDS.toMillis(elapsedTime);
+  }
+}
+
+class FilterEvaluator implements BaseEvaluator {
+  private Filter filter;
+  private long elapsedTime = 0;
+
+  public FilterEvaluator(Filter filter) {
+    this.filter = filter;
+  }
+
+  @Override
+  public void evaluate(ArrowRecordBatch recordBatch,
+                       BufferAllocator allocator) throws GandivaException {
+    ArrowBuf selectionBuffer = allocator.buffer(recordBatch.getLength() * 2);
+    SelectionVectorInt16 selectionVector = new 
SelectionVectorInt16(selectionBuffer);
+
+    try {
+      long start = System.nanoTime();
+      filter.evaluate(recordBatch, selectionVector);
+      long finish = System.nanoTime();
+      elapsedTime += (finish - start);
+    } finally {
+      selectionBuffer.close();
+    }
+  }
+
+  @Override
+  public long getElapsedMillis() {
+    return TimeUnit.NANOSECONDS.toMillis(elapsedTime);
+  }
+}
 
 interface DataAndVectorGenerator {
   public void writeData(ArrowBuf buffer);
@@ -186,20 +271,15 @@ class BaseEvaluatorTest {
     }
   }
 
-  private long generateDataAndEvaluate(DataAndVectorGenerator generator,
-                                       Projector evaluator,
-                                       int numFields, int numExprs,
+  private void generateDataAndEvaluate(DataAndVectorGenerator generator,
+                                       BaseEvaluator evaluator,
+                                       int numFields,
                                        int numRows, int maxRowsInBatch,
                                        int inputFieldSize)
     throws GandivaException, Exception {
     int numRemaining = numRows;
     List<ArrowBuf> inputData = new ArrayList<ArrowBuf>();
     List<ArrowFieldNode> fieldNodes = new ArrayList<ArrowFieldNode>();
-    List<ValueVector> outputVectors = new ArrayList<ValueVector>();
-
-    long start;
-    long finish;
-    long elapsedTime = 0;
 
     // set the bitmap
     while (numRemaining > 0) {
@@ -222,46 +302,50 @@ class BaseEvaluatorTest {
       // create record batch
       ArrowRecordBatch recordBatch = new ArrowRecordBatch(numRowsInBatch, 
fieldNodes, inputData);
 
-      // set up output vectors
-      // for each expression, generate the output vector
-      for (int i = 0; i < numExprs; i++) {
-        ValueVector valueVector = 
generator.generateOutputVector(maxRowsInBatch);
-        outputVectors.add(valueVector);
-      }
-
-      start = System.nanoTime();
-      evaluator.evaluate(recordBatch, outputVectors);
-      finish = System.nanoTime();
+      evaluator.evaluate(recordBatch, allocator);
 
-      elapsedTime += (finish - start);
       // fix numRemaining
       numRemaining -= numRowsInBatch;
 
       // release refs
       releaseRecordBatch(recordBatch);
-      releaseValueVectors(outputVectors);
 
       inputData.clear();
       fieldNodes.clear();
-      outputVectors.clear();
     }
-
-    return (elapsedTime / MILLION);
   }
 
-  long timedEvaluate(DataAndVectorGenerator generator,
-                     Schema schema, List<ExpressionTree> exprs,
-                     int numRows, int maxRowsInBatch,
-                     int inputFieldSize)
+  long timedProject(DataAndVectorGenerator generator,
+                    Schema schema, List<ExpressionTree> exprs,
+                    int numRows, int maxRowsInBatch,
+                    int inputFieldSize)
   throws GandivaException, Exception {
-    Projector eval = Projector.make(schema, exprs);
+    Projector projector = Projector.make(schema, exprs);
+    try {
+      ProjectEvaluator evaluator =
+        new ProjectEvaluator(projector, generator, exprs.size(), 
maxRowsInBatch);
+      generateDataAndEvaluate(generator, evaluator,
+        schema.getFields().size(), numRows, maxRowsInBatch, inputFieldSize);
+      return evaluator.getElapsedMillis();
+    } finally {
+      projector.close();
+    }
+  }
+
+  long timedFilter(DataAndVectorGenerator generator,
+                   Schema schema, Condition condition,
+                    int numRows, int maxRowsInBatch,
+                    int inputFieldSize)
+    throws GandivaException, Exception {
 
+    Filter filter = Filter.make(schema, condition);
     try {
-      return generateDataAndEvaluate(generator, eval,
-              schema.getFields().size(), exprs.size(),
-              numRows, maxRowsInBatch, inputFieldSize);
+      FilterEvaluator evaluator = new FilterEvaluator(filter);
+      generateDataAndEvaluate(generator, evaluator,
+        schema.getFields().size(), numRows, maxRowsInBatch, inputFieldSize);
+      return evaluator.getElapsedMillis();
     } finally {
-      eval.close();
+      filter.close();
     }
   }
 }
diff --git 
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java
 
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java
index 4bb88fa..c860963 100644
--- 
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java
+++ 
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java
@@ -20,16 +20,20 @@ package org.apache.arrow.gandiva.evaluator;
 
 import com.google.common.collect.Lists;
 
+import org.apache.arrow.gandiva.expression.Condition;
 import org.apache.arrow.gandiva.expression.ExpressionTree;
 import org.apache.arrow.gandiva.expression.TreeBuilder;
 import org.apache.arrow.gandiva.expression.TreeNode;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.List;
 
 public class MicroBenchmarkTest extends BaseEvaluatorTest {
+  private double toleranceRatio = 4.0;
+
   @Test
   public void testAdd3() throws Exception {
     Field x = Field.nullable("x", int32);
@@ -44,12 +48,13 @@ public class MicroBenchmarkTest extends BaseEvaluatorTest {
     List<Field> cols = Lists.newArrayList(x, N2x, N3x);
     Schema schema = new Schema(cols);
 
-    long timeTaken = timedEvaluate(new Int32DataAndVectorGenerator(allocator),
+    long timeTaken = timedProject(new Int32DataAndVectorGenerator(allocator),
             schema,
             Lists.newArrayList(expr),
-            100 * MILLION, 16 * THOUSAND,
+            1 * MILLION, 16 * THOUSAND,
             4);
-    System.out.println("Time taken for evaluating 100m records of add3 is " + 
timeTaken + "ms");
+    System.out.println("Time taken for projecting 1m records of add3 is " + 
timeTaken + "ms");
+    Assert.assertTrue(timeTaken <= 10 * toleranceRatio);
   }
 
   @Test
@@ -101,11 +106,36 @@ public class MicroBenchmarkTest extends BaseEvaluatorTest 
{
     ExpressionTree expr = TreeBuilder.makeExpression(topNode, x);
     Schema schema = new Schema(Lists.newArrayList(x));
 
-    long timeTaken = timedEvaluate(new 
BoundedInt32DataAndVectorGenerator(allocator, 250),
+    long timeTaken = timedProject(new 
BoundedInt32DataAndVectorGenerator(allocator, 250),
             schema,
             Lists.newArrayList(expr),
-            100 * MILLION, 16 * THOUSAND,
+            1 * MILLION, 16 * THOUSAND,
             4);
-    System.out.println("Time taken for evaluating 100m records of nestedIf is 
" + timeTaken + "ms");
+    System.out.println("Time taken for projecting 10m records of nestedIf is " 
+ timeTaken + "ms");
+    Assert.assertTrue(timeTaken <= 15 * toleranceRatio);
   }
+
+  @Test
+  public void testFilterAdd2() throws Exception {
+    Field x = Field.nullable("x", int32);
+    Field N2x = Field.nullable("N2x", int32);
+    Field N3x = Field.nullable("N3x", int32);
+
+    // x + N2x < N3x
+    TreeNode add = TreeBuilder.makeFunction("add", 
Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeField(N2x)), 
int32);
+    TreeNode less_than = TreeBuilder.makeFunction("less_than", 
Lists.newArrayList(add, TreeBuilder.makeField(N3x)), boolType);
+    Condition condition = TreeBuilder.makeCondition(less_than);
+
+    List<Field> cols = Lists.newArrayList(x, N2x, N3x);
+    Schema schema = new Schema(cols);
+
+    long timeTaken = timedFilter(new Int32DataAndVectorGenerator(allocator),
+      schema,
+      condition,
+      1 * MILLION, 16 * THOUSAND,
+      4);
+    System.out.println("Time taken for filtering 10m records of a+b<c is " + 
timeTaken + "ms");
+    Assert.assertTrue(timeTaken <= 12 * toleranceRatio);
+  }
+
 }
\ No newline at end of file

Reply via email to