This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 275d442 [SYSTEMDS-3010] Add lineage support for eval and list
operations
275d442 is described below
commit 275d4423a69252d3af9d568e62450262acf9e2ce
Author: arnabp <[email protected]>
AuthorDate: Mon Jun 7 14:17:51 2021 +0200
[SYSTEMDS-3010] Add lineage support for eval and list operations
This patch adds lineage support for eval and list operations.
Every list object materializes the lineage traces corresponding
to the data items. List operations then maintain the lineage items.
The eval instruction gathers the lineage items from the argument
list and passes them to the function.
In addition, this patch adds tests to apply lineage based full reuse
on the gridSearch builtin for LM and MultiLogReg.
---
.../sysds/runtime/instructions/Instruction.java | 4 +-
.../instructions/cp/EvalNaryCPInstruction.java | 15 +++-
.../instructions/cp/FunctionCallCPInstruction.java | 18 ++++-
.../instructions/cp/ListIndexingCPInstruction.java | 13 +++-
.../sysds/runtime/instructions/cp/ListObject.java | 19 +++++
.../cp/ParameterizedBuiltinCPInstruction.java | 24 +++++-
.../cp/ScalarBuiltinNaryCPInstruction.java | 8 +-
.../org/apache/sysds/runtime/lineage/Lineage.java | 4 +
.../apache/sysds/runtime/lineage/LineageCache.java | 13 ++++
.../sysds/runtime/lineage/LineageCacheConfig.java | 4 +-
.../apache/sysds/runtime/lineage/LineageMap.java | 3 +-
.../test/functions/lineage/LineageReuseAlg.java | 3 +-
...eageReuseAlg.java => LineageReuseEvalTest.java} | 88 +++++++---------------
.../scripts/functions/builtin/GridSearchLM2.dml | 2 +-
.../LineageReuseEval1.dml} | 19 ++---
.../LineageReuseEval2.dml} | 30 ++++----
16 files changed, 156 insertions(+), 111 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index 7be54e0..e0fbeca 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -229,11 +229,11 @@ public abstract class Instruction
* @param ec execution context
* @return instruction
*/
- public Instruction preprocessInstruction(ExecutionContext ec){
+ public Instruction preprocessInstruction(ExecutionContext ec) {
// Lineage tracing
if (DMLScript.LINEAGE)
ec.traceLineage(this);
- //return instruction ifself
+ //return the instruction itself
return this;
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 0970a37..4548b7b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -27,6 +27,7 @@ import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.conf.ConfigurationManager;
@@ -46,6 +47,7 @@ import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.DataConverter;
@@ -120,6 +122,7 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
//4. expand list arguments if needed
CPOperand[] boundInputs2 = null;
+ LineageItem[] lineageInputs = null;
if( boundInputs.length == 1 &&
boundInputs[0].getDataType().isList()
&& !(fpb.getInputParams().size() == 1 &&
fpb.getInputParams().get(0).getDataType().isList()))
{
@@ -135,11 +138,13 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
boundInputs2[i] = new CPOperand(varName, in);
}
boundInputs = boundInputs2;
+ lineageInputs = DMLScript.LINEAGE
+ ? lo.getLineageItems().toArray(new
LineageItem[lo.getLength()]) : null;
}
//5. call the function (to unoptimized function)
FunctionCallCPInstruction fcpi = new
FunctionCallCPInstruction(nsName, funcName,
- false, boundInputs, fpb.getInputParamNames(),
boundOutputNames, "eval func");
+ false, boundInputs, lineageInputs,
fpb.getInputParamNames(), boundOutputNames, "eval func");
fcpi.processInstruction(ec);
//6. convert the result to matrix
@@ -251,8 +256,12 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
private static ListObject reorderNamedListForFunctionCall(ListObject
in, List<String> fArgNames) {
List<Data> sortedData = new ArrayList<>();
- for( String name : fArgNames )
+ List<LineageItem> sortedLI = DMLScript.LINEAGE ? new
ArrayList<>() : null;
+ for( String name : fArgNames ) {
sortedData.add(in.getData(name));
- return new ListObject(sortedData, new ArrayList<>(fArgNames));
+ if (DMLScript.LINEAGE)
+ sortedLI.add(in.getLineageItem(name));
+ }
+ return new ListObject(sortedData, new ArrayList<>(fArgNames),
sortedLI);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index f8e0e6a..df46b35 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -56,23 +56,31 @@ public class FunctionCallCPInstruction extends
CPInstruction {
private final String _namespace;
private final boolean _opt;
private final CPOperand[] _boundInputs;
+ private final LineageItem[] _lineageInputs;
private final List<String> _boundInputNames;
private final List<String> _funArgNames;
private final List<String> _boundOutputNames;
public FunctionCallCPInstruction(String namespace, String functName,
boolean opt,
- CPOperand[] boundInputs, List<String> funArgNames, List<String>
boundOutputNames, String istr) {
+ CPOperand[] boundInputs, LineageItem[] lineageInputs,
List<String> funArgNames,
+ List<String> boundOutputNames, String istr) {
super(CPType.FCall, null, functName, istr);
_functionName = functName;
_namespace = namespace;
_opt = opt;
_boundInputs = boundInputs;
+ _lineageInputs = lineageInputs;
_boundInputNames = Arrays.stream(boundInputs).map(i ->
i.getName())
.collect(Collectors.toCollection(ArrayList::new));
_funArgNames = funArgNames;
_boundOutputNames = boundOutputNames;
}
+ public FunctionCallCPInstruction(String namespace, String functName,
boolean opt,
+ CPOperand[] boundInputs, List<String> funArgNames, List<String>
boundOutputNames, String istr) {
+ this(namespace, functName, opt, boundInputs, null, funArgNames,
boundOutputNames, istr);
+ }
+
public String getFunctionName() {
return _functionName;
}
@@ -125,8 +133,10 @@ public class FunctionCallCPInstruction extends
CPInstruction {
}
// check if function outputs can be reused from cache
- LineageItem[] liInputs = LineageCacheConfig.isMultiLevelReuse()
|| DMLScript.LINEAGE_ESTIMATE ?
- LineageItemUtils.getLineage(ec, _boundInputs) : null;
+ LineageItem[] liInputs = _lineageInputs;
+ if (_lineageInputs == null)
+ liInputs = (LineageCacheConfig.isMultiLevelReuse() ||
DMLScript.LINEAGE_ESTIMATE)
+ ? LineageItemUtils.getLineage(ec, _boundInputs)
: null;
if (!fpb.isNondeterministic() && reuseFunctionOutputs(liInputs,
fpb, ec))
return; //only if all the outputs are found in cache
@@ -164,7 +174,7 @@ public class FunctionCallCPInstruction extends
CPInstruction {
//map lineage to function arguments
if( lineage != null ) {
- LineageItem litem = ec.getLineageItem(input);
+ LineageItem litem = _lineageInputs == null ?
ec.getLineageItem(input) : _lineageInputs[i];
lineage.set(currFormalParam.getName(),
(litem!=null) ?
litem :
ec.getLineage().getOrCreate(input));
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
index bf45f8c..86b5654 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.RightIndex;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
@@ -67,6 +68,7 @@ public final class ListIndexingCPInstruction extends
IndexingCPInstruction {
//execute right indexing operation and set output
if( input2.getDataType().isList() ) { //LIST <- LIST
+ //TODO: copy the lineage trace of input2 list
to input1 list
ListObject rin = (ListObject)
ec.getVariable(input2.getName());
if( rl.getValueType()==ValueType.STRING ||
ru.getValueType()==ValueType.STRING )
ec.setVariable(output.getName(),
lin.copy().set(rl.getStringValue(), ru.getStringValue(), rin));
@@ -75,18 +77,21 @@ public final class ListIndexingCPInstruction extends
IndexingCPInstruction {
}
else if( input2.getDataType().isScalar() ) { //LIST <-
SCALAR
ScalarObject scalar = ec.getScalarInput(input2);
+ //LineageItem li = DMLScript.LINEAGE ?
LineageItemUtils.getLineage(ec, input2)[0] : null;
+ LineageItem li = DMLScript.LINEAGE ?
ec.getLineage().getOrCreate(input2) : null;
if( rl.getValueType()==ValueType.STRING )
- ec.setVariable(output.getName(),
lin.copy().set(rl.getStringValue(), scalar));
+ ec.setVariable(output.getName(),
lin.copy().set(rl.getStringValue(), scalar, li));
else
- ec.setVariable(output.getName(),
lin.copy().set((int)rl.getLongValue()-1, scalar));
+ ec.setVariable(output.getName(),
lin.copy().set((int)rl.getLongValue()-1, scalar, li));
}
else if( input2.getDataType().isMatrix() ) { //LIST <-
MATRIX/FRAME
CacheableData<?> dat =
ec.getCacheableData(input2);
dat.enableCleanup(false);
+ LineageItem li = DMLScript.LINEAGE ?
ec.getLineage().get(input2) : null;
if( rl.getValueType()==ValueType.STRING )
- ec.setVariable(output.getName(),
lin.copy().set(rl.getStringValue(), dat));
+ ec.setVariable(output.getName(),
lin.copy().set(rl.getStringValue(), dat, li));
else
- ec.setVariable(output.getName(),
lin.copy().set((int)rl.getLongValue()-1, dat));
+ ec.setVariable(output.getName(),
lin.copy().set((int)rl.getLongValue()-1, dat, li));
}
else {
throw new DMLRuntimeException("Unsupported list
"
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index edfd6cc..f6ac15a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -140,6 +140,12 @@ public class ListObject extends Data implements
Externalizable {
return slice(name);
}
+ public LineageItem getLineageItem(String name) {
+ //lookup position by name, incl error handling
+ int pos = getPosForName(name);
+ return getLineageItem(pos);
+ }
+
public List<LineageItem> getLineageItems() {
return _lineage;
}
@@ -206,6 +212,12 @@ public class ListObject extends Data implements
Externalizable {
_data.set(ix, data);
return this;
}
+
+ public ListObject set(int ix, Data data, LineageItem li) {
+ _data.set(ix, data);
+ if (li != null) _lineage.set(ix, li);
+ return this;
+ }
public ListObject set(int ix1, int ix2, ListObject data) {
int range = ix2 - ix1 + 1;
@@ -242,6 +254,13 @@ public class ListObject extends Data implements
Externalizable {
return set(pos, data);
}
+ public Data set(String name, Data data, LineageItem li) {
+ //lookup position by name, incl error handling
+ int pos = getPosForName(name);
+ //set entry into position
+ return set(pos, data, li);
+ }
+
public ListObject set(String name1, String name2, ListObject data) {
//lookup positions by name, incl error handling
int pos1 = getPosForName(name1);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index c8a1166..54a5339 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -30,6 +30,7 @@ import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.Lop;
@@ -389,8 +390,19 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
.collect(Collectors.toList());
List<String> names = new ArrayList<>(params.keySet());
- // create list object over all inputs
- ListObject list = new ListObject(data, names);
+ ListObject list = null;
+ if (DMLScript.LINEAGE) {
+ CPOperand[] listOperands = names.stream().map(n
-> ec.containsVariable(params.get(n))
+ ? new CPOperand(n,
ec.getVariable(params.get(n)))
+ :
getStringLiteral(n)).toArray(CPOperand[]::new);
+ LineageItem[] liList =
LineageItemUtils.getLineage(ec, listOperands);
+ // create list object over all inputs w/ the
corresponding lineage items
+ list = new ListObject(data, names,
Arrays.asList(liList));
+ }
+ else
+ // create list object over all inputs
+ list = new ListObject(data, names);
+
list.deriveAndSetStatusFromData();
ec.setVariable(output.getName(), list);
@@ -479,6 +491,14 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
return Pair.of(output.getName(),
new LineageItem(getOpcode(),
LineageItemUtils.getLineage(ec, target, meta, spec)));
}
+ else if (opcode.equalsIgnoreCase("nvlist")) {
+ List<String> names = new ArrayList<>(params.keySet());
+ CPOperand[] listOperands = names.stream().map(n ->
ec.containsVariable(params.get(n))
+ ? new CPOperand(n,
ec.getVariable(params.get(n)))
+ :
getStringLiteral(n)).toArray(CPOperand[]::new);
+ return Pair.of(output.getName(),
+ new LineageItem(getOpcode(),
LineageItemUtils.getLineage(ec, listOperands)));
+ }
else {
// NOTE: for now, we cannot have a generic fall through
path, because the
// data and value types of parmeters are not compiled
into the instruction
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
index 7fcef80..41fec41 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
@@ -95,11 +95,15 @@ public class ScalarBuiltinNaryCPInstruction extends
BuiltinNaryCPInstruction imp
}
else if( "list".equals(getOpcode()) ) {
//obtain all input data objects, incl handling of
literals
- List<Data> data = (inputs== null) ? new ArrayList<>() :
+ List<Data> data = (inputs == null) ? new ArrayList<>() :
Arrays.stream(inputs).map(in ->
ec.getVariable(in)).collect(Collectors.toList());
+ List<LineageItem> li = null;
+ if (DMLScript.LINEAGE)
+ li = (inputs == null) ? new ArrayList<>() :
+ Arrays.stream(inputs).map(in ->
ec.getLineage().get(in)).collect(Collectors.toList());
//create list object over all inputs
- ListObject list = new ListObject(data);
+ ListObject list = new ListObject(data, null, li);
list.deriveAndSetStatusFromData();
ec.setVariable(output.getName(), list);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
index 86bf0c5..af7b6e0 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
@@ -53,6 +53,10 @@ public class Lineage {
}
public void trace(Instruction inst, ExecutionContext ec) {
+ if (inst.getOpcode().equalsIgnoreCase("toString"))
+ //Silently skip toString. TODO: trace toString
+ return;
+
if (_activeDedupBlock == null ||
!_activeDedupBlock.isAllPathsTaken() ||
!LineageCacheConfig.ReuseCacheType.isNone())
_map.trace(inst, ec);
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index b366edb..8450fb5 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -234,6 +234,9 @@ public class LineageCache
Data boundValue = null;
//convert to matrix object
if (e.isMatrixValue()) {
+ MatrixBlock mb = e.getMBValue();
+ if (mb == null && e.getCacheStatus() ==
LineageCacheStatus.NOTCACHED)
+ return false; //the executing
thread removed this entry from cache
MetaDataFormat md = new MetaDataFormat(
e.getMBValue().getDataCharacteristics(),FileFormat.BINARY);
boundValue = new
MatrixObject(ValueType.FP64, boundVarName, md);
@@ -242,6 +245,8 @@ public class LineageCache
}
else {
boundValue = e.getSOValue();
+ if (boundValue == null &&
e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
+ return false; //the executing
thread removed this entry from cache
}
funcOutputs.put(boundVarName, boundValue);
@@ -310,6 +315,11 @@ public class LineageCache
Data outValue = null;
//convert to matrix object
if (e.isMatrixValue()) {
+ MatrixBlock mb = e.getMBValue();
+ if (mb == null && e.getCacheStatus() ==
LineageCacheStatus.NOTCACHED)
+ //the executing thread removed this
entry from cache
+ return new
FederatedResponse(FederatedResponse.ResponseType.ERROR);
+
MetaDataFormat md = new MetaDataFormat(
e.getMBValue().getDataCharacteristics(),FileFormat.BINARY);
outValue = new MatrixObject(ValueType.FP64,
outName, md);
@@ -318,6 +328,9 @@ public class LineageCache
}
else {
outValue = e.getSOValue();
+ if (outValue == null && e.getCacheStatus() ==
LineageCacheStatus.NOTCACHED)
+ //the executing thread removed this
entry from cache
+ return new
FederatedResponse(FederatedResponse.ResponseType.ERROR);
}
udfOutputs.put(outName, outValue);
savedComputeTime = e._computeTime;
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index b7b66a7..e2027f4 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -195,9 +195,9 @@ public class LineageCacheConfig
}
public static boolean isReusable (Instruction inst, ExecutionContext
ec) {
- boolean insttype = inst instanceof ComputationCPInstruction
+ boolean insttype = (inst instanceof ComputationCPInstruction
|| inst instanceof ComputationFEDInstruction
- || inst instanceof GPUInstruction
+ || inst instanceof GPUInstruction)
&& !(inst instanceof ListIndexingCPInstruction);
boolean rightop = (ArrayUtils.contains(REUSE_OPCODES,
inst.getOpcode())
|| (inst.getOpcode().equals("append") &&
isVectorAppend(inst, ec))
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
index e9d9368..f3c1c0a 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
@@ -167,7 +167,8 @@ public class LineageMap {
break;
}
case Write: {
- processWriteLI(vcp_inst.getInput1(),
vcp_inst.getInput2(), ec);
+ if (!vcp_inst.getInput1().isLiteral())
+
processWriteLI(vcp_inst.getInput1(), vcp_inst.getInput2(), ec);
break;
}
case MoveVariable: {
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
index 24cdc6a..46be9c1 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
@@ -26,7 +26,6 @@ import java.util.List;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
-import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
@@ -101,7 +100,7 @@ public class LineageReuseAlg extends LineageBase {
public void testLineageTrace(String testname, ReuseCacheType reuseType)
{
boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
- ExecMode platformOld = setExecMode(ExecType.CP);
+ ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
try {
LOG.debug("------------ BEGIN " + testname +
"------------");
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseEvalTest.java
similarity index 62%
copy from
src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
copy to
src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseEvalTest.java
index 24cdc6a..cb3b434 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseEvalTest.java
@@ -24,22 +24,22 @@ import java.util.HashMap;
import java.util.List;
import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
-import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
import org.junit.Test;
-public class LineageReuseAlg extends LineageBase {
+public class LineageReuseEvalTest extends LineageBase {
protected static final String TEST_DIR = "functions/lineage/";
- protected static final String TEST_NAME = "LineageReuseAlg";
- protected static final int TEST_VARIANTS = 6;
- protected String TEST_CLASS_DIR = TEST_DIR +
LineageReuseAlg.class.getSimpleName() + "/";
+ protected static final String TEST_NAME = "LineageReuseEval";
+ protected static final int TEST_VARIANTS = 2;
+ protected String TEST_CLASS_DIR = TEST_DIR +
LineageReuseEvalTest.class.getSimpleName() + "/";
@Override
public void setUp() {
@@ -48,67 +48,25 @@ public class LineageReuseAlg extends LineageBase {
addTestConfiguration(TEST_NAME+i, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
}
+ //FIXME: These tests fail/get stuck in a deadlock if MultiLevel/Hybrid
is used.
+ //This problem is not yet reproducible locally.
@Test
- public void testStepLMHybrid() {
- testLineageTrace(TEST_NAME+"1", ReuseCacheType.REUSE_HYBRID);
- }
-
- @Test
- public void testGridSearchLMHybrid() {
- testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_HYBRID);
- }
-
- @Test
- public void testMultiLogRegHybrid() {
- testLineageTrace(TEST_NAME+"3", ReuseCacheType.REUSE_HYBRID);
- }
-
- @Test
- public void testPCAHybrid() {
- testLineageTrace(TEST_NAME+"4", ReuseCacheType.REUSE_HYBRID);
- }
-
- @Test
- public void testGridSearchL2svmHybrid() {
- testLineageTrace(TEST_NAME+"5", ReuseCacheType.REUSE_HYBRID);
- }
-
- @Test
- public void testPCA_LM_pipeline() {
- testLineageTrace(TEST_NAME+"6", ReuseCacheType.REUSE_HYBRID);
- }
-
- @Test
- public void testStepLMFull() {
+ public void testGridsearchLM() {
testLineageTrace(TEST_NAME+"1", ReuseCacheType.REUSE_FULL);
}
-
- @Test
- public void testGridSearchLMFull() {
- testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_FULL);
- }
@Test
- public void testMultiLogRegFull() {
- testLineageTrace(TEST_NAME+"3", ReuseCacheType.REUSE_FULL);
- }
-
- @Test
- public void testPCAFull() {
- testLineageTrace(TEST_NAME+"4", ReuseCacheType.REUSE_FULL);
+ public void testGridSearchMLR() {
+ testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_FULL);
+ //FIXME: 2x slower with reuse. Heavy hitter function is
lineageitem equals.
+ //This problem only exists with parfor.
}
-
+
public void testLineageTrace(String testname, ReuseCacheType reuseType)
{
- boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
- boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
- ExecMode platformOld = setExecMode(ExecType.CP);
+ ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
try {
LOG.debug("------------ BEGIN " + testname +
"------------");
-
- OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false;
-
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
@@ -121,6 +79,8 @@ public class LineageReuseAlg extends LineageBase {
Lineage.resetInternalState();
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> X_orig =
readDMLMatrixFromOutputDir("X");
+ //long numlmDS =
Statistics.getCPHeavyHitterCount("m_lmDS");
+ long numMM = Statistics.getCPHeavyHitterCount("ba+*");
// With lineage-based reuse enabled
proArgs.clear();
@@ -131,19 +91,25 @@ public class LineageReuseAlg extends LineageBase {
proArgs.add(output("X"));
programArgs = proArgs.toArray(new
String[proArgs.size()]);
Lineage.resetInternalState();
- Lineage.setLinReuseFull();
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> X_reused =
readDMLMatrixFromOutputDir("X");
+ //long numlmDS_reuse =
Statistics.getCPHeavyHitterCount("m_lmDS");
+ long numMM_reuse =
Statistics.getCPHeavyHitterCount("ba+*");
Lineage.setLinReuseNone();
TestUtils.compareMatrices(X_orig, X_reused, 1e-6,
"Origin", "Reused");
+
+ if (testname.equalsIgnoreCase("LineageReuseEval1")) {
//gridSearchLM
+ //lmDS call should be reused for all the 7
values of tolerance
+ //Assert.assertTrue("Violated lmDS reuse count:
7 * "+numlmDS_reuse+" == "+numlmDS,
+ // 7*numlmDS_reuse == numlmDS);
+ Assert.assertTrue("Violated ba+* reuse count:
"+numMM_reuse+" < "+numMM, numMM_reuse < numMM);
+ }
}
finally {
- OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES =
old_sum_product;
rtplatform = platformOld;
Recompiler.reinitRecompiler();
}
}
-}
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/builtin/GridSearchLM2.dml
b/src/test/scripts/functions/builtin/GridSearchLM2.dml
index 278d94c..c172794 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM2.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM2.dml
@@ -36,7 +36,7 @@ Xtest = X[(N+1):nrow(X),];
ytest = y[(N+1):nrow(X),];
params = list("icpt","reg", "tol", "maxi");
-paramRanges = list(seq(0,1,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+paramRanges = list(seq(0,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
numB=ncol(X)+1, params=params, paramValues=paramRanges);
B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM2.dml
b/src/test/scripts/functions/lineage/LineageReuseEval1.dml
similarity index 79%
copy from src/test/scripts/functions/builtin/GridSearchLM2.dml
copy to src/test/scripts/functions/lineage/LineageReuseEval1.dml
index 278d94c..00a1b39 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM2.dml
+++ b/src/test/scripts/functions/lineage/LineageReuseEval1.dml
@@ -19,15 +19,15 @@
#
#-------------------------------------------------------------
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B)
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B)
return (Matrix[Double] loss)
{
yhat = lmPredict(X=X, B=B, ytest=y)
loss = as.matrix(sum((y - yhat)^2));
}
-X = read($1);
-y = read($2);
+X = rand(rows=300, cols=20, sparsity=1.0, seed=1);
+y = rand(rows=300, cols=1, sparsity=1.0, seed=1);
N = 200;
Xtrain = X[1:N,];
@@ -35,14 +35,11 @@ ytrain = y[1:N,];
Xtest = X[(N+1):nrow(X),];
ytest = y[(N+1):nrow(X),];
-params = list("icpt","reg", "tol", "maxi");
-paramRanges = list(seq(0,1,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+params = list("icpt","reg", "tol"); #numValues = 3, 5, 7
+paramRanges = list(seq(0,2), 10^seq(0,-4), 10^seq(-6,-12)); #3*5*7 = 105
[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
- numB=ncol(X)+1, params=params, paramValues=paramRanges);
-B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
-
+ numB=ncol(X)+1, params=params, paramValues=paramRanges, verbose=FALSE);
l1 = l2norm(Xtest, ytest, B1);
-l2 = l2norm(Xtest, ytest, B2);
-R = as.scalar(l1 < l2);
-write(R, $3)
+write(l1, $1, format="text");
+
diff --git a/src/test/scripts/functions/builtin/GridSearchLM2.dml
b/src/test/scripts/functions/lineage/LineageReuseEval2.dml
similarity index 55%
copy from src/test/scripts/functions/builtin/GridSearchLM2.dml
copy to src/test/scripts/functions/lineage/LineageReuseEval2.dml
index 278d94c..2263d77 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM2.dml
+++ b/src/test/scripts/functions/lineage/LineageReuseEval2.dml
@@ -19,15 +19,14 @@
#
#-------------------------------------------------------------
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B)
- return (Matrix[Double] loss)
-{
- yhat = lmPredict(X=X, B=B, ytest=y)
- loss = as.matrix(sum((y - yhat)^2));
+accuracy = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B)
return (Matrix[Double] err) {
+ [M,yhat,acc] = multiLogRegPredict(X=X, B=B, Y=y, verbose=FALSE);
+ err = as.matrix(1-(acc/100));
}
-X = read($1);
-y = read($2);
+X = rand(rows=300, cols=20, sparsity=1.0, seed=1);
+y = rand(rows=300, cols=1, min=1, max=3, sparsity=1.0, seed=1);
+y = floor(y);
N = 200;
Xtrain = X[1:N,];
@@ -35,14 +34,13 @@ ytrain = y[1:N,];
Xtest = X[(N+1):nrow(X),];
ytest = y[(N+1):nrow(X),];
-params = list("icpt","reg", "tol", "maxi");
-paramRanges = list(seq(0,1,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
-[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
- numB=ncol(X)+1, params=params, paramValues=paramRanges);
-B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
+params = list("icpt", "reg", "maxii");
+paramRanges = list(seq(0,2),10^seq(1,-6), 10^seq(1,3));
+trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100,
maxii=-1, verbose=FALSE);
+[B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg",
predict="accuracy", numB=ncol(X)+1,
+ params=params, paramValues=paramRanges, trainArgs=trainArgs, verbose=FALSE);
-l1 = l2norm(Xtest, ytest, B1);
-l2 = l2norm(Xtest, ytest, B2);
-R = as.scalar(l1 < l2);
+[M,yhat,acc] = multiLogRegPredict(X=Xtest, B=B1, Y=ytest, verbose=FALSE);
+
+write(yhat, $1, format="text");
-write(R, $3)