This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 57398289c1 [SYSTEMDS-3725] Fix countDistinct/unique 
compilation/runtime operators
57398289c1 is described below

commit 57398289c1da83e57b612059a63b6e0d9aca19ed
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Fri Aug 23 12:55:59 2024 +0200

    [SYSTEMDS-3725] Fix countDistinct/unique compilation/runtime operators
---
 src/main/java/org/apache/sysds/common/Types.java   |  8 +--
 .../org/apache/sysds/lops/PartialAggregate.java    | 12 ----
 .../org/apache/sysds/parser/DMLTranslator.java     | 12 ++--
 .../sysds/runtime/matrix/data/LibMatrixSketch.java | 70 +++++++++++++++++-----
 .../sysds/test/functions/unique/UniqueBase.java    |  2 +-
 .../sysds/test/functions/unique/UniqueRow.java     | 13 ++--
 6 files changed, 69 insertions(+), 48 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 4161f0c23d..a7397ae54b 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -493,12 +493,8 @@ public interface Types {
                PROD, SUM_PROD,
                TRACE, MEAN, VAR,
                MAXINDEX, MININDEX,
-               COUNT_DISTINCT, 
-               ROW_COUNT_DISTINCT, //TODO should be direction
-               COL_COUNT_DISTINCT,
-               COUNT_DISTINCT_APPROX, 
-               COUNT_DISTINCT_APPROX_ROW, //TODO should be direction
-               COUNT_DISTINCT_APPROX_COL,
+               COUNT_DISTINCT,
+               COUNT_DISTINCT_APPROX,
                UNIQUE;
 
                @Override
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 467c7c69b0..ed6ffe6e71 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -352,12 +352,6 @@ public class PartialAggregate extends Lop
                                }
                        }
 
-                       case ROW_COUNT_DISTINCT:
-                               return "uacdr";
-
-                       case COL_COUNT_DISTINCT:
-                               return "uacdc";
-
                        case COUNT_DISTINCT_APPROX: {
                                switch (dir) {
                                        case RowCol: return "uacdap";
@@ -369,12 +363,6 @@ public class PartialAggregate extends Lop
                                }
                        }
 
-                       case COUNT_DISTINCT_APPROX_ROW:
-                               return "uacdapr";
-
-                       case COUNT_DISTINCT_APPROX_COL:
-                               return "uacdapc";
-
                        case UNIQUE: {
                                switch (dir) {
                                        case RowCol: return "unique";
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 5ff351da4c..77ed904821 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2066,12 +2066,12 @@ public class DMLTranslator
 
                        case COUNT_DISTINCT_APPROX_ROW:
                                currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
-                                               
AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data"));
+                                               AggOp.COUNT_DISTINCT_APPROX, 
Direction.Row, paramHops.get("data"));
                                break;
 
                        case COUNT_DISTINCT_APPROX_COL:
                                currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
-                                               
AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
+                                               AggOp.COUNT_DISTINCT_APPROX, 
Direction.Col, paramHops.get("data"));
                                break;
 
                        case UNIQUE:
@@ -2795,13 +2795,13 @@ public class DMLTranslator
                }
 
                case ROW_COUNT_DISTINCT:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
-                                       
AggOp.valueOf(source.getOpCode().name()), Direction.Row, expr);
+                       currBuiltinOp = new AggUnaryOp(target.getName(),
+                               DataType.MATRIX, target.getValueType(), 
AggOp.COUNT_DISTINCT, Direction.Row, expr);
                        break;
 
                case COL_COUNT_DISTINCT:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
-                                       
AggOp.valueOf(source.getOpCode().name()), Direction.Col, expr);
+                       currBuiltinOp = new AggUnaryOp(target.getName(),
+                               DataType.MATRIX, target.getValueType(), 
AggOp.COUNT_DISTINCT, Direction.Col, expr);
                        break;
 
                default:
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
index 346651bdd0..8fdc276d66 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
@@ -19,11 +19,9 @@
 
 package org.apache.sysds.runtime.matrix.data;
 
-import org.apache.commons.lang3.NotImplementedException;
 import org.apache.sysds.common.Types;
 
 import java.util.HashSet;
-import java.util.Iterator;
 
 public class LibMatrixSketch {
 
@@ -35,31 +33,71 @@ public class LibMatrixSketch {
                int clen = blkIn.getNumColumns();
 
                MatrixBlock blkOut = null;
+               // TODO optimize for dense/sparse/compressed (once multi-column 
support added)
+               
                switch (dir) {
-                       case RowCol:
-                               if( clen != 1 )
-                                       throw new 
NotImplementedException("Unique only support single-column vectors yet");
-                               // TODO optimize for dense/sparse/compressed 
(once multi-column support added)
-                               
+                       case RowCol: {
                                // obtain set of unique items (dense input 
vector)
                                HashSet<Double> hashSet = new HashSet<>();
                                for( int i=0; i<rlen; i++ ) {
-                                       hashSet.add(blkIn.get(i, 0));
+                                       for( int j=0; j<clen; j++ )
+                                               hashSet.add(blkIn.get(i, j));
                                }
                                
                                // allocate output block and place values
                                int rlen2 = hashSet.size();
                                blkOut = new MatrixBlock(rlen2, 1, 
false).allocateBlock();
-                               Iterator<Double> iter = hashSet.iterator();
-                               for( int i=0; i<rlen2; i++ ) {
-                                       blkOut.set(i, 0, iter.next());
+                               int pos = 0;
+                               for( Double val : hashSet )
+                                       blkOut.set(pos++, 0, val);
+                               break;
+                       }
+                       case Row: {
+                               //2-pass algorithm to avoid unnecessarily large 
mem requirements
+                               HashSet<Double> hashSet = new HashSet<>();
+                               int clen2 = 0;
+                               for( int i=0; i<rlen; i++ ) {
+                                       hashSet.clear();
+                                       for( int j=0; j<clen; j++ )
+                                               hashSet.add(blkIn.get(i, j));
+                                       clen2 = Math.max(clen2, hashSet.size());
+                               }
+                               
+                               //actual 
+                               blkOut = new MatrixBlock(rlen, clen2, 
false).allocateBlock();
+                               for( int i=0; i<rlen; i++ ) {
+                                       hashSet.clear();
+                                       for( int j=0; j<clen; j++ )
+                                               hashSet.add(blkIn.get(i, j));
+                                       int pos = 0;
+                                       for( Double val : hashSet )
+                                               blkOut.set(i, pos++, val);
                                }
                                break;
-
-                       case Row:
-                       case Col:
-                               throw new NotImplementedException("Unique 
Row/Col has not been implemented yet");
-
+                       }
+                       case Col: {
+                               //2-pass algorithm to avoid unnecessarily large 
mem requirements
+                               HashSet<Double> hashSet = new HashSet<>();
+                               int rlen2 = 0;
+                               for( int j=0; j<clen; j++ ) {
+                                       hashSet.clear();
+                                       for( int i=0; i<rlen; i++ )
+                                               hashSet.add(blkIn.get(i, j));
+                                       rlen2 = Math.max(rlen2, hashSet.size());
+                               }
+                               
+                               //actual 
+                               blkOut = new MatrixBlock(rlen2, clen, 
false).allocateBlock();
+                               for( int j=0; j<clen; j++ ) {
+                                       hashSet.clear();
+                                       for( int i=0; i<rlen; i++ )
+                                               hashSet.add(blkIn.get(i, j));
+                                       int pos = 0;
+                                       for( Double val : hashSet )
+                                               blkOut.set(pos++, j, val);
+                               }
+                               break;
+                       }
                        default:
                                throw new 
IllegalArgumentException("Unrecognized direction: " + dir);
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java 
b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
index d834fe45ae..6e65c01f7c 100644
--- a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
@@ -45,7 +45,7 @@ public abstract class UniqueBase extends AutomatedTestBase {
                        
loadTestConfiguration(getTestConfiguration(getTestName()));
                        String HOME = SCRIPT_DIR + getTestDir();
                        fullDMLScriptName = HOME + getTestName() + ".dml";
-                       programArgs = new String[]{ "-args", input("I"), 
output("A")};
+                       programArgs = new String[]{"-args", input("I"), 
output("A")};
 
                        writeInputMatrixWithMTD("I", inputMatrix, true);
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java 
b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java
index ee8c664efa..fda9aa4a3c 100644
--- a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRow.java
@@ -27,7 +27,6 @@ public class UniqueRow extends UniqueBase {
        private final static String TEST_DIR = "functions/unique/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
UniqueRow.class.getSimpleName() + "/";
 
-
        @Override
        protected String getTestName() {
                return TEST_NAME;
@@ -52,22 +51,22 @@ public class UniqueRow extends UniqueBase {
 
        @Test
        public void testSkinnyCP() {
-               double[][] inputMatrix = 
{{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
-               double[][] expectedMatrix = {{1},{6},{9},{4},{2},{0}};
+               double[][] inputMatrix = {{1,1,6,9,4,2,0,9,0,0,4,4}};
+               double[][] expectedMatrix = {{1,6,9,4,2,0}};
                uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
        }
 
        @Test
        public void testSquareCP() {
-               double[][] inputMatrix = {{1, 2, 3}, {4, 5, 6}, {1, 2, 3}};
-               double[][] expectedMatrix = {{1, 2, 3},{4, 5, 6}};
+               double[][] inputMatrix = {{1, 4, 1}, {2, 5, 2}, {3, 6, 3}};
+               double[][] expectedMatrix = {{1, 4},{2, 5},{3, 6}};
                uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
        }
 
        @Test
        public void testWideCP() {
-               double[][] inputMatrix = {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 
12}, {1, 2, 3, 4, 5, 6}};
-               double[][] expectedMatrix = {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 
11, 12}};
+               double[][] inputMatrix = 
{{1,7,1},{2,8,2},{3,9,3},{4,10,4},{5,11,5},{6,12,6}};
+               double[][] expectedMatrix = 
{{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}};
                uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
        }
 

Reply via email to