This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 09f0be0  [SYSTEMDS-3283] Multi-threaded ctable instruction
09f0be0 is described below

commit 09f0be03c820d6851033bfc6469df7703cac0faa
Author: arnabp <[email protected]>
AuthorDate: Mon Jan 31 09:26:46 2022 +0100

    [SYSTEMDS-3283] Multi-threaded ctable instruction
    
    This patch implements a multithreaded version of
    F = ctable(A, B, W) case. Other cases will be supported
    in the future. Each thread constructs a separate
    CTableMap from a block of rows. Later we cascade-merge
    the partial maps.
    This implementation shows 8x improvement for
    23M rows with 470K unique values.
    
    Closes #1530.
---
 .../sysds/runtime/functionobjects/CTable.java      | 143 +++++++++++++++++++++
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  25 ++--
 2 files changed, 157 insertions(+), 11 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
index fc44ed1..291effa 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
@@ -24,8 +24,17 @@ import org.apache.sysds.runtime.matrix.data.CTableMap;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.LongLongDoubleHashMap;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
 public class CTable extends ValueFunction 
 {
        private static final long serialVersionUID = -5374880447194177236L;
@@ -139,4 +148,138 @@ public class CTable extends ValueFunction
                        throw new DMLRuntimeException("Erroneous input while 
computing the contingency table (value <= zero): "+v2);
                return new Pair<>(new MatrixIndexes(row, col), w);
        }
+
+       /* Multithreaded CTable (F = ctable(A,B,W))
+        * Divide the input vectors into equal-sized blocks and assign each 
block to a task.
+        * All tasks concurrently build their own CTableMaps.
+        * Cascade merge the partial maps.
+        * TODO: Support other cases
+        */
+       public void execute(MatrixBlock in1, MatrixBlock in2, MatrixBlock w, 
CTableMap resultMap, int k) {
+               ExecutorService pool = CommonThreadPool.get(k);
+               ArrayList<CTableMap> partialMaps = new ArrayList<>();
+               try {
+                       // Assign an equal-sized blocks to each task
+                       List<Callable<Object>> tasks = new ArrayList<>();
+                       int[] blockSizes = 
UtilFunctions.getBlockSizes(in1.getNumRows(), k);
+                       // Each task builds a separate CTableMap in a lock-free 
manner
+                       for(int startRow = 0, i = 0; i < blockSizes.length; 
startRow += blockSizes[i], i++)
+                               tasks.add(getPartialCTableTask(in1, in2, w, 
startRow, blockSizes[i], partialMaps));
+                       List<Future<Object>> taskret = pool.invokeAll(tasks);
+                       for(var task : taskret)
+                               task.get();
+               }
+               catch(Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+
+               ArrayList<CTableMap> newPartialMaps = new ArrayList<>();
+               // Cascade-merge all the partial CTableMaps
+               while(partialMaps.size() > 1) {
+                       newPartialMaps.clear();
+                       List<Callable<Object>> tasks = new ArrayList<>();
+                       int count;
+                       // Each task merges 2 maps and returns the merged map
+                       for (count=0; count+1<partialMaps.size(); count=count+2)
+                               
tasks.add(getMergePartialCTMapsTask(partialMaps.get(count),
+                                       partialMaps.get(count+1), 
newPartialMaps));
+
+                       try {
+                               List<Future<Object>> taskret = 
pool.invokeAll(tasks);
+                               for(var task : taskret)
+                                       task.get();
+                       }
+                       catch(Exception ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+                       // Copy the remaining maps to be merged in the future 
iterations
+                       if (count < partialMaps.size())
+                               newPartialMaps.add(partialMaps.get(count));
+                       partialMaps.clear();
+                       partialMaps.addAll(newPartialMaps);
+               }
+               pool.shutdown();
+               // Deep copy the last merged map into the result map
+               var map = partialMaps.get(0);
+               Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = 
map.getIterator();
+               while(iter.hasNext()) {
+                       LongLongDoubleHashMap.ADoubleEntry e = iter.next();
+                       resultMap.aggregate(e.getKey1(), e.getKey2(), e.value);
+               }
+       }
+
+       public Callable<Object> getPartialCTableTask(MatrixBlock in1, 
MatrixBlock in2, MatrixBlock w,
+               int startInd, int blockSize, ArrayList<CTableMap> pmaps) {
+               return new PartialCTableTask(in1, in2, w, startInd, blockSize, 
pmaps);
+       }
+
+       public Callable<Object> getMergePartialCTMapsTask(CTableMap map1, 
CTableMap map2, ArrayList<CTableMap> pmaps) {
+               return new MergePartialCTMaps(map1, map2, pmaps);
+       }
+
+       private static class PartialCTableTask implements Callable<Object> {
+               private final MatrixBlock _in1;
+               private final MatrixBlock _in2;
+               private final MatrixBlock _w;
+               private final int _startInd;
+               private final int _blockSize;
+               private final ArrayList<CTableMap> _partialCTmaps;
+
+               protected PartialCTableTask(MatrixBlock in1, MatrixBlock in2, 
MatrixBlock w,
+                       int startRow, int blockSize, ArrayList<CTableMap> 
pmaps) {
+                       _in1 = in1;
+                       _in2 = in2;
+                       _w = w;
+                       _startInd = startRow;
+                       _blockSize = blockSize;
+                       _partialCTmaps = pmaps;
+               }
+
+               @Override public Object call() throws Exception {
+                       CTable ctable = CTable.getCTableFnObject();
+                       CTableMap ctmap = new 
CTableMap(LongLongDoubleHashMap.EntryType.INT);
+                       int endInd = 
UtilFunctions.getEndIndex(_in1.getNumRows(), _startInd, _blockSize);
+                       for( int i=_startInd; i<endInd; i++ )
+                       {
+                               double v1 = _in1.quickGetValue(i, 0);
+                               double v2 = _in2.quickGetValue(i, 0);
+                               double w = _w.quickGetValue(i, 0);
+                               ctable.execute(v1, v2, w, false, ctmap);
+                       }
+                       synchronized(_partialCTmaps) {
+                               _partialCTmaps.add(ctmap);
+                       }
+                       return null;
+               }
+       }
+
+       private static class MergePartialCTMaps implements Callable<Object> {
+               private final CTableMap _map1;
+               private final CTableMap _map2;
+               private final ArrayList<CTableMap> _partialCTmaps;
+
+               protected MergePartialCTMaps(CTableMap map1, CTableMap map2, 
ArrayList<CTableMap> pmaps) {
+                       _map1 = map1;
+                       _map2 = map2;
+                       _partialCTmaps = pmaps;
+               }
+
+               private void mergeToFinal(CTableMap map, CTableMap finalMap) {
+                       Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = 
map.getIterator();
+                       while(iter.hasNext()) {
+                               LongLongDoubleHashMap.ADoubleEntry e = 
iter.next();
+                               finalMap.aggregate(e.getKey1(), e.getKey2(), 
e.value);
+                       }
+               }
+
+               @Override public Object call() throws Exception {
+                       CTableMap mergedMap = new 
CTableMap(LongLongDoubleHashMap.EntryType.INT);
+                       mergeToFinal(_map1, mergedMap);
+                       mergeToFinal(_map2, mergedMap);
+                       synchronized(_partialCTmaps) {
+                               _partialCTmaps.add(mergedMap);
+                               return null;
+                       }
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index a0fcef6..683cd14 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5501,19 +5501,22 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                MatrixBlock that = checkType(thatVal);
                MatrixBlock that2 = checkType(that2Val);
                CTable ctable = CTable.getCTableFnObject();
-               
+               int k = OptimizerUtils.getTransformNumThreads();
                //sparse-unsafe ctable execution
                //(because input values of 0 are invalid and have to result in 
errors) 
-               if(resultBlock == null) 
-               {
-                       for( int i=0; i<rlen; i++ )
-                               for( int j=0; j<clen; j++ )
-                               {
-                                       double v1 = this.quickGetValue(i, j);
-                                       double v2 = that.quickGetValue(i, j);
-                                       double w = that2.quickGetValue(i, j);
-                                       ctable.execute(v1, v2, w, false, 
resultMap);
-                               }               
+               if(resultBlock == null) {
+                       if (k > 1 && clen == 1)
+                               //TODO: Find the optimum k during compilation
+                               ctable.execute(this, that, that2, resultMap, k);
+                       else {
+                               for(int i = 0; i < rlen; i++)
+                                       for(int j = 0; j < clen; j++) {
+                                               double v1 = 
this.quickGetValue(i, j);
+                                               double v2 = 
that.quickGetValue(i, j);
+                                               double w = 
that2.quickGetValue(i, j);
+                                               ctable.execute(v1, v2, w, 
false, resultMap);
+                                       }
+                       }
                }
                else 
                {

Reply via email to