This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a070f63b46 [SYSTEMDS-3585] Reuse lineage traces from lineage cache
a070f63b46 is described below

commit a070f63b46611127dbf371f9fdc40730c5619f9c
Author: Arnab Phani <[email protected]>
AuthorDate: Thu Jun 29 13:25:09 2023 +0200

    [SYSTEMDS-3585] Reuse lineage traces from lineage cache
    
    This commits adds a small but useful extension to lineage-base reuse.
    We now also reuse the lineage traces corresponding to the reused
    intermediates by replacing the live lineage traces with the cached ones.
    This change increases the use of same lineage items in many lineage
    DAGs, which in turn reduces probing cost and memory overhead.
    This extension is disabled for parfor and deduplicated lineage traces.
    Integrating with those require more thoughts.
    
    Closes #1853
---
 .../runtime/controlprogram/ParForProgramBlock.java      |  5 +++++
 .../controlprogram/context/ExecutionContext.java        | 13 +++++++++++++
 .../instructions/spark/MapmmChainSPInstruction.java     | 16 ++++++++++++++--
 .../java/org/apache/sysds/runtime/lineage/Lineage.java  |  1 +
 .../org/apache/sysds/runtime/lineage/LineageCache.java  | 14 +++++++++++---
 .../sysds/runtime/lineage/LineageCacheConfig.java       | 17 +++++++++++++----
 .../sysds/runtime/lineage/LineageCacheStatistics.java   |  8 ++++++++
 .../org/apache/sysds/runtime/lineage/LineageItem.java   |  5 ++++-
 src/main/java/org/apache/sysds/utils/Statistics.java    |  2 +-
 9 files changed, 70 insertions(+), 11 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 4df1a7052e..2fc12c4c26 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -84,6 +84,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -800,6 +801,7 @@ public class ParForProgramBlock extends ForProgramBlock
                                StatisticMonitor.putPFStat(_ID, 
Stat.PARFOR_INIT_TASKS_T, time.stop());
                        
                        // Step 3) join all threads (wait for finished work)
+                       LineageCacheConfig.setReuseLineageTraces(false); 
//disable lineage trace reuse
                        for( Thread thread : threads )
                                thread.join();
                        
@@ -823,6 +825,7 @@ public class ParForProgramBlock extends ForProgramBlock
                                .map(w -> w.getExecutionContext().getLineage())
                                .toArray(Lineage[]::new);
                        mergeLineage(ec, lineages);
+                       //LineageCacheConfig.setReuseLineageTraces(true);
 
                        //consolidate results into global symbol table
                        consolidateAndCheckResults( ec, numIterations, 
numCreatedTasks,
@@ -900,6 +903,7 @@ public class ParForProgramBlock extends ForProgramBlock
                exportMatricesToHDFS(ec, brVars);
                
                // Step 3) submit Spark parfor job (no lazy evaluation, since 
collect on result)
+               LineageCacheConfig.setReuseLineageTraces(false); //disable 
lineage trace reuse
                boolean topLevelPF = OptimizerUtils.isTopLevelParFor();
                RemoteParForJobReturn ret = RemoteParForSpark.runJob(_ID, 
program,
                        clsMap, tasks, ec, brVars, _resultVars, 
_enableCPCaching, _numThreads, topLevelPF);
@@ -913,6 +917,7 @@ public class ParForProgramBlock extends ForProgramBlock
                
                //lineage maintenance
                mergeLineage(ec, ret.getLineages());
+               //LineageCacheConfig.setReuseLineageTraces(true);
                // TODO: remove duplicate lineage items in ec.getLineage()
                
                //consolidate results into global symbol table
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index cdedfb9e45..9c8547f615 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -167,6 +167,7 @@ public class ExecutionContext {
        }
 
        /**
+        *
         * Get the i-th GPUContext
         * @param index index of the GPUContext
         * @return a valid GPUContext or null if the indexed GPUContext does 
not exist.
@@ -924,6 +925,18 @@ public class ExecutionContext {
                        throw new DMLRuntimeException("Lineage Trace 
unavailable.");
                return _lineage.getOrCreate(input);
        }
+
+       public void replaceLineageItem(String varname, LineageItem li) {
+               if (!LineageCacheConfig.isLineageTraceReuse())
+                       return;
+               if( _lineage == null )
+                       throw new DMLRuntimeException("Lineage Trace 
unavailable.");
+               if (_lineage.get(varname) == null)
+                       throw new DMLRuntimeException("Lineage item does not 
exist for "+varname);
+               //Passed lineage trace should be equivalent to the live lineage 
trace
+               //corresponding to varname. Replacing reduces memory and 
probing overheads.
+               _lineage.set(varname, li);
+       }
        
        private static String getNonExistingVarError(String varname) {
                return "Variable '" + varname + "' does not exist in the symbol 
table.";
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
index b1c8248579..e2f4e5d270 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
@@ -20,9 +20,11 @@
 package org.apache.sysds.runtime.instructions.spark;
 
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.lops.MapMultChain;
 import org.apache.sysds.lops.MapMultChain.ChainType;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -32,12 +34,15 @@ import 
org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
 import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import scala.Tuple2;
 
-public class MapmmChainSPInstruction extends SPInstruction {
+public class MapmmChainSPInstruction extends SPInstruction implements 
LineageTraceable {
        private ChainType _chainType = null;
        private CPOperand _input1 = null;
        private CPOperand _input2 = null;
@@ -116,7 +121,14 @@ public class MapmmChainSPInstruction extends SPInstruction 
{
                //this also includes implicit maintenance of matrix 
characteristics
                sec.setMatrixOutput(_output.getName(), out);
        }
-       
+
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               CPOperand chainT = new CPOperand(_chainType.name(), 
Types.ValueType.INT64, Types.DataType.SCALAR, true);
+               return Pair.of(_output.getName(), new LineageItem(getOpcode(),
+                       LineageItemUtils.getLineage(ec, _input1, _input2, 
_input3, chainT)));
+       }
+
        /**
         * This function implements the chain type XtXv which requires just one 
broadcast and
         * no access to any indexes of matrix blocks.
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java 
b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
index c233e55e23..866ebc8120 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
@@ -132,6 +132,7 @@ public class Lineage {
        public void initializeDedupBlock(ProgramBlock pb, ExecutionContext ec) {
                if( !(pb instanceof ForProgramBlock || pb instanceof 
WhileProgramBlock) )
                        throw new DMLRuntimeException("Invalid deduplication 
block: "+ pb.getClass().getSimpleName());
+               LineageCacheConfig.setReuseLineageTraces(false);
                if (!_dedupBlocks.containsKey(pb)) {
                        // valid only if doesn't contain a nested loop
                        boolean valid = LineageDedupUtils.isValidDedupBlock(pb, 
false);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index a51c0ae9e3..1a3b12d7a9 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -108,7 +108,8 @@ public class LineageCache
                                //try to reuse full or partial intermediates 
(CPU and FED only)
                                for (MutablePair<LineageItem,LineageCacheEntry> 
item : liList) {
                                        if 
(LineageCacheConfig.getCacheType().isFullReuse())
-                                               e = 
LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
+                                               //e = 
LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
+                                               e = getIntern(item.getKey()); 
//avoid double probing (containsKey + get)
                                        //TODO need to also move execution of 
compensation plan out of here
                                        //(create lazily evaluated entry)
                                        if (e == null && 
LineageCacheConfig.getCacheType().isPartialReuse()
@@ -162,6 +163,7 @@ public class LineageCache
                                                                //Even not 
persisted, reuse the rdd locally for shuffle operations
                                                                if 
(!LineageCacheConfig.isShuffleOp(inst))
                                                                        return 
false;
+
                                                                
((SparkExecutionContext) ec).setRDDHandleForVariable(outName, rdd);
                                                                break;
                                                        case PERSISTEDRDD:
@@ -184,6 +186,8 @@ public class LineageCache
                                                //Increment the live count for 
this pointer
                                                
LineageGPUCacheEviction.incrementLiveCount(e.getGPUPointer());
                                        }
+                                       //Replace the live lineage trace with 
the cached one (if not parfor, dedup)
+                                       ec.replaceLineageItem(outName, e._key);
                                }
                                maintainReuseStatistics(ec, inst, 
liList.get(0).getValue());
                        }
@@ -444,6 +448,7 @@ public class LineageCache
                if (!p && DMLScript.STATISTICS && 
LineageCacheEviction._removelist.containsKey(key))
                        // The sought entry was in cache but removed later 
                        LineageCacheStatistics.incrementDelHits();
+
                return p;
        }
 
@@ -949,9 +954,11 @@ public class LineageCache
        }
        
        private static LineageCacheEntry getIntern(LineageItem key) {
-               // This method is called only when entry is present either in 
cache or in local FS.
                LineageCacheEntry e = _cache.get(key);
-               if (e != null && e.getCacheStatus() != 
LineageCacheStatus.SPILLED) {
+               if (e == null)
+                       return null;
+
+               if (e.getCacheStatus() != LineageCacheStatus.SPILLED) {
                        if (DMLScript.STATISTICS)
                                // Increment hit count.
                                LineageCacheStatistics.incrementMemHits();
@@ -1222,6 +1229,7 @@ public class LineageCache
                //TODO: Replace with generic type
 
                List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
+               //FIXME: Replace getLineageItem with get/getOrCreate to avoid 
creating a new LI object
                LineageItem instLI = (cinst != null) ? 
cinst.getLineageItem(ec).getValue()
                        : (cfinst != null) ? 
cfinst.getLineageItem(ec).getValue()
                        : (cspinst != null) ? 
cspinst.getLineageItem(ec).getValue()
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index d0e32570b9..63863f7029 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -50,11 +50,11 @@ public class LineageCacheConfig
        private static final String[] OPCODES = new String[] {
                "tsmm", "ba+*", "*", "/", "+", "||", "nrow", "ncol", "round", 
"exp", "log",
                "rightIndex", "leftIndex", "groupedagg", "r'", "solve", "spoof",
-               "uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+", 
"<=",
+               "uamean", "max", "min", "ifelse", "-", "sqrt", "<", ">", 
"uak+", "<=",
                "^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", 
"replace",
-               "^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", 
"qsort", 
+               "^2", "*2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", 
"uarimax", "qsort",
                "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", 
"lowertri",
-               "prefetch", "mapmm"
+               "prefetch", "mapmm", "contains", "mmchain", "mapmmchain", "+*"
                //TODO: Reuse everything.
        };
 
@@ -70,7 +70,7 @@ public class LineageCacheConfig
 
        // Relatively inexpensive instructions.
        private static final String[] PERSIST_OPCODES2 = new String[] {
-               "mapmm"
+               "mapmm,"
        };
 
        private static String[] REUSE_OPCODES  = new String[] {};
@@ -104,6 +104,7 @@ public class LineageCacheConfig
        private static CachedItemTail _itemT = null;
        private static boolean _compilerAssistedRW = false;
        private static boolean _onlyEstimate = false;
+       private static boolean _reuseLineageTraces = true;
 
        //-------------DISK SPILLING RELATED CONFIGURATIONS--------------//
 
@@ -368,6 +369,14 @@ public class LineageCacheConfig
                return _compilerAssistedRW;
        }
 
+       public static void setReuseLineageTraces(boolean reuseTrace) {
+               _reuseLineageTraces = reuseTrace;
+       }
+
+       public static boolean isLineageTraceReuse() {
+               return _reuseLineageTraces;
+       }
+
        public static void setCachePolicy(LineageCachePolicy policy) {
                // TODO: Automatic tuning of weights.
                switch(policy) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index c7bbd6a00d..fee5b8a0dd 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -41,6 +41,7 @@ public class LineageCacheStatistics {
        private static final LongAdder _ctimeFSWrite    = new LongAdder();
        private static final LongAdder _ctimeSaved      = new LongAdder();
        private static final LongAdder _ctimeMissed     = new LongAdder();
+       private static final LongAdder _ctimeProbe      = new LongAdder();
        // Bellow entries are specific to gpu lineage cache
        private static final LongAdder _numHitsGpu      = new LongAdder();
        private static final LongAdder _numAsyncEvictGpu= new LongAdder();
@@ -70,6 +71,7 @@ public class LineageCacheStatistics {
                _ctimeFSWrite.reset();
                _ctimeSaved.reset();
                _ctimeMissed.reset();
+               _ctimeProbe.reset();
                _evtimeGpu.reset();
                _numHitsGpu.reset();
                _numAsyncEvictGpu.reset();
@@ -191,6 +193,10 @@ public class LineageCacheStatistics {
                _ctimeMissed.add(delta);
        }
 
+       public static void incrementProbeTime(long delta) {
+               _ctimeProbe.add(delta);
+       }
+
        public static long getMultiLevelFnHits() {
                return _numHitsFunc.longValue();
        }
@@ -303,6 +309,8 @@ public class LineageCacheStatistics {
                sb.append(String.format("%.3f", 
((double)_ctimeSaved.longValue())/1000000000)); //in sec
                sb.append("/");
                sb.append(String.format("%.3f", 
((double)_ctimeMissed.longValue())/1000000000)); //in sec
+               sb.append("/");
+               sb.append(String.format("%.3f", 
((double)_ctimeProbe.longValue())/1000000000)); //in sec
                return sb.toString();
        }
 
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index 311dae2a86..943f497937 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Stack;
 
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.util.UtilFunctions;
@@ -269,8 +270,8 @@ public class LineageItem {
                Stack<LineageItem> s2 = new Stack<>();
                s1.push(this);
                s2.push(that);
-               //boolean ret = false;
                boolean ret = true;
+               long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                while (!s1.empty() && !s2.empty()) {
                        LineageItem li1 = s1.pop();
                        LineageItem li2 = s2.pop();
@@ -356,6 +357,8 @@ public class LineageItem {
                                }
                        li1.setVisited();
                }
+               if (DMLScript.STATISTICS) //increment probing time
+                       
LineageCacheStatistics.incrementProbeTime(System.nanoTime() - t0);
                return ret;
        }
        
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 9d931f3ce3..6978507179 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -649,7 +649,7 @@ public class Statistics
                                }
                                sb.append("LinCache writes (Mem/FS/Del): \t" + 
LineageCacheStatistics.displayWtrites() + ".\n");
                                sb.append("LinCache FStimes (Rd/Wr): \t" + 
LineageCacheStatistics.displayFSTime() + " sec.\n");
-                               sb.append("LinCache Computetime (S/M): \t" + 
LineageCacheStatistics.displayComputeTime() + " sec.\n");
+                               sb.append("LinCache Computetime (S/M/P): \t" + 
LineageCacheStatistics.displayComputeTime() + " sec.\n");
                                sb.append("LinCache Rewrites:    \t\t" + 
LineageCacheStatistics.displayRewrites() + ".\n");
                        }
 

Reply via email to