This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5deaa1a  [MINOR] update const function for compressed matrix 
multiplication
5deaa1a is described below

commit 5deaa1acd40890d09428b424ce35162c774706da
Author: baunsgaard <[email protected]>
AuthorDate: Mon Aug 30 13:20:47 2021 +0200

    [MINOR] update const function for compressed matrix multiplication
---
 .../runtime/compress/colgroup/ColGroupFactory.java |  2 +-
 .../compress/cost/ComputationCostEstimator.java    | 29 +++++--
 .../runtime/compress/lib/CLALibLeftMultBy.java     | 97 +++++++++++-----------
 3 files changed, 72 insertions(+), 56 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 2b95dba..9416618 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -216,7 +216,7 @@ public final class ColGroupFactory {
                try {
                        final int nrUniqueEstimate = cg.getNumVals();
                        CompressionType estimatedBestCompressionType = 
cg.getBestCompressionType();
-                       
+
                        if(estimatedBestCompressionType == CompressionType.SDC 
&& cs.costComputationType == CostType.W_TREE) {
                                if(cg.getCompressionSize(CompressionType.DDC) < 
cg.getCompressionSize(CompressionType.SDC) * 3)
                                        estimatedBestCompressionType = 
CompressionType.DDC;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index 0152edc..8db6729 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -84,7 +84,8 @@ public class ComputationCostEstimator implements 
ICostEstimate {
                // 16 is assuming that the right side is 16 rows.
                double rmc = rightMultCost(g) * 16;
                cost += _rightMultiplications * rmc;
-               cost += _compressedMultiplication * (lmc + rmc);
+               // cost += _compressedMultiplication * (lmc + rmc);
+               cost += _compressedMultiplication * _compressedMultCost(g);
                cost += _dictionaryOps * dictionaryOpsCost(g);
                return cost;
        }
@@ -97,24 +98,38 @@ public class ComputationCostEstimator implements 
ICostEstimate {
                final int nCols = g.getColumns().length;
                final double preAggregateCost = _nRows;
 
-               final int numberTuples = g.getNumVals();
+               final double numberTuples = g.getNumVals();
                final double tupleSparsity = g.getTupleSparsity();
-               final double postScalingCost = (nCols > 1 && tupleSparsity > 
0.4) ? numberTuples * nCols : numberTuples *
-                       nCols * tupleSparsity;
+               final double postScalingCost = (nCols > 1 && tupleSparsity > 
0.4) ? numberTuples * nCols * tupleSparsity *
+                       1.4 : numberTuples * nCols;
                if(numberTuples < 64000)
                        return preAggregateCost + postScalingCost;
                else
-                       // scale up cost worse if there is higher number of 
tuples.
                        return preAggregateCost * (numberTuples / 6400) + 
postScalingCost * (numberTuples / 64000);
 
        }
 
+       private double _compressedMultCost(CompressedSizeInfoColGroup g) {
+               final int nCols = g.getColumns().length;
+               final double mcf = g.getMostCommonFraction();
+               final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.7 * 
mcf) : _nRows;
+
+               final double numberTuples = (float) g.getNumVals();
+               final double tupleSparsity = g.getTupleSparsity();
+               final double postScalingCost = (nCols > 1 && tupleSparsity > 
0.4) ? numberTuples * nCols * tupleSparsity *
+                       1.4 : numberTuples * nCols;
+               if(numberTuples < 64000)
+                       return preAggregateCost + postScalingCost;
+               else
+                       return preAggregateCost * (numberTuples / 64000) + 
postScalingCost * (numberTuples / 64000);
+       }
+
        private static double rightMultCost(CompressedSizeInfoColGroup g) {
                final int nCols = g.getColumns().length;
                final int numberTuples = g.getNumVals() * 10;
                final double tupleSparsity = g.getTupleSparsity();
-               final double postScalingCost = (nCols > 1 && tupleSparsity > 
0.4) ? numberTuples * nCols : numberTuples *
-                       nCols * tupleSparsity;
+               final double postScalingCost = (nCols > 1 && tupleSparsity > 
0.4) ? numberTuples * nCols * tupleSparsity *
+                       1.4 : numberTuples * nCols;
 
                return postScalingCost;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index c0b4f40..5fc9016 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -233,8 +233,10 @@ public class CLALibLeftMultBy {
        }
 
        private static void tsmmColGroups(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret) {
-               for(int i = 0; i < groups.size(); i++)
+               for(int i = 0; i < groups.size(); i++) {
+                       groups.get(i).tsmm(ret);
                        tsmmColGroupsIndexI(groups, filteredGroups, ret, i);
+               }
        }
 
        private static void tsmmColGroupsParallel(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret,
@@ -242,9 +244,15 @@ public class CLALibLeftMultBy {
                try {
                        ExecutorService pool = CommonThreadPool.get(k);
                        ArrayList<Callable<Object>> tasks = new ArrayList<>();
-
-                       for(int i = 0; i < filteredGroups.size(); i++)
-                               tasks.add(new tsmmColGroupTask(groups, 
filteredGroups, ret, i));
+                       // if(groups.size()< 10){
+
+                       // }
+                       final int numColGroups = groups.size();
+                       for(int i = 0; i < numColGroups; i++) {
+                               tasks.add(new 
tsmmSelfColGroupTask(groups.get(i), ret));
+                               for(int j = i +1; j < numColGroups; j++)
+                                       tasks.add(new tsmmColGroupTask(groups, 
filteredGroups, ret, i, j, j+1));
+                       }
 
                        for(Future<Object> tret : pool.invokeAll(tasks))
                                tret.get();
@@ -257,58 +265,25 @@ public class CLALibLeftMultBy {
 
        private static void tsmmColGroupsIndexI(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret,
                int i) {
+               tsmmColGroupsIndexIStartEnd(groups, filteredGroups, ret, i, i + 
1, groups.size());
+       }
+
+       private static void tsmmColGroupsIndexIStartEnd(List<AColGroup> groups, 
List<AColGroup> filteredGroups,
+               MatrixBlock ret, int i, int start, int end) {
                final AColGroup full_lhs = groups.get(i);
                final AColGroup lhs = filteredGroups.get(i);
-               final int start = i;
-               final int end = groups.size();
-               full_lhs.tsmm(ret);
                boolean isSDC = full_lhs instanceof ColGroupSDC || full_lhs 
instanceof ColGroupSDCSingle;
-               // if(isSDC) {
-               // Arrays.fill(tmp, 0);
-               // full_lhs.computeColSums(tmp);
-               // }
-               for(int id = start + 1; id < end; id++) {
+               for(int id = start ; id < end; id++) {
                        final AColGroup full_rhs = groups.get(id);
                        final AColGroup rhs = filteredGroups.get(id);
-                       if(isSDC && (full_rhs instanceof ColGroupSDC || 
full_rhs instanceof ColGroupSDCSingle)) {
-                               // Full
+                       if(isSDC && (full_rhs instanceof ColGroupSDC || 
full_rhs instanceof ColGroupSDCSingle))
                                full_lhs.leftMultByAColGroup(full_rhs, ret);
-
-                               // Partial
-                               // full_lhs.leftMultByAColGroup(rhs, ret);
-                               // multiplyWithMostCommonElement(tmp, 
(ColGroupValue) full_rhs, ret);
-                       }
-                       else {
+                       else
                                lhs.leftMultByAColGroup(rhs, ret);
-                       }
+
                }
        }
 
-       // private static void multiplyWithMostCommonElement(double[] colSum, 
ColGroupValue full, MatrixBlock ret) {
-       // final ADictionary d = full.getDictionary();
-       // final double[] result = ret.getDenseBlockValues();
-       // final int numVals = full.getNumValues();
-       // final int[] colIndexes = full.getColIndices();
-       // final int numColumns = ret.getNumColumns();
-       // if(d instanceof MatrixBlockDictionary && ((MatrixBlockDictionary) 
d).getMatrixBlock().isInSparseFormat()) {
-       // throw new NotImplementedException();
-       // }
-       // else {
-       // final int offsetToDefault = numVals * full.getNumCols() - numVals;
-       // final double[] dv = d.getValues();
-       // for(int row = 0; row < colSum.length; row++) {
-
-       // final int offOut = numColumns * row;
-       // final double vLeft = colSum[row];
-       // if(vLeft != 0) {
-       // for(int colId = 0; colId < colIndexes.length; colId++) {
-       // result[offOut + colIndexes[colId]] += vLeft * dv[offsetToDefault + 
colId];
-       // }
-       // }
-       // }
-       // }
-       // }
-
        private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, 
MatrixBlock that, MatrixBlock ret, int k,
                boolean overlapping) {
 
@@ -487,18 +462,44 @@ public class CLALibLeftMultBy {
                private final List<AColGroup> _filteredGroups;
                private final MatrixBlock _ret;
                private final int _index;
+               private final int _start;
+               private final int _end;
 
-               protected tsmmColGroupTask(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret, int i) {
+               protected tsmmColGroupTask(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret, int i, int start, int end) {
                        _groups = groups;
                        _filteredGroups = filteredGroups;
                        _ret = ret;
                        _index = i;
+                       _start = start;
+                       _end = end;
+               }
+
+               @Override
+               public MatrixBlock call() {
+                       try {
+                               tsmmColGroupsIndexIStartEnd(_groups, 
_filteredGroups, _ret, _index, _start, _end);
+                       }
+                       catch(Exception e) {
+                               e.printStackTrace();
+                               throw new DMLRuntimeException(e);
+                       }
+                       return _ret;
+               }
+       }
+
+       private static class tsmmSelfColGroupTask implements Callable<Object> {
+               private final AColGroup _g;
+               private final MatrixBlock _ret;
+
+               protected tsmmSelfColGroupTask(AColGroup g, MatrixBlock ret) {
+                       _g = g;
+                       _ret = ret;
                }
 
                @Override
                public MatrixBlock call() {
                        try {
-                               tsmmColGroupsIndexI(_groups, _filteredGroups, 
_ret, _index);
+                               _g.tsmm(_ret);
                        }
                        catch(Exception e) {
                                e.printStackTrace();

Reply via email to