This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 10b493c  [SYSTEMDS-2634] Reduced number of RPCs calls in federated 
backend
10b493c is described below

commit 10b493cb71c72a1c6f65470166a8ab4842c239a4
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Aug 23 16:57:24 2020 +0200

    [SYSTEMDS-2634] Reduced number of RPCs calls in federated backend
    
    This patch improves the performance of the federated runtime backend by
    merging the execution and cleanup RPC request batches into a single
    batch of requests. Since every batch returns only a single response, we
    now carefully select the right get_var, error, or other responses to
    return. Overall, this reduced the number of RPC calls by almost 2x and
    removed unnecessary synchronization barriers.
---
 .../controlprogram/caching/CacheableData.java      |  2 +-
 .../federated/FederatedWorkerHandler.java          | 22 ++++++++++++++++++----
 .../controlprogram/federated/FederationMap.java    |  9 ++++++++-
 .../controlprogram/federated/FederationUtils.java  |  1 -
 .../fed/AggregateBinaryFEDInstruction.java         | 16 ++++++++--------
 .../fed/AggregateUnaryFEDInstruction.java          |  6 +++---
 .../fed/BinaryMatrixMatrixFEDInstruction.java      | 19 ++++++++-----------
 .../fed/BinaryMatrixScalarFEDInstruction.java      | 13 ++++++++-----
 .../instructions/fed/MMChainFEDInstruction.java    | 10 ++++++----
 .../instructions/fed/TsmmFEDInstruction.java       |  4 ++--
 10 files changed, 62 insertions(+), 40 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 720534a..4d0d5d9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -680,7 +680,7 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                
                //clear federated matrix
                if( _fedMapping != null )
-                       _fedMapping.cleanup(tid, _fedMapping.getID());
+                       _fedMapping.execCleanup(tid, _fedMapping.getID());
                
                // change object state EMPTY
                setDirty(false);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index f4af303..0dcb846 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -91,10 +91,24 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
                        PrivacyMonitor.clearCheckedConstraints();
        
-                       response = executeCommand(request);
-                       conditionalAddCheckedConstraints(request, response);
-                       if (!response.isSuccessful()){
-                               log.error("Command " + request.getType() + " 
failed: " + response.getErrorMessage() + "full command: \n" + 
request.toString());
+                       //execute command and handle privacy constraints
+                       FederatedResponse tmp = executeCommand(request);
+                       conditionalAddCheckedConstraints(request, tmp);
+                       
+                       //select the response for the entire batch of requests
+                       if (!tmp.isSuccessful()) {
+                               log.error("Command " + request.getType() + " 
failed: " 
+                                       + tmp.getErrorMessage() + "full 
command: \n" + request.toString());
+                               response = (response == null || 
response.isSuccessful()) 
+                                       ? tmp : response; //return first error
+                       }
+                       else if( request.getType() == RequestType.GET_VAR ) {
+                               if( response != null && response.isSuccessful() 
)
+                                       log.error("Multiple GET_VAR are not 
supported in single batch of requests.");
+                               response = tmp; //return last get result
+                       }
+                       else if( response == null && i == requests.length-1 ) {
+                               response = tmp; //return last
                        }
                }
                ctx.writeAndFlush(response).addListener(new CloseListener());
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b272bf9..72d1196 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -170,7 +170,14 @@ public class FederationMap
                return readResponses;
        }
        
-       public void cleanup(long tid, long... id) {
+       public FederatedRequest cleanup(long tid, long... id) {
+               FederatedRequest request = new 
FederatedRequest(RequestType.EXEC_INST, -1,
+                       
VariableCPInstruction.prepareRemoveInstruction(id).toString());
+               request.setTID(tid);
+               return request;
+       }
+       
+       public void execCleanup(long tid, long... id) {
                FederatedRequest request = new 
FederatedRequest(RequestType.EXEC_INST, -1,
                        
VariableCPInstruction.prepareRemoveInstruction(id).toString());
                request.setTID(tid);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index faae560..7df7c51 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -32,7 +32,6 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.functionobjects.KahanPlus;
 import org.apache.sysds.runtime.functionobjects.Mean;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 34caec2..c28a163 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -68,10 +68,10 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                                new CPOperand[]{input1, input2},
                                new long[]{mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID()});
                        FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+                       FederatedRequest fr3 = 
mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
                        //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2);
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
-                       mo2.getFedMapping().cleanup(getTID(), fr1.getID(), 
fr2.getID());
                        ec.setMatrixOutput(output.getName(), ret);
                }
                else if(mo1.isFederated(FType.ROW)) { // MV + MM
@@ -81,16 +81,16 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
                        if( mo2.getNumColumns() == 1 ) { //MV
                                FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                               FederatedRequest fr4 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
                                //execute federated operations and aggregate
-                               Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+                               Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
                                MatrixBlock ret = FederationUtils.rbind(tmp);
-                               mo1.getFedMapping().cleanup(getTID(), 
fr1.getID(), fr2.getID());
                                ec.setMatrixOutput(output.getName(), ret);
                        }
                        else { //MM
                                //execute federated operations and aggregate
-                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2);
-                               mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
+                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
                                MatrixObject out = ec.getMatrixObject(output);
                                
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), 
(int)mo1.getBlocksize());
                                
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), 
mo2.getNumColumns()));
@@ -104,10 +104,10 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1, input2}, new 
long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                       FederatedRequest fr4 = 
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
                        //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+                       Future<FederatedResponse>[] tmp = 
mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
-                       mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), 
fr2.getID());
                        ec.setMatrixOutput(output.getName(), ret);
                }
                else { //other combinations
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e87bf57..60fe40b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -55,19 +55,19 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
        public void processInstruction(ExecutionContext ec) {
                AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
                MatrixObject in = ec.getMatrixObject(input1);
+               FederationMap map = in.getFedMapping();
                
                //create federated commands for aggregation
                FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()});
                FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+               FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
                
                //execute federated commands and cleanups
-               FederationMap map = in.getFedMapping();
-               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2);
+               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2, fr3);
                if( output.isScalar() )
                        ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp));
                else
                        ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, tmp, map));
-               map.cleanup(getTID(), fr1.getID());
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 63c2d71..bceb6ae 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -42,39 +42,36 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                FederatedRequest fr2 = null;
 
                if( mo2.isFederated() ) {
-                       if(mo1.isFederated() && 
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)){
+                       if(mo1.isFederated() && 
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
                                        new long[]{mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID()});
                                mo1.getFedMapping().execute(getTID(), true, 
fr2);
-                               
-                       } else{
+                       }
+                       else {
                                throw new DMLRuntimeException("Matrix-matrix 
binary operations "
                                        + " with a federated right input are 
not supported yet.");
                        }
-
-               } 
+               }
                else {
                        //matrix-matrix binary oFederatedRequest fr2 = 
null;perations -> lhs fed input -> fed output
-                       
                        if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { 
//MV row vector
                                FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
                                        new long[]{mo1.getFedMapping().getID(), 
fr1[0].getID()});
+                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
                                //execute federated instruction and cleanup 
intermediates
-                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2);
-                               mo1.getFedMapping().cleanup(getTID(), 
fr1[0].getID());
+                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
                        }
                        else { //MM or MV col vector
                                FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
                                        new long[]{mo1.getFedMapping().getID(), 
fr1.getID()});
+                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
                                //execute federated instruction and cleanup 
intermediates
-                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2);
-                               mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
+                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
                        }
                }
                
-               
                //derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
                out.getDataCharacteristics().set(mo1.getDataCharacteristics());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 75bfe33..b6ea1fb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -39,17 +39,20 @@ public class BinaryMatrixScalarFEDInstruction extends 
BinaryFEDInstruction
                CPOperand scalar = input2.isScalar() ? input2 : input1;
                MatrixObject mo = ec.getMatrixObject(matrix);
                
-               //execute federated matrix-scalar operation and cleanups
+               //prepare federated request matrix-scalar
                FederatedRequest fr1 = !scalar.isLiteral() ?
                        mo.getFedMapping().broadcast(ec.getScalarInput(scalar)) 
: null;
                FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                        new CPOperand[]{matrix, (fr1 != null)?scalar:null},
                        new long[]{mo.getFedMapping().getID(), (fr1 != 
null)?fr1.getID():-1});
                
-               mo.getFedMapping().execute(getTID(), true, (fr1!=null) ?
-                       new FederatedRequest[]{fr1, fr2}: new 
FederatedRequest[]{fr2});
-               if( fr1 != null )
-                       mo.getFedMapping().cleanup(getTID(), fr1.getID());
+               //execute federated matrix-scalar operation and cleanups
+               if( fr1 != null ) {
+                       FederatedRequest fr3 = 
mo.getFedMapping().cleanup(getTID(), fr1.getID());
+                       mo.getFedMapping().execute(getTID(), true, fr1, fr2, 
fr3);
+               }
+               else
+                       mo.getFedMapping().execute(getTID(), true, fr2);
                
                //derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 2dee64b..99a305b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -86,11 +86,12 @@ public class MMChainFEDInstruction extends 
UnaryFEDInstruction {
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                       FederatedRequest fr4 = mo1.getFedMapping()
+                               .cleanup(getTID(), fr1.getID(), fr2.getID());
                        
                        //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
-                       mo1.getFedMapping().cleanup(getTID(), fr1.getID(), 
fr2.getID());
                        ec.setMatrixOutput(output.getName(), ret);
                }
                else { //XtwXv | XtXvy
@@ -101,11 +102,12 @@ public class MMChainFEDInstruction extends 
UnaryFEDInstruction {
                                new CPOperand[]{input1, input2, input3},
                                new long[]{mo1.getFedMapping().getID(), 
fr1.getID(), fr0[0].getID()});
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                       FederatedRequest fr4 = mo1.getFedMapping()
+                               .cleanup(getTID(), fr0[0].getID(), fr1.getID(), 
fr2.getID());
                        
                        //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3);
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3, fr4);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
-                       mo1.getFedMapping().cleanup(getTID(), fr0[0].getID(), 
fr1.getID(), fr2.getID());
                        ec.setMatrixOutput(output.getName(), ret);
                }
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 292bced..fbe88d6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -67,11 +67,11 @@ public class TsmmFEDInstruction extends 
BinaryFEDInstruction {
                        FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1}, new 
long[]{mo1.getFedMapping().getID()});
                        FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+                       FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
                        
                        //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2);
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
-                       mo1.getFedMapping().cleanup(getTID(), fr1.getID());
                        ec.setMatrixOutput(output.getName(), ret);
                }
                else { //other combinations

Reply via email to