This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 41ce5231ce [SYSTEMDS-3683] Improved sparse block non-empty row 
iterators
41ce5231ce is described below

commit 41ce5231ce03897297d92734f859872bf4cfa713
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Mar 25 18:13:07 2024 +0100

    [SYSTEMDS-3683] Improved sparse block non-empty row iterators
    
    1) Specialized iterators for DCSR, CSR, MCSR, and COO
    2) Generic iterable to be used in enhanced for loops
---
 .../org/apache/sysds/runtime/data/SparseBlock.java | 112 ++++-----------------
 .../apache/sysds/runtime/data/SparseBlockCOO.java  |  71 ++++++-------
 .../apache/sysds/runtime/data/SparseBlockCSR.java  |  59 ++++++-----
 .../apache/sysds/runtime/data/SparseBlockDCSR.java |  66 ++++++------
 .../apache/sysds/runtime/data/SparseBlockMCSR.java |  59 ++++++-----
 .../test/component/sparse/SparseBlockIterator.java |   4 +-
 6 files changed, 140 insertions(+), 231 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index c2fd193d7c..f6f44552af 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -306,33 +306,6 @@ public abstract class SparseBlock implements Serializable, 
Block
         * @return starting position of row r
         */
        public abstract int pos(int r);
-
-       /**
-        * Get the next non-zero row index in the row array.
-        *
-        * @param r  previous row index starting at 0
-        * @param ru exclusive upper row index starting at 0
-        * @return next non-zero row index
-        */
-       public abstract int nextNonZeroRowIndex(int r, int ru);
-
-       /**
-        * Get the starting index in the row array.
-        *
-        * @param r  inclusive lower row index starting at 0
-        * @param ru exclusive upper row index starting at 0
-        * @return starting index in row array
-        */
-       public abstract int setSearchIndex(int r, int ru);
-
-       /**
-        * Get the next index in the row array.
-        *
-        * @param r  previous row index starting at 0
-        * @param ru exclusive upper row index starting at 0
-        * @return next index in row array
-        */
-       public abstract int updateSearchIndex(int r, int ru);
        
        
        ////////////////////////
@@ -586,10 +559,10 @@ public abstract class SparseBlock implements 
Serializable, Block
         * This iterator facilitates traversal over rows that contain at least 
one non-zero element,
         * skipping entirely zero rows. The returned integers represent the 
indexes of non-empty rows.
         *
-        * @return iterator
+        * @return iterable
         */
-       public Iterator<Integer> getNonEmptyRowIterator() {
-               return new SparseNonEmptyRowIterator(0, numRows());
+       public Iterable<Integer> getNonEmptyRows() {
+               return new SparseNonEmptyRowIterable(0, numRows());
        }
 
        /**
@@ -599,12 +572,14 @@ public abstract class SparseBlock implements 
Serializable, Block
         *
         * @param rl inclusive lower row index starting at 0
         * @param ru exclusive upper row index starting at 0
-        * @return Integer iterator
+        * @return iterable
         */
-       public Iterator<Integer> getNonEmptyRowIterator(int rl, int ru) {
-               return new SparseNonEmptyRowIterator(rl, ru);
+       public Iterable<Integer> getNonEmptyRows(int rl, int ru) {
+               return new SparseNonEmptyRowIterable(rl, ru);
        }
        
+       public abstract Iterator<Integer> getNonEmptyRowsIterator(int rl, int 
ru);
+       
        @Override 
        public abstract String toString();
        
@@ -769,71 +744,20 @@ public abstract class SparseBlock implements 
Serializable, Block
                }
        }
        
-       //TODO: move to individual sparse blocks for performance/separation -> 
MB
-       private class SparseNonEmptyRowIterator implements Iterator<Integer> {
-               private int _rlen = 0; //row upper
-               private int _curRow = -1; //current row
-               private boolean _noNext = false; //end indicator
-               private int _searchIndex = 0;
-               private int _previousSearchIndex = -1;
-
-               protected SparseNonEmptyRowIterator(int rl, int ru) {
-                       _rlen = ru;
-                       _curRow = rl;
-                       _searchIndex = setSearchIndex(_curRow, ru);
-                       if(_searchIndex == -1) {
-                               _noNext = true;
-                       }
-               }
+       //generic iterable for use in enhanced for loops: for(int i : 
s.getNonEmptyRows())
+       private class SparseNonEmptyRowIterable implements Iterable<Integer> {
+               private final int _rl; //row lower
+               private final int _ru; //row upper
 
-               @Override
-               public boolean hasNext() {
-                       return !_noNext;
+               protected SparseNonEmptyRowIterable(int rl, int ru) {
+                       _rl = rl;
+                       _ru  =ru;
                }
 
                @Override
-               public Integer next() {
-                       if(SparseBlock.this instanceof SparseBlockDCSR || 
SparseBlock.this instanceof SparseBlockCOO) {
-                               _curRow = nextNonZeroRowIndex(_searchIndex, 
_rlen);
-                               _previousSearchIndex = _searchIndex;
-                               _searchIndex = 
updateSearchIndex(_previousSearchIndex, _rlen);
-                               if(_previousSearchIndex == _searchIndex) {
-                                       _noNext = true;
-                               }
-                               return _curRow;
-                       }
-                       else if(SparseBlock.this instanceof SparseBlockCSR) {
-                               _curRow = nextNonZeroRowIndex(_searchIndex, 
_rlen);
-                               _searchIndex = updateSearchIndex(_curRow, 
_rlen);
-                               _searchIndex = setSearchIndex(_searchIndex, 
_rlen); // special case: single non-zero row
-                               if(_curRow == _previousSearchIndex || _curRow 
== _searchIndex || _searchIndex == -1) {
-                                       _noNext = true;
-                                       _searchIndex = _curRow;
-                               }
-                               _previousSearchIndex = _curRow;
-                               return _curRow;
-                       }
-                       else { //MCSR
-                               _previousSearchIndex = 
nextNonZeroRowIndex(_searchIndex, _rlen);
-                               _curRow = 
updateSearchIndex(_previousSearchIndex, _rlen);
-                               if(_previousSearchIndex == _curRow) {
-                                       _noNext = true;
-                               }
-                               else {
-                                       _searchIndex = _curRow;
-                               }
-                               return _previousSearchIndex;
-                       }
-               }
-
-               @Override
-               public void remove() {
-                       throw new RuntimeException("SparseBlockIterator is 
unsupported!");
+               public Iterator<Integer> iterator() {
+                       //use specialized non-empty row iterators of sparse 
blocks
+                       return getNonEmptyRowsIterator(_rl, _ru);
                }
-
-               /**
-                * Moves cursor to next non-zero row or indicates that no more
-                * rows are available.
-                */
        }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
index b7028e1e1f..1ad2cd57d0 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
@@ -354,43 +354,6 @@ public class SparseBlockCOO extends SparseBlock
                return index;
        }
 
-       @Override
-       public int nextNonZeroRowIndex(int r, int ru) {
-               return _rindexes[r];
-       }
-
-       @Override
-       public int setSearchIndex(int r, int ru) {
-               int insertionPoint = -1;
-               int result = Arrays.binarySearch(_rindexes, r);
-               if(result < 0) {
-                       insertionPoint = -result - 1;
-                       if(_rindexes[insertionPoint] == ru) {
-                               return -1;
-                       }
-                       return insertionPoint;
-               }
-               else {
-                       if(_rindexes[result] == ru) {
-                               return -1;
-                       }
-                       return result;
-               }
-       }
-
-       @Override
-       public int updateSearchIndex(int r, int ru) {
-               int currentRow = _rindexes[r];
-               int i = r;
-               while(i < _rindexes.length && _rindexes[i] < ru) {
-                       if(_rindexes[i] != currentRow) {
-                               return i;
-                       }
-                       i++;
-               }
-               return r;
-       }
-
        @Override
        public boolean set(int r, int c, double v) {
                int pos = pos(r);
@@ -767,14 +730,14 @@ public class SparseBlockCOO extends SparseBlock
 
                @Override
                public IJV next( ) {
-                       retijv.set(_rindexes[_pos], _cindexes[_pos], 
_values[_pos++]);                  
+                       retijv.set(_rindexes[_pos], _cindexes[_pos], 
_values[_pos++]);
                        return retijv;
                }
 
                @Override
                public void remove() {
-                       throw new RuntimeException("SparseBlockCOOIterator is 
unsupported!");                   
-               }               
+                       throw new RuntimeException("SparseBlockCOOIterator is 
unsupported!");
+               }
        }
        
        /**
@@ -803,4 +766,32 @@ public class SparseBlockCOO extends SparseBlock
        public double[] values() {
                return _values;
        }
+       
+       @Override
+       public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+               return new NonEmptyRowsIteratorCOO(rl, ru);
+       }
+       
+       public class NonEmptyRowsIteratorCOO implements Iterator<Integer> {
+               private int _rpos;
+               private final int _ru;
+               
+               public NonEmptyRowsIteratorCOO(int rl, int ru) {
+                       _rpos = rl;
+                       _ru = ru;
+               }
+               
+               @Override
+               public boolean hasNext() {
+                       //TODO specialize for COO, but but equivalent to 
existing sparse ops
+                       while( _rpos<_ru && isEmpty(_rpos) )
+                               _rpos++;
+                       return _rpos < _ru;
+               }
+
+               @Override
+               public Integer next() {
+                       return _rpos++;
+               }
+       }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
index 18caf806d7..13e844007b 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.data;
 import java.io.DataInput;
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Iterator;
 
 import org.apache.sysds.runtime.util.SortUtils;
 import org.apache.sysds.runtime.util.UtilFunctions;
@@ -465,37 +466,6 @@ public class SparseBlockCSR extends SparseBlock
                return _ptr[r];
        }
 
-       @Override
-       public int nextNonZeroRowIndex(int r, int ru) {
-               for(int i = r; i < ru; i++) {
-                       if(_ptr[i] < _ptr[i + 1]) {
-                               return i;
-                       }
-               }
-               return r - 1;
-       }
-
-       @Override
-       public int setSearchIndex(int r, int ru) {
-               if(_ptr[r] == _ptr[ru]) {
-                       return -1; //zero matrix
-               }
-               return r;
-       }
-
-       @Override
-       public int updateSearchIndex(int r, int ru) {
-               if(r + 2 == ru && _ptr[r + 1] == _ptr[r + 2]) {
-                       return r;
-               }
-               else if(r + 1 == ru) {
-                       return r;
-               }
-               else {
-                       return r + 1;
-               }
-       }
-
        @Override
        public boolean set(int r, int c, double v) {
                int pos = pos(r);
@@ -1021,6 +991,33 @@ public class SparseBlockCSR extends SparseBlock
                return false;
        }
 
+       @Override
+       public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+               return new NonEmptyRowsIteratorCSR(rl, ru);
+       }
+       
+       public class NonEmptyRowsIteratorCSR implements Iterator<Integer> {
+               private int _rpos;
+               private final int _ru;
+               
+               public NonEmptyRowsIteratorCSR(int rl, int ru) {
+                       _rpos = rl;
+                       _ru = ru;
+               }
+               
+               @Override
+               public boolean hasNext() {
+                       while( _rpos<_ru && isEmpty(_rpos) )
+                               _rpos++;
+                       return _rpos < _ru;
+               }
+
+               @Override
+               public Integer next() {
+                       return _rpos++;
+               }
+       }
+       
        ///////////////////////////
        // private helper methods
        
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index 5daa4a18d1..33c1d3582f 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -341,39 +341,6 @@ public class SparseBlockDCSR extends SparseBlock
                return _rowptr[idx];
        }
 
-       @Override
-       public int nextNonZeroRowIndex(int r, int ru) {
-               return _rowidx[r];
-       }
-
-       @Override
-       public int setSearchIndex(int r, int ru) {
-               int insertionPoint = -1;
-               int result = Arrays.binarySearch(_rowidx, r);
-               if(result < 0) {
-                       insertionPoint = -result - 1;
-                       if(_rowidx[insertionPoint] == ru) {
-                               return -1;
-                       }
-                       return insertionPoint;
-               }
-               else {
-                       if(_rowidx[result] == ru) {
-                               return -1;
-                       }
-                       return result;
-               }
-       }
-
-       @Override
-       public int updateSearchIndex(int r, int ru) {
-               int nextIndex = r + 1;
-               if(nextIndex >= _rowidx.length || _rowidx[nextIndex] >= ru) {
-                       nextIndex = r;
-               }
-               return nextIndex;
-       }
-
        @Override
        public boolean set(int r, int c, double v) {
                int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
@@ -807,6 +774,31 @@ public class SparseBlockDCSR extends SparseBlock
                                return true;
                return false;
        }
+       
+       @Override
+       public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+               return new NonEmptyRowsIteratorDCSR(rl, ru);
+       }
+       
+       public class NonEmptyRowsIteratorDCSR implements Iterator<Integer> {
+               private int _rpos;
+               private final int _ru;
+               
+               public NonEmptyRowsIteratorDCSR(int rl, int ru) {
+                       _rpos = (rl==0) ? 0 : posRowIndex(rl);
+                       _ru = ru;
+               }
+               
+               @Override
+               public boolean hasNext() {
+                       return _rpos < _nnzr && _rowidx[_rpos] < _ru;
+               }
+
+               @Override
+               public Integer next() {
+                       return _rowidx[_rpos++];
+               }
+       }
 
        ///////////////////////////
        // private helper methods
@@ -987,8 +979,14 @@ public class SparseBlockDCSR extends SparseBlock
        }
 
        private void incrRowPtr(int rowIndex, int cnt) {
-
                for( int i = rowIndex; i < _nnzr + 1; i++ )
                        _rowptr[i] += cnt;
        }
+       
+       private int posRowIndex(int r) {
+               int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+               if( rowIndex < 0 )
+                       rowIndex = -rowIndex - 1;
+               return rowIndex;
+       }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
index 62f4480ff0..70e36bf4cd 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.runtime.data;
 
+import java.util.Iterator;
+
 import org.apache.sysds.utils.MemoryEstimates;
 
 /**
@@ -339,36 +341,6 @@ public class SparseBlockMCSR extends SparseBlock
                return 0;
        }
 
-       @Override
-       public int nextNonZeroRowIndex(int r, int ru) {
-               for(int i = r; i < ru; i++) {
-                       if(_rows[i] != null) {
-                               return i;
-                       }
-               }
-               return r;
-       }
-
-       @Override
-       public int setSearchIndex(int r, int ru) {
-               for(int i = r; i < ru; i++) {
-                       if(_rows[i] != null) {
-                               return i;
-                       }
-               }
-               return -1;
-       }
-
-       @Override
-       public int updateSearchIndex(int r, int ru) {
-               for(int i = r; i < ru; i++) {
-                       if(_rows[i] != null && i != r) {
-                               return i;
-                       }
-               }
-               return r;
-       }
-
        @Override
        public boolean set(int r, int c, double v) {
                if( !isAllocated(r) )
@@ -517,6 +489,33 @@ public class SparseBlockMCSR extends SparseBlock
                return sb.toString();
        }
        
+       @Override
+       public Iterator<Integer> getNonEmptyRowsIterator(int rl, int ru) {
+               return new NonEmptyRowsIteratorMCSR(rl, ru);
+       }
+       
+       public class NonEmptyRowsIteratorMCSR implements Iterator<Integer> {
+               private int _rpos;
+               private final int _ru;
+               
+               public NonEmptyRowsIteratorMCSR(int rl, int ru) {
+                       _rpos = rl;
+                       _ru = ru;
+               }
+               
+               @Override
+               public boolean hasNext() {
+                       while( _rpos<_ru && isEmpty(_rpos) )
+                               _rpos++;
+                       return _rpos < _ru;
+               }
+
+               @Override
+               public Integer next() {
+                       return _rpos++;
+               }
+       }
+       
        /**
         * Helper function for MCSR -&gt; {COO, CSR}
         * @return the underlying array of {@link SparseRow}
diff --git 
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java 
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
index 523e7d27db..d13a2076cd 100644
--- 
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
+++ 
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
@@ -233,8 +233,8 @@ public class SparseBlockIterator extends AutomatedTestBase {
                        List<Integer> manualNonZeroRows = new ArrayList<>();
                        List<Integer> iteratorNonZeroRows = new ArrayList<>();
                        Iterator<Integer> iterRows = !partial ?
-                               sblock.getNonEmptyRowIterator() :
-                               sblock.getNonEmptyRowIterator(rl, rows);
+                               sblock.getNonEmptyRowsIterator(0, rows) :
+                               sblock.getNonEmptyRowsIterator(rl, rows);
 
                        for(int i = rl; i < rows; i++)
                                if(!sblock.isEmpty(i))

Reply via email to