This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new fd343a8 [SYSTEMDS-2932] Fix eval function loading (partially existing
functions)
fd343a8 is described below
commit fd343a8160c49d64157babd719771a809487ef6b
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Apr 17 20:15:50 2021 +0200
[SYSTEMDS-2932] Fix eval function loading (partially existing functions)
This patch fixes issues of on demand function loading when called
through eval(). Specifically, on loading complex functions that call
internally other builtin functions caused problems, if a subset of these
functions where already loaded. We now do a more fine-grained check for
the entire tree of loaded functions, and add a test with hyperband and
gridsearch, where hyperband loads lmCG and gridsearch tries to load lm,
which also tries to bring in lmCG and lmDS.
---
.../instructions/cp/EvalNaryCPInstruction.java | 25 ++++++-----
.../functions/builtin/BuiltinHyperbandTest.java | 8 ++++
.../scripts/functions/builtin/HyperbandLM3.dml | 49 ++++++++++++++++++++++
3 files changed, 71 insertions(+), 11 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index e2ac467..8fd7867 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -25,6 +25,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.stream.Collectors;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.DataType;
@@ -146,16 +147,20 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
if( fsbs.isEmpty() )
throw new DMLRuntimeException("Failed to compile
function '"+name+"'.");
+ DMLProgram dmlp = (prog.getDMLProg() != null) ?
prog.getDMLProg() :
+ fsbs.get(Builtins.getInternalFName(name,
dt)).getDMLProg();
+
+ //filter already existing functions (e.g., already loaded
internally-called functions)
+ fsbs = (dmlp.getBuiltinFunctionDictionary() == null) ? fsbs :
fsbs.entrySet().stream()
+ .filter(e ->
!dmlp.getBuiltinFunctionDictionary().containsFunction(e.getKey()))
+ .collect(Collectors.toMap(e -> e.getKey(), e ->
e.getValue()));
+
// prepare common data structures, including a consolidated dml
program
// to facilitate function validation which tries to inline
lazily loaded
// and existing functions.
- DMLProgram dmlp = (prog.getDMLProg() != null) ?
prog.getDMLProg() :
- fsbs.get(Builtins.getInternalFName(name,
dt)).getDMLProg();
for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet()
) {
dmlp.createNamespace(nsName); // create namespace on
demand
- if(
!dmlp.getBuiltinFunctionDictionary().containsFunction(fsb.getKey()) ) {
- dmlp.addFunctionStatementBlock(nsName,
fsb.getKey(), fsb.getValue());
- }
+ dmlp.addFunctionStatementBlock(nsName, fsb.getKey(),
fsb.getValue());
fsb.getValue().setDMLProg(dmlp);
}
DMLTranslator dmlt = new DMLTranslator(dmlp);
@@ -186,12 +191,10 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
// compile runtime program
for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet()
) {
- if( !prog.containsFunctionProgramBlock(null,
fsb.getKey(), false) ) {
- FunctionProgramBlock fpb =
(FunctionProgramBlock) dmlt
- .createRuntimeProgramBlock(prog,
fsb.getValue(), ConfigurationManager.getDMLConfig());
- prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, true); // optimized
- prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, false); // unoptimized -> eval
- }
+ FunctionProgramBlock fpb = (FunctionProgramBlock) dmlt
+ .createRuntimeProgramBlock(prog,
fsb.getValue(), ConfigurationManager.getDMLConfig());
+ prog.addFunctionProgramBlock(nsName, fsb.getKey(), fpb,
true); // optimized
+ prog.addFunctionProgramBlock(nsName, fsb.getKey(), fpb,
false); // unoptimized -> eval
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
index 45cd8cb..064b83d 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
@@ -33,6 +33,8 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "HyperbandLM";
private final static String TEST_NAME2 = "HyperbandLM2";
+ private final static String TEST_NAME3 = "HyperbandLM3";
+
private final static String TEST_DIR = "functions/builtin/";
private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinHyperbandTest.class.getSimpleName() + "/";
@@ -43,6 +45,7 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
public void setUp() {
addTestConfiguration(TEST_NAME1,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"R"}));
addTestConfiguration(TEST_NAME2,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"R"}));
+ addTestConfiguration(TEST_NAME3,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3,new String[]{"R"}));
}
@Test
@@ -56,6 +59,11 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
}
@Test
+ public void testHyperbandNoCompare2CP() {
+ runHyperband(TEST_NAME3, ExecType.CP);
+ }
+
+ @Test
public void testHyperbandSpark() {
runHyperband(TEST_NAME2, ExecType.SPARK);
}
diff --git a/src/test/scripts/functions/builtin/HyperbandLM3.dml
b/src/test/scripts/functions/builtin/HyperbandLM3.dml
new file mode 100644
index 0000000..9dbdbba
--- /dev/null
+++ b/src/test/scripts/functions/builtin/HyperbandLM3.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return
(Matrix[Double] loss) {
+ loss = as.matrix(sum((y - X%*%B)^2));
+}
+
+X = read($1);
+y = read($2);
+numTrSamples = 100;
+numValSamples = 100;
+
+X_train = X[1:numTrSamples,];
+y_train = y[1:numTrSamples,];
+X_val = X[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+y_val = y[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+X_test = X[(numTrSamples+numValSamples+2):nrow(X),];
+y_test = y[(numTrSamples+numValSamples+2):nrow(X),];
+
+params = list("reg");
+paramRanges = matrix("0 20", rows=1, cols=2);
+
+[bestWeights, optHyperParams] = hyperband(X_train=X_train, y_train=y_train,
+ X_val=X_val, y_val=y_val, params=params, paramRanges=paramRanges);
+
+paramRanges2 = list(10^seq(0,-4))
+[bestWeights, optHyperParams2] = gridSearch(X=X_train, y=y_train,
+ train="lm", predict="l2norm", params=params, paramValues=paramRanges2);
+
+print(toString(optHyperParams))
+print(toString(optHyperParams2))