This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 65a8cabe56 [SYSTEMDS-3552] New double-compressed spark row block
65a8cabe56 is described below
commit 65a8cabe56bd96e65f446ce0a7fa71b9d36e2738
Author: Jaybit0 <[email protected]>
AuthorDate: Sat Jan 13 17:06:14 2024 +0100
[SYSTEMDS-3552] New double-compressed spark row block
DIA 23/24 project.
Closes 1969.
---
.../org/apache/sysds/runtime/data/SparseBlock.java | 7 +-
.../apache/sysds/runtime/data/SparseBlockDCSR.java | 946 +++++++++++++++++++++
.../sysds/runtime/data/SparseBlockFactory.java | 3 +
.../encode/ColumnEncoderWordEmbedding.java | 2 +-
.../component/sparse/SparseBlockAlignment.java | 39 +-
.../component/sparse/SparseBlockAppendSort.java | 32 +
.../test/component/sparse/SparseBlockDelete.java | 17 +
.../component/sparse/SparseBlockGetFirstIndex.java | 47 +
.../test/component/sparse/SparseBlockGetSet.java | 70 +-
.../component/sparse/SparseBlockIndexRange.java | 32 +
.../test/component/sparse/SparseBlockIterator.java | 32 +
.../component/sparse/SparseBlockMemEstimate.java | 5 +
.../test/component/sparse/SparseBlockMerge.java | 140 +++
.../test/component/sparse/SparseBlockScan.java | 17 +
.../test/component/sparse/SparseBlockSize.java | 21 +-
15 files changed, 1393 insertions(+), 17 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index cd1bd751f3..b19d132503 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -48,9 +48,10 @@ public abstract class SparseBlock implements Serializable,
Block
protected static final double RESIZE_FACTOR2 = 1.1; //factor after
reaching est nnz
public enum Type {
- MCSR,
- CSR,
- COO,
+ COO, // coordinate
+ CSR, // compressed sparse rows
+ DCSR, // double compressed sparse rows
+ MCSR, // modified compressed sparse rows (update-friendly)
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
new file mode 100644
index 0000000000..fd187f3064
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -0,0 +1,946 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.data;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.matrix.data.IJV;
+import org.apache.sysds.runtime.util.SortUtils;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.utils.MemoryEstimates;
+
+import java.util.Arrays;
+import java.util.Iterator;
+
+import static java.util.stream.IntStream.range;
+
+// TODO: handling of completely empty matrices? This will cause errors right
now
+public class SparseBlockDCSR extends SparseBlock
+{
+ private static final long serialVersionUID = 456844244252549431L;
+
+ private static final Log LOG =
LogFactory.getLog(SparseBlockDCSR.class.getName());
+
+ private int[] _rowidx = null; // row index array (size: >=
+ private int[] _rowptr = null; //
+ private int[] _colidx = null; // column index array (size: >=nnz)
+ private double[] _values = null; // value array (size: >=nnz)
+ private int _size = 0; // actual nnz
+ private int _rlen = 0; // number of rows
+ private int _nnzr = 0; // number of nonzero rows
+
+ public SparseBlockDCSR(int rlen) {
+ this(rlen, INIT_CAPACITY);
+ }
+
+ public SparseBlockDCSR(int rlen, int capacity) {
+ //TODO: This allocates too much space (we care about number of
non-empty rows)
+ LOG.warn("Allocating a DCSR-block using row-length. This will
lead to significant overhead!");
+ LOG.warn("If you want to initialize a sparse block using rlen,
choose SparseBlockCSR instead!");
+
+ _rowidx = new int[rlen];
+ _rowptr = new int[rlen + 1];
+ _colidx = new int[capacity];
+ _values = new double[capacity];
+ _rlen = rlen;
+ _size = 0;
+ _nnzr = 0;
+ }
+
+ public SparseBlockDCSR(int rlen, int capacity, int size, int nnzr){
+ LOG.warn("Allocating a DCSR-block using row-length. This will
lead to significant overhead!");
+ _rowidx = new int[rlen];
+ _rowptr = new int[rlen + 1];
+ _colidx = new int[capacity];
+ _values = new double[capacity];
+ _rlen = rlen;
+ _size = size;
+ _nnzr = nnzr;
+ }
+
+ public SparseBlockDCSR(int[] rowIdx, int[] rowPtr, int[] colIdx,
double[] values, int rlen, int nnz, int nnzr){
+ LOG.warn("Allocating a DCSR-block using row-length. This will
lead to significant overhead!");
+ _rowidx = rowIdx;
+ _rowptr = rowPtr;
+ _colidx = colIdx;
+ _values = values;
+ _rlen = rlen;
+ _size = nnz;
+ _nnzr = nnzr;
+ }
+
+ /**
+ * Copy constructor sparse block abstraction.
+ *
+ * @param sblock sparse block to copy
+ */
+ public SparseBlockDCSR(SparseBlock sblock)
+ {
+ long size = sblock.size();
+ if( size > Integer.MAX_VALUE )
+ throw new RuntimeException("SparseBlockDCSR supports
nnz<=Integer.MAX_VALUE but got "+size);
+
+ //special case SparseBlockDCSR
+ if( sblock instanceof SparseBlockDCSR ) {
+ SparseBlockDCSR ocsr = (SparseBlockDCSR)sblock;
+ _rowidx = Arrays.copyOf(ocsr._rowidx, ocsr._nnzr);
+ _rowptr = Arrays.copyOf(ocsr._rowptr, ocsr._nnzr+1);
+ _colidx = Arrays.copyOf(ocsr._colidx, ocsr._size);
+ _values = Arrays.copyOf(ocsr._values, ocsr._size);
+ _rlen = ocsr._rlen;
+ _nnzr = ocsr._nnzr;
+ _size = ocsr._size;
+ }
+ else if( sblock instanceof SparseBlockCSR ) {
+ // More efficient conversion from CSR to DCSR
+ int rlen = sblock.numRows();
+
+ SparseBlockCSR ocsr = (SparseBlockCSR)sblock;
+ _rowidx = range(0, rlen).filter(rowIdx ->
!sblock.isEmpty(rowIdx)).toArray();
+ _rowptr = new int[_rowidx.length + 1];
+ _colidx = Arrays.copyOf(ocsr.indexes(),
(int)ocsr.size());
+ _values = Arrays.copyOf(ocsr.values(),
(int)ocsr.size());
+ _rlen = rlen;
+ _nnzr = _rowidx.length;
+ _size = (int)ocsr.size();
+
+ int pos = 0;
+
+ for (int rowptr = 0; rowptr < _rowidx.length; rowptr++)
{
+ pos += sblock.size(_rowidx[rowptr]);
+ _rowptr[rowptr+1] = pos;
+ }
+ }
+ //general case SparseBlock
+ else {
+ int rlen = sblock.numRows();
+
+ _rowidx = range(0, rlen).filter(rowIdx ->
!sblock.isEmpty(rowIdx)).toArray();
+ _rowptr = new int[_rowidx.length + 1];
+ _colidx = new int[(int)size];
+ _values = new double[(int)size];
+ _rlen = rlen;
+ _nnzr = _rowidx.length;
+ _size = (int)size;
+
+ int pos = 0;
+
+ for ( int rowIdx : _rowidx ) {
+ int apos = sblock.pos(rowIdx);
+ int alen = sblock.size(rowIdx);
+ int[] aix = sblock.indexes(rowIdx);
+ double[] avals = sblock.values(rowIdx);
+ System.arraycopy(aix, apos, _colidx, pos, alen);
+ System.arraycopy(avals, apos, _values, pos,
alen);
+ pos += alen;
+ _rowptr[rowIdx+1] = pos;
+ }
+ }
+ }
+
+ /**
+ * Get the estimated in-memory size of the sparse block in CSR
+ * with the given dimensions w/o accounting for overallocation.
+ *
+ * @param nrows number of rows
+ * @param ncols number of columns
+ * @param sparsity sparsity ratio
+ * @return memory estimate
+ */
+ public static long estimateSizeInMemory(long nrows, long ncols, double
sparsity) {
+ double lnnz = Math.max(INIT_CAPACITY,
Math.ceil(sparsity*nrows*ncols));
+
+ //32B overhead per array, int arr in nrows, int/double arr in
nnz
+ double size = 16;
// Memory overhead of the object
+ size += 4 + 4 + 4 + 4;
// 3x int field + 0 (padding not necessary)
+ size += MemoryEstimates.intArrayCost(nrows); //
rowidx array (row indices)
+ size += MemoryEstimates.intArrayCost(nrows+1); // rowptr array
(row pointers)
+ size += MemoryEstimates.intArrayCost((long) lnnz); // colidx
array (column indexes)
+ size += MemoryEstimates.doubleArrayCost((long) lnnz);// values
array (non-zero values)
+
+ //robustness for long overflows
+ return (long) Math.min(size, Long.MAX_VALUE);
+ }
+
+ ///////////////////
+ //SparseBlock implementation
+
+ @Override
+ public void allocate(int r) {
+ //do nothing everything preallocated
+ }
+
+ @Override
+ public void allocate(int r, int nnz) {
+ //do nothing everything preallocated
+ }
+
+ @Override
+ public void allocate(int r, int ennz, int maxnnz) {
+ //do nothing everything preallocated
+ }
+
+ @Override
+ public void compact(int r) {
+ //do nothing everything preallocated
+ }
+
+ @Override
+ public int numRows() {
+ return _rlen;
+ }
+
+ @Override
+ public boolean isThreadSafe() {
+ return false;
+ }
+
+ @Override
+ public boolean isContiguous() {
+ return true;
+ }
+
+ @Override
+ public boolean isAllocated(int r) {
+ return true;
+ }
+
+ @Override
+ public void reset() {
+ if( _size > 0 ) {
+ _size = 0;
+ _nnzr = 0;
+ _rlen = 0;
+ }
+ }
+
+ @Override
+ public void reset(int ennz, int maxnnz) {
+ if( _size > 0 ) {
+ _size = 0;
+ _nnzr = 0;
+ _rlen = 0;
+ }
+ }
+
+ @Override
+ public void reset(int r, int ennz, int maxnnz) {
+ deleteIndexRange(r, 0, Integer.MAX_VALUE);
+ }
+
+ @Override
+ public long size() {
+ return _size;
+ }
+
+ @Override
+ public int size(int r) {
+ int idx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (idx < 0)
+ return 0;
+
+ return _rowptr[idx+1] - _rowptr[idx];
+ }
+
+ @Override
+ public long size(int rl, int ru) {
+ int lowerIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, rl);
+
+ if (lowerIdx < 0)
+ lowerIdx = -lowerIdx - 1;
+
+ int upperIdx = Arrays.binarySearch(_rowidx, lowerIdx, _nnzr,
ru);
+
+ if (upperIdx < 0)
+ upperIdx = -upperIdx - 1;
+
+ return _rowptr[upperIdx] - _rowptr[lowerIdx];
+ }
+
+ @Override
+ public long size(int rl, int ru, int cl, int cu) {
+ long nnz = 0;
+
+ int lRowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, rl);
+ if (lRowIdx < 0)
+ lRowIdx = -lRowIdx - 1;
+
+ int uRowIdx = Arrays.binarySearch(_rowidx, lRowIdx, _nnzr, ru);
+ if (uRowIdx < 0)
+ uRowIdx = -uRowIdx - 1;
+
+ for (int rowIdx = lRowIdx; rowIdx < uRowIdx; rowIdx++) {
+ int clIdx = Arrays.binarySearch(_colidx,
_rowptr[rowIdx], _rowptr[rowIdx+1], cl);
+ if (clIdx < 0)
+ clIdx = -clIdx - 1;
+
+ int cuIdx = Arrays.binarySearch(_colidx, clIdx,
_rowptr[rowIdx+1], cu);
+ if (cuIdx < 0)
+ cuIdx = -cuIdx - 1;
+
+ nnz += cuIdx - clIdx;
+ }
+ return nnz;
+ }
+
+ @Override
+ public boolean isEmpty(int r) {
+ return size(r) == 0;
+ }
+
+ @Override
+ public int[] indexes(int r) {
+ return _colidx;
+ }
+
+ @Override
+ public double[] values(int r) {
+ return _values;
+ }
+
+ @Override
+ public int pos(int r) {
+ int idx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (idx < 0)
+ idx = Math.max(-idx - 2, 0);
+
+ return _rowptr[idx];
+ }
+
+ @Override
+ public boolean set(int r, int c, double v) {
+ int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+ boolean rowExists = rowIndex >= 0;
+
+ if (!rowExists) {
+ // Nothing to do
+ if (v == 0)
+ return false;
+
+ int rowInsertionIndex = -rowIndex - 1;
+ insertRow(rowInsertionIndex, r,
_rowptr[rowInsertionIndex]);
+ incrRowPtr(rowInsertionIndex+1);
+ insertCol(_rowptr[rowInsertionIndex], c, v);
+ return true;
+ }
+
+ int pos = _rowptr[rowIndex];
+ int len = _rowptr[rowIndex+1] - pos;
+ int index = Arrays.binarySearch(_colidx, pos, pos+len, c);
+ boolean colExists = index >= 0;
+
+ if (v != 0) {
+ if (colExists) {
+ _values[index] = v;
+ return false;
+ }
+
+ // Insert a new column into an existing row
+ insertCol(-index-1, c, v);
+ incrRowPtr(rowIndex+1);
+
+ return true;
+ }
+
+ if (!colExists)
+ return false;
+
+ // If there is only one entry in the row, we have to remove the
entire row
+ if (len == 1) {
+ deleteRow(rowIndex);
+ rowIndex--;
+ }
+
+ // remove the column
+ incrRowPtr(rowIndex+1, -1);
+ deleteCol(index);
+
+ return true;
+ }
+
+ @Override
+ public boolean add(int r, int c, double v) {
+ // TODO: performance
+ double oldValue = get(r, c);
+
+ if (v == 0)
+ return false;
+
+ return set(r, c, oldValue + v);
+ }
+
+ @Override
+ public void set(int r, SparseRow row, boolean deep) {
+ int newRowSize = row.size();
+
+ int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+ boolean rowExists = rowIndex >= 0;
+
+ if (!rowExists) {
+ // Nothing to do
+ if (newRowSize == 0)
+ return;
+
+ int rowInsertionIndex = -rowIndex - 1;
+ insertRow(rowInsertionIndex, r,
_rowptr[rowInsertionIndex]);
+ incrRowPtr(rowInsertionIndex+1, newRowSize);
+ insertCols(_rowptr[rowInsertionIndex], row.indexes(),
row.values());
+ return;
+ }
+
+ int pos = _rowptr[rowIndex];
+ int oldRowSize = _rowptr[rowIndex+1] - pos;
+
+ if (newRowSize == 0) {
+ // Delete row
+ deleteRow(rowIndex);
+ incrRowPtr(rowIndex, -oldRowSize);
+ deleteCols(pos, oldRowSize);
+ return;
+ }
+
+ incrRowPtr(rowIndex+1, newRowSize-oldRowSize);
+ insertCols(pos, row.indexes(), row.values(), oldRowSize);
+ }
+
+ @Override
+ public void append(int r, int c, double v) {
+ // TODO performance
+ set(r, c, v);
+ }
+
+ @Override
+ public void setIndexRange(int r, int cl, int cu, double[] v, int vix,
int vlen) {
+ int lnnz = UtilFunctions.computeNnz(v, vix, vlen);
+
+ if (lnnz == 0) {
+ deleteIndexRange(r, cl, cu);
+ return;
+ }
+
+ int rowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (rowIdx < 0) {
+ rowIdx = -rowIdx - 1;
+ insertRow(rowIdx, r, _rowptr[rowIdx]);
+ }
+
+ int rowStart = _rowptr[rowIdx];
+ int rowEnd = _rowptr[rowIdx+1];
+
+ int clIdx = Arrays.binarySearch(_colidx, rowStart, rowEnd, cl);
+ if (clIdx < 0)
+ clIdx = -clIdx - 1;
+
+ int cuIdx = Arrays.binarySearch(_colidx, clIdx, rowEnd, cu);
+ if (cuIdx < 0)
+ cuIdx = -cuIdx - 1;
+
+ int oldnnz = cuIdx - clIdx;
+
+ allocateCols(clIdx, lnnz, oldnnz);
+ incrRowPtr(rowIdx+1, lnnz - oldnnz);
+
+ int insertionIndex = clIdx;
+
+ for (int i = vix; i < vix+vlen; i++) {
+ if (v[i] != 0) {
+ _colidx[insertionIndex] = cl + i - vix;
+ _values[insertionIndex] = v[i];
+ insertionIndex++;
+ }
+ }
+ }
+
+ @Override
+ public void setIndexRange(int r, int cl, int cu, double[] v, int[] vix,
int vpos, int vlen) {
+ if (vlen == 0) {
+ deleteIndexRange(r, cl, cu);
+ return;
+ }
+
+ int rowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (rowIdx < 0) {
+ rowIdx = -rowIdx - 1;
+ insertRow(rowIdx, r, _rowptr[rowIdx]);
+ }
+
+ int rowStart = _rowptr[rowIdx];
+ int rowEnd = _rowptr[rowIdx+1];
+
+ int clIdx = Arrays.binarySearch(_colidx, rowStart, rowEnd, cl);
+ if (clIdx < 0)
+ clIdx = -clIdx - 1;
+
+ int cuIdx = Arrays.binarySearch(_colidx, clIdx, rowEnd, cu);
+ if (cuIdx < 0)
+ cuIdx = -cuIdx - 1;
+
+ int oldnnz = cuIdx - clIdx;
+
+ allocateCols(clIdx, vlen, oldnnz);
+ incrRowPtr(rowIdx+1, vlen - oldnnz);
+
+ int insertionIndex = clIdx;
+
+ for (int i = vpos; i < vpos+vlen; i++) {
+ if (v[i] != 0) {
+ _colidx[insertionIndex] = cl - vix[i];
+ _values[insertionIndex] = v[i];
+ insertionIndex++;
+ }
+ }
+ }
+
+ @Override
+ public void deleteIndexRange(int r, int cl, int cu) {
+ int rowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+ if( rowIdx < 0 ) //nothing to delete
+ return;
+
+ int nnz = _rowptr[rowIdx+1] - _rowptr[rowIdx];
+
+ int start = Arrays.binarySearch(_colidx, _rowptr[rowIdx],
_rowptr[rowIdx+1], cl);
+ if (start < 0)
+ start = -start-1;
+
+ int end = Arrays.binarySearch(_colidx, start,
_rowptr[rowIdx+1], cu);
+ if( end < 0 ) //delete all remaining
+ end = -end-1;
+
+ if (end-start <= 0) // Nothing to delete
+ return;
+
+ if (nnz == end-start) {
+ deleteRow(rowIdx);
+ rowIdx--;
+ }
+
+ //overlapping array copy (shift rhs values left)
+ System.arraycopy(_colidx, end, _colidx, start, _size-end);
+ System.arraycopy(_values, end, _values, start, _size-end);
+ _size -= (end-start);
+
+ incrRowPtr(rowIdx+1, start-end);
+ }
+
+ @Override
+ public void sort() {
+ for( int i=0; i < _rowidx.length; i++ )
+ sortFromRowIndex(i);
+ }
+
+ @Override
+ public void sort(int r) {
+ int rowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (rowIdx >= 0)
+ sortFromRowIndex(rowIdx);
+ }
+
+ private void sortFromRowIndex(int rowIndex) {
+ int pos = _rowptr[rowIndex];
+ int len = _rowptr[rowIndex+1] - pos;
+ if( !SortUtils.isSorted(pos, pos+len, _colidx) )
+ SortUtils.sortByIndex(pos, pos+len, _colidx, _values);
+ }
+
+ @Override
+ public double get(int r, int c) {
+ int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (rowIndex < 0)
+ return 0;
+
+ int pos = _rowptr[rowIndex];
+ int len = _rowptr[rowIndex+1] - pos;
+
+ //search for existing col index in [pos,pos+len)
+ int index = Arrays.binarySearch(_colidx, pos, pos+len, c);
+ return (index >= 0) ? _values[index] : 0;
+ }
+
+ @Override
+ public SparseRow get(int r) {
+ if( isEmpty(r) )
+ return new SparseRowScalar();
+ int pos = pos(r);
+ int len = size(r);
+
+ SparseRowVector row = new SparseRowVector(len);
+ System.arraycopy(_colidx, pos, row.indexes(), 0, len);
+ System.arraycopy(_values, pos, row.values(), 0, len);
+ row.setSize(len);
+ return row;
+ }
+
+ @Override
+ public Iterator<IJV> getIterator() {
+ // TODO: performance
+ return super.getIterator();
+ }
+
+ @Override
+ public int posFIndexLTE(int r, int c) {
+ int rowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (rowIdx < 0)
+ return -1;
+
+ int colIdx = Arrays.binarySearch(_colidx, _rowptr[rowIdx],
_rowptr[rowIdx+1], c);
+
+ if (colIdx < 0)
+ colIdx = -colIdx - 2;
+
+ // There is no element smaller or equal in this row
+ if (colIdx < _rowptr[rowIdx])
+ return -1;
+
+ return colIdx - _rowptr[rowIdx];
+ }
+
+ @Override
+ public final int posFIndexGTE(int r, int c) {
+ int rowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (rowIdx < 0)
+ return -1;
+
+ int colIdx = Arrays.binarySearch(_colidx, _rowptr[rowIdx],
_rowptr[rowIdx+1], c);
+
+ if (colIdx < 0)
+ colIdx = -colIdx - 1;
+
+ // There is no element greater or equal in this row
+ if (colIdx >= _rowptr[rowIdx+1])
+ return -1;
+
+ return colIdx - _rowptr[rowIdx];
+ }
+
+ @Override
+ public int posFIndexGT(int r, int c) {
+ int rowIdx = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
+
+ if (rowIdx < 0)
+ return -1;
+
+ int colIdx = Arrays.binarySearch(_colidx, _rowptr[rowIdx],
_rowptr[rowIdx+1], c);
+
+ if (colIdx >= 0)
+ colIdx++;
+ else
+ colIdx = -colIdx - 1;
+
+ // There is no element great in this row
+ if (colIdx >= _rowptr[rowIdx+1])
+ return -1;
+
+ return colIdx - _rowptr[rowIdx];
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("SparseBlockCSR: rlen=");
+ sb.append(numRows());
+ sb.append(", nnz=");
+ sb.append(size());
+ sb.append("\n");
+ final int rowDigits =
(int)Math.max(Math.ceil(Math.log10(numRows())),1) ;
+ for(int rowIdx = 0; rowIdx < _rowidx.length; rowIdx++) {
+ // append row
+ final int row = _rowidx[rowIdx];
+ final int pos = _rowptr[rowIdx];
+ final int len = _rowptr[rowIdx+1] - pos;
+
+ sb.append(String.format("%0"+rowDigits+"d ", row));
+ for(int j = pos; j < pos + len; j++) {
+ if(_values[j] == (long) _values[j])
+
sb.append(String.format("%"+rowDigits+"d:%d", _colidx[j], (long)_values[j]));
+ else
+
sb.append(String.format("%"+rowDigits+"d:%s", _colidx[j],
Double.toString(_values[j])));
+ if(j + 1 < pos + len)
+ sb.append(" ");
+ }
+ sb.append("\n");
+ }
+
+ return sb.toString();
+ }
+
+ @Override
+ public boolean checkValidity(int rlen, int clen, long nnz, boolean
strict) {
+ //1. correct meta data
+ if ( rlen < 0 || clen < 0 ) {
+ throw new RuntimeException("Invalid block dimensions:
"+rlen+" "+clen);
+ }
+
+ //2. correct array lengths
+ if (_size != nnz && _rowptr.length != _rowidx.length + 1 &&
_values.length < nnz && _colidx.length < nnz ) {
+ throw new RuntimeException("Incorrect array lengths.");
+ }
+
+ //3. non-decreasing row pointers
+ for ( int i=1; i <_rowidx.length; i++ ) {
+ if (_rowidx[i-1] > _rowidx[i])
+ throw new RuntimeException("Row indices are
decreasing at row: " + i
+ + ", with indices " +
_rowidx[i-1] + " > " +_rowidx[i]);
+ }
+
+ for (int i = 1; i < _rowptr.length; i++ ) {
+ if (_rowptr[i - 1] > _rowptr[i]) {
+ throw new RuntimeException("Row pointers are
decreasing at row: " + i
+ + ", with pointers " +
_rowptr[i-1] + " > " +_rowptr[i]);
+ }
+ }
+
+ //4. sorted column indexes per row
+ for ( int rowIdx = 0; rowIdx < _rowidx.length; rowIdx++ ) {
+ int apos = _rowidx[rowIdx];
+ int alen = _rowidx[rowIdx+1] - apos;
+
+ for( int k = apos + 1; k < apos + alen; k++)
+ if( _colidx[k-1] >= _colidx[k] )
+ throw new RuntimeException("Wrong
sparse row ordering: "
+ + k + " " +
_colidx[k-1] + " " + _colidx[k]);
+
+ for( int k=apos; k<apos+alen; k++ )
+ if( _values[k] == 0 )
+ throw new RuntimeException("Wrong
sparse row: zero at "
+ + k + " at col index "
+ _colidx[k]);
+ }
+
+ //5. non-existing zero values
+ for( int i=0; i<_size; i++ ) {
+ if( _values[i] == 0 ) {
+ throw new RuntimeException("The values array
should not contain zeros."
+ + " The " + i + "th value is
"+_values[i]);
+ }
+ }
+
+ //6. a capacity that is no larger than nnz times resize factor.
+ int capacity = _values.length;
+ if(capacity > nnz*RESIZE_FACTOR1 ) {
+ throw new RuntimeException("Capacity is larger than the
nnz times a resize factor."
+ + " Current size: "+capacity+ ", while
Expected size:"+nnz*RESIZE_FACTOR1);
+ }
+
+ return true;
+ }
+
+ @Override //specialized for CSR
+ public boolean contains(double pattern) {
+ boolean NaNpattern = Double.isNaN(pattern);
+ double[] vals = _values;
+ int len = _size;
+ for(int i=0; i<len; i++)
+ if(vals[i]==pattern || (NaNpattern &&
Double.isNaN(vals[i])))
+ return true;
+ return false;
+ }
+
+ ///////////////////////////
+ // private helper methods
+
+ private int newCapacity(int minsize) {
+ //compute new size until minsize reached
+ double tmpCap = Math.max(_values.length, 1);
+ while( tmpCap < minsize ) {
+ tmpCap *= (tmpCap <= 1024) ?
+ RESIZE_FACTOR1 : RESIZE_FACTOR2;
+ }
+ return (int)Math.min(tmpCap, Integer.MAX_VALUE);
+ }
+
+ private void deleteRow(int rowIdx) {
+ System.arraycopy(_rowidx, rowIdx + 1, _rowidx, rowIdx,
_nnzr-rowIdx-1);
+ System.arraycopy(_rowptr, rowIdx + 1, _rowptr, rowIdx,
_nnzr-rowIdx);
+ _nnzr--;
+ }
+
+ private void insertRow(int ix, int row, int rowPtr) {
+ if (_nnzr >= _rowidx.length) {
+ resizeAndInsertRow(ix, row, rowPtr);
+ return;
+ }
+
+ System.arraycopy(_rowidx, ix, _rowidx, ix+1, _nnzr-ix);
+ System.arraycopy(_rowptr, ix, _rowptr, ix+1, _nnzr-ix+1);
+ _rowidx[ix] = row;
+ _rowptr[ix] = rowPtr;
+ _nnzr++;
+ }
+
+ private void resizeAndInsertRow(int ix, int row, int rowPtr) {
+ //compute new size
+ int newCap = newCapacity(_rowidx.length+1);
+
+ int[] oldrowidx = _rowidx;
+ int[] oldrowptr = _rowptr;
+ _rowidx = new int[newCap];
+ _rowptr = new int[newCap+1];
+
+ //copy lhs values to new array
+ System.arraycopy(oldrowidx, 0, _rowidx, 0, ix);
+ System.arraycopy(oldrowptr, 0, _rowptr, 0, ix);
+
+ //copy rhs values to new array
+ System.arraycopy(oldrowidx, ix, _rowidx, ix+1, _nnzr-ix);
+ System.arraycopy(oldrowptr, ix, _rowptr, ix+1, _nnzr-ix+1);
+
+ _rowidx[ix] = row;
+ _rowptr[ix] = rowPtr;
+
+ _nnzr++;
+ }
+
+ private void deleteCol(int ix) {
+ // Without removing row
+ //overlapping array copy (shift rhs values left by 1)
+ System.arraycopy(_colidx, ix+1, _colidx, ix, _size-ix-1);
+ System.arraycopy(_values, ix+1, _values, ix, _size-ix-1);
+ _size--;
+ }
+
+ private void insertCol(int ix, int c, double v) {
+ // Without inserting row
+ if (_size >= _colidx.length) {
+ resizeAndInsertCol(ix, c, v);
+ return;
+ }
+
+ System.arraycopy(_colidx, ix, _colidx, ix+1, _size-ix);
+ System.arraycopy(_values, ix, _values, ix+1, _size-ix);
+
+ _colidx[ix] = c;
+ _values[ix] = v;
+ _size++;
+ }
+
+ private void deleteCols(int ix, int len) {
+ insertCols(ix, new int[0], new double[0], len, 0, 0);
+ }
+
+ private void insertCols(int ix, int[] cols, double[] vals) {
+ insertCols(ix, cols, vals, 0, 0, vals.length);
+ }
+
+ private void insertCols(int ix, int[] cols, double[] vals, int
overwriteNum) {
+ insertCols(ix, cols, vals, overwriteNum, 0, vals.length);
+ }
+
+ private void insertCols(int ix, int[] cols, double[] vals, int
overwriteNum, int vix, int vlen) {
+ // Without inserting row
+ if (_size + vlen - overwriteNum > _colidx.length) {
+ resizeAndInsertCols(ix, cols, vals, overwriteNum, vix,
vlen);
+ return;
+ }
+
+ allocateCols(ix, vlen, overwriteNum);
+
+ System.arraycopy(cols, vix, _colidx, ix, vlen);
+ System.arraycopy(vals, vix, _values, ix, vlen);
+ }
+
+ private void resizeAndInsertCols(int ix, int[] cols, double[] vals, int
overwriteNum, int vix, int vlen) {
+ resizeAndAllocateCols(ix, vlen, overwriteNum);
+
+ //copy new vals into row
+ System.arraycopy(cols, vix, _colidx, ix, vlen);
+ System.arraycopy(vals, vix, _values, ix, vlen);
+ }
+
+ @SuppressWarnings("unused")
+ private void allocateCols(int ix, int numCols) {
+ allocateCols(ix, numCols, 0);
+ }
+
+ private void allocateCols(int ix, int numCols, int overwriteNum) {
+ if (numCols == 0)
+ return;
+
+ if (_size + numCols - overwriteNum > _colidx.length) {
+ resizeAndAllocateCols(ix, numCols, overwriteNum);
+ return;
+ }
+
+ System.arraycopy(_colidx, ix+overwriteNum, _colidx, ix+numCols,
_size-ix-overwriteNum);
+ System.arraycopy(_values, ix+overwriteNum, _values, ix+numCols,
_size-ix-overwriteNum);
+ _size += numCols - overwriteNum;
+ }
+
+ private void resizeAndAllocateCols(int ix, int numCols, int
overwriteNum) {
+ //compute new size
+ int newCap = newCapacity(_size + numCols - overwriteNum);
+
+ int[] oldcolidx = _colidx;
+ double[] oldvalues = _values;
+ _colidx = new int[newCap];
+ _values = new double[newCap];
+
+ //copy lhs values to new array
+ System.arraycopy(oldcolidx, 0, _colidx, 0, ix);
+ System.arraycopy(oldvalues, 0, _values, 0, ix);
+
+ //copy rhs values to new array
+ System.arraycopy(oldcolidx, ix + overwriteNum, _colidx,
ix+numCols, _size-ix-overwriteNum);
+ System.arraycopy(oldvalues, ix + overwriteNum, _values,
ix+numCols, _size-ix-overwriteNum);
+
+ _size += numCols - overwriteNum;
+ }
+
+ private void resizeAndInsertCol(int ix, int c, double v) {
+ //compute new size
+ int newCap = newCapacity(_values.length+1);
+
+ int[] oldcolidx = _colidx;
+ double[] oldvalues = _values;
+ _colidx = new int[newCap];
+ _values = new double[newCap];
+
+ //copy lhs values to new array
+ System.arraycopy(oldcolidx, 0, _colidx, 0, ix);
+ System.arraycopy(oldvalues, 0, _values, 0, ix);
+
+ //copy rhs values to new array
+ System.arraycopy(oldcolidx, ix, _colidx, ix+1, _size-ix);
+ System.arraycopy(oldvalues, ix, _values, ix+1, _size-ix);
+
+ //insert new value
+ _colidx[ix] = c;
+ _values[ix] = v;
+
+ _size++;
+ }
+
+ private void incrRowPtr(int rowIndex) {
+ incrRowPtr(rowIndex, 1);
+ }
+
+ private void incrRowPtr(int rowIndex, int cnt) {
+
+ for( int i = rowIndex; i < _nnzr + 1; i++ )
+ _rowptr[i] += cnt;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
index e03fab9963..efc4a534ef 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
@@ -36,6 +36,7 @@ public abstract class SparseBlockFactory{
case MCSR: return new SparseBlockMCSR(rlen, -1);
case CSR: return new SparseBlockCSR(rlen);
case COO: return new SparseBlockCOO(rlen);
+ case DCSR: return new SparseBlockDCSR(rlen);
default:
throw new RuntimeException("Unexpected sparse
block type: "+type.toString());
}
@@ -63,6 +64,7 @@ public abstract class SparseBlockFactory{
case MCSR: return new SparseBlockMCSR(sblock);
case CSR: return new SparseBlockCSR(sblock);
case COO: return new SparseBlockCOO(sblock);
+ case DCSR: return new SparseBlockDCSR(sblock);
default:
throw new RuntimeException("Unexpected sparse
block type: "+type.toString());
}
@@ -83,6 +85,7 @@ public abstract class SparseBlockFactory{
case MCSR: return
SparseBlockMCSR.estimateSizeInMemory(nrows, ncols, sparsity);
case CSR: return
SparseBlockCSR.estimateSizeInMemory(nrows, ncols, sparsity);
case COO: return
SparseBlockCOO.estimateSizeInMemory(nrows, ncols, sparsity);
+ case DCSR: return
SparseBlockDCSR.estimateSizeInMemory(nrows, ncols, sparsity);
default:
throw new RuntimeException("Unexpected sparse
block type: "+type.toString());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
index 65fde02994..8e0b36f9e0 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
@@ -32,7 +32,6 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
-import java.util.concurrent.ConcurrentHashMap;
public class ColumnEncoderWordEmbedding extends ColumnEncoder {
private MatrixBlock _wordEmbeddings;
@@ -45,6 +44,7 @@ public class ColumnEncoderWordEmbedding extends ColumnEncoder
{
_wordEmbeddings = new MatrixBlock();
}
+ @SuppressWarnings("unused")
private long lookupRCDMap(Object key) {
return _rcdMap.getOrDefault(key, -1L);
}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
index c66a9cb3f0..22e830cdb3 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.component.sparse;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -94,6 +95,21 @@ public class SparseBlockAlignment extends AutomatedTestBase
runSparseBlockScanTest(SparseBlock.Type.COO, sparsity3, true);
}
+ @Test
+ public void testSparseBlockDCSR1Pos() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity1, true);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Pos() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity2, true);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Pos() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity3, true);
+ }
+
@Test
public void testSparseBlockMCSR1Neg() {
runSparseBlockScanTest(SparseBlock.Type.MCSR, sparsity1, false);
@@ -138,6 +154,21 @@ public class SparseBlockAlignment extends AutomatedTestBase
public void testSparseBlockCOO3Neg() {
runSparseBlockScanTest(SparseBlock.Type.COO, sparsity3, false);
}
+
+ @Test
+ public void testSparseBlockDCSR1Neg() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity1, false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Neg() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity2, false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Neg() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity3, false);
+ }
private void runSparseBlockScanTest( SparseBlock.Type btype, double
sparsity, boolean positive)
{
@@ -154,6 +185,7 @@ public class SparseBlockAlignment extends AutomatedTestBase
case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
case CSR: sblock = new SparseBlockCSR(srtmp);
break;
case COO: sblock = new SparseBlockCOO(srtmp);
break;
+ case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
}
//init second sparse block and deep copy
@@ -162,6 +194,7 @@ public class SparseBlockAlignment extends AutomatedTestBase
case MCSR: sblock2 = new
SparseBlockMCSR(sblock); break;
case CSR: sblock2 = new SparseBlockCSR(sblock);
break;
case COO: sblock2 = new SparseBlockCOO(sblock);
break;
+ case DCSR: sblock2 = new
SparseBlockDCSR(sblock); break;
}
//modify second block if necessary
@@ -181,8 +214,12 @@ public class SparseBlockAlignment extends AutomatedTestBase
for( int i=0; i<rows; i++ ) {
if( i==37 || i==38 )
rowsAligned37 &= sblock.isAligned(i,
sblock2);
- else if( i<37 ) //CSR/COO different after
update pos
+ else if( i<37 ) {//CSR/COO different after
update pos
rowsAlignedRest &= sblock.isAligned(i,
sblock2);
+ if (!sblock.isAligned(i, sblock2)) {
+ Assert.fail("Alignment problem
at row: " + i + " (" + sblock.size(i) + " vs " + sblock2.size(i) + ")");
+ }
+ }
}
if( rowsAligned37 != positive )
Assert.fail("Wrong row alignment indicated:
"+rowsAligned37+", expected: "+positive);
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAppendSort.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAppendSort.java
index f75e42575a..e16f1536e7 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAppendSort.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAppendSort.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.component.sparse;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -145,6 +146,36 @@ public class SparseBlockAppendSort extends
AutomatedTestBase
public void testSparseBlockCOO3Rand() {
runSparseBlockAppendSortTest(SparseBlock.Type.COO, sparsity3,
InitType.RAND_SET);
}
+
+ @Test
+ public void testSparseBlockDCSR1Seq() {
+ runSparseBlockAppendSortTest(SparseBlock.Type.DCSR, sparsity1,
InitType.SEQ_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Seq() {
+ runSparseBlockAppendSortTest(SparseBlock.Type.DCSR, sparsity2,
InitType.SEQ_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Seq() {
+ runSparseBlockAppendSortTest(SparseBlock.Type.DCSR, sparsity3,
InitType.SEQ_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1Rand() {
+ runSparseBlockAppendSortTest(SparseBlock.Type.DCSR, sparsity1,
InitType.RAND_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Rand() {
+ runSparseBlockAppendSortTest(SparseBlock.Type.DCSR, sparsity2,
InitType.RAND_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Rand() {
+ runSparseBlockAppendSortTest(SparseBlock.Type.DCSR, sparsity3,
InitType.RAND_SET);
+ }
private void runSparseBlockAppendSortTest( SparseBlock.Type btype,
double sparsity, InitType itype)
{
@@ -159,6 +190,7 @@ public class SparseBlockAppendSort extends AutomatedTestBase
case MCSR: sblock = new SparseBlockMCSR(rows,
cols); break;
case CSR: sblock = new SparseBlockCSR(rows,
cols); break;
case COO: sblock = new SparseBlockCOO(rows,
cols); break;
+ case DCSR: sblock = new SparseBlockDCSR(rows,
cols); break;
}
if(itype == InitType.SEQ_SET) {
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockDelete.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockDelete.java
index 6167979e68..2c79a297d8 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockDelete.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockDelete.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.component.sparse;
import java.util.Iterator;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -98,6 +99,21 @@ public class SparseBlockDelete extends AutomatedTestBase
public void testSparseBlockCOO3() {
runSparseBlockDeleteTest(SparseBlock.Type.COO, sparsity3);
}
+
+ @Test
+ public void testSparseBlockDCSR1() {
+ runSparseBlockDeleteTest(SparseBlock.Type.DCSR, sparsity1);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2() {
+ runSparseBlockDeleteTest(SparseBlock.Type.DCSR, sparsity2);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3() {
+ runSparseBlockDeleteTest(SparseBlock.Type.DCSR, sparsity3);
+ }
private void runSparseBlockDeleteTest( SparseBlock.Type btype, double
sparsity)
{
@@ -114,6 +130,7 @@ public class SparseBlockDelete extends AutomatedTestBase
case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
case CSR: sblock = new SparseBlockCSR(srtmp);
break;
case COO: sblock = new SparseBlockCOO(srtmp);
break;
+ case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
}
//delete range per row via set
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetFirstIndex.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetFirstIndex.java
index 09e0c59e2f..82a641f0f0 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetFirstIndex.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetFirstIndex.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.component.sparse;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -189,6 +190,51 @@ public class SparseBlockGetFirstIndex extends
AutomatedTestBase
public void testSparseBlockCOO3LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO,
sparsity3, IndexType.LTE);
}
+
+ @Test
+ public void testSparseBlockDCSR1GT() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity1, IndexType.GT);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2GT() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity2, IndexType.GT);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3GT() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity3, IndexType.GT);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1GTE() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity1, IndexType.GTE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2GTE() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity2, IndexType.GTE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3GTE() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity3, IndexType.GTE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1LTE() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity1, IndexType.LTE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2LTE() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity2, IndexType.LTE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3LTE() {
+ runSparseBlockGetFirstIndexTest(SparseBlock.Type.DCSR,
sparsity3, IndexType.LTE);
+ }
private void runSparseBlockGetFirstIndexTest( SparseBlock.Type btype,
double sparsity, IndexType itype)
{
@@ -205,6 +251,7 @@ public class SparseBlockGetFirstIndex extends
AutomatedTestBase
case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
case CSR: sblock = new SparseBlockCSR(srtmp);
break;
case COO: sblock = new SparseBlockCOO(srtmp);
break;
+ case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
}
//check for correct number of non-zeros
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetSet.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetSet.java
index 27a7c91bd0..533cb643d6 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetSet.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockGetSet.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.component.sparse;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -43,7 +44,7 @@ import java.util.Iterator;
public class SparseBlockGetSet extends AutomatedTestBase
{
private final static int rows = 132;
- private final static int cols = 60;
+ private final static int cols = 60;
private final static double sparsity1 = 0.1;
private final static double sparsity2 = 0.2;
private final static double sparsity3 = 0.3;
@@ -193,23 +194,69 @@ public class SparseBlockGetSet extends AutomatedTestBase
public void testSparseBlockCOO3Rand() {
runSparseBlockGetSetTest(SparseBlock.Type.COO, sparsity3,
InitType.RAND_SET);
}
+
+ @Test
+ public void testSparseBlockDCSR1Bulk() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity1,
InitType.BULK);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Bulk() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity2,
InitType.BULK);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Bulk() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity3,
InitType.BULK);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1Seq() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity1,
InitType.SEQ_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Seq() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity2,
InitType.SEQ_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Seq() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity3,
InitType.SEQ_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1Rand() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity1,
InitType.RAND_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Rand() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity2,
InitType.RAND_SET);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Rand() {
+ runSparseBlockGetSetTest(SparseBlock.Type.DCSR, sparsity3,
InitType.RAND_SET);
+ }
private void runSparseBlockGetSetTest( SparseBlock.Type btype, double
sparsity, InitType itype)
{
try
{
//data generation
- double[][] A = getRandomMatrix(rows, cols, -10, 10,
sparsity, 7654321);
+ double[][] A = getRandomMatrix(rows, cols, -10, 10,
sparsity, 7654321);
//init sparse block
SparseBlock sblock = null;
if( itype == InitType.BULK ) {
MatrixBlock mbtmp =
DataConverter.convertToMatrixBlock(A);
- SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock srtmp = mbtmp.getSparseBlock();
switch( btype ) {
case MCSR: sblock = new
SparseBlockMCSR(srtmp); break;
case CSR: sblock = new
SparseBlockCSR(srtmp); break;
case COO: sblock = new
SparseBlockCOO(srtmp); break;
+ case DCSR: sblock = new
SparseBlockDCSR(srtmp); break;
}
}
else if( itype == InitType.SEQ_SET || itype ==
InitType.RAND_SET ) {
@@ -217,6 +264,7 @@ public class SparseBlockGetSet extends AutomatedTestBase
case MCSR: sblock = new
SparseBlockMCSR(rows, cols); break;
case CSR: sblock = new
SparseBlockCSR(rows, cols); break;
case COO: sblock = new
SparseBlockCOO(rows, cols); break;
+ case DCSR: sblock = new
SparseBlockDCSR(rows, cols); break;
}
if(itype == InitType.SEQ_SET) {
@@ -232,7 +280,9 @@ public class SparseBlockGetSet extends AutomatedTestBase
Iterator<ADoubleEntry> iter =
map.getIterator();
while( iter.hasNext() ) { //random hash
order
ADoubleEntry e = iter.next();
- sblock.set((int)e.getKey1(),
(int)e.getKey2(), e.value);
+ int r = (int)e.getKey1();
+ int c = (int)e.getKey2();
+ sblock.set(r, c, e.value);
}
}
}
@@ -254,12 +304,12 @@ public class SparseBlockGetSet extends AutomatedTestBase
//check correct isEmpty return
for( int i=0; i<rows; i++ )
if( sblock.isEmpty(i) != (rnnz[i]==0) )
- Assert.fail("Wrong isEmpty(row) result
for row nnz: "+rnnz[i]);
+ Assert.fail("Wrong isEmpty(row) result
for row nnz: "+rnnz[i] + "(row: " + i + ")");
- //check correct values
- for( int i=0; i<rows; i++ )
- if( !sblock.isEmpty(i) )
- for( int j=0; j<cols; j++ ) {
+ //check correct values
+ for( int i=0; i<rows; i++ )
+ if( !sblock.isEmpty(i) )
+ for( int j=0; j<cols; j++ ) {
double tmp = sblock.get(i, j);
if( tmp != A[i][j] )
Assert.fail("Wrong get
value for cell ("+i+","+j+"): "+tmp+", expected: "+A[i][j]);
@@ -270,4 +320,4 @@ public class SparseBlockGetSet extends AutomatedTestBase
throw new RuntimeException(ex);
}
}
-}
\ No newline at end of file
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIndexRange.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIndexRange.java
index bb86a053eb..2dcd17fe99 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIndexRange.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIndexRange.java
@@ -22,6 +22,7 @@ package org.apache.sysds.test.component.sparse;
import java.util.Arrays;
import java.util.Iterator;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -149,6 +150,36 @@ public class SparseBlockIndexRange extends
AutomatedTestBase
public void testSparseBlockCOO3Insert() {
runSparseBlockIndexRangeTest(SparseBlock.Type.COO, sparsity3,
UpdateType.INSERT);
}
+
+ @Test
+ public void testSparseBlockDCSR1Delete() {
+ runSparseBlockIndexRangeTest(SparseBlock.Type.DCSR, sparsity1,
UpdateType.DELETE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Delete() {
+ runSparseBlockIndexRangeTest(SparseBlock.Type.DCSR, sparsity2,
UpdateType.DELETE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Delete() {
+ runSparseBlockIndexRangeTest(SparseBlock.Type.DCSR, sparsity3,
UpdateType.DELETE);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1Insert() {
+ runSparseBlockIndexRangeTest(SparseBlock.Type.DCSR, sparsity1,
UpdateType.INSERT);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Insert() {
+ runSparseBlockIndexRangeTest(SparseBlock.Type.DCSR, sparsity2,
UpdateType.INSERT);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Insert() {
+ runSparseBlockIndexRangeTest(SparseBlock.Type.DCSR, sparsity3,
UpdateType.INSERT);
+ }
private void runSparseBlockIndexRangeTest( SparseBlock.Type btype,
double sparsity, UpdateType utype)
{
@@ -165,6 +196,7 @@ public class SparseBlockIndexRange extends AutomatedTestBase
case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
case CSR: sblock = new SparseBlockCSR(srtmp);
break;
case COO: sblock = new SparseBlockCOO(srtmp);
break;
+ case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
}
//delete range per row via set
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
index b73f5f41b4..068bedf78e 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.component.sparse;
import java.util.Iterator;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -142,6 +143,36 @@ public class SparseBlockIterator extends AutomatedTestBase
public void testSparseBlockCOO3Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
true);
}
+
+ @Test
+ public void testSparseBlockDCSR1Full() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Full() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Full() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1Partial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
true);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Partial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
true);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Partial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
true);
+ }
private void runSparseBlockIteratorTest( SparseBlock.Type btype, double
sparsity, boolean partial)
{
@@ -158,6 +189,7 @@ public class SparseBlockIterator extends AutomatedTestBase
case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
case CSR: sblock = new SparseBlockCSR(srtmp);
break;
case COO: sblock = new SparseBlockCOO(srtmp);
break;
+ case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
}
//check for correct number of non-zeros
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMemEstimate.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMemEstimate.java
index 465112124b..93ca8b8cdb 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMemEstimate.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMemEstimate.java
@@ -59,6 +59,7 @@ public class SparseBlockMemEstimate extends AutomatedTestBase
double memMCSR =
SparseBlockFactory.estimateSizeSparseInMemory(SparseBlock.Type.MCSR, rows,
cols, sparsity);
double memCSR =
SparseBlockFactory.estimateSizeSparseInMemory(SparseBlock.Type.CSR, rows, cols,
sparsity);
double memCOO =
SparseBlockFactory.estimateSizeSparseInMemory(SparseBlock.Type.COO, rows, cols,
sparsity);
+ double memDCSR =
SparseBlockFactory.estimateSizeSparseInMemory(SparseBlock.Type.DCSR, rows,
cols, sparsity);
double memDense = MatrixBlock.estimateSizeDenseInMemory(rows,
cols);
//check negative estimate
@@ -68,6 +69,8 @@ public class SparseBlockMemEstimate extends AutomatedTestBase
Assert.fail("SparseBlockCSR memory estimate <= 0.");
if( memCOO <= 0 )
Assert.fail("SparseBlockCOO memory estimate <= 0.");
+ if( memDCSR <= 0 )
+ Assert.fail("SparseBlockDCSR memory estimate <= 0.");
//check dense estimate
if( memMCSR > memDense )
@@ -76,6 +79,8 @@ public class SparseBlockMemEstimate extends AutomatedTestBase
Assert.fail("SparseBlockCSR memory estimate larger than
dense estimate.");
if( memCOO > memDense )
Assert.fail("SparseBlockCOO memory estimate larger than
dense estimate.");
+ if( memDCSR > memDense )
+ Assert.fail("SparseBlockDCSR memory estimate larger
than dense estimate.");
//check sparse estimates relations
if( sparsity == sparsity1 ) { //sparse (pref CSR)
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
index a25c5f9c0f..ce2211bc6e 100644
--- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
+++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockMerge.java
@@ -101,6 +101,26 @@ public class SparseBlockMerge extends AutomatedTestBase
public void testMergeMCSR_COO_3() {
runSparseBlockMergeTest(SparseBlock.Type.MCSR,
SparseBlock.Type.COO, sparsity3);
}
+
+ @Test
+ public void testMergeMCSR_DCSR_0() {
+ runSparseBlockMergeTest(SparseBlock.Type.MCSR,
SparseBlock.Type.DCSR, sparsity0);
+ }
+
+ @Test
+ public void testMergeMCSR_DCSR_1() {
+ runSparseBlockMergeTest(SparseBlock.Type.MCSR,
SparseBlock.Type.DCSR, sparsity1);
+ }
+
+ @Test
+ public void testMergeMCSR_DCSR_2() {
+ runSparseBlockMergeTest(SparseBlock.Type.MCSR,
SparseBlock.Type.DCSR, sparsity2);
+ }
+
+ @Test
+ public void testMergeMCSR_DCSR_3() {
+ runSparseBlockMergeTest(SparseBlock.Type.MCSR,
SparseBlock.Type.DCSR, sparsity3);
+ }
@Test
public void testMergeCSR_CSR_0() {
@@ -141,6 +161,26 @@ public class SparseBlockMerge extends AutomatedTestBase
public void testMergeCSR_MCSR_3() {
runSparseBlockMergeTest(SparseBlock.Type.CSR,
SparseBlock.Type.MCSR, sparsity3);
}
+
+ @Test
+ public void testMergeCSR_DCSR_0() {
+ runSparseBlockMergeTest(SparseBlock.Type.CSR,
SparseBlock.Type.DCSR, sparsity0);
+ }
+
+ @Test
+ public void testMergeCSR_DCSR_1() {
+ runSparseBlockMergeTest(SparseBlock.Type.CSR,
SparseBlock.Type.DCSR, sparsity1);
+ }
+
+ @Test
+ public void testMergeCSR_DCSR_2() {
+ runSparseBlockMergeTest(SparseBlock.Type.CSR,
SparseBlock.Type.DCSR, sparsity2);
+ }
+
+ @Test
+ public void testMergeCSR_DCSR_3() {
+ runSparseBlockMergeTest(SparseBlock.Type.CSR,
SparseBlock.Type.DCSR, sparsity3);
+ }
@Test
public void testMergeCSR_COO_0() {
@@ -221,6 +261,106 @@ public class SparseBlockMerge extends AutomatedTestBase
public void testMergeCOO_CSR_3() {
runSparseBlockMergeTest(SparseBlock.Type.COO,
SparseBlock.Type.CSR, sparsity3);
}
+
+ @Test
+ public void testMergeCOO_DCSR_0() {
+ runSparseBlockMergeTest(SparseBlock.Type.COO,
SparseBlock.Type.DCSR, sparsity0);
+ }
+
+ @Test
+ public void testMergeCOO_DCSR_1() {
+ runSparseBlockMergeTest(SparseBlock.Type.COO,
SparseBlock.Type.DCSR, sparsity1);
+ }
+
+ @Test
+ public void testMergeCOO_DCSR_2() {
+ runSparseBlockMergeTest(SparseBlock.Type.COO,
SparseBlock.Type.DCSR, sparsity2);
+ }
+
+ @Test
+ public void testMergeCOO_DCSR_3() {
+ runSparseBlockMergeTest(SparseBlock.Type.COO,
SparseBlock.Type.DCSR, sparsity3);
+ }
+
+ @Test
+ public void testMergeDCSR_DCSR_0() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.DCSR, sparsity0);
+ }
+
+ @Test
+ public void testMergeDCSR_DCSR_1() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.DCSR, sparsity1);
+ }
+
+ @Test
+ public void testMergeDCSR_DCSR_2() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.DCSR, sparsity2);
+ }
+
+ @Test
+ public void testMergeDCSR_DCSR_3() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.DCSR, sparsity3);
+ }
+
+ @Test
+ public void testMergeDCSR_CSR_0() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.CSR, sparsity0);
+ }
+
+ @Test
+ public void testMergeDCSR_CSR_1() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.CSR, sparsity1);
+ }
+
+ @Test
+ public void testMergeDCSR_CSR_2() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.CSR, sparsity2);
+ }
+
+ @Test
+ public void testMergeDCSR_CSR_3() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.CSR, sparsity3);
+ }
+
+ @Test
+ public void testMergeDCSR_MCSR_0() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.MCSR, sparsity0);
+ }
+
+ @Test
+ public void testMergeDCSR_MCSR_1() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.MCSR, sparsity1);
+ }
+
+ @Test
+ public void testMergeDCSR_MCSR_2() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.MCSR, sparsity2);
+ }
+
+ @Test
+ public void testMergeDCSR_MCSR_3() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.MCSR, sparsity3);
+ }
+
+ @Test
+ public void testMergeDCSR_COO_0() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.COO, sparsity0);
+ }
+
+ @Test
+ public void testMergeDCSR_COO_1() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.COO, sparsity1);
+ }
+
+ @Test
+ public void testMergeDCSR_COO_2() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.COO, sparsity2);
+ }
+
+ @Test
+ public void testMergeDCSR_COO_3() {
+ runSparseBlockMergeTest(SparseBlock.Type.DCSR,
SparseBlock.Type.COO, sparsity3);
+ }
private void runSparseBlockMergeTest( SparseBlock.Type btype1,
SparseBlock.Type btype2, double sparsity)
{
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockScan.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockScan.java
index 762f071d6c..d13bb93e38 100644
--- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockScan.java
+++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockScan.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.component.sparse;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -93,6 +94,21 @@ public class SparseBlockScan extends AutomatedTestBase
public void testSparseBlockCOO3Full() {
runSparseBlockScanTest(SparseBlock.Type.COO, sparsity3);
}
+
+ @Test
+ public void testSparseBlockDCSR1Full() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity1);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2Full() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity2);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3Full() {
+ runSparseBlockScanTest(SparseBlock.Type.DCSR, sparsity3);
+ }
private void runSparseBlockScanTest( SparseBlock.Type btype, double
sparsity)
{
@@ -109,6 +125,7 @@ public class SparseBlockScan extends AutomatedTestBase
case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
case CSR: sblock = new SparseBlockCSR(srtmp);
break;
case COO: sblock = new SparseBlockCOO(srtmp);
break;
+ case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
}
//check for correct number of non-zeros
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockSize.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockSize.java
index e70e08c9e1..39dafec0c8 100644
--- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockSize.java
+++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockSize.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.component.sparse;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -98,6 +99,21 @@ public class SparseBlockSize extends AutomatedTestBase
public void testSparseBlockCOO3() {
runSparseBlockSizeTest(SparseBlock.Type.COO, sparsity3);
}
+
+ @Test
+ public void testSparseBlockDCSR1() {
+ runSparseBlockSizeTest(SparseBlock.Type.DCSR, sparsity1);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2() {
+ runSparseBlockSizeTest(SparseBlock.Type.DCSR, sparsity2);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3() {
+ runSparseBlockSizeTest(SparseBlock.Type.DCSR, sparsity3);
+ }
private void runSparseBlockSizeTest( SparseBlock.Type btype, double
sparsity)
{
@@ -109,11 +125,12 @@ public class SparseBlockSize extends AutomatedTestBase
//init sparse block
SparseBlock sblock = null;
MatrixBlock mbtmp =
DataConverter.convertToMatrixBlock(A);
- SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock srtmp = mbtmp.getSparseBlock();
switch( btype ) {
case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
case CSR: sblock = new SparseBlockCSR(srtmp);
break;
case COO: sblock = new SparseBlockCOO(srtmp);
break;
+ case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
}
//prepare summary statistics nnz
@@ -149,7 +166,7 @@ public class SparseBlockSize extends AutomatedTestBase
//check index range nnz
if( sblock.size(rl, ru, cl, cu) != nnz2 )
Assert.fail("Wrong number of range non-zeros: "
+
- sblock.size(rl, ru, cl, cu)+",
expected: "+nnz2);
+ sblock.size(rl, ru, cl, cu)+",
expected: "+nnz2);
}
catch(Exception ex) {
ex.printStackTrace();