This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 57398289c1 [SYSTEMDS-3725] Fix countDistinct/unique
compilation/runtime operators
57398289c1 is described below
commit 57398289c1da83e57b612059a63b6e0d9aca19ed
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Aug 23 12:55:59 2024 +0200
[SYSTEMDS-3725] Fix countDistinct/unique compilation/runtime operators
---
src/main/java/org/apache/sysds/common/Types.java | 8 +--
.../org/apache/sysds/lops/PartialAggregate.java | 12 ----
.../org/apache/sysds/parser/DMLTranslator.java | 12 ++--
.../sysds/runtime/matrix/data/LibMatrixSketch.java | 70 +++++++++++++++++-----
.../sysds/test/functions/unique/UniqueBase.java | 2 +-
.../sysds/test/functions/unique/UniqueRow.java | 13 ++--
6 files changed, 69 insertions(+), 48 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 4161f0c23d..a7397ae54b 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -493,12 +493,8 @@ public interface Types {
PROD, SUM_PROD,
TRACE, MEAN, VAR,
MAXINDEX, MININDEX,
- COUNT_DISTINCT,
- ROW_COUNT_DISTINCT, //TODO should be direction
- COL_COUNT_DISTINCT,
- COUNT_DISTINCT_APPROX,
- COUNT_DISTINCT_APPROX_ROW, //TODO should be direction
- COUNT_DISTINCT_APPROX_COL,
+ COUNT_DISTINCT,
+ COUNT_DISTINCT_APPROX,
UNIQUE;
@Override
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 467c7c69b0..ed6ffe6e71 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -352,12 +352,6 @@ public class PartialAggregate extends Lop
}
}
- case ROW_COUNT_DISTINCT:
- return "uacdr";
-
- case COL_COUNT_DISTINCT:
- return "uacdc";
-
case COUNT_DISTINCT_APPROX: {
switch (dir) {
case RowCol: return "uacdap";
@@ -369,12 +363,6 @@ public class PartialAggregate extends Lop
}
}
- case COUNT_DISTINCT_APPROX_ROW:
- return "uacdapr";
-
- case COUNT_DISTINCT_APPROX_COL:
- return "uacdapc";
-
case UNIQUE: {
switch (dir) {
case RowCol: return "unique";
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 5ff351da4c..77ed904821 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2066,12 +2066,12 @@ public class DMLTranslator
case COUNT_DISTINCT_APPROX_ROW:
currBuiltinOp = new
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
-
AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data"));
+ AggOp.COUNT_DISTINCT_APPROX,
Direction.Row, paramHops.get("data"));
break;
case COUNT_DISTINCT_APPROX_COL:
currBuiltinOp = new
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
-
AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
+ AggOp.COUNT_DISTINCT_APPROX,
Direction.Col, paramHops.get("data"));
break;
case UNIQUE:
@@ -2795,13 +2795,13 @@ public class DMLTranslator
}
case ROW_COUNT_DISTINCT:
- currBuiltinOp = new AggUnaryOp(target.getName(),
DataType.MATRIX, target.getValueType(),
-
AggOp.valueOf(source.getOpCode().name()), Direction.Row, expr);
+ currBuiltinOp = new AggUnaryOp(target.getName(),
+ DataType.MATRIX, target.getValueType(),
AggOp.COUNT_DISTINCT, Direction.Row, expr);
break;
case COL_COUNT_DISTINCT:
- currBuiltinOp = new AggUnaryOp(target.getName(),
DataType.MATRIX, target.getValueType(),
-
AggOp.valueOf(source.getOpCode().name()), Direction.Col, expr);
+ currBuiltinOp = new AggUnaryOp(target.getName(),
+ DataType.MATRIX, target.getValueType(),
AggOp.COUNT_DISTINCT, Direction.Col, expr);
break;
default:
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
index 346651bdd0..8fdc276d66 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
@@ -19,11 +19,9 @@
package org.apache.sysds.runtime.matrix.data;
-import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.common.Types;
import java.util.HashSet;
-import java.util.Iterator;
public class LibMatrixSketch {
@@ -35,31 +33,71 @@ public class LibMatrixSketch {
int clen = blkIn.getNumColumns();
MatrixBlock blkOut = null;
+ // TODO optimize for dense/sparse/compressed (once multi-column
support added)
+
switch (dir) {
- case RowCol:
- if( clen != 1 )
- throw new
NotImplementedException("Unique only support single-column vectors yet");
- // TODO optimize for dense/sparse/compressed
(once multi-column support added)
-
+ case RowCol: {
// obtain set of unique items (dense input
vector)
HashSet<Double> hashSet = new HashSet<>();
for( int i=0; i<rlen; i++ ) {
- hashSet.add(blkIn.get(i, 0));
+ for( int j=0; j<clen; j++ )
+ hashSet.add(blkIn.get(i, j));
}
// allocate output block and place values
int rlen2 = hashSet.size();
blkOut = new MatrixBlock(rlen2, 1,
false).allocateBlock();
- Iterator<Double> iter = hashSet.iterator();
- for( int i=0; i<rlen2; i++ ) {
- blkOut.set(i, 0, iter.next());
+ int pos = 0;
+ for( Double val : hashSet )
+ blkOut.set(pos++, 0, val);
+ break;
+ }
+ case Row: {
+ //2-pass algorithm to avoid unnecessarily large
mem requirements
+ HashSet<Double> hashSet = new HashSet<>();
+ int clen2 = 0;
+ for( int i=0; i<rlen; i++ ) {
+ hashSet.clear();
+ for( int j=0; j<clen; j++ )
+ hashSet.add(blkIn.get(i, j));
+ clen2 = Math.max(clen2, hashSet.size());
+ }
+
+ //actual
+ blkOut = new MatrixBlock(rlen, clen2,
false).allocateBlock();
+ for( int i=0; i<rlen; i++ ) {
+ hashSet.clear();
+ for( int j=0; j<clen; j++ )
+ hashSet.add(blkIn.get(i, j));
+ int pos = 0;
+ for( Double val : hashSet )
+ blkOut.set(i, pos++, val);
}
break;
-
- case Row:
- case Col:
- throw new NotImplementedException("Unique
Row/Col has not been implemented yet");
-
+ }
+ case Col: {
+ //2-pass algorithm to avoid unnecessarily large
mem requirements
+ HashSet<Double> hashSet = new HashSet<>();
+ int rlen2 = 0;
+ for( int j=0; j<clen; j++ ) {
+ hashSet.clear();
+ for( int i=0; i<rlen; i++ )
+ hashSet.add(blkIn.get(i, j));
+ rlen2 = Math.max(rlen2, hashSet.size());
+ }
+
+ //actual
+ blkOut = new MatrixBlock(rlen2, clen,
false).allocateBlock();
+ for( int j=0; j<clen; j++ ) {
+ hashSet.clear();
+ for( int i=0; i<rlen; i++ )
+ hashSet.add(blkIn.get(i, j));
+ int pos = 0;
+ for( Double val : hashSet )
+ blkOut.set(pos++, j, val);
+ }
+ break;
+ }
default:
throw new
IllegalArgumentException("Unrecognized direction: " + dir);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
index d834fe45ae..6e65c01f7c 100644
--- a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
@@ -45,7 +45,7 @@ public abstract class UniqueBase extends AutomatedTestBase {
loadTestConfiguration(getTestConfiguration(getTestName()));
String HOME = SCRIPT_DIR + getTestDir();
fullDMLScriptName = HOME + getTestName() + ".dml";
- programArgs = new String[]{ "-args", input("I"),
output("A")};
+ programArgs = new String[]{"-args", input("I"),
output("A")};
writeInputMatrixWithMTD("I", inputMatrix, true);
diff --git
a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java
b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java
index ee8c664efa..fda9aa4a3c 100644
--- a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java
@@ -27,7 +27,6 @@ public class UniqueRow extends UniqueBase {
private final static String TEST_DIR = "functions/unique/";
private static final String TEST_CLASS_DIR = TEST_DIR +
UniqueRow.class.getSimpleName() + "/";
-
@Override
protected String getTestName() {
return TEST_NAME;
@@ -52,22 +51,22 @@ public class UniqueRow extends UniqueBase {
@Test
public void testSkinnyCP() {
- double[][] inputMatrix =
{{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
- double[][] expectedMatrix = {{1},{6},{9},{4},{2},{0}};
+ double[][] inputMatrix = {{1,1,6,9,4,2,0,9,0,0,4,4}};
+ double[][] expectedMatrix = {{1,6,9,4,2,0}};
uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
}
@Test
public void testSquareCP() {
- double[][] inputMatrix = {{1, 2, 3}, {4, 5, 6}, {1, 2, 3}};
- double[][] expectedMatrix = {{1, 2, 3},{4, 5, 6}};
+ double[][] inputMatrix = {{1, 4, 1}, {2, 5, 2}, {3, 6, 3}};
+ double[][] expectedMatrix = {{1, 4},{2, 5},{3, 6}};
uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
}
@Test
public void testWideCP() {
- double[][] inputMatrix = {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11,
12}, {1, 2, 3, 4, 5, 6}};
- double[][] expectedMatrix = {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10,
11, 12}};
+ double[][] inputMatrix =
{{1,7,1},{2,8,2},{3,9,3},{4,10,4},{5,11,5},{6,12,6}};
+ double[][] expectedMatrix =
{{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}};
uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
}