This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 63931bc [SYSTEMDS-3285] Multithreaded compaction for transformencode
63931bc is described below
commit 63931bc8592eaa6d6eccca05d2fc37df29ca0ce5
Author: arnabp <[email protected]>
AuthorDate: Wed Feb 2 21:12:11 2022 +0100
[SYSTEMDS-3285] Multithreaded compaction for transformencode
This patch replaces the HashSet with a list for tracking the sparse
row indexes during apply, and adds a multithreaded compaction
logic. This change removes the post-processing bottleneck for
PassThrough and DummyCoding which led to a 3x improvement
for Criteo dataset (10M rows).
---
.../runtime/transform/encode/ColumnEncoder.java | 17 ++++++-----
.../transform/encode/ColumnEncoderComposite.java | 5 ++--
.../transform/encode/ColumnEncoderDummycode.java | 13 ++++----
.../transform/encode/ColumnEncoderPassThrough.java | 13 ++++----
.../transform/encode/MultiColumnEncoder.java | 35 +++++++++++++++-------
5 files changed, 46 insertions(+), 37 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index f98d323..d5aee51 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -29,9 +29,7 @@ import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.List;
-import java.util.Set;
import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
@@ -56,10 +54,10 @@ import org.apache.sysds.utils.stats.TransformStatistics;
public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder> {
protected static final Log LOG =
LogFactory.getLog(ColumnEncoder.class.getName());
protected static final int APPLY_ROW_BLOCKS_PER_COLUMN = 1;
- public static int BUILD_ROW_BLOCKS_PER_COLUMN = 1;
+ public static int BUILD_ROW_BLOCKS_PER_COLUMN = -1;
private static final long serialVersionUID = 2299156350718979064L;
protected int _colID;
- protected Set<Integer> _sparseRowsWZeros = null;
+ protected ArrayList<Integer> _sparseRowsWZeros = null;
protected enum TransformType{
BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, N_A
@@ -354,14 +352,14 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
return new ColumnApplyTask<>(this, in, out, outputCol,
startRow, blk);
}
- public Set<Integer> getSparseRowsWZeros(){
+ public List<Integer> getSparseRowsWZeros(){
return _sparseRowsWZeros;
}
- protected void addSparseRowsWZeros(Set<Integer> sparseRowsWZeros){
+ protected void addSparseRowsWZeros(ArrayList<Integer> sparseRowsWZeros){
synchronized (this){
if(_sparseRowsWZeros == null)
- _sparseRowsWZeros = new HashSet<>();
+ _sparseRowsWZeros = new ArrayList<>();
_sparseRowsWZeros.addAll(sparseRowsWZeros);
}
}
@@ -371,7 +369,10 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
}
protected int getNumBuildRowPartitions(){
- return ConfigurationManager.getParallelBuildBlocks();
+ if (BUILD_ROW_BLOCKS_PER_COLUMN == -1)
+ return ConfigurationManager.getParallelBuildBlocks();
+ else
+ return BUILD_ROW_BLOCKS_PER_COLUMN;
}
public enum EncoderType {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index b7c162b..4f72f9c 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -28,7 +28,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
@@ -358,12 +357,12 @@ public class ColumnEncoderComposite extends ColumnEncoder
{
}
@Override
- public Set<Integer> getSparseRowsWZeros(){
+ public List<Integer> getSparseRowsWZeros(){
return
_columnEncoders.stream().map(ColumnEncoder::getSparseRowsWZeros).flatMap(l -> {
if(l == null)
return null;
return l.stream();
- }).collect(Collectors.toSet());
+ }).collect(Collectors.toList());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 01e7d5c..e0efe53 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -24,10 +24,7 @@ import static
org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Objects;
-import java.util.Set;
+import java.util.*;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -89,7 +86,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
}
boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK ==
SparseBlock.Type.MCSR;
mcsr = false; //force CSR for transformencode
- Set<Integer> sparseRowsWZeros = null;
+ ArrayList<Integer> sparseRowsWZeros = null;
int index = _colID - 1;
for(int r = rowStart; r < getEndIndex(in.getNumRows(),
rowStart, blk); r++) {
// Since the recoded values are already offset in the
output matrix (same as input at this point)
@@ -111,7 +108,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
double val =
out.getSparseBlock().get(r).values()[index];
if(Double.isNaN(val)){
if(sparseRowsWZeros == null)
- sparseRowsWZeros = new
HashSet<>();
+ sparseRowsWZeros = new
ArrayList<>();
sparseRowsWZeros.add(r);
out.getSparseBlock().get(r).values()[index] = 0;
continue;
@@ -126,7 +123,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
double val = csrblock.values()[rptr[r]+index];
if(Double.isNaN(val)){
if(sparseRowsWZeros == null)
- sparseRowsWZeros = new
HashSet<>();
+ sparseRowsWZeros = new
ArrayList<>();
sparseRowsWZeros.add(r);
csrblock.values()[rptr[r]+index] = 0;
//test
continue;
@@ -137,7 +134,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
csrblock.values()[rptr[r]+index] = 1;
}
}
- if(sparseRowsWZeros != null){
+ if(sparseRowsWZeros != null) {
addSparseRowsWZeros(sparseRowsWZeros);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index 5ceba09..2f5739f 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.transform.encode;
import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -70,7 +71,7 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
@Override
protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
- double codes[] = new double[endInd-startInd];
+ double[] codes = new double[endInd-startInd];
for (int i=startInd; i<endInd; i++) {
codes[i-startInd] = in.getDoubleNaN(i, _colID-1);
}
@@ -78,7 +79,8 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
}
protected void applySparse(CacheBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk){
- Set<Integer> sparseRowsWZeros = null;
+ //Set<Integer> sparseRowsWZeros = null;
+ ArrayList<Integer> sparseRowsWZeros = null;
boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK ==
SparseBlock.Type.MCSR;
mcsr = false; //force CSR for transformencode
int index = _colID - 1;
@@ -92,7 +94,7 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
double v = codes[ii-rowStart];
if(v == 0) {
if(sparseRowsWZeros == null)
- sparseRowsWZeros = new
HashSet<>();
+ sparseRowsWZeros = new
ArrayList<>();
sparseRowsWZeros.add(ii);
}
if (mcsr) {
@@ -101,11 +103,6 @@ public class ColumnEncoderPassThrough extends
ColumnEncoder {
row.indexes()[index] = outputCol;
}
else { //csr
- if(v == 0) {
- if(sparseRowsWZeros == null)
- sparseRowsWZeros = new
HashSet<>();
- sparseRowsWZeros.add(ii);
- }
// Manually fill the column-indexes and
values array
SparseBlockCSR csrblock =
(SparseBlockCSR)out.getSparseBlock();
int rptr[] = csrblock.rowPointers();
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 053c452..2e71680 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -30,10 +30,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
+import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -428,18 +425,36 @@ public class MultiColumnEncoder implements Encoder {
private void outputMatrixPostProcessing(MatrixBlock output){
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
- Set<Integer> indexSet = _columnEncoders.stream()
+ int k = OptimizerUtils.getTransformNumThreads();
+ ForkJoinPool myPool = new ForkJoinPool(k);
+ List<Integer> indexSet = _columnEncoders.stream().parallel()
.map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> {
if(l == null)
return null;
return l.stream();
- }).collect(Collectors.toSet());
- if(!indexSet.stream().allMatch(Objects::isNull)){
- for(Integer row : indexSet){
- // TODO: Maybe MT in special cases when the
number of rows is large
- output.getSparseBlock().get(row).compact();
+ }).collect(Collectors.toList());
+
+ if (k == 1) {
+
if(!indexSet.stream().parallel().allMatch(Objects::isNull)) {
+ for(Integer row : indexSet)
+
output.getSparseBlock().get(row).compact();
+ }
+ }
+ else {
+ try {
+
if(!indexSet.stream().allMatch(Objects::isNull)) {
+ myPool.submit(() -> {
+
indexSet.stream().parallel().forEach(row -> {
+
output.getSparseBlock().get(row).compact();
+ });
+ }).get();
+ }
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
}
}
+ myPool.shutdown();
output.recomputeNonZeros();
if(DMLScript.STATISTICS)
TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime()-t0);