This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f89a38da04 [SYSTEMDS-3479] Mark Spark instructions to persist and 
locally cache
f89a38da04 is described below

commit f89a38da041e685651afe016aafbe70288a31a66
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Dec 21 23:56:24 2022 +0100

    [SYSTEMDS-3479] Mark Spark instructions to persist and locally cache
    
    This patch adds the compiler flags and runtime support to checkpoint
    any Spark instruction which is marked for caching. During postprocessing
    of a marked instruction, we first inplace persist the RDD and then store
    the RDD in the local Lineage cache for reuse. This patch also fixes a
    bug in the last commit which was unpersisting the locally cached RDDs
    during rmvar. Future commits will add rewrites to mark the Spark
    instructions for caching in a cost-based manner.
    Hyperparameter tuning of LmDS with 2.5k columns improves by
    22x by caching the cpmm results in the executors.
    
    Closes #1756
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    |   4 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |   8 +
 src/main/java/org/apache/sysds/lops/MMCJ.java      |   2 +
 src/main/java/org/apache/sysds/lops/MMRJ.java      |   4 +-
 .../org/apache/sysds/lops/OutputParameters.java    |   6 +-
 .../controlprogram/caching/MatrixObject.java       |   2 +-
 .../context/SparkExecutionContext.java             |   4 +-
 .../spark/AggregateBinarySPInstruction.java        |   8 +-
 .../instructions/spark/BinarySPInstruction.java    |   7 +-
 .../spark/ComputationSPInstruction.java            |  40 ++++
 .../instructions/spark/CpmmSPInstruction.java      |   8 +-
 .../instructions/spark/RmmSPInstruction.java       |   8 +-
 .../instructions/spark/data/LineageObject.java     |  10 +
 .../runtime/instructions/spark/data/RDDObject.java |   5 +
 .../apache/sysds/runtime/lineage/LineageCache.java | 215 +++++++++++----------
 .../sysds/runtime/lineage/LineageCacheConfig.java  |  15 +-
 .../sysds/runtime/lineage/LineageCacheEntry.java   |   4 +
 .../runtime/lineage/LineageCacheStatistics.java    |  26 ++-
 .../java/org/apache/sysds/utils/Statistics.java    |   1 +
 .../functions/async/LineageReuseSparkTest.java     |  12 ++
 .../scripts/functions/async/LineageReuseSpark2.dml |  53 +++++
 21 files changed, 316 insertions(+), 126 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index dd04307229..9f80f7a683 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -217,7 +217,7 @@ public class AggBinaryOp extends MultiThreadedHop {
                                                input1.getDim1(), 
input1.getDim2(), input1.getBlocksize(), input1.getNnz(),
                                                input2.getDim1(), 
input2.getDim2(), input2.getBlocksize(), input2.getNnz(),
                                                mmtsj, chain, _hasLeftPMInput, 
tmmRewrite );
-                               //dispatch SPARK lops construction 
+                               //dispatch SPARK lops construction
                                switch( _method )
                                {
                                        case TSMM:
@@ -790,6 +790,7 @@ public class AggBinaryOp extends MultiThreadedHop {
                        Lop cpmm = new MMCJ(getInput().get(0).constructLops(), 
getInput().get(1).constructLops(),
                                getDataType(), getValueType(), 
_outputEmptyBlocks, aggtype, ExecType.SPARK);
                        setOutputDimensions( cpmm );
+                       //setMarkForLineageCaching(cpmm);
                        setLineNumbers( cpmm );
                        setLops( cpmm );
                }
@@ -823,6 +824,7 @@ public class AggBinaryOp extends MultiThreadedHop {
                Lop rmm = new 
MMRJ(getInput().get(0).constructLops(),getInput().get(1).constructLops(), 
                        getDataType(), getValueType(), ExecType.SPARK);
                setOutputDimensions(rmm);
+               //setMarkForLineageCaching(rmm);
                setLineNumbers( rmm );
                setLops(rmm);
        }
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 3988a6b59f..fa911749ee 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -57,6 +57,7 @@ import 
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -1235,6 +1236,13 @@ public abstract class Hop implements ParseInfo {
                        getDim1(), getDim2(), getBlocksize(), getNnz(), 
getUpdateType());
        }
 
+       protected void setMarkForLineageCaching(Lop lop) {
+               //TODO: set the flag in the HOP via a rewrite
+               
//lop.getOutputParameters().setLineageCacheCandidate(requiresLineageCaching());
+               if (!LineageCacheConfig.ReuseCacheType.isNone())
+                       
lop.getOutputParameters().setLineageCacheCandidate(true);
+       }
+
        protected void setOutputDimensionsIncludeCompressedSize(Lop lop) {
                lop.getOutputParameters().setDimensions(
                        getDim1(), getDim2(), getBlocksize(), getNnz(), 
getUpdateType(), getCompressedSize());
diff --git a/src/main/java/org/apache/sysds/lops/MMCJ.java 
b/src/main/java/org/apache/sysds/lops/MMCJ.java
index e804f84e9b..544e89341b 100644
--- a/src/main/java/org/apache/sysds/lops/MMCJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMCJ.java
@@ -109,6 +109,8 @@ public class MMCJ extends Lop
                }
                else
                        sb.append(_type.name());
+               sb.append( OPERAND_DELIMITOR );
+               sb.append(getOutputParameters().getLinCacheMarking());
                
                return sb.toString();
        }
diff --git a/src/main/java/org/apache/sysds/lops/MMRJ.java 
b/src/main/java/org/apache/sysds/lops/MMRJ.java
index e35cdead7e..21577eeddb 100644
--- a/src/main/java/org/apache/sysds/lops/MMRJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMRJ.java
@@ -59,11 +59,13 @@ public class MMRJ extends Lop
 
        @Override
        public String getInstructions(String input1, String input2, String 
output) {
+               boolean toCache = getOutputParameters().getLinCacheMarking();
                return InstructionUtils.concatOperands(
                        getExecType().name(),
                        "rmm",
                        getInputs().get(0).prepInputOperand(input1),
                        getInputs().get(1).prepInputOperand(input2),
-                       prepOutputOperand(output));
+                       prepOutputOperand(output),
+                       Boolean.toString(toCache));
        }
 }
diff --git a/src/main/java/org/apache/sysds/lops/OutputParameters.java 
b/src/main/java/org/apache/sysds/lops/OutputParameters.java
index 64ba755395..9454f19a13 100644
--- a/src/main/java/org/apache/sysds/lops/OutputParameters.java
+++ b/src/main/java/org/apache/sysds/lops/OutputParameters.java
@@ -39,7 +39,7 @@ public class OutputParameters
        private long _blocksize = -1;
        private String _file_name = null;
        private String _file_label = null;
-       private boolean _linCacheCandidate = true;
+       private boolean _linCacheCandidate = false;
        private long _compressedSize = -1;
 
        FileFormat matrix_format = FileFormat.BINARY;
@@ -162,6 +162,10 @@ public class OutputParameters
        public void setUpdateType(UpdateType update) {
                _updateType = update;
        }
+
+       public void setLineageCacheCandidate(boolean reqCaching) {
+               _linCacheCandidate = reqCaching;
+       }
        
        public boolean getLinCacheMarking() {
                return _linCacheCandidate;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index c723cc56fa..e0139f2a62 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -476,7 +476,7 @@ public class MatrixObject extends 
CacheableData<MatrixBlock> {
                FileFormat fmt = iimd.getFileFormat();
                MatrixBlock mb = null;
                try {
-                       // prevent unnecessary collect through rdd checkpoint
+                       // prevent unnecessary collect through rdd checkpoint 
(unless lineage cached)
                        if(rdd.allowsShortCircuitCollect()) {
                                lrdd = (RDDObject) 
rdd.getLineageChilds().get(0);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 48778cb4d4..77eca47640 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1522,7 +1522,9 @@ public class SparkExecutionContext extends 
ExecutionContext
                if( lob instanceof RDDObject ) {
                        RDDObject rdd = (RDDObject)lob;
                        int rddID = rdd.getRDD().id();
-                       cleanupRDDVariable(rdd.getRDD());
+                       //skip unpersisting if locally cached
+                       if (!lob.isInLineageCache())
+                               cleanupRDDVariable(rdd.getRDD());
                        if( rdd.getHDFSFilename()!=null ) { //deferred file 
removal
                                
HDFSTool.deleteFileWithMTDIfExistOnHDFS(rdd.getHDFSFilename());
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
index 80de732505..4575b06252 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
@@ -26,8 +26,12 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
  * Class to group the different MM <code>SPInstruction</code>s together.
  */
 public abstract class AggregateBinarySPInstruction extends BinarySPInstruction 
{
-       protected AggregateBinarySPInstruction(SPType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand out,
-                       String opcode, String istr) {
+       protected AggregateBinarySPInstruction(SPType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
                super(type, op, in1, in2, out, opcode, istr);
        }
+
+       protected AggregateBinarySPInstruction(SPType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand out,
+               String opcode, boolean toCache, String istr) {
+               super(type, op, in1, in2, out, opcode, toCache, istr);
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index 16196d5e0d..3c70d4021a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -56,7 +56,12 @@ public abstract class BinarySPInstruction extends 
ComputationSPInstruction {
        protected BinarySPInstruction(SPType type, Operator op, CPOperand in1, 
CPOperand in2, CPOperand out, String opcode, String istr) {
                super(type, op, in1, in2, out, opcode, istr);
        }
-       
+
+       protected BinarySPInstruction(SPType type, Operator op, CPOperand in1, 
CPOperand in2,
+               CPOperand out, String opcode, boolean toCache, String istr) {
+               super(type, op, in1, in2, out, opcode, toCache, istr);
+       }
+
        public static BinarySPInstruction parseInstruction ( String str ) {
                CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
                CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
index d380d913b2..465ce35820 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
@@ -20,7 +20,10 @@
 package org.apache.sysds.runtime.instructions.spark;
 
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.storage.StorageLevel;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.functionobjects.IndexFunction;
@@ -28,15 +31,22 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll;
 import org.apache.sysds.runtime.functionobjects.ReduceCol;
 import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
+import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.lineage.LineageTraceable;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
+import java.util.Map;
+
 public abstract class ComputationSPInstruction extends SPInstruction 
implements LineageTraceable {
        public CPOperand output;
        public CPOperand input1, input2, input3;
+       private boolean toPersistAndCache;
 
        protected ComputationSPInstruction(SPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand out, String opcode, String istr) {
                super(type, op, opcode, istr);
@@ -46,6 +56,15 @@ public abstract class ComputationSPInstruction extends 
SPInstruction implements
                output = out;
        }
 
+       protected ComputationSPInstruction(SPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand out, String opcode, boolean toCache, String istr) 
{
+               super(type, op, opcode, istr);
+               input1 = in1;
+               input2 = in2;
+               input3 = null;
+               output = out;
+               toPersistAndCache = toCache;
+       }
+
        protected ComputationSPInstruction(SPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
                super(type, op, opcode, istr);
                input1 = in1;
@@ -126,6 +145,27 @@ public abstract class ComputationSPInstruction extends 
SPInstruction implements
                                mcOut.set(1, mc1.getCols(), mc1.getBlocksize(), 
mc1.getBlocksize());
                }
        }
+
+       public boolean isRDDtoCache() {
+               return toPersistAndCache;
+       }
+
+       public void checkpointRDD(ExecutionContext ec) {
+               if (!toPersistAndCache)
+                       return;
+
+               SparkExecutionContext sec = (SparkExecutionContext)ec;
+               CacheableData<?> cd = sec.getCacheableData(output.getName());
+               RDDObject inro =  cd.getRDDHandle();
+               JavaPairRDD<?,?> outrdd = 
SparkUtils.copyBinaryBlockMatrix((JavaPairRDD<MatrixIndexes, 
MatrixBlock>)inro.getRDD(), false);
+               //TODO: remove shallow copying as short-circuit collect is 
disabled if locally cached
+               outrdd = outrdd.persist((StorageLevel.MEMORY_AND_DISK()));
+               RDDObject outro = new RDDObject(outrdd); //create new rdd object
+               outro.setCheckpointRDD(true);            //mark as checkpointed
+               outro.addLineageChild(inro);             //keep lineage to 
prevent cycles on cleanup
+               cd.setRDDHandle(outro);
+               sec.setVariable(output.getName(), cd);
+       }
        
        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 79832eabe2..253480b482 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -66,8 +66,9 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
        private final boolean _outputEmptyBlocks;
        private final SparkAggType _aggtype;
        
-       private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, 
String istr) {
-               super(SPType.CPMM, op, in1, in2, out, opcode, istr);
+       private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out,
+               boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, 
boolean toCache, String istr) {
+               super(SPType.CPMM, op, in1, in2, out, opcode, toCache, istr);
                _outputEmptyBlocks = outputEmptyBlocks;
                _aggtype = aggtype;
        }
@@ -83,7 +84,8 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                AggregateBinaryOperator aggbin = 
InstructionUtils.getMatMultOperator(1);
                boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
                SparkAggType aggtype = SparkAggType.valueOf(parts[5]);
-               return new CpmmSPInstruction(aggbin, in1, in2, out, 
outputEmptyBlocks, aggtype, opcode, str);
+               boolean toCache = parts.length == 7 ? 
Boolean.parseBoolean(parts[6]) : false;
+               return new CpmmSPInstruction(aggbin, in1, in2, out, 
outputEmptyBlocks, aggtype, opcode, toCache, str);
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
index 70d4bef1dd..130045a890 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
@@ -49,8 +49,9 @@ import java.util.LinkedList;
 
 public class RmmSPInstruction extends AggregateBinarySPInstruction {
 
-       private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, String opcode, String istr) {
-               super(SPType.RMM, op, in1, in2, out, opcode, istr);
+       private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out,
+               String opcode, boolean toCache, String istr) {
+               super(SPType.RMM, op, in1, in2, out, opcode, toCache, istr);
        }
 
        public static RmmSPInstruction parseInstruction( String str ) {
@@ -61,8 +62,9 @@ public class RmmSPInstruction extends 
AggregateBinarySPInstruction {
                        CPOperand in1 = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
                        CPOperand out = new CPOperand(parts[3]);
+                       boolean toCache = parts.length == 5 ? 
Boolean.parseBoolean(parts[4]) : false;
                        
-                       return new RmmSPInstruction(null, in1, in2, out, 
opcode, str);
+                       return new RmmSPInstruction(null, in1, in2, out, 
opcode, toCache, str);
                } 
                else {
                        throw new 
DMLRuntimeException("RmmSPInstruction.parseInstruction():: Unknown opcode " + 
opcode);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
index 0882dd1d9a..f4b99bb03e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
@@ -28,6 +28,7 @@ public abstract class LineageObject
 {
        //basic lineage information
        protected int _numRef = -1;
+       protected boolean _lineageCached = false;
        protected final List<LineageObject> _childs;
        
        //N:1 back reference to matrix/frame object
@@ -35,6 +36,7 @@ public abstract class LineageObject
        
        protected LineageObject() {
                _numRef = 0;
+               _lineageCached = false;
                _childs = new ArrayList<>();
        }
        
@@ -49,6 +51,14 @@ public abstract class LineageObject
        public boolean hasBackReference() {
                return (_cd != null);
        }
+
+       public void setLineageCached() {
+               _lineageCached = true;
+       }
+
+       public boolean isInLineageCache() {
+               return _lineageCached;
+       }
        
        public void incrementNumReferences() {
                _numRef++;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
index 2b03a00d31..04d021b6ff 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
@@ -103,6 +103,11 @@ public class RDDObject extends LineageObject
 
        public boolean allowsShortCircuitCollect()
        {
+               // If the RDD is marked to be persisted and cached locally, we 
want to collect the RDD
+               // so that the next time we can reuse the RDD.
+               if (isInLineageCache())
+                       return false;
+
                return ( isCheckpointRDD() && getLineageChilds().size() == 1
                             && getLineageChilds().get(0) instanceof RDDObject 
);
        }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index d4eb4b8f92..0391583b4c 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -53,7 +53,6 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
 import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
-import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
@@ -95,31 +94,10 @@ public class LineageCache
                        return false;
                
                boolean reuse = false;
-               //NOTE: the check for computation CP instructions ensures that 
the output
-               // will always fit in memory and hence can be pinned 
unconditionally
-               if (LineageCacheConfig.isReusable(inst, ec)) {
-                       ComputationCPInstruction cinst = inst instanceof 
ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
-                       ComputationFEDInstruction cfinst = inst instanceof 
ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
-                       ComputationSPInstruction cspinst = inst instanceof 
ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
-                       GPUInstruction gpuinst = inst instanceof GPUInstruction 
? (GPUInstruction)inst : null;
-                       //TODO: Replace with generic type
-                               
-                       LineageItem instLI = (cinst != null) ? 
cinst.getLineageItem(ec).getValue()
-                                       : (cfinst != null) ? 
cfinst.getLineageItem(ec).getValue()
-                                       : (cspinst != null) ? 
cspinst.getLineageItem(ec).getValue()
-                                       : gpuinst.getLineageItem(ec).getValue();
-                       List<MutablePair<LineageItem, LineageCacheEntry>> 
liList = null;
-                       if (inst instanceof MultiReturnBuiltinCPInstruction) {
-                               liList = new ArrayList<>();
-                               MultiReturnBuiltinCPInstruction mrInst = 
(MultiReturnBuiltinCPInstruction)inst;
-                               for (int i=0; i<mrInst.getNumOutputs(); i++) {
-                                       String opcode = instLI.getOpcode() + 
String.valueOf(i);
-                                       liList.add(MutablePair.of(new 
LineageItem(opcode, instLI.getInputs()), null));
-                               }
-                       }
-                       else
-                               liList = Arrays.asList(MutablePair.of(instLI, 
null));
-                       
+               if (LineageCacheConfig.isReusable(inst, ec))
+               {
+                       List<MutablePair<LineageItem, LineageCacheEntry>> 
liList = getLineageItems(inst, ec);
+
                        //atomic try reuse full/partial and set placeholder, 
without
                        //obtaining value to avoid blocking in critical section
                        LineageCacheEntry e = null;
@@ -131,49 +109,28 @@ public class LineageCache
                                                e = 
LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
                                        //TODO need to also move execution of 
compensation plan out of here
                                        //(create lazily evaluated entry)
-                                       if (e == null && 
LineageCacheConfig.getCacheType().isPartialReuse() && cspinst == null)
+                                       if (e == null && 
LineageCacheConfig.getCacheType().isPartialReuse()
+                                               && !(inst instanceof 
ComputationSPInstruction))
                                                if( 
LineageRewriteReuse.executeRewrites(inst, ec) )
                                                        e = 
getIntern(item.getKey());
-                                       //TODO: Partial reuse for Spark 
instructions
                                        reuseAll &= (e != null);
                                        item.setValue(e);
                                        
                                        //create a placeholder if no reuse to 
avoid redundancy
                                        //(e.g., concurrent threads that try to 
start the computation)
-                                       if(e == null && 
isMarkedForCaching(inst, ec)) {
-                                               if (cinst != null)
-                                                       
putIntern(item.getKey(), cinst.output.getDataType(), null, null,  0);
-                                               else if (cfinst != null)
-                                                       
putIntern(item.getKey(), cfinst.output.getDataType(), null, null,  0);
-                                               else if (cspinst != null)
-                                                       
putIntern(item.getKey(), cspinst.output.getDataType(), null, null,  0);
-                                               else if (gpuinst != null)
-                                                       
putIntern(item.getKey(), gpuinst._output.getDataType(), null, null,  0);
-                                               //FIXME: different o/p 
datatypes for MultiReturnBuiltins.
-                                       }
+                                       if(e == null && 
isMarkedForCaching(inst, ec))
+                                               putInternPlaceholder(inst, 
item.getKey());
                                }
                        }
                        reuse = reuseAll;
                        
                        if(reuse) { //reuse
-                               boolean gpuReuse = false;
-                               //put reuse value into symbol table (w/ 
blocking on placeholders)
+                               //put reused value into symbol table (w/ 
blocking on placeholders)
                                for (MutablePair<LineageItem, 
LineageCacheEntry> entry : liList) {
                                        e = entry.getValue();
-                                       String outName = null;
-                                       if (inst instanceof 
MultiReturnBuiltinCPInstruction)
-                                               outName = 
((MultiReturnBuiltinCPInstruction)inst).
-                                                       
getOutput(entry.getKey().getOpcode().charAt(entry.getKey().getOpcode().length()-1)-'0').getName();
 
-                                       else if (inst instanceof 
ComputationCPInstruction)
-                                               outName = 
cinst.output.getName();
-                                       else if (inst instanceof 
ComputationFEDInstruction)
-                                               outName = 
cfinst.output.getName();
-                                       else if (inst instanceof 
ComputationSPInstruction)
-                                               outName = 
cspinst.output.getName();
-                                       else if (inst instanceof GPUInstruction)
-                                               outName = 
gpuinst._output.getName();
-                                       
-                                       if (e.isMatrixValue() && e._gpuObject 
== null) {
+                                       String outName = getOutputName(inst, 
entry.getKey());
+
+                                       if (e.isMatrixValue() && 
!e.isGPUObject()) {
                                                MatrixBlock mb = 
e.getMBValue(); //wait if another thread is executing the same inst.
                                                if (mb == null && 
e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
                                                        return false;  //the 
executing thread removed this entry from cache
@@ -190,10 +147,13 @@ public class LineageCache
                                        else if (e.isRDDPersist()) {
                                                //Reuse the RDD which is also 
persisted in Spark
                                                RDDObject rdd = 
e.getRDDObject();
+                                               if (!((SparkExecutionContext) 
ec).isRDDCached(rdd.getRDD().id()))
+                                                       //Return if the RDD is 
not cached in the executors
+                                                       return false;
                                                if (rdd == null && 
e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
                                                        return false;
                                                else
-                                                       
((SparkExecutionContext)ec).setRDDHandleForVariable(outName, rdd);
+                                                       
((SparkExecutionContext) ec).setRDDHandleForVariable(outName, rdd);
                                        }
                                        else { //TODO handle locks on gpu 
objects
                                                //shallow copy the cached 
GPUObj to the output MatrixObject
@@ -201,26 +161,15 @@ public class LineageCache
                                                                
ec.getGPUContext(0).shallowCopyGPUObject(e._gpuObject, 
ec.getMatrixObject(outName)));
                                                //Set dirty to true, so that it 
is later copied to the host for write
                                                
ec.getMatrixObject(outName).getGPUObject(ec.getGPUContext(0)).setDirty(true);
-                                               gpuReuse = true;
                                        }
-
-                                       reuse = true;
-
-                                       if (DMLScript.STATISTICS) //increment 
saved time
-                                               
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
-                               }
-                               if (DMLScript.STATISTICS) {
-                                       if (gpuReuse)
-                                               
LineageCacheStatistics.incrementGpuHits();
-                                       else
-                                               
LineageCacheStatistics.incrementInstHits();
                                }
+                               maintainReuseStatistics(inst, 
liList.get(0).getValue());
                        }
                }
                
                return reuse;
        }
-       
+
        public static boolean reuse(List<String> outNames, List<DataIdentifier> 
outParams, 
                        int numOutputs, LineageItem[] liInputs, String name, 
ExecutionContext ec)
        {
@@ -532,9 +481,7 @@ public class LineageCache
                        //if (!isMarkedForCaching(inst, ec)) return;
                        List<Pair<LineageItem, Data>> liData = null;
                        GPUObject liGpuObj = null;
-                       RDDObject rddObj = null;
                        LineageItem instLI = ((LineageTraceable) 
inst).getLineageItem(ec).getValue();
-                       LineageItem instInputLI = null;
                        if (inst instanceof MultiReturnBuiltinCPInstruction) {
                                liData = new ArrayList<>();
                                MultiReturnBuiltinCPInstruction mrInst = 
(MultiReturnBuiltinCPInstruction)inst;
@@ -556,14 +503,9 @@ public class LineageCache
                                if (liGpuObj == null)
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((GPUInstruction)inst)._output)));
                        }
-                       else if (inst instanceof CheckpointSPInstruction) {
-                               // Get the lineage of the instruction being 
checkpointed
-                               instInputLI = 
ec.getLineageItem(((ComputationSPInstruction)inst).input1);
-                               // Get the RDD handle of the persisted RDD
-                               CacheableData<?> cd = 
ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
-                               rddObj = ((CacheableData<?>) cd).getRDDHandle();
-                               // Remove the lineage item of the chkpoint 
instruction
-                               removePlaceholder(instLI);
+                       else if (inst instanceof ComputationSPInstruction && 
((ComputationSPInstruction) inst).isRDDtoCache()) {
+                               putValueRDD(inst, instLI, ec, computetime);
+                               return;
                        }
                        else
                                if (inst instanceof ComputationCPInstruction)
@@ -573,12 +515,10 @@ public class LineageCache
                                else if (inst instanceof 
ComputationSPInstruction)
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationSPInstruction) inst).output)));
 
-                       if (liGpuObj == null && rddObj == null)
+                       if (liGpuObj == null)
                                putValueCPU(inst, liData, computetime);
-                       if (liGpuObj != null)
+                       else
                                putValueGPU(liGpuObj, instLI, computetime);
-                       if (rddObj != null)
-                               putValueRDD(rddObj, instInputLI, computetime);
                }
        }
        
@@ -607,13 +547,6 @@ public class LineageCache
                                        continue;
                                }
 
-                               if (LineageCacheConfig.isToPersist(inst) && 
LineageCacheConfig.getCompAssRW()) {
-                                       // The immediately following 
instruction must be a checkpoint, which will
-                                       // fill the rdd in this cache entry.
-                                       // TODO: Instead check if this 
instruction is marked for checkpointing
-                                       continue;
-                               }
-
                                if (data instanceof MatrixObject && 
((MatrixObject) data).hasRDDHandle()) {
                                        // Avoid triggering pre-matured Spark 
instruction chains
                                        removePlaceholder(item);
@@ -672,18 +605,21 @@ public class LineageCache
                }
        }
 
-       private static void putValueRDD(RDDObject rdd, LineageItem instLI, long 
computetime) {
+       private static void putValueRDD(Instruction inst, LineageItem instLI, 
ExecutionContext ec, long computetime) {
                synchronized( _cache ) {
-                       // Not available in the cache indicates this RDD is not 
marked for caching
                        if (!probe(instLI))
                                return;
+                       // Call persist on the output RDD
+                       ((ComputationSPInstruction) inst).checkpointRDD(ec);
+                       // Get the RDD handle of the persisted RDD
+                       CacheableData<?> cd = 
ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
+                       RDDObject rddObj = ((CacheableData<?>) 
cd).getRDDHandle();
 
                        LineageCacheEntry centry = _cache.get(instLI);
-                       if (centry.isRDDPersist() && 
centry.getRDDObject().isCheckpointRDD())
-                               // Do nothing if the cached RDD is already 
checkpointed
-                               return;
-
-                       centry.setRDDValue(rdd, computetime);
+                       // Set the RDD object in the cache
+                       // TODO: Make space in the executors
+                       rddObj.setLineageCached();
+                       centry.setRDDValue(rddObj, computetime);
                        // Maintain order for eviction
                        LineageCacheEviction.addEntry(centry);
                }
@@ -879,7 +815,24 @@ public class LineageCache
 
        
        //----------------- INTERNAL CACHE LOGIC IMPLEMENTATION --------------//
-       
+
+       private static void putInternPlaceholder(Instruction inst, LineageItem 
key) {
+               ComputationCPInstruction cinst = inst instanceof 
ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+               ComputationFEDInstruction cfinst = inst instanceof 
ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+               ComputationSPInstruction cspinst = inst instanceof 
ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+               GPUInstruction gpuinst = inst instanceof GPUInstruction ? 
(GPUInstruction)inst : null;
+
+               if (cinst != null)
+                       putIntern(key, cinst.output.getDataType(), null, null,  
0);
+               else if (cfinst != null)
+                       putIntern(key, cfinst.output.getDataType(), null, null, 
 0);
+               else if (cspinst != null)
+                       putIntern(key, cspinst.output.getDataType(), null, 
null,  0);
+               else if (gpuinst != null)
+                       putIntern(key, gpuinst._output.getDataType(), null, 
null,  0);
+               //FIXME: different o/p datatypes for MultiReturnBuiltins.
+       }
+
        private static void putIntern(LineageItem key, DataType dt, MatrixBlock 
Mval, ScalarObject Sval, long computetime) {
                if (_cache.containsKey(key))
                        //can come here if reuse_partial option is enabled
@@ -1130,4 +1083,70 @@ public class LineageCache
                
                return nflops / (2L * 1024 * 1024 * 1024);
        }
+
+
+       //----------------- UTILITY FUNCTIONS --------------------//
+
+       private static List<MutablePair<LineageItem, LineageCacheEntry>> 
getLineageItems(Instruction inst, ExecutionContext ec) {
+               ComputationCPInstruction cinst = inst instanceof 
ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+               ComputationFEDInstruction cfinst = inst instanceof 
ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+               ComputationSPInstruction cspinst = inst instanceof 
ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+               GPUInstruction gpuinst = inst instanceof GPUInstruction ? 
(GPUInstruction)inst : null;
+               //TODO: Replace with generic type
+
+               List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
+               LineageItem instLI = (cinst != null) ? 
cinst.getLineageItem(ec).getValue()
+                       : (cfinst != null) ? 
cfinst.getLineageItem(ec).getValue()
+                       : (cspinst != null) ? 
cspinst.getLineageItem(ec).getValue()
+                       : gpuinst.getLineageItem(ec).getValue();
+               if (inst instanceof MultiReturnBuiltinCPInstruction) {
+                       liList = new ArrayList<>();
+                       MultiReturnBuiltinCPInstruction mrInst = 
(MultiReturnBuiltinCPInstruction)inst;
+                       for (int i=0; i<mrInst.getNumOutputs(); i++) {
+                               String opcode = instLI.getOpcode() + 
String.valueOf(i);
+                               liList.add(MutablePair.of(new 
LineageItem(opcode, instLI.getInputs()), null));
+                       }
+               }
+               else
+                       liList = List.of(MutablePair.of(instLI, null));
+
+               return liList;
+       }
+
+       private static String getOutputName(Instruction inst, LineageItem li) {
+               ComputationCPInstruction cinst = inst instanceof 
ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+               ComputationFEDInstruction cfinst = inst instanceof 
ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+               ComputationSPInstruction cspinst = inst instanceof 
ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+               GPUInstruction gpuinst = inst instanceof GPUInstruction ? 
(GPUInstruction)inst : null;
+
+               String outName = null;
+               if (inst instanceof MultiReturnBuiltinCPInstruction)
+                       outName = ((MultiReturnBuiltinCPInstruction)inst).
+                               
getOutput(li.getOpcode().charAt(li.getOpcode().length()-1)-'0').getName();
+               else if (inst instanceof ComputationCPInstruction)
+                       outName = cinst.output.getName();
+               else if (inst instanceof ComputationFEDInstruction)
+                       outName = cfinst.output.getName();
+               else if (inst instanceof ComputationSPInstruction)
+                       outName = cspinst.output.getName();
+               else if (inst instanceof GPUInstruction)
+                       outName = gpuinst._output.getName();
+
+               return outName;
+       }
+       private static void maintainReuseStatistics(Instruction inst, 
LineageCacheEntry e) {
+               if (!DMLScript.STATISTICS)
+                       return;
+
+               
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
+               if (e.isGPUObject()) LineageCacheStatistics.incrementGpuHits();
+               if (e.isRDDPersist()) LineageCacheStatistics.incrementRDDHits();
+               if (e.isMatrixValue() || e.isScalarValue()) {
+                       if (inst instanceof ComputationSPInstruction || 
inst.getOpcode().equals("prefetch"))
+                               // Single_block Spark instructions (sync/async) 
and prefetch
+                               
LineageCacheStatistics.incrementSparkCollectHits();
+                       else
+                               LineageCacheStatistics.incrementInstHits();
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 52b8399fcc..04ed47093d 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -58,13 +58,7 @@ public class LineageCacheConfig
                //TODO: Reuse everything. 
        };
 
-       private static final String[] OPCODES_CP = new String[] {
-               "cpmm", "rmm"
-               //TODO: Instead mark an instruction to be checkpointed
-       };
-
        private static String[] REUSE_OPCODES  = new String[] {};
-       private static String[] OPCODES_CHECKPOINTS  = new String[] {};
 
        public enum ReuseCacheType {
                REUSE_FULL,
@@ -196,10 +190,9 @@ public class LineageCacheConfig
        static {
                //setup static configuration parameters
                REUSE_OPCODES = OPCODES;
-               OPCODES_CHECKPOINTS = OPCODES_CP;
-               //setSpill(true); 
+               //setSpill(true);
                setCachePolicy(LineageCachePolicy.COSTNSIZE);
-               setCompAssRW(false);
+               setCompAssRW(true);
        }
 
        public static void setReusableOpcodes(String... ops) {
@@ -210,10 +203,6 @@ public class LineageCacheConfig
                return REUSE_OPCODES;
        }
 
-       public static boolean isToPersist(Instruction inst) {
-               return ArrayUtils.contains(OPCODES_CHECKPOINTS, 
inst.getOpcode());
-       }
-       
        public static void resetReusableOpcodes() {
                REUSE_OPCODES = OPCODES;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 0042674e56..8efe57a162 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -161,6 +161,10 @@ public class LineageCacheEntry {
                return _rddObject != null;
        }
 
+       public boolean isGPUObject() {
+               return _gpuObject != null;
+       }
+
        public boolean isSerializedBytes() {
                return _dt.isUnknown() && 
_key.getOpcode().equals(LineageItemUtils.SERIALIZATION_OPCODE);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index fc34f7341a..01f2177b33 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -41,10 +41,13 @@ public class LineageCacheStatistics {
        private static final LongAdder _ctimeFSWrite    = new LongAdder();
        private static final LongAdder _ctimeSaved      = new LongAdder();
        private static final LongAdder _ctimeMissed     = new LongAdder();
-       // Bellow entries are for specific to gpu lineage cache
+       // Bellow entries are specific to gpu lineage cache
        private static final LongAdder _numHitsGpu      = new LongAdder();
        private static final LongAdder _numAsyncEvictGpu= new LongAdder();
        private static final LongAdder _numSyncEvictGpu = new LongAdder();
+       // Below entries are specific to Spark instructions
+       private static final LongAdder _numHitsRdd      = new LongAdder();
+       private static final LongAdder _numHitsSparkActions = new LongAdder();
 
        public static void reset() {
                _numHitsMem.reset();
@@ -64,6 +67,8 @@ public class LineageCacheStatistics {
                _numHitsGpu.reset();
                _numAsyncEvictGpu.reset();
                _numSyncEvictGpu.reset();
+               _numHitsRdd.reset();
+               _numHitsSparkActions.reset();
        }
        
        public static void incrementMemHits() {
@@ -197,6 +202,17 @@ public class LineageCacheStatistics {
                _numSyncEvictGpu.increment();
        }
 
+       public static void incrementRDDHits() {
+               // Number of times a persisted RDD are reused.
+               _numHitsRdd.increment();
+       }
+
+       public static void incrementSparkCollectHits() {
+               // Spark instructions that bring intermediate back to local.
+               // Both synchronous and asynchronous (e.g. tsmm, prefetch)
+               _numHitsSparkActions.increment();
+       }
+
        public static String displayHits() {
                StringBuilder sb = new StringBuilder();
                sb.append(_numHitsMem.longValue());
@@ -257,4 +273,12 @@ public class LineageCacheStatistics {
                sb.append(_numSyncEvictGpu.longValue());
                return sb.toString();
        }
+
+       public static String displaySparkStats() {
+               StringBuilder sb = new StringBuilder();
+               sb.append(_numHitsSparkActions.longValue());
+               sb.append("/");
+               sb.append(_numHitsRdd.longValue());
+               return sb.toString();
+       }
 }
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index ca359a46e6..fbbce8049b 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -639,6 +639,7 @@ public class Statistics
                                sb.append("LinCache hits (Mem/FS/Del): \t" + 
LineageCacheStatistics.displayHits() + ".\n");
                                sb.append("LinCache MultiLevel (Ins/SB/Fn):" + 
LineageCacheStatistics.displayMultiLevelHits() + ".\n");
                                sb.append("LinCache GPU (Hit/Async/Sync): \t" + 
LineageCacheStatistics.displayGpuStats() + ".\n");
+                               sb.append("LinCache Spark (Col/RDD): \t\t" + 
LineageCacheStatistics.displaySparkStats() + ".\n");
                                sb.append("LinCache writes (Mem/FS/Del): \t" + 
LineageCacheStatistics.displayWtrites() + ".\n");
                                sb.append("LinCache FStimes (Rd/Wr): \t" + 
LineageCacheStatistics.displayFSTime() + " sec.\n");
                                sb.append("LinCache Computetime (S/M): \t" + 
LineageCacheStatistics.displayComputeTime() + " sec.\n");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index 5b49bb82fa..98c4d14c83 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -35,6 +35,7 @@ package org.apache.sysds.test.functions.async;
        import org.apache.sysds.test.TestUtils;
        import org.apache.sysds.utils.Statistics;
        import org.junit.Assert;
+       import org.junit.Ignore;
        import org.junit.Test;
 
 public class LineageReuseSparkTest extends AutomatedTestBase {
@@ -62,6 +63,13 @@ public class LineageReuseSparkTest extends AutomatedTestBase 
{
                runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
        }
 
+       @Ignore
+       @Test
+       public void testlmdsRDD() {
+               // Persist and cache RDDs of shuffle-based Spark operations 
(eg. rmm, cpmm)
+               runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
+       }
+
        public void runTest(String testname, ExecMode execMode, int testId) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -90,6 +98,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
                        long numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
                        long numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+                       long numRmm = 
Statistics.getCPHeavyHitterCount("sp_rmm");
 
                        proArgs.clear();
                        proArgs.add("-explain");
@@ -105,6 +114,7 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        HashMap<MatrixValue.CellIndex, Double> R_reused = 
readDMLScalarFromOutputDir("R");
                        long numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
                        long numMapmm_r = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+                       long numRmm_r = 
Statistics.getCPHeavyHitterCount("sp_rmm");
 
                        //compare matrices
                        boolean matchVal = TestUtils.compareMatrices(R, 
R_reused, 1e-6, "Origin", "withPrefetch");
@@ -114,6 +124,8 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                                Assert.assertTrue("Violated sp_tsmm reuse 
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
                                Assert.assertTrue("Violated sp_mapmm reuse 
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
                        }
+                       if (testId == 2)
+                               Assert.assertTrue("Violated sp_rmm reuse count: 
" + numRmm_r + " < " + numRmm, numRmm_r < numRmm);
                } finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
old_sum_product;
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml 
b/src/test/scripts/functions/async/LineageReuseSpark2.dml
new file mode 100644
index 0000000000..22f127c07d
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark2.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, 
Integer N) return (Matrix[double] beta)
+{
+  # Reuse sp_tsmm and sp_mapmm if not future-based
+  A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+  b = t(X) %*% y;
+  beta = solve(A, b);
+}
+
+no_lamda = 10;
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
+
+X = rand(rows=1500, cols=1500, seed=42);
+y = rand(rows=1500, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+  beta = SimlinRegDS(X, y, lamda, N);
+  #beta = lmDS(X=X, y=y, reg=lamda);
+  R[,i] = beta;
+  lamda = lamda + stp;
+  i = i + 1;
+}
+
+R = sum(R);
+write(R, $1, format="text");
+


Reply via email to