This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cde89c5879 [SYSTEMDS-3552] Fix initialize/merge of DCSR sparse blocks
cde89c5879 is described below
commit cde89c587943fe12ba2cb685db033aee6c52e6a3
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Mar 29 17:38:54 2024 +0100
[SYSTEMDS-3552] Fix initialize/merge of DCSR sparse blocks
This patch fixes two bugs in the new DCSR sparse block representation:
* The initialization from MCSR incorrectly indexed the compressed
row-pointer arrays by row indexes, which only works if all rows have
at least one non-zero.
* The sparse block merge incorrectly added zeros into the column index,
and value arrays because it took the size of temporarily created
sparse rows (min capacity) instead of the actual length into account
---
.../apache/sysds/runtime/data/SparseBlockDCSR.java | 42 +++++++++-------------
.../test/component/sparse/SparseBlockMerge.java | 8 +++--
2 files changed, 23 insertions(+), 27 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index 33c1d3582f..3029370d63 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -31,7 +31,6 @@ import java.util.Iterator;
import static java.util.stream.IntStream.range;
-// TODO: handling of completely empty matrices? This will cause errors right
now
public class SparseBlockDCSR extends SparseBlock
{
private static final long serialVersionUID = 456844244252549431L;
@@ -121,17 +120,15 @@ public class SparseBlockDCSR extends SparseBlock
_nnzr = _rowidx.length;
_size = (int)ocsr.size();
- int pos = 0;
-
- for (int rowptr = 0; rowptr < _rowidx.length; rowptr++)
{
- pos += sblock.size(_rowidx[rowptr]);
- _rowptr[rowptr+1] = pos;
+ int vpos = 0;
+ for (int i = 0; i < _rowidx.length; i++) {
+ vpos += sblock.size(_rowidx[i]);
+ _rowptr[i+1] = vpos;
}
}
//general case SparseBlock
else {
int rlen = sblock.numRows();
-
_rowidx = range(0, rlen).filter(rowIdx ->
!sblock.isEmpty(rowIdx)).toArray();
_rowptr = new int[_rowidx.length + 1];
_colidx = new int[(int)size];
@@ -140,17 +137,16 @@ public class SparseBlockDCSR extends SparseBlock
_nnzr = _rowidx.length;
_size = (int)size;
- int pos = 0;
-
+ int vpos = 0, rpos = 1;
for ( int rowIdx : _rowidx ) {
int apos = sblock.pos(rowIdx);
int alen = sblock.size(rowIdx);
int[] aix = sblock.indexes(rowIdx);
double[] avals = sblock.values(rowIdx);
- System.arraycopy(aix, apos, _colidx, pos, alen);
- System.arraycopy(avals, apos, _values, pos,
alen);
- pos += alen;
- _rowptr[rowIdx+1] = pos;
+ System.arraycopy(aix, apos, _colidx, vpos,
alen);
+ System.arraycopy(avals, apos, _values, vpos,
alen);
+ vpos += alen;
+ _rowptr[rpos++] = vpos;
}
}
}
@@ -268,7 +264,6 @@ public class SparseBlockDCSR extends SparseBlock
@Override
public int size(int r) {
int idx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
-
if (idx < 0)
return 0;
@@ -347,14 +342,14 @@ public class SparseBlockDCSR extends SparseBlock
boolean rowExists = rowIndex >= 0;
if (!rowExists) {
- // Nothing to do
- if (v == 0)
+ if (v == 0) // Nothing to do
return false;
int rowInsertionIndex = -rowIndex - 1;
- insertRow(rowInsertionIndex, r,
_rowptr[rowInsertionIndex]);
+ int tmp = _rowptr[rowInsertionIndex];
+ insertRow(rowInsertionIndex, r, tmp);
incrRowPtr(rowInsertionIndex+1);
- insertCol(_rowptr[rowInsertionIndex], c, v);
+ insertCol(tmp, c, v);
return true;
}
@@ -416,9 +411,10 @@ public class SparseBlockDCSR extends SparseBlock
return;
int rowInsertionIndex = -rowIndex - 1;
- insertRow(rowInsertionIndex, r,
_rowptr[rowInsertionIndex]);
+ int tmp = _rowptr[rowInsertionIndex];
+ insertRow(rowInsertionIndex, r, tmp);
incrRowPtr(rowInsertionIndex+1, newRowSize);
- insertCols(_rowptr[rowInsertionIndex], row.indexes(),
row.values());
+ insertCols(tmp, row.indexes(), row.values(), 0, 0,
newRowSize);
return;
}
@@ -601,7 +597,7 @@ public class SparseBlockDCSR extends SparseBlock
return new SparseRowScalar();
int pos = pos(r);
int len = size(r);
-
+
SparseRowVector row = new SparseRowVector(len);
System.arraycopy(_colidx, pos, row.indexes(), 0, len);
System.arraycopy(_values, pos, row.values(), 0, len);
@@ -882,10 +878,6 @@ public class SparseBlockDCSR extends SparseBlock
insertCols(ix, new int[0], new double[0], len, 0, 0);
}
- private void insertCols(int ix, int[] cols, double[] vals) {
- insertCols(ix, cols, vals, 0, 0, vals.length);
- }
-
private void insertCols(int ix, int[] cols, double[] vals, int
overwriteNum) {
insertCols(ix, cols, vals, overwriteNum, 0, vals.length);
}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
index ce2211bc6e..544569aef2 100644
--- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
+++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.component.sparse;
import org.junit.Assert;
import org.junit.Test;
+
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockFactory;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -406,8 +407,11 @@ public class SparseBlockMerge extends AutomatedTestBase
int[] aix = sblock.indexes(i);
double[] avals = sblock.values(i);
for( int j=0; j<alen; j++ ) {
- if( avals[apos+j] !=
A[i][aix[apos+j]] )
- Assert.fail("Wrong
value returned by scan: "+avals[apos+j]+", expected: "+A[i][apos+aix[j]]);
+ if( avals[apos+j] !=
A[i][aix[apos+j]] ) {
+
System.out.println("Issue at row "+i);
+ Assert.fail("Wrong
value returned by scan: "
+ + avals[apos+j]
+", expected: "+ A[i][aix[apos+j]]);
+ }
count++;
}
}