This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 41ce5231ce [SYSTEMDS-3683] Improved sparse block non-empty row
iterators
41ce5231ce is described below
commit 41ce5231ce03897297d92734f859872bf4cfa713
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Mar 25 18:13:07 2024 +0100
[SYSTEMDS-3683] Improved sparse block non-empty row iterators
1) Specialized iterators for DCSR, CSR, MCSR, and COO
2) Generic iterable to be used in enhanced for loops
---
.../org/apache/sysds/runtime/data/SparseBlock.java | 112 ++++-----------------
.../apache/sysds/runtime/data/SparseBlockCOO.java | 71 ++++++-------
.../apache/sysds/runtime/data/SparseBlockCSR.java | 59 ++++++-----
.../apache/sysds/runtime/data/SparseBlockDCSR.java | 66 ++++++------
.../apache/sysds/runtime/data/SparseBlockMCSR.java | 59 ++++++-----
.../test/component/sparse/SparseBlockIterator.java | 4 +-
6 files changed, 140 insertions(+), 231 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index c2fd193d7c..f6f44552af 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -306,33 +306,6 @@ public abstract class SparseBlock implements Serializable,
Block
* @return starting position of row r
*/
public abstract int pos(int r);
-
- /**
- * Get the next non-zero row index in the row array.
- *
- * @param r previous row index starting at 0
- * @param ru exclusive upper row index starting at 0
- * @return next non-zero row index
- */
- public abstract int nextNonZeroRowIndex(int r, int ru);
-
- /**
- * Get the starting index in the row array.
- *
- * @param r inclusive lower row index starting at 0
- * @param ru exclusive upper row index starting at 0
- * @return starting index in row array
- */
- public abstract int setSearchIndex(int r, int ru);
-
- /**
- * Get the next index in the row array.
- *
- * @param r previous row index starting at 0
- * @param ru exclusive upper row index starting at 0
- * @return next index in row array
- */
- public abstract int updateSearchIndex(int r, int ru);
////////////////////////
@@ -586,10 +559,10 @@ public abstract class SparseBlock implements
Serializable, Block
* This iterator facilitates traversal over rows that contain at least
one non-zero element,
* skipping entirely zero rows. The returned integers represent the
indexes of non-empty rows.
*
- * @return iterator
+ * @return iterable
*/
- public Iterator<Integer> getNonEmptyRowIterator() {
- return new SparseNonEmptyRowIterator(0, numRows());
+ public Iterable<Integer> getNonEmptyRows() {
+ return new SparseNonEmptyRowIterable(0, numRows());
}
/**
@@ -599,12 +572,14 @@ public abstract class SparseBlock implements
Serializable, Block
*
* @param rl inclusive lower row index starting at 0
* @param ru exclusive upper row index starting at 0
- * @return Integer iterator
+ * @return iterable
*/
- public Iterator<Integer> getNonEmptyRowIterator(int rl, int ru) {
- return new SparseNonEmptyRowIterator(rl, ru);
+ public Iterable<Integer> getNonEmptyRows(int rl, int ru) {
+ return new SparseNonEmptyRowIterable(rl, ru);
}
+ public abstract Iterator<Integer> getNonEmptyRowsIterator(int rl, int
ru);
+
@Override
public abstract String toString();
@@ -769,71 +744,20 @@ public abstract class SparseBlock implements
Serializable, Block
}
}
- //TODO: move to individual sparse blocks for performance/separation ->
MB
- private class SparseNonEmptyRowIterator implements Iterator<Integer> {
- private int _rlen = 0; //row upper
- private int _curRow = -1; //current row
- private boolean _noNext = false; //end indicator
- private int _searchIndex = 0;
- private int _previousSearchIndex = -1;
-
- protected SparseNonEmptyRowIterator(int rl, int ru) {
- _rlen = ru;
- _curRow = rl;
- _searchIndex = setSearchIndex(_curRow, ru);
- if(_searchIndex == -1) {
- _noNext = true;
- }
- }
+ //generic iterable for use in enhanced for loops: for(int i :
s.getNonEmptyRows())
+ private class SparseNonEmptyRowIterable implements Iterable<Integer> {
+ private final int _rl; //row lower
+ private final int _ru; //row upper
- @Override
- public boolean hasNext() {
- return !_noNext;
+ protected SparseNonEmptyRowIterable(int rl, int ru) {
+ _rl = rl;
+ _ru =ru;
}
@Override
- public Integer next() {
- if(SparseBlock.this instanceof SparseBlockDCSR ||
SparseBlock.this instanceof SparseBlockCOO) {
- _curRow = nextNonZeroRowIndex(_searchIndex,
_rlen);
- _previousSearchIndex = _searchIndex;
- _searchIndex =
updateSearchIndex(_previousSearchIndex, _rlen);
- if(_previousSearchIndex == _searchIndex) {
- _noNext = true;
- }
- return _curRow;
- }
- else if(SparseBlock.this instanceof SparseBlockCSR) {
- _curRow = nextNonZeroRowIndex(_searchIndex,
_rlen);
- _searchIndex = updateSearchIndex(_curRow,
_rlen);
- _searchIndex = setSearchIndex(_searchIndex,
_rlen); // special case: single non-zero row
- if(_curRow == _previousSearchIndex || _curRow
== _searchIndex || _searchIndex == -1) {
- _noNext = true;
- _searchIndex = _curRow;
- }
- _previousSearchIndex = _curRow;
- return _curRow;
- }
- else { //MCSR
- _previousSearchIndex =
nextNonZeroRowIndex(_searchIndex, _rlen);
- _curRow =
updateSearchIndex(_previousSearchIndex, _rlen);
- if(_previousSearchIndex == _curRow) {
- _noNext = true;
- }
- else {
- _searchIndex = _curRow;
- }
- return _previousSearchIndex;
- }
- }
-
- @Override
- public void remove() {
- throw new RuntimeException("SparseBlockIterator is
unsupported!");
+ public Iterator<Integer> iterator() {
+ //use specialized non-empty row iterators of sparse
blocks
+ return getNonEmptyRowsIterator(_rl, _ru);
}
-
- /**
- * Moves cursor to next non-zero row or indicates that no more
- * rows are available.
- */
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
index b7028e1e1f..1ad2cd57d0 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
@@ -354,43 +354,6 @@ public class SparseBlockCOO extends SparseBlock
return index;
}
- @Override
- public int nextNonZeroRowIndex(int r, int ru) {
- return _rindexes[r];
- }
-
- @Override
- public int setSearchIndex(int r, int ru) {
- int insertionPoint = -1;
- int result = Arrays.binarySearch(_rindexes, r);
- if(result < 0) {
- insertionPoint = -result - 1;
- if(_rindexes[insertionPoint] == ru) {
- return -1;
- }
- return insertionPoint;
- }
- else {
- if(_rindexes[result] == ru) {
- return -1;
- }
- return result;
- }
- }
-
- @Override
- public int updateSearchIndex(int r, int ru) {
- int currentRow = _rindexes[r];
- int i = r;
- while(i < _rindexes.length && _rindexes[i] < ru) {
- if(_rindexes[i] != currentRow) {
- return i;
- }
- i++;
- }
- return r;
- }
-
@Override
public boolean set(int r, int c, double v) {
int pos = pos(r);
@@ -767,14 +730,14 @@ public class SparseBlockCOO extends SparseBlock
@Override
public IJV next( ) {
- retijv.set(_rindexes[_pos], _cindexes[_pos],
_values[_pos++]);
+ retijv.set(_rindexes[_pos], _cindexes[_pos],
_values[_pos++]);
return retijv;
}
@Override
public void remove() {
- throw new RuntimeException("SparseBlockCOOIterator is
unsupported!");
- }
+ throw new RuntimeException("SparseBlockCOOIterator is
unsupported!");
+ }
}
/**
@@ -803,4 +766,32 @@ public class SparseBlockCOO extends SparseBlock
public double[] values() {
return _values;
}
+
+ @Override
+ public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+ return new NonEmptyRowsIteratorCOO(rl, ru);
+ }
+
+ public class NonEmptyRowsIteratorCOO implements Iterator<Integer> {
+ private int _rpos;
+ private final int _ru;
+
+ public NonEmptyRowsIteratorCOO(int rl, int ru) {
+ _rpos = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public boolean hasNext() {
+ //TODO specialize for COO, but but equivalent to
existing sparse ops
+ while( _rpos<_ru && isEmpty(_rpos) )
+ _rpos++;
+ return _rpos < _ru;
+ }
+
+ @Override
+ public Integer next() {
+ return _rpos++;
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
index 18caf806d7..13e844007b 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.data;
import java.io.DataInput;
import java.io.IOException;
import java.util.Arrays;
+import java.util.Iterator;
import org.apache.sysds.runtime.util.SortUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -465,37 +466,6 @@ public class SparseBlockCSR extends SparseBlock
return _ptr[r];
}
- @Override
- public int nextNonZeroRowIndex(int r, int ru) {
- for(int i = r; i < ru; i++) {
- if(_ptr[i] < _ptr[i + 1]) {
- return i;
- }
- }
- return r - 1;
- }
-
- @Override
- public int setSearchIndex(int r, int ru) {
- if(_ptr[r] == _ptr[ru]) {
- return -1; //zero matrix
- }
- return r;
- }
-
- @Override
- public int updateSearchIndex(int r, int ru) {
- if(r + 2 == ru && _ptr[r + 1] == _ptr[r + 2]) {
- return r;
- }
- else if(r + 1 == ru) {
- return r;
- }
- else {
- return r + 1;
- }
- }
-
@Override
public boolean set(int r, int c, double v) {
int pos = pos(r);
@@ -1021,6 +991,33 @@ public class SparseBlockCSR extends SparseBlock
return false;
}
+ @Override
+ public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+ return new NonEmptyRowsIteratorCSR(rl, ru);
+ }
+
+ public class NonEmptyRowsIteratorCSR implements Iterator<Integer> {
+ private int _rpos;
+ private final int _ru;
+
+ public NonEmptyRowsIteratorCSR(int rl, int ru) {
+ _rpos = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public boolean hasNext() {
+ while( _rpos<_ru && isEmpty(_rpos) )
+ _rpos++;
+ return _rpos < _ru;
+ }
+
+ @Override
+ public Integer next() {
+ return _rpos++;
+ }
+ }
+
///////////////////////////
// private helper methods
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index 5daa4a18d1..33c1d3582f 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -341,39 +341,6 @@ public class SparseBlockDCSR extends SparseBlock
return _rowptr[idx];
}
- @Override
- public int nextNonZeroRowIndex(int r, int ru) {
- return _rowidx[r];
- }
-
- @Override
- public int setSearchIndex(int r, int ru) {
- int insertionPoint = -1;
- int result = Arrays.binarySearch(_rowidx, r);
- if(result < 0) {
- insertionPoint = -result - 1;
- if(_rowidx[insertionPoint] == ru) {
- return -1;
- }
- return insertionPoint;
- }
- else {
- if(_rowidx[result] == ru) {
- return -1;
- }
- return result;
- }
- }
-
- @Override
- public int updateSearchIndex(int r, int ru) {
- int nextIndex = r + 1;
- if(nextIndex >= _rowidx.length || _rowidx[nextIndex] >= ru) {
- nextIndex = r;
- }
- return nextIndex;
- }
-
@Override
public boolean set(int r, int c, double v) {
int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
@@ -807,6 +774,31 @@ public class SparseBlockDCSR extends SparseBlock
return true;
return false;
}
+
+ @Override
+ public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+ return new NonEmptyRowsIteratorDCSR(rl, ru);
+ }
+
+ public class NonEmptyRowsIteratorDCSR implements Iterator<Integer> {
+ private int _rpos;
+ private final int _ru;
+
+ public NonEmptyRowsIteratorDCSR(int rl, int ru) {
+ _rpos = (rl==0) ? 0 : posRowIndex(rl);
+ _ru = ru;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return _rpos < _nnzr && _rowidx[_rpos] < _ru;
+ }
+
+ @Override
+ public Integer next() {
+ return _rowidx[_rpos++];
+ }
+ }
///////////////////////////
// private helper methods
@@ -987,8 +979,14 @@ public class SparseBlockDCSR extends SparseBlock
}
private void incrRowPtr(int rowIndex, int cnt) {
-
for( int i = rowIndex; i < _nnzr + 1; i++ )
_rowptr[i] += cnt;
}
+
+ private int posRowIndex(int r) {
+ int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+ if( rowIndex < 0 )
+ rowIndex = -rowIndex - 1;
+ return rowIndex;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
index 62f4480ff0..70e36bf4cd 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.data;
+import java.util.Iterator;
+
import org.apache.sysds.utils.MemoryEstimates;
/**
@@ -339,36 +341,6 @@ public class SparseBlockMCSR extends SparseBlock
return 0;
}
- @Override
- public int nextNonZeroRowIndex(int r, int ru) {
- for(int i = r; i < ru; i++) {
- if(_rows[i] != null) {
- return i;
- }
- }
- return r;
- }
-
- @Override
- public int setSearchIndex(int r, int ru) {
- for(int i = r; i < ru; i++) {
- if(_rows[i] != null) {
- return i;
- }
- }
- return -1;
- }
-
- @Override
- public int updateSearchIndex(int r, int ru) {
- for(int i = r; i < ru; i++) {
- if(_rows[i] != null && i != r) {
- return i;
- }
- }
- return r;
- }
-
@Override
public boolean set(int r, int c, double v) {
if( !isAllocated(r) )
@@ -517,6 +489,33 @@ public class SparseBlockMCSR extends SparseBlock
return sb.toString();
}
+ @Override
+ public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+ return new NonEmptyRowsIteratorMCSR(rl, ru);
+ }
+
+ public class NonEmptyRowsIteratorMCSR implements Iterator<Integer> {
+ private int _rpos;
+ private final int _ru;
+
+ public NonEmptyRowsIteratorMCSR(int rl, int ru) {
+ _rpos = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public boolean hasNext() {
+ while( _rpos<_ru && isEmpty(_rpos) )
+ _rpos++;
+ return _rpos < _ru;
+ }
+
+ @Override
+ public Integer next() {
+ return _rpos++;
+ }
+ }
+
/**
* Helper function for MCSR -> {COO, CSR}
* @return the underlying array of {@link SparseRow}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
index 523e7d27db..d13a2076cd 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
@@ -233,8 +233,8 @@ public class SparseBlockIterator extends AutomatedTestBase {
List<Integer> manualNonZeroRows = new ArrayList<>();
List<Integer> iteratorNonZeroRows = new ArrayList<>();
Iterator<Integer> iterRows = !partial ?
- sblock.getNonEmptyRowIterator() :
- sblock.getNonEmptyRowIterator(rl, rows);
+ sblock.getNonEmptyRowsIterator(0, rows) :
+ sblock.getNonEmptyRowsIterator(rl, rows);
for(int i = rl; i < rows; i++)
if(!sblock.isEmpty(i))