This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 4d48c00 [SYSTEMDS-2935] Fix race condition in eval function loading
4d48c00 is described below
commit 4d48c0040d6f561bfe599c32b70c559e5e3ac491
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Apr 17 00:32:58 2021 +0200
[SYSTEMDS-2935] Fix race condition in eval function loading
This patch fixes a race condition, where multiple parfor workers call
eval of a non-loaded function, concurrently parsing and modifying the
shared program. Due to parfor retries, this issue did not cause
end-to-end test failures but produces confusing error outputs. We now
synchronize on the shared program, have a dedicated test without
retries, and cleanup some memory estimates related to lists.
---
src/main/java/org/apache/sysds/hops/Hop.java | 2 +-
src/main/java/org/apache/sysds/hops/NaryOp.java | 2 +-
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 33 +++++++++---------
.../runtime/controlprogram/ParForProgramBlock.java | 2 +-
.../instructions/cp/EvalNaryCPInstruction.java | 10 +++---
.../spark/CSVReblockSPInstruction.java | 6 +---
.../spark/utils/RDDConverterUtils.java | 4 ---
.../apache/sysds/runtime/util/UtilFunctions.java | 1 -
.../functions/builtin/BuiltinHyperbandTest.java | 31 ++++++++++++-----
.../pipelines/CleaningTestClassification.java | 1 -
.../scripts/functions/builtin/HyperbandLM2.dml | 40 ++++++++++++++++++++++
11 files changed, 88 insertions(+), 44 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index dcd258e..c4c2b5e 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -906,7 +906,7 @@ public abstract class Hop implements ParseInfo {
public boolean dimsKnown() {
return ( _dataType == DataType.SCALAR
- || ((_dataType==DataType.MATRIX ||
_dataType==DataType.FRAME)
+ || ((_dataType==DataType.MATRIX ||
_dataType==DataType.FRAME || _dataType==DataType.LIST)
&& _dc.rowsKnown() && _dc.colsKnown()) );
}
diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java
b/src/main/java/org/apache/sysds/hops/NaryOp.java
index 4ff005c..29dbea7 100644
--- a/src/main/java/org/apache/sysds/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/NaryOp.java
@@ -126,7 +126,7 @@ public class NaryOp extends Hop {
super.computeMemEstimate(memo);
//specific case for function call
- if( _op == OpOpN.EVAL ) {
+ if( _op == OpOpN.EVAL || _op == OpOpN.LIST ) {
_memEstimate = OptimizerUtils.INT_SIZE;
_outputMemEstimate = OptimizerUtils.INT_SIZE;
_processingMemEstimate = 0;
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index ee54561..18076be 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -468,18 +468,18 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
if (numNonZeroes < 0)
numNonZeroes = specifiedRows * specifiedCols;
long numRows = getInput().get(0).getDim1();
- if (numRows < 0) // If number of rows is not
known, set to default
+ if (numRows < 0) // If number of rows is not known, set
to default
numRows = specifiedRows;
long numCols = getInput().get(0).getDim2();
- if (numCols < 0) // If number of columns is not
known, set to default
+ if (numCols < 0) // If number of columns is not known,
set to default
numCols = specifiedCols;
// Assume Defaults : 100 * 100, sep = " ", linesep =
"\n", sparse = false
// String size in bytes is 36 + number_of_chars * 2
final long DEFAULT_SIZE = 36 + 2 *
- (100 * 100 * AVERAGE_CHARS_PER_VALUE
// Length for digits
- + 1 * 100 * 99
// Length for separator chars
- + 1* 100) ;
// Length for line separator chars
+ (100 * 100 * AVERAGE_CHARS_PER_VALUE //
Length for digits
+ + 1 * 100 * 99 //
Length for separator chars
+ + 1* 100); //
Length for line separator chars
try {
@@ -507,14 +507,14 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
long numberOfChars = -1;
if (sparsePrint){
- numberOfChars = AVERAGE_CHARS_PER_VALUE
* numNonZeroes // Length for value digits
- +
AVERAGE_CHARS_PER_INDEX * 2L * numNonZeroes // Length for row & column index
- +
sep.length() * 2L * numNonZeroes // Length for
separator chars
- +
linesep.length() * numNonZeroes; // Length for
line separator chars
+ numberOfChars = AVERAGE_CHARS_PER_VALUE
* numNonZeroes // Length for value digits
+ +
AVERAGE_CHARS_PER_INDEX * 2L * numNonZeroes // Length for row & column index
+ +
sep.length() * 2L * numNonZeroes // Length for separator chars
+ +
linesep.length() * numNonZeroes; // Length for line separator chars
} else {
- numberOfChars = AVERAGE_CHARS_PER_VALUE
* numRows * numCols // Length for digits
- +
sep.length() * numRows * (numCols - 1) // Length for separator
chars
- +
linesep.length() * numRows; // Length for
line separator chars
+ numberOfChars = AVERAGE_CHARS_PER_VALUE
* numRows * numCols // Length for digits
+ +
sep.length() * numRows * (numCols - 1) // Length for separator chars
+ +
linesep.length() * numRows; // Length for line separator chars
}
/*
@@ -528,14 +528,13 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
*/
return (36 + numberOfChars * 2);
-
- } catch (HopsException e){
+ }
+ catch (HopsException e){
LOG.warn("Invalid values when trying to compute
dims1, dims2 & nnz", e);
-
return DEFAULT_SIZE;
}
-
- } else {
+ }
+ else {
double sparsity = OptimizerUtils.getSparsity(dim1,
dim2, nnz);
return OptimizerUtils.estimateSizeExactSparsity(dim1,
dim2, sparsity);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 0c97d0a..7e318ae 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -293,7 +293,7 @@ public class ParForProgramBlock extends ForProgramBlock
public static final boolean CREATE_UNSCOPED_RESULTVARS = true;
public static boolean ALLOW_REUSE_PARTITION_VARS = true; //reuse
partition input matrices, applied only if read-only in surrounding loops
public static final int WRITE_REPLICATION_FACTOR = 1;
- public static final int MAX_RETRYS_ON_ERROR = 1;
+ public static int MAX_RETRYS_ON_ERROR = 1;
public static final boolean FORCE_CP_ON_REMOTE_SPARK = true; //
compile body to CP if exec type forced to Spark
public static final boolean LIVEVAR_AWARE_EXPORT = true; //
export only read variables according to live variable analysis
public static final boolean RESET_RECOMPILATION_FLAGs = true;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 9781d88..e2ac467 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -80,8 +80,10 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
String funcName2 = Builtins.getInternalFName(funcName, dt1);
if( !ec.getProgram().containsFunctionProgramBlock(nsName,
funcName)) {
nsName = DMLProgram.BUILTIN_NAMESPACE;
- if(
!ec.getProgram().containsFunctionProgramBlock(nsName, funcName2) )
- compileFunctionProgramBlock(funcName, dt1,
ec.getProgram());
+ synchronized(ec.getProgram()) { //prevent concurrent
recompile/prog modify
+ if(
!ec.getProgram().containsFunctionProgramBlock(nsName, funcName2) )
+ compileFunctionProgramBlock(funcName,
dt1, ec.getProgram());
+ }
funcName = funcName2;
}
@@ -187,8 +189,8 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
if( !prog.containsFunctionProgramBlock(null,
fsb.getKey(), false) ) {
FunctionProgramBlock fpb =
(FunctionProgramBlock) dmlt
.createRuntimeProgramBlock(prog,
fsb.getValue(), ConfigurationManager.getDMLConfig());
- prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, true); // optimized
- prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, false); // unoptimized -> eval
+ prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, true); // optimized
+ prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, false); // unoptimized -> eval
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java
index 165fbf5..f0c81f4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java
@@ -22,8 +22,6 @@ package org.apache.sysds.runtime.instructions.spark;
import java.util.HashSet;
import java.util.Set;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.spark.api.java.JavaPairRDD;
@@ -50,9 +48,7 @@ import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.utils.Statistics;
public class CSVReblockSPInstruction extends UnarySPInstruction {
-
- private static final Log LOG =
LogFactory.getLog(CSVReblockSPInstruction.class.getName());
-
+
private int _blen;
private boolean _hasHeader;
private String _delim;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java
index 3e3939a..4fc175d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java
@@ -26,8 +26,6 @@ import java.util.Iterator;
import java.util.List;
import java.util.Set;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
@@ -79,8 +77,6 @@ import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;
public class RDDConverterUtils {
- private static final Log LOG =
LogFactory.getLog(RDDConverterUtils.class.getName());
-
public static final String DF_ID_COLUMN = "__INDEX";
public static JavaPairRDD<MatrixIndexes, MatrixBlock>
textCellToBinaryBlock(JavaSparkContext sc,
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index b6849ee..f038994 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -37,7 +37,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
index 80793c6..45cd8cb 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.functions.builtin;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -30,7 +31,8 @@ import org.junit.Test;
public class BuiltinHyperbandTest extends AutomatedTestBase
{
- private final static String TEST_NAME = "HyperbandLM";
+ private final static String TEST_NAME1 = "HyperbandLM";
+ private final static String TEST_NAME2 = "HyperbandLM2";
private final static String TEST_DIR = "functions/builtin/";
private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinHyperbandTest.class.getSimpleName() + "/";
@@ -39,26 +41,35 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
@Override
public void setUp() {
- addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"R"}));
+ addTestConfiguration(TEST_NAME1,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"R"}));
+ addTestConfiguration(TEST_NAME2,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"R"}));
}
@Test
public void testHyperbandCP() {
- runHyperband(ExecType.CP);
+ runHyperband(TEST_NAME1, ExecType.CP);
+ }
+
+ @Test
+ public void testHyperbandNoCompareCP() {
+ runHyperband(TEST_NAME2, ExecType.CP);
}
@Test
public void testHyperbandSpark() {
- runHyperband(ExecType.SPARK);
+ runHyperband(TEST_NAME2, ExecType.SPARK);
}
- private void runHyperband(ExecType et) {
+ private void runHyperband(String testname, ExecType et) {
ExecMode modeOld = setExecMode(et);
+ int retries = ParForProgramBlock.MAX_RETRYS_ON_ERROR;
+
try {
- loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ loadTestConfiguration(getTestConfiguration(testname));
String HOME = SCRIPT_DIR + TEST_DIR;
-
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ ParForProgramBlock.MAX_RETRYS_ON_ERROR = 0;
+
+ fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[] {"-stats","-args",
input("X"), input("y"), output("R")};
double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8,
3);
double[][] y = getRandomMatrix(rows, 1, 0, 1, 0.8, 7);
@@ -68,9 +79,11 @@ public class BuiltinHyperbandTest extends AutomatedTestBase
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
//expected loss smaller than default invocation
-
Assert.assertTrue(TestUtils.readDMLBoolean(output("R")));
+ if( testname.equals(TEST_NAME1) )
+
Assert.assertTrue(TestUtils.readDMLBoolean(output("R")));
}
finally {
+ ParForProgramBlock.MAX_RETRYS_ON_ERROR = retries;
resetExecMode(modeOld);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/pipelines/CleaningTestClassification.java
b/src/test/java/org/apache/sysds/test/functions/pipelines/CleaningTestClassification.java
index 0f5f556..1a6e4b7 100644
---
a/src/test/java/org/apache/sysds/test/functions/pipelines/CleaningTestClassification.java
+++
b/src/test/java/org/apache/sysds/test/functions/pipelines/CleaningTestClassification.java
@@ -50,7 +50,6 @@ public class CleaningTestClassification extends
AutomatedTestBase {
addTestConfiguration(TEST_NAME2,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"R"}));
}
-
@Ignore
public void testCP1() {
runFindPipelineTest(0.1, 5,10, 2,
diff --git a/src/test/scripts/functions/builtin/HyperbandLM2.dml
b/src/test/scripts/functions/builtin/HyperbandLM2.dml
new file mode 100644
index 0000000..0c851d0
--- /dev/null
+++ b/src/test/scripts/functions/builtin/HyperbandLM2.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+y = read($2);
+numTrSamples = 100;
+numValSamples = 100;
+
+X_train = X[1:numTrSamples,];
+y_train = y[1:numTrSamples,];
+X_val = X[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+y_val = y[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+X_test = X[(numTrSamples+numValSamples+2):nrow(X),];
+y_test = y[(numTrSamples+numValSamples+2):nrow(X),];
+
+params = list("reg");
+paramRanges = matrix("0 20", rows=1, cols=2);
+
+[bestWeights, optHyperParams] = hyperband(X_train=X_train, y_train=y_train,
+ X_val=X_val, y_val=y_val, params=params, paramRanges=paramRanges);
+
+print(toString(optHyperParams))