This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a6d8bc0975 [MINOR] minor cleanups and optimizations to CLA MM 
primitives
a6d8bc0975 is described below

commit a6d8bc09752043f34f9593590f77112c8abfd449
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Feb 3 18:17:07 2025 +0100

    [MINOR] minor cleanups and optimizations to CLA MM primitives
    
    This commit include specialized decompressing MM for DDC with identity 
matrix dictionaries.
    
    Closes #2210
---
 .../runtime/compress/colgroup/ColGroupDDC.java     |  26 ++-
 .../sysds/runtime/compress/lib/CLALibMMChain.java  |  59 +++++-
 .../runtime/compress/lib/CLALibRightMultBy.java    | 211 +++++++++++++--------
 .../sysds/runtime/compress/lib/CLALibTSMM.java     |  48 ++---
 4 files changed, 230 insertions(+), 114 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index 86ebb4400e..c1b9c65f22 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -601,6 +601,30 @@ public class ColGroupDDC extends APreAgg implements 
IMapToDataGroup {
 
        @Override
        public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, 
int rl, int ru, int nRows, int crl, int cru) {
+               if(_dict instanceof IdentityDictionary)
+                       identityRightDecompressingMult(right, ret, rl, ru, crl, 
cru);
+               else
+                       defaultRightDecompressingMult(right, ret, rl, ru, crl, 
cru);
+       }
+
+       private void identityRightDecompressingMult(MatrixBlock right, 
MatrixBlock ret, int rl, int ru, int crl, int cru) {
+               final double[] b = right.getDenseBlockValues();
+               final double[] c = ret.getDenseBlockValues();
+               final int jd = right.getNumColumns();
+               final int vLen = 8;
+               final int lenJ = cru - crl;
+               final int end = cru - (lenJ % vLen);
+               for(int i = rl; i < ru; i++) {
+                       int k = _data.getIndex(i);
+                       final int offOut = i * jd + crl;
+                       final double aa = 1;
+                       final int k_right = _colIndexes.get(k);
+                       vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right, 
vLen);
+
+               }
+       }
+
+       private void defaultRightDecompressingMult(MatrixBlock right, 
MatrixBlock ret, int rl, int ru, int crl, int cru) {
                final double[] a = _dict.getValues();
                final double[] b = right.getDenseBlockValues();
                final double[] c = ret.getDenseBlockValues();
@@ -930,8 +954,6 @@ public class ColGroupDDC extends APreAgg implements 
IMapToDataGroup {
                }
        }
 
-
-
        private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, 
int pos, double[] values2, int pos2, int cl,
                int cu) {
                IdentityDictionary a = (IdentityDictionary) _dict;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
index 6207460d3d..d82d58e323 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
@@ -34,6 +34,7 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
 import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.utils.stats.Timing;
 
 /**
  * Support compressed MM chain operation to fuse the following cases :
@@ -53,6 +54,9 @@ import 
org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 public final class CLALibMMChain {
        static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());
 
+       /** Reusable cache intermediate double array for temporary 
decompression */
+       private static ThreadLocal<double[]> cacheIntermediate = null;
+
        private CLALibMMChain() {
                // private constructor
        }
@@ -87,20 +91,31 @@ public final class CLALibMMChain {
        public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock 
v, MatrixBlock w, MatrixBlock out,
                ChainType ctype, int k) {
 
+               Timing t = new Timing();
                if(x.isEmpty())
                        return returnEmpty(x, out);
 
                // Morph the columns to efficient types for the operation.
                x = filterColGroups(x);
+               double preFilterTime = t.stop();
 
                // Allow overlapping intermediate if the intermediate is 
guaranteed not to be overlapping.
                final boolean allowOverlap = x.getColGroups().size() == 1 && 
isOverlappingAllowed();
 
                // Right hand side multiplication
-               MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, 
null, k, allowOverlap);
+               MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, 
null, k, true);
+
+               double rmmTime = t.stop();
 
-               if(ctype == ChainType.XtwXv) // Multiply intermediate with 
vector if needed
+               if(ctype == ChainType.XtwXv) { // Multiply intermediate with 
vector if needed
                        tmp = binaryMultW(tmp, w, k);
+               }
+
+               if(!allowOverlap && tmp instanceof CompressedMatrixBlock) {
+                       tmp = decompressIntermediate((CompressedMatrixBlock) 
tmp, k);
+               }
+
+               double decompressTime = t.stop();
 
                if(tmp instanceof CompressedMatrixBlock)
                        // Compressed Compressed Matrix Multiplication
@@ -109,12 +124,50 @@ public final class CLALibMMChain {
                        // LMM with Compressed - uncompressed multiplication.
                        CLALibLeftMultBy.leftMultByMatrixTransposed(x, tmp, 
out, k);
 
+               double lmmTime = t.stop();
                if(out.getNumColumns() != 1) // transpose the output to make it 
a row output if needed
                        out = LibMatrixReorg.transposeInPlace(out, k);
 
+               if(LOG.isDebugEnabled()) {
+                       StringBuilder sb = new StringBuilder("\n");
+                       sb.append("\nPreFilter Time      : " + preFilterTime);
+                       sb.append("\nChain RMM           : " + rmmTime);
+                       sb.append("\nChain RMM Decompress: " + decompressTime);
+                       sb.append("\nChain LMM           : " + lmmTime);
+                       sb.append("\nChain Transpose     : " + t.stop());
+                       LOG.debug(sb.toString());
+               }
+
                return out;
        }
 
+       private static MatrixBlock decompressIntermediate(CompressedMatrixBlock 
tmp, int k) {
+               // cacheIntermediate
+               final int rows = tmp.getNumRows();
+               final int cols = tmp.getNumColumns();
+               final int nCells = rows * cols;
+               final double[] tmpArr;
+               if(cacheIntermediate == null) {
+                       tmpArr = new double[nCells];
+                       cacheIntermediate = new ThreadLocal<>();
+                       cacheIntermediate.set(tmpArr);
+               }
+               else {
+                       double[] cachedArr = cacheIntermediate.get();
+                       if(cachedArr == null || cachedArr.length < nCells) {
+                               tmpArr = new double[nCells];
+                               cacheIntermediate.set(tmpArr);
+                       }
+                       else {
+                               tmpArr = cachedArr;
+                       }
+               }
+
+               final MatrixBlock tmpV = new MatrixBlock(tmp.getNumRows(), 
tmp.getNumColumns(), tmpArr);
+               CLALibDecompress.decompressTo((CompressedMatrixBlock) tmp, 
tmpV, 0, 0, k, false, true);
+               return tmpV;
+       }
+
        private static boolean isOverlappingAllowed() {
                return 
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING);
        }
@@ -146,6 +199,8 @@ public final class CLALibMMChain {
                final List<AColGroup> groups = x.getColGroups();
                final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(groups);
                if(shouldFilter) {
+                       if(CLALibUtils.alreadyPreFiltered(groups, 
x.getNumColumns()))
+                               return x;
                        final int nCol = x.getNumColumns();
                        final double[] constV = new double[nCol];
                        final List<AColGroup> filteredGroups = 
CLALibUtils.filterGroups(groups, constV);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
index 966051cd8b..f14d6833d9 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.compress.lib;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 
@@ -30,23 +29,20 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
-import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
 import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
-import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
 public final class CLALibRightMultBy {
        private static final Log LOG = 
LogFactory.getLog(CLALibRightMultBy.class.getName());
 
-       private CLALibRightMultBy(){
+       private CLALibRightMultBy() {
                // private constructor
        }
 
@@ -59,42 +55,104 @@ public final class CLALibRightMultBy {
        public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, int k,
                boolean allowOverlap) {
 
-               final int rr = m1.getNumRows();
-               final int rc = m2.getNumColumns();
+               try {
+                       final int rr = m1.getNumRows();
+                       final int rc = m2.getNumColumns();
 
-               if(m1.isEmpty() || m2.isEmpty()) {
-                       LOG.trace("Empty right multiply");
-                       if(ret == null)
-                               ret = new MatrixBlock(rr, rc, 0);
-                       else
-                               ret.reset(rr, rc, 0);
-                       return ret;
+                       if(m1.isEmpty() || m2.isEmpty()) {
+                               LOG.trace("Empty right multiply");
+                               if(ret == null)
+                                       ret = new MatrixBlock(rr, rc, 0);
+                               else
+                                       ret.reset(rr, rc, 0);
+                               return ret;
+                       }
+                       else {
+                               if(m2 instanceof CompressedMatrixBlock)
+                                       m2 = ((CompressedMatrixBlock) 
m2).getUncompressed("Uncompressed right side of right MM", k);
+
+                               if(betterIfDecompressed(m1)) {
+                                       // perform uncompressed multiplication.
+                                       return decompressingMatrixMult(m1, m2, 
k);
+                               }
+
+                               if(!allowOverlap) {
+                                       LOG.trace("Overlapping output not 
allowed in call to Right MM");
+                                       return RMM(m1, m2, k);
+                               }
+
+                               final CompressedMatrixBlock retC = 
RMMOverlapping(m1, m2, k);
+
+                               if(retC.isEmpty())
+                                       return retC;
+                               else {
+                                       if(retC.isOverlapping())
+                                               retC.setNonZeros((long) rr * 
rc); // set non zeros to fully dense in case of overlapping.
+                                       else
+                                               retC.recomputeNonZeros(k); // 
recompute if non overlapping compressed out.
+                                       return retC;
+                               }
+                       }
                }
-               else {
-                       if(m2 instanceof CompressedMatrixBlock)
-                               m2 = ((CompressedMatrixBlock) 
m2).getUncompressed("Uncompressed right side of right MM", k);
+               catch(Exception e) {
+                       throw new RuntimeException("Failed Right MM", e);
+               }
+       }
 
-                       if(!allowOverlap) {
-                               LOG.trace("Overlapping output not allowed in 
call to Right MM");
-                               return RMM(m1, m2, k);
+       private static MatrixBlock 
decompressingMatrixMult(CompressedMatrixBlock m1, MatrixBlock m2, int k)
+               throws Exception {
+               final ExecutorService pool = CommonThreadPool.get(k);
+               try {
+                       final int rl = m1.getNumRows();
+                       final int cr = m2.getNumColumns();
+                       // final int rr = m2.getNumRows(); // shared dim
+                       final MatrixBlock ret = new MatrixBlock(rl, cr, false);
+                       ret.allocateBlock();
+
+                       // MatrixBlock m1uc = m1.decompress(k);
+                       final List<Future<Long>> tasks = new ArrayList<>();
+                       final List<AColGroup> groups = m1.getColGroups();
+                       final int blkI = Math.max((int) Math.ceil((double) rl / 
k), 16);
+                       final int blkJ = blkI > 16 ? cr : Math.max((cr / k), 
512); // make it a multiplicative of 8.
+                       for(int i = 0; i < rl; i += blkI) {
+                               final int startI = i;
+                               final int endI = Math.min(i + blkI, rl);
+                               for(int j = 0; j < cr; j += blkJ) {
+                                       final int startJ = j;
+                                       final int endJ = Math.min(j + blkJ, cr);
+                                       tasks.add(pool.submit(() -> {
+                                               for(AColGroup g : groups)
+                                                       
g.rightDecompressingMult(m2, ret, startI, endI, rl, startJ, endJ);
+                                               return 
ret.recomputeNonZeros(startI, endI - 1, startJ, endJ - 1);
+                                       }));
+                               }
                        }
+                       long nnz = 0;
+                       for(Future<Long> t : tasks)
+                               nnz += t.get();
 
-                       final CompressedMatrixBlock retC = RMMOverlapping(m1, 
m2, k);
+                       ret.setNonZeros(nnz);
+                       ret.examSparsity();
+                       return ret;
+               }
+               finally {
+                       pool.shutdown();
+               }
 
-                       if(retC.isEmpty())
-                               return retC;
-                       else {
-                               if(retC.isOverlapping())
-                                       retC.setNonZeros((long) rr * rc); // 
set non zeros to fully dense in case of overlapping.
-                               else
-                                       retC.recomputeNonZeros(); // recompute 
if non overlapping compressed out.
-                               return retC;
+       }
+
+       private static boolean betterIfDecompressed(CompressedMatrixBlock m) {
+               for(AColGroup g : m.getColGroups()) {
+                       if(!(g instanceof ColGroupUncompressed) && 
g.getNumValues() * 2 >= m.getNumRows()) {
+                               return true;
                        }
                }
-
+               return false;
        }
 
-       private static CompressedMatrixBlock 
RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) {
+       private static CompressedMatrixBlock 
RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k)
+               throws Exception {
+
                final int rl = m1.getNumRows();
                final int cr = that.getNumColumns();
                final int rr = that.getNumRows(); // shared dim
@@ -103,13 +161,19 @@ public final class CLALibRightMultBy {
                final CompressedMatrixBlock ret = new CompressedMatrixBlock(rl, 
cr);
 
                final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(colGroups);
+               final double[] constV;
+               final List<AColGroup> filteredGroups;
 
-               double[] constV = shouldFilter ? new double[rr] : null;
-               final List<AColGroup> filteredGroups = 
CLALibUtils.filterGroups(colGroups, constV);
-               if(colGroups == filteredGroups)
+               if(shouldFilter) {
+                       constV = new double[rr];
+                       filteredGroups = CLALibUtils.filterGroups(colGroups, 
constV);
+               }
+               else {
+                       filteredGroups = colGroups;
                        constV = null;
+               }
 
-               if(k == 1)
+               if(k == 1 || filteredGroups.size() == 1)
                        RMMSingle(filteredGroups, that, retCg);
                else
                        RMMParallel(filteredGroups, that, retCg, k);
@@ -117,7 +181,7 @@ public final class CLALibRightMultBy {
                if(constV != null) {
                        final MatrixBlock cb = new MatrixBlock(1, 
constV.length, constV);
                        final MatrixBlock cbRet = new MatrixBlock(1, 
that.getNumColumns(), false);
-                       LibMatrixMult.matrixMult(cb, that, cbRet);
+                       LibMatrixMult.matrixMult(cb, that, cbRet); // mm on row 
vector left.
                        if(!cbRet.isEmpty())
                                addConstant(cbRet, retCg);
                }
@@ -133,35 +197,18 @@ public final class CLALibRightMultBy {
        }
 
        private static void addConstant(MatrixBlock constantRow, 
List<AColGroup> out) {
-               final int nCol = constantRow.getNumColumns();
-               int bestCandidate = -1;
-               int bestCandidateValuesSize = Integer.MAX_VALUE;
-               for(int i = 0; i < out.size(); i++) {
-                       AColGroup g = out.get(i);
-                       if(g instanceof ColGroupDDC && g.getNumCols() == nCol 
&& g.getNumValues() < bestCandidateValuesSize)
-                               bestCandidate = i;
-               }
-
                constantRow.sparseToDense();
-
-               if(bestCandidate != -1) {
-                       AColGroup bc = out.get(bestCandidate);
-                       out.remove(bestCandidate);
-                       AColGroup ng = bc.binaryRowOpRight(new 
BinaryOperator(Plus.getPlusFnObject(), 1),
-                               constantRow.getDenseBlockValues(), true);
-                       out.add(ng);
-               }
-               else
-                       
out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
+               
out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
        }
 
-       private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock 
that, int k) {
+       private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock 
that, int k) throws Exception {
+
+               // Timing t = new Timing();
                // this version returns a decompressed result.
                final int rl = m1.getNumRows();
                final int cr = that.getNumColumns();
                final int rr = that.getNumRows(); // shared dim
                final List<AColGroup> colGroups = m1.getColGroups();
-               final List<AColGroup> retCg = new ArrayList<>();
 
                final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(colGroups);
 
@@ -169,11 +216,25 @@ public final class CLALibRightMultBy {
                MatrixBlock ret = new MatrixBlock(rl, cr, false);
                final Future<MatrixBlock> f = ret.allocateBlockAsync();
 
-               double[] constV = shouldFilter ? new double[rr] : null;
-               final List<AColGroup> filteredGroups = 
CLALibUtils.filterGroups(colGroups, constV);
-               if(colGroups == filteredGroups)
+               double[] constV;
+               final List<AColGroup> filteredGroups;
+
+               if(shouldFilter) {
+                       if(CLALibUtils.alreadyPreFiltered(colGroups, cr)) {
+                               filteredGroups = new 
ArrayList<>(colGroups.size() - 1);
+                               constV = 
CLALibUtils.filterGroupsAndSplitPreAggOneConst(colGroups, filteredGroups);
+                       }
+                       else {
+                               constV = new double[rr];
+                               filteredGroups = 
CLALibUtils.filterGroups(colGroups, constV);
+                       }
+               }
+               else {
+                       filteredGroups = colGroups;
                        constV = null;
+               }
 
+               final List<AColGroup> retCg = new 
ArrayList<>(filteredGroups.size());
                if(k == 1)
                        RMMSingle(filteredGroups, that, retCg);
                else
@@ -186,21 +247,12 @@ public final class CLALibRightMultBy {
                        constV = mmTemp.isEmpty() ? null : 
mmTemp.getDenseBlockValues();
                }
 
-               ret = asyncRet(f);
+               ret = f.get();
                CLALibDecompress.decompressDense(ret, retCg, constV, 0, k, 
true);
 
                return ret;
        }
 
-       private static <T> T asyncRet(Future<T> in) {
-               try {
-                       return in.get();
-               }
-               catch(Exception e) {
-                       throw new DMLRuntimeException(e);
-               }
-       }
-
        private static boolean RMMSingle(List<AColGroup> filteredGroups, 
MatrixBlock that, List<AColGroup> retCg) {
                boolean containsNull = false;
                final IColIndex allCols = 
ColIndexFactory.create(that.getNumColumns());
@@ -214,7 +266,8 @@ public final class CLALibRightMultBy {
                return containsNull;
        }
 
-       private static boolean RMMParallel(List<AColGroup> filteredGroups, 
MatrixBlock that, List<AColGroup> retCg, int k) {
+       private static boolean RMMParallel(List<AColGroup> filteredGroups, 
MatrixBlock that, List<AColGroup> retCg, int k)
+               throws Exception {
                final ExecutorService pool = CommonThreadPool.get(k);
                boolean containsNull = false;
                try {
@@ -230,10 +283,7 @@ public final class CLALibRightMultBy {
                                        containsNull = true;
                        }
                }
-               catch(InterruptedException | ExecutionException e) {
-                       throw new DMLRuntimeException(e);
-               }
-               finally{
+               finally {
                        pool.shutdown();
                }
                return containsNull;
@@ -253,13 +303,8 @@ public final class CLALibRightMultBy {
                }
 
                @Override
-               public AColGroup call() {
-                       try {
-                               return _colGroup.rightMultByMatrix(_b, 
_allCols, _k);
-                       }
-                       catch(Exception e) {
-                               throw new DMLRuntimeException(e);
-                       }
+               public AColGroup call() throws Exception {
+                       return _colGroup.rightMultByMatrix(_b, _allCols, _k);
                }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java
index 5f5e63c9ac..a1d47a9b15 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java
@@ -52,8 +52,15 @@ public final class CLALibTSMM {
         * @param k   The parallelization degree allowed
         */
        public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, 
MatrixBlock ret, int k) {
+
                final List<AColGroup> groups = cmb.getColGroups();
+
                final int numColumns = cmb.getNumColumns();
+               if(groups.size() >= numColumns) {
+                       MatrixBlock m = cmb.getUncompressed("TSMM to many 
columngroups", k);
+                       LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k);
+                       return;
+               }
                final int numRows = cmb.getNumRows();
                final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(groups);
                final boolean overlapping = cmb.isOverlapping();
@@ -63,8 +70,10 @@ public final class CLALibTSMM {
                        tsmmColGroups(filteredGroups, ret, numRows, 
overlapping, k);
                        addCorrectionLayer(filteredGroups, ret, numRows, 
numColumns, constV);
                }
-               else
+               else {
+
                        tsmmColGroups(groups, ret, numRows, overlapping, k);
+               }
 
                ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret));
                ret.examSparsity();
@@ -77,10 +86,7 @@ public final class CLALibTSMM {
                addCorrectionLayer(constV, filteredColSum, nRows, retV);
        }
 
-       public static void addCorrectionLayer(double[] constV, double[] 
correctedSum, int nRow, double[] ret) {
-               outerProductUpperTriangle(constV, correctedSum, ret);
-               outerProductUpperTriangleWithScaling(correctedSum, constV, 
nRow, ret);
-       }
+
 
        private static void tsmmColGroups(List<AColGroup> groups, MatrixBlock 
ret, int nRows, boolean overlapping, int k) {
                if(k <= 1)
@@ -108,7 +114,7 @@ public final class CLALibTSMM {
        }
 
        private static void tsmmColGroupsMultiThread(List<AColGroup> groups, 
MatrixBlock ret, int nRows, int k) {
-               final ExecutorService pool = CommonThreadPool.get(k);           
+               final ExecutorService pool = CommonThreadPool.get(k);
                try {
                        final ArrayList<Callable<MatrixBlock>> tasks = new 
ArrayList<>((groups.size() * (1 + groups.size())) / 2);
                        for(int i = 0; i < groups.size(); i++) {
@@ -123,31 +129,19 @@ public final class CLALibTSMM {
                catch(InterruptedException | ExecutionException e) {
                        throw new DMLRuntimeException(e);
                }
-               finally{
+               finally {
                        pool.shutdown();
                }
        }
 
-       private static void outerProductUpperTriangle(final double[] 
leftRowSum, final double[] rightColumnSum,
-               final double[] result) {
-               for(int row = 0; row < leftRowSum.length; row++) {
-                       final int offOut = rightColumnSum.length * row;
-                       final double vLeft = leftRowSum[row];
-                       for(int col = row; col < rightColumnSum.length; col++) {
-                               result[offOut + col] += vLeft * 
rightColumnSum[col];
-                       }
-               }
-       }
-
-       private static void outerProductUpperTriangleWithScaling(final double[] 
leftRowSum, final double[] rightColumnSum,
-               final int scale, final double[] result) {
-               // note this scaling is a bit different since it is 
encapsulating two scalar multiplications via an addition in
-               // the outer loop.
-               for(int row = 0; row < leftRowSum.length; row++) {
-                       final int offOut = rightColumnSum.length * row;
-                       final double vLeft = leftRowSum[row] + 
rightColumnSum[row] * scale;
-                       for(int col = row; col < rightColumnSum.length; col++) {
-                               result[offOut + col] += vLeft * 
rightColumnSum[col];
+       public static void addCorrectionLayer(double[] constV, double[] 
filteredColSum, int nRow, double[] ret) {
+               final int nColRow = constV.length;
+               for(int row = 0; row < nColRow; row++){
+                       int offOut = nColRow * row;
+                       final double v1l = constV[row];
+                       final double v2l = filteredColSum[row] + constV[row] * 
nRow;
+                       for(int col = row; col < nColRow; col++){
+                               ret[offOut + col] += v1l * filteredColSum[col]  
+ v2l * constV[col];
                        }
                }
        }

Reply via email to