This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 6c9f0ff1304125111ee20d3a3309f45f65bc6661
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Fri Oct 28 14:38:21 2022 +0200

    [SYSTEMDS-3413] Row/Col aggregation for countDistinct
    
    This patch converts countDistinct() from  a non-parameterized builtin to
    a parameterized builtin function to allow for 1 new parameter: dir for
    direction. The value of dir can be r and c, denoting row-wise and
    column-wise aggregation respectively. This patch only implements CP and
    the SP case will throw a NotImplementedException()- the latter case will
    be addressed in a subsequent patch.
    
    Closes #1677
---
 .../java/org/apache/sysds/common/Builtins.java     | 10 +++-
 src/main/java/org/apache/sysds/common/Types.java   |  6 +-
 .../org/apache/sysds/lops/PartialAggregate.java    | 29 +++++++--
 .../sysds/parser/BuiltinFunctionExpression.java    |  8 ---
 .../org/apache/sysds/parser/DMLTranslator.java     | 22 +++++--
 .../ParameterizedBuiltinFunctionExpression.java    | 60 +++++++++++++++++--
 .../runtime/instructions/CPInstructionParser.java  |  2 +
 .../runtime/instructions/InstructionUtils.java     | 51 +++++++++++-----
 .../runtime/instructions/SPInstructionParser.java  |  4 +-
 .../cp/AggregateUnaryCPInstruction.java            | 70 +++++++---------------
 .../spark/AggregateUnarySketchSPInstruction.java   | 34 +++++------
 .../matrix/data/LibMatrixCountDistinct.java        |  4 +-
 .../matrix/operators/CountDistinctOperator.java    | 45 ++++----------
 .../test/component/matrix/CountDistinctTest.java   |  5 +-
 .../countDistinct/CountDistinctApproxCol.java      |  2 +-
 .../countDistinct/CountDistinctApproxRow.java      |  2 +-
 ...istinctApproxCol.java => CountDistinctCol.java} | 11 ++--
 ...ctApproxCol.java => CountDistinctColAlias.java} | 11 ++--
 ...istinctApproxRow.java => CountDistinctRow.java} |  7 ++-
 ...ctApproxRow.java => CountDistinctRowAlias.java} |  7 ++-
 .../countDistinct/CountDistinctRowCol.java         |  2 +-
 ....java => CountDistinctRowColParameterized.java} |  6 +-
 .../countDistinct/CountDistinctRowOrColBase.java   | 32 +++++-----
 .../CountDistinctApproxCol.java                    | 20 ++++++-
 .../CountDistinctApproxColAlias.java}              | 26 ++++++--
 .../CountDistinctApproxRow.java                    | 20 ++++++-
 .../CountDistinctApproxRowAlias.java}              | 26 ++++++--
 .../CountDistinctApproxRowCol.java                 |  5 +-
 .../CountDistinctApproxRowColParameterized.java}   | 11 ++--
 .../{countDistinct.dml => countDistinctCol.dml}    |  2 +-
 ...countDistinct.dml => countDistinctColAlias.dml} |  2 +-
 .../{countDistinct.dml => countDistinctRow.dml}    |  2 +-
 ...countDistinct.dml => countDistinctRowAlias.dml} |  2 +-
 .../{countDistinct.dml => countDistinctRowCol.dml} |  2 +-
 ...ct.dml => countDistinctRowColParameterized.dml} |  2 +-
 .../countDistinctApproxCol.dml                     |  0
 .../countDistinctApproxColAlias.dml}               |  2 +-
 .../countDistinctApproxRow.dml                     |  0
 .../countDistinctApproxRowAlias.dml}               |  2 +-
 .../countDistinctApproxRowCol.dml}                 |  2 +-
 .../countDistinctApproxRowColParameterized.dml}    |  0
 41 files changed, 339 insertions(+), 217 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 5e9509696b..262212570e 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -37,7 +37,7 @@ import org.apache.sysds.common.Types.ReturnType;
  * building SystemDS, these scripts are packaged into the jar as well.
  */
 public enum Builtins {
-       //builtin functions
+       // Builtin functions without parameters
        ABSTAIN("abstain", true),
        ABS("abs", false),
        ACOS("acos", false),
@@ -93,7 +93,6 @@ public enum Builtins {
        CORRECTTYPOSAPPLY("correctTyposApply", true),
        COS("cos", false),
        COSH("cosh", false),
-       COUNT_DISTINCT("countDistinct",false),
        COV("cov", false),
        COX("cox", true),
        CSPLINE("cspline", true),
@@ -305,10 +304,15 @@ public enum Builtins {
        XGBOOSTPREDICT_CLASS("xgboostPredictClassification", true),
        XOR("xor", false),
 
-       //parameterized builtin functions
+       // Parameterized functions with parameters
        AUTODIFF("autoDiff", false, true),
        CDF("cdf", false, true),
+       COUNT_DISTINCT("countDistinct",false, true),
+       COUNT_DISTINCT_ROW("countDistinctRow",false, true),
+       COUNT_DISTINCT_COL("countDistinctCol",false, true),
        COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
+       COUNT_DISTINCT_APPROX_ROW("countDistinctApproxRow", false, true),
+       COUNT_DISTINCT_APPROX_COL("countDistinctApproxCol", false, true),
        CVLM("cvlm", true, false),
        GROUPEDAGG("aggregate", "groupedAggregate", false, true),
        INVCDF("icdf", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index c013b7890b..4f613a40d7 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -197,9 +197,9 @@ public class Types
                PROD(4), SUM_PROD(5),
                TRACE(6), MEAN(7), VAR(8),
                MAXINDEX(9), MININDEX(10),
-               COUNT_DISTINCT(11),
-               COUNT_DISTINCT_APPROX(12);
-               
+               COUNT_DISTINCT(11), COUNT_DISTINCT_ROW(12), 
COUNT_DISTINCT_COL(13),
+               COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), 
COUNT_DISTINCT_APPROX_COL(16);
+
                @Override
                public String toString() {
                        switch(this) {
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 050d87a3fd..0481c7373a 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -342,19 +342,38 @@ public class PartialAggregate extends Lop
                        }
 
                        case COUNT_DISTINCT: {
-                               if(dir == Direction.RowCol )
-                                       return "uacd";
-                               break;
+                               switch (dir) {
+                                       case RowCol: return "uacd";
+                                       case Row: return "uacdr";
+                                       case Col: return "uacdc";
+                                       default:
+                                               throw new 
LopsException("PartialAggregate.getOpcode() - "
+                                                               + "Unknown 
aggregate direction: " + dir);
+                               }
                        }
-                       
+
+                       case COUNT_DISTINCT_ROW:
+                               return "uacdr";
+
+                       case COUNT_DISTINCT_COL:
+                               return "uacdc";
+
                        case COUNT_DISTINCT_APPROX: {
                                switch (dir) {
                                        case RowCol: return "uacdap";
                                        case Row: return "uacdapr";
                                        case Col: return "uacdapc";
+                                       default:
+                                               throw new 
LopsException("PartialAggregate.getOpcode() - "
+                                                               + "Unknown 
aggregate direction: " + dir);
                                }
-                               break;
                        }
+
+                       case COUNT_DISTINCT_APPROX_ROW:
+                               return "uacdapr";
+
+                       case COUNT_DISTINCT_APPROX_COL:
+                               return "uacdapc";
                }
                
                //should never come here for normal compilation
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index c5db658dfa..c3aca47d38 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -943,14 +943,6 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        output.setBlocksize(0);
                        output.setValueType(ValueType.INT64);
                        break;
-               case COUNT_DISTINCT:
-                       checkNumParameters(1);
-                       checkDataTypeParam(getFirstExpr(), DataType.MATRIX);
-                       output.setDataType(DataType.SCALAR);
-                       output.setDimensions(0, 0);
-                       output.setBlocksize(0);
-                       output.setValueType(ValueType.INT64);
-                       break;
                case LINEAGE:
                        checkNumParameters(1);
                        checkDataTypeParam(getFirstExpr(),
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 315f54ff72..553bf56fc5 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2034,15 +2034,16 @@ public class DMLTranslator
                                                target.getValueType(), 
ParamBuiltinOp.TOSTRING, paramHops) :
                                        
HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), 
OpOp2.PLUS);
                                break;
+
                        case LISTNV:
                                currBuiltinOp = new 
ParameterizedBuiltinOp(target.getName(), target.getDataType(),
                                        target.getValueType(), 
ParamBuiltinOp.LIST, paramHops);
                                break;
 
+                       case COUNT_DISTINCT:
                        case COUNT_DISTINCT_APPROX:
-                               // Default direction and data type
-                               Direction dir = Direction.RowCol;
-                               DataType dataType = DataType.SCALAR;
+                               Direction dir = Direction.RowCol;  // Default 
direction
+                               DataType dataType = DataType.SCALAR;  // 
Default output data type
 
                                LiteralOp dirOp = (LiteralOp) 
paramHops.get("dir");
                                if (dirOp != null) {
@@ -2062,6 +2063,19 @@ public class DMLTranslator
                                currBuiltinOp = new 
AggUnaryOp(target.getName(), dataType, target.getValueType(),
                                                
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
                                break;
+
+                       case COUNT_DISTINCT_ROW:
+                       case COUNT_DISTINCT_APPROX_ROW:
+                               currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
+                                               
AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data"));
+                               break;
+
+                       case COUNT_DISTINCT_COL:
+                       case COUNT_DISTINCT_APPROX_COL:
+                               currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
+                                               
AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
+                               break;
+
                        default:
                                throw new 
ParseException(source.printErrorLocation() + 
                                        
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + 
source.getOpCode());
@@ -2361,10 +2375,10 @@ public class DMLTranslator
                case SUM:
                case PROD:
                case VAR:
-               case COUNT_DISTINCT:
                        currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.SCALAR, target.getValueType(),
                                        
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
                        break;
+
                case MEAN:
                        if ( expr2 == null ) {
                                // example: x = mean(Y);
diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index c83b6f3911..bdfd38c5a4 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -246,7 +246,15 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        validateParamserv(output, conditional);
                        break;
 
+               case COUNT_DISTINCT:
+               case COUNT_DISTINCT_ROW:
+               case COUNT_DISTINCT_COL:
+                       validateCountDistinct(output, conditional);
+                       break;
+
                case COUNT_DISTINCT_APPROX:
+               case COUNT_DISTINCT_APPROX_ROW:
+               case COUNT_DISTINCT_APPROX_COL:
                        validateCountDistinctApprox(output, conditional);
                        break;
 
@@ -353,6 +361,45 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                output.setBlocksize(-1);
        }
 
+       private void validateCountDistinct(DataIdentifier output, boolean 
conditional) {
+               HashMap<String, Expression> varParams = getVarParams();
+
+               // "data" is the only parameter that is allowed to be unnamed
+               if (varParams.containsKey(null)) {
+                       varParams.put("data", varParams.remove(null));
+               }
+
+               // Validate the number of parameters
+               String fname = getOpCode().getName();
+               String usageMessage = "function " + fname + " takes at least 1 
and at most 2 parameters";
+               if (varParams.size() < 1) {
+                       raiseValidateError("Too few parameters: " + 
usageMessage, conditional);
+               }
+
+               if (varParams.size() > 2) {
+                       raiseValidateError("Too many parameters: " + 
usageMessage, conditional);
+               }
+
+               // Check parameter names are valid
+               Set<String> validParameterNames = CollectionUtils.asSet("data", 
"dir");
+               checkInvalidParameters(getOpCode(), varParams, 
validParameterNames);
+
+               // Check parameter expression data types match expected
+               checkDataType(false, fname, "data", DataType.MATRIX, 
conditional);
+               checkDataValueType(false, fname, "data", DataType.MATRIX, 
ValueType.FP64, conditional);
+
+               // We need the dimensions of the input matrix to determine the 
output matrix characteristics
+               // Validate data parameter, lookup previously defined var or 
resolve expression
+               Identifier dataId = varParams.get("data").getOutput();
+               if (dataId == null) {
+                       raiseValidateError("Cannot parse input parameter 
\"data\" to function " + fname, conditional);
+               }
+
+               checkStringParam(true, fname, "dir", conditional);
+               // Check data value of "dir" parameter
+               validateAggregationDirection(dataId, output);
+       }
+
        private void validateCountDistinctApprox(DataIdentifier output, boolean 
conditional) {
                Set<String> validTypeNames = CollectionUtils.asSet("KMV");
                HashMap<String, Expression> varParams = getVarParams();
@@ -390,7 +437,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
 
                checkStringParam(true, fname, "type", conditional);
                // Check data value of "type" parameter
-               if (varParams.keySet().contains("type")) {
+               if (varParams.containsKey("type")) {
                        String typeString = 
varParams.get("type").toString().toUpperCase();
                        if (!validTypeNames.contains(typeString)) {
                                raiseValidateError("Unrecognized type for 
optional parameter " + typeString, conditional);
@@ -402,7 +449,12 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
 
                checkStringParam(true, fname, "dir", conditional);
                // Check data value of "dir" parameter
-               if (varParams.keySet().contains("dir")) {
+               validateAggregationDirection(dataId, output);
+       }
+
+       private void validateAggregationDirection(Identifier dataId, 
DataIdentifier output) {
+               HashMap<String, Expression> varParams = getVarParams();
+               if (varParams.containsKey("dir")) {
                        String directionString = 
varParams.get("dir").toString().toUpperCase();
 
                        // Set output type and dimensions based on direction
@@ -435,9 +487,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                        } else {
                                raiseValidateError("Invalid argument: " + 
directionString + " is not recognized");
                        }
-
-               // default to dir="rc"
-               } else {
+               } else {  // default to dir="rc"
                        output.setDataType(DataType.SCALAR);
                        output.setDimensions(0, 0);
                        output.setBlocksize(0);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index f2d3080ddc..b83a78674b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -117,6 +117,8 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "exists"  , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "lineage" , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacd"    , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uacdr"   , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uacdc"   , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacdap"  , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacdapr" , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacdapc" , 
CPType.AggregateUnary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 2c8468955f..d87e772709 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -82,26 +82,16 @@ import org.apache.sysds.runtime.functionobjects.ReduceCol;
 import org.apache.sysds.runtime.functionobjects.ReduceDiag;
 import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.functionobjects.Xor;
-import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import 
org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
 import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
 import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
-import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.CMOperator;
+import org.apache.sysds.runtime.matrix.operators.*;
 import 
org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
-import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
-import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
-import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
-import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
-import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
 
 
 public class InstructionUtils 
@@ -287,7 +277,14 @@ public class InstructionUtils
        public static AggregateUnaryOperator 
parseBasicAggregateUnaryOperator(String opcode) {
                return parseBasicAggregateUnaryOperator(opcode, 1);
        }
-       
+
+       /**
+        * Parse the given opcode into an aggregate unary operator.
+        *
+        * @param opcode opcode
+        * @param numThreads number of threads
+        * @return Parsed aggregate unary operator object. Caller must handle 
possible null return value.
+        */
        public static AggregateUnaryOperator 
parseBasicAggregateUnaryOperator(String opcode, int numThreads)
        {
                AggregateUnaryOperator aggun = null;
@@ -420,7 +417,31 @@ public class InstructionUtils
                        AggregateOperator agg = new 
AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
                        aggun = new AggregateUnaryOperator(agg, 
ReduceRow.getReduceRowFnObject(), numThreads);
                }
-               
+               else if ( opcode.equalsIgnoreCase("uacd") ) {
+                       aggun = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT,
+                                       Direction.RowCol, 
ReduceAll.getReduceAllFnObject());
+               }
+               else if ( opcode.equalsIgnoreCase("uacdr") ) {
+                       aggun = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT,
+                                       Direction.Row, 
ReduceCol.getReduceColFnObject());
+               }
+               else if ( opcode.equalsIgnoreCase("uacdc") ) {
+                       aggun = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT,
+                                       Direction.Col, 
ReduceRow.getReduceRowFnObject());
+               }
+               else if ( opcode.equalsIgnoreCase("uacdap") ) {
+                       aggun = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
+                                       Direction.RowCol, 
ReduceAll.getReduceAllFnObject());
+               }
+               else if ( opcode.equalsIgnoreCase("uacdapr") ) {
+                       aggun = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
+                                       Direction.Row, 
ReduceCol.getReduceColFnObject());
+               }
+               else if ( opcode.equalsIgnoreCase("uacdapc") ) {
+                       aggun = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
+                                       Direction.Col, 
ReduceRow.getReduceRowFnObject());
+               }
+
                return aggun;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 73329b954d..9496cf465e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -126,7 +126,9 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "uac*"    , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uatrace" , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uaktrace", 
SPType.AggregateUnary);
-               String2SPInstructionType.put( "uacdap"  , 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uacd"    , 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uacdr"   , 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uacdc"   , 
SPType.AggregateUnary);
 
                // Aggregate unary sketch operators
                String2SPInstructionType.put( "uacdap" , 
SPType.AggregateUnarySketch);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index ddf00ada2b..6fc0107520 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -33,11 +33,13 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll;
 import org.apache.sysds.runtime.functionobjects.ReduceCol;
 import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.lineage.LineageDedupUtils;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.SmallestPriorityQueue;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -45,6 +47,9 @@ import 
org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.utils.Explain;
 
+import java.util.HashSet;
+import java.util.Set;
+
 public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
        // private static final Log LOG = 
LogFactory.getLog(AggregateUnaryCPInstruction.class.getName());
 
@@ -81,36 +86,19 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                        return new AggregateUnaryCPInstruction(new 
SimpleOperator(Builtin.getBuiltinFnObject(opcode)),
                                in1, out, AUType.valueOf(opcode.toUpperCase()), 
opcode, str);
                } 
-               else if(opcode.equalsIgnoreCase("uacd")){
-                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT)
-                                       .setDirection(Types.Direction.RowCol)
-                                       
.setIndexFunction(ReduceAll.getReduceAllFnObject());
-
-                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT,
-                                       opcode, str);
+               else if(opcode.equalsIgnoreCase("uacd")
+                               || opcode.equalsIgnoreCase("uacdr")
+                               || opcode.equalsIgnoreCase("uacdc")){
+                       AggregateUnaryOperator aggun = 
InstructionUtils.parseBasicAggregateUnaryOperator(opcode,
+                                       Integer.parseInt(parts[3]));
+                       return new AggregateUnaryCPInstruction(aggun, in1, out, 
AUType.COUNT_DISTINCT, opcode, str);
                }
-               else if(opcode.equalsIgnoreCase("uacdap")){
-                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
-                                       .setDirection(Types.Direction.RowCol)
-                                       
.setIndexFunction(ReduceAll.getReduceAllFnObject());
-
-                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT_APPROX,
-                                       opcode, str);
-               }
-               else if(opcode.equalsIgnoreCase("uacdapr")){
-                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
-                                       .setDirection(Types.Direction.Row)
-                                       
.setIndexFunction(ReduceCol.getReduceColFnObject());
-
-                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT_APPROX,
-                                       opcode, str);
-               }
-               else if(opcode.equalsIgnoreCase("uacdapc")){
-                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
-                                       .setDirection(Types.Direction.Col)
-                                       
.setIndexFunction(ReduceRow.getReduceRowFnObject());
-
-                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT_APPROX,
+               else if(opcode.equalsIgnoreCase("uacdap")
+                               || opcode.equalsIgnoreCase("uacdapr")
+                               || opcode.equalsIgnoreCase("uacdapc")){
+                       AggregateUnaryOperator aggun = 
InstructionUtils.parseBasicAggregateUnaryOperator(opcode,
+                                       Integer.parseInt(parts[3]));
+                       return new AggregateUnaryCPInstruction(aggun, in1, out, 
AUType.COUNT_DISTINCT_APPROX,
                                        opcode, str);
                }
                else if(opcode.equalsIgnoreCase("uarimax") || 
opcode.equalsIgnoreCase("uarimin")){
@@ -199,34 +187,18 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                ec.setScalarOutput(output_name, new 
StringObject(out));
                                break;
                        }
-                       case COUNT_DISTINCT: {
-                               if( 
!ec.getVariables().keySet().contains(input1.getName()) )
-                                       throw new DMLRuntimeException("Variable 
'" + input1.getName() + "' does not exist.");
-                               MatrixBlock input = 
ec.getMatrixInput(input1.getName());
-
-                               // Operator type: test and cast
-                               if (!(_optr instanceof CountDistinctOperator)) {
-                                       throw new DMLRuntimeException("Operator 
should be instance of " + CountDistinctOperator.class.getSimpleName());
-                               }
-                               CountDistinctOperator op = 
(CountDistinctOperator) (_optr);
-
-                               //TODO add support for row or col count 
distinct.
-                               int res = (int) 
LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
-                               ec.releaseMatrixInput(input1.getName());
-                               ec.setScalarOutput(output_name, new 
IntObject(res));
-                               break;
-                       }
+                       case COUNT_DISTINCT:
                        case COUNT_DISTINCT_APPROX: {
                                
if(!ec.getVariables().keySet().contains(input1.getName())) {
                                        throw new DMLRuntimeException("Variable 
'" + input1.getName() + "' does not exist.");
                                }
-
                                MatrixBlock input = 
ec.getMatrixInput(input1.getName());
+
+                               // Operator type: test and cast
                                if (!(_optr instanceof CountDistinctOperator)) {
                                        throw new DMLRuntimeException("Operator 
should be instance of " + CountDistinctOperator.class.getSimpleName());
                                }
-
-                               CountDistinctOperator op = 
(CountDistinctOperator) _optr;  // It is safe to cast at this point
+                               CountDistinctOperator op = 
(CountDistinctOperator) _optr;
 
                                if (op.getDirection().isRowCol()) {
                                        long res = (long) 
LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
index 71bc75fd45..703828e3a1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -53,20 +53,6 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
     protected AggregateUnarySketchSPInstruction(Operator op, CPOperand in, 
CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String instr) {
         super(SPType.AggregateUnarySketch, op, in, out, opcode, instr);
         this.op = (CountDistinctOperator) super.getOperator();
-
-        if (opcode.equals("uacdap")) {
-            this.op.setDirection(Types.Direction.RowCol)
-                    .setIndexFunction(ReduceAll.getReduceAllFnObject());
-        } else if (opcode.equals("uacdapr")) {
-            this.op.setDirection(Types.Direction.Row)
-                    .setIndexFunction(ReduceCol.getReduceColFnObject());
-        } else if (opcode.equals("uacdapc")) {
-            this.op.setDirection(Types.Direction.Col)
-                    .setIndexFunction(ReduceRow.getReduceRowFnObject());
-        } else {
-            throw new DMLException("Unrecognized opcode " + opcode);
-        }
-
         this.aggtype = aggtype;
     }
 
@@ -79,7 +65,19 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
         CPOperand out = new CPOperand(parts[2]);
         AggBinaryOp.SparkAggType aggtype = 
AggBinaryOp.SparkAggType.valueOf(parts[3]);
 
-        CountDistinctOperator cdop = new 
CountDistinctOperator(CountDistinctOperatorTypes.KMV, Hash.HashType.LinearHash);
+        CountDistinctOperator cdop = null;
+        if (opcode.equals("uacdap")) {
+            cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, 
Types.Direction.RowCol,
+                    ReduceAll.getReduceAllFnObject(), 
Hash.HashType.LinearHash);
+        } else if (opcode.equals("uacdapr")) {
+            cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, 
Types.Direction.Row,
+                    ReduceCol.getReduceColFnObject(), 
Hash.HashType.LinearHash);
+        } else if (opcode.equals("uacdapc")) {
+            cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, 
Types.Direction.Col,
+                    ReduceRow.getReduceRowFnObject(), 
Hash.HashType.LinearHash);
+        } else {
+            throw new DMLException("Unrecognized opcode: " + opcode);
+        }
 
         return new AggregateUnarySketchSPInstruction(cdop, in1, out, aggtype, 
opcode, str);
     }
@@ -147,7 +145,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
 
             out3 = out2.mapValues(new 
CalculateAggregateSketchFunction(this.op));
 
-            updateUnaryAggOutputDataCharacteristics(sec, 
this.op.getIndexFunction());
+            updateUnaryAggOutputDataCharacteristics(sec, this.op.indexFn);
 
             // put output RDD handle into symbol table
             sec.setRDDHandleForVariable(output.getName(), out3);
@@ -173,7 +171,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
             MatrixBlock blkIn = arg0._2();
 
             MatrixIndexes ixOut = new MatrixIndexes();
-            this.op.getIndexFunction().execute(ixIn, ixOut);
+            this.op.indexFn.execute(ixIn, ixOut);
 
             return LibMatrixCountDistinct.createSketch(blkIn, this.op);
         }
@@ -222,7 +220,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
 
             MatrixIndexes idxOut = new MatrixIndexes();
             MatrixBlock blkOut = blkIn;  // Do not create sketch yet
-            this._op.getIndexFunction().execute(idxIn, idxOut);
+            this._op.indexFn.execute(idxIn, idxOut);
 
             return new Tuple2<>(idxOut, blkOut);
         }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index 814b7737f9..ee97c8d340 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -52,7 +52,6 @@ public interface LibMatrixCountDistinct {
         * Public method to count the number of distinct values inside a 
matrix. Depending on which CountDistinctOperator
         * selected it either gets the absolute number or a estimated value.
         * 
-        * TODO: Support counting num distinct in rows, or columns axis.
         * TODO: If the MatrixBlock type is CompressedMatrix, simply read the 
values from the ColGroups.
         * 
         * @param in the input matrix to count number distinct values in
@@ -252,7 +251,7 @@ public interface LibMatrixCountDistinct {
                                }
                        }
                } else {  // Col aggregation
-                       blkOut = new MatrixBlock(1, blkIn.getNumColumns(), 
false, blkIn.getNumRows());
+                       blkOut = new MatrixBlock(1, blkIn.getNumColumns(), 
false, blkIn.getNumColumns());
                        blkOut.allocateBlock();
 
                        // All dense and sparse formats (COO, CSR, MCSR) are 
row-major formats, so there is no obvious way to iterate
@@ -300,7 +299,6 @@ public interface LibMatrixCountDistinct {
                                                if (csrBlock.isEmpty(rix)) {
                                                        continue;
                                                }
-                                               distinct.clear();
                                                int rpos = csrBlock.pos(rix);
                                                int clen = csrBlock.size(rix);
                                                int[] cixs = csrBlock.indexes();
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
index 1c430c9134..c33accf943 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
@@ -22,19 +22,20 @@ package org.apache.sysds.runtime.matrix.operators;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.functionobjects.IndexFunction;
+import org.apache.sysds.runtime.functionobjects.Plus;
 import 
org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType;
 import org.apache.sysds.utils.Hash.HashType;
 
-public class CountDistinctOperator extends Operator {
+public class CountDistinctOperator extends AggregateUnaryOperator {
        private static final long serialVersionUID = 7615123453265129670L;
 
        private final CountDistinctOperatorTypes operatorType;
+       private final Types.Direction direction;
        private final HashType hashType;
-       private Types.Direction direction;
-       private IndexFunction indexFunction;
 
-       public CountDistinctOperator(AUType opType) {
-               super(true);
+       public CountDistinctOperator(AUType opType, Types.Direction direction, 
IndexFunction indexFunction) {
+               super(new AggregateOperator(0, Plus.getPlusFnObject()), 
indexFunction, 1);
+
                switch(opType) {
                        case COUNT_DISTINCT:
                                this.operatorType = 
CountDistinctOperatorTypes.COUNT;
@@ -46,25 +47,15 @@ public class CountDistinctOperator extends Operator {
                                throw new DMLRuntimeException(opType + " not 
supported for CountDistinct Operator");
                }
                this.hashType = HashType.LinearHash;
+               this.direction = direction;
        }
 
-       public CountDistinctOperator(CountDistinctOperatorTypes operatorType) {
-               super(true);
-               this.operatorType = operatorType;
-               this.hashType = HashType.StandardJava;
-       }
-
-       public CountDistinctOperator(CountDistinctOperatorTypes operatorType, 
HashType hashType) {
-               super(true);
-               this.operatorType = operatorType;
-               this.hashType = hashType;
-       }
+       public CountDistinctOperator(CountDistinctOperatorTypes operatorType, 
Types.Direction direction,
+                                                                IndexFunction 
indexFunction, HashType hashType) {
+               super(new AggregateOperator(0, Plus.getPlusFnObject()), 
indexFunction, 1);
 
-       public CountDistinctOperator(CountDistinctOperatorTypes operatorType, 
IndexFunction indexFunction,
-               HashType hashType) {
-               super(true);
                this.operatorType = operatorType;
-               this.indexFunction = indexFunction;
+               this.direction = direction;
                this.hashType = hashType;
        }
 
@@ -76,21 +67,7 @@ public class CountDistinctOperator extends Operator {
                return hashType;
        }
 
-       public IndexFunction getIndexFunction() {
-               return indexFunction;
-       }
-
-       public CountDistinctOperator setIndexFunction(IndexFunction 
indexFunction) {
-               this.indexFunction = indexFunction;
-               return this;
-       }
-
        public Types.Direction getDirection() {
                return direction;
        }
-
-       public CountDistinctOperator setDirection(Types.Direction direction) {
-               this.direction = direction;
-               return this;
-       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java 
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
index 5de18c4b3e..4b4909e27a 100644
--- 
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
@@ -29,6 +29,7 @@ import java.util.Collection;
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysds.api.DMLException;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
@@ -138,7 +139,9 @@ public class CountDistinctTest {
        @Test
        public void testEstimation() {
                try {
-                       CountDistinctOperator op = new 
CountDistinctOperator(et, ht).setDirection(Types.Direction.RowCol);
+                       CountDistinctOperator op = new 
CountDistinctOperator(et, Types.Direction.RowCol,
+                                       ReduceAll.getReduceAllFnObject(), ht);
+
                        if(expectedException != null) {
                                assertThrows(expectedException.getClass(), () 
-> {
                                        
LibMatrixCountDistinct.estimateDistinctValues(in, op);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
index 5a7eccc447..69f5fa1ef1 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
@@ -26,7 +26,7 @@ import org.junit.Test;
 public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
 
        private final static String TEST_NAME = "countDistinctApproxCol";
-       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxCol.class.getSimpleName() + "/";
 
        @Override
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
index c9aa75e375..07f3fcac38 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
@@ -26,7 +26,7 @@ import org.junit.Test;
 public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
 
        private final static String TEST_NAME = "countDistinctApproxRow";
-       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRow.class.getSimpleName() + "/";
 
        @Override
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java
similarity index 90%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java
index 5a7eccc447..fb26da4da2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java
@@ -23,11 +23,11 @@ import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.junit.Test;
 
-public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
+public class CountDistinctCol extends CountDistinctRowOrColBase {
 
-       private final static String TEST_NAME = "countDistinctApproxCol";
+       private final static String TEST_NAME = "countDistinctCol";
        private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxCol.class.getSimpleName() + "/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctCol.class.getSimpleName() + "/";
 
        @Override
        protected String getTestClassDir() {
@@ -52,6 +52,7 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
        @Override
        public void setUp() {
                super.addTestConfiguration();
+               super.setRunSparkTests(false);
        }
 
        @Test
@@ -73,7 +74,7 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
                double sparsity = 0.1;
                double tolerance = actualDistinctCount * this.percentTolerance;
 
-               super.testCPSparseLarge(SparseBlock.Type.CSR, 
Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+               super.testCPSparseLarge(SparseBlock.Type.CSR, 
Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
                                tolerance);
        }
 
@@ -84,7 +85,7 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
                double sparsity = 0.1;
                double tolerance = actualDistinctCount * this.percentTolerance;
 
-               super.testCPSparseLarge(SparseBlock.Type.COO, 
Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+               super.testCPSparseLarge(SparseBlock.Type.COO, 
Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
                                tolerance);
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java
similarity index 89%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java
index 5a7eccc447..08620d13d1 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java
@@ -23,11 +23,11 @@ import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.junit.Test;
 
-public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
+public class CountDistinctColAlias extends CountDistinctRowOrColBase {
 
-       private final static String TEST_NAME = "countDistinctApproxCol";
+       private final static String TEST_NAME = "countDistinctColAlias";
        private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxCol.class.getSimpleName() + "/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctColAlias.class.getSimpleName() + "/";
 
        @Override
        protected String getTestClassDir() {
@@ -52,6 +52,7 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
        @Override
        public void setUp() {
                super.addTestConfiguration();
+               super.setRunSparkTests(false);
        }
 
        @Test
@@ -73,7 +74,7 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
                double sparsity = 0.1;
                double tolerance = actualDistinctCount * this.percentTolerance;
 
-               super.testCPSparseLarge(SparseBlock.Type.CSR, 
Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+               super.testCPSparseLarge(SparseBlock.Type.CSR, 
Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
                                tolerance);
        }
 
@@ -84,7 +85,7 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
                double sparsity = 0.1;
                double tolerance = actualDistinctCount * this.percentTolerance;
 
-               super.testCPSparseLarge(SparseBlock.Type.COO, 
Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+               super.testCPSparseLarge(SparseBlock.Type.COO, 
Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
                                tolerance);
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java
similarity index 93%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java
index c9aa75e375..568c7516d0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java
@@ -23,11 +23,11 @@ import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.junit.Test;
 
-public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
+public class CountDistinctRow extends CountDistinctRowOrColBase {
 
-       private final static String TEST_NAME = "countDistinctApproxRow";
+       private final static String TEST_NAME = "countDistinctRow";
        private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRow.class.getSimpleName() + "/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctRow.class.getSimpleName() + "/";
 
        @Override
        protected String getTestClassDir() {
@@ -52,6 +52,7 @@ public class CountDistinctApproxRow extends 
CountDistinctRowOrColBase {
        @Override
        public void setUp() {
                super.addTestConfiguration();
+               super.setRunSparkTests(false);
        }
 
        @Test
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java
similarity index 93%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java
index c9aa75e375..c9e24cd38d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java
@@ -23,11 +23,11 @@ import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.junit.Test;
 
-public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
+public class CountDistinctRowAlias extends CountDistinctRowOrColBase {
 
-       private final static String TEST_NAME = "countDistinctApproxRow";
+       private final static String TEST_NAME = "countDistinctRowAlias";
        private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRow.class.getSimpleName() + "/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctRowAlias.class.getSimpleName() + "/";
 
        @Override
        protected String getTestClassDir() {
@@ -52,6 +52,7 @@ public class CountDistinctApproxRow extends 
CountDistinctRowOrColBase {
        @Override
        public void setUp() {
                super.addTestConfiguration();
+               super.setRunSparkTests(false);
        }
 
        @Test
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
index 3de4a61bcd..8f7d9acb8f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
@@ -24,7 +24,7 @@ import org.junit.Test;
 
 public class CountDistinctRowCol extends CountDistinctRowColBase {
 
-       public String TEST_NAME = "countDistinct";
+       public String TEST_NAME = "countDistinctRowCol";
        public String TEST_DIR = "functions/countDistinct/";
        public String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctRowCol.class.getSimpleName() + "/";
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java
similarity index 85%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java
index 3de4a61bcd..02048595e6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java
@@ -22,11 +22,11 @@ package org.apache.sysds.test.functions.countDistinct;
 import org.apache.sysds.common.Types.ExecType;
 import org.junit.Test;
 
-public class CountDistinctRowCol extends CountDistinctRowColBase {
+public class CountDistinctRowColParameterized extends CountDistinctRowColBase {
 
-       public String TEST_NAME = "countDistinct";
+       public String TEST_NAME = "countDistinctRowColParameterized";
        public String TEST_DIR = "functions/countDistinct/";
-       public String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctRowCol.class.getSimpleName() + "/";
+       public String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctRowColParameterized.class.getSimpleName() + "/";
 
        protected String getTestClassDir() {
                return TEST_CLASS_DIR;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
index a880c0d0dd..0d517776a2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
@@ -30,6 +30,8 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Test;
 
+import static org.junit.Assume.assumeTrue;
+
 public abstract class CountDistinctRowOrColBase extends CountDistinctBase {
 
        @Override
@@ -43,6 +45,8 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
 
        protected abstract Types.Direction getDirection();
 
+       private boolean runSparkTests = true;
+
        protected void addTestConfiguration() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(getTestName(), new 
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
@@ -50,19 +54,8 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
                this.percentTolerance = 0.2;
        }
 
-       /**
-        * This is a contrived example where size of row/col > 1024, which 
forces the calculation of a sketch.
-        */
-       @Test
-       public void testCPDenseXLarge() {
-               Types.ExecType ex = Types.ExecType.CP;
-
-               int actualDistinctCount = 10000;
-               int rows = 10000, cols = 10000;
-               double sparsity = 0.9;
-               double tolerance = actualDistinctCount * this.percentTolerance;
-
-               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       public void setRunSparkTests(boolean runSparkTests) {
+               this.runSparkTests = runSparkTests;
        }
 
        @Test
@@ -91,6 +84,8 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
 
        @Test
        public void testSparkSparseLargeMultiBlockAggregation() {
+               assumeTrue(runSparkTests);
+
                Types.ExecType execType = Types.ExecType.SPARK;
 
                int actualDistinctCount = 10;
@@ -103,6 +98,8 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
 
        @Test
        public void testSparkDenseLargeMultiBlockAggregation() {
+               assumeTrue(runSparkTests);
+
                Types.ExecType execType = Types.ExecType.SPARK;
 
                int actualDistinctCount = 10;
@@ -115,6 +112,8 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
 
        @Test
        public void testSparkSparseLargeNoneAggregation() {
+               assumeTrue(runSparkTests);
+
                Types.ExecType execType = Types.ExecType.SPARK;
 
                int actualDistinctCount = 10;
@@ -127,6 +126,8 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
 
        @Test
        public void testSparkDenseLargeNoneAggregation() {
+               assumeTrue(runSparkTests);
+
                Types.ExecType execType = Types.ExecType.SPARK;
 
                int actualDistinctCount = 10;
@@ -145,9 +146,8 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
                }
                blkIn = new MatrixBlock(blkIn, sparseBlockType, true);
 
-               CountDistinctOperator op = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX)
-                               .setDirection(direction)
-                               
.setIndexFunction(ReduceCol.getReduceColFnObject());
+               CountDistinctOperator op = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
+                               direction, ReduceCol.getReduceColFnObject());
 
                MatrixBlock blkOut = 
LibMatrixCountDistinct.estimateDistinctValues(blkIn, op);
                double[][] expectedMatrix = 
getExpectedMatrixRowOrCol(direction, cols, rows, actualDistinctCount);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java
similarity index 82%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java
index 5a7eccc447..6752bc29bb 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java
@@ -17,16 +17,17 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.test.functions.countDistinctApprox;
 
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase;
 import org.junit.Test;
 
 public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
 
        private final static String TEST_NAME = "countDistinctApproxCol";
-       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxCol.class.getSimpleName() + "/";
 
        @Override
@@ -99,4 +100,19 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
 
                countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
        }
+
+       /**
+        * This is a contrived example where size of row/col > 1024, which 
forces the calculation of a sketch in CP exec mode.
+        */
+       @Test
+       public void testCPDenseXLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 10000;
+               int rows = 10000, cols = 10000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java
similarity index 78%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java
index 5a7eccc447..e87813f464 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java
@@ -17,17 +17,18 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.test.functions.countDistinctApprox;
 
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase;
 import org.junit.Test;
 
-public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
+public class CountDistinctApproxColAlias extends CountDistinctRowOrColBase {
 
-       private final static String TEST_NAME = "countDistinctApproxCol";
-       private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxCol.class.getSimpleName() + "/";
+       private final static String TEST_NAME = "countDistinctApproxColAlias";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxColAlias.class.getSimpleName() + "/";
 
        @Override
        protected String getTestClassDir() {
@@ -99,4 +100,19 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
 
                countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
        }
+
+       /**
+        * This is a contrived example where size of row/col > 1024, which 
forces the calculation of a sketch in CP exec mode.
+        */
+       @Test
+       public void testCPDenseXLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 10000;
+               int rows = 10000, cols = 10000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java
similarity index 82%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java
index c9aa75e375..6e4678f5a8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java
@@ -17,16 +17,17 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.test.functions.countDistinctApprox;
 
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase;
 import org.junit.Test;
 
 public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
 
        private final static String TEST_NAME = "countDistinctApproxRow";
-       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRow.class.getSimpleName() + "/";
 
        @Override
@@ -99,4 +100,19 @@ public class CountDistinctApproxRow extends 
CountDistinctRowOrColBase {
 
                countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
        }
+
+       /**
+        * This is a contrived example where size of row/col > 1024, which 
forces the calculation of a sketch in CP exec mode.
+        */
+       @Test
+       public void testCPDenseXLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 10000;
+               int rows = 10000, cols = 10000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java
similarity index 78%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java
index c9aa75e375..99b6d60e04 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java
@@ -17,17 +17,18 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.test.functions.countDistinctApprox;
 
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase;
 import org.junit.Test;
 
-public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
+public class CountDistinctApproxRowAlias extends CountDistinctRowOrColBase {
 
-       private final static String TEST_NAME = "countDistinctApproxRow";
-       private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRow.class.getSimpleName() + "/";
+       private final static String TEST_NAME = "countDistinctApproxRowAlias";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRowAlias.class.getSimpleName() + "/";
 
        @Override
        protected String getTestClassDir() {
@@ -99,4 +100,19 @@ public class CountDistinctApproxRow extends 
CountDistinctRowOrColBase {
 
                countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
        }
+
+       /**
+        * This is a contrived example where size of row/col > 1024, which 
forces the calculation of a sketch in CP exec mode.
+        */
+       @Test
+       public void testCPDenseXLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 10000;
+               int rows = 10000, cols = 10000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java
similarity index 96%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java
index e59b002887..4c0f27bd5b 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java
@@ -17,15 +17,16 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.test.functions.countDistinctApprox;
 
 import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctRowColBase;
 import org.junit.Test;
 
 public class CountDistinctApproxRowCol extends CountDistinctRowColBase {
 
        private final static String TEST_NAME = "countDistinctApproxRowCol";
-       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRowCol.class.getSimpleName() + "/";
 
        @Override
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java
similarity index 91%
rename from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
rename to 
src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java
index e59b002887..df532a95d8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java
@@ -17,16 +17,17 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.test.functions.countDistinctApprox;
 
 import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctRowColBase;
 import org.junit.Test;
 
-public class CountDistinctApproxRowCol extends CountDistinctRowColBase {
+public class CountDistinctApproxRowColParameterized extends 
CountDistinctRowColBase {
 
-       private final static String TEST_NAME = "countDistinctApproxRowCol";
-       private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRowCol.class.getSimpleName() + "/";
+       private final static String TEST_NAME = 
"countDistinctApproxRowColParameterized";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRowColParameterized.class.getSimpleName() + "/";
 
        @Override
        public void setUp() {
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctCol.dml
similarity index 96%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctCol.dml
index 3b21bc89f1..3f2918ee1e 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctCol.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
+res = countDistinct(input, dir="c")
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
similarity index 96%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
index 3b21bc89f1..3eeb8ed54a 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
+res = countDistinctCol(input, dir="c")
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctRow.dml
similarity index 96%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctRow.dml
index 3b21bc89f1..f8665f6fc7 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctRow.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
+res = countDistinct(input, dir="r")
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
similarity index 96%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
index 3b21bc89f1..62d7196ce1 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
+res = countDistinctRow(input, dir="r")
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctRowCol.dml
similarity index 95%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctRowCol.dml
index 3b21bc89f1..7ac9dd53fc 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctRowCol.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
+res = countDistinct(input)  # default is dir="rc"
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml
similarity index 96%
rename from src/test/scripts/functions/countDistinct/countDistinct.dml
rename to 
src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml
index 3b21bc89f1..9bd22867d0 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ 
b/src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
+res = countDistinct(input, dir="rc")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxCol.dml
similarity index 100%
copy from src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
copy to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxCol.dml
diff --git 
a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
similarity index 94%
copy from src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
copy to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
index 2c5b6cf412..83a9f5070c 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
+++ 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApprox(input, dir="rc", type="KMV")
+res = countDistinctApproxCol(input, dir="c", type="KMV")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRow.dml
similarity index 100%
rename from src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml
rename to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxRow.dml
diff --git 
a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
similarity index 94%
copy from src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
copy to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
index 2c5b6cf412..f4be480156 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
+++ 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApprox(input, dir="rc", type="KMV")
+res = countDistinctApproxRow(input, dir="r", type="KMV")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml
similarity index 95%
rename from src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
rename to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml
index 777a56a443..21245ecfbb 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
+++ 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApprox(input, dir="c", type="KMV")
+res = countDistinctApprox(input, type="KMV")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowColParameterized.dml
similarity index 100%
rename from 
src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
rename to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowColParameterized.dml

Reply via email to