This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 57398289c1 [SYSTEMDS-3725] Fix countDistinct/unique compilation/runtime operators 57398289c1 is described below commit 57398289c1da83e57b612059a63b6e0d9aca19ed Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Fri Aug 23 12:55:59 2024 +0200 [SYSTEMDS-3725] Fix countDistinct/unique compilation/runtime operators --- src/main/java/org/apache/sysds/common/Types.java | 8 +-- .../org/apache/sysds/lops/PartialAggregate.java | 12 ---- .../org/apache/sysds/parser/DMLTranslator.java | 12 ++-- .../sysds/runtime/matrix/data/LibMatrixSketch.java | 70 +++++++++++++++++----- .../sysds/test/functions/unique/UniqueBase.java | 2 +- .../sysds/test/functions/unique/UniqueRow.java | 13 ++-- 6 files changed, 69 insertions(+), 48 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 4161f0c23d..a7397ae54b 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -493,12 +493,8 @@ public interface Types { PROD, SUM_PROD, TRACE, MEAN, VAR, MAXINDEX, MININDEX, - COUNT_DISTINCT, - ROW_COUNT_DISTINCT, //TODO should be direction - COL_COUNT_DISTINCT, - COUNT_DISTINCT_APPROX, - COUNT_DISTINCT_APPROX_ROW, //TODO should be direction - COUNT_DISTINCT_APPROX_COL, + COUNT_DISTINCT, + COUNT_DISTINCT_APPROX, UNIQUE; @Override diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java b/src/main/java/org/apache/sysds/lops/PartialAggregate.java index 467c7c69b0..ed6ffe6e71 100644 --- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java +++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java @@ -352,12 +352,6 @@ public class PartialAggregate extends Lop } } - case ROW_COUNT_DISTINCT: - return "uacdr"; - - case COL_COUNT_DISTINCT: - return "uacdc"; - case COUNT_DISTINCT_APPROX: { switch (dir) { case RowCol: return "uacdap"; @@ -369,12 +363,6 @@ public class PartialAggregate extends Lop } } - case COUNT_DISTINCT_APPROX_ROW: - return "uacdapr"; - - case COUNT_DISTINCT_APPROX_COL: - return "uacdapc"; - case UNIQUE: { switch (dir) { case RowCol: return "unique"; diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 5ff351da4c..77ed904821 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2066,12 +2066,12 @@ public class DMLTranslator case COUNT_DISTINCT_APPROX_ROW: currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), - AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data")); + AggOp.COUNT_DISTINCT_APPROX, Direction.Row, paramHops.get("data")); break; case COUNT_DISTINCT_APPROX_COL: currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), - AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data")); + AggOp.COUNT_DISTINCT_APPROX, Direction.Col, paramHops.get("data")); break; case UNIQUE: @@ -2795,13 +2795,13 @@ public class DMLTranslator } case ROW_COUNT_DISTINCT: - currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), - AggOp.valueOf(source.getOpCode().name()), Direction.Row, expr); + currBuiltinOp = new AggUnaryOp(target.getName(), + DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Row, expr); break; case COL_COUNT_DISTINCT: - currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), - AggOp.valueOf(source.getOpCode().name()), Direction.Col, expr); + currBuiltinOp = new AggUnaryOp(target.getName(), + DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr); break; default: diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java index 346651bdd0..8fdc276d66 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java @@ -19,11 +19,9 @@ package org.apache.sysds.runtime.matrix.data; -import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types; import java.util.HashSet; -import java.util.Iterator; public class LibMatrixSketch { @@ -35,31 +33,71 @@ public class LibMatrixSketch { int clen = blkIn.getNumColumns(); MatrixBlock blkOut = null; + // TODO optimize for dense/sparse/compressed (once multi-column support added) + switch (dir) { - case RowCol: - if( clen != 1 ) - throw new NotImplementedException("Unique only support single-column vectors yet"); - // TODO optimize for dense/sparse/compressed (once multi-column support added) - + case RowCol: { // obtain set of unique items (dense input vector) HashSet<Double> hashSet = new HashSet<>(); for( int i=0; i<rlen; i++ ) { - hashSet.add(blkIn.get(i, 0)); + for( int j=0; j<clen; j++ ) + hashSet.add(blkIn.get(i, j)); } // allocate output block and place values int rlen2 = hashSet.size(); blkOut = new MatrixBlock(rlen2, 1, false).allocateBlock(); - Iterator<Double> iter = hashSet.iterator(); - for( int i=0; i<rlen2; i++ ) { - blkOut.set(i, 0, iter.next()); + int pos = 0; + for( Double val : hashSet ) + blkOut.set(pos++, 0, val); + break; + } + case Row: { + //2-pass algorithm to avoid unnecessarily large mem requirements + HashSet<Double> hashSet = new HashSet<>(); + int clen2 = 0; + for( int i=0; i<rlen; i++ ) { + hashSet.clear(); + for( int j=0; j<clen; j++ ) + hashSet.add(blkIn.get(i, j)); + clen2 = Math.max(clen2, hashSet.size()); + } + + //actual + blkOut = new MatrixBlock(rlen, clen2, false).allocateBlock(); + for( int i=0; i<rlen; i++ ) { + hashSet.clear(); + for( int j=0; j<clen; j++ ) + hashSet.add(blkIn.get(i, j)); + int pos = 0; + for( Double val : hashSet ) + blkOut.set(i, pos++, val); } break; - - case Row: - case Col: - throw new NotImplementedException("Unique Row/Col has not been implemented yet"); - + } + case Col: { + //2-pass algorithm to avoid unnecessarily large mem requirements + HashSet<Double> hashSet = new HashSet<>(); + int rlen2 = 0; + for( int j=0; j<clen; j++ ) { + hashSet.clear(); + for( int i=0; i<rlen; i++ ) + hashSet.add(blkIn.get(i, j)); + rlen2 = Math.max(rlen2, hashSet.size()); + } + + //actual + blkOut = new MatrixBlock(rlen2, clen, false).allocateBlock(); + for( int j=0; j<clen; j++ ) { + hashSet.clear(); + for( int i=0; i<rlen; i++ ) + hashSet.add(blkIn.get(i, j)); + int pos = 0; + for( Double val : hashSet ) + blkOut.set(pos++, j, val); + } + break; + } default: throw new IllegalArgumentException("Unrecognized direction: " + dir); } diff --git a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java index d834fe45ae..6e65c01f7c 100644 --- a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java +++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java @@ -45,7 +45,7 @@ public abstract class UniqueBase extends AutomatedTestBase { loadTestConfiguration(getTestConfiguration(getTestName())); String HOME = SCRIPT_DIR + getTestDir(); fullDMLScriptName = HOME + getTestName() + ".dml"; - programArgs = new String[]{ "-args", input("I"), output("A")}; + programArgs = new String[]{"-args", input("I"), output("A")}; writeInputMatrixWithMTD("I", inputMatrix, true); diff --git a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java index ee8c664efa..fda9aa4a3c 100644 --- a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java +++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java @@ -27,7 +27,6 @@ public class UniqueRow extends UniqueBase { private final static String TEST_DIR = "functions/unique/"; private static final String TEST_CLASS_DIR = TEST_DIR + UniqueRow.class.getSimpleName() + "/"; - @Override protected String getTestName() { return TEST_NAME; @@ -52,22 +51,22 @@ public class UniqueRow extends UniqueBase { @Test public void testSkinnyCP() { - double[][] inputMatrix = {{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}}; - double[][] expectedMatrix = {{1},{6},{9},{4},{2},{0}}; + double[][] inputMatrix = {{1,1,6,9,4,2,0,9,0,0,4,4}}; + double[][] expectedMatrix = {{1,6,9,4,2,0}}; uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0); } @Test public void testSquareCP() { - double[][] inputMatrix = {{1, 2, 3}, {4, 5, 6}, {1, 2, 3}}; - double[][] expectedMatrix = {{1, 2, 3},{4, 5, 6}}; + double[][] inputMatrix = {{1, 4, 1}, {2, 5, 2}, {3, 6, 3}}; + double[][] expectedMatrix = {{1, 4},{2, 5},{3, 6}}; uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0); } @Test public void testWideCP() { - double[][] inputMatrix = {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}, {1, 2, 3, 4, 5, 6}}; - double[][] expectedMatrix = {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}; + double[][] inputMatrix = {{1,7,1},{2,8,2},{3,9,3},{4,10,4},{5,11,5},{6,12,6}}; + double[][] expectedMatrix = {{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}}; uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0); }