This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c93f9f9a4e [SYSTEMDS-3492] Lineage-based reuse of all RDDs
c93f9f9a4e is described below

commit c93f9f9a4e5172edb51be86cf9c3e98c035a9672
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Jan 25 22:58:55 2023 +0100

    [SYSTEMDS-3492] Lineage-based reuse of all RDDs
    
    This patch enables reuse of RDDs of redundant Spark operations.
    We also persist a subset of operations in the executors, where
    the rest are just cached locally. Reuse of even unpersisted RDDs
    allows Spark to apply optimizations and skip stages. In addition,
    this patch removes the compile-time flag to indicate reuse and
    instead reuse all RDDs. Local RDD caching is now disabled due to
    bugs.
    
    LinCache Spark (Col/Loc/Dist):  16/2/2. =>
    indicates the number of reused collects/prefetches (=16), local
    RDDs (=2) and persisted RDDs(=2).
    
    Closes #1777
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    |  3 +-
 src/main/java/org/apache/sysds/lops/MMCJ.java      |  4 +-
 src/main/java/org/apache/sysds/lops/MMRJ.java      |  3 +-
 .../controlprogram/caching/CacheableData.java      |  4 ++
 .../context/SparkExecutionContext.java             |  8 ++-
 .../spark/AggregateBinarySPInstruction.java        |  8 +--
 .../instructions/spark/BinarySPInstruction.java    |  5 --
 .../spark/ComputationSPInstruction.java            | 17 -----
 .../instructions/spark/CpmmSPInstruction.java      |  7 +-
 .../instructions/spark/RmmSPInstruction.java       | 10 ++-
 .../apache/sysds/runtime/lineage/LineageCache.java | 76 ++++++++++++++++++----
 .../sysds/runtime/lineage/LineageCacheConfig.java  | 47 +++++++------
 .../runtime/lineage/LineageCacheStatistics.java    | 11 +++-
 .../java/org/apache/sysds/utils/Statistics.java    |  2 +-
 .../functions/async/LineageReuseSparkTest.java     | 17 +++--
 .../scripts/functions/async/LineageReuseSpark3.dml | 67 +++++++++++++++++++
 16 files changed, 196 insertions(+), 93 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 9f80f7a683..e8bf19f40b 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -543,7 +543,7 @@ public class AggBinaryOp extends MultiThreadedHop {
                int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                Lop matmultCP = new 
MMTSJ(getInput().get(mmtsj.isLeft()?1:0).constructLops(),
                        getDataType(), getValueType(), et, mmtsj, false, k);
-               matmultCP.getOutputParameters().setDimensions(getDim1(), 
getDim2(), getBlocksize(), getNnz(), requiresLineageCaching());
+               matmultCP.getOutputParameters().setDimensions(getDim1(), 
getDim2(), getBlocksize(), getNnz());
                setLineNumbers( matmultCP );
                setLops(matmultCP);
        }
@@ -790,7 +790,6 @@ public class AggBinaryOp extends MultiThreadedHop {
                        Lop cpmm = new MMCJ(getInput().get(0).constructLops(), 
getInput().get(1).constructLops(),
                                getDataType(), getValueType(), 
_outputEmptyBlocks, aggtype, ExecType.SPARK);
                        setOutputDimensions( cpmm );
-                       //setMarkForLineageCaching(cpmm);
                        setLineNumbers( cpmm );
                        setLops( cpmm );
                }
diff --git a/src/main/java/org/apache/sysds/lops/MMCJ.java 
b/src/main/java/org/apache/sysds/lops/MMCJ.java
index 544e89341b..7659a6631f 100644
--- a/src/main/java/org/apache/sysds/lops/MMCJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMCJ.java
@@ -109,9 +109,7 @@ public class MMCJ extends Lop
                }
                else
                        sb.append(_type.name());
-               sb.append( OPERAND_DELIMITOR );
-               sb.append(getOutputParameters().getLinCacheMarking());
-               
+
                return sb.toString();
        }
 }
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/lops/MMRJ.java 
b/src/main/java/org/apache/sysds/lops/MMRJ.java
index 21577eeddb..bf19c703f0 100644
--- a/src/main/java/org/apache/sysds/lops/MMRJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMRJ.java
@@ -65,7 +65,6 @@ public class MMRJ extends Lop
                        "rmm",
                        getInputs().get(0).prepInputOperand(input1),
                        getInputs().get(1).prepInputOperand(input2),
-                       prepOutputOperand(output),
-                       Boolean.toString(toCache));
+                       prepOutputOperand(output));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 9d195ea7fd..59c9be2639 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -451,6 +451,10 @@ public abstract class CacheableData<T extends 
CacheBlock<?>> extends Data
                return _bcHandle;
        }
 
+       public boolean hasBroadcastHandle() {
+               return  _bcHandle != null && _bcHandle.hasBackReference();
+       }
+
        @SuppressWarnings({ "rawtypes", "unchecked" })
        public void setBroadcastHandle( BroadcastObject bc ) {
                //cleanup potential old back reference
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 718f4d8613..1d4681ec75 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1516,14 +1516,16 @@ public class SparkExecutionContext extends 
ExecutionContext
                if( lob.hasBackReference() )
                        return;
 
+               //abort if a RDD is cached locally
+               if (lob.isInLineageCache())
+                       return;
+
                //cleanup current lineage object (from driver/executors)
                //incl deferred hdfs file removal (only if metadata set by 
cleanup call)
                if( lob instanceof RDDObject ) {
                        RDDObject rdd = (RDDObject)lob;
                        int rddID = rdd.getRDD().id();
-                       //skip unpersisting if locally cached
-                       if (!lob.isInLineageCache())
-                               cleanupRDDVariable(rdd.getRDD());
+                       cleanupRDDVariable(rdd.getRDD());
                        if( rdd.getHDFSFilename()!=null ) { //deferred file 
removal
                                
HDFSTool.deleteFileWithMTDIfExistOnHDFS(rdd.getHDFSFilename());
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
index 4575b06252..c860bdcce1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
@@ -26,12 +26,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
  * Class to group the different MM <code>SPInstruction</code>s together.
  */
 public abstract class AggregateBinarySPInstruction extends BinarySPInstruction 
{
-       protected AggregateBinarySPInstruction(SPType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
-               super(type, op, in1, in2, out, opcode, istr);
-       }
-
        protected AggregateBinarySPInstruction(SPType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand out,
-               String opcode, boolean toCache, String istr) {
-               super(type, op, in1, in2, out, opcode, toCache, istr);
+                               String opcode, String istr) {
+               super(type, op, in1, in2, out, opcode, istr);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index 3c70d4021a..5c7d7ec71d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -57,11 +57,6 @@ public abstract class BinarySPInstruction extends 
ComputationSPInstruction {
                super(type, op, in1, in2, out, opcode, istr);
        }
 
-       protected BinarySPInstruction(SPType type, Operator op, CPOperand in1, 
CPOperand in2,
-               CPOperand out, String opcode, boolean toCache, String istr) {
-               super(type, op, in1, in2, out, opcode, toCache, istr);
-       }
-
        public static BinarySPInstruction parseInstruction ( String str ) {
                CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
                CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
index 09ac134702..f55a0b398e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
@@ -44,7 +44,6 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
 public abstract class ComputationSPInstruction extends SPInstruction 
implements LineageTraceable {
        public CPOperand output;
        public CPOperand input1, input2, input3;
-       private boolean toPersistAndCache;
 
        protected ComputationSPInstruction(SPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand out, String opcode, String istr) {
                super(type, op, opcode, istr);
@@ -54,15 +53,6 @@ public abstract class ComputationSPInstruction extends 
SPInstruction implements
                output = out;
        }
 
-       protected ComputationSPInstruction(SPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand out, String opcode, boolean toCache, String istr) 
{
-               super(type, op, opcode, istr);
-               input1 = in1;
-               input2 = in2;
-               input3 = null;
-               output = out;
-               toPersistAndCache = toCache;
-       }
-
        protected ComputationSPInstruction(SPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
                super(type, op, opcode, istr);
                input1 = in1;
@@ -144,15 +134,8 @@ public abstract class ComputationSPInstruction extends 
SPInstruction implements
                }
        }
 
-       public boolean isRDDtoCache() {
-               return toPersistAndCache;
-       }
-
        @SuppressWarnings("unchecked")
        public void checkpointRDD(ExecutionContext ec) {
-               if (!toPersistAndCache)
-                       return;
-
                SparkExecutionContext sec = (SparkExecutionContext)ec;
                CacheableData<?> cd = sec.getCacheableData(output.getName());
                RDDObject inro =  cd.getRDDHandle();
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 253480b482..6425613583 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -67,8 +67,8 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
        private final SparkAggType _aggtype;
        
        private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out,
-               boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, 
boolean toCache, String istr) {
-               super(SPType.CPMM, op, in1, in2, out, opcode, toCache, istr);
+               boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, 
String istr) {
+               super(SPType.CPMM, op, in1, in2, out, opcode, istr);
                _outputEmptyBlocks = outputEmptyBlocks;
                _aggtype = aggtype;
        }
@@ -84,8 +84,7 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                AggregateBinaryOperator aggbin = 
InstructionUtils.getMatMultOperator(1);
                boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
                SparkAggType aggtype = SparkAggType.valueOf(parts[5]);
-               boolean toCache = parts.length == 7 ? 
Boolean.parseBoolean(parts[6]) : false;
-               return new CpmmSPInstruction(aggbin, in1, in2, out, 
outputEmptyBlocks, aggtype, opcode, toCache, str);
+               return new CpmmSPInstruction(aggbin, in1, in2, out, 
outputEmptyBlocks, aggtype, opcode, str);
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
index 130045a890..ae450da789 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
@@ -49,9 +49,8 @@ import java.util.LinkedList;
 
 public class RmmSPInstruction extends AggregateBinarySPInstruction {
 
-       private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out,
-               String opcode, boolean toCache, String istr) {
-               super(SPType.RMM, op, in1, in2, out, opcode, toCache, istr);
+       private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, String opcode, String istr) {
+               super(SPType.RMM, op, in1, in2, out, opcode, istr);
        }
 
        public static RmmSPInstruction parseInstruction( String str ) {
@@ -62,9 +61,8 @@ public class RmmSPInstruction extends 
AggregateBinarySPInstruction {
                        CPOperand in1 = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
                        CPOperand out = new CPOperand(parts[3]);
-                       boolean toCache = parts.length == 5 ? 
Boolean.parseBoolean(parts[4]) : false;
-                       
-                       return new RmmSPInstruction(null, in1, in2, out, 
opcode, toCache, str);
+
+                       return new RmmSPInstruction(null, in1, in2, out, 
opcode, str);
                } 
                else {
                        throw new 
DMLRuntimeException("RmmSPInstruction.parseInstruction():: Unknown opcode " + 
opcode);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index c38a07d43c..fb4f579986 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -145,11 +145,8 @@ public class LineageCache
                                                        
ec.setScalarOutput(outName, so);
                                        }
                                        else if (e.isRDDPersist()) {
-                                               //Reuse the RDD which is also 
persisted in Spark
+                                               //Reuse the cached RDD (local 
or persisted at the executors)
                                                RDDObject rdd = 
e.getRDDObject();
-                                               if (!((SparkExecutionContext) 
ec).isRDDCached(rdd.getRDD().id()))
-                                                       //Return if the RDD is 
not cached in the executors
-                                                       return false;
                                                ((SparkExecutionContext) 
ec).setRDDHandleForVariable(outName, rdd);
                                        }
                                        else { //TODO handle locks on gpu 
objects
@@ -160,7 +157,7 @@ public class LineageCache
                                                
ec.getMatrixObject(outName).getGPUObject(ec.getGPUContext(0)).setDirty(true);
                                        }
                                }
-                               maintainReuseStatistics(inst, 
liList.get(0).getValue());
+                               maintainReuseStatistics(ec, inst, 
liList.get(0).getValue());
                        }
                }
                
@@ -500,7 +497,9 @@ public class LineageCache
                                if (liGpuObj == null)
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((GPUInstruction)inst)._output)));
                        }
-                       else if (inst instanceof ComputationSPInstruction && 
((ComputationSPInstruction) inst).isRDDtoCache()) {
+                       else if (inst instanceof ComputationSPInstruction
+                               && (ec.getVariable(((ComputationSPInstruction) 
inst).output) instanceof MatrixObject)
+                               && 
(ec.getCacheableData(((ComputationSPInstruction)inst).output.getName())).hasRDDHandle())
 {
                                putValueRDD(inst, instLI, ec, computetime);
                                return;
                        }
@@ -509,7 +508,7 @@ public class LineageCache
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationCPInstruction) inst).output)));
                                else if (inst instanceof 
ComputationFEDInstruction)
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationFEDInstruction) inst).output)));
-                               else if (inst instanceof 
ComputationSPInstruction)
+                               else if (inst instanceof 
ComputationSPInstruction) //collects or prefetches
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationSPInstruction) inst).output)));
 
                        if (liGpuObj == null)
@@ -606,15 +605,37 @@ public class LineageCache
                synchronized( _cache ) {
                        if (!probe(instLI))
                                return;
-                       // Call persist on the output RDD
-                       ((ComputationSPInstruction) inst).checkpointRDD(ec);
-                       // Get the RDD handle of the persisted RDD
+                       // Avoid reuse chkpoint, which is unnecessary
+                       if (inst.getOpcode().equalsIgnoreCase("chkpoint")) {
+                               removePlaceholder(instLI);
+                               return;
+                       }
+                       boolean opToPersist = 
LineageCacheConfig.isReusableRDDType(inst);
+                       // Return if the intermediate is not to be persisted in 
the executors
+                       // and the local only RDD caching is disabled
+                       if (!opToPersist && 
!LineageCacheConfig.ENABLE_LOCAL_ONLY_RDD_CACHING) {
+                               removePlaceholder(instLI);
+                               return;
+                       }
+
+                       // Filter out Spark instructions with broadcast input
+                       // TODO: This code avoids one crash. Remove once fixed.
+                       if (!opToPersist && !allInputsSpark(inst, ec)) {
+                               removePlaceholder(instLI);
+                               return;
+                       }
+
+                       // Call persist on the output RDD if required
+                       if (opToPersist)
+                               ((ComputationSPInstruction) 
inst).checkpointRDD(ec);
+                       // Get the RDD handle of the RDD
                        CacheableData<?> cd = 
ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
-                       RDDObject rddObj = ((CacheableData<?>) 
cd).getRDDHandle();
+                       RDDObject rddObj = cd.getRDDHandle();
 
                        LineageCacheEntry centry = _cache.get(instLI);
                        // Set the RDD object in the cache
                        // TODO: Make space in the executors
+                       // TODO: Estimate the actual compute time for this 
operation
                        rddObj.setLineageCached();
                        centry.setRDDValue(rddObj, computetime);
                        // Maintain order for eviction
@@ -1131,13 +1152,42 @@ public class LineageCache
 
                return outName;
        }
-       private static void maintainReuseStatistics(Instruction inst, 
LineageCacheEntry e) {
+
+       private static boolean allInputsSpark(Instruction inst, 
ExecutionContext ec) {
+               CPOperand in1 = ((ComputationSPInstruction)inst).input1;
+               CPOperand in2 = ((ComputationSPInstruction)inst).input2;
+               CPOperand in3 = ((ComputationSPInstruction)inst).input3;
+
+               // All inputs must be matrices
+               if ((in1 != null && !in1.isMatrix()) || (in2 != null && 
!in2.isMatrix()) || (in3 != null && !in3.isMatrix()))
+                       return false;
+
+               // Filter out if any input is local
+               if (in1 != null && 
(!ec.getMatrixObject(in1.getName()).hasRDDHandle() ||
+                       ec.getMatrixObject(in1.getName()).hasBroadcastHandle()))
+                       return false;
+               if (in2 != null && 
(!ec.getMatrixObject(in2.getName()).hasRDDHandle() ||
+                       ec.getMatrixObject(in2.getName()).hasBroadcastHandle()))
+                       return false;
+               if (in3 != null && 
(!ec.getMatrixObject(in3.getName()).hasRDDHandle() ||
+                       ec.getMatrixObject(in3.getName()).hasBroadcastHandle()))
+                       return false;
+
+               return true;
+       }
+
+       private static void maintainReuseStatistics(ExecutionContext ec, 
Instruction inst, LineageCacheEntry e) {
                if (!DMLScript.STATISTICS)
                        return;
 
                
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
                if (e.isGPUObject()) LineageCacheStatistics.incrementGpuHits();
-               if (e.isRDDPersist()) LineageCacheStatistics.incrementRDDHits();
+               if (e.isRDDPersist()) {
+                       if (((SparkExecutionContext) 
ec).isRDDCached(e.getRDDObject().getRDD().id()))
+                               
LineageCacheStatistics.incrementRDDPersistHits(); //persisted in the executors
+                       else
+                               LineageCacheStatistics.incrementRDDHits();  
//only locally cached
+               }
                if (e.isMatrixValue() || e.isScalarValue()) {
                        if (inst instanceof ComputationSPInstruction || 
inst.getOpcode().equals("prefetch"))
                                // Single_block Spark instructions (sync/async) 
and prefetch
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 04ed47093d..b0c8a87db1 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -58,7 +58,12 @@ public class LineageCacheConfig
                //TODO: Reuse everything. 
        };
 
+       private static final String[] PERSIST_OPCODES = new String[] {
+               "mapmm", "cpmm", "rmm"
+       };
+
        private static String[] REUSE_OPCODES  = new String[] {};
+       private static String[] CHKPOINT_OPCODES  = new String[] {};
 
        public enum ReuseCacheType {
                REUSE_FULL,
@@ -185,11 +190,17 @@ public class LineageCacheConfig
                return ret;
        };
 
+
+       //-------------SPARK OPERATION RELATED CONFIGURATIONS--------------//
+
+       protected static boolean ENABLE_LOCAL_ONLY_RDD_CACHING = false;
+
        //----------------------------------------------------------------//
 
        static {
                //setup static configuration parameters
                REUSE_OPCODES = OPCODES;
+               CHKPOINT_OPCODES = PERSIST_OPCODES;
                //setSpill(true);
                setCachePolicy(LineageCachePolicy.COSTNSIZE);
                setCompAssRW(true);
@@ -258,31 +269,19 @@ public class LineageCacheConfig
                }
        }
 
-       // Check if the Spark instruction returns result back to local
-       @SuppressWarnings("unused")
-       private static boolean isRightSparkOp(Instruction inst) {
-               if (!(inst instanceof ComputationSPInstruction))
-                       return false;
-
-               boolean spAction = false;
-               if (inst instanceof MapmmSPInstruction &&
-                       ((MapmmSPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
-                       spAction = true;
-               else if (inst instanceof TsmmSPInstruction)
-                       spAction = true;
-               else if (inst instanceof AggregateUnarySPInstruction &&
-                       ((AggregateUnarySPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
-                       spAction = true;
-               else if (inst instanceof CpmmSPInstruction &&
-                       ((CpmmSPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
-                       spAction = true;
-               else if (((ComputationSPInstruction) inst).output.getDataType() 
== Types.DataType.SCALAR)
-                       spAction = true;
-               //TODO: include other cases
-
-               return spAction;
+       protected static boolean isReusableRDDType(Instruction inst) {
+               boolean insttype = inst instanceof ComputationSPInstruction;
+               boolean rightOp = ArrayUtils.contains(CHKPOINT_OPCODES, 
inst.getOpcode());
+               if (rightOp && inst instanceof MapmmSPInstruction
+                       && ((MapmmSPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
+                       rightOp = false;
+               if (rightOp && inst instanceof CpmmSPInstruction
+                       && ((CpmmSPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
+                       rightOp = false;
+               return insttype && rightOp;
        }
-       
+
+
        public static boolean isOutputFederated(Instruction inst, Data data) {
                if (!(inst instanceof ComputationFEDInstruction))
                        return false;
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index 01f2177b33..fd708517e8 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -48,6 +48,7 @@ public class LineageCacheStatistics {
        // Below entries are specific to Spark instructions
        private static final LongAdder _numHitsRdd      = new LongAdder();
        private static final LongAdder _numHitsSparkActions = new LongAdder();
+       private static final LongAdder _numHitsRddPersist   = new LongAdder();
 
        public static void reset() {
                _numHitsMem.reset();
@@ -69,6 +70,7 @@ public class LineageCacheStatistics {
                _numSyncEvictGpu.reset();
                _numHitsRdd.reset();
                _numHitsSparkActions.reset();
+               _numHitsRddPersist.reset();
        }
        
        public static void incrementMemHits() {
@@ -203,7 +205,7 @@ public class LineageCacheStatistics {
        }
 
        public static void incrementRDDHits() {
-               // Number of times a persisted RDD are reused.
+               // Number of times a locally cached (but not persisted) RDD are 
reused.
                _numHitsRdd.increment();
        }
 
@@ -213,6 +215,11 @@ public class LineageCacheStatistics {
                _numHitsSparkActions.increment();
        }
 
+       public static void incrementRDDPersistHits() {
+               // Number of times a locally cached and persisted RDD are 
reused.
+               _numHitsRddPersist.increment();
+       }
+
        public static String displayHits() {
                StringBuilder sb = new StringBuilder();
                sb.append(_numHitsMem.longValue());
@@ -279,6 +286,8 @@ public class LineageCacheStatistics {
                sb.append(_numHitsSparkActions.longValue());
                sb.append("/");
                sb.append(_numHitsRdd.longValue());
+               sb.append("/");
+               sb.append(_numHitsRddPersist.longValue());
                return sb.toString();
        }
 }
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 89ad98e337..da8c9fb444 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -639,7 +639,7 @@ public class Statistics
                                sb.append("LinCache hits (Mem/FS/Del): \t" + 
LineageCacheStatistics.displayHits() + ".\n");
                                sb.append("LinCache MultiLevel (Ins/SB/Fn):" + 
LineageCacheStatistics.displayMultiLevelHits() + ".\n");
                                sb.append("LinCache GPU (Hit/Async/Sync): \t" + 
LineageCacheStatistics.displayGpuStats() + ".\n");
-                               sb.append("LinCache Spark (Col/RDD): \t" + 
LineageCacheStatistics.displaySparkStats() + ".\n");
+                               sb.append("LinCache Spark (Col/Loc/Dist): \t" + 
LineageCacheStatistics.displaySparkStats() + ".\n");
                                sb.append("LinCache writes (Mem/FS/Del): \t" + 
LineageCacheStatistics.displayWtrites() + ".\n");
                                sb.append("LinCache FStimes (Rd/Wr): \t" + 
LineageCacheStatistics.displayFSTime() + " sec.\n");
                                sb.append("LinCache Computetime (S/M): \t" + 
LineageCacheStatistics.displayComputeTime() + " sec.\n");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index 98c4d14c83..27b2bf24d2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -42,7 +42,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 
        protected static final String TEST_DIR = "functions/async/";
        protected static final String TEST_NAME = "LineageReuseSpark";
-       protected static final int TEST_VARIANTS = 2;
+       protected static final int TEST_VARIANTS = 3;
        protected static String TEST_CLASS_DIR = TEST_DIR + 
LineageReuseSparkTest.class.getSimpleName() + "/";
 
        @Override
@@ -63,13 +63,17 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
        }
 
-       @Ignore
        @Test
        public void testlmdsRDD() {
-               // Persist and cache RDDs of shuffle-based Spark operations 
(eg. rmm, cpmm)
+               // Cache all RDDs and persist shuffle-based Spark operations 
(eg. rmm, cpmm)
                runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
        }
 
+       @Test
+       public void testL2svm() {
+               runTest(TEST_NAME+"3", ExecMode.SPARK, 3);
+       }
+
        public void runTest(String testname, ExecMode execMode, int testId) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -87,7 +91,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 
                        List<String> proArgs = new ArrayList<>();
 
-                       proArgs.add("-explain");
+                       //proArgs.add("-explain");
                        proArgs.add("-stats");
                        proArgs.add("-args");
                        proArgs.add(output("R"));
@@ -101,7 +105,8 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        long numRmm = 
Statistics.getCPHeavyHitterCount("sp_rmm");
 
                        proArgs.clear();
-                       proArgs.add("-explain");
+                       //proArgs.add("-explain");
+                       //proArgs.add("recompile_runtime");
                        proArgs.add("-stats");
                        proArgs.add("-lineage");
                        
proArgs.add(LineageCacheConfig.ReuseCacheType.REUSE_FULL.name().toLowerCase());
@@ -120,7 +125,7 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        boolean matchVal = TestUtils.compareMatrices(R, 
R_reused, 1e-6, "Origin", "withPrefetch");
                        if (!matchVal)
                                System.out.println("Value w/o reuse "+R+" w/ 
reuse "+R_reused);
-                       if (testId == 1) {
+                       if (testId == 1 || testId == 3) {
                                Assert.assertTrue("Violated sp_tsmm reuse 
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
                                Assert.assertTrue("Violated sp_mapmm reuse 
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
                        }
diff --git a/src/test/scripts/functions/async/LineageReuseSpark3.dml 
b/src/test/scripts/functions/async/LineageReuseSpark3.dml
new file mode 100644
index 0000000000..04c5461511
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark3.dml
@@ -0,0 +1,67 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B, 
Boolean icpt)
+return (Matrix[Double] loss) {
+  if (icpt)
+    X = cbind(X, matrix(1, nrow(X), 1));
+  loss = as.matrix(sum((y - X%*%B)^2));
+}
+
+M = 100000;
+N = 20;
+sp = 1.0;
+no_lamda = 2;
+
+X = rand(rows=M, cols=N, sparsity=sp, seed=42);
+y = rand(rows=M, cols=1, min=0, max=2, seed=42);
+y = ceil(y);
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+Rbeta = matrix(0, rows=ncol(X)+1, cols=no_lamda*2);
+Rloss = matrix(0, rows=no_lamda*2, cols=1);
+i = 1;
+
+
+for (l in 1:no_lamda)
+{
+  beta = l2svm(X=X, Y=y, intercept=FALSE, epsilon=1e-12, maxIterations=1,
+      maxii=1, reg = lamda, verbose=FALSE);
+  Rbeta[1:nrow(beta),i] = beta;
+  Rloss[i,] = l2norm(X, y, beta, FALSE);
+  i = i + 1;
+
+  beta = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12, maxIterations=1,
+      maxii=1, reg = lamda, verbose=FALSE);
+  Rbeta[1:nrow(beta),i] = beta;
+  Rloss[i,] = l2norm(X, y, beta, TRUE);
+  i = i + 1;
+
+  lamda = lamda + stp;
+}
+
+leastLoss = rowIndexMin(t(Rloss));
+bestModel = Rbeta[,as.scalar(leastLoss)];
+
+R = sum(bestModel);
+write(R, $1, format="text");
+

Reply via email to