This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 63931bc  [SYSTEMDS-3285] Multithreaded compaction for transformencode
63931bc is described below

commit 63931bc8592eaa6d6eccca05d2fc37df29ca0ce5
Author: arnabp <[email protected]>
AuthorDate: Wed Feb 2 21:12:11 2022 +0100

    [SYSTEMDS-3285] Multithreaded compaction for transformencode
    
    This patch replaces the HashSet with a list for tracking the sparse
    row indexes during apply, and adds a multithreaded compaction
    logic. This change removes the post-processing bottleneck for
    PassThrough and DummyCoding which led to a 3x improvement
    for Criteo dataset (10M rows).
---
 .../runtime/transform/encode/ColumnEncoder.java    | 17 ++++++-----
 .../transform/encode/ColumnEncoderComposite.java   |  5 ++--
 .../transform/encode/ColumnEncoderDummycode.java   | 13 ++++----
 .../transform/encode/ColumnEncoderPassThrough.java | 13 ++++----
 .../transform/encode/MultiColumnEncoder.java       | 35 +++++++++++++++-------
 5 files changed, 46 insertions(+), 37 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index f98d323..d5aee51 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -29,9 +29,7 @@ import java.io.ObjectOutput;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Set;
 import java.util.concurrent.Callable;
 
 import org.apache.commons.logging.Log;
@@ -56,10 +54,10 @@ import org.apache.sysds.utils.stats.TransformStatistics;
 public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder> {
        protected static final Log LOG = 
LogFactory.getLog(ColumnEncoder.class.getName());
        protected static final int APPLY_ROW_BLOCKS_PER_COLUMN = 1;
-       public static int BUILD_ROW_BLOCKS_PER_COLUMN = 1;
+       public static int BUILD_ROW_BLOCKS_PER_COLUMN = -1;
        private static final long serialVersionUID = 2299156350718979064L;
        protected int _colID;
-       protected Set<Integer> _sparseRowsWZeros = null;
+       protected ArrayList<Integer> _sparseRowsWZeros = null;
 
        protected enum TransformType{
                BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, N_A
@@ -354,14 +352,14 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
                return new ColumnApplyTask<>(this, in, out, outputCol, 
startRow, blk);
        }
 
-       public Set<Integer> getSparseRowsWZeros(){
+       public List<Integer> getSparseRowsWZeros(){
                return _sparseRowsWZeros;
        }
 
-       protected void addSparseRowsWZeros(Set<Integer> sparseRowsWZeros){
+       protected void addSparseRowsWZeros(ArrayList<Integer> sparseRowsWZeros){
                synchronized (this){
                        if(_sparseRowsWZeros == null)
-                               _sparseRowsWZeros = new HashSet<>();
+                               _sparseRowsWZeros = new ArrayList<>();
                        _sparseRowsWZeros.addAll(sparseRowsWZeros);
                }
        }
@@ -371,7 +369,10 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
        }
 
        protected int getNumBuildRowPartitions(){
-               return ConfigurationManager.getParallelBuildBlocks();
+               if (BUILD_ROW_BLOCKS_PER_COLUMN == -1)
+                       return ConfigurationManager.getParallelBuildBlocks();
+               else
+                       return BUILD_ROW_BLOCKS_PER_COLUMN;
        }
 
        public enum EncoderType {
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index b7c162b..4f72f9c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -28,7 +28,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.stream.Collectors;
 
@@ -358,12 +357,12 @@ public class ColumnEncoderComposite extends ColumnEncoder 
{
        }
 
        @Override
-       public Set<Integer> getSparseRowsWZeros(){
+       public List<Integer> getSparseRowsWZeros(){
                return 
_columnEncoders.stream().map(ColumnEncoder::getSparseRowsWZeros).flatMap(l -> {
                                        if(l == null)
                                                return null;
                                        return l.stream();
-                               }).collect(Collectors.toSet());
+                               }).collect(Collectors.toList());
        }
 
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 01e7d5c..e0efe53 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -24,10 +24,7 @@ import static 
org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
 import java.io.IOException;
 import java.io.ObjectInput;
 import java.io.ObjectOutput;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Objects;
-import java.util.Set;
+import java.util.*;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -89,7 +86,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
                }
                boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
                mcsr = false; //force CSR for transformencode
-               Set<Integer> sparseRowsWZeros = null;
+               ArrayList<Integer> sparseRowsWZeros = null;
                int index = _colID - 1;
                for(int r = rowStart; r < getEndIndex(in.getNumRows(), 
rowStart, blk); r++) {
                        // Since the recoded values are already offset in the 
output matrix (same as input at this point)
@@ -111,7 +108,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
                                double val = 
out.getSparseBlock().get(r).values()[index];
                                if(Double.isNaN(val)){
                                        if(sparseRowsWZeros == null)
-                                               sparseRowsWZeros = new 
HashSet<>();
+                                               sparseRowsWZeros = new 
ArrayList<>();
                                        sparseRowsWZeros.add(r);
                                        
out.getSparseBlock().get(r).values()[index] = 0;
                                        continue;
@@ -126,7 +123,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
                                double val = csrblock.values()[rptr[r]+index];
                                if(Double.isNaN(val)){
                                        if(sparseRowsWZeros == null)
-                                               sparseRowsWZeros = new 
HashSet<>();
+                                               sparseRowsWZeros = new 
ArrayList<>();
                                        sparseRowsWZeros.add(r);
                                        csrblock.values()[rptr[r]+index] = 0; 
//test
                                        continue;
@@ -137,7 +134,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
                                csrblock.values()[rptr[r]+index] = 1;
                        }
                }
-               if(sparseRowsWZeros != null){
+               if(sparseRowsWZeros != null) {
                        addSparseRowsWZeros(sparseRowsWZeros);
                }
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index 5ceba09..2f5739f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.transform.encode;
 
 import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
 
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -70,7 +71,7 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
        @Override
        protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) 
{
                int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
-               double codes[] = new double[endInd-startInd];
+               double[] codes = new double[endInd-startInd];
                for (int i=startInd; i<endInd; i++) {
                        codes[i-startInd] = in.getDoubleNaN(i, _colID-1);
                }
@@ -78,7 +79,8 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
        }
 
        protected void applySparse(CacheBlock in, MatrixBlock out, int 
outputCol, int rowStart, int blk){
-               Set<Integer> sparseRowsWZeros = null;
+               //Set<Integer> sparseRowsWZeros = null;
+               ArrayList<Integer> sparseRowsWZeros = null;
                boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
                mcsr = false; //force CSR for transformencode
                int index = _colID - 1;
@@ -92,7 +94,7 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
                                double v = codes[ii-rowStart];
                                if(v == 0) {
                                        if(sparseRowsWZeros == null)
-                                               sparseRowsWZeros = new 
HashSet<>();
+                                               sparseRowsWZeros = new 
ArrayList<>();
                                        sparseRowsWZeros.add(ii);
                                }
                                if (mcsr) {
@@ -101,11 +103,6 @@ public class ColumnEncoderPassThrough extends 
ColumnEncoder {
                                        row.indexes()[index] = outputCol;
                                }
                                else { //csr
-                                       if(v == 0) {
-                                               if(sparseRowsWZeros == null)
-                                                       sparseRowsWZeros = new 
HashSet<>();
-                                               sparseRowsWZeros.add(ii);
-                                       }
                                        // Manually fill the column-indexes and 
values array
                                        SparseBlockCSR csrblock = 
(SparseBlockCSR)out.getSparseBlock();
                                        int rptr[] = csrblock.rowPointers();
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 053c452..2e71680 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -30,10 +30,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
+import java.util.concurrent.*;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -428,18 +425,36 @@ public class MultiColumnEncoder implements Encoder {
 
        private void outputMatrixPostProcessing(MatrixBlock output){
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
-               Set<Integer> indexSet = _columnEncoders.stream()
+               int k = OptimizerUtils.getTransformNumThreads();
+               ForkJoinPool myPool = new ForkJoinPool(k);
+               List<Integer> indexSet = _columnEncoders.stream().parallel()
                                
.map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> {
                                        if(l == null)
                                                return null;
                                        return l.stream();
-                               }).collect(Collectors.toSet());
-               if(!indexSet.stream().allMatch(Objects::isNull)){
-                       for(Integer row : indexSet){
-                               // TODO: Maybe MT in special cases when the 
number of rows is large
-                               output.getSparseBlock().get(row).compact();
+                               }).collect(Collectors.toList());
+
+               if (k == 1) {
+                       
if(!indexSet.stream().parallel().allMatch(Objects::isNull)) {
+                               for(Integer row : indexSet)
+                                       
output.getSparseBlock().get(row).compact();
+                       }
+               }
+               else {
+                       try {
+                               
if(!indexSet.stream().allMatch(Objects::isNull)) {
+                                       myPool.submit(() -> {
+                                               
indexSet.stream().parallel().forEach(row -> {
+                                                       
output.getSparseBlock().get(row).compact();
+                                               });
+                                       }).get();
+                               }
+                       }
+                       catch(Exception ex) {
+                               throw new DMLRuntimeException(ex);
                        }
                }
+               myPool.shutdown();
                output.recomputeNonZeros();
                if(DMLScript.STATISTICS)
                        
TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime()-t0);

Reply via email to