This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new f22e999 [SYSTEMDS-331,332] Fix robustness lineage cache (deadlocks, correctness) f22e999 is described below commit f22e9991e2370dc30a1fed01c3142c27071da42c Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Fri Apr 10 16:28:39 2020 +0200 [SYSTEMDS-331,332] Fix robustness lineage cache (deadlocks, correctness) This patch fixes the robustness of lineage-based caching, especially in multi-threaded parfor programs. This includes: 1) Deadlock prevention: With multi-level caching, the placeholders that prevent concurrent computation of redundant intermediates led to deadlocks because the following threads blocked inside the critical region and thus any caching of the thread that was producing the intermediate (via a complex DAG of operations) was blocked. 2) Deadlock wrong Data Types: With the introduction of scalar caching each thread had to decide to either pull a scalar or matrix on the placeholders. Since this decision was made based on the data item (which might not be available yet in parfor) threads were blocking on the wrong type and thus again producing deadlocks. 3) Correctness: The loop iteration variable of parfor was not integrated yet with lineage tracing leading to incorrect reuse for different parfor iterations that depended on the iteration variable. Furthermore, this patch also cleans up an unnecessarily wide public API of the lineage cache in order to facilitate a correct internal implementation. However, there are still a number of remaining issues, e.g., with the computation of compensation plans and probing logic. --- docs/Tasks.txt | 2 +- .../org/apache/sysds/parser/StatementBlock.java | 43 ++- .../runtime/controlprogram/BasicProgramBlock.java | 16 +- .../runtime/controlprogram/parfor/ParWorker.java | 42 ++- .../instructions/cp/FunctionCallCPInstruction.java | 8 +- .../apache/sysds/runtime/lineage/LineageCache.java | 388 +++++++++++---------- .../sysds/runtime/lineage/LineageCacheConfig.java | 30 ++ .../sysds/runtime/lineage/LineageRewriteReuse.java | 54 +-- .../functions/lineage/FunctionFullReuseTest.java | 42 ++- .../functions/lineage/FunctionFullReuse6.dml | 37 ++ .../functions/lineage/FunctionFullReuse7.dml | 37 ++ 11 files changed, 412 insertions(+), 287 deletions(-) diff --git a/docs/Tasks.txt b/docs/Tasks.txt index 42741da..d19672f 100644 --- a/docs/Tasks.txt +++ b/docs/Tasks.txt @@ -239,7 +239,7 @@ SYSTEMDS-320 Merge SystemDS into Apache SystemML OK SYSTEMDS-330 Lineage Tracing, Reuse and Integration * 331 Cache and reuse scalar outputs (instruction and multi-level) OK - * 332 Parfor integration with multi-level reuse + * 332 Parfor integration with multi-level reuse OK * 333 Use exact execution time for cost based eviction SYSTEMDS-340 Compiler Assisted Lineage Caching and Reuse diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java index 2e87909..5991315 100644 --- a/src/main/java/org/apache/sysds/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java @@ -43,12 +43,12 @@ import org.apache.sysds.utils.MLContextProxy; public class StatementBlock extends LiveVariableAnalysis implements ParseInfo { - protected static final Log LOG = LogFactory.getLog(StatementBlock.class.getName()); protected static IDSequence _seq = new IDSequence(); private static IDSequence _seqSBID = new IDSequence(); protected final long _ID; - + protected final String _name; + protected DMLProgram _dmlProg; protected ArrayList<Statement> _statements; ArrayList<Hop> _hops = null; @@ -62,6 +62,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo public StatementBlock() { _ID = getNextSBID(); + _name = "SB"+_ID; _dmlProg = null; _statements = new ArrayList<>(); _read = new VariableSet(); @@ -96,6 +97,10 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo public long getSBID() { return _ID; } + + public String getName() { + return _name; + } public void addStatement(Statement s) { _statements.add(s); @@ -399,8 +404,9 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo return inputs; } - public ArrayList<String> getOutputsofSB() { - ArrayList<String> outputs = _liveOut != null && _updated != null ? new ArrayList<>() : null; + public ArrayList<String> getOutputNamesofSB() { + ArrayList<String> outputs = _liveOut != null + && _updated != null ? new ArrayList<>() : null; if (_liveOut != null && _updated != null) { for (String varName : _updated.getVariables().keySet()) { if (_liveOut.containsVariable(varName)) @@ -409,6 +415,18 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo } return outputs; } + + public ArrayList<DataIdentifier> getOutputsofSB() { + ArrayList<DataIdentifier> outputs = _liveOut != null + && _updated != null ? new ArrayList<>() : null; + if (_liveOut != null && _updated != null) { + for (String varName : _updated.getVariables().keySet()) { + if (_liveOut.containsVariable(varName)) + outputs.add(_liveOut.getVariable(varName)); + } + } + return outputs; + } public static ArrayList<StatementBlock> mergeStatementBlocks(ArrayList<StatementBlock> sb){ if (sb == null || sb.isEmpty()) @@ -683,29 +701,20 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo AssignmentStatement as = (AssignmentStatement) current; if ((as.getTargetList().size() == 1) && (as.getTargetList().get(0) != null)) { raiseValidateError("Function '" + fcall.getName() - + "' does not return a value but is assigned to " + as.getTargetList().get(0), - true); + + "' does not return a value but is assigned to " + as.getTargetList().get(0), true); } } - } else if (current instanceof MultiAssignmentStatement) { + } + else if (current instanceof MultiAssignmentStatement) { if (fstmt.getOutputParams().size() == 0) { MultiAssignmentStatement mas = (MultiAssignmentStatement) current; raiseValidateError("Function '" + fcall.getName() - + "' does not return a value but is assigned to " + mas.getTargetList(), true); + + "' does not return a value but is assigned to " + mas.getTargetList(), true); } } // handle returns by appending name mappings, but with special handling of // statements that contain function calls or multi-return builtin expressions (but disabled) -// Statement lastAdd = newStatements.get(newStatements.size()-1); -// if( isOutputBindingViaFunctionCall(lastAdd, prefix, fstmt) && lastAdd instanceof AssignmentStatement ) -// ((AssignmentStatement)lastAdd).setTarget(((AssignmentStatement)current).getTarget()); -// else if ( isOutputBindingViaFunctionCall(lastAdd, prefix, fstmt) && lastAdd instanceof MultiAssignmentStatement ) -// if( current instanceof MultiAssignmentStatement ) -// ((MultiAssignmentStatement)lastAdd).setTargetList(((MultiAssignmentStatement)current).getTargetList()); -// else //correct for multi-assignment to assignment transform -// newStatements.set(newStatements.size()-1, createNewPartialMultiAssignment(lastAdd, current, prefix, fstmt)); -// else appendOutputAssignments(current, prefix, fstmt, newStatements); } return newStatements; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java index 1f52a75..5f44ac3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.controlprogram; import java.util.ArrayList; +import java.util.List; import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.ConfigurationManager; @@ -29,7 +30,6 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.lineage.LineageCache; import org.apache.sysds.runtime.lineage.LineageCacheConfig; -import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; import org.apache.sysds.runtime.lineage.LineageCacheStatistics; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageItemUtils; @@ -108,12 +108,10 @@ public class BasicProgramBlock extends ProgramBlock //statement-block-level, lineage-based reuse LineageItem[] liInputs = null; - if (_sb != null - && !ReuseCacheType.isNone() - && LineageCacheConfig.getCacheType().isMultilevelReuse()) { - String name = "SB" + _sb.getSBID(); + if (_sb != null && LineageCacheConfig.isMultiLevelReuse()) { liInputs = LineageItemUtils.getLineageItemInputstoSB(_sb.getInputstoSB(), ec); - if( LineageCache.reuse(_sb.getOutputsofSB(), _sb.getOutputsofSB().size(), liInputs, name, ec) ) { + List<String> outNames = _sb.getOutputNamesofSB(); + if( LineageCache.reuse(outNames, _sb.getOutputsofSB(), outNames.size(), liInputs, _sb.getName(), ec) ) { if( DMLScript.STATISTICS ) LineageCacheStatistics.incrementSBHits(); return; @@ -124,9 +122,7 @@ public class BasicProgramBlock extends ProgramBlock executeInstructions(tmp, ec); //statement-block-level, lineage-based caching - if (_sb != null && liInputs != null) { - String name = "SB" + _sb.getSBID(); - LineageCache.putValue(_sb.getOutputsofSB(), _sb.getOutputsofSB().size(), liInputs, name, ec); - } + if (_sb != null && liInputs != null) + LineageCache.putValue(_sb.getOutputsofSB(), liInputs, _sb.getName(), ec); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java index 9f8fbb4..7b74ace 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java @@ -24,6 +24,7 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.parser.ParForStatementBlock.ResultVar; import org.apache.sysds.runtime.controlprogram.LocalVariableMap; import org.apache.sysds.runtime.controlprogram.ProgramBlock; @@ -32,8 +33,10 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.stat.Stat; import org.apache.sysds.runtime.controlprogram.parfor.stat.StatisticMonitor; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.cp.IntObject; +import org.apache.sysds.runtime.lineage.Lineage; /** * Super class for master/worker pattern implementations. Central place to @@ -113,22 +116,20 @@ public abstract class ParWorker protected void executeTask( Task task ) { LOG.trace("EXECUTE PARFOR_WORKER ID="+_workerID+" for task "+task.toCompactString()); - switch( task.getType() ) - { + switch( task.getType() ) { case SET: executeSetTask( task ); break; case RANGE: executeRangeTask( task ); - break; + break; } } private void executeSetTask( Task task ) { //monitoring start Timing time1 = null, time2 = null; - if( _monitor ) - { + if( _monitor ) { time1 = new Timing(true); time2 = new Timing(true); } @@ -143,6 +144,10 @@ public abstract class ParWorker //set index values _ec.setVariable(lVarName, indexVal); + if (DMLScript.LINEAGE) { + Lineage li = _ec.getLineage(); + li.set(lVarName, li.getOrCreate(new CPOperand(indexVal))); + } // for each program block for (ProgramBlock pb : _childBlocks) @@ -157,8 +162,7 @@ public abstract class ParWorker _numTasks++; //monitoring end - if( _monitor ) - { + if( _monitor ) { StatisticMonitor.putPWStat(_workerID, Stat.PARWRK_TASKSIZE, task.size()); StatisticMonitor.putPWStat(_workerID, Stat.PARWRK_TASK_T, time2.stop()); } @@ -167,10 +171,9 @@ public abstract class ParWorker private void executeRangeTask( Task task ) { //monitoring start Timing time1 = null, time2 = null; - if( _monitor ) - { - time1 = new Timing(true); - time2 = new Timing(true); + if( _monitor ) { + time1 = new Timing(true); + time2 = new Timing(true); } //core execution @@ -183,28 +186,29 @@ public abstract class ParWorker for( long i=lFrom; i<=lTo; i+=lIncr ) { //set index values - _ec.setVariable(lVarName, new IntObject(i)); + IntObject indexVal = new IntObject(i); + _ec.setVariable(lVarName, indexVal); + if (DMLScript.LINEAGE) { + Lineage li = _ec.getLineage(); + li.set(lVarName, li.getOrCreate(new CPOperand(indexVal))); + } // for each program block for (ProgramBlock pb : _childBlocks) pb.execute(_ec); - + _numIters++; if( _monitor ) - StatisticMonitor.putPWStat(_workerID, Stat.PARWRK_ITER_T, time1.stop()); + StatisticMonitor.putPWStat(_workerID, Stat.PARWRK_ITER_T, time1.stop()); } _numTasks++; //monitoring end - if( _monitor ) - { + if( _monitor ) { StatisticMonitor.putPWStat(_workerID, Stat.PARWRK_TASKSIZE, task.size()); StatisticMonitor.putPWStat(_workerID, Stat.PARWRK_TASK_T, time2.stop()); } } - } - - diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java index 5a24ad8..e605a55 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.lineage.Lineage; import org.apache.sysds.runtime.lineage.LineageCache; +import org.apache.sysds.runtime.lineage.LineageCacheConfig; import org.apache.sysds.runtime.lineage.LineageCacheStatistics; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageItemUtils; @@ -115,7 +116,7 @@ public class FunctionCallCPInstruction extends CPInstruction { } // check if function outputs can be reused from cache - LineageItem[] liInputs = DMLScript.LINEAGE ? + LineageItem[] liInputs = DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() ? LineageItemUtils.getLineage(ec, _boundInputs) : null; if( reuseFunctionOutputs(liInputs, fpb, ec) ) return; //only if all the outputs are found in cache @@ -224,7 +225,8 @@ public class FunctionCallCPInstruction extends CPInstruction { } //update lineage cache with the functions outputs - LineageCache.putValue(_boundOutputNames, numOutputs, liInputs, _functionName, ec); + if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() ) + LineageCache.putValue(fpb.getOutputParams(), liInputs, _functionName, ec); } @Override @@ -261,7 +263,7 @@ public class FunctionCallCPInstruction extends CPInstruction { private boolean reuseFunctionOutputs(LineageItem[] liInputs, FunctionProgramBlock fpb, ExecutionContext ec) { int numOutputs = Math.min(_boundOutputNames.size(), fpb.getOutputParams().size()); - boolean reuse = LineageCache.reuse(_boundOutputNames, numOutputs, liInputs, _functionName, ec); + boolean reuse = LineageCache.reuse(_boundOutputNames, fpb.getOutputParams(), numOutputs, liInputs, _functionName, ec); if (reuse && DMLScript.STATISTICS) { //decrement the call count for this function diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java index 789b9f7..2741b70 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java @@ -20,10 +20,12 @@ package org.apache.sysds.runtime.lineage; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.cost.CostEstimatorStaticRuntime; import org.apache.sysds.lops.MMTSJ.MMTSJType; +import org.apache.sysds.parser.DataIdentifier; import org.apache.sysds.parser.Statement; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; @@ -51,7 +53,8 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -public class LineageCache { +public class LineageCache +{ private static final Map<LineageItem, Entry> _cache = new HashMap<>(); private static final Map<LineageItem, SpilledItem> _spillList = new HashMap<>(); private static final HashSet<LineageItem> _removelist = new HashSet<>(); @@ -67,7 +70,18 @@ public class LineageCache { CACHE_LIMIT = (long)(CACHE_FRAC * maxMem); } - //--------------------- CACHE LOGIC METHODS ---------------------- + // Cache Synchronization Approach: + // The central static cache is only synchronized in a fine-grained manner + // for short get, put, or remove calls or during eviction. All blocking of + // threads for computing the values of placeholders is done on the individual + // entry objects which reduces contention and prevents deadlocks in case of + // function/statement block placeholders which computation itself might be + // a complex workflow of operations that accesses the cache as well. + + + /////////////////////////////////////// + // Public Cache API (keep it narrow) // + /////////////////////////////////////// public static boolean reuse(Instruction inst, ExecutionContext ec) { if (ReuseCacheType.isNone()) @@ -76,77 +90,85 @@ public class LineageCache { boolean reuse = false; //NOTE: the check for computation CP instructions ensures that the output // will always fit in memory and hence can be pinned unconditionally - if (inst instanceof ComputationCPInstruction && LineageCache.isReusable(inst, ec)) { - LineageItem item = ((ComputationCPInstruction) inst).getLineageItems(ec)[0]; + if (LineageCacheConfig.isReusable(inst, ec)) { + ComputationCPInstruction cinst = (ComputationCPInstruction) inst; + LineageItem item = cinst.getLineageItems(ec)[0]; + //atomic try reuse full/partial and set placeholder, without + //obtaining value to avoid blocking in critical section + Entry e = null; synchronized( _cache ) { //try to reuse full or partial intermediates if (LineageCacheConfig.getCacheType().isFullReuse()) - reuse = fullReuse(item, (ComputationCPInstruction)inst, ec); - if (LineageCacheConfig.getCacheType().isPartialReuse()) - reuse |= LineageRewriteReuse.executeRewrites(inst, ec); - - if (reuse && DMLScript.STATISTICS) - LineageCacheStatistics.incrementInstHits(); + e = LineageCache.probe(item) ? getIntern(item) : null; + //TODO need to also move execution of compensation plan out of here + //(create lazily evaluated entry) + if (e == null && LineageCacheConfig.getCacheType().isPartialReuse()) + if( LineageRewriteReuse.executeRewrites(inst, ec) ) + e = getIntern(item); + reuse = (e != null); //create a placeholder if no reuse to avoid redundancy //(e.g., concurrent threads that try to start the computation) - if(!reuse && isMarkedForCaching(inst, ec)) - putIntern(item, null, null, 0); + if(!reuse && isMarkedForCaching(inst, ec)) { + putIntern(item, cinst.output.getDataType(), null, null, 0); + } + } + + if( reuse ) { //reuse + //put reuse value into symbol table (w/ blocking on placeholders) + if (e.isMatrixValue()) + ec.setMatrixOutput(cinst.output.getName(), e.getMBValue()); + else + ec.setScalarOutput(cinst.output.getName(), e.getSOValue()); + if (DMLScript.STATISTICS) + LineageCacheStatistics.incrementInstHits(); + reuse = true; } } return reuse; } - public static Entry reuse(LineageItem item) { - if (ReuseCacheType.isNone()) - return null; - - Entry e = null; - synchronized( _cache ) { - if (LineageCache.probe(item)) - e = LineageCache.get(item); - else - //create a placeholder if no reuse to avoid redundancy - //(e.g., concurrent threads that try to start the computation) - putIntern(item, null, null, 0); - //FIXME: parfor - every thread gets different function names - } - return e; - } - - public static boolean reuse(List<String> outputs, int numOutputs, LineageItem[] liInputs, String name, ExecutionContext ec) + public static boolean reuse(List<String> outNames, List<DataIdentifier> outParams, int numOutputs, LineageItem[] liInputs, String name, ExecutionContext ec) { - if( ReuseCacheType.isNone() || !LineageCacheConfig.getCacheType().isMultilevelReuse()) + if( !LineageCacheConfig.isMultiLevelReuse()) return false; - - boolean reuse = (numOutputs != 0); + + boolean reuse = (outParams.size() != 0); HashMap<String, Data> funcOutputs = new HashMap<>(); HashMap<String, LineageItem> funcLIs = new HashMap<>(); for (int i=0; i<numOutputs; i++) { String opcode = name + String.valueOf(i+1); - LineageItem li = new LineageItem(outputs.get(i), opcode, liInputs); - Entry cachedValue = LineageCache.reuse(li); + LineageItem li = new LineageItem(outNames.get(i), opcode, liInputs); + Entry e = null; + synchronized( _cache ) { + if (LineageCache.probe(li)) + e = LineageCache.getIntern(li); + else + //create a placeholder if no reuse to avoid redundancy + //(e.g., concurrent threads that try to start the computation) + putIntern(li, outParams.get(i).getDataType(), null, null, 0); + //FIXME: parfor - every thread gets different function names + } //TODO: handling of recursive calls - if (cachedValue != null && !cachedValue.isNullVal()) { - String boundVarName = outputs.get(i); + if (e != null && !e.isNullVal()) { + String boundVarName = outNames.get(i); Data boundValue = null; //convert to matrix object - if (cachedValue.isMatrixValue()) { - MetaDataFormat md = new MetaDataFormat(cachedValue.getMBValue().getDataCharacteristics(), - OutputInfo.BinaryCellOutputInfo, InputInfo.BinaryCellInputInfo); + if (e.isMatrixValue()) { + MetaDataFormat md = new MetaDataFormat(e.getMBValue().getDataCharacteristics(), + OutputInfo.BinaryCellOutputInfo, InputInfo.BinaryCellInputInfo); boundValue = new MatrixObject(ValueType.FP64, boundVarName, md); - ((MatrixObject)boundValue).acquireModify(cachedValue.getMBValue()); + ((MatrixObject)boundValue).acquireModify(e.getMBValue()); ((MatrixObject)boundValue).release(); } else - boundValue = cachedValue.getSOValue(); + boundValue = e.getSOValue(); funcOutputs.put(boundVarName, boundValue); - - LineageItem orig = _cache.get(li)._origItem; //FIXME: synchronize + LineageItem orig = e._origItem; funcLIs.put(boundVarName, orig); } else { @@ -169,17 +191,37 @@ public class LineageCache { //map original lineage items return to the calling site funcLIs.forEach((var, li) -> ec.getLineage().set(var, li)); } + return reuse; } + public static boolean probe(LineageItem key) { + //TODO problematic as after probe the matrix might be kicked out of cache + boolean p = (_cache.containsKey(key) || _spillList.containsKey(key)); + if (!p && DMLScript.STATISTICS && _removelist.contains(key)) + // The sought entry was in cache but removed later + LineageCacheStatistics.incrementDelHits(); + return p; + } + + public static MatrixBlock getMatrix(LineageItem key) { + Entry e = null; + synchronized( _cache ) { + e = getIntern(key); + } + return e.getMBValue(); + } + //NOTE: safe to pin the object in memory as coming from CPInstruction - public static void put(Instruction inst, ExecutionContext ec) { - if (inst instanceof ComputationCPInstruction && isReusable(inst, ec) ) { + //TODO why do we need both of these public put methods + public static void putMatrix(Instruction inst, ExecutionContext ec) { + if (LineageCacheConfig.isReusable(inst, ec) ) { LineageItem item = ((LineageTraceable) inst).getLineageItems(ec)[0]; //This method is called only to put matrix value MatrixObject mo = ec.getMatrixObject(((ComputationCPInstruction) inst).output); synchronized( _cache ) { - putIntern(item, mo.acquireReadAndRelease(), null, getRecomputeEstimate(inst, ec)); + putIntern(item, DataType.MATRIX, mo.acquireReadAndRelease(), + null, getRecomputeEstimate(inst, ec)); } } } @@ -187,18 +229,18 @@ public class LineageCache { public static void putValue(Instruction inst, ExecutionContext ec) { if (ReuseCacheType.isNone()) return; - if (inst instanceof ComputationCPInstruction && isReusable(inst, ec) ) { - if (!isMarkedForCaching(inst, ec)) return; + if (LineageCacheConfig.isReusable(inst, ec) ) { + //if (!isMarkedForCaching(inst, ec)) return; LineageItem item = ((LineageTraceable) inst).getLineageItems(ec)[0]; - //MatrixObject mo = ec.getMatrixObject(((ComputationCPInstruction) inst).output); Data data = ec.getVariable(((ComputationCPInstruction) inst).output); - MatrixObject mo = data instanceof MatrixObject ? (MatrixObject)data : null; - ScalarObject so = data instanceof ScalarObject ? (ScalarObject)data : null; - MatrixBlock Mval = mo != null ? mo.acquireReadAndRelease() : null; - _cache.get(item).setValue(Mval, so, getRecomputeEstimate(inst, ec)); //outside sync to prevent deadlocks - long size = _cache.get(item).getSize(); - + double cest = getRecomputeEstimate(inst, ec); synchronized( _cache ) { + if( data instanceof MatrixObject ) + _cache.get(item).setValue(((MatrixObject)data).acquireReadAndRelease(), cest); + else + _cache.get(item).setValue((ScalarObject)data, cest); + long size = _cache.get(item).getSize(); + if( !isBelowThreshold(size) ) makeSpace(size); updateSize(size, true); @@ -206,42 +248,17 @@ public class LineageCache { } } - public static void putValue(LineageItem item, LineageItem probeItem) { - if (ReuseCacheType.isNone()) - return; - if (LineageCache.probe(probeItem)) { - Entry oe = LineageCache.get(probeItem); - Entry e = _cache.get(item); - //TODO: compute estimate for function - if (oe.isMatrixValue()) - e.setValue(oe.getMBValue(), null, 0); - else - e.setValue(null, oe.getSOValue(), 0); - e._origItem = probeItem; - - long size = oe.getSize(); - synchronized( _cache ) { - if(!isBelowThreshold(size)) - makeSpace(size); - updateSize(size, true); - } - } - else - removeEntry(item); //remove the placeholder - - } - - public static void putValue(List<String> outputs, int numOutputs, LineageItem[] liInputs, String name, ExecutionContext ec) + public static void putValue(List<DataIdentifier> outputs, LineageItem[] liInputs, String name, ExecutionContext ec) { - if( ReuseCacheType.isNone() || !LineageCacheConfig.getCacheType().isMultilevelReuse()) + if( !LineageCacheConfig.isMultiLevelReuse()) return; HashMap<LineageItem, LineageItem> FuncLIMap = new HashMap<>(); boolean AllOutputsCacheable = true; - for (int i=0; i<numOutputs; i++) { + for (int i=0; i<outputs.size(); i++) { String opcode = name + String.valueOf(i+1); - LineageItem li = new LineageItem(outputs.get(i), opcode, liInputs); - String boundVarName = outputs.get(i); + LineageItem li = new LineageItem(outputs.get(i).getName(), opcode, liInputs); + String boundVarName = outputs.get(i).getName(); LineageItem boundLI = ec.getLineage().get(boundVarName); if (boundLI != null) boundLI.resetVisitStatus(); @@ -254,23 +271,42 @@ public class LineageCache { } //cache either all the outputs, or none. - if(AllOutputsCacheable) - FuncLIMap.forEach((Li, boundLI) -> LineageCache.putValue(Li, boundLI)); - else - //remove all the placeholders - FuncLIMap.forEach((Li, boundLI) -> LineageCache.removeEntry(Li)); + synchronized( _cache ) { + //move or remove placeholders + if(AllOutputsCacheable) + FuncLIMap.forEach((Li, boundLI) -> mvIntern(Li, boundLI)); + else + FuncLIMap.forEach((Li, boundLI) -> removeEntry(Li)); + } return; } - private static void putIntern(LineageItem key, MatrixBlock Mval, ScalarObject Sval, double compcost) { + public static void resetCache() { + synchronized( _cache ) { + _cache.clear(); + _spillList.clear(); + _head = null; + _end = null; + // reset cache size, otherwise the cache clear leads to unusable + // space which means evictions could run into endless loops + _cachesize = 0; + if (DMLScript.STATISTICS) + _removelist.clear(); + } + } + + ///////////////////////////////////////// + // Internal Cache Logic Implementation // + ///////////////////////////////////////// + + private static void putIntern(LineageItem key, DataType dt, MatrixBlock Mval, ScalarObject Sval, double compcost) { if (_cache.containsKey(key)) //can come here if reuse_partial option is enabled - return; - //throw new DMLRuntimeException("Redundant lineage caching detected: "+inst); + return; // Create a new entry. - Entry newItem = new Entry(key, Mval, Sval, compcost); + Entry newItem = new Entry(key, dt, Mval, Sval, compcost); // Make space by removing or spilling LRU entries. if( Mval != null || Sval != null ) { @@ -290,40 +326,7 @@ public class LineageCache { LineageCacheStatistics.incrementMemWrites(); } - protected static boolean probe(LineageItem key) { - boolean p = (_cache.containsKey(key) || _spillList.containsKey(key)); - if (!p && DMLScript.STATISTICS && _removelist.contains(key)) - // The sought entry was in cache but removed later - LineageCacheStatistics.incrementDelHits(); - return p; - } - - public static void resetCache() { - _cache.clear(); - _spillList.clear(); - _head = null; - _end = null; - // reset cache size, otherwise the cache clear leads to unusable - // space which means evictions could run into endless loops - _cachesize = 0; - if (DMLScript.STATISTICS) - _removelist.clear(); - } - - - private static boolean fullReuse (LineageItem item, ComputationCPInstruction inst, ExecutionContext ec) { - if (LineageCache.probe(item)) { - Entry e = LineageCache.get(item); - if (e.isMatrixValue()) - ec.setMatrixOutput(inst.output.getName(), e.getMBValue()); - else - ec.setScalarOutput(inst.output.getName(), e.getSOValue()); - return true; - } - return false; - } - - protected static Entry get(LineageItem key) { + private static Entry getIntern(LineageItem key) { // This method is called only when entry is present either in cache or in local FS. if (_cache.containsKey(key)) { // Read and put the entry at head. @@ -337,44 +340,39 @@ public class LineageCache { else return readFromLocalFS(key); } + - public static boolean isReusable (Instruction inst, ExecutionContext ec) { - // TODO: Move this to the new class LineageCacheConfig and extend - return inst.getOpcode().equalsIgnoreCase("tsmm") - || inst.getOpcode().equalsIgnoreCase("ba+*") - || inst.getOpcode().equalsIgnoreCase("*") - || inst.getOpcode().equalsIgnoreCase("/") - || inst.getOpcode().equalsIgnoreCase("+") - || inst.getOpcode().equalsIgnoreCase("nrow") - || inst.getOpcode().equalsIgnoreCase("ncol") - || inst.getOpcode().equalsIgnoreCase("rightIndex") - || inst.getOpcode().equalsIgnoreCase("leftIndex") - || inst.getOpcode().equalsIgnoreCase("groupedagg") - || inst.getOpcode().equalsIgnoreCase("r'") - || (inst.getOpcode().equalsIgnoreCase("append") && isVectorAppend(inst, ec)) - || inst.getOpcode().equalsIgnoreCase("solve") - || inst.getOpcode().contains("spoof"); - } - - private static boolean isVectorAppend(Instruction inst, ExecutionContext ec) { - ComputationCPInstruction cpinst = (ComputationCPInstruction) inst; - if( !cpinst.input1.isMatrix() || !cpinst.input2.isMatrix() ) - return false; - long c1 = ec.getMatrixObject(cpinst.input1).getNumColumns(); - long c2 = ec.getMatrixObject(cpinst.input2).getNumColumns(); - return(c1 == 1 || c2 == 1); + private static void mvIntern(LineageItem item, LineageItem probeItem) { + if (ReuseCacheType.isNone()) + return; + if (LineageCache.probe(probeItem)) { + Entry oe = getIntern(probeItem); + Entry e = _cache.get(item); + //TODO: compute estimate for function + if (oe.isMatrixValue()) + e.setValue(oe.getMBValue(), 0); + else + e.setValue(oe.getSOValue(), 0); + e._origItem = probeItem; + + long size = oe.getSize(); + if(!isBelowThreshold(size)) + makeSpace(size); + updateSize(size, true); + } + else + removeEntry(item); //remove the placeholder } - public static boolean isMarkedForCaching (Instruction inst, ExecutionContext ec) { + private static boolean isMarkedForCaching (Instruction inst, ExecutionContext ec) { if (!LineageCacheConfig.getCompAssRW()) return true; if (((ComputationCPInstruction)inst).output.isMatrix()) { MatrixObject mo = ec.getMatrixObject(((ComputationCPInstruction)inst).output); //limit this to full reuse as partial reuse is applicable even for loop dependent operation - boolean marked = (LineageCacheConfig.getCacheType() == ReuseCacheType.REUSE_FULL - && !mo.isMarked()) ? false : true; - return marked; + return !(LineageCacheConfig.getCacheType() == ReuseCacheType.REUSE_FULL + && !mo.isMarked()); } else return true; @@ -397,7 +395,6 @@ public class LineageCache { continue; } - double reduction = _cache.get(_end._key).getSize(); if (_cache.get(_end._key).isMatrixValue()) { //spill matrix blocks only if (_cache.get(_end._key)._compEst > getDiskSpillEstimate() && LineageCacheConfig.isSetSpill()) @@ -410,8 +407,8 @@ public class LineageCache { setEnd2Head(_end); continue; } - removeEntry(reduction); - } + removeLastEntry(); + } } private static void updateSize(long space, boolean addspace) { @@ -617,7 +614,7 @@ public class LineageCache { } // Restore to cache LocalFileUtils.deleteFileIfExists(_spillList.get(key)._outfile, true); - putIntern(key, mb, null, _spillList.get(key)._compEst); + putIntern(key, DataType.MATRIX, mb, null, _spillList.get(key)._compEst); _spillList.remove(key); if (DMLScript.STATISTICS) { long t1 = System.nanoTime(); @@ -627,7 +624,30 @@ public class LineageCache { return _cache.get(key); } - //------------------ LINKEDLIST MAINTENANCE METHODS ------------------- + //////////////////////////////////////////// + // Cache Maintenance and Lookup Functions // + //////////////////////////////////////////// + + private static void removeLastEntry() { + if (DMLScript.STATISTICS) + _removelist.add(_end._key); + Entry e = _cache.remove(_end._key); + _cachesize -= e.getSize(); + delete(_end); + } + + private static void removeEntry(LineageItem key) { + // Remove the entry for key + if (!_cache.containsKey(key)) + return; + delete(_cache.get(key)); + _cache.remove(key); + } + + private static void setEnd2Head(Entry entry) { + delete(entry); + setHead(entry); + } private static void delete(Entry entry) { if (entry._prev != null) @@ -650,29 +670,13 @@ public class LineageCache { _end = _head; } - private static void setEnd2Head(Entry entry) { - delete(entry); - setHead(entry); - } - - private static void removeEntry(double space) { - if (DMLScript.STATISTICS) - _removelist.add(_end._key); - _cache.remove(_end._key); - _cachesize -= space; - delete(_end); - } - - public static void removeEntry(LineageItem key) { - // Remove the entry for key - if (!_cache.containsKey(key)) - return; - delete(_cache.get(key)); - _cache.remove(key); - } + //////////////////////////////////// + // Internal Cache Data Structures // + //////////////////////////////////// - static class Entry { + private static class Entry { private final LineageItem _key; + private final DataType _dt; private MatrixBlock _MBval; private ScalarObject _SOval; double _compEst; @@ -680,8 +684,9 @@ public class LineageCache { private Entry _next; private LineageItem _origItem; - public Entry(LineageItem key, MatrixBlock Mval, ScalarObject Sval, double computecost) { + public Entry(LineageItem key, DataType dt, MatrixBlock Mval, ScalarObject Sval, double computecost) { _key = key; + _dt = dt; _MBval = Mval; _SOval = Sval; _compEst = computecost; @@ -725,19 +730,20 @@ public class LineageCache { } public boolean isMatrixValue() { - return(_MBval != null); + return _dt.isMatrix(); } public synchronized void setValue(MatrixBlock val, double compEst) { _MBval = val; _compEst = compEst; + //resume all threads waiting for val notifyAll(); } - public synchronized void setValue(MatrixBlock mval, ScalarObject so, double compEst) { - _MBval = mval; - _SOval = so; + public synchronized void setValue(ScalarObject val, double compEst) { + _SOval = val; _compEst = compEst; + //resume all threads waiting for val notifyAll(); } } @@ -747,8 +753,8 @@ public class LineageCache { double _compEst; public SpilledItem(String outfile, double computecost) { - this._outfile = outfile; - this._compEst = computecost; + _outfile = outfile; + _compEst = computecost; } } } diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java index e4ce09b..75a305a 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java @@ -19,11 +19,21 @@ package org.apache.sysds.runtime.lineage; +import org.apache.commons.lang3.ArrayUtils; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction; + import java.util.ArrayList; public class LineageCacheConfig { + private static final String[] REUSE_OPCODES = new String[] { + "tmm", "ba+*", "*", "/", "+", "nrow", "ncol", + "rightIndex", "leftIndex", "groupedagg", "r'", "solve", "spoof" + }; + public enum ReuseCacheType { REUSE_FULL, REUSE_PARTIAL, @@ -69,6 +79,21 @@ public class LineageCacheConfig { setSpill(false); //disable spilling of cache entries to disk } + public static boolean isReusable (Instruction inst, ExecutionContext ec) { + return inst instanceof ComputationCPInstruction + && (ArrayUtils.contains(REUSE_OPCODES, inst.getOpcode()) + || (inst.getOpcode().equals("append") && isVectorAppend(inst, ec))); + } + + private static boolean isVectorAppend(Instruction inst, ExecutionContext ec) { + ComputationCPInstruction cpinst = (ComputationCPInstruction) inst; + if( !cpinst.input1.isMatrix() || !cpinst.input2.isMatrix() ) + return false; + long c1 = ec.getMatrixObject(cpinst.input1).getNumColumns(); + long c2 = ec.getMatrixObject(cpinst.input2).getNumColumns(); + return(c1 == 1 || c2 == 1); + } + public static void setConfigTsmmCbind(ReuseCacheType ct) { _cacheType = ct; _itemH = CachedItemHead.TSMM; @@ -110,6 +135,11 @@ public class LineageCacheConfig { public static ReuseCacheType getCacheType() { return _cacheType; } + + public static boolean isMultiLevelReuse() { + return !ReuseCacheType.isNone() + && _cacheType.isMultilevelReuse(); + } public static CachedItemHead getCachedItemHead() { return _itemH; diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java index fb5a21f..f1b8c58 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java @@ -99,7 +99,7 @@ public class LineageRewriteReuse ec.setVariable(((ComputationCPInstruction)curr).output.getName(), lrwec.getVariable(LR_VAR)); //put the result into the cache - LineageCache.put(curr, ec); + LineageCache.putMatrix(curr, ec); DMLScript.EXPLAIN = et; //TODO can't change this here //cleanup execution context @@ -529,7 +529,7 @@ public class LineageRewriteReuse private static boolean isTsmmCbind(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) { + if (!LineageCacheConfig.isReusable(curr, ec)) { return false; } @@ -543,10 +543,10 @@ public class LineageRewriteReuse LineageItem input1 = source.getInputs()[0]; LineageItem tmp = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {input1}); if (LineageCache.probe(tmp)) - inCache.put("lastMatrix", LineageCache.get(tmp).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(tmp)); // look for the appended column in cache if (LineageCache.probe(source.getInputs()[1])) - inCache.put("deltaX", LineageCache.get(source.getInputs()[1]).getMBValue()); + inCache.put("deltaX", LineageCache.getMatrix(source.getInputs()[1])); } // return true only if the last tsmm is found return inCache.containsKey("lastMatrix") ? true : false; @@ -554,7 +554,7 @@ public class LineageRewriteReuse private static boolean isTsmmRbind(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) + if (!LineageCacheConfig.isReusable(curr, ec)) return false; // If the input to tsmm came from rbind, look for both the inputs in cache. @@ -566,10 +566,10 @@ public class LineageRewriteReuse LineageItem input1 = source.getInputs()[0]; LineageItem tmp = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {input1}); if (LineageCache.probe(tmp)) - inCache.put("lastMatrix", LineageCache.get(tmp).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(tmp)); // look for the appended column in cache if (LineageCache.probe(source.getInputs()[1])) - inCache.put("deltaX", LineageCache.get(source.getInputs()[1]).getMBValue()); + inCache.put("deltaX", LineageCache.getMatrix(source.getInputs()[1])); } // return true only if the last tsmm is found return inCache.containsKey("lastMatrix") ? true : false; @@ -577,7 +577,7 @@ public class LineageRewriteReuse private static boolean isTsmm2Cbind (Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) + if (!LineageCacheConfig.isReusable(curr, ec)) return false; //TODO: support nary cbind @@ -593,10 +593,10 @@ public class LineageRewriteReuse LineageItem tmp = new LineageItem("comb", "cbind", new LineageItem[] {L2appin1, source.getInputs()[1]}); LineageItem toProbe = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {tmp}); if (LineageCache.probe(toProbe)) - inCache.put("lastMatrix", LineageCache.get(toProbe).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(toProbe)); // look for the appended column in cache if (LineageCache.probe(input.getInputs()[1])) - inCache.put("deltaX", LineageCache.get(input.getInputs()[1]).getMBValue()); + inCache.put("deltaX", LineageCache.getMatrix(input.getInputs()[1])); } } // return true only if the last tsmm is found @@ -605,7 +605,7 @@ public class LineageRewriteReuse private static boolean isMatMulRbindLeft(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) + if (!LineageCacheConfig.isReusable(curr, ec)) return false; // If the left input to ba+* came from rbind, look for both the inputs in cache. @@ -618,10 +618,10 @@ public class LineageRewriteReuse // create ba+* lineage on top of the input of last append LineageItem tmp = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {leftSource, right}); if (LineageCache.probe(tmp)) - inCache.put("lastMatrix", LineageCache.get(tmp).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(tmp)); // look for the appended column in cache if (LineageCache.probe(left.getInputs()[1])) - inCache.put("deltaX", LineageCache.get(left.getInputs()[1]).getMBValue()); + inCache.put("deltaX", LineageCache.getMatrix(left.getInputs()[1])); } } // return true only if the last tsmm is found @@ -630,7 +630,7 @@ public class LineageRewriteReuse private static boolean isMatMulCbindRight(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) + if (!LineageCacheConfig.isReusable(curr, ec)) return false; // If the right input to ba+* came from cbind, look for both the inputs in cache. @@ -643,10 +643,10 @@ public class LineageRewriteReuse // create ba+* lineage on top of the input of last append LineageItem tmp = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {left, rightSource}); if (LineageCache.probe(tmp)) - inCache.put("lastMatrix", LineageCache.get(tmp).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(tmp)); // look for the appended column in cache if (LineageCache.probe(right.getInputs()[1])) - inCache.put("deltaY", LineageCache.get(right.getInputs()[1]).getMBValue()); + inCache.put("deltaY", LineageCache.getMatrix(right.getInputs()[1])); } } return inCache.containsKey("lastMatrix") ? true : false; @@ -654,7 +654,7 @@ public class LineageRewriteReuse private static boolean isElementMulRbind(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) + if (!LineageCacheConfig.isReusable(curr, ec)) return false; // If the inputs to * came from rbind, look for both the inputs in cache. @@ -668,12 +668,12 @@ public class LineageRewriteReuse // create * lineage on top of the input of last append LineageItem tmp = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {leftSource, rightSource}); if (LineageCache.probe(tmp)) - inCache.put("lastMatrix", LineageCache.get(tmp).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(tmp)); // look for the appended rows in cache if (LineageCache.probe(left.getInputs()[1])) - inCache.put("deltaX", LineageCache.get(left.getInputs()[1]).getMBValue()); + inCache.put("deltaX", LineageCache.getMatrix(left.getInputs()[1])); if (LineageCache.probe(right.getInputs()[1])) - inCache.put("deltaY", LineageCache.get(right.getInputs()[1]).getMBValue()); + inCache.put("deltaY", LineageCache.getMatrix(right.getInputs()[1])); } } return inCache.containsKey("lastMatrix") ? true : false; @@ -681,7 +681,7 @@ public class LineageRewriteReuse private static boolean isElementMulCbind(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) + if (!LineageCacheConfig.isReusable(curr, ec)) return false; // If the inputs to * came from cbind, look for both the inputs in cache. @@ -695,12 +695,12 @@ public class LineageRewriteReuse // create * lineage on top of the input of last append LineageItem tmp = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {leftSource, rightSource}); if (LineageCache.probe(tmp)) - inCache.put("lastMatrix", LineageCache.get(tmp).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(tmp)); // look for the appended columns in cache if (LineageCache.probe(left.getInputs()[1])) - inCache.put("deltaX", LineageCache.get(left.getInputs()[1]).getMBValue()); + inCache.put("deltaX", LineageCache.getMatrix(left.getInputs()[1])); if (LineageCache.probe(right.getInputs()[1])) - inCache.put("deltaY", LineageCache.get(right.getInputs()[1]).getMBValue()); + inCache.put("deltaY", LineageCache.getMatrix(right.getInputs()[1])); } } return inCache.containsKey("lastMatrix") ? true : false; @@ -708,7 +708,7 @@ public class LineageRewriteReuse private static boolean isAggCbind (Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) { - if (!LineageCache.isReusable(curr, ec)) { + if (!LineageCacheConfig.isReusable(curr, ec)) { return false; } @@ -726,10 +726,10 @@ public class LineageRewriteReuse LineageItem tmp = new LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {input1, groups, weights, fn, ngroups}); if (LineageCache.probe(tmp)) - inCache.put("lastMatrix", LineageCache.get(tmp).getMBValue()); + inCache.put("lastMatrix", LineageCache.getMatrix(tmp)); // look for the appended column in cache if (LineageCache.probe(target.getInputs()[1])) - inCache.put("deltaX", LineageCache.get(target.getInputs()[1]).getMBValue()); + inCache.put("deltaX", LineageCache.getMatrix(target.getInputs()[1])); } } // return true only if the last tsmm is found diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java index 9819fe0..22740f3 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java @@ -35,50 +35,55 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; -public class FunctionFullReuseTest extends AutomatedTestBase { - +public class FunctionFullReuseTest extends AutomatedTestBase +{ protected static final String TEST_DIR = "functions/lineage/"; - protected static final String TEST_NAME1 = "FunctionFullReuse1"; - protected static final String TEST_NAME2 = "FunctionFullReuse2"; - protected static final String TEST_NAME3 = "FunctionFullReuse3"; - protected static final String TEST_NAME4 = "FunctionFullReuse4"; - protected static final String TEST_NAME5 = "FunctionFullReuse5"; + protected static final String TEST_NAME = "FunctionFullReuse"; + protected static final int TEST_VARIANTS = 7; + protected String TEST_CLASS_DIR = TEST_DIR + FunctionFullReuseTest.class.getSimpleName() + "/"; @Override public void setUp() { TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1)); - addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2)); - addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3)); - addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4)); - addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5)); + for( int i=1; i<=TEST_VARIANTS; i++ ) + addTestConfiguration(TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i)); } @Test public void testCacheHit() { - testLineageTrace(TEST_NAME1); + testLineageTrace(TEST_NAME+"1"); } @Test public void testCacheMiss() { - testLineageTrace(TEST_NAME2); + testLineageTrace(TEST_NAME+"2"); } @Test public void testMultipleReturns() { - testLineageTrace(TEST_NAME3); + testLineageTrace(TEST_NAME+"3"); } @Test public void testNestedFunc() { - testLineageTrace(TEST_NAME4); + testLineageTrace(TEST_NAME+"4"); } @Test public void testStepLM() { - testLineageTrace(TEST_NAME5); - } + testLineageTrace(TEST_NAME+"5"); + } + + @Test + public void testParforIssue1() { + testLineageTrace(TEST_NAME+"6"); + } + + @Test + public void testParforIssue2() { + testLineageTrace(TEST_NAME+"7"); + } public void testLineageTrace(String testname) { boolean old_simplification = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; @@ -132,4 +137,3 @@ public class FunctionFullReuseTest extends AutomatedTestBase { } } } - diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse6.dml b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml new file mode 100644 index 0000000..2b025d5 --- /dev/null +++ b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(Matrix[Double] X) return (Matrix[Double] R) { + y = X + X - 2 * sqrt(X) + X * X; + while(FALSE){} + R = rowSums(y)*colSums(y); +} + +X = rand(rows=100, cols=10, seed=7); +while(FALSE){} +X = X + 1; + +R = matrix(0, 1, ncol(X)); +parfor(i in 1:10) { + R[,i] = sum(foo(X)); +} + +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse7.dml b/src/test/scripts/functions/lineage/FunctionFullReuse7.dml new file mode 100644 index 0000000..e4b64d8 --- /dev/null +++ b/src/test/scripts/functions/lineage/FunctionFullReuse7.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(Matrix[Double] X) return (Matrix[Double] R) { + y = X + X - 2 * sqrt(X) + X * X; + while(FALSE){} + R = rowSums(y)*colSums(y); +} + +X = rand(rows=100, cols=10, seed=7); +while(FALSE){} +X = X + 1; + +R = matrix(0, 1, ncol(X)); +parfor(i in 1:10) { + R[,i] = sum(foo(X[,i])); +} + +write(R, $1, format="text");