This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 33effa8840 [SYSTEMDS-3343] Fix missing default handling in eval
function calls
33effa8840 is described below
commit 33effa88408b6ed25def675c00ccbf02e05ea75b
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat May 7 17:38:01 2022 +0200
[SYSTEMDS-3343] Fix missing default handling in eval function calls
This patch generalizes the existing eval function call mechanics by
handling of defaults arguments. For passed named lists, we now, in a
best effort manner, append constant scalar defaults if they are not
existing in the passed list before checking for valid input arguments.
In a second step, this will also enabled default parameters in
transformencode UDF functions.
---
scripts/builtin/gridSearch.dml | 4 +-
.../org/apache/sysds/parser/DMLTranslator.java | 3 ++
.../instructions/cp/EvalNaryCPInstruction.java | 27 +++++++++++++
.../sysds/runtime/instructions/cp/ListObject.java | 9 +++++
.../builtin/part1/BuiltinGridSearchTest.java | 44 +++++++++++++++++++---
.../transform/TransformEncodeUDFTest.java | 19 ++++++++--
.../scripts/functions/builtin/GridSearchLM.dml | 8 ++--
.../functions/builtin/GridSearchMLogreg.dml | 2 +-
.../TransformEncodeUDF2.dml} | 36 ++++++++----------
9 files changed, 117 insertions(+), 35 deletions(-)
diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index eb1f16e6c5..8e53502257 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -65,9 +65,9 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y,
String train, String
{
# Step 0) handling default arguments, which require access to passed data
if( length(trainArgs) == 0 )
- trainArgs = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
+ trainArgs = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1);
if( length(dataArgs) == 0 )
- dataArgs = list("X", "y");
+ dataArgs = list("X", "y");
if( length(predictArgs) == 0 )
predictArgs = list(X, y);
if( cv & cvk <= 1 ) {
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index c12e5e69d0..4d280a371a 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -659,6 +659,9 @@ public class DMLTranslator
retPB = rtpb;
+ // add statement block
+ retPB.setStatementBlock(sb);
+
// add location information
retPB.setParseInfo(sb);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 6b8f890953..5c55264627 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -30,13 +30,17 @@ import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.parser.ConstIdentifier;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.Expression;
+import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.dml.DmlSyntacticValidator;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -136,6 +140,7 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
&& !(fpb.getInputParams().size() == 1 &&
fpb.getInputParams().get(0).getDataType().isList()))
{
ListObject lo = ec.getListObject(boundInputs[0]);
+ lo = appendNamedDefaults(lo,
(FunctionStatement)fpb.getStatementBlock().getStatement(0));
checkValidArguments(lo.getData(), lo.getNames(),
fpb.getInputParamNames());
if( lo.isNamedList() )
lo = reorderNamedListForFunctionCall(lo,
fpb.getInputParamNames());
@@ -271,6 +276,28 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
}
}
+ private static ListObject appendNamedDefaults(ListObject params,
FunctionStatement fstmt) {
+ if( !params.isNamedList() )
+ return params;
+
+ //best effort replacement of scalar literal defaults
+ ListObject ret = new ListObject(params);
+ for( int i=0; i<fstmt.getInputParams().size(); i++ ) {
+ String param = fstmt.getInputParamNames()[i];
+ if( !ret.contains(param)
+ && fstmt.getInputDefaults().get(i) != null
+ &&
fstmt.getInputParams().get(i).getDataType().isScalar() )
+ {
+ ValueType vt =
fstmt.getInputParams().get(i).getValueType();
+ Expression expr =
fstmt.getInputDefaults().get(i);
+ if( expr instanceof ConstIdentifier )
+ ret.add(param,
ScalarObjectFactory.createScalarObject(vt, expr.toString()), null);
+ }
+ }
+
+ return ret;
+ }
+
private static void checkValidArguments(List<Data> loData, List<String>
loNames, List<String> fArgNames) {
//check number of parameters
int listSize = (loNames != null) ? loNames.size() :
loData.size();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index f6ac15a8f7..38288178e4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -120,6 +120,10 @@ public class ListObject extends Data implements
Externalizable {
return _names;
}
+ public void setNames(List<String> names) {
+ _names = names;
+ }
+
public String getName(int ix) {
return (_names == null) ? null : _names.get(ix);
}
@@ -155,6 +159,11 @@ public class ListObject extends Data implements
Externalizable {
(lo == d || ((ListObject)lo).contains(d)) : lo == d);
}
+ public boolean contains(String name) {
+ return _names != null
+ && _names.contains(name);
+ }
+
public long getDataSize() {
return _data.stream().filter(data -> data instanceof
CacheableData)
.mapToLong(data -> ((CacheableData<?>)
data).getDataSize()).sum();
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
index 6cc6411106..5bba5312ef 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
@@ -28,6 +28,7 @@ import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
public class BuiltinGridSearchTest extends AutomatedTestBase
{
@@ -71,6 +72,16 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
runGridSearch(TEST_NAME1, ExecMode.HYBRID, true);
}
+ @Test
+ public void testGridSearchLmVerboseCP() {
+ runGridSearch(TEST_NAME1, ExecMode.SINGLE_NODE, false, true);
+ }
+
+ @Test
+ public void testGridSearchLmVerboseHybrid() {
+ runGridSearch(TEST_NAME1, ExecMode.HYBRID, false, true);
+ }
+
@Test
public void testGridSearchLmSpark() {
runGridSearch(TEST_NAME1, ExecMode.SPARK, false);
@@ -86,6 +97,19 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
runGridSearch(TEST_NAME2, ExecMode.HYBRID, false);
}
+ @Test
+ public void testGridSearchMLogregVerboseCP() {
+ //verbose default
+ runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, false, true);
+ }
+
+ @Test
+ public void testGridSearchMLogregVerboseHybrid() {
+ //verbose default
+ runGridSearch(TEST_NAME2, ExecMode.HYBRID, false, true);
+ }
+
+
@Test
public void testGridSearchLm2CP() {
runGridSearch(TEST_NAME3, ExecMode.SINGLE_NODE, false);
@@ -108,20 +132,24 @@ public class BuiltinGridSearchTest extends
AutomatedTestBase
@Test
public void testGridSearchMLogreg4CP() {
- runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, 10, 4, false);
+ runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, 10, 4, false,
false);
}
@Test
public void testGridSearchMLogreg4Hybrid() {
- runGridSearch(TEST_NAME2, ExecMode.HYBRID, 10, 4, false);
+ runGridSearch(TEST_NAME2, ExecMode.HYBRID, 10, 4, false, false);
}
private void runGridSearch(String testname, ExecMode et, boolean
codegen) {
- runGridSearch(testname, et, _cols, 2, codegen); //binary
classification
+ runGridSearch(testname, et, _cols, 2, codegen, false); //binary
classification
}
- private void runGridSearch(String testname, ExecMode et, int cols, int
nc, boolean codegen)
+ private void runGridSearch(String testname, ExecMode et, boolean
codegen, boolean verbose) {
+ runGridSearch(testname, et, _cols, 2, codegen, verbose);
//binary classification
+ }
+
+ private void runGridSearch(String testname, ExecMode et, int cols, int
nc, boolean codegen, boolean verbose)
{
ExecMode modeOld = setExecMode(et);
_codegen = codegen;
@@ -131,7 +159,8 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] {"-args", input("X"),
input("y"), output("R")};
+ programArgs = new String[] {"-stats", "100", "-args",
+ input("X"), input("y"), output("R"),
String.valueOf(verbose).toUpperCase()};
double max = testname.equals(TEST_NAME2) ? nc : 2;
double[][] X = getRandomMatrix(_rows, cols, 0, 1, 0.8,
7);
double[][] y = getRandomMatrix(_rows, 1, 1, max, 1, 1);
@@ -142,6 +171,11 @@ public class BuiltinGridSearchTest extends
AutomatedTestBase
//expected loss smaller than default invocation
Assert.assertTrue(TestUtils.readDMLBoolean(output("R")));
+
+ //correct handling of verbose flag
+ if( verbose ) // 2 prints outside, if verbose more
+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("print")>100);
+
//Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
//TODO analyze influence of multiple subsequent tests
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
index aba517efbb..e7ccc8d582 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
@@ -31,7 +31,8 @@ import org.apache.sysds.utils.Statistics;
public class TransformEncodeUDFTest extends AutomatedTestBase
{
- private final static String TEST_NAME1 = "TransformEncodeUDF1";
+ private final static String TEST_NAME1 = "TransformEncodeUDF1";
//min-max
+ private final static String TEST_NAME2 = "TransformEncodeUDF2"; //scale
w/ defaults
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR +
TransformEncodeUDFTest.class.getSimpleName() + "/";
@@ -42,6 +43,7 @@ public class TransformEncodeUDFTest extends AutomatedTestBase
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
}
@Test
@@ -53,6 +55,17 @@ public class TransformEncodeUDFTest extends AutomatedTestBase
public void testUDF1Hybrid() {
runTransformTest(ExecMode.HYBRID, TEST_NAME1);
}
+
+// TODO default handling without named lists
+// @Test
+// public void testUDF2Singlenode() {
+// runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME2);
+// }
+//
+// @Test
+// public void testUDF2Hybrid() {
+// runTransformTest(ExecMode.HYBRID, TEST_NAME2);
+// }
private void runTransformTest(ExecMode rt, String testname)
{
@@ -61,10 +74,10 @@ public class TransformEncodeUDFTest extends
AutomatedTestBase
try
{
- getAndLoadTestConfiguration(TEST_NAME1);
+ getAndLoadTestConfiguration(testname);
String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-explain",
"-nvargs", "DATA=" + DATASET_DIR + DATASET,
"R="+output("R")};
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml
b/src/test/scripts/functions/builtin/GridSearchLM.dml
index 4311eba719..d439a80959 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM.dml
@@ -25,6 +25,7 @@ l2norm = function(Matrix[Double] X, Matrix[Double] y,
Matrix[Double] B) return (
X = read($1);
y = read($2);
+verbose = $4;
N = 200;
Xtrain = X[1:N,];
@@ -32,10 +33,11 @@ ytrain = y[1:N,];
Xtest = X[(N+1):nrow(X),];
ytest = y[(N+1):nrow(X),];
-params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+args = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
+params = list("reg", "tol", "maxi", "verbose");
+paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3),
as.matrix(as.double(verbose)));
[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
- numB=ncol(X), params=params, paramValues=paramRanges);
+ numB=ncol(X), params=params, paramValues=paramRanges, trainArgs=args);
B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
l1 = l2norm(Xtest, ytest, B1);
diff --git a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
index ce54d5d8be..08aebad83a 100644
--- a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
+++ b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
@@ -36,7 +36,7 @@ ytest = y[(N+1):nrow(X),];
params = list("icpt", "reg", "maxii");
paramRanges = list(seq(0,2),10^seq(1,-6), 10^seq(1,3));
-trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100,
maxii=-1, verbose=FALSE);
+trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100,
maxii=-1);
[B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg",
predict="accuracy", numB=(ncol(X)+1)*(nc-1),
params=params, paramValues=paramRanges, trainArgs=trainArgs, verbose=TRUE);
B2 = multiLogReg(X=Xtrain, Y=ytrain, verbose=TRUE);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml
b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
similarity index 58%
copy from src/test/scripts/functions/builtin/GridSearchLM.dml
copy to src/test/scripts/functions/transform/TransformEncodeUDF2.dml
index 4311eba719..233f25b73c 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,27 +19,21 @@
#
#-------------------------------------------------------------
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return
(Matrix[Double] loss) {
- loss = as.matrix(sum((y - X%*%B)^2));
-}
+F1 = read($DATA, data_type="frame", format="csv");
+
+# reference solution with scale outside transformencode
+jspec = "{ids: true, recode: [1, 2, 7]}";
+[X, M] = transformencode(target=F1, spec=jspec);
+R1 = scale(X=X);
-X = read($1);
-y = read($2);
+while(FALSE){}
-N = 200;
-Xtrain = X[1:N,];
-ytrain = y[1:N,];
-Xtest = X[(N+1):nrow(X),];
-ytest = y[(N+1):nrow(X),];
+# reference solution with scale outside transformencode
+jspec2 = "{ids: true, recode: [1, 2, 7], udf: {name: scale, ids: [1, 2, 3, 4,
5, 6, 7, 8, 9]}}";
+[R2, M2] = transformencode(target=F1, spec=jspec2);
-params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
-[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
- numB=ncol(X), params=params, paramValues=paramRanges);
-B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
+while(FALSE){}
-l1 = l2norm(Xtest, ytest, B1);
-l2 = l2norm(Xtest, ytest, B2);
-R = as.scalar(l1 < l2);
+R = sum(R1==R2);
+write(R, $R);
-write(R, $3)