This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 6caa9c02e81de88f691763f25e93497b0b0d2381
Author: baunsgaard <[email protected]>
AuthorDate: Sun Aug 29 23:35:48 2021 +0200

    [MINOR] CLA update tsmm
    
    This commit does two things, first it optimize the tsmm by exploiting
    common elements in SDC groups, and secound it update the cost calculation
    to compute some cost of for single column groups.
---
 .../sysds/runtime/compress/colgroup/AColGroup.java |   7 +
 .../compress/colgroup/ColGroupCompressed.java      |   6 +-
 .../runtime/compress/colgroup/ColGroupSDC.java     |   3 +-
 .../compress/colgroup/ColGroupSDCZeros.java        |  21 +-
 .../compress/colgroup/ColGroupUncompressed.java    |  15 ++
 .../compress/cost/ComputationCostEstimator.java    |   7 +-
 .../runtime/compress/lib/CLALibLeftMultBy.java     | 219 ++++++++++++++++-----
 7 files changed, 222 insertions(+), 56 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
index e675426..0460cf2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
@@ -516,6 +516,13 @@ public abstract class AColGroup implements Serializable {
         */
        public abstract AColGroup replace(double pattern, double replace);
 
+       /**
+        * Compute the column sum
+        * 
+        * @param c The array to add the column sum to.
+        */
+       public abstract void computeColSums(double[] c);
+
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder();
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java
index 968e261..c060596 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java
@@ -36,7 +36,7 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 public abstract class ColGroupCompressed extends AColGroup {
 
        private static final long serialVersionUID = 6219835795420081223L;
-       
+
        final protected int _numRows;
 
        protected ColGroupCompressed(int numRows) {
@@ -72,6 +72,10 @@ public abstract class ColGroupCompressed extends AColGroup {
 
        protected abstract void computeRowSums(double[] c, boolean square, int 
rl, int ru);
 
+       public void computeColSums(double[] c){
+               computeColSums(c, false);
+       }
+
        protected abstract void computeColSums(double[] c, boolean square);
 
        protected abstract void computeRowMxx(double[] c, Builtin builtin, int 
rl, int ru);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index ffe38b1..5c6d365 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -507,11 +507,10 @@ public class ColGroupSDC extends ColGroupValue {
 
        @Override
        public Dictionary preAggregateThatSDCZerosStructure(ColGroupSDCZeros 
that, Dictionary ret) {
-
                final AIterator itThat = that._indexes.getIterator();
                final AIterator itThis = _indexes.getIterator();
                final int nCol = that._colIndexes.length;
-               final int defThis = this.getNumValues() * nCol - nCol;
+               final int defThis = getNumValues() - 1;
 
                while(itThat.hasNext()) {
                        final int thatV = itThat.value();
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index 030cf83..2397fd1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -436,7 +436,26 @@ public class ColGroupSDCZeros extends ColGroupValue {
 
        @Override
        public Dictionary preAggregateThatSDCStructure(ColGroupSDC that, 
Dictionary ret, boolean preModified) {
-               throw new NotImplementedException();
+               if(preModified){
+                       final AIterator itThat = that._indexes.getIterator();
+                       final AIterator itThis = _indexes.getIterator();
+                       final int nCol = that._colIndexes.length;
+       
+                       while(itThat.hasNext() && itThis.hasNext()) {
+                               if(itThat.value() == itThis.value()) {
+                                       final int fr = 
that.getIndex(itThat.getDataIndexAndIncrement());
+                                       final int to = 
getIndex(itThis.getDataIndexAndIncrement());
+                                       that._dict.addToEntry(ret, fr, to, 
nCol);
+                               }
+                               else if(itThat.value() < itThis.value())
+                                       itThat.next();
+                               else
+                                       itThis.next();
+                       }
+                       return ret;
+               }else{
+                       throw new NotImplementedException("Not implemented not 
PreModded preaggregate of SDC");
+               }
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index a9412d0..ab94646 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -620,4 +620,19 @@ public class ColGroupUncompressed extends AColGroup {
                MatrixBlock replaced = _data.replaceOperations(new 
MatrixBlock(), pattern, replace);
                return new ColGroupUncompressed(_colIndexes, replaced);
        }
+
+       @Override
+       public void computeColSums(double[] c) {
+               // TODO Auto-generated method stub
+               MatrixBlock colSum = _data.colSum();
+               if(colSum.isInSparseFormat()) {
+                       throw new NotImplementedException();
+               }
+               else {
+                       double[] dv = colSum.getDenseBlockValues();
+                       for(int i = 0; i < _colIndexes.length; i++)
+                               c[_colIndexes[i]] += dv[i];
+                       
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index fe4cb83..0152edc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -79,9 +79,12 @@ public class ComputationCostEstimator implements 
ICostEstimate {
                cost += _decompressions * decompressionCost(g);
                cost += _overlappingDecompressions * 
overlappingDecompressionCost(g);
                // 16 is assuming that the left side is 16 rows.
-               cost += _leftMultiplications * leftMultCost(g) * 16;
+               double lmc = leftMultCost(g) * 16;
+               cost += _leftMultiplications * lmc;
                // 16 is assuming that the right side is 16 rows.
-               cost += _rightMultiplications * rightMultCost(g) * 16;
+               double rmc = rightMultCost(g) * 16;
+               cost += _rightMultiplications * rmc;
+               cost += _compressedMultiplication * (lmc + rmc);
                cost += _dictionaryOps * dictionaryOpsCost(g);
                return cost;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 10cd2e8..c0b4f40 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -91,53 +91,65 @@ public class CLALibLeftMultBy {
        }
 
        public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, 
MatrixBlock result, int k) {
-               final int numColumns = cmb.getNumColumns();
                final boolean overlapping = cmb.isOverlapping();
-               List<AColGroup> groups = cmb.getColGroups();
+               final List<AColGroup> groups = cmb.getColGroups();
+
                result.allocateDenseBlock();
 
                if(overlapping) {
                        LOG.warn("Inefficient TSMM with overlapping matrix 
could be implemented multi-threaded but is not yet.");
                        leftMultByCompressedTransposedMatrix(groups, groups, 
result);
                }
-               else if(k <= 1) {
-                       for(int i = 0; i < groups.size(); i++)
-                               
leftMultByCompressedTransposedMatrix(groups.get(i), groups, result, i, 
groups.size());
-               }
                else {
-                       try {
-                               ExecutorService pool = CommonThreadPool.get(k);
-                               ArrayList<Callable<Object>> tasks = new 
ArrayList<>();
+                       final boolean containsSDC = containsSDC(groups);
+                       final int numColumns = cmb.getNumColumns();
+                       final double[] constV = containsSDC ? new 
double[cmb.getNumColumns()] : null;
+                       final List<AColGroup> filteredGroups = 
filterSDCGroups(groups, constV);
+                       final double[] colSums = containsSDC ? new 
double[cmb.getNumColumns()] : null;
 
+                       if(containsSDC)
                                for(int i = 0; i < groups.size(); i++) {
-                                       final AColGroup g = groups.get(i);
-                                       tasks.add(new 
LeftMultByCompressedTransposedMatrixTask(groups, g, result, i, groups.size()));
+                                       AColGroup gi = groups.get(i);
+                                       if(!(gi instanceof ColGroupSDC || gi 
instanceof ColGroupSDCSingle))
+                                               gi.computeColSums(colSums);
                                }
 
-                               for(Future<Object> tret : pool.invokeAll(tasks))
-                                       tret.get();
-                               pool.shutdown();
-                       }
-                       catch(InterruptedException | ExecutionException e) {
-                               throw new DMLRuntimeException(e);
+                       if(k <= 1)
+                               tsmmColGroups(groups, filteredGroups, result);
+                       else
+                               tsmmColGroupsParallel(groups, filteredGroups, 
result, k);
+
+                       double[] retV = result.getDenseBlockValues();
+
+                       // Move values in the lower part of the matrix to the 
upper part
+                       copyToUpperTriangle(retV, numColumns);
+
+                       // add the correction layer for the subtracted common 
values.
+                       if(colSums != null) {
+                               outerProduct(colSums, constV, retV);
+                               addToUpperTriangle(retV, numColumns);
                        }
                }
-               // Move values in the lower part of the matrix to the upper part
-               copyToUpperTriangle(result.getDenseBlockValues(), numColumns);
-               // calculate the number of non zeros, and allocate all value 
locations by copying upper triangle back to bottom.
+
                long nnz = LinearAlgebraUtils.copyUpperToLowerTriangle(result);
                result.setNonZeros(nnz);
-               // Evaluate if the output should be sparsely allocated.
                result.examSparsity();
        }
 
        private static void copyToUpperTriangle(final double[] c, final int 
cols) {
                for(int i = 0, offC = 0; i < cols; i++, offC += cols)
-                       for(int j = i, offR = i * cols; j < cols; j++, offR += 
cols) {
+                       for(int j = (i + 1), offR = (i + 1) * cols; j < cols; 
j++, offR += cols) {
                                final double prev = c[offC + j];
                                if(prev == 0)
                                        c[offC + j] = c[i + offR];
+                               c[i + offR] = 0;
                        }
+       }
+
+       private static void addToUpperTriangle(final double[] c, final int 
cols) {
+               for(int i = 0, offC = 0; i < cols; i++, offC += cols)
+                       for(int j = (i + 1), offR = (i + 1) * cols; j < cols; 
j++, offR += cols)
+                               c[offC + j] += c[i + offR];
 
        }
 
@@ -181,15 +193,6 @@ public class CLALibLeftMultBy {
                private final int _start;
                private final int _end;
 
-               protected 
LeftMultByCompressedTransposedMatrixTask(List<AColGroup> groups, AColGroup 
left, MatrixBlock ret,
-                       int start, int end) {
-                       _groups = groups;
-                       _left = left;
-                       _ret = ret;
-                       _start = start;
-                       _end = end;
-               }
-
                protected 
LeftMultByCompressedTransposedMatrixTask(List<AColGroup> groups, AColGroup 
left, MatrixBlock ret) {
                        _groups = groups;
                        _left = left;
@@ -227,9 +230,85 @@ public class CLALibLeftMultBy {
                        else
                                rhs.tsmm(ret);
                }
+       }
+
+       private static void tsmmColGroups(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret) {
+               for(int i = 0; i < groups.size(); i++)
+                       tsmmColGroupsIndexI(groups, filteredGroups, ret, i);
+       }
+
+       private static void tsmmColGroupsParallel(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret,
+               int k) {
+               try {
+                       ExecutorService pool = CommonThreadPool.get(k);
+                       ArrayList<Callable<Object>> tasks = new ArrayList<>();
+
+                       for(int i = 0; i < filteredGroups.size(); i++)
+                               tasks.add(new tsmmColGroupTask(groups, 
filteredGroups, ret, i));
+
+                       for(Future<Object> tret : pool.invokeAll(tasks))
+                               tret.get();
+                       pool.shutdown();
+               }
+               catch(InterruptedException | ExecutionException e) {
+                       throw new DMLRuntimeException(e);
+               }
+       }
 
+       private static void tsmmColGroupsIndexI(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret,
+               int i) {
+               final AColGroup full_lhs = groups.get(i);
+               final AColGroup lhs = filteredGroups.get(i);
+               final int start = i;
+               final int end = groups.size();
+               full_lhs.tsmm(ret);
+               boolean isSDC = full_lhs instanceof ColGroupSDC || full_lhs 
instanceof ColGroupSDCSingle;
+               // if(isSDC) {
+               // Arrays.fill(tmp, 0);
+               // full_lhs.computeColSums(tmp);
+               // }
+               for(int id = start + 1; id < end; id++) {
+                       final AColGroup full_rhs = groups.get(id);
+                       final AColGroup rhs = filteredGroups.get(id);
+                       if(isSDC && (full_rhs instanceof ColGroupSDC || 
full_rhs instanceof ColGroupSDCSingle)) {
+                               // Full
+                               full_lhs.leftMultByAColGroup(full_rhs, ret);
+
+                               // Partial
+                               // full_lhs.leftMultByAColGroup(rhs, ret);
+                               // multiplyWithMostCommonElement(tmp, 
(ColGroupValue) full_rhs, ret);
+                       }
+                       else {
+                               lhs.leftMultByAColGroup(rhs, ret);
+                       }
+               }
        }
 
+       // private static void multiplyWithMostCommonElement(double[] colSum, 
ColGroupValue full, MatrixBlock ret) {
+       // final ADictionary d = full.getDictionary();
+       // final double[] result = ret.getDenseBlockValues();
+       // final int numVals = full.getNumValues();
+       // final int[] colIndexes = full.getColIndices();
+       // final int numColumns = ret.getNumColumns();
+       // if(d instanceof MatrixBlockDictionary && ((MatrixBlockDictionary) 
d).getMatrixBlock().isInSparseFormat()) {
+       // throw new NotImplementedException();
+       // }
+       // else {
+       // final int offsetToDefault = numVals * full.getNumCols() - numVals;
+       // final double[] dv = d.getValues();
+       // for(int row = 0; row < colSum.length; row++) {
+
+       // final int offOut = numColumns * row;
+       // final double vLeft = colSum[row];
+       // if(vLeft != 0) {
+       // for(int colId = 0; colId < colIndexes.length; colId++) {
+       // result[offOut + colIndexes[colId]] += vLeft * dv[offsetToDefault + 
colId];
+       // }
+       // }
+       // }
+       // }
+       // }
+
        private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, 
MatrixBlock that, MatrixBlock ret, int k,
                boolean overlapping) {
 
@@ -237,28 +316,13 @@ public class CLALibLeftMultBy {
                        ret.setNonZeros(0);
                        return ret;
                }
-               final int numColumnsOut = ret.getNumColumns();
-               boolean containsSDC = false;
 
-               for(AColGroup g : colGroups) {
-                       if(g instanceof ColGroupSDC || g instanceof 
ColGroupSDCSingle)
-                               containsSDC = true;
-               }
+               final int numColumnsOut = ret.getNumColumns();
+               final boolean containsSDC = containsSDC(colGroups);
 
-               final List<AColGroup> filteredGroups = containsSDC ? new 
ArrayList<>() : colGroups;
                // a constant colgroup summing the default values.
                final double[] constV = containsSDC ? new double[numColumnsOut] 
: null;
-
-               if(containsSDC) {
-                       for(AColGroup g : colGroups) {
-                               if(g instanceof ColGroupSDC)
-                                       filteredGroups.add(((ColGroupSDC) 
g).extractCommon(constV));
-                               else if(g instanceof ColGroupSDCSingle)
-                                       filteredGroups.add(((ColGroupSDCSingle) 
g).extractCommon(constV));
-                               else
-                                       filteredGroups.add(g);
-                       }
-               }
+               final List<AColGroup> filteredGroups = 
filterSDCGroups(colGroups, constV);
 
                ret.allocateDenseBlock();
                final double[] rowSums = containsSDC ? new 
double[that.getNumRows()] : null;
@@ -418,6 +482,32 @@ public class CLALibLeftMultBy {
                }
        }
 
+       private static class tsmmColGroupTask implements Callable<Object> {
+               private final List<AColGroup> _groups;
+               private final List<AColGroup> _filteredGroups;
+               private final MatrixBlock _ret;
+               private final int _index;
+
+               protected tsmmColGroupTask(List<AColGroup> groups, 
List<AColGroup> filteredGroups, MatrixBlock ret, int i) {
+                       _groups = groups;
+                       _filteredGroups = filteredGroups;
+                       _ret = ret;
+                       _index = i;
+               }
+
+               @Override
+               public MatrixBlock call() {
+                       try {
+                               tsmmColGroupsIndexI(_groups, _filteredGroups, 
_ret, _index);
+                       }
+                       catch(Exception e) {
+                               e.printStackTrace();
+                               throw new DMLRuntimeException(e);
+                       }
+                       return _ret;
+               }
+       }
+
        private static void leftMultByMatrixPrimitive(List<AColGroup> 
colGroups, MatrixBlock that, MatrixBlock ret, int rl,
                int ru, double[] rowSums) {
                if(that.isInSparseFormat())
@@ -435,7 +525,7 @@ public class CLALibLeftMultBy {
                        }
                        if(rowSum != null) {
                                final SparseBlock sb = that.getSparseBlock();
-                               if(!sb.isEmpty(i)){
+                               if(!sb.isEmpty(i)) {
                                        final int apos = sb.pos(i);
                                        final int alen = sb.size(i) + apos;
                                        final double[] aval = sb.values(i);
@@ -538,4 +628,33 @@ public class CLALibLeftMultBy {
                Collections.sort(ColGroupValues, 
Comparator.comparing(AColGroup::getNumValues).reversed());
                return ColGroupValues;
        }
+
+       private static boolean containsSDC(List<AColGroup> groups) {
+               boolean containsSDC = false;
+
+               for(AColGroup g : groups) {
+                       if(g instanceof ColGroupSDC || g instanceof 
ColGroupSDCSingle) {
+                               containsSDC = true;
+                               break;
+                       }
+               }
+               return containsSDC;
+       }
+
+       private static List<AColGroup> filterSDCGroups(List<AColGroup> groups, 
double[] constV) {
+               if(constV != null) {
+                       final List<AColGroup> filteredGroups = new 
ArrayList<>();
+                       for(AColGroup g : groups) {
+                               if(g instanceof ColGroupSDC)
+                                       filteredGroups.add(((ColGroupSDC) 
g).extractCommon(constV));
+                               else if(g instanceof ColGroupSDCSingle)
+                                       filteredGroups.add(((ColGroupSDCSingle) 
g).extractCommon(constV));
+                               else
+                                       filteredGroups.add(g);
+                       }
+                       return filteredGroups;
+               }
+               else
+                       return groups;
+       }
 }

Reply via email to