This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 8ecd1fb [SYSTEMDS-3105] CLA Left MM Shard Common element sum upgrade
8ecd1fb is described below
commit 8ecd1fb5a14d6b82fc41d240ce2f477ef4d859e9
Author: baunsgaard <[email protected]>
AuthorDate: Fri Aug 27 22:02:29 2021 +0200
[SYSTEMDS-3105] CLA Left MM Shard Common element sum upgrade
This commit expand on the shared element sum from yesterday, improving
the performance gains further. This together with the workload aware
improvements today, make LMM with 16 rows on the left
go 19-21x faster than the default sparse matrix multiplication on
census_enc.
makit it take 10 ms per multiplication vs 200 ms with our default.
---
.../runtime/compress/lib/CLALibLeftMultBy.java | 222 ++++++++++++---------
1 file changed, 126 insertions(+), 96 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 3ca657b..10cd2e8 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -37,6 +37,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils;
+import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -73,7 +74,7 @@ public class CLALibLeftMultBy {
if(m2.isEmpty())
return ret;
- ret = leftMultByMatrix(m1.getColGroups(), m2, ret, k,
m1.getNumColumns(), m1.isOverlapping());
+ ret = leftMultByMatrix(m1.getColGroups(), m2, ret, k,
m1.isOverlapping());
ret.recomputeNonZeros();
return ret;
}
@@ -230,13 +231,13 @@ public class CLALibLeftMultBy {
}
private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups,
MatrixBlock that, MatrixBlock ret, int k,
- int numColumns, boolean overlapping) {
+ boolean overlapping) {
if(that.isEmpty()) {
ret.setNonZeros(0);
return ret;
}
-
+ final int numColumnsOut = ret.getNumColumns();
boolean containsSDC = false;
for(AColGroup g : colGroups) {
@@ -246,7 +247,7 @@ public class CLALibLeftMultBy {
final List<AColGroup> filteredGroups = containsSDC ? new
ArrayList<>() : colGroups;
// a constant colgroup summing the default values.
- final double[] constV = containsSDC ? new double[numColumns] :
null;
+ final double[] constV = containsSDC ? new double[numColumnsOut]
: null;
if(containsSDC) {
for(AColGroup g : colGroups) {
@@ -260,23 +261,17 @@ public class CLALibLeftMultBy {
}
ret.allocateDenseBlock();
+ final double[] rowSums = containsSDC ? new
double[that.getNumRows()] : null;
if(k == 1) {
- leftMultByMatrixPrimitive(filteredGroups, that, ret,
numColumns, 0, that.getNumRows());
- if(containsSDC) {
- MatrixBlock rowSum = that.rowSum();
- if(rowSum.isInSparseFormat())
- rowSum.sparseToDense();
- double[] rowSums = rowSum.getDenseBlockValues();
- outerProduct(rowSums, constV,
ret.getDenseBlockValues());
- }
+ leftMultByMatrixPrimitive(filteredGroups, that, ret, 0,
that.getNumRows(), rowSums);
}
else {
try {
final ExecutorService pool =
CommonThreadPool.get(k);
final ArrayList<Callable<MatrixBlock>> tasks =
new ArrayList<>();
- final int rowBlockSize = that.getNumRows() < 8
? 1 : Math.min(Math.max(that.getNumRows() / k, 1), 8);
- double[] rowSums = null;
+ final int rowBlockSize = that.getNumRows() <= k
? 1 : Math.min(Math.max(that.getNumRows() / k * 2, 1),
+ 8);
if(overlapping) {
for(AColGroup g : filteredGroups) {
@@ -288,54 +283,56 @@ public class CLALibLeftMultBy {
}
List<Future<MatrixBlock>> futures =
pool.invokeAll(tasks);
- if(containsSDC) {
- MatrixBlock rowSum =
that.rowSum();
- if(rowSum.isInSparseFormat())
- rowSum.sparseToDense();
- rowSums =
rowSum.getDenseBlockValues();
- }
pool.shutdown();
BinaryOperator op = new
BinaryOperator(Plus.getPlusFnObject());
for(Future<MatrixBlock> future :
futures)
ret.binaryOperationsInPlace(op,
future.get());
}
else {
- if(rowBlockSize > 2) {
+ final int numberSplits = Math.max((k /
(ret.getNumRows() / rowBlockSize)), 1);
+ // LOG.error("RowBLockSize:"
+rowBlockSize + " Splits " + numberSplits);
+ if(numberSplits == 1) {
for(int blo = 0; blo <
that.getNumRows(); blo += rowBlockSize) {
- tasks.add(new
LeftMatrixColGroupMultTaskNew(filteredGroups, that, ret, numColumns, blo,
- Math.min(blo +
rowBlockSize, that.getNumRows())));
+ tasks.add(new
LeftMatrixColGroupMultTaskNew(filteredGroups, that, ret, blo,
+ Math.min(blo +
rowBlockSize, that.getNumRows()), rowSums));
}
}
else {
- List<List<AColGroup>> split =
split(filteredGroups, Math.max(k / that.getNumRows(), 1));
+ List<List<AColGroup>> split =
split(filteredGroups, numberSplits);
for(int blo = 0; blo <
that.getNumRows(); blo += rowBlockSize) {
- for(List<AColGroup> gr
: split)
- tasks.add(new
LeftMatrixColGroupMultTaskNew(gr, that, ret, numColumns, blo,
-
Math.min(blo + rowBlockSize, that.getNumRows())));
+ for(int i = 0; i <
split.size(); i++) {
+ List<AColGroup>
gr = split.get(i);
+ if(i == 0) {
+ // the
first thread also have the responsibility to calculate the som of the left
+ // hand
side.
+
tasks.add(new LeftMatrixColGroupMultTaskNew(gr, that, ret, blo,
+
Math.min(blo + rowBlockSize, that.getNumRows()), rowSums));
+ }
+ else {
+
tasks.add(new LeftMatrixColGroupMultTaskNew(gr, that, ret, blo,
+
Math.min(blo + rowBlockSize, that.getNumRows()), null));
+ }
+ }
}
}
List<Future<MatrixBlock>> futures =
pool.invokeAll(tasks);
- if(containsSDC) {
- MatrixBlock rowSum =
that.rowSum();
- if(rowSum.isInSparseFormat())
- rowSum.sparseToDense();
- rowSums =
rowSum.getDenseBlockValues();
- }
+
pool.shutdown();
for(Future<MatrixBlock> future :
futures)
future.get();
}
- if(containsSDC)
- outerProduct(rowSums, constV,
ret.getDenseBlockValues());
-
}
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
}
+ // add the correction layer for the subtracted common values.
+ if(rowSums != null)
+ outerProduct(rowSums, constV,
ret.getDenseBlockValues());
+
ret.recomputeNonZeros();
return ret;
}
@@ -396,112 +393,145 @@ public class CLALibLeftMultBy {
private final MatrixBlock _ret;
private final int _rl;
private final int _ru;
- private final int _numColumns;
+ private final double[] _rowSums;
- protected LeftMatrixColGroupMultTaskNew(List<AColGroup> groups,
MatrixBlock that, MatrixBlock ret,
- int numColumns, int rl, int ru) {
+ protected LeftMatrixColGroupMultTaskNew(List<AColGroup> groups,
MatrixBlock that, MatrixBlock ret, int rl,
+ int ru, double[] rowSums) {
_groups = groups;
_that = that;
_ret = ret;
_rl = rl;
_ru = ru;
- _numColumns = numColumns;
+ _rowSums = rowSums;
}
@Override
public MatrixBlock call() {
try {
- leftMultByMatrixPrimitive(_groups, _that, _ret,
_numColumns, _rl, _ru);
+ leftMultByMatrixPrimitive(_groups, _that, _ret,
_rl, _ru, _rowSums);
}
catch(Exception e) {
+ e.printStackTrace();
throw new DMLRuntimeException(e);
}
return _ret;
}
}
- private static void leftMultByMatrixPrimitive(List<AColGroup>
colGroups, MatrixBlock that, MatrixBlock ret,
- int numColumns, int rl, int ru) {
+ private static void leftMultByMatrixPrimitive(List<AColGroup>
colGroups, MatrixBlock that, MatrixBlock ret, int rl,
+ int ru, double[] rowSums) {
+ if(that.isInSparseFormat())
+ leftMultByMatrixPrimitiveSparse(colGroups, that, ret,
rl, ru, rowSums);
+ else
+ leftMultByMatrixPrimitiveDense(colGroups, that, ret,
rl, ru, rowSums);
+ }
+
+ private static void leftMultByMatrixPrimitiveSparse(List<AColGroup>
colGroups, MatrixBlock that, MatrixBlock ret,
+ int rl, int ru, double[] rowSum) {
- if(that.isInSparseFormat()) {
- for(int i = rl; i < ru; i++) {
- for(int j = 0; j < colGroups.size(); j++) {
- colGroups.get(j).leftMultByMatrix(that,
ret, i, i + 1);
+ for(int i = rl; i < ru; i++) {
+ for(int j = 0; j < colGroups.size(); j++) {
+ colGroups.get(j).leftMultByMatrix(that, ret, i,
i + 1);
+ }
+ if(rowSum != null) {
+ final SparseBlock sb = that.getSparseBlock();
+ if(!sb.isEmpty(i)){
+ final int apos = sb.pos(i);
+ final int alen = sb.size(i) + apos;
+ final double[] aval = sb.values(i);
+ for(int j = apos; j < alen; j++)
+ rowSum[i] += aval[j];
}
}
}
- else {
- // The number of rows to process together
- final int rowBlockSize = 1;
- // The number of column groups to process together
- final int colGroupBlocking = 16;
+ }
- // Allocate pre Aggregate Array List
- final List<MatrixBlock> preAgg =
populatePreAggregate(colGroupBlocking);
- // Allocate a ColGroupValue array for the Column Groups
of Value Type.
- final List<ColGroupValue> ColGroupValues =
preFilterAndMultiply(colGroups, that, ret, numColumns, rl, ru);
+ private static void leftMultByMatrixPrimitiveDense(List<AColGroup>
colGroups, MatrixBlock that, MatrixBlock ret,
+ int rl, int ru, double[] rowSum) {
- // Allocate temporary Result matrix.
- MatrixBlock tmpRes = new MatrixBlock(rowBlockSize,
numColumns, false);
+ final int numColsOut = ret.getNumColumns();
+ // Allocate a ColGroupValue array for the Column Groups of
Value Type and multiply out any other columns.
+ final List<ColGroupValue> ColGroupValues =
preFilterAndMultiply(colGroups, that, ret, rl, ru);
- for(int g = 0; g < ColGroupValues.size(); g +=
colGroupBlocking) {
- final int gEnd = Math.min(g + colGroupBlocking,
colGroups.size());
+ // The number of rows to process together
+ final int rowBlockSize = 1;
+ // The number of column groups to process together
+ // the value should ideally be set so that the colgroups fits
into cache together with a row block.
+ // currently we only try to avoid having a dangling small
number of column groups in the last block.
+ final int colGroupBlocking = ColGroupValues.size() % 16 < 4 ?
20 : 16;
- // for each column group in the current block
allocate the preaggregate array.
- for(int j = g; j < gEnd && j <
ColGroupValues.size(); j++) {
- ColGroupValue cg =
ColGroupValues.get(j);
- int nVals = cg.getNumValues();
- preAgg.get(j %
colGroupBlocking).reset(rowBlockSize, nVals, false);
- }
+ // Allocate pre Aggregate Array List
+ final MatrixBlock[] preAgg =
populatePreAggregate(colGroupBlocking);
- // int colBlockSize = 16000;
- int colBlockSize = 64000;
-
- // For each row block
- for(int h = rl; h < ru; h += rowBlockSize) {
- // For each column block
- for(int i = 0; i <
that.getNumColumns(); i += colBlockSize) {
- // Pre Aggregate each column
group in block
- for(int j = g; j < gEnd && j <
ColGroupValues.size(); j++) {
-
ColGroupValues.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking),
h,
- Math.min(h +
rowBlockSize, ru), i, Math.min(i + colBlockSize, that.getNumColumns()));
- }
- }
+ // Allocate temporary Result matrix.
+ MatrixBlock tmpRes = new MatrixBlock(rowBlockSize, numColsOut,
false);
+
+ // For each column group block
+ for(int g = 0; g < ColGroupValues.size(); g +=
colGroupBlocking) {
+ final int gEnd = Math.min(g + colGroupBlocking,
ColGroupValues.size());
+
+ // For each column group in the current block allocate
the preaggregate array.
+ for(int j = g; j < gEnd && j < ColGroupValues.size();
j++) {
+ ColGroupValue cg = ColGroupValues.get(j);
+ int nVals = cg.getNumValues();
+ preAgg[j %
colGroupBlocking].reset(rowBlockSize, nVals, false);
+ }
+
+ int colBlockSize = 32000;
+
+ // For each row block
+ for(int h = rl; h < ru; h += rowBlockSize) {
+ // For each column block
+ final int rowUpper = Math.min(h + rowBlockSize,
ru);
+ for(int i = 0; i < that.getNumColumns(); i +=
colBlockSize) {
+ final int colUpper = Math.min(i +
colBlockSize, that.getNumColumns());
+ // Pre Aggregate each column group in
block
for(int j = g; j < gEnd && j <
ColGroupValues.size(); j++) {
- ColGroupValue vj =
ColGroupValues.get(j);
- MatrixBlock preAggJ =
preAgg.get(j % colGroupBlocking);
- preAggJ.recomputeNonZeros();
- tmpRes.reset(rowBlockSize,
vj.getNumCols(), false);
- MatrixBlock tmp =
vj.leftMultByPreAggregateMatrix(preAggJ, tmpRes);
- vj.addMatrixToResult(tmp, ret,
h, Math.min(h + rowBlockSize, ru));
- preAggJ.reset();
+
ColGroupValues.get(j).preAggregateDense(that, preAgg[j % colGroupBlocking], h,
rowUpper, i,
+ colUpper);
}
+ if(rowSum != null) {
+ final double[] thatV =
that.getDenseBlockValues();
+ for(int r = h; r < rowUpper;
r++) {
+ final int rowOff = r *
that.getNumColumns();
+ for(int c = rowOff + i;
c < rowOff + colUpper; c++)
+ rowSum[r] +=
thatV[c];
+ }
+ }
+ }
+ // Multiply out the preAggregate to the output
matrix.
+ for(int j = g; j < gEnd && j <
ColGroupValues.size(); j++) {
+ ColGroupValue vj =
ColGroupValues.get(j);
+ MatrixBlock preAggJ = preAgg[j %
colGroupBlocking];
+ preAggJ.recomputeNonZeros();
+ tmpRes.reset(rowBlockSize,
vj.getNumCols(), false);
+ MatrixBlock tmp =
vj.leftMultByPreAggregateMatrix(preAggJ, tmpRes);
+ vj.addMatrixToResult(tmp, ret, h,
Math.min(h + rowBlockSize, ru));
+ preAggJ.reset();
}
}
}
+
}
- private static List<MatrixBlock> populatePreAggregate(int
colGroupBlocking) {
- final List<MatrixBlock> preAgg = new ArrayList<>();
+ private static MatrixBlock[] populatePreAggregate(int colGroupBlocking)
{
+ final MatrixBlock[] preAgg = new MatrixBlock[colGroupBlocking];
// poplate the preAgg array.
for(int j = 0; j < colGroupBlocking; j++) {
-
MatrixBlock m = new MatrixBlock(1, 1, false);
m.allocateDenseBlock();
- preAgg.add(m);
+ preAgg[j] = m;
}
return preAgg;
}
private static List<ColGroupValue> preFilterAndMultiply(List<AColGroup>
colGroups, MatrixBlock that,
- MatrixBlock ret, int numColumns, int rl, int ru) {
- final List<ColGroupValue> ColGroupValues = new ArrayList<>();
+ MatrixBlock ret, int rl, int ru) {
+ final List<ColGroupValue> ColGroupValues = new
ArrayList<>(colGroups.size());
for(int j = 0; j < colGroups.size(); j++) {
AColGroup a = colGroups.get(j);
- if(a instanceof ColGroupValue) {
- ColGroupValue av = (ColGroupValue) a;
- ColGroupValues.add(av);
- }
+ if(a instanceof ColGroupValue)
+ ColGroupValues.add((ColGroupValue) a);
else
a.leftMultByMatrix(that, ret, rl, ru);
}