This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 5590135 [SYSTEMDS-2996] countDistinctApprox Builtin function
5590135 is described below
commit 5590135bd2d73c50e9528db5a83f15c0f7964d4b
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Sun Jan 30 21:27:16 2022 -0800
[SYSTEMDS-2996] countDistinctApprox Builtin function
This commit adds countDistinctApprox instruction to allow
for a faster approximate counting of distinct elements in a matrix.
Also added is support for spark with this new instruction.
Closes #1531
Closes #1554
(Just to make sure github see that you are the author)
Co-authored-by: Badrul Chowdhury <[email protected]>
---
.../java/org/apache/sysds/common/Builtins.java | 2 +-
src/main/java/org/apache/sysds/common/Types.java | 3 +
src/main/java/org/apache/sysds/conf/DMLConfig.java | 2 +
.../org/apache/sysds/lops/PartialAggregate.java | 9 +-
.../sysds/parser/BuiltinFunctionExpression.java | 16 +-
.../org/apache/sysds/parser/DMLTranslator.java | 60 ++-
.../ParameterizedBuiltinFunctionExpression.java | 140 +++++-
.../estim/CompressedSizeEstimatorUltraSparse.java | 4 +-
.../sysds/runtime/functionobjects/Builtin.java | 2 +-
.../runtime/instructions/CPInstructionParser.java | 14 +-
.../runtime/instructions/SPInstructionParser.java | 15 +-
.../cp/AggregateUnaryCPInstruction.java | 70 ++-
.../spark/AggregateUnarySPInstruction.java | 1 -
.../spark/AggregateUnarySketchSPInstruction.java | 293 +++++++++++++
.../runtime/instructions/spark/SPInstruction.java | 2 +-
.../matrix/data/LibMatrixCountDistinct.java | 200 ++-------
.../runtime/matrix/data/sketch/MatrixSketch.java | 68 +++
.../CountDistinctApproxSketch.java | 56 +++
.../data/sketch/countdistinctapprox/KMVSketch.java | 488 +++++++++++++++++++++
.../countdistinctapprox/SmallestPriorityQueue.java | 84 ++++
.../matrix/operators/CountDistinctOperator.java | 60 ++-
.../operators/CountDistinctOperatorTypes.java} | 35 +-
.../test/component/matrix/CountDistinctTest.java | 62 ++-
...inctApprox.java => CountDistinctApproxCol.java} | 37 +-
...ntDistinct.java => CountDistinctApproxRow.java} | 29 +-
.../countDistinct/CountDistinctApproxRowCol.java | 140 ++++++
.../functions/countDistinct/CountDistinctBase.java | 107 ++---
...CountDistinct.java => CountDistinctRowCol.java} | 14 +-
.../countDistinct/CountDistinctRowColBase.java | 81 ++++
.../countDistinct/CountDistinctRowOrColBase.java | 142 ++++++
.../functions/countDistinct/countDistinct.dml | 1 -
...stinctApprox.dml => countDistinctApproxCol.dml} | 4 +-
...ountDistinct.dml => countDistinctApproxRow.dml} | 5 +-
...tDistinct.dml => countDistinctApproxRowCol.dml} | 5 +-
34 files changed, 1831 insertions(+), 420 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 2fab87b..8eec1e5 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -93,7 +93,6 @@ public enum Builtins {
COS("cos", false),
COSH("cosh", false),
COUNT_DISTINCT("countDistinct",false),
- COUNT_DISTINCT_APPROX("countDistinctApprox",false),
COV("cov", false),
COX("cox", true),
CSPLINE("cspline", true),
@@ -306,6 +305,7 @@ public enum Builtins {
//parameterized builtin functions
AUTODIFF("autoDiff", false, true),
CDF("cdf", false, true),
+ COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
CVLM("cvlm", true, false),
GROUPEDAGG("aggregate", "groupedAggregate", false, true),
INVCDF("icdf", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 85a11e4..916935c 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -153,6 +153,9 @@ public class Types
public boolean isCol() {
return this == Col;
}
+ public boolean isRowCol() {
+ return this == RowCol;
+ }
@Override
public String toString() {
switch(this) {
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 2e27c17..f46be3b 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -25,6 +25,7 @@ import java.io.IOException;
import java.io.StringWriter;
import java.util.HashMap;
+import javax.xml.XMLConstants;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
@@ -245,6 +246,7 @@ public class DMLConfig
private DocumentBuilder getDocumentBuilder() throws
ParserConfigurationException {
if (_documentBuilder == null) {
DocumentBuilderFactory factory =
DocumentBuilderFactory.newInstance();
+
factory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true); // Prevent
XML Injection
factory.setIgnoringComments(true); //ignore XML comments
_documentBuilder = factory.newDocumentBuilder();
}
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 1a3bde7..050d87a 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -217,7 +217,7 @@ public class PartialAggregate extends Lop
}
/**
- * Instruction generation for for CP and Spark
+ * Instruction generation for CP and Spark
*/
@Override
public String getInstructions(String input1, String output)
@@ -348,8 +348,11 @@ public class PartialAggregate extends Lop
}
case COUNT_DISTINCT_APPROX: {
- if(dir == Direction.RowCol )
- return "uacdap";
+ switch (dir) {
+ case RowCol: return "uacdap";
+ case Row: return "uacdapr";
+ case Col: return "uacdapc";
+ }
break;
}
}
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index e3cb0ee..19b7177 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -623,10 +623,10 @@ public class BuiltinFunctionExpression extends
DataIdentifier
case MEAN:
//checkNumParameters(2, false); // mean(Y) or mean(Y,W)
if (getSecondExpr() != null) {
- checkNumParameters (2);
+ checkNumParameters(2);
}
else {
- checkNumParameters (1);
+ checkNumParameters(1);
}
checkMatrixParam(getFirstExpr());
@@ -933,7 +933,6 @@ public class BuiltinFunctionExpression extends
DataIdentifier
output.setValueType(ValueType.INT64);
break;
case COUNT_DISTINCT:
- case COUNT_DISTINCT_APPROX:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(), DataType.MATRIX);
output.setDataType(DataType.SCALAR);
@@ -941,7 +940,6 @@ public class BuiltinFunctionExpression extends
DataIdentifier
output.setBlocksize(0);
output.setValueType(ValueType.INT64);
break;
-
case LINEAGE:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(),
@@ -951,14 +949,12 @@ public class BuiltinFunctionExpression extends
DataIdentifier
output.setBlocksize(0);
output.setValueType(ValueType.STRING);
break;
-
case LIST:
output.setDataType(DataType.LIST);
output.setValueType(ValueType.UNKNOWN);
output.setDimensions(getAllExpr().length, 1);
output.setBlocksize(-1);
break;
-
case EXISTS:
checkNumParameters(1);
checkStringOrDataIdentifier(getFirstExpr());
@@ -1825,9 +1821,9 @@ public class BuiltinFunctionExpression extends
DataIdentifier
protected void checkNumParameters(int count) { //always unconditional
if (getFirstExpr() == null && _args.length > 0) {
raiseValidateError("Missing argument for function " +
this.getOpCode(), false,
- LanguageErrorCodes.INVALID_PARAMETERS);
+ LanguageErrorCodes.INVALID_PARAMETERS);
}
-
+
// Not sure the rationale for the first two if loops, but will
keep them for backward compatibility
if (((count == 1) && (getSecondExpr() != null || getThirdExpr()
!= null))
|| ((count == 2) && (getThirdExpr() != null))) {
@@ -1843,7 +1839,7 @@ public class BuiltinFunctionExpression extends
DataIdentifier
} else if (count == 0 && (_args.length > 0
|| getSecondExpr() != null || getThirdExpr() !=
null)) {
raiseValidateError("Missing argument for function " +
this.getOpCode()
- + "(). This function doesn't take any
arguments.", false);
+ + "(). This function doesn't take any
arguments.", false);
}
}
@@ -1870,7 +1866,7 @@ public class BuiltinFunctionExpression extends
DataIdentifier
if( !ArrayUtils.contains(dt, e.getOutput().getDataType()) )
raiseValidateError("Non-matching expected data type for
function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
-
+
protected void checkMatrixFrameParam(Expression e) { //always
unconditional
if (e.getOutput().getDataType() != DataType.MATRIX &&
e.getOutput().getDataType() != DataType.FRAME) {
raiseValidateError("Expecting matrix or frame parameter
for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 84212a7..ef51904 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -30,6 +30,22 @@ import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Builtins;
+import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
+import org.apache.sysds.common.Types.OpOpDG;
+import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.common.Types.OpOpDnn;
+import org.apache.sysds.common.Types.OpOpN;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
+import org.apache.sysds.common.Types.ReOrgOp;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
@@ -62,22 +78,6 @@ import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.compile.Dag;
-import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.common.Builtins;
-import org.apache.sysds.common.Types.AggOp;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.common.Types.FileFormat;
-import org.apache.sysds.common.Types.OpOp1;
-import org.apache.sysds.common.Types.OpOp2;
-import org.apache.sysds.common.Types.OpOp3;
-import org.apache.sysds.common.Types.OpOpDG;
-import org.apache.sysds.common.Types.OpOpData;
-import org.apache.sysds.common.Types.OpOpDnn;
-import org.apache.sysds.common.Types.OpOpN;
-import org.apache.sysds.common.Types.ParamBuiltinOp;
-import org.apache.sysds.common.Types.ReOrgOp;
-import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.PrintStatement.PRINTTYPE;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
@@ -91,7 +91,6 @@ import
org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
-
public class DMLTranslator
{
private static final Log LOG =
LogFactory.getLog(DMLTranslator.class.getName());
@@ -2035,6 +2034,29 @@ public class DMLTranslator
target.getValueType(),
ParamBuiltinOp.LIST, paramHops);
break;
+ case COUNT_DISTINCT_APPROX:
+ // Default direction and data type
+ Direction dir = Direction.RowCol;
+ DataType dataType = DataType.SCALAR;
+
+ LiteralOp dirOp = (LiteralOp)
paramHops.get("dir");
+ if (dirOp != null) {
+ String dirString =
dirOp.getStringValue().toUpperCase();
+ if
(dirString.equals(Direction.RowCol.toString())) {
+ dir = Direction.RowCol;
+ dataType = DataType.SCALAR;
+ } else if
(dirString.equals(Direction.Row.toString())) {
+ dir = Direction.Row;
+ dataType = DataType.MATRIX;
+ } else if
(dirString.equals(Direction.Col.toString())) {
+ dir = Direction.Col;
+ dataType = DataType.MATRIX;
+ }
+ }
+
+ currBuiltinOp = new
AggUnaryOp(target.getName(), dataType, target.getValueType(),
+
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
+ break;
default:
throw new
ParseException(source.printErrorLocation() +
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: " +
source.getOpCode());
@@ -2335,11 +2357,9 @@ public class DMLTranslator
case PROD:
case VAR:
case COUNT_DISTINCT:
- case COUNT_DISTINCT_APPROX:
currBuiltinOp = new AggUnaryOp(target.getName(),
DataType.SCALAR, target.getValueType(),
- AggOp.valueOf(source.getOpCode().name()),
Direction.RowCol, expr);
+
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
break;
-
case MEAN:
if ( expr2 == null ) {
// example: x = mean(Y);
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 442d1e6..6b6ca9b 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -29,13 +29,14 @@ import java.util.Set;
import java.util.stream.Collectors;
import org.antlr.v4.runtime.ParserRuleContext;
-import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.common.Builtins;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
import org.apache.sysds.runtime.util.CollectionUtils;
+import org.apache.wink.json4j.JSONObject;
public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
@@ -245,6 +246,10 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
validateParamserv(output, conditional);
break;
+ case COUNT_DISTINCT_APPROX:
+ validateCountDistinctApprox(output, conditional);
+ break;
+
default: //always unconditional (because unsupported operation)
//handle common issue of transformencode
if( getOpCode()==Builtins.TRANSFORMENCODE )
@@ -258,7 +263,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
private void validateAutoDiff(DataIdentifier output, boolean
conditional) {
//validate data / metadata (recode maps)
- checkDataType("lineage", LINEAGE_TRACE, DataType.LIST,
conditional);
+ checkDataType(false, "lineage", LINEAGE_TRACE, DataType.LIST,
conditional);
//validate specification
checkDataValueType(false, "lineage", LINEAGE_TRACE,
DataType.LIST, ValueType.UNKNOWN, conditional);
@@ -266,7 +271,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
// set output characteristics
output.setDataType(DataType.LIST);
output.setValueType(ValueType.UNKNOWN);
- // TODO dimension should be set to -1 but could not set due to
lineage parsing error in Spark contetx
+ // TODO dimension should be set to -1 but could not set due to
lineage parsing error in Spark context
output.setDimensions(varParams.size(), 1);
// output.setDimensions(-1, 1);
output.setBlocksize(-1);
@@ -319,9 +324,9 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
checkInvalidParameters(getOpCode(), getVarParams(), valid);
// check existence and correctness of parameters
- checkDataType(fname, Statement.PS_MODEL, DataType.LIST,
conditional); // check the model which is the only non-parameterized argument
- checkDataType(fname, Statement.PS_FEATURES, DataType.MATRIX,
conditional);
- checkDataType(fname, Statement.PS_LABELS, DataType.MATRIX,
conditional);
+ checkDataType(false, fname, Statement.PS_MODEL, DataType.LIST,
conditional); // check the model which is the only non-parameterized argument
+ checkDataType(false, fname, Statement.PS_FEATURES,
DataType.MATRIX, conditional);
+ checkDataType(false, fname, Statement.PS_LABELS,
DataType.MATRIX, conditional);
checkDataValueType(true, fname, Statement.PS_VAL_FEATURES,
DataType.MATRIX, ValueType.FP64, conditional);
checkDataValueType(true, fname, Statement.PS_VAL_LABELS,
DataType.MATRIX, ValueType.FP64, conditional);
checkDataValueType(false, fname, Statement.PS_UPDATE_FUN,
DataType.SCALAR, ValueType.STRING, conditional);
@@ -347,6 +352,99 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
output.setBlocksize(-1);
}
+ private void validateCountDistinctApprox(DataIdentifier output, boolean
conditional) {
+ Set<String> validTypeNames = CollectionUtils.asSet("KMV");
+ HashMap<String, Expression> varParams = getVarParams();
+
+ // "data" is the only parameter that is allowed to be unnamed
+ if (varParams.containsKey(null)) {
+ varParams.put("data", varParams.remove(null));
+ }
+
+ // Validate the number of parameters
+ String fname = getOpCode().getName();
+ String usageMessage = "function " + fname + " takes at least 1
and at most 3 parameters";
+ if (varParams.size() < 1) {
+ raiseValidateError("Too few parameters: " +
usageMessage, conditional);
+ }
+
+ if (varParams.size() > 3) {
+ raiseValidateError("Too many parameters: " +
usageMessage, conditional);
+ }
+
+ // Check parameter names are valid
+ Set<String> validParameterNames = CollectionUtils.asSet("data",
"type", "dir");
+ checkInvalidParameters(getOpCode(), varParams,
validParameterNames);
+
+ // Check parameter expression data types match expected
+ checkDataType(false, fname, "data", DataType.MATRIX,
conditional);
+ checkDataValueType(false, fname, "data", DataType.MATRIX,
ValueType.FP64, conditional);
+
+ // We need the dimensions of the input matrix to determine the
output matrix characteristics
+ // Validate data parameter, lookup previously defined var or
resolve expression
+ Identifier dataId = varParams.get("data").getOutput();
+ if (dataId == null) {
+ raiseValidateError("Cannot parse input parameter
\"data\" to function " + fname, conditional);
+ }
+
+ checkStringParam(true, fname, "type", conditional);
+ // Check data value of "type" parameter
+ if (varParams.keySet().contains("type")) {
+ String typeString =
varParams.get("type").toString().toUpperCase();
+ if (!validTypeNames.contains(typeString)) {
+ raiseValidateError("Unrecognized type for
optional parameter " + typeString, conditional);
+ }
+ } else {
+ // default to KMV
+ addVarParam("type", new StringIdentifier("KMV", this));
+ }
+
+ checkStringParam(true, fname, "dir", conditional);
+ // Check data value of "dir" parameter
+ if (varParams.keySet().contains("dir")) {
+ String directionString =
varParams.get("dir").toString().toUpperCase();
+
+ // Set output type and dimensions based on direction
+
+ // "r" -> count across all rows, resulting in a Mx1
matrix
+ if
(directionString.equals(Types.Direction.Row.toString())) {
+ output.setDataType(DataType.MATRIX);
+ output.setDimensions(dataId.getDim1(), 1);
+ output.setBlocksize(dataId.getBlocksize());
+ output.setValueType(ValueType.INT64);
+ output.setNnz(dataId.getDim1());
+
+ // "c" -> count across all cols, resulting in a 1xN
matrix
+ } else if
(directionString.equals(Types.Direction.Col.toString())) {
+ output.setDataType(DataType.MATRIX);
+ output.setDimensions(1, dataId.getDim2());
+ output.setBlocksize(dataId.getBlocksize());
+ output.setValueType(ValueType.INT64);
+ output.setNnz(dataId.getDim2());
+
+ // "rc" -> count across all rows and cols in input
matrix, resulting in a single value
+ } else if
(directionString.equals(Types.Direction.RowCol.toString())) {
+ output.setDataType(DataType.SCALAR);
+ output.setDimensions(0, 0);
+ output.setBlocksize(0);
+ output.setValueType(ValueType.INT64);
+ output.setNnz(1);
+
+ // unrecognized value for "dir" parameter, should "cr"
be valid?
+ } else {
+ raiseValidateError("Invalid argument: " +
directionString + " is not recognized");
+ }
+
+ // default to dir="rc"
+ } else {
+ output.setDataType(DataType.SCALAR);
+ output.setDimensions(0, 0);
+ output.setBlocksize(0);
+ output.setValueType(ValueType.INT64);
+ output.setNnz(1);
+ }
+ }
+
private void checkStringParam(boolean optional, String fname, String
pname, boolean conditional) {
Expression param = getVarParam(pname);
if (param == null) {
@@ -365,7 +463,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
private void validateTokenize(DataIdentifier output, boolean
conditional)
{
//validate data / metadata (recode maps)
- checkDataType("tokenize", TF_FN_PARAM_DATA, DataType.FRAME,
conditional);
+ checkDataType(false, "tokenize", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
//validate specification
checkDataValueType(false, "tokenize", TF_FN_PARAM_SPEC,
DataType.SCALAR, ValueType.STRING, conditional);
@@ -381,8 +479,8 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
private void validateTransformApply(DataIdentifier output, boolean
conditional)
{
//validate data / metadata (recode maps)
- checkDataType("transformapply", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
- checkDataType("transformapply", TF_FN_PARAM_MTD2,
DataType.FRAME, conditional);
+ checkDataType(false, "transformapply", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
+ checkDataType(false, "transformapply", TF_FN_PARAM_MTD2,
DataType.FRAME, conditional);
//validate specification
checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC,
DataType.SCALAR, ValueType.STRING, conditional);
@@ -397,8 +495,8 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
private void validateTransformDecode(DataIdentifier output, boolean
conditional)
{
//validate data / metadata (recode maps)
- checkDataType("transformdecode", TF_FN_PARAM_DATA,
DataType.MATRIX, conditional);
- checkDataType("transformdecode", TF_FN_PARAM_MTD2,
DataType.FRAME, conditional);
+ checkDataType(false, "transformdecode", TF_FN_PARAM_DATA,
DataType.MATRIX, conditional);
+ checkDataType(false, "transformdecode", TF_FN_PARAM_MTD2,
DataType.FRAME, conditional);
//validate specification
checkDataValueType(false, "transformdecode", TF_FN_PARAM_SPEC,
DataType.SCALAR, ValueType.STRING, conditional);
@@ -414,7 +512,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
{
//validate data / metadata (recode maps)
Expression exprTarget = getVarParam(Statement.GAGG_TARGET);
- checkDataType("transformcolmap", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
+ checkDataType(false, "transformcolmap", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
//validate specification
checkDataValueType(false,"transformcolmap", TF_FN_PARAM_SPEC,
DataType.SCALAR, ValueType.STRING, conditional);
@@ -444,7 +542,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
private void validateTransformEncode(DataIdentifier output1,
DataIdentifier output2, boolean conditional)
{
//validate data / metadata (recode maps)
- checkDataType("transformencode", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
+ checkDataType(false, "transformencode", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
//validate specification
checkDataValueType(false, "transformencode", TF_FN_PARAM_SPEC,
DataType.SCALAR, ValueType.STRING, conditional);
@@ -871,12 +969,18 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
output.setBlocksize(-1);
}
- private void checkDataType( String fname, String pname, DataType dt,
boolean conditional ) {
+ private void checkDataType(boolean optional, String fname, String
pname, DataType dt, boolean conditional) {
Expression data = getVarParam(pname);
- if( data==null )
- raiseValidateError("Named parameter '" + pname + "'
missing. Please specify the input.", conditional,
LanguageErrorCodes.INVALID_PARAMETERS);
- else if( data.getOutput().getDataType() != dt )
- raiseValidateError("Input to "+fname+"::"+pname+" must
be of type '"+dt.toString()+"'. It should not be of type
'"+data.getOutput().getDataType()+"'.", conditional,
LanguageErrorCodes.INVALID_PARAMETERS);
+ if(data == null) {
+ if(optional)
+ return;
+ raiseValidateError("Named parameter '" + pname + "'
missing. Please specify the input.", conditional,
+ LanguageErrorCodes.INVALID_PARAMETERS);
+ }
+ else if(data.getOutput().getDataType() != dt)
+ raiseValidateError("Input to " + fname + "::" + pname +
" must be of type '" + dt.toString()
+ + "'. It should not be of type '" +
data.getOutput().getDataType() + "'.", conditional,
+ LanguageErrorCodes.INVALID_PARAMETERS);
}
private void checkDataValueType(boolean optional, String fname, String
pname, DataType dt, ValueType vt,
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
index 7a31f13..23ad02c 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
@@ -26,7 +26,7 @@ import
org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
-import
org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
/**
* UltraSparse compressed size estimator (examines entire dataset).
@@ -39,7 +39,7 @@ public class CompressedSizeEstimatorUltraSparse extends
CompressedSizeEstimator
private CompressedSizeEstimatorUltraSparse(MatrixBlock data,
CompressionSettings compSettings) {
super(data, compSettings);
- CountDistinctOperator op = new
CountDistinctOperator(CountDistinctTypes.COUNT);
+ CountDistinctOperator op = new
CountDistinctOperator(CountDistinctOperatorTypes.COUNT);
final int _numRows = getNumRows();
if(LOG.isDebugEnabled()) {
diff --git
a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 4f423c2..7866f23 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -223,7 +223,7 @@ public class Builtin extends ValueFunction
// compared and performs just the value part of
the comparison. We
// return an integer cast down to a double,
since the aggregation
// API doesn't have any way to return anything
but a double. The
- // integer returned takes on three posssible
values: //
+ // integer returned takes on three possible
values: //
// . 0 => keep the index associated with
in1 //
// . 1 => use the index associated with in2
//
// . 2 => use whichever index is higher
(tie in value) //
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index c08985b..d3b8ad6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -96,7 +96,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "uacvar" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uamax" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uarmax" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarimax", CPType.AggregateUnary);
+ String2CPInstructionType.put( "uarimax" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uacmax" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uamin" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uarmin" ,
CPType.AggregateUnary);
@@ -110,13 +110,15 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "uac*" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uatrace" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uaktrace",
CPType.AggregateUnary);
- String2CPInstructionType.put( "nrow" ,CPType.AggregateUnary);
- String2CPInstructionType.put( "ncol" ,CPType.AggregateUnary);
- String2CPInstructionType.put( "length" ,CPType.AggregateUnary);
- String2CPInstructionType.put( "exists" ,CPType.AggregateUnary);
- String2CPInstructionType.put( "lineage" ,CPType.AggregateUnary);
+ String2CPInstructionType.put( "nrow" ,
CPType.AggregateUnary);
+ String2CPInstructionType.put( "ncol" ,
CPType.AggregateUnary);
+ String2CPInstructionType.put( "length" ,
CPType.AggregateUnary);
+ String2CPInstructionType.put( "exists" ,
CPType.AggregateUnary);
+ String2CPInstructionType.put( "lineage" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uacd" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uacdap" ,
CPType.AggregateUnary);
+ String2CPInstructionType.put( "uacdapr" ,
CPType.AggregateUnary);
+ String2CPInstructionType.put( "uacdapc" ,
CPType.AggregateUnary);
String2CPInstructionType.put( "uaggouterchain",
CPType.UaggOuterChain);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 56cd49a..965617d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -42,6 +42,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import
org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.AggregateUnarySketchSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
@@ -62,6 +63,7 @@ import
org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction
import org.apache.sysds.runtime.instructions.spark.DeCompressionSPInstruction;
import org.apache.sysds.runtime.instructions.spark.DnnSPInstruction;
import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.LIBSVMReblockSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmChainSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MatrixReshapeSPInstruction;
@@ -87,7 +89,7 @@ import
org.apache.sysds.runtime.instructions.spark.UnaryFrameSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ZipmmSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.LIBSVMReblockSPInstruction;
+
public class SPInstructionParser extends InstructionParser
{
@@ -110,7 +112,7 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "uacvar" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uamax" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uarmax" ,
SPType.AggregateUnary);
- String2SPInstructionType.put( "uarimax" ,
SPType.AggregateUnary);
+ String2SPInstructionType.put( "uarimax" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uacmax" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uamin" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uarmin" ,
SPType.AggregateUnary);
@@ -124,6 +126,12 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "uac*" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uatrace" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uaktrace",
SPType.AggregateUnary);
+ String2SPInstructionType.put( "uacdap" ,
SPType.AggregateUnary);
+
+ // Aggregate unary sketch operators
+ String2SPInstructionType.put( "uacdap" ,
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacdapr",
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacdapc",
SPType.AggregateUnarySketch);
//binary aggregate operators (matrix multiplication operators)
String2SPInstructionType.put( "mapmm" , SPType.MAPMM);
@@ -388,6 +396,9 @@ public class SPInstructionParser extends InstructionParser
case AggregateUnary:
return
AggregateUnarySPInstruction.parseInstruction(str);
+ case AggregateUnarySketch:
+ return
AggregateUnarySketchSPInstruction.parseInstruction(str);
+
case AggregateTernary:
return
AggregateTernarySPInstruction.parseInstruction(str);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index ef1ff08..fbcf6ff 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -28,6 +29,9 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageDedupUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -82,8 +86,28 @@ public class AggregateUnaryCPInstruction extends
UnaryCPInstruction {
in1, out, AUType.COUNT_DISTINCT, opcode, str);
}
else if(opcode.equalsIgnoreCase("uacdap")){
- return new AggregateUnaryCPInstruction(new
SimpleOperator(null),
- in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str);
+ CountDistinctOperator op = new
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
+ .setDirection(Types.Direction.RowCol)
+
.setIndexFunction(ReduceAll.getReduceAllFnObject());
+
+ return new AggregateUnaryCPInstruction(op, in1, out,
AUType.COUNT_DISTINCT_APPROX,
+ opcode, str);
+ }
+ else if(opcode.equalsIgnoreCase("uacdapr")){
+ CountDistinctOperator op = new
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
+ .setDirection(Types.Direction.Row)
+
.setIndexFunction(ReduceCol.getReduceColFnObject());
+
+ return new AggregateUnaryCPInstruction(op, in1, out,
AUType.COUNT_DISTINCT_APPROX,
+ opcode, str);
+ }
+ else if(opcode.equalsIgnoreCase("uacdapc")){
+ CountDistinctOperator op = new
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
+ .setDirection(Types.Direction.Col)
+
.setIndexFunction(ReduceRow.getReduceRowFnObject());
+
+ return new AggregateUnaryCPInstruction(op, in1, out,
AUType.COUNT_DISTINCT_APPROX,
+ opcode, str);
}
else if(opcode.equalsIgnoreCase("uarimax") ||
opcode.equalsIgnoreCase("uarimin")){
// parse with number of outputs
@@ -171,17 +195,55 @@ public class AggregateUnaryCPInstruction extends
UnaryCPInstruction {
ec.setScalarOutput(output_name, new
StringObject(out));
break;
}
- case COUNT_DISTINCT:
- case COUNT_DISTINCT_APPROX: {
+ case COUNT_DISTINCT: {
if(
!ec.getVariables().keySet().contains(input1.getName()) )
throw new DMLRuntimeException("Variable
'" + input1.getName() + "' does not exist.");
MatrixBlock input =
ec.getMatrixInput(input1.getName());
CountDistinctOperator op = new
CountDistinctOperator(_type);
+ //TODO add support for row or col count
distinct.
int res =
LibMatrixCountDistinct.estimateDistinctValues(input, op);
ec.releaseMatrixInput(input1.getName());
ec.setScalarOutput(output_name, new
IntObject(res));
break;
}
+ case COUNT_DISTINCT_APPROX: {
+
if(!ec.getVariables().keySet().contains(input1.getName())) {
+ throw new DMLRuntimeException("Variable
'" + input1.getName() + "' does not exist.");
+ }
+
+ MatrixBlock input =
ec.getMatrixInput(input1.getName());
+ if (!(_optr instanceof CountDistinctOperator)) {
+ throw new DMLRuntimeException("Operator
should be instance of " + CountDistinctOperator.class.getSimpleName());
+ }
+
+ CountDistinctOperator op =
(CountDistinctOperator) _optr; // It is safe to cast at this point
+
+ if (op.getDirection().isRowCol()) {
+ int res =
LibMatrixCountDistinct.estimateDistinctValues(input, op);
+ ec.releaseMatrixInput(input1.getName());
+ ec.setScalarOutput(output_name, new
IntObject(res));
+ } else if (op.getDirection().isRow()) {
+ //TODO Do not slice out the matrix but
directly process on the input
+ MatrixBlock res = input.slice(0,
input.getNumRows() - 1, 0, 0);
+ for (int i = 0; i < input.getNumRows();
++i) {
+ res.setValue(i, 0,
LibMatrixCountDistinct.estimateDistinctValues(input.slice(i, i), op));
+ }
+ ec.releaseMatrixInput(input1.getName());
+ ec.setMatrixOutput(output_name, res);
+ } else if (op.getDirection().isCol()) {
+ //TODO Do not slice out the matrix but
directly process on the input
+ MatrixBlock res = input.slice(0, 0, 0,
input.getNumColumns() - 1);
+ for (int j = 0; j <
input.getNumColumns(); ++j) {
+ res.setValue(0, j,
LibMatrixCountDistinct.estimateDistinctValues(input.slice(0, input.getNumRows()
- 1, j, j), op));
+ }
+ ec.releaseMatrixInput(input1.getName());
+ ec.setMatrixOutput(output_name, res);
+ } else {
+ throw new
DMLRuntimeException("Direction for CountDistinctOperator not recognized");
+ }
+
+ break;
+ }
default: {
AggregateUnaryOperator au_op =
(AggregateUnaryOperator) _optr;
if (input1.getDataType() == DataType.MATRIX) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 04b6650..38b032c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -289,7 +289,6 @@ public class AggregateUnarySPInstruction extends
UnarySPInstruction {
public RDDUAggValueFunction( AggregateUnaryOperator op, int
blen ) {
_op = op;
_blen = blen;
- _blen = blen;
_ix = new MatrixIndexes(1,1);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
new file mode 100644
index 0000000..71bc75f
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.spark;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.Hash;
+import scala.Tuple2;
+
+public class AggregateUnarySketchSPInstruction extends UnarySPInstruction {
+ private AggBinaryOp.SparkAggType aggtype;
+ private CountDistinctOperator op;
+
+ protected AggregateUnarySketchSPInstruction(Operator op, CPOperand in,
CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String instr) {
+ super(SPType.AggregateUnarySketch, op, in, out, opcode, instr);
+ this.op = (CountDistinctOperator) super.getOperator();
+
+ if (opcode.equals("uacdap")) {
+ this.op.setDirection(Types.Direction.RowCol)
+ .setIndexFunction(ReduceAll.getReduceAllFnObject());
+ } else if (opcode.equals("uacdapr")) {
+ this.op.setDirection(Types.Direction.Row)
+ .setIndexFunction(ReduceCol.getReduceColFnObject());
+ } else if (opcode.equals("uacdapc")) {
+ this.op.setDirection(Types.Direction.Col)
+ .setIndexFunction(ReduceRow.getReduceRowFnObject());
+ } else {
+ throw new DMLException("Unrecognized opcode " + opcode);
+ }
+
+ this.aggtype = aggtype;
+ }
+
+ public static AggregateUnarySketchSPInstruction parseInstruction(String
str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(parts, 3);
+ String opcode = parts[0];
+
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ AggBinaryOp.SparkAggType aggtype =
AggBinaryOp.SparkAggType.valueOf(parts[3]);
+
+ CountDistinctOperator cdop = new
CountDistinctOperator(CountDistinctOperatorTypes.KMV, Hash.HashType.LinearHash);
+
+ return new AggregateUnarySketchSPInstruction(cdop, in1, out, aggtype,
opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ if (input1.getDataType() == Types.DataType.MATRIX) {
+ processMatrixSketch(ec);
+ } else {
+ processTensorSketch(ec);
+ }
+ }
+
+ private void processMatrixSketch(ExecutionContext ec) {
+ SparkExecutionContext sec = (SparkExecutionContext)ec;
+
+ //get input
+ JavaPairRDD<MatrixIndexes, MatrixBlock> in =
sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName());
+ JavaPairRDD<MatrixIndexes, MatrixBlock> out = in;
+
+ // dir = RowCol and (dim1() > 1000 || dim2() > 1000)
+ if (aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
+
+ // Create a single sketch and derive approximate count distinct
from the sketch
+ JavaRDD<CorrMatrixBlock> out1 = out.map(new
AggregateUnarySketchCreateFunction(this.op));
+
+ // Using fold() instead of reduce() for stable aggregation
+ // Instantiating CorrMatrixBlock mutable buffer with empty matrix
block so that it can be serialized properly
+ CorrMatrixBlock out2 =
+ out1.fold(new CorrMatrixBlock(new MatrixBlock()),
+ new
AggregateUnarySketchUnionAllFunction(this.op));
+
+ MatrixBlock out3 =
LibMatrixCountDistinct.countDistinctValuesFromSketch(out2, this.op);
+
+ // put output block into symbol table (no lineage because single
block)
+ // this also includes implicit maintenance of matrix
characteristics
+ sec.setMatrixOutput(output.getName(), out3);
+ } else {
+
+ if (aggtype != AggBinaryOp.SparkAggType.NONE && aggtype !=
AggBinaryOp.SparkAggType.MULTI_BLOCK) {
+ throw new DMLRuntimeException(String.format("Unsupported
aggregation type: %s", aggtype));
+ }
+
+ JavaPairRDD<MatrixIndexes, MatrixBlock> out1;
+ JavaPairRDD<MatrixIndexes, CorrMatrixBlock> out2;
+ JavaPairRDD<MatrixIndexes, MatrixBlock> out3;
+
+ // dir = Row || Col || RowCol and (dim1() <= 1000 || dim2() <=
1000)
+ if (aggtype == AggBinaryOp.SparkAggType.NONE) {
+ // Input matrix is small enough for a single index, so there
is no need to execute index function.
+ // Reuse the CreateCombinerFunction(), although there is no
need to merge values within the same
+ // partition, or combiners across partitions for that matter.
+ out2 = out.mapValues(new
AggregateUnarySketchCreateCombinerFunction(this.op));
+
+ // aggType = MULTI_BLOCK: dir = Row || Col and (dim1() > 1000 ||
dim2() > 1000)
+ } else {
+ // Execute index function to group rows/columns together based
on aggregation direction
+ out1 = out.mapToPair(new RowColGroupingFunction(this.op));
+
+ // Create sketch per index
+ out2 = out1.combineByKey(new
AggregateUnarySketchCreateCombinerFunction(this.op),
+ new AggregateUnarySketchMergeValueFunction(this.op),
+ new
AggregateUnarySketchMergeCombinerFunction(this.op));
+ }
+
+ out3 = out2.mapValues(new
CalculateAggregateSketchFunction(this.op));
+
+ updateUnaryAggOutputDataCharacteristics(sec,
this.op.getIndexFunction());
+
+ // put output RDD handle into symbol table
+ sec.setRDDHandleForVariable(output.getName(), out3);
+ sec.addLineageRDD(output.getName(), input1.getName());
+ }
+ }
+
+ private void processTensorSketch(ExecutionContext ec) {
+ throw new NotImplementedException("Aggregate sketch instruction for
tensors has not been implemented yet.");
+ }
+
+ private static class AggregateUnarySketchCreateFunction implements
Function<Tuple2<MatrixIndexes, MatrixBlock>, CorrMatrixBlock> {
+ private static final long serialVersionUID = 7295176181965491548L;
+ private CountDistinctOperator op;
+
+ public AggregateUnarySketchCreateFunction(CountDistinctOperator op) {
+ this.op = op;
+ }
+
+ @Override
+ public CorrMatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
throws Exception {
+ MatrixIndexes ixIn = arg0._1();
+ MatrixBlock blkIn = arg0._2();
+
+ MatrixIndexes ixOut = new MatrixIndexes();
+ this.op.getIndexFunction().execute(ixIn, ixOut);
+
+ return LibMatrixCountDistinct.createSketch(blkIn, this.op);
+ }
+ }
+
+ private static class AggregateUnarySketchUnionAllFunction implements
Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
+ private static final long serialVersionUID = -3799519241499062936L;
+ private CountDistinctOperator op;
+
+ public AggregateUnarySketchUnionAllFunction(CountDistinctOperator op) {
+ this.op = op;
+ }
+
+ @Override
+ public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock
arg1) throws Exception {
+
+ // Input matrix blocks must have corresponding sketch metadata
+ if (arg0.getCorrection() == null && arg1.getCorrection() == null) {
+ throw new DMLRuntimeException("Corrupt sketch: metadata is
missing");
+ }
+
+ if ((arg0.getValue().getNumRows() == 0 &&
arg0.getValue().getNumColumns() == 0) || arg0.getCorrection() == null) {
+ arg0.set(arg1.getValue(), arg1.getCorrection());
+ return arg0;
+ } else if ((arg1.getValue().getNumRows() == 0 &&
arg1.getValue().getNumColumns() == 0) || arg1.getCorrection() == null) {
+ return arg0;
+ }
+
+ return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+ }
+ }
+
+ private static class RowColGroupingFunction implements
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
+
+ private static final long serialVersionUID = -3456633769452405482L;
+ private CountDistinctOperator _op;
+
+ public RowColGroupingFunction(CountDistinctOperator op) {
+ this._op = op;
+ }
+
+ @Override
+ public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes,
MatrixBlock> arg0) throws Exception {
+ MatrixIndexes idxIn = arg0._1();
+ MatrixBlock blkIn = arg0._2();
+
+ MatrixIndexes idxOut = new MatrixIndexes();
+ MatrixBlock blkOut = blkIn; // Do not create sketch yet
+ this._op.getIndexFunction().execute(idxIn, idxOut);
+
+ return new Tuple2<>(idxOut, blkOut);
+ }
+ }
+
+ private static class AggregateUnarySketchCreateCombinerFunction implements
Function<MatrixBlock, CorrMatrixBlock>
+ {
+ private static final long serialVersionUID = 8997980606986435297L;
+ private final CountDistinctOperator op;
+
+ private
AggregateUnarySketchCreateCombinerFunction(CountDistinctOperator op) {
+ this.op = op;
+ }
+
+ @Override
+ public CorrMatrixBlock call(MatrixBlock arg0)
+ throws Exception {
+
+ return LibMatrixCountDistinct.createSketch(arg0, this.op);
+ }
+ }
+
+ private static class AggregateUnarySketchMergeValueFunction implements
Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock>
+ {
+ private static final long serialVersionUID = -7006864809860460549L;
+ private CountDistinctOperator op;
+
+ public AggregateUnarySketchMergeValueFunction(CountDistinctOperator
op) {
+ this.op = op;
+ }
+
+ @Override
+ public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1)
throws Exception {
+ CorrMatrixBlock arg1WithCorr =
LibMatrixCountDistinct.createSketch(arg1, this.op);
+ return LibMatrixCountDistinct.unionSketch(arg0, arg1WithCorr,
this.op);
+ }
+ }
+
+ private static class AggregateUnarySketchMergeCombinerFunction implements
Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock>
+ {
+ private static final long serialVersionUID = 172215143740379070L;
+ private CountDistinctOperator op;
+
+ public AggregateUnarySketchMergeCombinerFunction(CountDistinctOperator
op) {
+ this.op = op;
+ }
+
+ @Override
+ public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock
arg1) throws Exception {
+ return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+ }
+ }
+
+ private static class CalculateAggregateSketchFunction implements
Function<CorrMatrixBlock, MatrixBlock>
+ {
+ private static final long serialVersionUID = 7504873483231717138L;
+ private CountDistinctOperator op;
+
+ public CalculateAggregateSketchFunction(CountDistinctOperator op) {
+ this.op = op;
+ }
+
+ @Override
+ public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
+ return LibMatrixCountDistinct.countDistinctValuesFromSketch(arg0,
this.op);
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
index c9935b1..830ba4d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
@@ -37,7 +37,7 @@ public abstract class SPInstruction extends Instruction {
CentralMoment, Covariance, QSort, QPick,
ParameterizedBuiltin, MAppend, RAppend, GAppend,
GAlignedAppend, Rand,
MatrixReshape, Ctable, Quaternary, CumsumAggregate,
CumsumOffset, BinUaggChain, UaggOuterChain,
- Write, SpoofFused, Dnn
+ Write, SpoofFused, Dnn, AggregateUnarySketch
}
protected final SPType _sptype;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index fa95aee..4b13abc 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -19,9 +19,7 @@
package org.apache.sysds.runtime.matrix.data;
-import java.util.Collections;
import java.util.HashSet;
-import java.util.PriorityQueue;
import java.util.Set;
import org.apache.commons.lang.NotImplementedException;
@@ -32,16 +30,17 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
-import
org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
-import org.apache.sysds.utils.Hash;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
import org.apache.sysds.utils.Hash.HashType;
/**
* This class contains various methods for counting the number of distinct
values inside a MatrixBlock
*/
-public class LibMatrixCountDistinct {
- private static final Log LOG =
LogFactory.getLog(LibMatrixCountDistinct.class.getName());
+public interface LibMatrixCountDistinct {
+ static final Log LOG =
LogFactory.getLog(LibMatrixCountDistinct.class.getName());
/**
* The minimum number NonZero of cells in the input before using
approximate techniques for counting number of
@@ -49,10 +48,6 @@ public class LibMatrixCountDistinct {
*/
public static int minimumSize = 1024;
- private LibMatrixCountDistinct() {
- // Prevent instantiation via private constructor.
- }
-
/**
* Public method to count the number of distinct values inside a
matrix. Depending on which CountDistinctOperator
* selected it either gets the absolute number or a estimated value.
@@ -61,7 +56,7 @@ public class LibMatrixCountDistinct {
*
* TODO: Add support for distributed spark operations
*
- * TODO: If the MatrixBlock type is CompressedMatrix, simply read the
vaules from the ColGroups.
+ * TODO: If the MatrixBlock type is CompressedMatrix, simply read the
values from the ColGroups.
*
* @param in the input matrix to count number distinct values in
* @param op the selected operator to use
@@ -69,11 +64,12 @@ public class LibMatrixCountDistinct {
*/
public static int estimateDistinctValues(MatrixBlock in,
CountDistinctOperator op) {
int res = 0;
- if(op.operatorType == CountDistinctTypes.KMV &&
- (op.hashType == HashType.ExpHash || op.hashType ==
HashType.StandardJava)) {
- throw new DMLException("Invalid hashing configuration
using " + op.hashType + " and " + op.operatorType);
+ if(op.getOperatorType() == CountDistinctOperatorTypes.KMV &&
+ (op.getHashType() == HashType.ExpHash ||
op.getHashType() == HashType.StandardJava)) {
+ throw new DMLException(
+ "Invalid hashing configuration using " +
op.getHashType() + " and " + op.getOperatorType());
}
- else if(op.operatorType == CountDistinctTypes.HLL) {
+ else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
{
throw new NotImplementedException("HyperLogLog not
implemented");
}
// shortcut in simplest case.
@@ -84,27 +80,27 @@ public class LibMatrixCountDistinct {
res = countDistinctValuesNaive(in);
}
else {
- switch(op.operatorType) {
+ switch(op.getOperatorType()) {
case COUNT:
res = countDistinctValuesNaive(in);
break;
case KMV:
- res = countDistinctValuesKVM(in, op);
+ res = new
KMVSketch(op).getScalarValue(in);
break;
default:
throw new DMLException("Invalid or not
implemented Estimator Type");
}
}
- if(res == 0)
+ if(res <= 0)
throw new DMLRuntimeException("Impossible estimate of
distinct values");
return res;
}
/**
- * Naive implementation of counting Distinct values.
+ * Naive implementation of counting distinct values.
*
- * Benefit Precise, but uses memory, on the scale of inputs number of
distinct values.
+ * Benefit: precise, but uses memory, on the scale of inputs number of
distinct values.
*
* @param in The input matrix to count number distinct values in
* @return The absolute distinct count
@@ -151,155 +147,35 @@ public class LibMatrixCountDistinct {
}
private static Set<Double> countDistinctValuesNaive(double[]
valuesPart, Set<Double> distinct) {
- for(double v : valuesPart) {
+ for(double v : valuesPart)
distinct.add(v);
- }
return distinct;
}
- /**
- * KMV synopsis(for k minimum values) Distinct-Value Estimation
- *
- * Kevin S. Beyer, Peter J. Haas, Berthold Reinwald, Yannis Sismanis,
Rainer Gemulla:
- *
- * On synopses for distinct‐value estimation under multiset operations.
SIGMOD 2007
- *
- * TODO: Add multi-threaded version
- *
- * @param in The Matrix Block to estimate the number of distinct values
in
- * @return The distinct count estimate
- */
- private static int countDistinctValuesKVM(MatrixBlock in,
CountDistinctOperator op) {
-
- // D is the number of possible distinct values in the
MatrixBlock.
- // plus 1 to take account of 0 input.
- long D = in.getNonZeros() + 1;
-
- /**
- * To ensure that the likelihood to hash to the same value we
need O(D^2) positions to hash to assign. If the
- * value is higher than int (which is the area we hash to) then
use Integer Max value as largest hashing space.
- */
- long tmp = D * D;
- int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE :
(int) tmp;
- LOG.debug("M not forced to int size: " + tmp);
- LOG.debug("M: " + M);
- /**
- * The estimator is asymptotically unbiased as k becomes large,
but memory usage also scales with k. Furthermore k
- * value must be within range: D >> k >> 0
- */
- int k = D > 64 ? 64 : (int) D;
- SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
-
- countDistinctValuesKVM(in, op.hashType, k, spq, M);
-
- LOG.debug("M: " + M);
- LOG.debug("smallest hash:" + spq.peek());
- LOG.debug("spq: " + spq.toString());
-
- if(spq.size() < k) {
- return spq.size();
- }
- else {
- double U_k = (double) spq.poll() / (double) M;
- LOG.debug("U_k : " + U_k);
- double estimate = (double) (k - 1) / U_k;
- LOG.debug("Estimate: " + estimate);
- double ceilEstimate = Math.min(estimate, (double) D);
- LOG.debug("Ceil worst case: " + D);
- return (int) ceilEstimate;
- }
- }
-
- private static void countDistinctValuesKVM(MatrixBlock in, HashType
hashType, int k, SmallestPriorityQueue spq,
- int m) {
- double[] data;
- if(in.isEmpty())
- spq.add(0);
- else if(in instanceof CompressedMatrixBlock)
- throw new NotImplementedException();
- else if(in.sparseBlock != null) {
- SparseBlock sb = in.sparseBlock;
- if(in.sparseBlock.isContiguous()) {
- data = sb.values(0);
- countDistinctValuesKVM(data, hashType, k, spq,
m);
- }
- else {
- for(int i = 0; i < in.getNumRows(); i++) {
- if(!sb.isEmpty(i)) {
- data = in.sparseBlock.values(i);
- countDistinctValuesKVM(data,
hashType, k, spq, m);
- }
- }
- }
- }
- else {
- DenseBlock db = in.denseBlock;
- final int bil = db.index(0);
- final int biu = db.index(in.rlen);
- for(int i = bil; i <= biu; i++) {
- data = db.valuesAt(i);
- countDistinctValuesKVM(data, hashType, k, spq,
m);
- }
- }
+ public static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock
arg0, CountDistinctOperator op) {
+ if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+ return new KMVSketch(op).getMatrixValue(arg0);
+ else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
+ throw new NotImplementedException("Not implemented
yet");
+ else
+ throw new NotImplementedException("Not implemented
yet");
}
- private static void countDistinctValuesKVM(double[] data, HashType
hashType, int k, SmallestPriorityQueue spq,
- int m) {
- for(double fullValue : data) {
- int hash = Hash.hash(fullValue, hashType);
- int v = (Math.abs(hash)) % (m - 1) + 1;
- spq.add(v);
- }
+ public static CorrMatrixBlock createSketch(MatrixBlock blkIn,
CountDistinctOperator op) {
+ if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+ return new KMVSketch(op).create(blkIn);
+ else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
+ throw new NotImplementedException("Not implemented
yet");
+ else
+ throw new NotImplementedException("Not implemented
yet");
}
- /**
- * Deceiving name, but is used to contain the k smallest values
inserted.
- *
- * TODO: add utility method to join two partitions
- *
- * TODO: Replace Standard Java Set and Priority Queue with optimized
versions.
- */
- private static class SmallestPriorityQueue {
- private Set<Integer> containedSet;
- private PriorityQueue<Integer> smallestHashes;
- private int k;
-
- public SmallestPriorityQueue(int k) {
- smallestHashes = new PriorityQueue<>(k,
Collections.reverseOrder());
- containedSet = new HashSet<>(1);
- this.k = k;
- }
-
- public void add(int v) {
- if(!containedSet.contains(v)) {
- if(smallestHashes.size() < k) {
- smallestHashes.add(v);
- containedSet.add(v);
- }
- else if(v < smallestHashes.peek()) {
- LOG.trace(smallestHashes.peek() + " --
" + v);
- smallestHashes.add(v);
- containedSet.add(v);
-
containedSet.remove(smallestHashes.poll());
- }
- }
- }
-
- public int size() {
- return smallestHashes.size();
- }
-
- public int peek() {
- return smallestHashes.peek();
- }
-
- public int poll() {
- return smallestHashes.poll();
- }
-
- @Override
- public String toString() {
- return smallestHashes.toString();
- }
+ public static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0,
CorrMatrixBlock arg1, CountDistinctOperator op) {
+ if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+ return new KMVSketch(op).union(arg0, arg1);
+ else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
+ throw new NotImplementedException("Not implemented
yet");
+ else
+ throw new NotImplementedException("Not implemented
yet");
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
new file mode 100644
index 0000000..f9c5f63
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch;
+
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public interface MatrixSketch<T> {
+
+ /**
+ * Get scalar distinct count from a input matrix block.
+ *
+ * @param blkIn A input block to estimate the number of distinct values
in
+ * @return The distinct count estimate
+ */
+ T getScalarValue(MatrixBlock blkIn);
+
+ /**
+ * Obtain matrix distinct count value from estimation Used for
estimating distinct in rows or columns.
+ *
+ * @param blkIn The sketch block to extract the count from
+ * @return The result matrix block
+ */
+ public MatrixBlock getMatrixValue(CorrMatrixBlock blkIn);
+
+ /**
+ * Create a initial sketch of a given block.
+ *
+ * @param blkIn A block to process
+ * @return A sketch
+ */
+ public CorrMatrixBlock create(MatrixBlock blkIn);
+
+ /**
+ * Union two sketches together to from a combined sketch.
+ *
+ * @param arg0 Sketch one
+ * @param arg1 Sketch two
+ * @return The combined sketch
+ */
+ public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock
arg1);
+
+ /**
+ * Intersect two sketches
+ *
+ * @param arg0 Sketch one
+ * @param arg1 Sketch two
+ * @return The intersected sketch
+ */
+ public CorrMatrixBlock intersection(CorrMatrixBlock arg0,
CorrMatrixBlock arg1);
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
new file mode 100644
index 0000000..9893e09
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+// Package private
+abstract class CountDistinctApproxSketch implements MatrixSketch<Integer> {
+ CountDistinctOperator op;
+
+ CountDistinctApproxSketch(Operator op) {
+ if(!(op instanceof CountDistinctOperator)) {
+ throw new DMLRuntimeException(
+ String.format("Cannot create %s with given
operator", CountDistinctApproxSketch.class.getSimpleName()));
+ }
+
+ this.op = (CountDistinctOperator) op;
+
+ if(this.op.getDirection() == null) {
+ throw new DMLRuntimeException("No direction was set for
the operator");
+ }
+
+ if(!this.op.getDirection().isRow() &&
!this.op.getDirection().isCol() && !this.op.getDirection().isRowCol()) {
+ throw new DMLRuntimeException(String.format("Unexpected
direction: %s", this.op.getDirection()));
+ }
+ }
+
+ protected void validateSketchMetadata(MatrixBlock corrBlock) {
+ // (nHashes, k, D) row vector
+ if(corrBlock.getNumColumns() < 3 || corrBlock.getValue(0, 0) <
0 || corrBlock.getValue(0, 1) < 0 ||
+ corrBlock.getValue(0, 2) < 0) {
+ throw new DMLRuntimeException("Sketch metadata is
corrupt");
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
new file mode 100644
index 0000000..01cfb28
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
@@ -0,0 +1,488 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.Hash;
+
+/**
+ * KMV synopsis(for k minimum values) Distinct-Value Estimation
+ *
+ * Kevin S. Beyer, Peter J. Haas, Berthold Reinwald, Yannis Sismanis, Rainer
Gemulla:
+ *
+ * On synopses for distinct‐value estimation under multiset operations. SIGMOD
2007
+ *
+ * TODO: Add multi-threaded version
+ *
+ */
+public class KMVSketch extends CountDistinctApproxSketch {
+
+ private static final Log LOG =
LogFactory.getLog(KMVSketch.class.getName());
+
+ public KMVSketch(Operator op) {
+ super(op);
+ }
+
+ @Override
+ public Integer getScalarValue(MatrixBlock in) {
+
+ // D is the number of possible distinct values in the
MatrixBlock.
+ // plus 1 to take account of 0 input.
+ long D = in.getNonZeros() + 1;
+
+ /**
+ * To ensure that the likelihood to hash to the same value we
need O(D^2) positions to hash to assign. If the
+ * value is higher than int (which is the area we hash to) then
use Integer Max value as largest hashing space.
+ */
+ long tmp = D * D;
+ int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE :
(int) tmp;
+ /**
+ * The estimator is asymptotically unbiased as k becomes large,
but memory usage also scales with k. Furthermore k
+ * value must be within range: D >> k >> 0
+ */
+ int k = D > 64 ? 64 : (int) D;
+
+ SmallestPriorityQueue spq = getKSmallestHashes(in, k, M);
+
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("M not forced to int size: " + tmp);
+ LOG.debug("M: " + M);
+ LOG.debug("M: " + M);
+ LOG.debug("kth smallest hash:" + spq.peek());
+ LOG.debug("spq: " + spq.toString());
+ }
+
+ if(spq.size() < k) {
+ return spq.size();
+ }
+ else {
+ double kthSmallestHash = spq.poll();
+ double U_k = kthSmallestHash / (double) M;
+ double estimate = (double) (k - 1) / U_k;
+ double ceilEstimate = Math.min(estimate, (double) D);
+
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("U_k : " + U_k);
+ LOG.debug("Estimate: " + estimate);
+ LOG.debug("Ceil worst case: " + D);
+ }
+ return (int) ceilEstimate;
+ }
+ }
+
+ private SmallestPriorityQueue getKSmallestHashes(MatrixBlock in, int k,
int M) {
+ SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
+ countDistinctValuesKMV(in, op.getHashType(), k, spq, M);
+
+ return spq;
+ }
+
+ private void countDistinctValuesKMV(MatrixBlock in, Hash.HashType
hashType, int k, SmallestPriorityQueue spq,
+ int m) {
+ double[] data;
+ if(in.isEmpty())
+ spq.add(0);
+ else if(in instanceof CompressedMatrixBlock)
+ throw new NotImplementedException("Cannot approximate
distinct count for compressed matrices");
+ else if(in.getSparseBlock() != null) {
+ SparseBlock sb = in.getSparseBlock();
+ if(sb.isContiguous()) {
+ data = sb.values(0);
+ countDistinctValuesKMV(data, hashType, k, spq,
m);
+ }
+ else {
+ for(int i = 0; i < in.getNumRows(); i++) {
+ if(!sb.isEmpty(i)) {
+ data = sb.values(i);
+ countDistinctValuesKMV(data,
hashType, k, spq, m);
+ }
+ }
+ }
+ }
+ else {
+ DenseBlock db = in.getDenseBlock();
+ final int bil = db.index(0);
+ final int biu = db.index(in.getNumRows());
+ for(int i = bil; i <= biu; i++) {
+ data = db.valuesAt(i);
+ countDistinctValuesKMV(data, hashType, k, spq,
m);
+ }
+ }
+ }
+
+ private void countDistinctValuesKMV(double[] data, Hash.HashType
hashType, int k, SmallestPriorityQueue spq, int m) {
+ for(double fullValue : data) {
+ int hash = Hash.hash(fullValue, hashType);
+ int v = (Math.abs(hash)) % (m - 1) + 1;
+ spq.add(v);
+ }
+ }
+
+ @Override
+ public MatrixBlock getMatrixValue(CorrMatrixBlock arg0) {
+ MatrixBlock blkIn = arg0.getValue();
+ if(op.getDirection() == Types.Direction.Row) {
+ // 1000 x 1 blkOut -> slice out the first column of the
matrix
+ MatrixBlock blkOut = blkIn.slice(0, blkIn.getNumRows()
- 1, 0, 0);
+ for(int i = 0; i < blkIn.getNumRows(); ++i) {
+ getDistinctCountFromSketchByIndex(arg0, i,
blkOut);
+ }
+
+ return blkOut;
+ }
+ else if(op.getDirection() == Types.Direction.Col) {
+ // 1 x 1000 blkOut -> slice out the first row of the
matrix
+ MatrixBlock blkOut = blkIn.slice(0, 0, 0,
blkIn.getNumColumns() - 1);
+ for(int j = 0; j < blkIn.getNumColumns(); ++j) {
+ getDistinctCountFromSketchByIndex(arg0, j,
blkOut);
+ }
+
+ return blkOut;
+ }
+ else { // op.getDirection().isRowCol()
+
+ // 1 x 1 blkOut -> slice out the first row and column
of the matrix
+ MatrixBlock blkOut = blkIn.slice(0, 0, 0, 0);
+ getDistinctCountFromSketchByIndex(arg0, 0, blkOut);
+
+ return blkOut;
+ }
+ }
+
+ private void getDistinctCountFromSketchByIndex(CorrMatrixBlock arg0,
int idx, MatrixBlock blkOut) {
+ MatrixBlock blkIn = arg0.getValue();
+ MatrixBlock blkInCorr = arg0.getCorrection();
+
+ if(op.getOperatorType() == CountDistinctOperatorTypes.KMV) {
+ double kthSmallestHash;
+ if(op.getDirection().isRow() ||
op.getDirection().isRowCol()) {
+ kthSmallestHash = blkIn.getValue(idx, 0);
+ }
+ else { // op.getDirection().isCol()
+ kthSmallestHash = blkIn.getValue(0, idx);
+ }
+
+ double nHashes = blkInCorr.getValue(idx, 0);
+ double k = blkInCorr.getValue(idx, 1);
+ double D = blkInCorr.getValue(idx, 2);
+
+ double D2 = D * D;
+ double M = (D2 > (long) Integer.MAX_VALUE) ?
Integer.MAX_VALUE : D2;
+
+ double ceilEstimate;
+ if(nHashes != 0 && nHashes < k) {
+ ceilEstimate = nHashes;
+ }
+ else if(nHashes == 0) {
+ ceilEstimate = 1;
+ }
+ else {
+ double U_k = kthSmallestHash / M;
+ double estimate = (k - 1) / U_k;
+ ceilEstimate = Math.min(estimate, D);
+ }
+
+ if(op.getDirection().isRow() ||
op.getDirection().isRowCol()) {
+ blkOut.setValue(idx, 0, ceilEstimate);
+ }
+ else { // op.getDirection().isCol()
+ blkOut.setValue(0, idx, ceilEstimate);
+ }
+ }
+ }
+
+ // Create sketch
+ @Override
+ public CorrMatrixBlock create(MatrixBlock blkIn) {
+
+ // We need a matrix containing sketch metadata per block
+ // N x 3 row vector: (nHashes, k, D)
+ // O(N) extra space
+
+ if(op.getDirection().isRowCol()) {
+ // (nHashes, k, D) row matrix
+ MatrixBlock blkOut = new MatrixBlock(blkIn);
+ MatrixBlock blkOutCorr = new MatrixBlock(1, 3, false);
+ createSketchByIndex(blkIn, blkOutCorr, 0, blkOut);
+ return new CorrMatrixBlock(blkOut, blkOutCorr);
+ }
+ else if(op.getDirection().isRow()) {
+ MatrixBlock blkOut = blkIn;
+ MatrixBlock blkOutCorr = new
MatrixBlock(blkIn.getNumRows(), 3, false);
+ // (nHashes, k, D) row matrix
+ for(int i = 0; i < blkIn.getNumRows(); ++i) {
+ createSketchByIndex(blkOut, blkOutCorr, i);
+ }
+ return new CorrMatrixBlock(blkOut, blkOutCorr);
+
+ }
+ else if(op.getDirection().isCol()) {
+ MatrixBlock blkOut = blkIn;
+ // (nHashes, k, D) row matrix
+ MatrixBlock blkOutCorr = new
MatrixBlock(blkIn.getNumColumns(), 3, false);
+ for(int j = 0; j < blkIn.getNumColumns(); ++j) {
+ createSketchByIndex(blkOut, blkOutCorr, j);
+ }
+ return new CorrMatrixBlock(blkOut, blkOutCorr);
+ }
+ else {
+ throw new DMLRuntimeException(String.format("Unexpected
direction: %s", op.getDirection()));
+ }
+ }
+
+ private MatrixBlock sliceMatrixBlockByIndexDirection(MatrixBlock blkIn,
int idx) {
+ MatrixBlock blkInSlice;
+ if(op.getDirection().isRow()) {
+ blkInSlice = blkIn.slice(idx, idx);
+ }
+ else if(op.getDirection().isCol()) {
+ blkInSlice = blkIn.slice(0, blkIn.getNumRows() - 1,
idx, idx);
+ }
+ else {
+ blkInSlice = blkIn;
+ }
+
+ return blkInSlice;
+ }
+
+ private void createSketchByIndex(MatrixBlock blkIn, MatrixBlock
sketchMetaMB, int idx) {
+ createSketchByIndex(blkIn, sketchMetaMB, idx, null);
+ }
+
+ private void createSketchByIndex(MatrixBlock blkIn, MatrixBlock
sketchMetaMB, int idx, MatrixBlock blkOut) {
+
+ MatrixBlock sketchMB = (blkOut == null) ? blkIn : blkOut;
+
+ MatrixBlock blkInSlice =
sliceMatrixBlockByIndexDirection(blkIn, idx);
+ long D = blkInSlice.getNonZeros() + 1;
+
+ long D2 = D * D;
+ int M = (D2 > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE :
(int) D2;
+ int k = D > 64 ? 64 : (int) D;
+
+ // blkOut is only passed as parameter in case dir == RowCol
+ // This means that the entire block will produce a single 1xK
sketch-
+ // The output matrix block must be resized and filled with 0
accordingly
+ if(blkOut != null) {
+ sketchMB.reset(1, k);
+ }
+
+ if(blkInSlice.getLength() == 1 || blkInSlice.isEmpty()) {
+
+ // There can only be 1 distinct value for a 1x1 or
empty matrix
+ // getMatrixValue() will short circuit and return 1 if
nHashes = 0
+
+ // (nHashes, k, D) row matrix
+ sketchMetaMB.setValue(idx, 0, 0);
+ sketchMetaMB.setValue(idx, 1, k);
+ sketchMetaMB.setValue(idx, 2, D);
+
+ return;
+ }
+
+ SmallestPriorityQueue spq = getKSmallestHashes(blkInSlice, k,
M);
+ int nHashes = spq.size();
+ assert (nHashes > 0);
+
+ // nHashes != k always
+
+ int i = 0;
+ while(!spq.isEmpty()) {
+ double toInsert = spq.poll();
+ if(op.getDirection().isRow()) {
+ sketchMB.setValue(idx, i, toInsert);
+ }
+ else if(op.getDirection().isCol()) {
+ sketchMB.setValue(i, idx, toInsert);
+ }
+ else {
+ sketchMB.setValue(idx, i, toInsert);
+ }
+ ++i;
+ }
+
+ // Last column contains the correction
+ sketchMetaMB.setValue(idx, 0, nHashes);
+ sketchMetaMB.setValue(idx, 1, k);
+ sketchMetaMB.setValue(idx, 2, D);
+ }
+
+ @Override
+ public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock
arg1) {
+
+ // Both matrices are guaranteed to be row-/column-aligned
+ MatrixBlock matrix0 = arg0.getValue();
+ MatrixBlock matrix1 = arg1.getValue();
+
+ if(op.getDirection().isRow()) {
+ // Use the wider of the 2 inputs for stable aggregation.
+ // The number of rows is always guaranteed to match due
to col index function execution.
+ // Therefore, checking the number of columns is
sufficient.
+ MatrixBlock combined;
+ if(matrix0.getNumColumns() > matrix1.getNumColumns()) {
+ combined = matrix0;
+ }
+ else {
+ combined = matrix1;
+ }
+ // (nHashes, k, D)
+ MatrixBlock combinedCorr = new
MatrixBlock(matrix0.getNumRows(), 3, false);
+
+ CorrMatrixBlock blkout = new CorrMatrixBlock(combined,
combinedCorr);
+ for(int i = 0; i < matrix0.getNumRows(); ++i) {
+ unionSketchByIndex(arg0, arg1, i, blkout);
+ }
+
+ return blkout;
+
+ }
+ else if(op.getDirection().isCol()) {
+ // Use the taller of the 2 inputs for stable
aggregation.
+ // The number of columns is always guaranteed to match
due to col index function execution.
+ // Therefore, checking the number of rows is sufficient.
+ MatrixBlock combined;
+ if(matrix0.getNumRows() > matrix1.getNumRows()) {
+ combined = matrix0;
+ }
+ else {
+ combined = matrix1;
+ }
+ // (nHashes, k, D) row vector
+ MatrixBlock combinedCorr = new
MatrixBlock(matrix0.getNumColumns(), 3, false);
+
+ CorrMatrixBlock blkOut = new CorrMatrixBlock(combined,
combinedCorr);
+ for(int j = 0; j < matrix0.getNumColumns(); ++j) {
+ unionSketchByIndex(arg0, arg1, j, blkOut);
+ }
+
+ return blkOut;
+
+ }
+ else { // op.getDirection().isRowCol()
+
+ // Use the wider of the 2 inputs for stable aggregation.
+ // The number of rows is always guaranteed to match due
to col index function execution.
+ // Therefore, checking the number of columns is
sufficient.
+ MatrixBlock combined;
+ if(matrix0.getNumColumns() > matrix1.getNumColumns()) {
+ combined = matrix0;
+ }
+ else {
+ combined = matrix1;
+ }
+ // (nHashes, k, D)
+ MatrixBlock combinedCorr = new MatrixBlock(1, 3, false);
+
+ CorrMatrixBlock blkOut = new CorrMatrixBlock(combined,
combinedCorr);
+ unionSketchByIndex(arg0, arg1, 0, blkOut);
+
+ return blkOut;
+ }
+ }
+
+ public void unionSketchByIndex(CorrMatrixBlock arg0, CorrMatrixBlock
arg1, int idx, CorrMatrixBlock blkOut) {
+ MatrixBlock corr0 = arg0.getCorrection();
+ MatrixBlock corr1 = arg1.getCorrection();
+
+ validateSketchMetadata(corr0);
+ validateSketchMetadata(corr1);
+
+ // Both matrices are guaranteed to be row-/column-aligned
+ MatrixBlock matrix0 = arg0.getValue();
+ MatrixBlock matrix1 = arg1.getValue();
+
+ if((op.getDirection().isRow() && matrix0.getNumRows() !=
matrix1.getNumRows()) ||
+ (op.getDirection().isCol() && matrix0.getNumColumns()
!= matrix1.getNumColumns())) {
+ throw new DMLRuntimeException("Cannot take the union of
sketches: rows/columns are not aligned");
+ }
+
+ MatrixBlock combined = blkOut.getValue();
+ MatrixBlock combinedCorr = blkOut.getCorrection();
+
+ double nHashes0 = corr0.getValue(idx, 0);
+ double k0 = corr0.getValue(idx, 1);
+ double D0 = corr0.getValue(idx, 2);
+
+ double nHashes1 = corr1.getValue(idx, 0);
+ double k1 = corr1.getValue(idx, 1);
+ double D1 = corr1.getValue(idx, 2);
+
+ double nHashes = Math.max(nHashes0, nHashes1);
+ double k = Math.max(k0, k1);
+ double D = D0 + D1 - 1;
+
+ SmallestPriorityQueue hashUnion = new
SmallestPriorityQueue((int) nHashes);
+
+ for(int i = 0; i < nHashes0; ++i) {
+ double val;
+ if(op.getDirection().isRow() ||
op.getDirection().isRowCol()) {
+ val = matrix0.getValue(idx, i);
+ }
+ else { // op.getDirection().isCol()
+ val = matrix0.getValue(i, idx);
+ }
+ hashUnion.add(val);
+ }
+
+ for(int i = 0; i < nHashes1; ++i) {
+ double val;
+ if(op.getDirection().isRow() ||
op.getDirection().isRowCol()) {
+ val = matrix1.getValue(idx, i);
+ }
+ else { // op.getDirection().isCol()
+ val = matrix1.getValue(i, idx);
+ }
+ hashUnion.add(val);
+ }
+
+ int i = 0;
+ while(!hashUnion.isEmpty()) {
+ double val = hashUnion.poll();
+ if(op.getDirection().isRow() ||
op.getDirection().isRowCol()) {
+ combined.setValue(idx, i, val);
+ }
+ else { // op.getDirection().isCol()
+ combined.setValue(i, idx, val);
+ }
+ i++;
+ }
+
+ combinedCorr.setValue(idx, 0, nHashes);
+ combinedCorr.setValue(idx, 1, k);
+ combinedCorr.setValue(idx, 2, D);
+ }
+
+ @Override
+ public CorrMatrixBlock intersection(CorrMatrixBlock arg0,
CorrMatrixBlock arg1) {
+ throw new NotImplementedException(
+ String.format("%s intersection has not been implemented
yet", KMVSketch.class.getSimpleName()));
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
new file mode 100644
index 0000000..0a29028
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.PriorityQueue;
+import java.util.Set;
+
+/**
+ * Deceiving name, but is used to contain the k smallest values inserted.
+ *
+ * TODO: Replace Standard Java Set and Priority Queue with optimized versions.
+ */
+public class SmallestPriorityQueue {
+ private static final Log LOG =
LogFactory.getLog(SmallestPriorityQueue.class.getName());
+
+ private Set<Double> containedSet;
+ private PriorityQueue<Double> smallestHashes;
+ private int k;
+
+ public SmallestPriorityQueue(int k) {
+ smallestHashes = new PriorityQueue<>(k,
Collections.reverseOrder());
+ containedSet = new HashSet<>(1);
+ this.k = k;
+ }
+
+ public void add(double v) {
+ if(!containedSet.contains(v)) {
+ if(smallestHashes.size() < k) {
+ smallestHashes.add(v);
+ containedSet.add(v);
+ }
+ else if(v < smallestHashes.peek()) {
+ LOG.trace(smallestHashes.peek() + " -- " + v);
+ smallestHashes.add(v);
+ containedSet.add(v);
+ double largest = smallestHashes.poll();
+ containedSet.remove(largest);
+ }
+ }
+ }
+
+ public int size() {
+ return smallestHashes.size();
+ }
+
+ public double peek() {
+ return smallestHashes.peek();
+ }
+
+ public double poll() {
+ return smallestHashes.poll();
+ }
+
+ public boolean isEmpty() {
+ return this.size() == 0;
+ }
+
+ @Override
+ public String toString() {
+ return smallestHashes.toString();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
index 3f63ef9..1c430c9 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
@@ -19,30 +19,28 @@
package org.apache.sysds.runtime.matrix.operators;
+import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.IndexFunction;
import
org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType;
import org.apache.sysds.utils.Hash.HashType;
public class CountDistinctOperator extends Operator {
private static final long serialVersionUID = 7615123453265129670L;
- public final CountDistinctTypes operatorType;
- public final HashType hashType;
-
- public enum CountDistinctTypes { // The different supported types of
counting.
- COUNT, // Baseline naive implementation, iterate though, add to
hashMap.
- KMV, // K-Minimum Values algorithm.
- HLL // HyperLogLog algorithm.
- }
+ private final CountDistinctOperatorTypes operatorType;
+ private final HashType hashType;
+ private Types.Direction direction;
+ private IndexFunction indexFunction;
public CountDistinctOperator(AUType opType) {
super(true);
- switch (opType) {
+ switch(opType) {
case COUNT_DISTINCT:
- this.operatorType = CountDistinctTypes.COUNT;
+ this.operatorType =
CountDistinctOperatorTypes.COUNT;
break;
case COUNT_DISTINCT_APPROX:
- this.operatorType = CountDistinctTypes.KMV;
+ this.operatorType =
CountDistinctOperatorTypes.KMV;
break;
default:
throw new DMLRuntimeException(opType + " not
supported for CountDistinct Operator");
@@ -50,15 +48,49 @@ public class CountDistinctOperator extends Operator {
this.hashType = HashType.LinearHash;
}
- public CountDistinctOperator(CountDistinctTypes operatorType) {
+ public CountDistinctOperator(CountDistinctOperatorTypes operatorType) {
super(true);
this.operatorType = operatorType;
this.hashType = HashType.StandardJava;
}
- public CountDistinctOperator(CountDistinctTypes operatorType, HashType
hashType) {
+ public CountDistinctOperator(CountDistinctOperatorTypes operatorType,
HashType hashType) {
+ super(true);
+ this.operatorType = operatorType;
+ this.hashType = hashType;
+ }
+
+ public CountDistinctOperator(CountDistinctOperatorTypes operatorType,
IndexFunction indexFunction,
+ HashType hashType) {
super(true);
this.operatorType = operatorType;
+ this.indexFunction = indexFunction;
this.hashType = hashType;
}
-}
\ No newline at end of file
+
+ public CountDistinctOperatorTypes getOperatorType() {
+ return operatorType;
+ }
+
+ public HashType getHashType() {
+ return hashType;
+ }
+
+ public IndexFunction getIndexFunction() {
+ return indexFunction;
+ }
+
+ public CountDistinctOperator setIndexFunction(IndexFunction
indexFunction) {
+ this.indexFunction = indexFunction;
+ return this;
+ }
+
+ public Types.Direction getDirection() {
+ return direction;
+ }
+
+ public CountDistinctOperator setDirection(Types.Direction direction) {
+ this.direction = direction;
+ return this;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperatorTypes.java
similarity index 53%
copy from
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
copy to
src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperatorTypes.java
index 9581bc8..520b02e 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperatorTypes.java
@@ -17,33 +17,10 @@
* under the License.
*/
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.runtime.matrix.operators;
-import org.apache.sysds.common.Types.ExecType;
-import org.junit.Test;
-
-public class CountDistinct extends CountDistinctBase {
-
- public String TEST_NAME = "countDistinct";
- public String TEST_DIR = "functions/countDistinct/";
- public String TEST_CLASS_DIR = TEST_DIR +
CountDistinct.class.getSimpleName() + "/";
-
- protected String getTestClassDir() {
- return TEST_CLASS_DIR;
- }
-
- protected String getTestName() {
- return TEST_NAME;
- }
-
- protected String getTestDir() {
- return TEST_DIR;
- }
-
- @Test
- public void testSimple1by1() {
- // test simple 1 by 1.
- ExecType ex = ExecType.CP;
- countDistinctTest(1, 1, 1, 1.0, ex, 0.00001);
- }
-}
\ No newline at end of file
+public enum CountDistinctOperatorTypes { // The different supported types of
counting.
+ COUNT, // Baseline naive implementation, iterate through, add to
hashMap.
+ KMV, // K-Minimum Values algorithm.
+ HLL // HyperLogLog algorithm.
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
index 308aaaa..cd20b67 100644
---
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
+++
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
@@ -20,6 +20,7 @@
package org.apache.sysds.test.component.matrix;
import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.ArrayList;
@@ -27,10 +28,11 @@ import java.util.Collection;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
-import
org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Hash.HashType;
@@ -42,9 +44,9 @@ import org.junit.runners.Parameterized.Parameters;
@RunWith(value = Parameterized.class)
public class CountDistinctTest {
- private static CountDistinctTypes[] esT = new CountDistinctTypes[] {
+ private static CountDistinctOperatorTypes[] esT = new
CountDistinctOperatorTypes[] {
// The different types of Estimators
- CountDistinctTypes.COUNT, CountDistinctTypes.KMV,
CountDistinctTypes.HLL};
+ CountDistinctOperatorTypes.COUNT,
CountDistinctOperatorTypes.KMV, CountDistinctOperatorTypes.HLL};
@Parameters
public static Collection<Object[]> data() {
@@ -86,26 +88,26 @@ public class CountDistinctTest {
inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(1024,
10241, 0, 3000, 0.1, 7)));
actualUnique.add(3000L);
- for(CountDistinctTypes et : esT) {
+ for(CountDistinctOperatorTypes et : esT) {
for(HashType ht : HashType.values()) {
- if((ht == HashType.ExpHash && et ==
CountDistinctTypes.KMV) ||
- (ht == HashType.StandardJava && et ==
CountDistinctTypes.KMV)) {
+ if((ht == HashType.ExpHash && et ==
CountDistinctOperatorTypes.KMV) ||
+ (ht == HashType.StandardJava && et ==
CountDistinctOperatorTypes.KMV)) {
String errorMessage = "Invalid hashing
configuration using " + ht + " and " + et;
- tests.add(new Object[] {et,
inputs.get(0), actualUnique.get(0), ht, new DMLException(),
- errorMessage, 0.0});
+ tests.add(
+ new Object[] {et,
inputs.get(0), actualUnique.get(0), ht, new DMLException(), errorMessage,
0.0});
}
- else if(et == CountDistinctTypes.HLL) {
+ else if(et == CountDistinctOperatorTypes.HLL) {
tests.add(new Object[] {et,
inputs.get(0), actualUnique.get(0), ht, new NotImplementedException(),
"HyperLogLog not implemented",
0.0});
}
- else if(et != CountDistinctTypes.COUNT) {
+ else if(et != CountDistinctOperatorTypes.COUNT)
{
for(int i = 0; i < inputs.size(); i++) {
// allowing the estimate to be
15% off
tests.add(new Object[] {et,
inputs.get(i), actualUnique.get(i), ht, null, null, 0.15});
}
}
}
- if(et == CountDistinctTypes.COUNT) {
+ if(et == CountDistinctOperatorTypes.COUNT) {
for(int i = 0; i < inputs.size(); i++) {
tests.add(new Object[] {et,
inputs.get(i), actualUnique.get(i), null, null, null, 0.0001});
}
@@ -115,7 +117,7 @@ public class CountDistinctTest {
}
@Parameterized.Parameter
- public CountDistinctTypes et;
+ public CountDistinctOperatorTypes et;
@Parameterized.Parameter(1)
public MatrixBlock in;
@Parameterized.Parameter(2)
@@ -135,36 +137,28 @@ public class CountDistinctTest {
@Test
public void testEstimation() {
-
- Integer out = 0;
- CountDistinctOperator op = new CountDistinctOperator(et, ht);
try {
- if(expectedException != null){
- assertThrows(expectedException.getClass(), ()
-> {LibMatrixCountDistinct.estimateDistinctValues(in, op);});
- return;
+ CountDistinctOperator op = new
CountDistinctOperator(et, ht).setDirection(Types.Direction.RowCol);
+ if(expectedException != null) {
+ assertThrows(expectedException.getClass(), ()
-> {
+
LibMatrixCountDistinct.estimateDistinctValues(in, op);
+ });
+ }
+ else {
+ int out =
LibMatrixCountDistinct.estimateDistinctValues(in, op);
+ int count = out;
+ boolean success = Math.abs(nrUnique - count) <=
nrUnique * epsilon;
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.toString());
+ sb.append("\n" + count + " unique values,
actual:" + nrUnique + " with eps of " + epsilon);
+ assertTrue(sb.toString(), success);
}
- else
- out =
LibMatrixCountDistinct.estimateDistinctValues(in, op);
- }
- catch(DMLException e) {
- throw e;
- }
- catch(NotImplementedException e) {
- throw e;
}
catch(Exception e) {
e.printStackTrace();
fail(this.toString());
}
- int count = out;
- boolean success = Math.abs(nrUnique - count) <= nrUnique *
epsilon;
- if(!success){
- StringBuilder sb = new StringBuilder();
- sb.append(this.toString());
- sb.append("\n" + count + " unique values, actual:" +
nrUnique + " with eps of " + epsilon);
- fail(sb.toString());
- }
}
@Override
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
similarity index 65%
rename from
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
rename to
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
index a15019d..e808cb5 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
@@ -19,32 +19,13 @@
package org.apache.sysds.test.functions.countDistinct;
-import org.apache.sysds.common.Types.ExecType;
-import org.junit.Test;
+import org.apache.sysds.common.Types;
-public class CountDistinctApprox extends CountDistinctBase {
+public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
- private final static String TEST_NAME = "countDistinctApprox";
+ private final static String TEST_NAME = "countDistinctApproxCol";
private final static String TEST_DIR = "functions/countDistinct/";
- private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctApprox.class.getSimpleName() + "/";
-
- public CountDistinctApprox() {
- percentTolerance = 0.1;
- }
-
- @Test
- public void testXXLarge() {
- ExecType ex = ExecType.CP;
- double tolerance = 9000 * percentTolerance;
- countDistinctTest(9000, 10000, 5000, 0.1, ex, tolerance);
- }
-
- @Test
- public void testSparse500Unique(){
- ExecType ex = ExecType.CP;
- double tolerance = 0.00001 + 120 * percentTolerance;
- countDistinctTest(500, 100, 100000, 0.1, ex, tolerance);
- }
+ private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctApproxCol.class.getSimpleName() + "/";
@Override
protected String getTestClassDir() {
@@ -60,4 +41,14 @@ public class CountDistinctApprox extends CountDistinctBase {
protected String getTestDir() {
return TEST_DIR;
}
+
+ @Override
+ protected Types.Direction getDirection() {
+ return Types.Direction.Col;
+ }
+
+ @Override
+ public void setUp() {
+ super.addTestConfiguration();
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
similarity index 65%
copy from
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
copy to
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
index 9581bc8..05a1256 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
@@ -19,31 +19,36 @@
package org.apache.sysds.test.functions.countDistinct;
-import org.apache.sysds.common.Types.ExecType;
-import org.junit.Test;
+import org.apache.sysds.common.Types;
-public class CountDistinct extends CountDistinctBase {
+public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
- public String TEST_NAME = "countDistinct";
- public String TEST_DIR = "functions/countDistinct/";
- public String TEST_CLASS_DIR = TEST_DIR +
CountDistinct.class.getSimpleName() + "/";
+ private final static String TEST_NAME = "countDistinctApproxRow";
+ private final static String TEST_DIR = "functions/countDistinct/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctApproxRow.class.getSimpleName() + "/";
+ @Override
protected String getTestClassDir() {
return TEST_CLASS_DIR;
}
+ @Override
protected String getTestName() {
return TEST_NAME;
}
+ @Override
protected String getTestDir() {
return TEST_DIR;
}
- @Test
- public void testSimple1by1() {
- // test simple 1 by 1.
- ExecType ex = ExecType.CP;
- countDistinctTest(1, 1, 1, 1.0, ex, 0.00001);
+ @Override
+ protected Types.Direction getDirection() {
+ return Types.Direction.Row;
}
-}
\ No newline at end of file
+
+ @Override
+ public void setUp() {
+ super.addTestConfiguration();
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
new file mode 100644
index 0000000..e59b002
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types.ExecType;
+import org.junit.Test;
+
+public class CountDistinctApproxRowCol extends CountDistinctRowColBase {
+
+ private final static String TEST_NAME = "countDistinctApproxRowCol";
+ private final static String TEST_DIR = "functions/countDistinct/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctApproxRowCol.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ super.addTestConfiguration();
+ super.percentTolerance = 0.2;
+ }
+
+ @Test
+ public void testCPSparseLarge() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 9000 * percentTolerance;
+ countDistinctScalarTest(9000, 10000, 5000, 0.1, ex, tolerance);
+ }
+
+ @Test
+ public void testSparkSparseLarge() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = 9000 * percentTolerance;
+ countDistinctScalarTest(9000, 10000, 5000, 0.1, ex, tolerance);
+ }
+
+ @Test
+ public void testCPSparseSmall() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 9000 * percentTolerance;
+ countDistinctScalarTest(9000, 999, 999, 0.1, ex, tolerance);
+ }
+
+ @Test
+ public void testSparkSparseSmall() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = 9000 * percentTolerance;
+ countDistinctScalarTest(9000, 999, 999, 0.1, ex, tolerance);
+ }
+
+ @Test
+ public void testCPDenseXSmall() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 5 * percentTolerance;
+ countDistinctScalarTest(5, 5, 10, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testSparkDenseXSmall() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = 5 * percentTolerance;
+ countDistinctScalarTest(5, 10, 5, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testCPEmpty() {
+ ExecType ex = ExecType.CP;
+ countDistinctScalarTest(1, 0, 0, 0.1, ex, 0);
+ }
+
+ @Test
+ public void testSparkEmpty() {
+ ExecType ex = ExecType.SPARK;
+ countDistinctScalarTest(1, 0, 0, 0.1, ex, 0);
+ }
+
+ @Test
+ public void testCPSingleValue() {
+ ExecType ex = ExecType.CP;
+ countDistinctScalarTest(1, 1, 1, 1.0, ex, 0);
+ }
+
+ @Test
+ public void testSparkSingleValue() {
+ ExecType ex = ExecType.SPARK;
+ countDistinctScalarTest(1, 1, 1, 1.0, ex, 0);
+ }
+
+ // Corresponding execType=SPARK tests for CP tests in base class
+ //
+ @Test
+ public void testSparkDense1Unique() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = 0.00001;
+ countDistinctScalarTest(1, 100, 1000, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testSparkDense2Unique() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = 0.00001;
+ countDistinctScalarTest(2, 100, 1000, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testSparkDense120Unique() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = 0.00001 + 120 * percentTolerance;
+ countDistinctScalarTest(120, 100, 1000, 1.0, ex, tolerance);
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
index 9d7e940..041cf51 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -22,13 +22,13 @@ package org.apache.sysds.test.functions.countDistinct;
import static org.junit.Assert.assertTrue;
import org.apache.sysds.common.Types;
-import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.junit.Test;
public abstract class CountDistinctBase extends AutomatedTestBase {
+ protected double percentTolerance = 0.0;
+ protected double baseTolerance = 0.0001;
protected abstract String getTestClassDir();
@@ -36,86 +36,47 @@ public abstract class CountDistinctBase extends
AutomatedTestBase {
protected abstract String getTestDir();
- @Override
- public void setUp() {
+ protected void addTestConfiguration() {
TestUtils.clearAssertionInformation();
addTestConfiguration(getTestName(),
new TestConfiguration(getTestClassDir(), getTestName(),
new String[] {"A.scalar"}));
}
- protected double percentTolerance = 0.0;
- protected double baseTolerance = 0.0001;
-
- @Test
- public void testSmall() {
- ExecType ex = ExecType.CP;
- double tolerance = baseTolerance + 50 * percentTolerance;
- countDistinctTest(50, 50, 50, 1.0, ex, tolerance);
- }
-
- @Test
- public void testLarge() {
- ExecType ex = ExecType.CP;
- double tolerance = baseTolerance + 800 * percentTolerance;
- countDistinctTest(800, 1000, 1000, 1.0, ex, tolerance);
- }
-
- @Test
- public void testXLarge() {
- ExecType ex = ExecType.CP;
- double tolerance = baseTolerance + 1723 * percentTolerance;
- countDistinctTest(1723, 5000, 2000, 1.0, ex, tolerance);
- }
-
- @Test
- public void test1Unique() {
- ExecType ex = ExecType.CP;
- double tolerance = 0.00001;
- countDistinctTest(1, 100, 1000, 1.0, ex, tolerance);
- }
-
- @Test
- public void test2Unique() {
- ExecType ex = ExecType.CP;
- double tolerance = 0.00001;
- countDistinctTest(2, 100, 1000, 1.0, ex, tolerance);
- }
+ @Override
+ public abstract void setUp();
- @Test
- public void test120Unique() {
- ExecType ex = ExecType.CP;
- double tolerance = 0.00001 + 120 * percentTolerance;
- countDistinctTest(120, 100, 1000, 1.0, ex, tolerance);
+ public void countDistinctScalarTest(long numberDistinct, int cols, int
rows, double sparsity,
+ Types.ExecType instType, double tolerance) {
+ countDistinctTest(Types.Direction.RowCol, numberDistinct, cols,
rows, sparsity, instType, tolerance);
}
- @Test
- public void testSparse500Unique() {
- ExecType ex = ExecType.CP;
- double tolerance = 0.00001 + 500 * percentTolerance;
- countDistinctTest(500, 100, 640000, 0.1, ex, tolerance);
+ public void countDistinctMatrixTest(Types.Direction dir, long
numberDistinct, int cols, int rows, double sparsity,
+ Types.ExecType instType, double tolerance) {
+ countDistinctTest(dir, numberDistinct, cols, rows, sparsity,
instType, tolerance);
}
- @Test
- public void testSparse120Unique(){
- ExecType ex = ExecType.CP;
- double tolerance = 0.00001 + 120 * percentTolerance;
- countDistinctTest(120, 100, 64000, 0.1, ex, tolerance);
- }
+ public void countDistinctTest(Types.Direction dir, long numberDistinct,
int cols, int rows, double sparsity,
+ Types.ExecType instType, double tolerance) {
- public void countDistinctTest(int numberDistinct, int cols, int rows,
double sparsity,
- ExecType instType, double tolerance) {
Types.ExecMode platformOld = setExecMode(instType);
try {
loadTestConfiguration(getTestConfiguration(getTestName()));
String HOME = SCRIPT_DIR + getTestDir();
fullDMLScriptName = HOME + getTestName() + ".dml";
- String out = output("A");
- System.out.println(out);
+ String outputPath = output("A");
+
programArgs = new String[] {"-args",
String.valueOf(numberDistinct), String.valueOf(rows),
- String.valueOf(cols), String.valueOf(sparsity),
out};
+ String.valueOf(cols), String.valueOf(sparsity),
outputPath};
runTest(true, false, null, -1);
- writeExpectedScalar("A", numberDistinct);
+
+ if(dir.isRowCol()) {
+ writeExpectedScalar("A", numberDistinct);
+ }
+ else {
+ double[][] expectedMatrix =
getExpectedMatrixRowOrCol(dir, cols, rows, numberDistinct);
+ writeExpectedMatrix("A", expectedMatrix);
+ }
compareResults(tolerance);
}
catch(Exception e) {
@@ -126,4 +87,22 @@ public abstract class CountDistinctBase extends
AutomatedTestBase {
rtplatform = platformOld;
}
}
-}
\ No newline at end of file
+
+ private double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int
cols, int rows, long expectedValue) {
+ double[][] expectedResult;
+ if(dir.isRow()) {
+ expectedResult = new double[rows][1];
+ for(int i = 0; i < rows; ++i) {
+ expectedResult[i][0] = expectedValue;
+ }
+ }
+ else {
+ expectedResult = new double[1][cols];
+ for(int i = 0; i < cols; ++i) {
+ expectedResult[0][i] = expectedValue;
+ }
+ }
+
+ return expectedResult;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
similarity index 80%
rename from
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
rename to
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
index 9581bc8..3de4a61 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
@@ -22,11 +22,11 @@ package org.apache.sysds.test.functions.countDistinct;
import org.apache.sysds.common.Types.ExecType;
import org.junit.Test;
-public class CountDistinct extends CountDistinctBase {
+public class CountDistinctRowCol extends CountDistinctRowColBase {
public String TEST_NAME = "countDistinct";
public String TEST_DIR = "functions/countDistinct/";
- public String TEST_CLASS_DIR = TEST_DIR +
CountDistinct.class.getSimpleName() + "/";
+ public String TEST_CLASS_DIR = TEST_DIR +
CountDistinctRowCol.class.getSimpleName() + "/";
protected String getTestClassDir() {
return TEST_CLASS_DIR;
@@ -40,10 +40,16 @@ public class CountDistinct extends CountDistinctBase {
return TEST_DIR;
}
+ @Override
+ public void setUp() {
+ super.addTestConfiguration();
+ super.percentTolerance = 0.0;
+ }
+
@Test
public void testSimple1by1() {
// test simple 1 by 1.
ExecType ex = ExecType.CP;
- countDistinctTest(1, 1, 1, 1.0, ex, 0.00001);
+ countDistinctScalarTest(1, 1, 1, 1.0, ex, 0.00001);
}
-}
\ No newline at end of file
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
new file mode 100644
index 0000000..6b20075
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types.ExecType;
+import org.junit.Test;
+
+public abstract class CountDistinctRowColBase extends CountDistinctBase {
+ @Test
+ public void testCPDenseSmall() {
+ ExecType ex = ExecType.CP;
+ double tolerance = baseTolerance + 50 * percentTolerance;
+ countDistinctScalarTest(50, 50, 50, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testCPDenseLarge() {
+ ExecType ex = ExecType.CP;
+ double tolerance = baseTolerance + 800 * percentTolerance;
+ countDistinctScalarTest(800, 1000, 1000, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testCPDenseXLarge() {
+ ExecType ex = ExecType.CP;
+ double tolerance = baseTolerance + 1723 * percentTolerance;
+ countDistinctScalarTest(1723, 5000, 2000, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testCPDense1Unique() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 0.00001;
+ countDistinctScalarTest(1, 100, 1000, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testCPDense2Unique() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 0.00001;
+ countDistinctScalarTest(2, 100, 1000, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testCPDense120Unique() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 0.00001 + 120 * percentTolerance;
+ countDistinctScalarTest(120, 100, 1000, 1.0, ex, tolerance);
+ }
+
+ @Test
+ public void testCPSparse500Unique() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 0.00001 + 500 * percentTolerance;
+ countDistinctScalarTest(500, 100, 640000, 0.1, ex, tolerance);
+ }
+
+ @Test
+ public void testCPSparse120Unique() {
+ ExecType ex = ExecType.CP;
+ double tolerance = 0.00001 + 120 * percentTolerance;
+ countDistinctScalarTest(120, 100, 64000, 0.1, ex, tolerance);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
new file mode 100644
index 0000000..df2ea8a
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public abstract class CountDistinctRowOrColBase extends CountDistinctBase {
+
+ @Override
+ protected abstract String getTestClassDir();
+
+ @Override
+ protected abstract String getTestName();
+
+ @Override
+ protected abstract String getTestDir();
+
+ protected abstract Types.Direction getDirection();
+
+ protected void addTestConfiguration() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+ this.percentTolerance = 0.2;
+ }
+
+ @Test
+ public void testCPSparseLarge() {
+ Types.ExecType ex = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, ex, tolerance);
+ }
+
+ @Test
+ public void testCPDenseLarge() {
+ Types.ExecType ex = Types.ExecType.CP;
+
+ int actualDistinctCount = 100;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.9;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, ex, tolerance);
+ }
+
+ @Test
+ public void testCPSparseSmall() {
+ Types.ExecType execType = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, execType, tolerance);
+ }
+
+ @Test
+ public void testCPDenseSmall() {
+ Types.ExecType execType = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 1000;
+ double sparsity = 0.9;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, execType, tolerance);
+ }
+
+ @Test
+ public void testSparkSparseLargeMultiBlockAggregation() {
+ Types.ExecType execType = Types.ExecType.SPARK;
+
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1001;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, execType, tolerance);
+ }
+
+ @Test
+ public void testSparkDenseLargeMultiBlockAggregation() {
+ Types.ExecType execType = Types.ExecType.SPARK;
+
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1001;
+ double sparsity = 0.9;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, execType, tolerance);
+ }
+
+ @Test
+ public void testSparkSparseLargeNoneAggregation() {
+ Types.ExecType execType = Types.ExecType.SPARK;
+
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, execType, tolerance);
+ }
+
+ @Test
+ public void testSparkDenseLargeNoneAggregation() {
+ Types.ExecType execType = Types.ExecType.SPARK;
+
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.9;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount,
cols, rows, sparsity, execType, tolerance);
+ }
+}
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml
b/src/test/scripts/functions/countDistinct/countDistinct.dml
index a0da780..3b21bc8 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinct.dml
@@ -21,5 +21,4 @@
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
res = countDistinct(input)
-print(res)
write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinctApprox.dml
b/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
similarity index 92%
rename from src/test/scripts/functions/countDistinct/countDistinctApprox.dml
rename to src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
index eeb5bfc..777a56a 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctApprox.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctApprox(input)
-write(res, $5, format="text")
\ No newline at end of file
+res = countDistinctApprox(input, dir="c", type="KMV")
+write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml
b/src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml
similarity index 92%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml
index a0da780..38c8b9c 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml
@@ -19,7 +19,6 @@
#
#-------------------------------------------------------------
-input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinct(input)
-print(res)
+input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
+res = countDistinctApprox(input, dir="r", type="KMV")
write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml
b/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
similarity index 92%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
index a0da780..2c5b6cf 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
@@ -19,7 +19,6 @@
#
#-------------------------------------------------------------
-input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinct(input)
-print(res)
+input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
+res = countDistinctApprox(input, dir="rc", type="KMV")
write(res, $5, format="text")