This is an automated email from the ASF dual-hosted git repository.
estrauss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 8f4dba14f6 [SYSTEMDS-3801] Fix missing method implementations in
ColGroupSDCZeros
8f4dba14f6 is described below
commit 8f4dba14f6f4f94ad34de559d2d72168868fcc2d
Author: e-strauss <[email protected]>
AuthorDate: Tue Dec 3 00:58:31 2024 +0100
[SYSTEMDS-3801] Fix missing method implementations in ColGroupSDCZeros
The previous master version broke the AWARE experiment for the kmeans+
algorithm. This patch fixes that and adds missing methods implementations for
DenseBlocks in ColGroupSDCZeros.
After the changes, the runtime additionally was decreased from 40s to 32s
for the kmeans+ algorithm on the US Census dataset.
Closes #2149.
---
.../runtime/compress/colgroup/ColGroupSDC.java | 4 +--
.../compress/colgroup/ColGroupSDCZeros.java | 38 +++++++++++++++++++---
.../compress/colgroup/dictionary/ADictionary.java | 12 ++++++-
.../compress/colgroup/dictionary/IDictionary.java | 15 ++++++++-
.../colgroup/dictionary/PlaceHolderDict.java | 8 ++++-
.../compress/dictionary/PlaceHolderDictTest.java | 9 +++--
6 files changed, 74 insertions(+), 12 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index ea4f2fb581..e78bea93a2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -683,7 +683,7 @@ public class ColGroupSDC extends ASDC implements
IMapToDataGroup {
}
else {
while(c < points.length && points[c].o == of) {
- _dict.put(sr,
_data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
+ _dict.putSparse(sr,
_data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
c++;
}
of = it.next();
@@ -696,7 +696,7 @@ public class ColGroupSDC extends ASDC implements
IMapToDataGroup {
}
while(of == last && c < points.length && points[c].o == of) {
- _dict.put(sr, _data.getIndex(it.getDataIndex()),
points[c].r, nCol, _colIndexes);
+ _dict.putSparse(sr, _data.getIndex(it.getDataIndex()),
points[c].r, nCol, _colIndexes);
c++;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index c1e081f253..d250969a6a 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -836,7 +836,7 @@ public class ColGroupSDCZeros extends ASDCZero implements
IMapToDataGroup {
while(of < last && c < points.length) {
if(points[c].o == of) {
- c = processRow(points, sr, nCol, c, of,
_data.getIndex(it.getDataIndex()));
+ c = processRowSparse(points, sr, nCol, c, of,
_data.getIndex(it.getDataIndex()));
of = it.next();
}
else if(points[c].o < of)
@@ -848,18 +848,46 @@ public class ColGroupSDCZeros extends ASDCZero implements
IMapToDataGroup {
while(c < points.length && points[c].o < last)
c++;
- c = processRow(points, sr, nCol, c, of,
_data.getIndex(it.getDataIndex()));
+ c = processRowSparse(points, sr, nCol, c, of,
_data.getIndex(it.getDataIndex()));
}
@Override
protected void denseSelection(MatrixBlock selection, P[] points,
MatrixBlock ret, int rl, int ru) {
- throw new NotImplementedException();
+ final DenseBlock dr = ret.getDenseBlock();
+ final int nCol = _colIndexes.size();
+ final AIterator it = _indexes.getIterator();
+ final int last = _indexes.getOffsetToLast();
+ int c = 0;
+ int of = it.value();
+
+ while(of < last && c < points.length) {
+ if(points[c].o == of) {
+ c = processRowDense(points, dr, nCol, c, of,
_data.getIndex(it.getDataIndex()));
+ of = it.next();
+ }
+ else if(points[c].o < of)
+ c++;
+ else
+ of = it.next();
+ }
+ // increment the c pointer until it is pointing at
least to last point or is done.
+ while(c < points.length && points[c].o < last)
+ c++;
+ c = processRowDense(points, dr, nCol, c, of,
_data.getIndex(it.getDataIndex()));
+ }
+
+ private int processRowSparse(P[] points, final SparseBlock sr, final
int nCol, int c, int of, final int did) {
+ while(c < points.length && points[c].o == of) {
+ _dict.putSparse(sr, did, points[c].r, nCol,
_colIndexes);
+ c++;
+ }
+ return c;
}
- private int processRow(P[] points, final SparseBlock sr, final int
nCol, int c, int of, final int did) {
+ private int processRowDense(P[] points, final DenseBlock dr, final int
nCol, int c, int of, final int did) {
while(c < points.length && points[c].o == of) {
- _dict.put(sr, did, points[c].r, nCol, _colIndexes);
+ _dict.putDense(dr, did, points[c].r, nCol, _colIndexes);
c++;
}
return c;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index d41e2675f5..7d88573e3a 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary;
import java.io.Serializable;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
@@ -87,8 +88,17 @@ public abstract class ADictionary implements IDictionary,
Serializable {
}
@Override
- public void put(SparseBlock sb, int idx, int rowOut, int nCol,
IColIndex columns) {
+ public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol,
IColIndex columns) {
for(int i = 0; i < nCol; i++)
sb.append(rowOut, columns.get(i), getValue(idx, i,
nCol));
}
+
+ @Override
+ public void putDense(DenseBlock dr, int idx, int rowOut, int nCol,
IColIndex columns) {
+ double[] dv = dr.values(rowOut);
+ int off = dr.pos(rowOut);
+ for(int i = 0; i < nCol; i++)
+ dv[off + columns.get(i)] += getValue(idx, i, nCol);
+ }
+
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
index a7a74775be..bfe4ef23c3 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
@@ -25,6 +25,7 @@ import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -989,6 +990,18 @@ public interface IDictionary {
* @param nCol The number of columns in the dictionary
* @param columns The columns to output into.
*/
- public void put(SparseBlock sb, int idx, int rowOut, int nCol,
IColIndex columns);
+ public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol,
IColIndex columns);
+
+ /**
+ * Put the row specified into the sparse block, via append calls.
+ *
+ * @param db The dense block to put into
+ * @param idx The dictionary index to put in.
+ * @param rowOut The row in the sparse block to put it into
+ * @param nCol The number of columns in the dictionary
+ * @param columns The columns to output into.
+ */
+ public void putDense(DenseBlock db, int idx, int rowOut, int nCol,
IColIndex columns);
+
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
index 88a7be2619..68a3fb3fac 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
@@ -25,6 +25,7 @@ import java.io.IOException;
import java.io.Serializable;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -526,7 +527,12 @@ public class PlaceHolderDict implements IDictionary,
Serializable {
}
@Override
- public void put(SparseBlock sb, int idx, int rowOut, int nCol,
IColIndex columns) {
+ public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol,
IColIndex columns) {
+ throw new RuntimeException(errMessage);
+ }
+
+ @Override
+ public void putDense(DenseBlock sb, int idx, int rowOut, int nCol,
IColIndex columns) {
throw new RuntimeException(errMessage);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
index 88e5d8adcc..5a112a800c 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
@@ -490,8 +490,13 @@ public class PlaceHolderDictTest {
}
@Test(expected = Exception.class)
- public void put() {
- d.put(null, 1, 1, 1, null);
+ public void putDense() {
+ d.putDense(null, 1, 1, 1, null);
+ }
+
+ @Test(expected = Exception.class)
+ public void putSparse() {
+ d.putSparse(null, 1, 1, 1, null);
}
@Test