This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 246eea9784 [MINOR] Uncompressed ColGroup Outer TSMM
246eea9784 is described below
commit 246eea9784aa3b34c9eefdaee4666708b5a7db95
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sat Dec 30 14:10:02 2023 +0100
[MINOR] Uncompressed ColGroup Outer TSMM
Add support for sparse outer TSMM for uncompressed column groups.
This was missing in 1c26e2d299ace9f0b3b4974c9d8bac665fd9692e
Closes #1968
---
.../compress/colgroup/ColGroupUncompressed.java | 35 +++++++++++++++++-----
.../component/compress/colgroup/ColGroupTest.java | 21 ++++++++-----
2 files changed, 41 insertions(+), 15 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index c4713d6e59..d5553deb41 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -532,14 +532,33 @@ public class ColGroupUncompressed extends AColGroup {
// tsmm but only upper triangle.
LibMatrixMult.matrixMultTransposeSelf(_data, tmp, true, false);
- // copy that upper triangle part to ret
- final int numColumns = ret.getNumColumns();
- final double[] result = ret.getDenseBlockValues();
- final double[] tmpV = tmp.getDenseBlockValues();
- for(int row = 0, offTmp = 0; row < tCol; row++, offTmp += tCol)
{
- final int offRet = _colIndexes.get(row) * numColumns;
- for(int col = row; col < tCol; col++)
- result[offRet + _colIndexes.get(col)] +=
tmpV[offTmp + col];
+ if(tmp.isInSparseFormat()){
+ final int numColumns = ret.getNumColumns();
+ final double[] result = ret.getDenseBlockValues();
+ final SparseBlock sb = tmp.getSparseBlock();
+ for(int row = 0; row < tCol; row++) {
+ final int offRet = _colIndexes.get(row) *
numColumns;
+ if(sb.isEmpty(row))
+ continue;
+ int apos = sb.pos(row);
+ int alen = sb.size(row) + apos;
+ int[] aix = sb.indexes(row);
+ double[] aval = sb.values(row);
+ for(int j = apos; j < alen; j++)
+ result[offRet +
_colIndexes.get(aix[j])] += aval[j];
+
+ }
+ }
+ else{
+ // copy that upper triangle part to ret
+ final int numColumns = ret.getNumColumns();
+ final double[] result = ret.getDenseBlockValues();
+ final double[] tmpV = tmp.getDenseBlockValues();
+ for(int row = 0, offTmp = 0; row < tCol; row++, offTmp
+= tCol) {
+ final int offRet = _colIndexes.get(row) *
numColumns;
+ for(int col = row; col < tCol; col++)
+ result[offRet + _colIndexes.get(col)]
+= tmpV[offTmp + col];
+ }
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
index 14f4a56c18..54a543ad13 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
@@ -1118,13 +1118,20 @@ public class ColGroupTest extends ColGroupBase {
@Test
public void tsmm() {
- final MatrixBlock bt = new MatrixBlock(maxCol, maxCol, false);
- final MatrixBlock ot = new MatrixBlock(maxCol, maxCol, false);
- ot.allocateDenseBlock();
- bt.allocateDenseBlock();
- base.tsmm(bt, nRow);
- other.tsmm(ot, nRow);
- compare(ot, bt);
+ try{
+
+ final MatrixBlock bt = new MatrixBlock(maxCol, maxCol,
false);
+ final MatrixBlock ot = new MatrixBlock(maxCol, maxCol,
false);
+ ot.allocateDenseBlock();
+ bt.allocateDenseBlock();
+ base.tsmm(bt, nRow);
+ other.tsmm(ot, nRow);
+ compare(ot, bt);
+ }
+ catch(Exception e){
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
}
@Test