This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 246eea9784 [MINOR] Uncompressed ColGroup Outer TSMM
246eea9784 is described below

commit 246eea9784aa3b34c9eefdaee4666708b5a7db95
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sat Dec 30 14:10:02 2023 +0100

    [MINOR] Uncompressed ColGroup Outer TSMM
    
    Add support for sparse outer TSMM for uncompressed column groups.
    This was missing in 1c26e2d299ace9f0b3b4974c9d8bac665fd9692e
    
    Closes #1968
---
 .../compress/colgroup/ColGroupUncompressed.java    | 35 +++++++++++++++++-----
 .../component/compress/colgroup/ColGroupTest.java  | 21 ++++++++-----
 2 files changed, 41 insertions(+), 15 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index c4713d6e59..d5553deb41 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -532,14 +532,33 @@ public class ColGroupUncompressed extends AColGroup {
                // tsmm but only upper triangle.
                LibMatrixMult.matrixMultTransposeSelf(_data, tmp, true, false);
 
-               // copy that upper triangle part to ret
-               final int numColumns = ret.getNumColumns();
-               final double[] result = ret.getDenseBlockValues();
-               final double[] tmpV = tmp.getDenseBlockValues();
-               for(int row = 0, offTmp = 0; row < tCol; row++, offTmp += tCol) 
{
-                       final int offRet = _colIndexes.get(row) * numColumns;
-                       for(int col = row; col < tCol; col++)
-                               result[offRet + _colIndexes.get(col)] += 
tmpV[offTmp + col];
+               if(tmp.isInSparseFormat()){
+                       final int numColumns = ret.getNumColumns();
+                       final double[] result = ret.getDenseBlockValues();
+                       final SparseBlock sb = tmp.getSparseBlock();
+                       for(int row = 0; row < tCol; row++) {
+                               final int offRet = _colIndexes.get(row) * 
numColumns;
+                               if(sb.isEmpty(row))
+                                       continue;
+                               int apos = sb.pos(row);
+                               int alen = sb.size(row) + apos;
+                               int[] aix = sb.indexes(row);
+                               double[] aval = sb.values(row);
+                               for(int j = apos; j < alen; j++)
+                                       result[offRet + 
_colIndexes.get(aix[j])] += aval[j];
+                               
+                       }
+               }
+               else{
+                       // copy that upper triangle part to ret
+                       final int numColumns = ret.getNumColumns();
+                       final double[] result = ret.getDenseBlockValues();
+                       final double[] tmpV = tmp.getDenseBlockValues();
+                       for(int row = 0, offTmp = 0; row < tCol; row++, offTmp 
+= tCol) {
+                               final int offRet = _colIndexes.get(row) * 
numColumns;
+                               for(int col = row; col < tCol; col++)
+                                       result[offRet + _colIndexes.get(col)] 
+= tmpV[offTmp + col];
+                       }
                }
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
index 14f4a56c18..54a543ad13 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
@@ -1118,13 +1118,20 @@ public class ColGroupTest extends ColGroupBase {
 
        @Test
        public void tsmm() {
-               final MatrixBlock bt = new MatrixBlock(maxCol, maxCol, false);
-               final MatrixBlock ot = new MatrixBlock(maxCol, maxCol, false);
-               ot.allocateDenseBlock();
-               bt.allocateDenseBlock();
-               base.tsmm(bt, nRow);
-               other.tsmm(ot, nRow);
-               compare(ot, bt);
+               try{
+
+                       final MatrixBlock bt = new MatrixBlock(maxCol, maxCol, 
false);
+                       final MatrixBlock ot = new MatrixBlock(maxCol, maxCol, 
false);
+                       ot.allocateDenseBlock();
+                       bt.allocateDenseBlock();
+                       base.tsmm(bt, nRow);
+                       other.tsmm(ot, nRow);
+                       compare(ot, bt);
+               }
+               catch(Exception e){
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
        }
 
        @Test

Reply via email to