This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 33effa8840 [SYSTEMDS-3343] Fix missing default handling in eval 
function calls
33effa8840 is described below

commit 33effa88408b6ed25def675c00ccbf02e05ea75b
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat May 7 17:38:01 2022 +0200

    [SYSTEMDS-3343] Fix missing default handling in eval function calls
    
    This patch generalizes the existing eval function call mechanics by
    handling of defaults arguments. For passed named lists, we now, in a
    best effort manner, append constant scalar defaults if they are not
    existing in the passed list before checking for valid input arguments.
    
    In a second step, this will also enabled default parameters in
    transformencode UDF functions.
---
 scripts/builtin/gridSearch.dml                     |  4 +-
 .../org/apache/sysds/parser/DMLTranslator.java     |  3 ++
 .../instructions/cp/EvalNaryCPInstruction.java     | 27 +++++++++++++
 .../sysds/runtime/instructions/cp/ListObject.java  |  9 +++++
 .../builtin/part1/BuiltinGridSearchTest.java       | 44 +++++++++++++++++++---
 .../transform/TransformEncodeUDFTest.java          | 19 ++++++++--
 .../scripts/functions/builtin/GridSearchLM.dml     |  8 ++--
 .../functions/builtin/GridSearchMLogreg.dml        |  2 +-
 .../TransformEncodeUDF2.dml}                       | 36 ++++++++----------
 9 files changed, 117 insertions(+), 35 deletions(-)

diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index eb1f16e6c5..8e53502257 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -65,9 +65,9 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, 
String train, String
 {
   # Step 0) handling default arguments, which require access to passed data
   if( length(trainArgs) == 0 )
-    trainArgs = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
+    trainArgs = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1);
   if( length(dataArgs) == 0 )
-    dataArgs = list("X", "y");  
+    dataArgs = list("X", "y");
   if( length(predictArgs) == 0 )
     predictArgs = list(X, y);
   if( cv & cvk <= 1 ) {
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index c12e5e69d0..4d280a371a 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -659,6 +659,9 @@ public class DMLTranslator
                        
                        retPB = rtpb;
                        
+                       // add statement block
+                       retPB.setStatementBlock(sb);
+                       
                        // add location information
                        retPB.setParseInfo(sb);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 6b8f890953..5c55264627 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -30,13 +30,17 @@ import java.util.stream.Collectors;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Builtins;
 import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.parser.ConstIdentifier;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.Expression;
+import org.apache.sysds.parser.FunctionStatement;
 import org.apache.sysds.parser.FunctionStatementBlock;
 import org.apache.sysds.parser.dml.DmlSyntacticValidator;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -136,6 +140,7 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                        && !(fpb.getInputParams().size() == 1 && 
fpb.getInputParams().get(0).getDataType().isList()))
                {
                        ListObject lo = ec.getListObject(boundInputs[0]);
+                       lo = appendNamedDefaults(lo, 
(FunctionStatement)fpb.getStatementBlock().getStatement(0));
                        checkValidArguments(lo.getData(), lo.getNames(), 
fpb.getInputParamNames());
                        if( lo.isNamedList() )
                                lo = reorderNamedListForFunctionCall(lo, 
fpb.getInputParamNames());
@@ -271,6 +276,28 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                }
        }
        
+       private static ListObject appendNamedDefaults(ListObject params, 
FunctionStatement fstmt) {
+               if( !params.isNamedList() )
+                       return params;
+               
+               //best effort replacement of scalar literal defaults
+               ListObject ret = new ListObject(params);
+               for( int i=0; i<fstmt.getInputParams().size(); i++ ) {
+                       String param = fstmt.getInputParamNames()[i];
+                       if( !ret.contains(param)
+                               && fstmt.getInputDefaults().get(i) != null
+                               && 
fstmt.getInputParams().get(i).getDataType().isScalar() )
+                       {
+                               ValueType vt = 
fstmt.getInputParams().get(i).getValueType();
+                               Expression expr = 
fstmt.getInputDefaults().get(i);
+                               if( expr instanceof ConstIdentifier )
+                                       ret.add(param, 
ScalarObjectFactory.createScalarObject(vt, expr.toString()), null);
+                       }
+               }
+               
+               return ret;
+       }
+       
        private static void checkValidArguments(List<Data> loData, List<String> 
loNames, List<String> fArgNames) {
                //check number of parameters
                int listSize = (loNames != null) ? loNames.size() : 
loData.size();
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index f6ac15a8f7..38288178e4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -120,6 +120,10 @@ public class ListObject extends Data implements 
Externalizable {
                return _names;
        }
 
+       public void setNames(List<String> names) {
+               _names = names;
+       }
+       
        public String getName(int ix) {
                return (_names == null) ? null : _names.get(ix);
        }
@@ -155,6 +159,11 @@ public class ListObject extends Data implements 
Externalizable {
                        (lo == d || ((ListObject)lo).contains(d)) : lo == d);
        }
        
+       public boolean contains(String name) {
+               return _names != null
+                       && _names.contains(name);
+       }
+       
        public long getDataSize() {
                return _data.stream().filter(data -> data instanceof 
CacheableData)
                        .mapToLong(data -> ((CacheableData<?>) 
data).getDataSize()).sum();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
index 6cc6411106..5bba5312ef 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
@@ -28,6 +28,7 @@ import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
 
 public class BuiltinGridSearchTest extends AutomatedTestBase
 {
@@ -71,6 +72,16 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
                runGridSearch(TEST_NAME1, ExecMode.HYBRID, true);
        }
        
+       @Test
+       public void testGridSearchLmVerboseCP() {
+               runGridSearch(TEST_NAME1, ExecMode.SINGLE_NODE, false, true);
+       }
+       
+       @Test
+       public void testGridSearchLmVerboseHybrid() {
+               runGridSearch(TEST_NAME1, ExecMode.HYBRID, false, true);
+       }
+       
        @Test
        public void testGridSearchLmSpark() {
                runGridSearch(TEST_NAME1, ExecMode.SPARK, false);
@@ -86,6 +97,19 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
                runGridSearch(TEST_NAME2, ExecMode.HYBRID, false);
        }
        
+       @Test
+       public void testGridSearchMLogregVerboseCP() {
+               //verbose default
+               runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, false, true);
+       }
+       
+       @Test
+       public void testGridSearchMLogregVerboseHybrid() {
+               //verbose default
+               runGridSearch(TEST_NAME2, ExecMode.HYBRID, false, true);
+       }
+       
+       
        @Test
        public void testGridSearchLm2CP() {
                runGridSearch(TEST_NAME3, ExecMode.SINGLE_NODE, false);
@@ -108,20 +132,24 @@ public class BuiltinGridSearchTest extends 
AutomatedTestBase
        
        @Test
        public void testGridSearchMLogreg4CP() {
-               runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, 10, 4, false);
+               runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, 10, 4, false, 
false);
        }
        
        @Test
        public void testGridSearchMLogreg4Hybrid() {
-               runGridSearch(TEST_NAME2, ExecMode.HYBRID, 10, 4, false);
+               runGridSearch(TEST_NAME2, ExecMode.HYBRID, 10, 4, false, false);
        }
        
        
        private void runGridSearch(String testname, ExecMode et, boolean 
codegen) {
-               runGridSearch(testname, et, _cols, 2, codegen); //binary 
classification
+               runGridSearch(testname, et, _cols, 2, codegen, false); //binary 
classification
        }
        
-       private void runGridSearch(String testname, ExecMode et, int cols, int 
nc, boolean codegen)
+       private void runGridSearch(String testname, ExecMode et, boolean 
codegen, boolean verbose) {
+               runGridSearch(testname, et, _cols, 2, codegen, verbose); 
//binary classification
+       }
+       
+       private void runGridSearch(String testname, ExecMode et, int cols, int 
nc, boolean codegen, boolean verbose)
        {
                ExecMode modeOld = setExecMode(et);
                _codegen = codegen;
@@ -131,7 +159,8 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
                        String HOME = SCRIPT_DIR + TEST_DIR;
        
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[] {"-args", input("X"), 
input("y"), output("R")};
+                       programArgs = new String[] {"-stats", "100", "-args",
+                               input("X"), input("y"), output("R"), 
String.valueOf(verbose).toUpperCase()};
                        double max = testname.equals(TEST_NAME2) ? nc : 2;
                        double[][] X = getRandomMatrix(_rows, cols, 0, 1, 0.8, 
7);
                        double[][] y = getRandomMatrix(_rows, 1, 1, max, 1, 1);
@@ -142,6 +171,11 @@ public class BuiltinGridSearchTest extends 
AutomatedTestBase
                        
                        //expected loss smaller than default invocation
                        
Assert.assertTrue(TestUtils.readDMLBoolean(output("R")));
+                       
+                       //correct handling of verbose flag
+                       if( verbose ) // 2 prints outside, if verbose more
+                               
Assert.assertTrue(Statistics.getCPHeavyHitterCount("print")>100);
+                       
                        //Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
                        //TODO analyze influence of multiple subsequent tests
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
index aba517efbb..e7ccc8d582 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
@@ -31,7 +31,8 @@ import org.apache.sysds.utils.Statistics;
 
 public class TransformEncodeUDFTest extends AutomatedTestBase 
 {
-       private final static String TEST_NAME1 = "TransformEncodeUDF1";
+       private final static String TEST_NAME1 = "TransformEncodeUDF1"; 
//min-max
+       private final static String TEST_NAME2 = "TransformEncodeUDF2"; //scale 
w/ defaults
        private final static String TEST_DIR = "functions/transform/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformEncodeUDFTest.class.getSimpleName() + "/";
        
@@ -42,6 +43,7 @@ public class TransformEncodeUDFTest extends AutomatedTestBase
        public void setUp()  {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
        }
        
        @Test
@@ -53,6 +55,17 @@ public class TransformEncodeUDFTest extends AutomatedTestBase
        public void testUDF1Hybrid() {
                runTransformTest(ExecMode.HYBRID, TEST_NAME1);
        }
+
+// TODO default handling without named lists
+//     @Test
+//     public void testUDF2Singlenode() {
+//             runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME2);
+//     }
+//     
+//     @Test
+//     public void testUDF2Hybrid() {
+//             runTransformTest(ExecMode.HYBRID, TEST_NAME2);
+//     }
        
        private void runTransformTest(ExecMode rt, String testname)
        {
@@ -61,10 +74,10 @@ public class TransformEncodeUDFTest extends 
AutomatedTestBase
                
                try
                {
-                       getAndLoadTestConfiguration(TEST_NAME1);
+                       getAndLoadTestConfiguration(testname);
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
-                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+                       fullDMLScriptName = HOME + testname + ".dml";
                        programArgs = new String[]{"-explain",
                                "-nvargs", "DATA=" + DATASET_DIR + DATASET, 
"R="+output("R")};
 
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml 
b/src/test/scripts/functions/builtin/GridSearchLM.dml
index 4311eba719..d439a80959 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM.dml
@@ -25,6 +25,7 @@ l2norm = function(Matrix[Double] X, Matrix[Double] y, 
Matrix[Double] B) return (
 
 X = read($1);
 y = read($2);
+verbose = $4;
 
 N = 200;
 Xtrain = X[1:N,];
@@ -32,10 +33,11 @@ ytrain = y[1:N,];
 Xtest = X[(N+1):nrow(X),];
 ytest = y[(N+1):nrow(X),];
 
-params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+args = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
+params = list("reg", "tol", "maxi", "verbose");
+paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3), 
as.matrix(as.double(verbose)));
 [B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
-  numB=ncol(X), params=params, paramValues=paramRanges);
+  numB=ncol(X), params=params, paramValues=paramRanges, trainArgs=args);
 B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
 
 l1 = l2norm(Xtest, ytest, B1);
diff --git a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml 
b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
index ce54d5d8be..08aebad83a 100644
--- a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
+++ b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
@@ -36,7 +36,7 @@ ytest = y[(N+1):nrow(X),];
 
 params = list("icpt", "reg", "maxii");
 paramRanges = list(seq(0,2),10^seq(1,-6), 10^seq(1,3));
-trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100, 
maxii=-1, verbose=FALSE);
+trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100, 
maxii=-1);
 [B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg", 
predict="accuracy", numB=(ncol(X)+1)*(nc-1),
   params=params, paramValues=paramRanges, trainArgs=trainArgs, verbose=TRUE);
 B2 = multiLogReg(X=Xtrain, Y=ytrain, verbose=TRUE);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml 
b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
similarity index 58%
copy from src/test/scripts/functions/builtin/GridSearchLM.dml
copy to src/test/scripts/functions/transform/TransformEncodeUDF2.dml
index 4311eba719..233f25b73c 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,27 +19,21 @@
 #
 #-------------------------------------------------------------
 
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return 
(Matrix[Double] loss) {
-  loss = as.matrix(sum((y - X%*%B)^2));
-}
+F1 = read($DATA, data_type="frame", format="csv");
+
+# reference solution with scale outside transformencode
+jspec = "{ids: true, recode: [1, 2, 7]}";
+[X, M] = transformencode(target=F1, spec=jspec);
+R1 = scale(X=X);
 
-X = read($1);
-y = read($2);
+while(FALSE){}
 
-N = 200;
-Xtrain = X[1:N,];
-ytrain = y[1:N,];
-Xtest = X[(N+1):nrow(X),];
-ytest = y[(N+1):nrow(X),];
+# reference solution with scale outside transformencode
+jspec2 = "{ids: true, recode: [1, 2, 7], udf: {name: scale, ids: [1, 2, 3, 4, 
5, 6, 7, 8, 9]}}";
+[R2, M2] = transformencode(target=F1, spec=jspec2);
 
-params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
-[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
-  numB=ncol(X), params=params, paramValues=paramRanges);
-B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
+while(FALSE){}
 
-l1 = l2norm(Xtest, ytest, B1);
-l2 = l2norm(Xtest, ytest, B2);
-R = as.scalar(l1 < l2);
+R = sum(R1==R2);
+write(R, $R);
 
-write(R, $3)

Reply via email to