This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cde89c5879 [SYSTEMDS-3552] Fix initialize/merge of DCSR sparse blocks
cde89c5879 is described below

commit cde89c587943fe12ba2cb685db033aee6c52e6a3
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Mar 29 17:38:54 2024 +0100

    [SYSTEMDS-3552] Fix initialize/merge of DCSR sparse blocks
    
    This patch fixes two bugs in the new DCSR sparse block representation:
    * The initialization from MCSR incorrectly indexed the compressed
      row-pointer arrays by row indexes, which only works if all rows have
      at least one non-zero.
    * The sparse block merge incorrectly added zeros into the column index,
      and value arrays because it took the size of temporarily created
      sparse rows (min capacity) instead of the actual length into account
---
 .../apache/sysds/runtime/data/SparseBlockDCSR.java | 42 +++++++++-------------
 .../test/component/sparse/SparseBlockMerge.java    |  8 +++--
 2 files changed, 23 insertions(+), 27 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index 33c1d3582f..3029370d63 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -31,7 +31,6 @@ import java.util.Iterator;
 
 import static java.util.stream.IntStream.range;
 
-// TODO: handling of completely empty matrices? This will cause errors right 
now
 public class SparseBlockDCSR extends SparseBlock
 {
        private static final long serialVersionUID = 456844244252549431L;
@@ -121,17 +120,15 @@ public class SparseBlockDCSR extends SparseBlock
                        _nnzr = _rowidx.length;
                        _size = (int)ocsr.size();
 
-                       int pos = 0;
-
-                       for (int rowptr = 0; rowptr < _rowidx.length; rowptr++) 
{
-                               pos += sblock.size(_rowidx[rowptr]);
-                               _rowptr[rowptr+1] = pos;
+                       int vpos = 0;
+                       for (int i = 0; i < _rowidx.length; i++) {
+                               vpos += sblock.size(_rowidx[i]);
+                               _rowptr[i+1] = vpos;
                        }
                }
                //general case SparseBlock
                else {
                        int rlen = sblock.numRows();
-
                        _rowidx = range(0, rlen).filter(rowIdx -> 
!sblock.isEmpty(rowIdx)).toArray();
                        _rowptr = new int[_rowidx.length + 1];
                        _colidx = new int[(int)size];
@@ -140,17 +137,16 @@ public class SparseBlockDCSR extends SparseBlock
                        _nnzr = _rowidx.length;
                        _size = (int)size;
 
-                       int pos = 0;
-
+                       int vpos = 0, rpos = 1;
                        for ( int rowIdx : _rowidx ) {
                                int apos = sblock.pos(rowIdx);
                                int alen = sblock.size(rowIdx);
                                int[] aix = sblock.indexes(rowIdx);
                                double[] avals = sblock.values(rowIdx);
-                               System.arraycopy(aix, apos, _colidx, pos, alen);
-                               System.arraycopy(avals, apos, _values, pos, 
alen);
-                               pos += alen;
-                               _rowptr[rowIdx+1] = pos;
+                               System.arraycopy(aix, apos, _colidx, vpos, 
alen);
+                               System.arraycopy(avals, apos, _values, vpos, 
alen);
+                               vpos += alen;
+                               _rowptr[rpos++] = vpos;
                        }
                }
        }
@@ -268,7 +264,6 @@ public class SparseBlockDCSR extends SparseBlock
        @Override
        public int size(int r) {
                int idx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
-
                if (idx < 0)
                        return 0;
 
@@ -347,14 +342,14 @@ public class SparseBlockDCSR extends SparseBlock
                boolean rowExists = rowIndex >= 0;
 
                if (!rowExists) {
-                       // Nothing to do
-                       if (v == 0)
+                       if (v == 0) // Nothing to do
                                return false;
 
                        int rowInsertionIndex = -rowIndex - 1;
-                       insertRow(rowInsertionIndex, r, 
_rowptr[rowInsertionIndex]);
+                       int tmp = _rowptr[rowInsertionIndex];
+                       insertRow(rowInsertionIndex, r, tmp);
                        incrRowPtr(rowInsertionIndex+1);
-                       insertCol(_rowptr[rowInsertionIndex], c, v);
+                       insertCol(tmp, c, v);
                        return true;
                }
 
@@ -416,9 +411,10 @@ public class SparseBlockDCSR extends SparseBlock
                                return;
 
                        int rowInsertionIndex = -rowIndex - 1;
-                       insertRow(rowInsertionIndex, r, 
_rowptr[rowInsertionIndex]);
+                       int tmp = _rowptr[rowInsertionIndex];
+                       insertRow(rowInsertionIndex, r, tmp);
                        incrRowPtr(rowInsertionIndex+1, newRowSize);
-                       insertCols(_rowptr[rowInsertionIndex], row.indexes(), 
row.values());
+                       insertCols(tmp, row.indexes(), row.values(), 0, 0, 
newRowSize);
                        return;
                }
 
@@ -601,7 +597,7 @@ public class SparseBlockDCSR extends SparseBlock
                        return new SparseRowScalar();
                int pos = pos(r);
                int len = size(r);
-
+               
                SparseRowVector row = new SparseRowVector(len);
                System.arraycopy(_colidx, pos, row.indexes(), 0, len);
                System.arraycopy(_values, pos, row.values(), 0, len);
@@ -882,10 +878,6 @@ public class SparseBlockDCSR extends SparseBlock
                insertCols(ix, new int[0], new double[0], len, 0, 0);
        }
 
-       private void insertCols(int ix, int[] cols, double[] vals) {
-               insertCols(ix, cols, vals, 0, 0, vals.length);
-       }
-
        private void insertCols(int ix, int[] cols, double[] vals, int 
overwriteNum) {
                insertCols(ix, cols, vals, overwriteNum, 0, vals.length);
        }
diff --git 
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java 
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
index ce2211bc6e..544569aef2 100644
--- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
+++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.component.sparse;
 
 import org.junit.Assert;
 import org.junit.Test;
+
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlockFactory;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -406,8 +407,11 @@ public class SparseBlockMerge extends AutomatedTestBase
                                        int[] aix = sblock.indexes(i);
                                        double[] avals = sblock.values(i);
                                        for( int j=0; j<alen; j++ ) {
-                                               if( avals[apos+j] != 
A[i][aix[apos+j]] )
-                                                       Assert.fail("Wrong 
value returned by scan: "+avals[apos+j]+", expected: "+A[i][apos+aix[j]]);
+                                               if( avals[apos+j] != 
A[i][aix[apos+j]] ) {
+                                                       
System.out.println("Issue at row "+i);
+                                                       Assert.fail("Wrong 
value returned by scan: "
+                                                               + avals[apos+j] 
+", expected: "+ A[i][aix[apos+j]]);
+                                               }
                                                count++;
                                        }
                                }

Reply via email to