This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new fd343a8  [SYSTEMDS-2932] Fix eval function loading (partially existing 
functions)
fd343a8 is described below

commit fd343a8160c49d64157babd719771a809487ef6b
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Apr 17 20:15:50 2021 +0200

    [SYSTEMDS-2932] Fix eval function loading (partially existing functions)
    
    This patch fixes issues of on demand function loading when called
    through eval(). Specifically, on loading complex functions that call
    internally other builtin functions caused problems, if a subset of these
    functions where already loaded. We now do a more fine-grained check for
    the entire tree of loaded functions, and add a test with hyperband and
    gridsearch, where hyperband loads lmCG and gridsearch tries to load lm,
    which also tries to bring in lmCG and lmDS.
---
 .../instructions/cp/EvalNaryCPInstruction.java     | 25 ++++++-----
 .../functions/builtin/BuiltinHyperbandTest.java    |  8 ++++
 .../scripts/functions/builtin/HyperbandLM3.dml     | 49 ++++++++++++++++++++++
 3 files changed, 71 insertions(+), 11 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index e2ac467..8fd7867 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -25,6 +25,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.stream.Collectors;
 
 import org.apache.sysds.common.Builtins;
 import org.apache.sysds.common.Types.DataType;
@@ -146,16 +147,20 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                if( fsbs.isEmpty() )
                        throw new DMLRuntimeException("Failed to compile 
function '"+name+"'.");
                
+               DMLProgram dmlp = (prog.getDMLProg() != null) ? 
prog.getDMLProg() :
+                       fsbs.get(Builtins.getInternalFName(name, 
dt)).getDMLProg();
+               
+               //filter already existing functions (e.g., already loaded 
internally-called functions)
+               fsbs = (dmlp.getBuiltinFunctionDictionary() == null) ? fsbs : 
fsbs.entrySet().stream()
+                       .filter(e -> 
!dmlp.getBuiltinFunctionDictionary().containsFunction(e.getKey()))
+                       .collect(Collectors.toMap(e -> e.getKey(), e -> 
e.getValue()));
+               
                // prepare common data structures, including a consolidated dml 
program
                // to facilitate function validation which tries to inline 
lazily loaded
                // and existing functions.
-               DMLProgram dmlp = (prog.getDMLProg() != null) ? 
prog.getDMLProg() :
-                       fsbs.get(Builtins.getInternalFName(name, 
dt)).getDMLProg();
                for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() 
) {
                        dmlp.createNamespace(nsName); // create namespace on 
demand
-                       if( 
!dmlp.getBuiltinFunctionDictionary().containsFunction(fsb.getKey()) ) {
-                               dmlp.addFunctionStatementBlock(nsName, 
fsb.getKey(), fsb.getValue());
-                       }
+                       dmlp.addFunctionStatementBlock(nsName, fsb.getKey(), 
fsb.getValue());
                        fsb.getValue().setDMLProg(dmlp);
                }
                DMLTranslator dmlt = new DMLTranslator(dmlp);
@@ -186,12 +191,10 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                
                // compile runtime program
                for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() 
) {
-                       if( !prog.containsFunctionProgramBlock(null, 
fsb.getKey(), false) ) {
-                               FunctionProgramBlock fpb = 
(FunctionProgramBlock) dmlt
-                                       .createRuntimeProgramBlock(prog, 
fsb.getValue(), ConfigurationManager.getDMLConfig());
-                               prog.addFunctionProgramBlock(nsName, 
fsb.getKey(), fpb, true);  // optimized
-                               prog.addFunctionProgramBlock(nsName, 
fsb.getKey(), fpb, false); // unoptimized -> eval
-                       }
+                       FunctionProgramBlock fpb = (FunctionProgramBlock) dmlt
+                               .createRuntimeProgramBlock(prog, 
fsb.getValue(), ConfigurationManager.getDMLConfig());
+                       prog.addFunctionProgramBlock(nsName, fsb.getKey(), fpb, 
true);  // optimized
+                       prog.addFunctionProgramBlock(nsName, fsb.getKey(), fpb, 
false); // unoptimized -> eval
                }
        }
        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
index 45cd8cb..064b83d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
@@ -33,6 +33,8 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
 {
        private final static String TEST_NAME1 = "HyperbandLM";
        private final static String TEST_NAME2 = "HyperbandLM2";
+       private final static String TEST_NAME3 = "HyperbandLM3";
+       
        private final static String TEST_DIR = "functions/builtin/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinHyperbandTest.class.getSimpleName() + "/";
        
@@ -43,6 +45,7 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
        public void setUp() {
                addTestConfiguration(TEST_NAME1,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"R"}));
                addTestConfiguration(TEST_NAME2,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"R"}));
+               addTestConfiguration(TEST_NAME3,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3,new String[]{"R"}));
        }
        
        @Test
@@ -56,6 +59,11 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
        }
        
        @Test
+       public void testHyperbandNoCompare2CP() {
+               runHyperband(TEST_NAME3, ExecType.CP);
+       }
+       
+       @Test
        public void testHyperbandSpark() {
                runHyperband(TEST_NAME2, ExecType.SPARK);
        }
diff --git a/src/test/scripts/functions/builtin/HyperbandLM3.dml 
b/src/test/scripts/functions/builtin/HyperbandLM3.dml
new file mode 100644
index 0000000..9dbdbba
--- /dev/null
+++ b/src/test/scripts/functions/builtin/HyperbandLM3.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return 
(Matrix[Double] loss) {
+  loss = as.matrix(sum((y - X%*%B)^2));
+}
+
+X = read($1);
+y = read($2);
+numTrSamples = 100;
+numValSamples = 100;
+
+X_train = X[1:numTrSamples,];
+y_train = y[1:numTrSamples,];
+X_val = X[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+y_val = y[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+X_test = X[(numTrSamples+numValSamples+2):nrow(X),];
+y_test = y[(numTrSamples+numValSamples+2):nrow(X),];
+
+params = list("reg");
+paramRanges = matrix("0 20", rows=1, cols=2);
+
+[bestWeights, optHyperParams] = hyperband(X_train=X_train, y_train=y_train,
+  X_val=X_val, y_val=y_val, params=params, paramRanges=paramRanges);
+
+paramRanges2 = list(10^seq(0,-4))
+[bestWeights, optHyperParams2] = gridSearch(X=X_train, y=y_train,
+  train="lm", predict="l2norm", params=params, paramValues=paramRanges2);
+
+print(toString(optHyperParams))
+print(toString(optHyperParams2))

Reply via email to