This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 275d442  [SYSTEMDS-3010] Add lineage support for eval and list 
operations
275d442 is described below

commit 275d4423a69252d3af9d568e62450262acf9e2ce
Author: arnabp <[email protected]>
AuthorDate: Mon Jun 7 14:17:51 2021 +0200

    [SYSTEMDS-3010] Add lineage support for eval and list operations
    
    This patch adds lineage support for eval and list operations.
    Every list object materializes the lineage traces corresponding
    to the data items. List operations then maintain the lineage items.
    The eval instruction gathers the lineage items from the argument
    list and passes them to the function.
    In addition, this patch adds tests to apply lineage based full reuse
    on the gridSearch builtin for LM and MultiLogReg.
---
 .../sysds/runtime/instructions/Instruction.java    |  4 +-
 .../instructions/cp/EvalNaryCPInstruction.java     | 15 +++-
 .../instructions/cp/FunctionCallCPInstruction.java | 18 ++++-
 .../instructions/cp/ListIndexingCPInstruction.java | 13 +++-
 .../sysds/runtime/instructions/cp/ListObject.java  | 19 +++++
 .../cp/ParameterizedBuiltinCPInstruction.java      | 24 +++++-
 .../cp/ScalarBuiltinNaryCPInstruction.java         |  8 +-
 .../org/apache/sysds/runtime/lineage/Lineage.java  |  4 +
 .../apache/sysds/runtime/lineage/LineageCache.java | 13 ++++
 .../sysds/runtime/lineage/LineageCacheConfig.java  |  4 +-
 .../apache/sysds/runtime/lineage/LineageMap.java   |  3 +-
 .../test/functions/lineage/LineageReuseAlg.java    |  3 +-
 ...eageReuseAlg.java => LineageReuseEvalTest.java} | 88 +++++++---------------
 .../scripts/functions/builtin/GridSearchLM2.dml    |  2 +-
 .../LineageReuseEval1.dml}                         | 19 ++---
 .../LineageReuseEval2.dml}                         | 30 ++++----
 16 files changed, 156 insertions(+), 111 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index 7be54e0..e0fbeca 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -229,11 +229,11 @@ public abstract class Instruction
         * @param ec execution context
         * @return instruction
         */
-       public Instruction preprocessInstruction(ExecutionContext ec){
+       public Instruction preprocessInstruction(ExecutionContext ec) {
                // Lineage tracing
                if (DMLScript.LINEAGE)
                        ec.traceLineage(this);
-               //return instruction ifself
+               //return the instruction itself
                return this;
        }
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 0970a37..4548b7b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -27,6 +27,7 @@ import java.util.Map;
 import java.util.Map.Entry;
 import java.util.stream.Collectors;
 
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Builtins;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.conf.ConfigurationManager;
@@ -46,6 +47,7 @@ import org.apache.sysds.runtime.controlprogram.ProgramBlock;
 import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.util.DataConverter;
@@ -120,6 +122,7 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                
                //4. expand list arguments if needed
                CPOperand[] boundInputs2 = null;
+               LineageItem[] lineageInputs = null;
                if( boundInputs.length == 1 && 
boundInputs[0].getDataType().isList()
                        && !(fpb.getInputParams().size() == 1 && 
fpb.getInputParams().get(0).getDataType().isList()))
                {
@@ -135,11 +138,13 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                                boundInputs2[i] = new CPOperand(varName, in);
                        }
                        boundInputs = boundInputs2;
+                       lineageInputs = DMLScript.LINEAGE 
+                                       ? lo.getLineageItems().toArray(new 
LineageItem[lo.getLength()]) : null;
                }
                
                //5. call the function (to unoptimized function)
                FunctionCallCPInstruction fcpi = new 
FunctionCallCPInstruction(nsName, funcName,
-                       false, boundInputs, fpb.getInputParamNames(), 
boundOutputNames, "eval func");
+                       false, boundInputs, lineageInputs, 
fpb.getInputParamNames(), boundOutputNames, "eval func");
                fcpi.processInstruction(ec);
                
                //6. convert the result to matrix
@@ -251,8 +256,12 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
        
        private static ListObject reorderNamedListForFunctionCall(ListObject 
in, List<String> fArgNames) {
                List<Data> sortedData = new ArrayList<>();
-               for( String name : fArgNames )
+               List<LineageItem> sortedLI = DMLScript.LINEAGE ? new 
ArrayList<>() : null;
+               for( String name : fArgNames ) {
                        sortedData.add(in.getData(name));
-               return new ListObject(sortedData, new ArrayList<>(fArgNames));
+                       if (DMLScript.LINEAGE)
+                               sortedLI.add(in.getLineageItem(name));
+               }
+               return new ListObject(sortedData, new ArrayList<>(fArgNames), 
sortedLI);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index f8e0e6a..df46b35 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -56,23 +56,31 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
        private final String _namespace;
        private final boolean _opt;
        private final CPOperand[] _boundInputs;
+       private final LineageItem[] _lineageInputs;
        private final List<String> _boundInputNames;
        private final List<String> _funArgNames;
        private final List<String> _boundOutputNames;
 
        public FunctionCallCPInstruction(String namespace, String functName, 
boolean opt,
-               CPOperand[] boundInputs, List<String> funArgNames, List<String> 
boundOutputNames, String istr) {
+               CPOperand[] boundInputs, LineageItem[] lineageInputs, 
List<String> funArgNames, 
+               List<String> boundOutputNames, String istr) {
                super(CPType.FCall, null, functName, istr);
                _functionName = functName;
                _namespace = namespace;
                _opt = opt;
                _boundInputs = boundInputs;
+               _lineageInputs = lineageInputs;
                _boundInputNames = Arrays.stream(boundInputs).map(i -> 
i.getName())
                        .collect(Collectors.toCollection(ArrayList::new));
                _funArgNames = funArgNames;
                _boundOutputNames = boundOutputNames;
        }
 
+       public FunctionCallCPInstruction(String namespace, String functName, 
boolean opt,
+               CPOperand[] boundInputs, List<String> funArgNames, List<String> 
boundOutputNames, String istr) {
+               this(namespace, functName, opt, boundInputs, null, funArgNames, 
boundOutputNames, istr);
+       }
+
        public String getFunctionName() {
                return _functionName;
        }
@@ -125,8 +133,10 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                }
                
                // check if function outputs can be reused from cache
-               LineageItem[] liInputs = LineageCacheConfig.isMultiLevelReuse() 
|| DMLScript.LINEAGE_ESTIMATE ?
-                       LineageItemUtils.getLineage(ec, _boundInputs) : null;
+               LineageItem[] liInputs = _lineageInputs;
+               if (_lineageInputs == null)
+                       liInputs = (LineageCacheConfig.isMultiLevelReuse() || 
DMLScript.LINEAGE_ESTIMATE) 
+                               ? LineageItemUtils.getLineage(ec, _boundInputs) 
: null;
                if (!fpb.isNondeterministic() && reuseFunctionOutputs(liInputs, 
fpb, ec))
                        return; //only if all the outputs are found in cache
                
@@ -164,7 +174,7 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                        
                        //map lineage to function arguments
                        if( lineage != null ) {
-                               LineageItem litem = ec.getLineageItem(input);
+                               LineageItem litem = _lineageInputs == null ? 
ec.getLineageItem(input) : _lineageInputs[i];
                                lineage.set(currFormalParam.getName(), 
(litem!=null) ? 
                                        litem : 
ec.getLineage().getOrCreate(input));
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
index bf45f8c..86b5654 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListIndexingCPInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.cp;
 import org.apache.sysds.lops.LeftIndex;
 import org.apache.sysds.lops.RightIndex;
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
@@ -67,6 +68,7 @@ public final class ListIndexingCPInstruction extends 
IndexingCPInstruction {
                        
                        //execute right indexing operation and set output
                        if( input2.getDataType().isList() ) { //LIST <- LIST
+                               //TODO: copy the lineage trace of input2 list 
to input1 list
                                ListObject rin = (ListObject) 
ec.getVariable(input2.getName());
                                if( rl.getValueType()==ValueType.STRING || 
ru.getValueType()==ValueType.STRING  )
                                        ec.setVariable(output.getName(), 
lin.copy().set(rl.getStringValue(), ru.getStringValue(), rin));
@@ -75,18 +77,21 @@ public final class ListIndexingCPInstruction extends 
IndexingCPInstruction {
                        }
                        else if( input2.getDataType().isScalar() ) { //LIST <- 
SCALAR
                                ScalarObject scalar = ec.getScalarInput(input2);
+                               //LineageItem li = DMLScript.LINEAGE ? 
LineageItemUtils.getLineage(ec, input2)[0] : null; 
+                               LineageItem li = DMLScript.LINEAGE ? 
ec.getLineage().getOrCreate(input2) : null; 
                                if( rl.getValueType()==ValueType.STRING )
-                                       ec.setVariable(output.getName(), 
lin.copy().set(rl.getStringValue(), scalar));
+                                       ec.setVariable(output.getName(), 
lin.copy().set(rl.getStringValue(), scalar, li));
                                else
-                                       ec.setVariable(output.getName(), 
lin.copy().set((int)rl.getLongValue()-1, scalar));
+                                       ec.setVariable(output.getName(), 
lin.copy().set((int)rl.getLongValue()-1, scalar, li));
                        }
                        else if( input2.getDataType().isMatrix() ) { //LIST <- 
MATRIX/FRAME
                                CacheableData<?> dat = 
ec.getCacheableData(input2);
                                dat.enableCleanup(false);
+                               LineageItem li = DMLScript.LINEAGE ? 
ec.getLineage().get(input2) : null;
                                if( rl.getValueType()==ValueType.STRING )
-                                       ec.setVariable(output.getName(), 
lin.copy().set(rl.getStringValue(), dat));
+                                       ec.setVariable(output.getName(), 
lin.copy().set(rl.getStringValue(), dat, li));
                                else
-                                       ec.setVariable(output.getName(), 
lin.copy().set((int)rl.getLongValue()-1, dat));
+                                       ec.setVariable(output.getName(), 
lin.copy().set((int)rl.getLongValue()-1, dat, li));
                        }
                        else {
                                throw new DMLRuntimeException("Unsupported list 
"
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index edfd6cc..f6ac15a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -140,6 +140,12 @@ public class ListObject extends Data implements 
Externalizable {
                return slice(name);
        }
        
+       public LineageItem getLineageItem(String name) {
+               //lookup position by name, incl error handling
+               int pos = getPosForName(name);
+               return getLineageItem(pos);
+       }
+       
        public List<LineageItem> getLineageItems() {
                return _lineage;
        }
@@ -206,6 +212,12 @@ public class ListObject extends Data implements 
Externalizable {
                _data.set(ix, data);
                return this;
        }
+
+       public ListObject set(int ix, Data data, LineageItem li) {
+               _data.set(ix, data);
+               if (li != null) _lineage.set(ix, li);
+               return this;
+       }
        
        public ListObject set(int ix1, int ix2, ListObject data) {
                int range = ix2 - ix1 + 1;
@@ -242,6 +254,13 @@ public class ListObject extends Data implements 
Externalizable {
                return set(pos, data);
        }
        
+       public Data set(String name, Data data, LineageItem li) {
+               //lookup position by name, incl error handling
+               int pos = getPosForName(name);
+               //set entry into position
+               return set(pos, data, li);
+       }
+       
        public ListObject set(String name1, String name2, ListObject data) {
                //lookup positions by name, incl error handling
                int pos1 = getPosForName(name1);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index c8a1166..54a5339 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -30,6 +30,7 @@ import java.util.stream.IntStream;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.lops.Lop;
@@ -389,8 +390,19 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                                .collect(Collectors.toList());
                        List<String> names = new ArrayList<>(params.keySet());
 
-                       // create list object over all inputs
-                       ListObject list = new ListObject(data, names);
+                       ListObject list = null;
+                       if (DMLScript.LINEAGE) {
+                               CPOperand[] listOperands = names.stream().map(n 
-> ec.containsVariable(params.get(n)) 
+                                               ? new CPOperand(n, 
ec.getVariable(params.get(n))) 
+                                               : 
getStringLiteral(n)).toArray(CPOperand[]::new);
+                               LineageItem[] liList = 
LineageItemUtils.getLineage(ec, listOperands);
+                               // create list object over all inputs w/ the 
corresponding lineage items
+                               list = new ListObject(data, names, 
Arrays.asList(liList));
+                       }
+                       else
+                               // create list object over all inputs
+                               list = new ListObject(data, names);
+
                        list.deriveAndSetStatusFromData();
 
                        ec.setVariable(output.getName(), list);
@@ -479,6 +491,14 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        return Pair.of(output.getName(),
                                new LineageItem(getOpcode(), 
LineageItemUtils.getLineage(ec, target, meta, spec)));
                }
+               else if (opcode.equalsIgnoreCase("nvlist")) {
+                       List<String> names = new ArrayList<>(params.keySet());
+                       CPOperand[] listOperands = names.stream().map(n -> 
ec.containsVariable(params.get(n)) 
+                                       ? new CPOperand(n, 
ec.getVariable(params.get(n))) 
+                                       : 
getStringLiteral(n)).toArray(CPOperand[]::new);
+                       return Pair.of(output.getName(), 
+                               new LineageItem(getOpcode(), 
LineageItemUtils.getLineage(ec, listOperands)));
+               }
                else {
                        // NOTE: for now, we cannot have a generic fall through 
path, because the
                        // data and value types of parmeters are not compiled 
into the instruction
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
index 7fcef80..41fec41 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ScalarBuiltinNaryCPInstruction.java
@@ -95,11 +95,15 @@ public class ScalarBuiltinNaryCPInstruction extends 
BuiltinNaryCPInstruction imp
                }
                else if( "list".equals(getOpcode()) ) {
                        //obtain all input data objects, incl handling of 
literals
-                       List<Data> data = (inputs== null) ? new ArrayList<>() :
+                       List<Data> data = (inputs == null) ? new ArrayList<>() :
                                Arrays.stream(inputs).map(in -> 
ec.getVariable(in)).collect(Collectors.toList());
+                       List<LineageItem> li = null;
+                       if (DMLScript.LINEAGE)
+                               li = (inputs == null) ? new ArrayList<>() :
+                                       Arrays.stream(inputs).map(in -> 
ec.getLineage().get(in)).collect(Collectors.toList());
                        
                        //create list object over all inputs
-                       ListObject list = new ListObject(data);
+                       ListObject list = new ListObject(data, null, li);
                        list.deriveAndSetStatusFromData();
                        
                        ec.setVariable(output.getName(), list);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java 
b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
index 86bf0c5..af7b6e0 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
@@ -53,6 +53,10 @@ public class Lineage {
        }
        
        public void trace(Instruction inst, ExecutionContext ec) {
+               if (inst.getOpcode().equalsIgnoreCase("toString"))
+                       //Silently skip toString. TODO: trace toString
+                       return;
+
                if (_activeDedupBlock == null || 
!_activeDedupBlock.isAllPathsTaken() || 
!LineageCacheConfig.ReuseCacheType.isNone())
                        _map.trace(inst, ec);
        }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index b366edb..8450fb5 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -234,6 +234,9 @@ public class LineageCache
                                Data boundValue = null;
                                //convert to matrix object
                                if (e.isMatrixValue()) {
+                                       MatrixBlock mb = e.getMBValue();
+                                       if (mb == null && e.getCacheStatus() == 
LineageCacheStatus.NOTCACHED)
+                                               return false;  //the executing 
thread removed this entry from cache
                                        MetaDataFormat md = new MetaDataFormat(
                                                
e.getMBValue().getDataCharacteristics(),FileFormat.BINARY);
                                        boundValue = new 
MatrixObject(ValueType.FP64, boundVarName, md);
@@ -242,6 +245,8 @@ public class LineageCache
                                }
                                else {
                                        boundValue = e.getSOValue();
+                                       if (boundValue == null && 
e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
+                                               return false;  //the executing 
thread removed this entry from cache
                                }
 
                                funcOutputs.put(boundVarName, boundValue);
@@ -310,6 +315,11 @@ public class LineageCache
                        Data outValue = null;
                        //convert to matrix object
                        if (e.isMatrixValue()) {
+                               MatrixBlock mb = e.getMBValue();
+                               if (mb == null && e.getCacheStatus() == 
LineageCacheStatus.NOTCACHED)
+                                       //the executing thread removed this 
entry from cache
+                                       return new 
FederatedResponse(FederatedResponse.ResponseType.ERROR);
+
                                MetaDataFormat md = new MetaDataFormat(
                                        
e.getMBValue().getDataCharacteristics(),FileFormat.BINARY);
                                outValue = new MatrixObject(ValueType.FP64, 
outName, md);
@@ -318,6 +328,9 @@ public class LineageCache
                        }
                        else {
                                outValue = e.getSOValue();
+                               if (outValue == null && e.getCacheStatus() == 
LineageCacheStatus.NOTCACHED)
+                                       //the executing thread removed this 
entry from cache
+                                       return new 
FederatedResponse(FederatedResponse.ResponseType.ERROR);
                        }
                        udfOutputs.put(outName, outValue);
                        savedComputeTime = e._computeTime;
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index b7b66a7..e2027f4 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -195,9 +195,9 @@ public class LineageCacheConfig
        }
 
        public static boolean isReusable (Instruction inst, ExecutionContext 
ec) {
-               boolean insttype = inst instanceof ComputationCPInstruction 
+               boolean insttype = (inst instanceof ComputationCPInstruction 
                        || inst instanceof ComputationFEDInstruction
-                       || inst instanceof GPUInstruction
+                       || inst instanceof GPUInstruction)
                        && !(inst instanceof ListIndexingCPInstruction);
                boolean rightop = (ArrayUtils.contains(REUSE_OPCODES, 
inst.getOpcode())
                        || (inst.getOpcode().equals("append") && 
isVectorAppend(inst, ec))
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
index e9d9368..f3c1c0a 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
@@ -167,7 +167,8 @@ public class LineageMap {
                                        break;
                                }
                                case Write: {
-                                       processWriteLI(vcp_inst.getInput1(), 
vcp_inst.getInput2(), ec);
+                                       if (!vcp_inst.getInput1().isLiteral())
+                                               
processWriteLI(vcp_inst.getInput1(), vcp_inst.getInput2(), ec);
                                        break;
                                }
                                case MoveVariable: {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
index 24cdc6a..46be9c1 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
@@ -26,7 +26,6 @@ import java.util.List;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
-import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.runtime.lineage.Lineage;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.matrix.data.MatrixValue;
@@ -101,7 +100,7 @@ public class LineageReuseAlg extends LineageBase {
        public void testLineageTrace(String testname, ReuseCacheType reuseType) 
{
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
-               ExecMode platformOld = setExecMode(ExecType.CP);
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
                
                try {
                        LOG.debug("------------ BEGIN " + testname + 
"------------");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseEvalTest.java
similarity index 62%
copy from 
src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
copy to 
src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseEvalTest.java
index 24cdc6a..cb3b434 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseEvalTest.java
@@ -24,22 +24,22 @@ import java.util.HashMap;
 import java.util.List;
 
 import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
-import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.runtime.lineage.Lineage;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
 import org.junit.Test;
 
-public class LineageReuseAlg extends LineageBase {
+public class LineageReuseEvalTest extends LineageBase {
        
        protected static final String TEST_DIR = "functions/lineage/";
-       protected static final String TEST_NAME = "LineageReuseAlg";
-       protected static final int TEST_VARIANTS = 6;
-       protected String TEST_CLASS_DIR = TEST_DIR + 
LineageReuseAlg.class.getSimpleName() + "/";
+       protected static final String TEST_NAME = "LineageReuseEval";
+       protected static final int TEST_VARIANTS = 2;
+       protected String TEST_CLASS_DIR = TEST_DIR + 
LineageReuseEvalTest.class.getSimpleName() + "/";
        
        @Override
        public void setUp() {
@@ -48,67 +48,25 @@ public class LineageReuseAlg extends LineageBase {
                        addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
        }
 
+       //FIXME: These tests fail/get stuck in a deadlock if MultiLevel/Hybrid 
is used.
+       //This problem is not yet reproducible locally. 
        @Test
-       public void testStepLMHybrid() {
-               testLineageTrace(TEST_NAME+"1", ReuseCacheType.REUSE_HYBRID);
-       }
-       
-       @Test
-       public void testGridSearchLMHybrid() {
-               testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_HYBRID);
-       }
-
-       @Test
-       public void testMultiLogRegHybrid() {
-               testLineageTrace(TEST_NAME+"3", ReuseCacheType.REUSE_HYBRID);
-       }
-
-       @Test
-       public void testPCAHybrid() {
-               testLineageTrace(TEST_NAME+"4", ReuseCacheType.REUSE_HYBRID);
-       }
-
-       @Test
-       public void testGridSearchL2svmHybrid() {
-               testLineageTrace(TEST_NAME+"5", ReuseCacheType.REUSE_HYBRID);
-       }
-
-       @Test
-       public void testPCA_LM_pipeline() {
-               testLineageTrace(TEST_NAME+"6", ReuseCacheType.REUSE_HYBRID);
-       }
-       
-       @Test
-       public void testStepLMFull() {
+       public void testGridsearchLM() {
                testLineageTrace(TEST_NAME+"1", ReuseCacheType.REUSE_FULL);
        }
-       
-       @Test
-       public void testGridSearchLMFull() {
-               testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_FULL);
-       }
 
        @Test
-       public void testMultiLogRegFull() {
-               testLineageTrace(TEST_NAME+"3", ReuseCacheType.REUSE_FULL);
-       }
-
-       @Test
-       public void testPCAFull() {
-               testLineageTrace(TEST_NAME+"4", ReuseCacheType.REUSE_FULL);
+       public void testGridSearchMLR() {
+               testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_FULL);
+               //FIXME: 2x slower with reuse. Heavy hitter function is 
lineageitem equals.
+               //This problem only exists with parfor.
        }
-
+       
        public void testLineageTrace(String testname, ReuseCacheType reuseType) 
{
-               boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
-               boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
-               ExecMode platformOld = setExecMode(ExecType.CP);
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
                
                try {
                        LOG.debug("------------ BEGIN " + testname + 
"------------");
-                       
-                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
-                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false;
-                       
                        getAndLoadTestConfiguration(testname);
                        fullDMLScriptName = getScript();
                        
@@ -121,6 +79,8 @@ public class LineageReuseAlg extends LineageBase {
                        Lineage.resetInternalState();
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> X_orig = 
readDMLMatrixFromOutputDir("X");
+                       //long numlmDS = 
Statistics.getCPHeavyHitterCount("m_lmDS");
+                       long numMM = Statistics.getCPHeavyHitterCount("ba+*");
                        
                        // With lineage-based reuse enabled
                        proArgs.clear();
@@ -131,19 +91,25 @@ public class LineageReuseAlg extends LineageBase {
                        proArgs.add(output("X"));
                        programArgs = proArgs.toArray(new 
String[proArgs.size()]);
                        Lineage.resetInternalState();
-                       Lineage.setLinReuseFull();
                        
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> X_reused = 
readDMLMatrixFromOutputDir("X");
+                       //long numlmDS_reuse = 
Statistics.getCPHeavyHitterCount("m_lmDS");
+                       long numMM_reuse = 
Statistics.getCPHeavyHitterCount("ba+*");
                        
                        Lineage.setLinReuseNone();
                        TestUtils.compareMatrices(X_orig, X_reused, 1e-6, 
"Origin", "Reused");
+
+                       if (testname.equalsIgnoreCase("LineageReuseEval1")) {  
//gridSearchLM
+                               //lmDS call should be reused for all the 7 
values of tolerance 
+                               //Assert.assertTrue("Violated lmDS reuse count: 
7 * "+numlmDS_reuse+" == "+numlmDS, 
+                               //              7*numlmDS_reuse == numlmDS);
+                               Assert.assertTrue("Violated ba+* reuse count: 
"+numMM_reuse+" < "+numMM, numMM_reuse < numMM);
+                       }
                }
                finally {
-                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
-                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
old_sum_product;
                        rtplatform = platformOld;
                        Recompiler.reinitRecompiler();
                }
        }
-}
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/builtin/GridSearchLM2.dml 
b/src/test/scripts/functions/builtin/GridSearchLM2.dml
index 278d94c..c172794 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM2.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM2.dml
@@ -36,7 +36,7 @@ Xtest = X[(N+1):nrow(X),];
 ytest = y[(N+1):nrow(X),];
 
 params = list("icpt","reg", "tol", "maxi");
-paramRanges = list(seq(0,1,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+paramRanges = list(seq(0,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
 [B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
   numB=ncol(X)+1, params=params, paramValues=paramRanges);
 B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM2.dml 
b/src/test/scripts/functions/lineage/LineageReuseEval1.dml
similarity index 79%
copy from src/test/scripts/functions/builtin/GridSearchLM2.dml
copy to src/test/scripts/functions/lineage/LineageReuseEval1.dml
index 278d94c..00a1b39 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM2.dml
+++ b/src/test/scripts/functions/lineage/LineageReuseEval1.dml
@@ -19,15 +19,15 @@
 #
 #-------------------------------------------------------------
 
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) 
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B)
   return (Matrix[Double] loss)
 {
   yhat = lmPredict(X=X, B=B, ytest=y)
   loss = as.matrix(sum((y - yhat)^2));
 }
 
-X = read($1);
-y = read($2);
+X = rand(rows=300, cols=20, sparsity=1.0, seed=1);
+y = rand(rows=300, cols=1, sparsity=1.0, seed=1);
 
 N = 200;
 Xtrain = X[1:N,];
@@ -35,14 +35,11 @@ ytrain = y[1:N,];
 Xtest = X[(N+1):nrow(X),];
 ytest = y[(N+1):nrow(X),];
 
-params = list("icpt","reg", "tol", "maxi");
-paramRanges = list(seq(0,1,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+params = list("icpt","reg", "tol"); #numValues = 3, 5, 7
+paramRanges = list(seq(0,2), 10^seq(0,-4), 10^seq(-6,-12)); #3*5*7 = 105
 [B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
-  numB=ncol(X)+1, params=params, paramValues=paramRanges);
-B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
-
+  numB=ncol(X)+1, params=params, paramValues=paramRanges, verbose=FALSE);
 l1 = l2norm(Xtest, ytest, B1);
-l2 = l2norm(Xtest, ytest, B2);
-R = as.scalar(l1 < l2);
 
-write(R, $3)
+write(l1, $1, format="text");
+
diff --git a/src/test/scripts/functions/builtin/GridSearchLM2.dml 
b/src/test/scripts/functions/lineage/LineageReuseEval2.dml
similarity index 55%
copy from src/test/scripts/functions/builtin/GridSearchLM2.dml
copy to src/test/scripts/functions/lineage/LineageReuseEval2.dml
index 278d94c..2263d77 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM2.dml
+++ b/src/test/scripts/functions/lineage/LineageReuseEval2.dml
@@ -19,15 +19,14 @@
 #
 #-------------------------------------------------------------
 
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) 
-  return (Matrix[Double] loss)
-{
-  yhat = lmPredict(X=X, B=B, ytest=y)
-  loss = as.matrix(sum((y - yhat)^2));
+accuracy = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) 
return (Matrix[Double] err) {
+  [M,yhat,acc] = multiLogRegPredict(X=X, B=B, Y=y, verbose=FALSE);
+  err = as.matrix(1-(acc/100));
 }
 
-X = read($1);
-y = read($2);
+X = rand(rows=300, cols=20, sparsity=1.0, seed=1);
+y = rand(rows=300, cols=1, min=1, max=3, sparsity=1.0, seed=1);
+y = floor(y);
 
 N = 200;
 Xtrain = X[1:N,];
@@ -35,14 +34,13 @@ ytrain = y[1:N,];
 Xtest = X[(N+1):nrow(X),];
 ytest = y[(N+1):nrow(X),];
 
-params = list("icpt","reg", "tol", "maxi");
-paramRanges = list(seq(0,1,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
-[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
-  numB=ncol(X)+1, params=params, paramValues=paramRanges);
-B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
+params = list("icpt", "reg", "maxii");
+paramRanges = list(seq(0,2),10^seq(1,-6), 10^seq(1,3));
+trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100, 
maxii=-1, verbose=FALSE);
+[B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg", 
predict="accuracy", numB=ncol(X)+1,
+  params=params, paramValues=paramRanges, trainArgs=trainArgs, verbose=FALSE);
 
-l1 = l2norm(Xtest, ytest, B1);
-l2 = l2norm(Xtest, ytest, B2);
-R = as.scalar(l1 < l2);
+[M,yhat,acc] = multiLogRegPredict(X=Xtest, B=B1, Y=ytest, verbose=FALSE);
+
+write(yhat, $1, format="text");
 
-write(R, $3)

Reply via email to