This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e5c247  [MINOR] Fix broadcast handling federated ternary-aggregate
7e5c247 is described below

commit 7e5c2472c4c98f411784d0ae87798fec108d2e1b
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Mar 13 16:35:57 2022 +0100

    [MINOR] Fix broadcast handling federated ternary-aggregate
---
 .../runtime/instructions/fed/AggregateTernaryFEDInstruction.java     | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index b253f69..f981454 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -104,7 +104,9 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                                new CPOperand[] {input1, input2, input3},
                                new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr1[0].getID()}, true);
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
-                       FederatedRequest fr4 = 
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+                       FederatedRequest fr4 = (mo3 == null) ? 
+                               mo2.getFedMapping().cleanup(getTID(), 
fr1[0].getID(), fr2.getID()) :
+                               mo2.getFedMapping().cleanup(getTID(), 
fr2.getID()); //no cleanup of broadcasts
                        Future<FederatedResponse>[] tmp = (mo3 == null) ?
                                mo1.getFedMapping().execute(getTID(), fr1[0], 
fr2, fr3, fr4) :
                                mo1.getFedMapping().execute(getTID(), fr1, fr2, 
fr3, fr4);
@@ -118,7 +120,6 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                                        catch(Exception e) {
                                                throw new 
DMLRuntimeException("Federated Get data failed with exception on 
TernaryFedInstruction", e);
                                        }
-
                                ec.setScalarOutput(output.getName(), new 
DoubleObject(sum));
                        }
                        else {

Reply via email to