This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 33de453e87523253c6a1e853804a2bbc40021f04
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Aug 10 17:32:45 2023 +0200

    [SYSTEMDS-3606] Performance shuffle-based spark quaternary operations
    
    This patch significantly improves the performance of shuffle-based
    spark quaternary operations, where more than one input is an RDD
    (too large to broadcast). Instead of replicating the factor blocks, we
    now use custom join keys enabling spark to perform more efficient
    1:M joins. With appropriate function abstractions, the implementation
    also got simpler and thus, easier to maintain.
    
    On the scenario mentioned in the JIRA task, the original implementation
    did not finish any task of the first shuffle phase after >9000s, while
    with the new implementation the entire script (with two shuffle-based
    quaternary operators) finishes in 1276s. Here are the stats:
    
    SystemDS Statistics:
    Total elapsed time:             1276.917 sec.
    Total compilation time:         2.338 sec.
    Total execution time:           1274.578 sec.
    Number of compiled Spark inst:  4.
    Number of executed Spark inst:  4.
    Cache hits (Mem/Li/WB/FS/HDFS): 13/2/0/1/0.
    Cache writes (Li/WB/FS/HDFS):   4/6/4/1.
    Cache times (ACQr/m, RLS, EXP): 1209.517/0.001/10.926/8.589 sec.
    HOP DAGs recompiled (PRED, SB): 0/1.
    HOP DAGs recompile time:        0.006 sec.
    Functions recompiled:           1.
    Functions recompile time:       0.011 sec.
    Spark ctx create time (lazy):   19.302 sec.
    Spark trans counts (par,bc,col):0/3/1.
    Spark trans times (par,bc,col): 0.000/13.671/644.719 secs.
    Spark async. count (pf,bc,op):  0/0/0.
    Total JIT compile time:         73.677 sec.
    Total JVM GC count:             188.
    Total JVM GC time:              23.182 sec.
    Heavy hitter instructions:
      1  m_pnmf        714.304      1
      2  r'            653.012      5
      3  uak+          560.027      2
      4  sp_redwdivmm   42.446      2
      5  rand            9.414      4
      6  *               3.544      1
      7  /               3.491      1
      8  uack+           3.466      1
      9  uark+           2.146      1
     10  rmvar           0.246     15
---
 .../spark/QuaternarySPInstruction.java             | 174 +++++++--------------
 1 file changed, 59 insertions(+), 115 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java
index 9c9a063d31..dbb71c724f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.api.java.function.PairFunction;
@@ -44,7 +45,6 @@ import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
 import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
 import 
org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
-import 
org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction;
 import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
@@ -198,11 +198,6 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction {
                JavaPairRDD<MatrixIndexes,MatrixBlock> in = 
sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
                JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
                
-               DataCharacteristics inMc = sec.getDataCharacteristics( 
input1.getName() );
-               long rlen = inMc.getRows();
-               long clen = inMc.getCols();
-               int blen = inMc.getBlocksize();
-               
                //pre-filter empty blocks (ultra-sparse matrices) for full 
aggregates
                //(map/redwsloss, map/redwcemm); safe because theses ops 
produce a scalar
                if( qop.wtype1 != null || qop.wtype4 != null ) {
@@ -237,42 +232,25 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction {
                        JavaPairRDD<MatrixIndexes,MatrixBlock> inW = 
(qop.hasFourInputs() && !_input4.isLiteral()) ? 
                                        
sec.getBinaryMatrixBlockRDDHandleForVariable( _input4.getName() ) : null;
 
-                       //preparation of transposed and replicated U
-                       if( inU != null )
-                               inU = inU.flatMapToPair(new 
ReplicateBlockFunction(clen, blen, true));
-
-                       //preparation of transposed and replicated V
-                       if( inV != null )
-                               inV = inV.mapToPair(new 
TransposeFactorIndexesFunction())
-                                        .flatMapToPair(new 
ReplicateBlockFunction(rlen, blen, false));
+                       //join X and W on original indexes if W existing
+                       JavaPairRDD<MatrixIndexes,MatrixBlock[]> tmp = (inW != 
null) ?
+                               in.join(inW).mapToPair(new ToArray()) :
+                               in.mapValues(mb -> new MatrixBlock[]{mb, null});
                        
-                       //functions calls w/ two rdd inputs             
-                       if( inU != null && inV == null && inW == null )
-                               out = in.join(inU)
-                                       .mapToPair(new 
RDDQuaternaryFunction2(qop, bc1, bc2));
-                       else if( inU == null && inV != null && inW == null )
-                               out = in.join(inV)
-                                       .mapToPair(new 
RDDQuaternaryFunction2(qop, bc1, bc2));
-                       else if( inU == null && inV == null && inW != null )
-                               out = in.join(inW)
-                                       .mapToPair(new 
RDDQuaternaryFunction2(qop, bc1, bc2));
-                       //function calls w/ three rdd inputs
-                       else if( inU != null && inV != null && inW == null )
-                               out = in.join(inU).join(inV)
-                                       .mapToPair(new 
RDDQuaternaryFunction3(qop, bc1, bc2));
-                       else if( inU != null && inV == null && inW != null )
-                               out = in.join(inU).join(inW)
-                                       .mapToPair(new 
RDDQuaternaryFunction3(qop, bc1, bc2));
-                       else if( inU == null && inV != null && inW != null )
-                               out = in.join(inV).join(inW)
-                                       .mapToPair(new 
RDDQuaternaryFunction3(qop, bc1, bc2));
-                       else if( inU == null && inV == null && inW == null ) {
-                               out = in.mapPartitionsToPair(new 
RDDQuaternaryFunction1(qop, bc1, bc2), false);
-                       }
-                       //function call w/ four rdd inputs
-                       else //need keys in case of wdivmm 
-                               out = in.join(inU).join(inV).join(inW)
-                                       .mapToPair(new 
RDDQuaternaryFunction4(qop));
+                       //join lhs U on row-block indexes of X/W
+                       tmp = ( inU != null ) ?
+                               tmp.mapToPair(new ExtractIndexWith(true))
+                                       .join(inU.mapToPair(new 
ExtractIndex(true))).mapToPair(new Unpack()) :
+                               tmp.mapValues(mb -> ArrayUtils.add(mb, null));
+                       
+                       //join rhs V on column-block indexes X/W (note V 
transposed input, so rows)
+                       tmp = ( inV != null ) ?
+                               tmp.mapToPair(new ExtractIndexWith(false))
+                                       .join(inV.mapToPair(new 
ExtractIndex(true))).mapToPair(new Unpack()) :
+                               tmp.mapValues(mb -> ArrayUtils.add(mb, null));
+                       
+                       //execute quaternary block operations on joined inputs
+                       out = tmp.mapToPair(new RDDQuaternaryFunction2(qop, 
bc1, bc2));
                        
                        //keep variable names for lineage maintenance
                        if( inU == null ) bcVars.add(input2.getName()); else 
rddVars.add(input2.getName());
@@ -374,12 +352,11 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction {
                        protected Tuple2<MatrixIndexes, MatrixBlock> 
computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) {
                                MatrixIndexes ixIn = arg._1();
                                MatrixBlock blkIn = arg._2();
-                               MatrixBlock blkOut = new MatrixBlock();
                                MatrixBlock mbU = 
_pmU.getBlock((int)ixIn.getRowIndex(), 1);
                                MatrixBlock mbV = 
_pmV.getBlock((int)ixIn.getColumnIndex(), 1);
                                
                                //execute core operation
-                               blkIn.quaternaryOperations(_qop, mbU, mbV, 
null, blkOut);
+                               MatrixBlock blkOut = 
blkIn.quaternaryOperations(_qop, mbU, mbV, null, new MatrixBlock());
                                
                                //create return tuple
                                MatrixIndexes ixOut = createOutputIndexes(ixIn);
@@ -389,7 +366,7 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction {
        }
 
        private static class RDDQuaternaryFunction2 extends 
RDDQuaternaryBaseFunction //two rdd input
-               implements PairFunction<Tuple2<MatrixIndexes, 
Tuple2<MatrixBlock,MatrixBlock>>, MatrixIndexes, MatrixBlock>
+               implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, 
MatrixIndexes, MatrixBlock>
        {
                private static final long serialVersionUID = 
7493974462943080693L;
                
@@ -398,17 +375,15 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction {
                }
 
                @Override
-               public Tuple2<MatrixIndexes, MatrixBlock> 
call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0) {
+               public Tuple2<MatrixIndexes, MatrixBlock> 
call(Tuple2<MatrixIndexes, MatrixBlock[]> arg0) {
                        MatrixIndexes ixIn = arg0._1();
-                       MatrixBlock blkIn1 = arg0._2()._1();
-                       MatrixBlock blkIn2 = arg0._2()._2();
-                       MatrixBlock blkOut = new MatrixBlock();
-                       MatrixBlock mbU = 
(_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
-                       MatrixBlock mbV = 
(_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blkIn2;
-                       MatrixBlock mbW = (_qop.hasFourInputs()) ? blkIn2 : 
null;
+                       MatrixBlock[] blks = arg0._2();
+                       MatrixBlock mbU = 
(_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blks[2];
+                       MatrixBlock mbV = 
(_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blks[3];
+                       MatrixBlock mbW = (_qop.hasFourInputs()) ? blks[1] : 
null;
                        
                        //execute core operation
-                       blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, 
blkOut);
+                       MatrixBlock blkOut = blks[0].quaternaryOperations(_qop, 
mbU, mbV, mbW, new MatrixBlock());
                        
                        //create return tuple
                        MatrixIndexes ixOut = createOutputIndexes(ixIn);
@@ -416,82 +391,51 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction {
                }
        }
 
-       private static class RDDQuaternaryFunction3 extends 
RDDQuaternaryBaseFunction //three rdd input
-               implements PairFunction<Tuple2<MatrixIndexes, 
Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, MatrixIndexes, 
MatrixBlock>
-       {
-               private static final long serialVersionUID = 
-2294086455843773095L;
-               
-               public RDDQuaternaryFunction3( QuaternaryOperator qop, 
PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) {
-                       super(qop, bcU, bcV);
+       private static class ExtractIndex implements 
PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, Long, MatrixBlock> {
+               private static final long serialVersionUID = 
-6542246824481788376L;
+               private final boolean _row;
+               public ExtractIndex(boolean row) {
+                       _row = row;
                }
-
                @Override
-               public Tuple2<MatrixIndexes, MatrixBlock> 
call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, 
MatrixBlock>> arg0) {
-                       MatrixIndexes ixIn = arg0._1();
-                       MatrixBlock blkIn1 = arg0._2()._1()._1();
-                       MatrixBlock blkIn2 = arg0._2()._1()._2();
-                       MatrixBlock blkIn3 = arg0._2()._2();
-                       MatrixBlock blkOut = new MatrixBlock();
-                       MatrixBlock mbU = 
(_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
-                       MatrixBlock mbV = 
(_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : 
-                                             (_pmU!=null)? blkIn2 : blkIn3;
-                       MatrixBlock mbW = (_qop.hasFourInputs())? blkIn3 : null;
-                       
-                       //execute core operation
-                       blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, 
blkOut);
-                       
-                       //create return tuple
-                       MatrixIndexes ixOut = createOutputIndexes(ixIn);
-                       return new Tuple2<>(ixOut, blkOut);
+               public Tuple2<Long, MatrixBlock> call(Tuple2<MatrixIndexes, 
MatrixBlock> arg) throws Exception {
+                       return new 
Tuple2<>(_row?arg._1().getRowIndex():arg._1().getColumnIndex(), arg._2());
                }
        }
        
-       /**
-        * Note: never called for wsigmoid/wdivmm (only wsloss)
-        */
-       private static class RDDQuaternaryFunction4 extends 
RDDQuaternaryBaseFunction //four rdd input
-               implements 
PairFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>,MatrixBlock>>,MatrixIndexes,MatrixBlock>
-       {
-               private static final long serialVersionUID = 
7328911771600289250L;
-               
-               public RDDQuaternaryFunction4( QuaternaryOperator qop ) {
-                       super(qop, null, null);
+       private static class ExtractIndexWith implements 
PairFunction<Tuple2<MatrixIndexes,MatrixBlock[]>, Long, 
Tuple2<MatrixIndexes,MatrixBlock[]>> {
+               private static final long serialVersionUID = 
-966212318512764461L;
+               private final boolean _row;
+               public ExtractIndexWith(boolean row) {
+                       _row = row;
                }
+               @Override
+               public Tuple2<Long, Tuple2<MatrixIndexes, MatrixBlock[]>> 
call(Tuple2<MatrixIndexes, MatrixBlock[]> arg)
+                       throws Exception
+               {
+                       return new 
Tuple2<>(_row?arg._1().getRowIndex():arg._1().getColumnIndex(), arg);
+               }
+       }
+       
+       private static class ToArray implements 
PairFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock,MatrixBlock>>, 
MatrixIndexes, MatrixBlock[]> {
+               private static final long serialVersionUID = 
-4856316007590144978L;
 
                @Override
-               public Tuple2<MatrixIndexes, MatrixBlock> 
call(Tuple2<MatrixIndexes, Tuple2<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, 
MatrixBlock>, MatrixBlock>> arg0)
+               public Tuple2<MatrixIndexes, MatrixBlock[]> 
call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
+                       throws Exception 
                {
-                       MatrixIndexes ixIn1 = arg0._1();
-                       MatrixBlock blkIn1 = arg0._2()._1()._1()._1();
-                       MatrixBlock mbU = arg0._2()._1()._1()._2();
-                       MatrixBlock mbV = arg0._2()._1()._2();
-                       MatrixBlock mbW = arg0._2()._2();
-                       MatrixBlock blkOut = new MatrixBlock();
-                       
-                       //execute core operation
-                       blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, 
blkOut);
-                       
-                       //create return tuple
-                       MatrixIndexes ixOut = createOutputIndexes(ixIn1);
-                       return new Tuple2<>(ixOut, blkOut);
+                       return new Tuple2<>(arg._1(), new 
MatrixBlock[]{arg._2()._1(),arg._2()._2()});
                }
        }
        
-       private static class TransposeFactorIndexesFunction implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> 
-       {
-               private static final long serialVersionUID = 
-2571724736131823708L;
-               
+       private static class Unpack implements PairFunction<Tuple2<Long, 
Tuple2<Tuple2<MatrixIndexes,MatrixBlock[]>,MatrixBlock>>, MatrixIndexes, 
MatrixBlock[]> {
+               private static final long serialVersionUID = 
3812660351709830714L;
                @Override
-               public Tuple2<MatrixIndexes, MatrixBlock> call( 
Tuple2<MatrixIndexes, MatrixBlock> arg0 ) {
-                       MatrixIndexes ixIn = arg0._1();
-                       MatrixBlock blkIn = arg0._2();
-
-                       //swap the matrix indexes
-                       MatrixIndexes ixOut = new 
MatrixIndexes(ixIn.getColumnIndex(), ixIn.getRowIndex());
-                       MatrixBlock blkOut = new MatrixBlock(blkIn);
-                       
-                       //output new tuple
-                       return new Tuple2<>(ixOut,blkOut);
+               public Tuple2<MatrixIndexes, MatrixBlock[]> call(
+                       Tuple2<Long, Tuple2<Tuple2<MatrixIndexes, 
MatrixBlock[]>, MatrixBlock>> arg) throws Exception
+               {
+                       return new Tuple2<>(arg._2()._1()._1(),                 
   //matrix indexes
+                               ArrayUtils.addAll(arg._2()._1()._2(), 
arg._2()._2())); //array of matrix blocks
                }
        }
 }

Reply via email to