This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new d96fa66871 [SYSTEMDS-3172] Performance improvement CSC sparse block
d96fa66871 is described below
commit d96fa66871263f056964c2d35caee0aad3cf5649
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Oct 25 17:08:07 2024 +0200
[SYSTEMDS-3172] Performance improvement CSC sparse block
This patch makes some simple performance improvement in order to
reduce the runtime of the sparse component tests (300+s -> 30s). In
detail the runtime of specific tests improved as follows:
* SparseBlockMerge: 149s -> 14.7s
* SparseBlockIndexRange: 110s -> 13.4s
* SparseBlockGetFirstIndex: 29s -> 1.3s
---
.../apache/sysds/runtime/data/SparseBlockCSC.java | 134 +++++----------------
1 file changed, 33 insertions(+), 101 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
index 1532c02b7c..b38c3525c9 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
@@ -28,7 +28,6 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
-import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
@@ -119,10 +118,8 @@ public class SparseBlockCSC extends SparseBlock {
for(SparseRow column : columns) {
int rowIdx[] = column.indexes();
double vals[] = column.values();
- for(int i = 0; i < column.size(); i++) {
- _indexes[valPos + i] = rowIdx[i];
- _values[valPos + i] = vals[i];
- }
+ System.arraycopy(rowIdx, 0, _indexes, valPos,
column.size());
+ System.arraycopy(vals, 0, _values, valPos,
column.size());
_ptr[ptrPos] = _ptr[ptrPos - 1] + column.size();
ptrPos++;
valPos += column.size();
@@ -483,8 +480,9 @@ public class SparseBlockCSC extends SparseBlock {
throw new RuntimeException("Row index has to be zero or
larger.");
int nnz = 0;
- for(int i = 0; i < _size; i++) {
- if(_indexes[i] == r)
+ for(int c=0; c<_ptr.length-1; c++) {
+ int ix = Arrays.binarySearch(_indexes, _ptr[c],
_ptr[c+1], r);
+ if(ix >= 0)
nnz++;
}
return nnz;
@@ -531,12 +529,13 @@ public class SparseBlockCSC extends SparseBlock {
@Override
public boolean isEmpty(int r) {
- boolean empty = true;
- for(int i = 0; i < _size; i++) {
- if(_indexes[i] == r)
+ int clen = numCols();
+ for(int c=0; c<clen; c++) {
+ int ix = Arrays.binarySearch(_indexes, _ptr[c],
_ptr[c+1], r);
+ if(ix >= 0)
return false;
}
- return empty;
+ return true;
}
public boolean isEmptyCol(int c) {
@@ -609,26 +608,15 @@ public class SparseBlockCSC extends SparseBlock {
@Override
public int[] indexes(int r) {
- //Count elements per row
- //int[] rowCounts = numElemPerRow();
-
- // Compute csr pointers
- int[] csrPtr = rowPointerAll();
-
- // Populate CSR indices array
- int[] csrIndices = new int[_size];
- // Temporary array to keep track of the current position in
each row
- int[] currentPos = Arrays.copyOf(csrPtr, _rlen);
-
- for(int col = 0; col < numCols(); col++) {
- for(int i = _ptr[col]; i < _ptr[col + 1]; i++) {
- int row = _indexes[i];
- int pos = currentPos[row]++;
- csrIndices[pos] = col;
- }
+ int clen = numCols();
+ int[] cix = new int[clen];
+ int pos = 0;
+ for(int c = 0; c < clen; c++) {
+ int ix = Arrays.binarySearch(_indexes, _ptr[c],
_ptr[c+1], r);
+ if(ix >= 0)
+ cix[pos++] = c;
}
-
- return csrIndices;
+ return cix;
}
public int[] indexesCol(int c) {
@@ -637,20 +625,15 @@ public class SparseBlockCSC extends SparseBlock {
@Override
public double[] values(int r) {
- // Only use first _size elements for sorting
- Integer[] idx = new Integer[_size];
- for(int i = 0; i < _size; i++)
- idx[i] = i;
-
- // Sort indices based on corresponding index values
- Arrays.sort(idx, Comparator.comparingInt(i -> _indexes[i]));
-
- // Create values array sorted in row order
- double[] csrValues = new double[_size];
- for(int i = 0; i < _size; i++) {
- csrValues[i] = _values[idx[i]];
+ int clen = numCols();
+ double[] vals = new double[clen];
+ int pos = 0;
+ for(int c = 0; c < clen; c++) {
+ int ix = Arrays.binarySearch(_indexes, _ptr[c],
_ptr[c+1], r);
+ if(ix >= 0)
+ vals[pos++] = _values[ix];
}
- return csrValues;
+ return vals;
}
public double[] valuesCol(int c) {
@@ -659,12 +642,7 @@ public class SparseBlockCSC extends SparseBlock {
@Override
public int pos(int r) {
- int nnz = 0;
- for(int i = 0; i < _size; i++) {
- if(_indexes[i] < r)
- nnz++;
- }
- return nnz;
+ return 0;
}
public int posCol(int c) {
@@ -787,7 +765,6 @@ public class SparseBlockCSC extends SparseBlock {
shiftRightAndInsert(pos + len, r, v);
}
incrPtr(c + 1);
-
}
@Override
@@ -1005,27 +982,12 @@ public class SparseBlockCSC extends SparseBlock {
@Override
public SparseRow get(int r) {
- int rowSize = size(r);
- if(rowSize == 0)
- return new SparseRowScalar();
-
- //Create sparse row
- SparseRowVector row = new SparseRowVector(rowSize);
-
- for(int i = 0; i < _size; i++) {
- if(_indexes[i] == r) {
- //Search for index i in pointer array
- for(int j = 0; j < _ptr.length; j++) {
- // two possible cases
- if(_ptr[j] < i && _ptr[j + 1] > i) {
- row.set(j, _values[i]);
- }
- else if(_ptr[j] == i && _ptr[j + 1] >
i) {
- row.set(j, _values[i]);
- break;
- }
- }
- }
+ int clen = numCols();
+ SparseRowVector row = new SparseRowVector(clen);
+ for(int c = 0; c < clen; c++) {
+ int ix = Arrays.binarySearch(_indexes, _ptr[c],
_ptr[c+1], r);
+ if(ix >= 0)
+ row.append(c, _values[ix]);
}
return row;
}
@@ -1329,16 +1291,6 @@ public class SparseBlockCSC extends SparseBlock {
_ptr[i] += cnt;
}
- private void incrRowPtr(int rl, int[] csrPtr) {
- incrRowPtr(rl, csrPtr, 1);
- }
-
- private void incrRowPtr(int rl, int[] csrPtr, int cnt) {
- for(int i = rl; i < csrPtr.length; i++) {
- csrPtr[i] += cnt;
- }
- }
-
private void decrPtr(int cl) {
decrPtr(cl, 1);
}
@@ -1348,26 +1300,6 @@ public class SparseBlockCSC extends SparseBlock {
_ptr[i] -= cnt;
}
- @SuppressWarnings("unused")
- private int[] numElemPerRow() {
- int rlen = numRows();
- int[] rowCount = new int[rlen];
- for(int i = 0; i < _size; i++) {
- rowCount[_indexes[i]] += 1;
- }
- return rowCount;
- }
-
- private int[] rowPointerAll() {
- int rlen = numRows();
- int[] csrPtr = new int[rlen + 1];
- csrPtr[0] = 0;
- for(int i = 0; i < _size; i++)
- incrRowPtr(_indexes[i] + 1, csrPtr);
-
- return csrPtr;
- }
-
private int internPosFIndexLTECol(int r, int c) {
int pos = posCol(c);
int len = sizeCol(c);