This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5deaa1a [MINOR] update const function for compressed matrix
multiplication
5deaa1a is described below
commit 5deaa1acd40890d09428b424ce35162c774706da
Author: baunsgaard <[email protected]>
AuthorDate: Mon Aug 30 13:20:47 2021 +0200
[MINOR] update const function for compressed matrix multiplication
---
.../runtime/compress/colgroup/ColGroupFactory.java | 2 +-
.../compress/cost/ComputationCostEstimator.java | 29 +++++--
.../runtime/compress/lib/CLALibLeftMultBy.java | 97 +++++++++++-----------
3 files changed, 72 insertions(+), 56 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 2b95dba..9416618 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -216,7 +216,7 @@ public final class ColGroupFactory {
try {
final int nrUniqueEstimate = cg.getNumVals();
CompressionType estimatedBestCompressionType =
cg.getBestCompressionType();
-
+
if(estimatedBestCompressionType == CompressionType.SDC
&& cs.costComputationType == CostType.W_TREE) {
if(cg.getCompressionSize(CompressionType.DDC) <
cg.getCompressionSize(CompressionType.SDC) * 3)
estimatedBestCompressionType =
CompressionType.DDC;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index 0152edc..8db6729 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -84,7 +84,8 @@ public class ComputationCostEstimator implements
ICostEstimate {
// 16 is assuming that the right side is 16 rows.
double rmc = rightMultCost(g) * 16;
cost += _rightMultiplications * rmc;
- cost += _compressedMultiplication * (lmc + rmc);
+ // cost += _compressedMultiplication * (lmc + rmc);
+ cost += _compressedMultiplication * _compressedMultCost(g);
cost += _dictionaryOps * dictionaryOpsCost(g);
return cost;
}
@@ -97,24 +98,38 @@ public class ComputationCostEstimator implements
ICostEstimate {
final int nCols = g.getColumns().length;
final double preAggregateCost = _nRows;
- final int numberTuples = g.getNumVals();
+ final double numberTuples = g.getNumVals();
final double tupleSparsity = g.getTupleSparsity();
- final double postScalingCost = (nCols > 1 && tupleSparsity >
0.4) ? numberTuples * nCols : numberTuples *
- nCols * tupleSparsity;
+ final double postScalingCost = (nCols > 1 && tupleSparsity >
0.4) ? numberTuples * nCols * tupleSparsity *
+ 1.4 : numberTuples * nCols;
if(numberTuples < 64000)
return preAggregateCost + postScalingCost;
else
- // scale up cost worse if there is higher number of
tuples.
return preAggregateCost * (numberTuples / 6400) +
postScalingCost * (numberTuples / 64000);
}
+ private double _compressedMultCost(CompressedSizeInfoColGroup g) {
+ final int nCols = g.getColumns().length;
+ final double mcf = g.getMostCommonFraction();
+ final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.7 *
mcf) : _nRows;
+
+ final double numberTuples = (float) g.getNumVals();
+ final double tupleSparsity = g.getTupleSparsity();
+ final double postScalingCost = (nCols > 1 && tupleSparsity >
0.4) ? numberTuples * nCols * tupleSparsity *
+ 1.4 : numberTuples * nCols;
+ if(numberTuples < 64000)
+ return preAggregateCost + postScalingCost;
+ else
+ return preAggregateCost * (numberTuples / 64000) +
postScalingCost * (numberTuples / 64000);
+ }
+
private static double rightMultCost(CompressedSizeInfoColGroup g) {
final int nCols = g.getColumns().length;
final int numberTuples = g.getNumVals() * 10;
final double tupleSparsity = g.getTupleSparsity();
- final double postScalingCost = (nCols > 1 && tupleSparsity >
0.4) ? numberTuples * nCols : numberTuples *
- nCols * tupleSparsity;
+ final double postScalingCost = (nCols > 1 && tupleSparsity >
0.4) ? numberTuples * nCols * tupleSparsity *
+ 1.4 : numberTuples * nCols;
return postScalingCost;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index c0b4f40..5fc9016 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -233,8 +233,10 @@ public class CLALibLeftMultBy {
}
private static void tsmmColGroups(List<AColGroup> groups,
List<AColGroup> filteredGroups, MatrixBlock ret) {
- for(int i = 0; i < groups.size(); i++)
+ for(int i = 0; i < groups.size(); i++) {
+ groups.get(i).tsmm(ret);
tsmmColGroupsIndexI(groups, filteredGroups, ret, i);
+ }
}
private static void tsmmColGroupsParallel(List<AColGroup> groups,
List<AColGroup> filteredGroups, MatrixBlock ret,
@@ -242,9 +244,15 @@ public class CLALibLeftMultBy {
try {
ExecutorService pool = CommonThreadPool.get(k);
ArrayList<Callable<Object>> tasks = new ArrayList<>();
-
- for(int i = 0; i < filteredGroups.size(); i++)
- tasks.add(new tsmmColGroupTask(groups,
filteredGroups, ret, i));
+ // if(groups.size()< 10){
+
+ // }
+ final int numColGroups = groups.size();
+ for(int i = 0; i < numColGroups; i++) {
+ tasks.add(new
tsmmSelfColGroupTask(groups.get(i), ret));
+ for(int j = i +1; j < numColGroups; j++)
+ tasks.add(new tsmmColGroupTask(groups,
filteredGroups, ret, i, j, j+1));
+ }
for(Future<Object> tret : pool.invokeAll(tasks))
tret.get();
@@ -257,58 +265,25 @@ public class CLALibLeftMultBy {
private static void tsmmColGroupsIndexI(List<AColGroup> groups,
List<AColGroup> filteredGroups, MatrixBlock ret,
int i) {
+ tsmmColGroupsIndexIStartEnd(groups, filteredGroups, ret, i, i +
1, groups.size());
+ }
+
+ private static void tsmmColGroupsIndexIStartEnd(List<AColGroup> groups,
List<AColGroup> filteredGroups,
+ MatrixBlock ret, int i, int start, int end) {
final AColGroup full_lhs = groups.get(i);
final AColGroup lhs = filteredGroups.get(i);
- final int start = i;
- final int end = groups.size();
- full_lhs.tsmm(ret);
boolean isSDC = full_lhs instanceof ColGroupSDC || full_lhs
instanceof ColGroupSDCSingle;
- // if(isSDC) {
- // Arrays.fill(tmp, 0);
- // full_lhs.computeColSums(tmp);
- // }
- for(int id = start + 1; id < end; id++) {
+ for(int id = start ; id < end; id++) {
final AColGroup full_rhs = groups.get(id);
final AColGroup rhs = filteredGroups.get(id);
- if(isSDC && (full_rhs instanceof ColGroupSDC ||
full_rhs instanceof ColGroupSDCSingle)) {
- // Full
+ if(isSDC && (full_rhs instanceof ColGroupSDC ||
full_rhs instanceof ColGroupSDCSingle))
full_lhs.leftMultByAColGroup(full_rhs, ret);
-
- // Partial
- // full_lhs.leftMultByAColGroup(rhs, ret);
- // multiplyWithMostCommonElement(tmp,
(ColGroupValue) full_rhs, ret);
- }
- else {
+ else
lhs.leftMultByAColGroup(rhs, ret);
- }
+
}
}
- // private static void multiplyWithMostCommonElement(double[] colSum,
ColGroupValue full, MatrixBlock ret) {
- // final ADictionary d = full.getDictionary();
- // final double[] result = ret.getDenseBlockValues();
- // final int numVals = full.getNumValues();
- // final int[] colIndexes = full.getColIndices();
- // final int numColumns = ret.getNumColumns();
- // if(d instanceof MatrixBlockDictionary && ((MatrixBlockDictionary)
d).getMatrixBlock().isInSparseFormat()) {
- // throw new NotImplementedException();
- // }
- // else {
- // final int offsetToDefault = numVals * full.getNumCols() - numVals;
- // final double[] dv = d.getValues();
- // for(int row = 0; row < colSum.length; row++) {
-
- // final int offOut = numColumns * row;
- // final double vLeft = colSum[row];
- // if(vLeft != 0) {
- // for(int colId = 0; colId < colIndexes.length; colId++) {
- // result[offOut + colIndexes[colId]] += vLeft * dv[offsetToDefault +
colId];
- // }
- // }
- // }
- // }
- // }
-
private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups,
MatrixBlock that, MatrixBlock ret, int k,
boolean overlapping) {
@@ -487,18 +462,44 @@ public class CLALibLeftMultBy {
private final List<AColGroup> _filteredGroups;
private final MatrixBlock _ret;
private final int _index;
+ private final int _start;
+ private final int _end;
- protected tsmmColGroupTask(List<AColGroup> groups,
List<AColGroup> filteredGroups, MatrixBlock ret, int i) {
+ protected tsmmColGroupTask(List<AColGroup> groups,
List<AColGroup> filteredGroups, MatrixBlock ret, int i, int start, int end) {
_groups = groups;
_filteredGroups = filteredGroups;
_ret = ret;
_index = i;
+ _start = start;
+ _end = end;
+ }
+
+ @Override
+ public MatrixBlock call() {
+ try {
+ tsmmColGroupsIndexIStartEnd(_groups,
_filteredGroups, _ret, _index, _start, _end);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ throw new DMLRuntimeException(e);
+ }
+ return _ret;
+ }
+ }
+
+ private static class tsmmSelfColGroupTask implements Callable<Object> {
+ private final AColGroup _g;
+ private final MatrixBlock _ret;
+
+ protected tsmmSelfColGroupTask(AColGroup g, MatrixBlock ret) {
+ _g = g;
+ _ret = ret;
}
@Override
public MatrixBlock call() {
try {
- tsmmColGroupsIndexI(_groups, _filteredGroups,
_ret, _index);
+ _g.tsmm(_ret);
}
catch(Exception e) {
e.printStackTrace();