This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new 9bd68ff [SYSTEMDS-233] Fix multi-level lineage caching (parfor, determinism) 9bd68ff is described below commit 9bd68ffc5d211583a2ebcfe5be514abf4cc29b69 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Wed Apr 15 21:46:16 2020 +0200 [SYSTEMDS-233] Fix multi-level lineage caching (parfor, determinism) This patch fixes some issues with multi-level lineage caching in parfor, specifically (1) to allow function reuse despite differently named parfor worker functions, and (2) the check for deterministic function results incorrectly probed too far and thus missing opportunities. However, down the road we should add an IPA pass which determines once for all functions if they are deterministic and pass this information down to the runtime, in order to avoid scenarios where threads are already blocking on placeholders that are later removed due to non-deterministic functions. --- .../apache/sysds/hops/recompile/Recompiler.java | 10 +++++----- src/main/java/org/apache/sysds/lops/Lop.java | 2 +- .../sysds/runtime/controlprogram/ProgramBlock.java | 17 ++++++++++++++-- .../instructions/cp/FunctionCallCPInstruction.java | 23 +++++++++++++++++----- .../apache/sysds/runtime/lineage/LineageCache.java | 16 +++++++++------ .../runtime/lineage/LineageCacheStatistics.java | 10 +++++++++- .../sysds/runtime/lineage/LineageItemUtils.java | 10 +++------- .../java/org/apache/sysds/utils/Statistics.java | 2 +- .../functions/lineage/FunctionFullReuseTest.java | 7 +++++++ .../functions/lineage/FunctionFullReuse6.dml | 4 ++-- 10 files changed, 71 insertions(+), 30 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java index 2b11c73..d058c6a 100644 --- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java @@ -155,7 +155,7 @@ public class Recompiler } // replace thread ids in new instructions - if( tid != 0 ) //only in parfor context + if( ProgramBlock.isThreadID(tid) ) //only in parfor context newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false); // remove writes if called through mlcontext or jmlc @@ -187,7 +187,7 @@ public class Recompiler } // replace thread ids in new instructions - if( tid != 0 ) //only in parfor context + if( ProgramBlock.isThreadID(tid) ) //only in parfor context newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false); // explain recompiled instructions @@ -209,7 +209,7 @@ public class Recompiler } // replace thread ids in new instructions - if( tid != 0 ) //only in parfor context + if( ProgramBlock.isThreadID(tid) ) //only in parfor context newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false); // explain recompiled instructions @@ -231,7 +231,7 @@ public class Recompiler } // replace thread ids in new instructions - if( tid != 0 ) //only in parfor context + if( ProgramBlock.isThreadID(tid) ) //only in parfor context newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false); // explain recompiled hops / instructions @@ -253,7 +253,7 @@ public class Recompiler } // replace thread ids in new instructions - if( tid != 0 ) //only in parfor context + if( ProgramBlock.isThreadID(tid) ) //only in parfor context newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false); // explain recompiled hops / instructions diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java index fa25000..8bb7e1a 100644 --- a/src/main/java/org/apache/sysds/lops/Lop.java +++ b/src/main/java/org/apache/sysds/lops/Lop.java @@ -82,7 +82,7 @@ public abstract class Lop public static final String PROCESS_PREFIX = "_p"; public static final String CP_ROOT_THREAD_ID = "_t0"; public static final String CP_CHILD_THREAD = "_t"; - public static final double SAMPLE_FRACTION = 0.01; // for row sampling in distributed frame meta operations + public static final double SAMPLE_FRACTION = 0.01; // for row sampling in distributed frame meta operations //special delimiters w/ extended ASCII characters to avoid collisions public static final String INSTRUCTION_DELIMITOR = "\u2021"; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java index 4f5ef85..5cde84e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java @@ -82,14 +82,27 @@ public abstract class ProgramBlock implements ParseInfo return _sb; } - public void setStatementBlock( StatementBlock sb ){ + public void setStatementBlock(StatementBlock sb){ _sb = sb; } - public void setThreadID( long id ){ + public void setThreadID(long id){ _tid = id; } + public boolean hasThreadID() { + return _tid != 0; + } + + public static boolean isThreadID (long tid) { + return tid != 0; + } + + public long getThreadID() { + return _tid; + } + + /** * Get the list of child program blocks if nested; * otherwise this method returns null. diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java index 9c1eac0..5d7feee 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java @@ -225,8 +225,10 @@ public class FunctionCallCPInstruction extends CPInstruction { } //update lineage cache with the functions outputs - if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() ) - LineageCache.putValue(fpb.getOutputParams(), liInputs, _functionName, ec); + if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() ) { + LineageCache.putValue(fpb.getOutputParams(), + liInputs, getCacheFunctionName(_functionName, fpb), ec); + } } @Override @@ -249,7 +251,7 @@ public class FunctionCallCPInstruction extends CPInstruction { //split current instruction String[] parts = instString.split(Lop.OPERAND_DELIMITOR); if( parts[3].equals(pattern) ) - parts[3] = replace; + parts[3] = replace; //construct and set modified instruction StringBuilder sb = new StringBuilder(); @@ -262,14 +264,25 @@ public class FunctionCallCPInstruction extends CPInstruction { } private boolean reuseFunctionOutputs(LineageItem[] liInputs, FunctionProgramBlock fpb, ExecutionContext ec) { + //prepare lineage cache probing + String funcName = getCacheFunctionName(_functionName, fpb); int numOutputs = Math.min(_boundOutputNames.size(), fpb.getOutputParams().size()); - boolean reuse = LineageCache.reuse(_boundOutputNames, fpb.getOutputParams(), numOutputs, liInputs, _functionName, ec); + + //reuse of function outputs + boolean reuse = LineageCache.reuse( + _boundOutputNames, fpb.getOutputParams(), numOutputs, liInputs, funcName, ec); + //statistics maintenance if (reuse && DMLScript.STATISTICS) { //decrement the call count for this function - Statistics.maintainCPFuncCallStats(this.getExtendedOpcode()); + Statistics.maintainCPFuncCallStats(getExtendedOpcode()); LineageCacheStatistics.incrementFuncHits(); } return reuse; } + + private static String getCacheFunctionName(String fname, FunctionProgramBlock fpb) { + return !fpb.hasThreadID() ? fname : + fname.substring(0, fname.lastIndexOf(Lop.CP_CHILD_THREAD+fpb.getThreadID())); + } } diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java index 2741b70..0d93699 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java @@ -143,17 +143,18 @@ public class LineageCache LineageItem li = new LineageItem(outNames.get(i), opcode, liInputs); Entry e = null; synchronized( _cache ) { - if (LineageCache.probe(li)) + if (LineageCache.probe(li)) { e = LineageCache.getIntern(li); - else + } + else { //create a placeholder if no reuse to avoid redundancy //(e.g., concurrent threads that try to start the computation) putIntern(li, outParams.get(i).getDataType(), null, null, 0); - //FIXME: parfor - every thread gets different function names + } } //TODO: handling of recursive calls - if (e != null && !e.isNullVal()) { + if ( e != null ) { String boundVarName = outNames.get(i); Data boundValue = null; //convert to matrix object @@ -164,8 +165,9 @@ public class LineageCache ((MatrixObject)boundValue).acquireModify(e.getMBValue()); ((MatrixObject)boundValue).release(); } - else + else { boundValue = e.getSOValue(); + } funcOutputs.put(boundVarName, boundValue); LineageItem orig = e._origItem; @@ -250,7 +252,7 @@ public class LineageCache public static void putValue(List<DataIdentifier> outputs, LineageItem[] liInputs, String name, ExecutionContext ec) { - if( !LineageCacheConfig.isMultiLevelReuse()) + if( !LineageCacheConfig.isMultiLevelReuse() ) return; HashMap<LineageItem, LineageItem> FuncLIMap = new HashMap<>(); @@ -264,6 +266,8 @@ public class LineageCache boundLI.resetVisitStatus(); if (boundLI == null || !LineageCache.probe(li) + //TODO remove this brittle constraint (if the placeholder is removed + //it might crash threads that are already waiting for its results) || LineageItemUtils.containsRandDataGen(new HashSet<>(Arrays.asList(liInputs)), boundLI)) { AllOutputsCacheable = false; } diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java index 98ad75e..9704797 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java @@ -122,6 +122,14 @@ public class LineageCacheStatistics { // Total time spent compiling lineage rewrites. _ctimeRewrite.add(delta); } + + public static long getMultiLevelFnHits() { + return _numHitsFunc.longValue(); + } + + public static long getMultiLevelSBHits() { + return _numHitsSB.longValue(); + } public static void incrementPRwExecTime(long delta) { // Total time spent executing lineage rewrites. @@ -138,7 +146,7 @@ public class LineageCacheStatistics { return sb.toString(); } - public static String displayMultiLvlHits() { + public static String displayMultiLevelHits() { StringBuilder sb = new StringBuilder(); sb.append(_numHitsInst.longValue()); sb.append("/"); diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java index eade225..aeeacd3 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java @@ -612,16 +612,12 @@ public class LineageItemUtils { public static boolean containsRandDataGen(HashSet<LineageItem> entries, LineageItem root) { if (entries.contains(root) || root.isVisited()) return false; - - boolean isRand = false; - if (isNonDeterministic(root)) - isRand |= true; - if (!root.isLeaf()) + boolean isRand = isNonDeterministic(root); + if (!root.isLeaf() && !isRand) for (LineageItem input : root.getInputs()) - isRand = isRand ? true : containsRandDataGen(entries, input); + isRand |= containsRandDataGen(entries, input); root.setVisited(); return isRand; - //TODO: unmark for caching in compile time } private static boolean isNonDeterministic(LineageItem li) { diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index 9445caf..4c3cbef 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -945,7 +945,7 @@ public class Statistics } if (DMLScript.LINEAGE && !ReuseCacheType.isNone()) { sb.append("LinCache hits (Mem/FS/Del): \t" + LineageCacheStatistics.displayHits() + ".\n"); - sb.append("LinCache MultiLevel (Ins/SB/Fn):" + LineageCacheStatistics.displayMultiLvlHits() + ".\n"); + sb.append("LinCache MultiLevel (Ins/SB/Fn):" + LineageCacheStatistics.displayMultiLevelHits() + ".\n"); sb.append("LinCache writes (Mem/FS): \t" + LineageCacheStatistics.displayWtrites() + ".\n"); sb.append("LinCache FStimes (Rd/Wr): \t" + LineageCacheStatistics.displayTime() + " sec.\n"); sb.append("LinCache costing time: \t" + LineageCacheStatistics.displayCostingTime() + " sec.\n"); diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java index 22740f3..8fc7f78 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java @@ -19,13 +19,16 @@ package org.apache.sysds.test.functions.lineage; +import org.junit.Assert; import org.junit.Test; + import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.recompile.Recompiler; import org.apache.sysds.lops.LopProperties.ExecType; import org.apache.sysds.runtime.lineage.Lineage; import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; +import org.apache.sysds.runtime.lineage.LineageCacheStatistics; import org.apache.sysds.runtime.matrix.data.MatrixValue; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; @@ -128,6 +131,10 @@ public class FunctionFullReuseTest extends AutomatedTestBase Lineage.setLinReuseNone(); TestUtils.compareMatrices(X_orig, X_reused, 1e-6, "Origin", "Reused"); + if( testname.endsWith("6") ) { // parfor fn reuse + Assert.assertEquals(9L, LineageCacheStatistics.getMultiLevelFnHits() + + LineageCacheStatistics.getMultiLevelSBHits()); + } } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = old_simplification; diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse6.dml b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml index 2b025d5..02af351 100644 --- a/src/test/scripts/functions/lineage/FunctionFullReuse6.dml +++ b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml @@ -20,9 +20,9 @@ #------------------------------------------------------------- foo = function(Matrix[Double] X) return (Matrix[Double] R) { - y = X + X - 2 * sqrt(X) + X * X; + Y = X + X - 2 * sqrt(X) + X * X; while(FALSE){} - R = rowSums(y)*colSums(y); + R = rowSums(Y)%*%colSums(Y); } X = rand(rows=100, cols=10, seed=7);