This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 11a23342cbd82c15e21a324c37b6d3b013a236f5
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Jan 30 21:01:23 2021 +0100

    [SYSTEMDS-2816] Fix incorrect tracking of spark broadcast sizes
    
    On running multiLogReg on one day of the Criteo dataset (~65GB,
    192215183 x 40) showed initially good performance but then after a while
    individual spark jobs got significantly slower. The underlying issue is
    an incorrect tracking of live broadcast objects (and their sizes), which
    are taken into account when deciding to collect the output of an RDD
    operation or pipe it to HDFS and then read it in to avoid the double
    memory requirement (list of blocks and target matrix).
    
    In detail, root cause was that there are two kinds of broadcasts
    (partitioned and non-partitioned) which have different sizes for the
    same matrix. The removal bookkeeping took the non-partitioned sizes and
    thus ignored our default partitioned broadcasts. This patch simply
    cleans this up by taking whatever size is available.
    
    This issue was been introduced w/ SYSTEMML-1313 (Apr 2018), so every
    release afterward got affected. How much performance penalty this
    caused, is dependent on the size of broadcasts, total number of
    iterations, and performance difference hdfs-write/read vs collect.
---
 .../controlprogram/context/SparkExecutionContext.java  | 18 +++++++++---------
 .../instructions/spark/MapmmChainSPInstruction.java    |  4 ++--
 .../instructions/spark/data/BroadcastObject.java       | 10 ++++------
 3 files changed, 15 insertions(+), 17 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 41ac510..c68bd9d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -594,7 +594,7 @@ public class SparkExecutionContext extends ExecutionContext
                        //create new broadcast handle (never created, evicted)
                        // account for overwritten invalid broadcast (e.g., 
evicted)
                        if (cd.getBroadcastHandle() != null)
-                               
CacheableData.addBroadcastSize(-cd.getBroadcastHandle().getNonPartitionedBroadcastSize());
+                               
CacheableData.addBroadcastSize(-cd.getBroadcastHandle().getSize());
 
                        // read the matrix block
                        CacheBlock cb = cd.acquireRead();
@@ -609,7 +609,7 @@ public class SparkExecutionContext extends ExecutionContext
                                }
                                
cd.getBroadcastHandle().setNonPartitionedBroadcast(brBlock,
                                        
OptimizerUtils.estimateSize(cd.getDataCharacteristics()));
-                               
CacheableData.addBroadcastSize(cd.getBroadcastHandle().getNonPartitionedBroadcastSize());
+                               
CacheableData.addBroadcastSize(cd.getBroadcastHandle().getSize());
 
                                if (DMLScript.STATISTICS) {
                                        
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
@@ -643,7 +643,7 @@ public class SparkExecutionContext extends ExecutionContext
                if (bret == null) {
                        //account for overwritten invalid broadcast (e.g., 
evicted)
                        if (mo.getBroadcastHandle() != null)
-                               
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getPartitionedBroadcastSize());
+                               
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
 
                        //obtain meta data for matrix
                        int blen = (int) mo.getBlocksize();
@@ -674,7 +674,7 @@ public class SparkExecutionContext extends ExecutionContext
                        }
                        mo.getBroadcastHandle().setPartitionedBroadcast(bret,
                                
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getDataCharacteristics()));
-                       
CacheableData.addBroadcastSize(mo.getBroadcastHandle().getPartitionedBroadcastSize());
+                       
CacheableData.addBroadcastSize(mo.getBroadcastHandle().getSize());
                }
 
                if (DMLScript.STATISTICS) {
@@ -700,7 +700,7 @@ public class SparkExecutionContext extends ExecutionContext
                if (bret == null) {
                        //account for overwritten invalid broadcast (e.g., 
evicted)
                        if (to.getBroadcastHandle() != null)
-                               
CacheableData.addBroadcastSize(-to.getBroadcastHandle().getPartitionedBroadcastSize());
+                               
CacheableData.addBroadcastSize(-to.getBroadcastHandle().getSize());
 
                        //obtain meta data for matrix
                        DataCharacteristics dc = to.getDataCharacteristics();
@@ -731,7 +731,7 @@ public class SparkExecutionContext extends ExecutionContext
                        }
                        to.getBroadcastHandle().setPartitionedBroadcast(bret,
                                        
OptimizerUtils.estimatePartitionedSizeExactSparsity(to.getDataCharacteristics()));
-                       
CacheableData.addBroadcastSize(to.getBroadcastHandle().getPartitionedBroadcastSize());
+                       
CacheableData.addBroadcastSize(to.getBroadcastHandle().getSize());
                }
 
                if (DMLScript.STATISTICS) {
@@ -767,7 +767,7 @@ public class SparkExecutionContext extends ExecutionContext
                if (bret == null) {
                        //account for overwritten invalid broadcast (e.g., 
evicted)
                        if (fo.getBroadcastHandle() != null)
-                               
CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getPartitionedBroadcastSize());
+                               
CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getSize());
 
                        //obtain meta data for frame
                        int blen = OptimizerUtils.getDefaultFrameSize();
@@ -798,7 +798,7 @@ public class SparkExecutionContext extends ExecutionContext
                        
                        fo.getBroadcastHandle().setPartitionedBroadcast(bret,
                                
OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getDataCharacteristics()));
-                       
CacheableData.addBroadcastSize(fo.getBroadcastHandle().getPartitionedBroadcastSize());
+                       
CacheableData.addBroadcastSize(fo.getBroadcastHandle().getSize());
                }
 
                if (DMLScript.STATISTICS) {
@@ -1417,7 +1417,7 @@ public class SparkExecutionContext extends 
ExecutionContext
                                if( bc != null ) //robustness evictions
                                        cleanupBroadcastVariable(bc);
                        }
-                       
CacheableData.addBroadcastSize(-bob.getNonPartitionedBroadcastSize());
+                       CacheableData.addBroadcastSize(-bob.getSize());
                }
 
                //recursively process lineage children
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
index 76f8597..b1c8248 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
@@ -104,12 +104,12 @@ public class MapmmChainSPInstruction extends 
SPInstruction {
                MatrixBlock out = null;
                if( _chainType == ChainType.XtXv ) {
                        JavaRDD<MatrixBlock> tmp = inX.values().map(new 
RDDMapMMChainFunction(inV));
-                       out = RDDAggregateUtils.sumStable(tmp);         
+                       out = RDDAggregateUtils.sumStable(tmp);
                }
                else { // ChainType.XtwXv / ChainType.XtXvy
                        PartitionedBroadcast<MatrixBlock> inW = 
sec.getBroadcastForVariable( _input3.getName() );
                        JavaRDD<MatrixBlock> tmp = inX.map(new 
RDDMapMMChainFunction2(inV, inW, _chainType));
-                       out = RDDAggregateUtils.sumStable(tmp);         
+                       out = RDDAggregateUtils.sumStable(tmp);
                }
                
                //put output block into symbol table (no lineage because single 
block)
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/BroadcastObject.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/BroadcastObject.java
index 10e0d1e..7de9171 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/BroadcastObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/BroadcastObject.java
@@ -39,11 +39,13 @@ public class BroadcastObject<T extends CacheBlock> extends 
LineageObject {
        public void setNonPartitionedBroadcast(Broadcast<T> bvar, long size) {
                _npbcRef = new SoftReference<>(bvar);
                _npbcSize = size;
+               _pbcSize = -1;
        }
 
        public void setPartitionedBroadcast(PartitionedBroadcast<T> bvar, long 
size) {
                _pbcRef = new SoftReference<>(bvar);
                _pbcSize = size;
+               _npbcSize = -1;
        }
 
        @SuppressWarnings("rawtypes")
@@ -55,12 +57,8 @@ public class BroadcastObject<T extends CacheBlock> extends 
LineageObject {
                return _npbcRef.get();
        }
 
-       public long getPartitionedBroadcastSize() {
-               return _pbcSize;
-       }
-
-       public long getNonPartitionedBroadcastSize() {
-               return _npbcSize;
+       public long getSize() {
+               return Math.max(_pbcSize, _npbcSize);
        }
 
        public boolean isPartitionedBroadcastValid() {

Reply via email to