This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1c60cc5e35 [MINOR] Improve parameter validation of countDistinct
1c60cc5e35 is described below
commit 1c60cc5e35f472fefbe1df8cf3c6877fa2fe4ba8
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Sat Nov 5 19:27:06 2022 -0700
[MINOR] Improve parameter validation of countDistinct
This patch improves the validation of parameters to aliases for
countDistinct() and countDistinctApprox(). The aliases have also been
renamed for consistency with other builtin functions:
- countDistinctRow() -> rowCountDistinct()
- countDistinctCol() -> colCountDistinct()
- countDistinctApproxRow() -> rowCountDistinctApprox()
- countDistinctApproxCol() -> colCountDistinctApprox()
countDistinctApproxRow() and countDistinctApproxCol() only accept a
single additional parameter for type (default=KMV), which is optional.
countDistinctRow() and countDistinctCol() have been converted from
parameterized builtin functions to non-parameterized builtins as The
aliases specify their respective directions implicitly.
Closes #1722
---
.../java/org/apache/sysds/common/Builtins.java | 8 +--
src/main/java/org/apache/sysds/common/Types.java | 2 +-
.../org/apache/sysds/lops/PartialAggregate.java | 4 +-
.../sysds/parser/BuiltinFunctionExpression.java | 22 +++++-
.../org/apache/sysds/parser/DMLTranslator.java | 13 +++-
.../ParameterizedBuiltinFunctionExpression.java | 54 +++++++++------
.../functions/countDistinct/CountDistinctBase.java | 6 +-
.../CountDistinctColAliasException.java | 77 +++++++++++++++++++++
.../CountDistinctRowAliasException.java | 77 +++++++++++++++++++++
.../CountDistinctApproxColAliasException.java | 78 ++++++++++++++++++++++
.../CountDistinctApproxRowAliasException.java | 78 ++++++++++++++++++++++
.../countDistinct/countDistinctColAlias.dml | 2 +-
.../countDistinctColAliasException.dml} | 2 +-
.../countDistinct/countDistinctRowAlias.dml | 2 +-
.../countDistinctRowAliasException.dml} | 2 +-
.../countDistinctApproxColAlias.dml | 2 +-
...ml => countDistinctApproxColAliasException.dml} | 2 +-
.../countDistinctApproxRowAlias.dml | 2 +-
...ml => countDistinctApproxRowAliasException.dml} | 2 +-
19 files changed, 395 insertions(+), 40 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 262212570e..5afef9c308 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -74,6 +74,7 @@ public enum Builtins {
CBIND("cbind", "append", false),
CEIL("ceil", "ceiling", false),
CHOLESKY("cholesky", false),
+ COL_COUNT_DISTINCT("colCountDistinct",false),
COLMAX("colMaxs", false),
COLMEAN("colMeans", false),
COLMIN("colMins", false),
@@ -245,6 +246,7 @@ public enum Builtins {
REMOVE("remove", false, ReturnType.MULTI_RETURN),
REV("rev", false),
ROUND("round", false),
+ ROW_COUNT_DISTINCT("rowCountDistinct",false),
ROWINDEXMAX("rowIndexMax", false),
ROWINDEXMIN("rowIndexMin", false),
ROWMAX("rowMaxs", false),
@@ -308,11 +310,9 @@ public enum Builtins {
AUTODIFF("autoDiff", false, true),
CDF("cdf", false, true),
COUNT_DISTINCT("countDistinct",false, true),
- COUNT_DISTINCT_ROW("countDistinctRow",false, true),
- COUNT_DISTINCT_COL("countDistinctCol",false, true),
COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
- COUNT_DISTINCT_APPROX_ROW("countDistinctApproxRow", false, true),
- COUNT_DISTINCT_APPROX_COL("countDistinctApproxCol", false, true),
+ COUNT_DISTINCT_APPROX_ROW("rowCountDistinctApprox", false, true),
+ COUNT_DISTINCT_APPROX_COL("colCountDistinctApprox", false, true),
CVLM("cvlm", true, false),
GROUPEDAGG("aggregate", "groupedAggregate", false, true),
INVCDF("icdf", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 991284c13e..7c3a3f1e53 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -198,7 +198,7 @@ public class Types
PROD(4), SUM_PROD(5),
TRACE(6), MEAN(7), VAR(8),
MAXINDEX(9), MININDEX(10),
- COUNT_DISTINCT(11), COUNT_DISTINCT_ROW(12),
COUNT_DISTINCT_COL(13),
+ COUNT_DISTINCT(11), ROW_COUNT_DISTINCT(12),
COL_COUNT_DISTINCT(13),
COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15),
COUNT_DISTINCT_APPROX_COL(16);
@Override
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 0481c7373a..1a7d22b989 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -352,10 +352,10 @@ public class PartialAggregate extends Lop
}
}
- case COUNT_DISTINCT_ROW:
+ case ROW_COUNT_DISTINCT:
return "uacdr";
- case COUNT_DISTINCT_COL:
+ case COL_COUNT_DISTINCT:
return "uacdc";
case COUNT_DISTINCT_APPROX: {
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index c3aca47d38..5634e02be9 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1606,6 +1606,26 @@ public class BuiltinFunctionExpression extends
DataIdentifier
else
raiseValidateError("Compress/DeCompress
instruction not allowed in dml script");
break;
+ case ROW_COUNT_DISTINCT:
+ checkNumParameters(1);
+ checkMatrixParam(getFirstExpr());
+ output.setDataType(DataType.MATRIX);
+ output.setDimensions(id.getDim1(), 1);
+ output.setBlocksize (id.getBlocksize());
+ output.setValueType(ValueType.INT64);
+ output.setNnz(id.getDim1());
+ break;
+
+ case COL_COUNT_DISTINCT:
+ checkNumParameters(1);
+ checkMatrixParam(getFirstExpr());
+ output.setDataType(DataType.MATRIX);
+ output.setDimensions(1, id.getDim2());
+ output.setBlocksize (id.getBlocksize());
+ output.setValueType(ValueType.INT64);
+ output.setNnz(id.getDim2());
+ break;
+
default:
if( isMathFunction() ) {
checkMathFunctionParam();
@@ -1637,7 +1657,7 @@ public class BuiltinFunctionExpression extends
DataIdentifier
}
}
}
-
+
private void setBinaryOutputProperties(DataIdentifier output) {
DataType dt1 = getFirstExpr().getOutput().getDataType();
DataType dt2 = getSecondExpr().getOutput().getDataType();
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 553bf56fc5..0c3a6dfd8f 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2064,13 +2064,11 @@ public class DMLTranslator
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
break;
- case COUNT_DISTINCT_ROW:
case COUNT_DISTINCT_APPROX_ROW:
currBuiltinOp = new
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data"));
break;
- case COUNT_DISTINCT_COL:
case COUNT_DISTINCT_APPROX_COL:
currBuiltinOp = new
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
@@ -2758,6 +2756,17 @@ public class DMLTranslator
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
+
+ case ROW_COUNT_DISTINCT:
+ currBuiltinOp = new AggUnaryOp(target.getName(),
DataType.MATRIX, target.getValueType(),
+
AggOp.valueOf(source.getOpCode().name()), Direction.Row, expr);
+ break;
+
+ case COL_COUNT_DISTINCT:
+ currBuiltinOp = new AggUnaryOp(target.getName(),
DataType.MATRIX, target.getValueType(),
+
AggOp.valueOf(source.getOpCode().name()), Direction.Col, expr);
+ break;
+
default:
throw new ParseException("Unsupported builtin function
type: "+source.getOpCode());
}
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index bdfd38c5a4..7ef19badde 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -247,15 +247,16 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
break;
case COUNT_DISTINCT:
- case COUNT_DISTINCT_ROW:
- case COUNT_DISTINCT_COL:
validateCountDistinct(output, conditional);
break;
case COUNT_DISTINCT_APPROX:
+ validateCountDistinctApprox(output, conditional, false);
+ break;
+
case COUNT_DISTINCT_APPROX_ROW:
case COUNT_DISTINCT_APPROX_COL:
- validateCountDistinctApprox(output, conditional);
+ validateCountDistinctApprox(output, conditional, true);
break;
default: //always unconditional (because unsupported operation)
@@ -400,7 +401,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
validateAggregationDirection(dataId, output);
}
- private void validateCountDistinctApprox(DataIdentifier output, boolean
conditional) {
+ private void validateCountDistinctApprox(DataIdentifier output, boolean
conditional, boolean isDirectionAlias) {
Set<String> validTypeNames = CollectionUtils.asSet("KMV");
HashMap<String, Expression> varParams = getVarParams();
@@ -411,13 +412,26 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
// Validate the number of parameters
String fname = getOpCode().getName();
- String usageMessage = "function " + fname + " takes at least 1
and at most 3 parameters";
- if (varParams.size() < 1) {
- raiseValidateError("Too few parameters: " +
usageMessage, conditional);
- }
+ if (!isDirectionAlias) {
+ // Function is not an alias, so we have to check for
all 3 permissible parameters
+ String usageMessage = "function " + fname + " takes at
least 1 and at most 3 parameters";
+ if (varParams.size() < 1) {
+ raiseValidateError("Too few parameters: " +
usageMessage, conditional);
+ }
- if (varParams.size() > 3) {
- raiseValidateError("Too many parameters: " +
usageMessage, conditional);
+ if (varParams.size() > 3) {
+ raiseValidateError("Too many parameters: " +
usageMessage, conditional);
+ }
+ } else {
+ // The direction is fixed for function aliases
+ String usageMessage = "function " + fname + " takes at
least 1 and at most 2 parameters";
+ if (varParams.size() < 1) {
+ raiseValidateError("Too few parameters: " +
usageMessage, conditional);
+ }
+
+ if (varParams.size() > 2) {
+ raiseValidateError("Too many parameters: " +
usageMessage, conditional);
+ }
}
// Check parameter names are valid
@@ -447,20 +461,22 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
addVarParam("type", new StringIdentifier("KMV", this));
}
- checkStringParam(true, fname, "dir", conditional);
- // Check data value of "dir" parameter
- validateAggregationDirection(dataId, output);
+ if (!isDirectionAlias) {
+ checkStringParam(true, fname, "dir", conditional);
+ // Check data value of "dir" parameter
+ validateAggregationDirection(dataId, output);
+ }
}
private void validateAggregationDirection(Identifier dataId,
DataIdentifier output) {
HashMap<String, Expression> varParams = getVarParams();
if (varParams.containsKey("dir")) {
- String directionString =
varParams.get("dir").toString().toUpperCase();
+ String inputDirectionString =
varParams.get("dir").toString().toUpperCase();
// Set output type and dimensions based on direction
// "r" -> count across all rows, resulting in a Mx1
matrix
- if
(directionString.equals(Types.Direction.Row.toString())) {
+ if
(inputDirectionString.equals(Types.Direction.Row.toString())) {
output.setDataType(DataType.MATRIX);
output.setDimensions(dataId.getDim1(), 1);
output.setBlocksize(dataId.getBlocksize());
@@ -468,7 +484,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
output.setNnz(dataId.getDim1());
// "c" -> count across all cols, resulting in a 1xN
matrix
- } else if
(directionString.equals(Types.Direction.Col.toString())) {
+ } else if
(inputDirectionString.equals(Types.Direction.Col.toString())) {
output.setDataType(DataType.MATRIX);
output.setDimensions(1, dataId.getDim2());
output.setBlocksize(dataId.getBlocksize());
@@ -476,16 +492,16 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
output.setNnz(dataId.getDim2());
// "rc" -> count across all rows and cols in input
matrix, resulting in a single value
- } else if
(directionString.equals(Types.Direction.RowCol.toString())) {
+ } else if
(inputDirectionString.equals(Types.Direction.RowCol.toString())) {
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.INT64);
output.setNnz(1);
- // unrecognized value for "dir" parameter, should "cr"
be valid?
+ // unrecognized value for "dir" parameter
} else {
- raiseValidateError("Invalid argument: " +
directionString + " is not recognized");
+ raiseValidateError("Invalid argument: " +
inputDirectionString + " is not recognized");
}
} else { // default to dir="rc"
output.setDataType(DataType.SCALAR);
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
index 5bf850d49a..72838b63c9 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -46,17 +46,17 @@ public abstract class CountDistinctBase extends
AutomatedTestBase {
public abstract void setUp();
public void countDistinctScalarTest(long numberDistinct, int cols, int
rows, double sparsity,
- Types.ExecType instType, double tolerance) {
+
Types.ExecType instType, double tolerance) {
countDistinctTest(Types.Direction.RowCol, numberDistinct, cols,
rows, sparsity, instType, tolerance);
}
public void countDistinctMatrixTest(Types.Direction dir, long
numberDistinct, int cols, int rows, double sparsity,
- Types.ExecType instType, double tolerance) {
+
Types.ExecType instType, double tolerance) {
countDistinctTest(dir, numberDistinct, cols, rows, sparsity,
instType, tolerance);
}
public void countDistinctTest(Types.Direction dir, long numberDistinct,
int cols, int rows, double sparsity,
- Types.ExecType instType, double tolerance) {
+
Types.ExecType instType, double tolerance) {
Types.ExecMode platformOld = setExecMode(instType);
try {
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAliasException.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAliasException.java
new file mode 100644
index 0000000000..8af98ea790
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAliasException.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctColAliasException extends CountDistinctBase {
+
+ @Rule
+ public ExpectedException exceptionRule = ExpectedException.none();
+
+ private final static String TEST_NAME =
"countDistinctColAliasException";
+ private final static String TEST_DIR = "functions/countDistinct/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctColAliasException.class.getSimpleName() + "/";
+
+ private final Types.Direction DIRECTION = Types.Direction.Row;
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+ this.percentTolerance = 0.2;
+ }
+
+ @Test
+ public void testCPSparseSmall() {
+ exceptionRule.expect(AssertionError.class);
+ exceptionRule.expectMessage("Invalid number of arguments for
function col_count_distinct(). " +
+ "This function only takes 1 or 2 arguments.");
+
+ Types.ExecType execType = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+ countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols,
rows, sparsity, execType, tolerance);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAliasException.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAliasException.java
new file mode 100644
index 0000000000..dd5c4c2a05
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAliasException.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctRowAliasException extends CountDistinctBase {
+
+ @Rule
+ public ExpectedException exceptionRule = ExpectedException.none();
+
+ private final static String TEST_NAME =
"countDistinctRowAliasException";
+ private final static String TEST_DIR = "functions/countDistinct/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctRowAliasException.class.getSimpleName() + "/";
+
+ private final Types.Direction DIRECTION = Types.Direction.Row;
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+ this.percentTolerance = 0.2;
+ }
+
+ @Test
+ public void testCPSparseSmall() {
+ exceptionRule.expect(AssertionError.class);
+ exceptionRule.expectMessage("Invalid number of arguments for
function row_count_distinct(). " +
+ "This function only takes 1 or 2 arguments.");
+
+ Types.ExecType execType = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+ countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols,
rows, sparsity, execType, tolerance);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAliasException.java
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAliasException.java
new file mode 100644
index 0000000000..8ea94a3a88
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAliasException.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinctApprox;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctBase;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctApproxColAliasException extends CountDistinctBase {
+
+ @Rule
+ public ExpectedException exceptionRule = ExpectedException.none();
+
+ private final static String TEST_NAME =
"countDistinctApproxColAliasException";
+ private final static String TEST_DIR = "functions/countDistinctApprox/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctApproxColAliasException.class.getSimpleName() + "/";
+
+ private final Types.Direction DIRECTION = Types.Direction.Row;
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+ this.percentTolerance = 0.2;
+ }
+
+ @Test
+ public void testCPSparseSmall() {
+ exceptionRule.expect(AssertionError.class);
+ exceptionRule.expectMessage("Too many parameters: function
colCountDistinctApprox takes at least 1" +
+ " and at most 2 parameters");
+
+ Types.ExecType execType = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+ countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols,
rows, sparsity, execType, tolerance);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAliasException.java
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAliasException.java
new file mode 100644
index 0000000000..5693985a91
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAliasException.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinctApprox;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctBase;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctApproxRowAliasException extends CountDistinctBase {
+
+ @Rule
+ public ExpectedException exceptionRule = ExpectedException.none();
+
+ private final static String TEST_NAME =
"countDistinctApproxRowAliasException";
+ private final static String TEST_DIR = "functions/countDistinctApprox/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
CountDistinctApproxRowAliasException.class.getSimpleName() + "/";
+
+ private final Types.Direction DIRECTION = Types.Direction.Row;
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+ this.percentTolerance = 0.2;
+ }
+
+ @Test
+ public void testCPSparseSmall() {
+ exceptionRule.expect(AssertionError.class);
+ exceptionRule.expectMessage("Too many parameters: function
rowCountDistinctApprox takes at least 1" +
+ " and at most 2 parameters");
+
+ Types.ExecType execType = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+ countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols,
rows, sparsity, execType, tolerance);
+ }
+}
diff --git a/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
index 3eeb8ed54a..2522fbd1a5 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctCol(input, dir="c")
+res = colCountDistinct(input)
write(res, $5, format="text")
diff --git
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
b/src/test/scripts/functions/countDistinct/countDistinctColAliasException.dml
similarity index 94%
copy from
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to
src/test/scripts/functions/countDistinct/countDistinctColAliasException.dml
index 83a9f5070c..45caeb85af 100644
---
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++
b/src/test/scripts/functions/countDistinct/countDistinctColAliasException.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = colCountDistinct(input, dir="x")
write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
index 62d7196ce1..685221ffbe 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctRow(input, dir="r")
+res = rowCountDistinct(input)
write(res, $5, format="text")
diff --git
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
b/src/test/scripts/functions/countDistinct/countDistinctRowAliasException.dml
similarity index 94%
copy from
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to
src/test/scripts/functions/countDistinct/countDistinctRowAliasException.dml
index 83a9f5070c..3b1cabfe98 100644
---
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++
b/src/test/scripts/functions/countDistinct/countDistinctRowAliasException.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = rowCountDistinct(input, dir="x")
write(res, $5, format="text")
diff --git
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
index 83a9f5070c..0eda3fb989 100644
---
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = colCountDistinctApprox(input, type="KMV")
write(res, $5, format="text")
diff --git
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAliasException.dml
similarity index 94%
copy from
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAliasException.dml
index 83a9f5070c..8428cd061b 100644
---
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAliasException.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = colCountDistinctApprox(input, dir="x", type="KMV")
write(res, $5, format="text")
diff --git
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
index f4be480156..f2c226e62e 100644
---
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
+++
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctApproxRow(input, dir="r", type="KMV")
+res = rowCountDistinctApprox(input, type="KMV")
write(res, $5, format="text")
diff --git
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAliasException.dml
similarity index 94%
copy from
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to
src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAliasException.dml
index 83a9f5070c..05526c9ce8 100644
---
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAliasException.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = rowCountDistinctApprox(input, dir="x", type="KMV")
write(res, $5, format="text")