This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 6c9f0ff1304125111ee20d3a3309f45f65bc6661 Author: Badrul Chowdhury <[email protected]> AuthorDate: Fri Oct 28 14:38:21 2022 +0200 [SYSTEMDS-3413] Row/Col aggregation for countDistinct This patch converts countDistinct() from a non-parameterized builtin to a parameterized builtin function to allow for 1 new parameter: dir for direction. The value of dir can be r and c, denoting row-wise and column-wise aggregation respectively. This patch only implements CP and the SP case will throw a NotImplementedException()- the latter case will be addressed in a subsequent patch. Closes #1677 --- .../java/org/apache/sysds/common/Builtins.java | 10 +++- src/main/java/org/apache/sysds/common/Types.java | 6 +- .../org/apache/sysds/lops/PartialAggregate.java | 29 +++++++-- .../sysds/parser/BuiltinFunctionExpression.java | 8 --- .../org/apache/sysds/parser/DMLTranslator.java | 22 +++++-- .../ParameterizedBuiltinFunctionExpression.java | 60 +++++++++++++++++-- .../runtime/instructions/CPInstructionParser.java | 2 + .../runtime/instructions/InstructionUtils.java | 51 +++++++++++----- .../runtime/instructions/SPInstructionParser.java | 4 +- .../cp/AggregateUnaryCPInstruction.java | 70 +++++++--------------- .../spark/AggregateUnarySketchSPInstruction.java | 34 +++++------ .../matrix/data/LibMatrixCountDistinct.java | 4 +- .../matrix/operators/CountDistinctOperator.java | 45 ++++---------- .../test/component/matrix/CountDistinctTest.java | 5 +- .../countDistinct/CountDistinctApproxCol.java | 2 +- .../countDistinct/CountDistinctApproxRow.java | 2 +- ...istinctApproxCol.java => CountDistinctCol.java} | 11 ++-- ...ctApproxCol.java => CountDistinctColAlias.java} | 11 ++-- ...istinctApproxRow.java => CountDistinctRow.java} | 7 ++- ...ctApproxRow.java => CountDistinctRowAlias.java} | 7 ++- .../countDistinct/CountDistinctRowCol.java | 2 +- ....java => CountDistinctRowColParameterized.java} | 6 +- .../countDistinct/CountDistinctRowOrColBase.java | 32 +++++----- .../CountDistinctApproxCol.java | 20 ++++++- .../CountDistinctApproxColAlias.java} | 26 ++++++-- .../CountDistinctApproxRow.java | 20 ++++++- .../CountDistinctApproxRowAlias.java} | 26 ++++++-- .../CountDistinctApproxRowCol.java | 5 +- .../CountDistinctApproxRowColParameterized.java} | 11 ++-- .../{countDistinct.dml => countDistinctCol.dml} | 2 +- ...countDistinct.dml => countDistinctColAlias.dml} | 2 +- .../{countDistinct.dml => countDistinctRow.dml} | 2 +- ...countDistinct.dml => countDistinctRowAlias.dml} | 2 +- .../{countDistinct.dml => countDistinctRowCol.dml} | 2 +- ...ct.dml => countDistinctRowColParameterized.dml} | 2 +- .../countDistinctApproxCol.dml | 0 .../countDistinctApproxColAlias.dml} | 2 +- .../countDistinctApproxRow.dml | 0 .../countDistinctApproxRowAlias.dml} | 2 +- .../countDistinctApproxRowCol.dml} | 2 +- .../countDistinctApproxRowColParameterized.dml} | 0 41 files changed, 339 insertions(+), 217 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 5e9509696b..262212570e 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -37,7 +37,7 @@ import org.apache.sysds.common.Types.ReturnType; * building SystemDS, these scripts are packaged into the jar as well. */ public enum Builtins { - //builtin functions + // Builtin functions without parameters ABSTAIN("abstain", true), ABS("abs", false), ACOS("acos", false), @@ -93,7 +93,6 @@ public enum Builtins { CORRECTTYPOSAPPLY("correctTyposApply", true), COS("cos", false), COSH("cosh", false), - COUNT_DISTINCT("countDistinct",false), COV("cov", false), COX("cox", true), CSPLINE("cspline", true), @@ -305,10 +304,15 @@ public enum Builtins { XGBOOSTPREDICT_CLASS("xgboostPredictClassification", true), XOR("xor", false), - //parameterized builtin functions + // Parameterized functions with parameters AUTODIFF("autoDiff", false, true), CDF("cdf", false, true), + COUNT_DISTINCT("countDistinct",false, true), + COUNT_DISTINCT_ROW("countDistinctRow",false, true), + COUNT_DISTINCT_COL("countDistinctCol",false, true), COUNT_DISTINCT_APPROX("countDistinctApprox", false, true), + COUNT_DISTINCT_APPROX_ROW("countDistinctApproxRow", false, true), + COUNT_DISTINCT_APPROX_COL("countDistinctApproxCol", false, true), CVLM("cvlm", true, false), GROUPEDAGG("aggregate", "groupedAggregate", false, true), INVCDF("icdf", false, true), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index c013b7890b..4f613a40d7 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -197,9 +197,9 @@ public class Types PROD(4), SUM_PROD(5), TRACE(6), MEAN(7), VAR(8), MAXINDEX(9), MININDEX(10), - COUNT_DISTINCT(11), - COUNT_DISTINCT_APPROX(12); - + COUNT_DISTINCT(11), COUNT_DISTINCT_ROW(12), COUNT_DISTINCT_COL(13), + COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), COUNT_DISTINCT_APPROX_COL(16); + @Override public String toString() { switch(this) { diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java b/src/main/java/org/apache/sysds/lops/PartialAggregate.java index 050d87a3fd..0481c7373a 100644 --- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java +++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java @@ -342,19 +342,38 @@ public class PartialAggregate extends Lop } case COUNT_DISTINCT: { - if(dir == Direction.RowCol ) - return "uacd"; - break; + switch (dir) { + case RowCol: return "uacd"; + case Row: return "uacdr"; + case Col: return "uacdc"; + default: + throw new LopsException("PartialAggregate.getOpcode() - " + + "Unknown aggregate direction: " + dir); + } } - + + case COUNT_DISTINCT_ROW: + return "uacdr"; + + case COUNT_DISTINCT_COL: + return "uacdc"; + case COUNT_DISTINCT_APPROX: { switch (dir) { case RowCol: return "uacdap"; case Row: return "uacdapr"; case Col: return "uacdapc"; + default: + throw new LopsException("PartialAggregate.getOpcode() - " + + "Unknown aggregate direction: " + dir); } - break; } + + case COUNT_DISTINCT_APPROX_ROW: + return "uacdapr"; + + case COUNT_DISTINCT_APPROX_COL: + return "uacdapc"; } //should never come here for normal compilation diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index c5db658dfa..c3aca47d38 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -943,14 +943,6 @@ public class BuiltinFunctionExpression extends DataIdentifier output.setBlocksize(0); output.setValueType(ValueType.INT64); break; - case COUNT_DISTINCT: - checkNumParameters(1); - checkDataTypeParam(getFirstExpr(), DataType.MATRIX); - output.setDataType(DataType.SCALAR); - output.setDimensions(0, 0); - output.setBlocksize(0); - output.setValueType(ValueType.INT64); - break; case LINEAGE: checkNumParameters(1); checkDataTypeParam(getFirstExpr(), diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 315f54ff72..553bf56fc5 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2034,15 +2034,16 @@ public class DMLTranslator target.getValueType(), ParamBuiltinOp.TOSTRING, paramHops) : HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), OpOp2.PLUS); break; + case LISTNV: currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.LIST, paramHops); break; + case COUNT_DISTINCT: case COUNT_DISTINCT_APPROX: - // Default direction and data type - Direction dir = Direction.RowCol; - DataType dataType = DataType.SCALAR; + Direction dir = Direction.RowCol; // Default direction + DataType dataType = DataType.SCALAR; // Default output data type LiteralOp dirOp = (LiteralOp) paramHops.get("dir"); if (dirOp != null) { @@ -2062,6 +2063,19 @@ public class DMLTranslator currBuiltinOp = new AggUnaryOp(target.getName(), dataType, target.getValueType(), AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data")); break; + + case COUNT_DISTINCT_ROW: + case COUNT_DISTINCT_APPROX_ROW: + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), + AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data")); + break; + + case COUNT_DISTINCT_COL: + case COUNT_DISTINCT_APPROX_COL: + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), + AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data")); + break; + default: throw new ParseException(source.printErrorLocation() + "processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode()); @@ -2361,10 +2375,10 @@ public class DMLTranslator case SUM: case PROD: case VAR: - case COUNT_DISTINCT: currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr); break; + case MEAN: if ( expr2 == null ) { // example: x = mean(Y); diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index c83b6f3911..bdfd38c5a4 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -246,7 +246,15 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier validateParamserv(output, conditional); break; + case COUNT_DISTINCT: + case COUNT_DISTINCT_ROW: + case COUNT_DISTINCT_COL: + validateCountDistinct(output, conditional); + break; + case COUNT_DISTINCT_APPROX: + case COUNT_DISTINCT_APPROX_ROW: + case COUNT_DISTINCT_APPROX_COL: validateCountDistinctApprox(output, conditional); break; @@ -353,6 +361,45 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier output.setBlocksize(-1); } + private void validateCountDistinct(DataIdentifier output, boolean conditional) { + HashMap<String, Expression> varParams = getVarParams(); + + // "data" is the only parameter that is allowed to be unnamed + if (varParams.containsKey(null)) { + varParams.put("data", varParams.remove(null)); + } + + // Validate the number of parameters + String fname = getOpCode().getName(); + String usageMessage = "function " + fname + " takes at least 1 and at most 2 parameters"; + if (varParams.size() < 1) { + raiseValidateError("Too few parameters: " + usageMessage, conditional); + } + + if (varParams.size() > 2) { + raiseValidateError("Too many parameters: " + usageMessage, conditional); + } + + // Check parameter names are valid + Set<String> validParameterNames = CollectionUtils.asSet("data", "dir"); + checkInvalidParameters(getOpCode(), varParams, validParameterNames); + + // Check parameter expression data types match expected + checkDataType(false, fname, "data", DataType.MATRIX, conditional); + checkDataValueType(false, fname, "data", DataType.MATRIX, ValueType.FP64, conditional); + + // We need the dimensions of the input matrix to determine the output matrix characteristics + // Validate data parameter, lookup previously defined var or resolve expression + Identifier dataId = varParams.get("data").getOutput(); + if (dataId == null) { + raiseValidateError("Cannot parse input parameter \"data\" to function " + fname, conditional); + } + + checkStringParam(true, fname, "dir", conditional); + // Check data value of "dir" parameter + validateAggregationDirection(dataId, output); + } + private void validateCountDistinctApprox(DataIdentifier output, boolean conditional) { Set<String> validTypeNames = CollectionUtils.asSet("KMV"); HashMap<String, Expression> varParams = getVarParams(); @@ -390,7 +437,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkStringParam(true, fname, "type", conditional); // Check data value of "type" parameter - if (varParams.keySet().contains("type")) { + if (varParams.containsKey("type")) { String typeString = varParams.get("type").toString().toUpperCase(); if (!validTypeNames.contains(typeString)) { raiseValidateError("Unrecognized type for optional parameter " + typeString, conditional); @@ -402,7 +449,12 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkStringParam(true, fname, "dir", conditional); // Check data value of "dir" parameter - if (varParams.keySet().contains("dir")) { + validateAggregationDirection(dataId, output); + } + + private void validateAggregationDirection(Identifier dataId, DataIdentifier output) { + HashMap<String, Expression> varParams = getVarParams(); + if (varParams.containsKey("dir")) { String directionString = varParams.get("dir").toString().toUpperCase(); // Set output type and dimensions based on direction @@ -435,9 +487,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier } else { raiseValidateError("Invalid argument: " + directionString + " is not recognized"); } - - // default to dir="rc" - } else { + } else { // default to dir="rc" output.setDataType(DataType.SCALAR); output.setDimensions(0, 0); output.setBlocksize(0); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index f2d3080ddc..b83a78674b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -117,6 +117,8 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "exists" , CPType.AggregateUnary); String2CPInstructionType.put( "lineage" , CPType.AggregateUnary); String2CPInstructionType.put( "uacd" , CPType.AggregateUnary); + String2CPInstructionType.put( "uacdr" , CPType.AggregateUnary); + String2CPInstructionType.put( "uacdc" , CPType.AggregateUnary); String2CPInstructionType.put( "uacdap" , CPType.AggregateUnary); String2CPInstructionType.put( "uacdapr" , CPType.AggregateUnary); String2CPInstructionType.put( "uacdapc" , CPType.AggregateUnary); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index 2c8468955f..d87e772709 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -82,26 +82,16 @@ import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.functionobjects.ReduceDiag; import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.functionobjects.Xor; -import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; +import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE; import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType; import org.apache.sysds.runtime.matrix.data.LibCommonsMath; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.CMOperator; +import org.apache.sysds.runtime.matrix.operators.*; import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes; -import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; -import org.apache.sysds.runtime.matrix.operators.Operator; -import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; -import org.apache.sysds.runtime.matrix.operators.ScalarOperator; -import org.apache.sysds.runtime.matrix.operators.TernaryOperator; -import org.apache.sysds.runtime.matrix.operators.UnaryOperator; public class InstructionUtils @@ -287,7 +277,14 @@ public class InstructionUtils public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode) { return parseBasicAggregateUnaryOperator(opcode, 1); } - + + /** + * Parse the given opcode into an aggregate unary operator. + * + * @param opcode opcode + * @param numThreads number of threads + * @return Parsed aggregate unary operator object. Caller must handle possible null return value. + */ public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode, int numThreads) { AggregateUnaryOperator aggun = null; @@ -420,7 +417,31 @@ public class InstructionUtils AggregateOperator agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min")); aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); } - + else if ( opcode.equalsIgnoreCase("uacd") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT, + Direction.RowCol, ReduceAll.getReduceAllFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdr") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT, + Direction.Row, ReduceCol.getReduceColFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdc") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT, + Direction.Col, ReduceRow.getReduceRowFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdap") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + Direction.RowCol, ReduceAll.getReduceAllFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdapr") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + Direction.Row, ReduceCol.getReduceColFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdapc") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + Direction.Col, ReduceRow.getReduceRowFnObject()); + } + return aggun; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index 73329b954d..9496cf465e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -126,7 +126,9 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "uac*" , SPType.AggregateUnary); String2SPInstructionType.put( "uatrace" , SPType.AggregateUnary); String2SPInstructionType.put( "uaktrace", SPType.AggregateUnary); - String2SPInstructionType.put( "uacdap" , SPType.AggregateUnary); + String2SPInstructionType.put( "uacd" , SPType.AggregateUnary); + String2SPInstructionType.put( "uacdr" , SPType.AggregateUnary); + String2SPInstructionType.put( "uacdc" , SPType.AggregateUnary); // Aggregate unary sketch operators String2SPInstructionType.put( "uacdap" , SPType.AggregateUnarySketch); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java index ddf00ada2b..6fc0107520 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java @@ -33,11 +33,13 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll; import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock; import org.apache.sysds.runtime.lineage.LineageDedupUtils; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.SmallestPriorityQueue; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -45,6 +47,9 @@ import org.apache.sysds.runtime.matrix.operators.SimpleOperator; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.utils.Explain; +import java.util.HashSet; +import java.util.Set; + public class AggregateUnaryCPInstruction extends UnaryCPInstruction { // private static final Log LOG = LogFactory.getLog(AggregateUnaryCPInstruction.class.getName()); @@ -81,36 +86,19 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction { return new AggregateUnaryCPInstruction(new SimpleOperator(Builtin.getBuiltinFnObject(opcode)), in1, out, AUType.valueOf(opcode.toUpperCase()), opcode, str); } - else if(opcode.equalsIgnoreCase("uacd")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT) - .setDirection(Types.Direction.RowCol) - .setIndexFunction(ReduceAll.getReduceAllFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT, - opcode, str); + else if(opcode.equalsIgnoreCase("uacd") + || opcode.equalsIgnoreCase("uacdr") + || opcode.equalsIgnoreCase("uacdc")){ + AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode, + Integer.parseInt(parts[3])); + return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.COUNT_DISTINCT, opcode, str); } - else if(opcode.equalsIgnoreCase("uacdap")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX) - .setDirection(Types.Direction.RowCol) - .setIndexFunction(ReduceAll.getReduceAllFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT_APPROX, - opcode, str); - } - else if(opcode.equalsIgnoreCase("uacdapr")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX) - .setDirection(Types.Direction.Row) - .setIndexFunction(ReduceCol.getReduceColFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT_APPROX, - opcode, str); - } - else if(opcode.equalsIgnoreCase("uacdapc")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX) - .setDirection(Types.Direction.Col) - .setIndexFunction(ReduceRow.getReduceRowFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT_APPROX, + else if(opcode.equalsIgnoreCase("uacdap") + || opcode.equalsIgnoreCase("uacdapr") + || opcode.equalsIgnoreCase("uacdapc")){ + AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode, + Integer.parseInt(parts[3])); + return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str); } else if(opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin")){ @@ -199,34 +187,18 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction { ec.setScalarOutput(output_name, new StringObject(out)); break; } - case COUNT_DISTINCT: { - if( !ec.getVariables().keySet().contains(input1.getName()) ) - throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist."); - MatrixBlock input = ec.getMatrixInput(input1.getName()); - - // Operator type: test and cast - if (!(_optr instanceof CountDistinctOperator)) { - throw new DMLRuntimeException("Operator should be instance of " + CountDistinctOperator.class.getSimpleName()); - } - CountDistinctOperator op = (CountDistinctOperator) (_optr); - - //TODO add support for row or col count distinct. - int res = (int) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0); - ec.releaseMatrixInput(input1.getName()); - ec.setScalarOutput(output_name, new IntObject(res)); - break; - } + case COUNT_DISTINCT: case COUNT_DISTINCT_APPROX: { if(!ec.getVariables().keySet().contains(input1.getName())) { throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist."); } - MatrixBlock input = ec.getMatrixInput(input1.getName()); + + // Operator type: test and cast if (!(_optr instanceof CountDistinctOperator)) { throw new DMLRuntimeException("Operator should be instance of " + CountDistinctOperator.class.getSimpleName()); } - - CountDistinctOperator op = (CountDistinctOperator) _optr; // It is safe to cast at this point + CountDistinctOperator op = (CountDistinctOperator) _optr; if (op.getDirection().isRowCol()) { long res = (long) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java index 71bc75fd45..703828e3a1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java @@ -53,20 +53,6 @@ public class AggregateUnarySketchSPInstruction extends UnarySPInstruction { protected AggregateUnarySketchSPInstruction(Operator op, CPOperand in, CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String instr) { super(SPType.AggregateUnarySketch, op, in, out, opcode, instr); this.op = (CountDistinctOperator) super.getOperator(); - - if (opcode.equals("uacdap")) { - this.op.setDirection(Types.Direction.RowCol) - .setIndexFunction(ReduceAll.getReduceAllFnObject()); - } else if (opcode.equals("uacdapr")) { - this.op.setDirection(Types.Direction.Row) - .setIndexFunction(ReduceCol.getReduceColFnObject()); - } else if (opcode.equals("uacdapc")) { - this.op.setDirection(Types.Direction.Col) - .setIndexFunction(ReduceRow.getReduceRowFnObject()); - } else { - throw new DMLException("Unrecognized opcode " + opcode); - } - this.aggtype = aggtype; } @@ -79,7 +65,19 @@ public class AggregateUnarySketchSPInstruction extends UnarySPInstruction { CPOperand out = new CPOperand(parts[2]); AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[3]); - CountDistinctOperator cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Hash.HashType.LinearHash); + CountDistinctOperator cdop = null; + if (opcode.equals("uacdap")) { + cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.RowCol, + ReduceAll.getReduceAllFnObject(), Hash.HashType.LinearHash); + } else if (opcode.equals("uacdapr")) { + cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.Row, + ReduceCol.getReduceColFnObject(), Hash.HashType.LinearHash); + } else if (opcode.equals("uacdapc")) { + cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.Col, + ReduceRow.getReduceRowFnObject(), Hash.HashType.LinearHash); + } else { + throw new DMLException("Unrecognized opcode: " + opcode); + } return new AggregateUnarySketchSPInstruction(cdop, in1, out, aggtype, opcode, str); } @@ -147,7 +145,7 @@ public class AggregateUnarySketchSPInstruction extends UnarySPInstruction { out3 = out2.mapValues(new CalculateAggregateSketchFunction(this.op)); - updateUnaryAggOutputDataCharacteristics(sec, this.op.getIndexFunction()); + updateUnaryAggOutputDataCharacteristics(sec, this.op.indexFn); // put output RDD handle into symbol table sec.setRDDHandleForVariable(output.getName(), out3); @@ -173,7 +171,7 @@ public class AggregateUnarySketchSPInstruction extends UnarySPInstruction { MatrixBlock blkIn = arg0._2(); MatrixIndexes ixOut = new MatrixIndexes(); - this.op.getIndexFunction().execute(ixIn, ixOut); + this.op.indexFn.execute(ixIn, ixOut); return LibMatrixCountDistinct.createSketch(blkIn, this.op); } @@ -222,7 +220,7 @@ public class AggregateUnarySketchSPInstruction extends UnarySPInstruction { MatrixIndexes idxOut = new MatrixIndexes(); MatrixBlock blkOut = blkIn; // Do not create sketch yet - this._op.getIndexFunction().execute(idxIn, idxOut); + this._op.indexFn.execute(idxIn, idxOut); return new Tuple2<>(idxOut, blkOut); } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java index 814b7737f9..ee97c8d340 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java @@ -52,7 +52,6 @@ public interface LibMatrixCountDistinct { * Public method to count the number of distinct values inside a matrix. Depending on which CountDistinctOperator * selected it either gets the absolute number or a estimated value. * - * TODO: Support counting num distinct in rows, or columns axis. * TODO: If the MatrixBlock type is CompressedMatrix, simply read the values from the ColGroups. * * @param in the input matrix to count number distinct values in @@ -252,7 +251,7 @@ public interface LibMatrixCountDistinct { } } } else { // Col aggregation - blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumRows()); + blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns()); blkOut.allocateBlock(); // All dense and sparse formats (COO, CSR, MCSR) are row-major formats, so there is no obvious way to iterate @@ -300,7 +299,6 @@ public interface LibMatrixCountDistinct { if (csrBlock.isEmpty(rix)) { continue; } - distinct.clear(); int rpos = csrBlock.pos(rix); int clen = csrBlock.size(rix); int[] cixs = csrBlock.indexes(); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java index 1c430c9134..c33accf943 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java @@ -22,19 +22,20 @@ package org.apache.sysds.runtime.matrix.operators; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.functionobjects.IndexFunction; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType; import org.apache.sysds.utils.Hash.HashType; -public class CountDistinctOperator extends Operator { +public class CountDistinctOperator extends AggregateUnaryOperator { private static final long serialVersionUID = 7615123453265129670L; private final CountDistinctOperatorTypes operatorType; + private final Types.Direction direction; private final HashType hashType; - private Types.Direction direction; - private IndexFunction indexFunction; - public CountDistinctOperator(AUType opType) { - super(true); + public CountDistinctOperator(AUType opType, Types.Direction direction, IndexFunction indexFunction) { + super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, 1); + switch(opType) { case COUNT_DISTINCT: this.operatorType = CountDistinctOperatorTypes.COUNT; @@ -46,25 +47,15 @@ public class CountDistinctOperator extends Operator { throw new DMLRuntimeException(opType + " not supported for CountDistinct Operator"); } this.hashType = HashType.LinearHash; + this.direction = direction; } - public CountDistinctOperator(CountDistinctOperatorTypes operatorType) { - super(true); - this.operatorType = operatorType; - this.hashType = HashType.StandardJava; - } - - public CountDistinctOperator(CountDistinctOperatorTypes operatorType, HashType hashType) { - super(true); - this.operatorType = operatorType; - this.hashType = hashType; - } + public CountDistinctOperator(CountDistinctOperatorTypes operatorType, Types.Direction direction, + IndexFunction indexFunction, HashType hashType) { + super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, 1); - public CountDistinctOperator(CountDistinctOperatorTypes operatorType, IndexFunction indexFunction, - HashType hashType) { - super(true); this.operatorType = operatorType; - this.indexFunction = indexFunction; + this.direction = direction; this.hashType = hashType; } @@ -76,21 +67,7 @@ public class CountDistinctOperator extends Operator { return hashType; } - public IndexFunction getIndexFunction() { - return indexFunction; - } - - public CountDistinctOperator setIndexFunction(IndexFunction indexFunction) { - this.indexFunction = indexFunction; - return this; - } - public Types.Direction getDirection() { return direction; } - - public CountDistinctOperator setDirection(Types.Direction direction) { - this.direction = direction; - return this; - } } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java index 5de18c4b3e..4b4909e27a 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java @@ -29,6 +29,7 @@ import java.util.Collection; import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.api.DMLException; import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.functionobjects.ReduceAll; import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator; @@ -138,7 +139,9 @@ public class CountDistinctTest { @Test public void testEstimation() { try { - CountDistinctOperator op = new CountDistinctOperator(et, ht).setDirection(Types.Direction.RowCol); + CountDistinctOperator op = new CountDistinctOperator(et, Types.Direction.RowCol, + ReduceAll.getReduceAllFnObject(), ht); + if(expectedException != null) { assertThrows(expectedException.getClass(), () -> { LibMatrixCountDistinct.estimateDistinctValues(in, op); diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java index 5a7eccc447..69f5fa1ef1 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java @@ -26,7 +26,7 @@ import org.junit.Test; public class CountDistinctApproxCol extends CountDistinctRowOrColBase { private final static String TEST_NAME = "countDistinctApproxCol"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxCol.class.getSimpleName() + "/"; @Override diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java index c9aa75e375..07f3fcac38 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java @@ -26,7 +26,7 @@ import org.junit.Test; public class CountDistinctApproxRow extends CountDistinctRowOrColBase { private final static String TEST_NAME = "countDistinctApproxRow"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRow.class.getSimpleName() + "/"; @Override diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java similarity index 90% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java copy to src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java index 5a7eccc447..fb26da4da2 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java @@ -23,11 +23,11 @@ import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; import org.junit.Test; -public class CountDistinctApproxCol extends CountDistinctRowOrColBase { +public class CountDistinctCol extends CountDistinctRowOrColBase { - private final static String TEST_NAME = "countDistinctApproxCol"; + private final static String TEST_NAME = "countDistinctCol"; private final static String TEST_DIR = "functions/countDistinct/"; - private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxCol.class.getSimpleName() + "/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctCol.class.getSimpleName() + "/"; @Override protected String getTestClassDir() { @@ -52,6 +52,7 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { @Override public void setUp() { super.addTestConfiguration(); + super.setRunSparkTests(false); } @Test @@ -73,7 +74,7 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { double sparsity = 0.1; double tolerance = actualDistinctCount * this.percentTolerance; - super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, tolerance); } @@ -84,7 +85,7 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { double sparsity = 0.1; double tolerance = actualDistinctCount * this.percentTolerance; - super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, tolerance); } diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java similarity index 89% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java copy to src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java index 5a7eccc447..08620d13d1 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java @@ -23,11 +23,11 @@ import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; import org.junit.Test; -public class CountDistinctApproxCol extends CountDistinctRowOrColBase { +public class CountDistinctColAlias extends CountDistinctRowOrColBase { - private final static String TEST_NAME = "countDistinctApproxCol"; + private final static String TEST_NAME = "countDistinctColAlias"; private final static String TEST_DIR = "functions/countDistinct/"; - private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxCol.class.getSimpleName() + "/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctColAlias.class.getSimpleName() + "/"; @Override protected String getTestClassDir() { @@ -52,6 +52,7 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { @Override public void setUp() { super.addTestConfiguration(); + super.setRunSparkTests(false); } @Test @@ -73,7 +74,7 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { double sparsity = 0.1; double tolerance = actualDistinctCount * this.percentTolerance; - super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, tolerance); } @@ -84,7 +85,7 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { double sparsity = 0.1; double tolerance = actualDistinctCount * this.percentTolerance; - super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, tolerance); } diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java similarity index 93% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java copy to src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java index c9aa75e375..568c7516d0 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java @@ -23,11 +23,11 @@ import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; import org.junit.Test; -public class CountDistinctApproxRow extends CountDistinctRowOrColBase { +public class CountDistinctRow extends CountDistinctRowOrColBase { - private final static String TEST_NAME = "countDistinctApproxRow"; + private final static String TEST_NAME = "countDistinctRow"; private final static String TEST_DIR = "functions/countDistinct/"; - private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRow.class.getSimpleName() + "/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctRow.class.getSimpleName() + "/"; @Override protected String getTestClassDir() { @@ -52,6 +52,7 @@ public class CountDistinctApproxRow extends CountDistinctRowOrColBase { @Override public void setUp() { super.addTestConfiguration(); + super.setRunSparkTests(false); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java similarity index 93% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java copy to src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java index c9aa75e375..c9e24cd38d 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java @@ -23,11 +23,11 @@ import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; import org.junit.Test; -public class CountDistinctApproxRow extends CountDistinctRowOrColBase { +public class CountDistinctRowAlias extends CountDistinctRowOrColBase { - private final static String TEST_NAME = "countDistinctApproxRow"; + private final static String TEST_NAME = "countDistinctRowAlias"; private final static String TEST_DIR = "functions/countDistinct/"; - private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRow.class.getSimpleName() + "/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctRowAlias.class.getSimpleName() + "/"; @Override protected String getTestClassDir() { @@ -52,6 +52,7 @@ public class CountDistinctApproxRow extends CountDistinctRowOrColBase { @Override public void setUp() { super.addTestConfiguration(); + super.setRunSparkTests(false); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java index 3de4a61bcd..8f7d9acb8f 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java @@ -24,7 +24,7 @@ import org.junit.Test; public class CountDistinctRowCol extends CountDistinctRowColBase { - public String TEST_NAME = "countDistinct"; + public String TEST_NAME = "countDistinctRowCol"; public String TEST_DIR = "functions/countDistinct/"; public String TEST_CLASS_DIR = TEST_DIR + CountDistinctRowCol.class.getSimpleName() + "/"; diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java similarity index 85% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java copy to src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java index 3de4a61bcd..02048595e6 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java @@ -22,11 +22,11 @@ package org.apache.sysds.test.functions.countDistinct; import org.apache.sysds.common.Types.ExecType; import org.junit.Test; -public class CountDistinctRowCol extends CountDistinctRowColBase { +public class CountDistinctRowColParameterized extends CountDistinctRowColBase { - public String TEST_NAME = "countDistinct"; + public String TEST_NAME = "countDistinctRowColParameterized"; public String TEST_DIR = "functions/countDistinct/"; - public String TEST_CLASS_DIR = TEST_DIR + CountDistinctRowCol.class.getSimpleName() + "/"; + public String TEST_CLASS_DIR = TEST_DIR + CountDistinctRowColParameterized.class.getSimpleName() + "/"; protected String getTestClassDir() { return TEST_CLASS_DIR; diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java index a880c0d0dd..0d517776a2 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java @@ -30,6 +30,8 @@ import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; import org.junit.Test; +import static org.junit.Assume.assumeTrue; + public abstract class CountDistinctRowOrColBase extends CountDistinctBase { @Override @@ -43,6 +45,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { protected abstract Types.Direction getDirection(); + private boolean runSparkTests = true; + protected void addTestConfiguration() { TestUtils.clearAssertionInformation(); addTestConfiguration(getTestName(), new TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"})); @@ -50,19 +54,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { this.percentTolerance = 0.2; } - /** - * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch. - */ - @Test - public void testCPDenseXLarge() { - Types.ExecType ex = Types.ExecType.CP; - - int actualDistinctCount = 10000; - int rows = 10000, cols = 10000; - double sparsity = 0.9; - double tolerance = actualDistinctCount * this.percentTolerance; - - countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + public void setRunSparkTests(boolean runSparkTests) { + this.runSparkTests = runSparkTests; } @Test @@ -91,6 +84,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { @Test public void testSparkSparseLargeMultiBlockAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -103,6 +98,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { @Test public void testSparkDenseLargeMultiBlockAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -115,6 +112,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { @Test public void testSparkSparseLargeNoneAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -127,6 +126,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { @Test public void testSparkDenseLargeNoneAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -145,9 +146,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { } blkIn = new MatrixBlock(blkIn, sparseBlockType, true); - CountDistinctOperator op = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX) - .setDirection(direction) - .setIndexFunction(ReduceCol.getReduceColFnObject()); + CountDistinctOperator op = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + direction, ReduceCol.getReduceColFnObject()); MatrixBlock blkOut = LibMatrixCountDistinct.estimateDistinctValues(blkIn, op); double[][] expectedMatrix = getExpectedMatrixRowOrCol(direction, cols, rows, actualDistinctCount); diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java similarity index 82% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java copy to src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java index 5a7eccc447..6752bc29bb 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java @@ -17,16 +17,17 @@ * under the License. */ -package org.apache.sysds.test.functions.countDistinct; +package org.apache.sysds.test.functions.countDistinctApprox; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; import org.junit.Test; public class CountDistinctApproxCol extends CountDistinctRowOrColBase { private final static String TEST_NAME = "countDistinctApproxCol"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxCol.class.getSimpleName() + "/"; @Override @@ -99,4 +100,19 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java similarity index 78% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java copy to src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java index 5a7eccc447..e87813f464 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java @@ -17,17 +17,18 @@ * under the License. */ -package org.apache.sysds.test.functions.countDistinct; +package org.apache.sysds.test.functions.countDistinctApprox; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; import org.junit.Test; -public class CountDistinctApproxCol extends CountDistinctRowOrColBase { +public class CountDistinctApproxColAlias extends CountDistinctRowOrColBase { - private final static String TEST_NAME = "countDistinctApproxCol"; - private final static String TEST_DIR = "functions/countDistinct/"; - private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxCol.class.getSimpleName() + "/"; + private final static String TEST_NAME = "countDistinctApproxColAlias"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxColAlias.class.getSimpleName() + "/"; @Override protected String getTestClassDir() { @@ -99,4 +100,19 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java similarity index 82% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java copy to src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java index c9aa75e375..6e4678f5a8 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java @@ -17,16 +17,17 @@ * under the License. */ -package org.apache.sysds.test.functions.countDistinct; +package org.apache.sysds.test.functions.countDistinctApprox; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; import org.junit.Test; public class CountDistinctApproxRow extends CountDistinctRowOrColBase { private final static String TEST_NAME = "countDistinctApproxRow"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRow.class.getSimpleName() + "/"; @Override @@ -99,4 +100,19 @@ public class CountDistinctApproxRow extends CountDistinctRowOrColBase { countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java similarity index 78% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java copy to src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java index c9aa75e375..99b6d60e04 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java @@ -17,17 +17,18 @@ * under the License. */ -package org.apache.sysds.test.functions.countDistinct; +package org.apache.sysds.test.functions.countDistinctApprox; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; import org.junit.Test; -public class CountDistinctApproxRow extends CountDistinctRowOrColBase { +public class CountDistinctApproxRowAlias extends CountDistinctRowOrColBase { - private final static String TEST_NAME = "countDistinctApproxRow"; - private final static String TEST_DIR = "functions/countDistinct/"; - private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRow.class.getSimpleName() + "/"; + private final static String TEST_NAME = "countDistinctApproxRowAlias"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRowAlias.class.getSimpleName() + "/"; @Override protected String getTestClassDir() { @@ -99,4 +100,19 @@ public class CountDistinctApproxRow extends CountDistinctRowOrColBase { countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java similarity index 96% copy from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java copy to src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java index e59b002887..4c0f27bd5b 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java @@ -17,15 +17,16 @@ * under the License. */ -package org.apache.sysds.test.functions.countDistinct; +package org.apache.sysds.test.functions.countDistinctApprox; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowColBase; import org.junit.Test; public class CountDistinctApproxRowCol extends CountDistinctRowColBase { private final static String TEST_NAME = "countDistinctApproxRowCol"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRowCol.class.getSimpleName() + "/"; @Override diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java similarity index 91% rename from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java rename to src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java index e59b002887..df532a95d8 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java @@ -17,16 +17,17 @@ * under the License. */ -package org.apache.sysds.test.functions.countDistinct; +package org.apache.sysds.test.functions.countDistinctApprox; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowColBase; import org.junit.Test; -public class CountDistinctApproxRowCol extends CountDistinctRowColBase { +public class CountDistinctApproxRowColParameterized extends CountDistinctRowColBase { - private final static String TEST_NAME = "countDistinctApproxRowCol"; - private final static String TEST_DIR = "functions/countDistinct/"; - private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRowCol.class.getSimpleName() + "/"; + private final static String TEST_NAME = "countDistinctApproxRowColParameterized"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRowColParameterized.class.getSimpleName() + "/"; @Override public void setUp() { diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinctCol.dml similarity index 96% copy from src/test/scripts/functions/countDistinct/countDistinct.dml copy to src/test/scripts/functions/countDistinct/countDistinctCol.dml index 3b21bc89f1..3f2918ee1e 100644 --- a/src/test/scripts/functions/countDistinct/countDistinct.dml +++ b/src/test/scripts/functions/countDistinct/countDistinctCol.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinct(input) +res = countDistinct(input, dir="c") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml similarity index 96% copy from src/test/scripts/functions/countDistinct/countDistinct.dml copy to src/test/scripts/functions/countDistinct/countDistinctColAlias.dml index 3b21bc89f1..3eeb8ed54a 100644 --- a/src/test/scripts/functions/countDistinct/countDistinct.dml +++ b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinct(input) +res = countDistinctCol(input, dir="c") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinctRow.dml similarity index 96% copy from src/test/scripts/functions/countDistinct/countDistinct.dml copy to src/test/scripts/functions/countDistinct/countDistinctRow.dml index 3b21bc89f1..f8665f6fc7 100644 --- a/src/test/scripts/functions/countDistinct/countDistinct.dml +++ b/src/test/scripts/functions/countDistinct/countDistinctRow.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinct(input) +res = countDistinct(input, dir="r") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml similarity index 96% copy from src/test/scripts/functions/countDistinct/countDistinct.dml copy to src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml index 3b21bc89f1..62d7196ce1 100644 --- a/src/test/scripts/functions/countDistinct/countDistinct.dml +++ b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinct(input) +res = countDistinctRow(input, dir="r") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinctRowCol.dml similarity index 95% copy from src/test/scripts/functions/countDistinct/countDistinct.dml copy to src/test/scripts/functions/countDistinct/countDistinctRowCol.dml index 3b21bc89f1..7ac9dd53fc 100644 --- a/src/test/scripts/functions/countDistinct/countDistinct.dml +++ b/src/test/scripts/functions/countDistinct/countDistinctRowCol.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinct(input) +res = countDistinct(input) # default is dir="rc" write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml similarity index 96% rename from src/test/scripts/functions/countDistinct/countDistinct.dml rename to src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml index 3b21bc89f1..9bd22867d0 100644 --- a/src/test/scripts/functions/countDistinct/countDistinct.dml +++ b/src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinct(input) +res = countDistinct(input, dir="rc") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxCol.dml similarity index 100% copy from src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml copy to src/test/scripts/functions/countDistinctApprox/countDistinctApproxCol.dml diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml similarity index 94% copy from src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml copy to src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml index 2c5b6cf412..83a9f5070c 100644 --- a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml +++ b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinctApprox(input, dir="rc", type="KMV") +res = countDistinctApproxCol(input, dir="c", type="KMV") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRow.dml similarity index 100% rename from src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml rename to src/test/scripts/functions/countDistinctApprox/countDistinctApproxRow.dml diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml similarity index 94% copy from src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml copy to src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml index 2c5b6cf412..f4be480156 100644 --- a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml +++ b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinctApprox(input, dir="rc", type="KMV") +res = countDistinctApproxRow(input, dir="r", type="KMV") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml similarity index 95% rename from src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml rename to src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml index 777a56a443..21245ecfbb 100644 --- a/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml +++ b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinctApprox(input, dir="c", type="KMV") +res = countDistinctApprox(input, type="KMV") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowColParameterized.dml similarity index 100% rename from src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml rename to src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowColParameterized.dml
