This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new ccd6a36  [SYSTEMDS-3018] Conflict handling federated plan enumeration
ccd6a36 is described below

commit ccd6a360dc072e62f36f4f76e76234c8c3079487
Author: sebwrede <[email protected]>
AuthorDate: Mon Dec 27 21:34:25 2021 +0100

    [SYSTEMDS-3018] Conflict handling federated plan enumeration
    
    Federated plan enumeration build a global data flow graph and computes
    optimal plans per interesting property (fed-out, local-out). In trees,
    we could purely compose optimal plans from optimal plans of inputs, but
    in DAGs optimal input plans of n-ary operations might not agree on the
    decisions of common subexpressions.
    
    We mitigate this issue (fed-out vs local-out) decisions by keeping the
    data federated, but additionally spawning an asynchronous prefetch
    operation to also bring the data into local memory if at least one
    subplan prefers local intermediates.
    
    Closes #1476.
    
    Co-authored-by: arnabp <[email protected]>
---
 scripts/builtin/normalize.dml                      |   2 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |  32 +++++-
 .../sysds/hops/cost/FederatedCostEstimator.java    |   8 +-
 .../java/org/apache/sysds/hops/cost/HopRel.java    |  27 ++---
 .../hops/ipa/IPAPassRewriteFederatedPlan.java      |  45 +++++---
 .../java/org/apache/sysds/hops/ipa/MemoTable.java  | 118 +++++++++++++++++++++
 src/main/java/org/apache/sysds/lops/Lop.java       |  20 ++++
 .../java/org/apache/sysds/lops/compile/Dag.java    |  42 +++++++-
 .../controlprogram/federated/FederationMap.java    |   7 +-
 .../instructions/cp/BroadcastCPInstruction.java    |   6 +-
 .../instructions/cp/PrefetchCPInstruction.java     |   6 +-
 ...sTask.java => TriggerRemoteOperationsTask.java} |  15 ++-
 .../sysds/runtime/util/CommonThreadPool.java       |   8 +-
 .../java/org/apache/sysds/utils/Statistics.java    |   8 ++
 .../fedplanning/FederatedMultiplyPlanningTest.java |  56 ++++++++--
 .../fedplanning/FederatedMultiplyPlanningTest7.dml |  29 +++++
 .../FederatedMultiplyPlanningTest7Reference.dml    |  27 +++++
 .../fedplanning/FederatedMultiplyPlanningTest8.dml |  31 ++++++
 .../FederatedMultiplyPlanningTest8Reference.dml    |  29 +++++
 19 files changed, 445 insertions(+), 71 deletions(-)

diff --git a/scripts/builtin/normalize.dml b/scripts/builtin/normalize.dml
index e2a32be..f7b86c2 100644
--- a/scripts/builtin/normalize.dml
+++ b/scripts/builtin/normalize.dml
@@ -39,6 +39,6 @@ m_normalize = function(Matrix[Double] X)
   # compute feature ranges for transformations
   cmin = colMins(X);
   cmax = colMaxs(X);
-       # normalize features to range [0,1]
+  # normalize features to range [0,1]
   Y = normalizeApply(X, cmin, cmax);
 }
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 336beb0..f47fcff 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -93,6 +93,14 @@ public abstract class Hop implements ParseInfo {
         */
        protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
        protected FederatedCost _federatedCost = new FederatedCost();
+
+       /**
+        * Field defining if prefetch should be activated for operation.
+        * When prefetch is activated, the output will be transferred from
+        * remote federated sites to local before one of the subsequent
+        * local operations.
+        */
+       protected boolean activatePrefetch;
        
        // Estimated size for the output produced from this Hop in bytes
        protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
@@ -187,6 +195,21 @@ public abstract class Hop implements ParseInfo {
        public void setFederatedOutput(FederatedOutput federatedOutput){
                _federatedOutput = federatedOutput;
        }
+
+       /**
+        * Activate prefetch of HOP.
+        */
+       public void activatePrefetch(){
+               activatePrefetch = true;
+       }
+
+       /**
+        * Checks if prefetch is activated for this hop.
+        * @return true if prefetch is activated
+        */
+       public boolean prefetchActivated(){
+               return activatePrefetch;
+       }
        
        public void resetExecType()
        {
@@ -352,6 +375,8 @@ public abstract class Hop implements ParseInfo {
                //propagate federated output configuration to lops
                if( isFederated() )
                        getLops().setFederatedOutput(_federatedOutput);
+               if ( prefetchActivated() )
+                       getLops().activatePrefetch();
                
                //Step 1: construct reblock lop if required (output of hop)
                constructAndSetReblockLopIfRequired();
@@ -869,8 +894,11 @@ public abstract class Hop implements ParseInfo {
         * This method only has an effect if FEDERATED_COMPILATION is activated.
         * Federated compilation is activated in OptimizerUtils.
         */
-       protected void updateETFed(){
-               if ( someInputFederated() || isFederatedDataOp() )
+       protected void updateETFed() {
+               boolean localOut = hasLocalOutput();
+               boolean fedIn = getInput().stream().anyMatch(
+                       in -> in.hasFederatedOutput() && 
!(in.prefetchActivated() && localOut));
+               if( isFederatedDataOp() || fedIn )
                        _etype = ExecType.FED;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java 
b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 7089ed8..96a33d4 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.hops.cost;
 
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.ipa.MemoTable;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.ForStatement;
 import org.apache.sysds.parser.ForStatementBlock;
@@ -33,8 +34,6 @@ import org.apache.sysds.parser.WhileStatement;
 import org.apache.sysds.parser.WhileStatementBlock;
 
 import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
 
 /**
  * Cost estimator for federated executions with methods and constants for 
going through DML programs to estimate costs.
@@ -200,10 +199,9 @@ public class FederatedCostEstimator {
         * @param hopRelMemo memo table of HopRels for calculating input costs
         * @return cost estimation of Hop DAG starting from given root HopRel
         */
-       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+       public FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
                // Check if root is in memo table.
-               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
-                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+               if ( hopRelMemo.containsHopRel(root) ){
                        return root.getCostObject();
                }
                else {
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java 
b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index 6191a6c..b1cc6dd 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -21,15 +21,14 @@ package org.apache.sysds.hops.cost;
 
 import org.apache.sysds.api.DMLException;
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.ipa.MemoTable;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Comparator;
 import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -52,7 +51,7 @@ public class HopRel {
         * @param fedOut FederatedOutput value assigned to this HopRel
         * @param hopRelMemo memo table storing other HopRels including the 
inputs of associatedHop
         */
-       public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, 
Map<Long, List<HopRel>> hopRelMemo){
+       public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, 
MemoTable hopRelMemo){
                hopRef = associatedHop;
                this.fedOut = fedOut;
                setInputDependency(hopRelMemo);
@@ -108,27 +107,15 @@ public class HopRel {
         * @param hopRelMemo memo table storing HopRels
         * @return FOUT HopRel found in hopRelMemo
         */
-       private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> 
hopRelMemo){
-               return 
hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
-       }
-
-       /**
-        * Get the HopRel with minimum cost for given hop
-        * @param hopRelMemo memo table storing HopRels
-        * @param input hop for which minimum cost HopRel is found
-        * @return HopRel with minimum cost for given hop
-        */
-       private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop 
input){
-               return hopRelMemo.get(input.getHopID()).stream()
-                       .min(Comparator.comparingDouble(a -> a.cost.getTotal()))
-                       .orElseThrow(() -> new DMLException("No element in Memo 
Table found for input"));
+       private HopRel getFOUTHopRel(Hop hop, MemoTable hopRelMemo){
+               return hopRelMemo.getFederatedOutputAlternativeOrNull(hop);
        }
 
        /**
         * Set valid and optimal input dependency for this HopRel as a field.
         * @param hopRelMemo memo table storing input HopRels
         */
-       private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
+       private void setInputDependency(MemoTable hopRelMemo){
                if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
                        if ( fedOut == FederatedOutput.FOUT && 
!hopRef.isFederatedDataOp() ) {
                                int lowestFOUTIndex = 0;
@@ -152,7 +139,7 @@ public class HopRel {
                                for(int i = 0; i < hopRef.getInput().size(); 
i++) {
                                        if(i != lowestFOUTIndex) {
                                                Hop input = hopRef.getInput(i);
-                                               inputHopRels[i] = 
getMinOfInput(hopRelMemo, input);
+                                               inputHopRels[i] = 
hopRelMemo.getMinCostAlternative(input);
                                        }
                                        else {
                                                inputHopRels[i] = 
lowestFOUTHopRel;
@@ -162,7 +149,7 @@ public class HopRel {
                        } else {
                                inputDependency.addAll(
                                        hopRef.getInput().stream()
-                                               .map(input -> 
getMinOfInput(hopRelMemo, input))
+                                               
.map(hopRelMemo::getMinCostAlternative)
                                                .collect(Collectors.toList()));
                        }
                }
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index 8c8df49..59333ab 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -19,7 +19,6 @@
 
 package org.apache.sysds.hops.ipa;
 
-import org.apache.sysds.api.DMLException;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggUnaryOp;
 import org.apache.sysds.hops.BinaryOp;
@@ -45,10 +44,9 @@ import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction;
 
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
+import java.util.Set;
 
 /**
  * This rewrite generates a federated execution plan by estimating and setting 
costs and the FederatedOutput values of
@@ -57,7 +55,8 @@ import java.util.Map;
  */
 public class IPAPassRewriteFederatedPlan extends IPAPass {
 
-       private final static Map<Long, List<HopRel>> hopRelMemo = new 
HashMap<>();
+       private final static MemoTable hopRelMemo = new MemoTable();
+       private final static Set<Long> hopRelUpdatedFinal = new HashSet<>();
 
        /**
         * Indicates if an IPA pass is applicable for the current configuration.
@@ -66,7 +65,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         * @param fgraph function call graph
         * @return true if federated compilation is activated.
         */
-       @Override public boolean isApplicable(FunctionCallGraph fgraph) {
+       @Override
+       public boolean isApplicable(FunctionCallGraph fgraph) {
                return OptimizerUtils.FEDERATED_COMPILATION;
        }
 
@@ -79,7 +79,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         * @param fcallSizes function call size infos
         * @return false since the function call graph never has to be rebuilt
         */
-       @Override public boolean rewriteProgram(DMLProgram prog, 
FunctionCallGraph fgraph,
+       @Override
+       public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
                FunctionCallSizeInfo fcallSizes) {
                rewriteStatementBlocks(prog, prog.getStatementBlocks());
                return false;
@@ -189,9 +190,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         * @param root hop for which FederatedOutput needs to be set
         */
        private void setFinalFedout(Hop root) {
-               HopRel optimalRootHopRel = 
hopRelMemo.get(root.getHopID()).stream()
-                       .min(Comparator.comparingDouble(HopRel::getCost))
-                       .orElseThrow(() -> new DMLException("Hop root " + root 
+ " has no feasible federated output alternatives"));
+               HopRel optimalRootHopRel = 
hopRelMemo.getMinCostAlternative(root);
                setFinalFedout(root, optimalRootHopRel);
        }
 
@@ -202,8 +201,21 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         * @param rootHopRel from which FederatedOutput value and cost is 
retrieved
         */
        private void setFinalFedout(Hop root, HopRel rootHopRel) {
-               updateFederatedOutput(root, rootHopRel);
-               visitInputDependency(rootHopRel);
+               if ( hopRelUpdatedFinal.contains(root.getHopID()) ){
+                       if((rootHopRel.hasLocalOutput() ^ 
root.hasLocalOutput()) && hopRelMemo.hasFederatedOutputAlternative(root)){
+                               // Update with FOUT alternative without 
visiting inputs
+                               updateFederatedOutput(root, 
hopRelMemo.getFederatedOutputAlternative(root));
+                               root.activatePrefetch();
+                       }
+                       else {
+                               // Update without visiting inputs
+                               updateFederatedOutput(root, rootHopRel);
+                       }
+               }
+               else {
+                       updateFederatedOutput(root, rootHopRel);
+                       visitInputDependency(rootHopRel);
+               }
        }
 
        /**
@@ -226,6 +238,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
        private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
                root.setFederatedOutput(updateHopRel.getFederatedOutput());
                root.setFederatedCost(updateHopRel.getCostObject());
+               hopRelUpdatedFinal.add(root.getHopID());
        }
 
        /**
@@ -257,7 +270,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         */
        private void visitFedPlanHop(Hop currentHop) {
                // If the currentHop is in the hopRelMemo table, it means that 
it has been visited
-               if(hopRelMemo.containsKey(currentHop.getHopID()))
+               if(hopRelMemo.containsHop(currentHop))
                        return;
                // If the currentHop has input, then the input should be 
visited depth-first
                if(currentHop.getInput() != null && 
currentHop.getInput().size() > 0) {
@@ -273,7 +286,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
                }
                if(hopRels.isEmpty())
                        hopRels.add(new HopRel(currentHop, 
FEDInstruction.FederatedOutput.NONE, hopRelMemo));
-               hopRelMemo.put(currentHop.getHopID(), hopRels);
+               hopRelMemo.put(currentHop, hopRels);
        }
 
        /**
@@ -319,8 +332,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
                if(associatedHop instanceof AggUnaryOp && 
associatedHop.isScalar())
                        return false;
                // It can only be FOUT if at least one of the inputs are FOUT, 
except if it is a federated DataOp
-               if(associatedHop.getInput().stream().noneMatch(input -> 
hopRelMemo.get(input.getHopID()).stream()
-                       .anyMatch(HopRel::hasFederatedOutput)) && 
!associatedHop.isFederatedDataOp())
+               
if(associatedHop.getInput().stream().noneMatch(hopRelMemo::hasFederatedOutputAlternative)
+                       && !associatedHop.isFederatedDataOp())
                        return false;
                return true;
        }
diff --git a/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java 
b/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
new file mode 100644
index 0000000..c1aeff6
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.ipa;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.cost.HopRel;
+
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Memoization of federated execution alternatives.
+ * This memoization data structure is used when generating optimal federated 
execution plans.
+ * The alternative executions are stored as HopRels and the methods of this 
class are used to
+ * add, update, and retrieve the alternatives.
+ */
+public class MemoTable {
+       //TODO refactoring: could we generalize the privacy and codegen memo 
tables into 
+       // a generic implementation (e.g., MemoTable<HopRel>) that can be 
reused in both? 
+       
+       /**
+        * Map holding the relation between Hop IDs and execution plan 
alternatives.
+        */
+       private final static Map<Long, List<HopRel>> hopRelMemo = new 
HashMap<>();
+
+       /**
+        * Get the HopRel with minimum cost for given root hop
+        * @param root hop for which minimum cost HopRel is found
+        * @return HopRel with minimum cost for given hop
+        */
+       public HopRel getMinCostAlternative(Hop root){
+               return hopRelMemo.get(root.getHopID()).stream()
+                       .min(Comparator.comparingDouble(HopRel::getCost))
+                       .orElseThrow(() -> new DMLException("Hop root " + root 
+ " has no feasible federated output alternatives"));
+       }
+
+       /**
+        * Checks if any of the federated execution alternatives for the given 
root hop has federated output.
+        * @param root hop for which execution alternatives are checked
+        * @return true if root has federated output as an execution alternative
+        */
+       public boolean hasFederatedOutputAlternative(Hop root){
+               return 
hopRelMemo.get(root.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput);
+       }
+
+       /**
+        * Get the federated output alternative for given root hop or throw 
exception if not found.
+        * @param root hop for which federated output HopRel is returned
+        * @return federated output HopRel for given root hop
+        */
+       public HopRel getFederatedOutputAlternative(Hop root){
+               return getFederatedOutputAlternativeOptional(root).orElseThrow(
+                       () -> new DMLException("Hop root " + root + " has no 
FOUT alternative"));
+       }
+
+       /**
+        * Get the federated output alternative for given root hop or null if 
not found.
+        * @param root hop for which federated output HopRel is returned
+        * @return federated output HopRel for given root hop
+        */
+       public HopRel getFederatedOutputAlternativeOrNull(Hop root){
+               return getFederatedOutputAlternativeOptional(root).orElse(null);
+       }
+
+       private Optional<HopRel> getFederatedOutputAlternativeOptional(Hop 
root){
+               return 
hopRelMemo.get(root.getHopID()).stream().filter(HopRel::hasFederatedOutput).findFirst();
+       }
+
+       /**
+        * Memoize hopRels related to given root.
+        * @param root for which hopRels are added
+        * @param hopRels execution alternatives related to the given root
+        */
+       public void put(Hop root, List<HopRel> hopRels){
+               hopRelMemo.put(root.getHopID(), hopRels);
+       }
+
+       /**
+        * Checks if root hop has been added to memo.
+        * @param root hop
+        * @return true if root has been added to memo.
+        */
+       public boolean containsHop(Hop root){
+               return hopRelMemo.containsKey(root.getHopID());
+       }
+
+       /**
+        * Checks if given HopRel has been added to memo.
+        * @param root HopRel
+        * @return true if root HopRel has been added to memo.
+        */
+       public boolean containsHopRel(HopRel root){
+               return containsHop(root.getHopRef())
+                       && hopRelMemo.get(root.getHopRef().getHopID()).stream()
+                       .anyMatch(h -> h.getFederatedOutput() == 
root.getFederatedOutput());
+       }
+}
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java 
b/src/main/java/org/apache/sysds/lops/Lop.java
index 7da091f..dda7cdd 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -117,6 +117,14 @@ public abstract class Lop
        protected PrivacyConstraint privacyConstraint;
 
        /**
+        * Field defining if prefetch should be activated for operation.
+        * When prefetch is activated, the output will be transferred from
+        * remote federated sites to local before one of the subsequent
+        * local operations.
+        */
+       protected boolean activatePrefetch;
+
+       /**
         * Enum defining if the output of the operation should be forced 
federated, forced local or neither.
         * If it is FOUT, the output should be kept at federated sites.
         * If it is LOUT, the output should be retrieved by the coordinator.
@@ -316,9 +324,21 @@ public abstract class Lop
                return privacyConstraint;
        }
 
+       public void activatePrefetch(){
+               activatePrefetch = true;
+       }
+
+       public boolean prefetchActivated(){
+               return activatePrefetch;
+       }
+
        public void setFederatedOutput(FederatedOutput fedOutput){
                _fedOutput = fedOutput;
        }
+
+       public FederatedOutput getFederatedOutput(){
+               return _fedOutput;
+       }
        
        public void setConsumerCount(int cc) {
                consumerCount = cc;
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java 
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 9b7f1e5..76090f0 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -203,7 +203,9 @@ public class Dag<N extends Lop>
                List<Lop> node_pf = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS 
? addPrefetchLop(node_v) : node_v;
                List<Lop> node_bc = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS 
? addBroadcastLop(node_pf) : node_pf;
                // TODO: Merge via a single traversal of the nodes
-               
+
+               prefetchFederated(node_bc);
+
                // do greedy grouping of operations
                ArrayList<Instruction> inst =
                        doPlainInstructionGen(sb, node_bc);
@@ -211,6 +213,42 @@ public class Dag<N extends Lop>
                // cleanup instruction (e.g., create packed rmvar instructions)
                return cleanupInstructions(inst);
        }
+
+       /**
+        * Checks if the given input needs to be prefetched before executing 
given lop.
+        * @param input to check for prefetch
+        * @param lop which possibly needs the input prefetched
+        * @return true if given input needs to be prefetched before lop
+        */
+       private boolean inputNeedsPrefetch(Lop input, Lop lop){
+               return input.prefetchActivated() && lop.getExecType() != 
ExecType.FED
+                       && input.getFederatedOutput().isForcedFederated();
+       }
+
+       /**
+        * Add prefetch lop between input and lop.
+        * @param input to be prefetched
+        * @param lop for which the given input needs to be prefetched
+        */
+       private void addFedPrefetchLop(Lop input, Lop lop){
+               UnaryCP prefetch = new UnaryCP(input, OpOp1.PREFETCH, 
input.getDataType(), input.getValueType(), ExecType.CP);
+               prefetch.addOutput(lop);
+               lop.replaceInput(input, prefetch);
+               input.removeOutput(lop);
+       }
+
+       /**
+        * Add prefetch lops where needed.
+        * @param lops for which prefetch lops could be added.
+        */
+       private void prefetchFederated(List<Lop> lops){
+               for ( Lop lop : lops ){
+                       for ( Lop input : lop.getInputs() ){
+                               if ( inputNeedsPrefetch(input, lop) )
+                                       addFedPrefetchLop(input, lop);
+                       }
+               }
+       }
        
        private static List<Lop> doTopologicalSortTwoLevelOrder(List<Lop> v) {
                //partition nodes into leaf/inner nodes and dag root nodes,
@@ -251,7 +289,7 @@ public class Dag<N extends Lop>
                }
                return nodesWithPrefetch;
        }
-       
+
        private static List<Lop> addBroadcastLop(List<Lop> nodes) {
                List<Lop> nodesWithBroadcast = new ArrayList<>();
                
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 39309d6..680f608 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -591,17 +591,16 @@ public class FederationMap {
                }
                // derive output type
                switch(_type) {
-                       case FULL:
-                               _type = FType.FULL;
-                               break;
                        case ROW:
                                _type = FType.COL;
                                break;
                        case COL:
                                _type = FType.ROW;
                                break;
+                       case FULL:
                        case PART:
-                               _type = FType.PART;
+                               // FULL and PART are not changed
+                               break;
                        default:
                                _type = FType.OTHER;
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
index d29ef4c..2cc9d7c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -44,8 +44,8 @@ public class BroadcastCPInstruction extends 
UnaryCPInstruction {
        public void processInstruction(ExecutionContext ec) {
                ec.setVariable(output.getName(), ec.getMatrixObject(input1));
 
-               if (CommonThreadPool.triggerRDDPool == null)
-                       CommonThreadPool.triggerRDDPool = 
Executors.newCachedThreadPool();
-               CommonThreadPool.triggerRDDPool.submit(new 
TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
+               if (CommonThreadPool.triggerRemoteOPsPool == null)
+                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
+               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 00f8ac2..9d95a58 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -49,8 +49,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
                // If the next instruction which takes this output as an input 
comes before
                // the prefetch thread triggers, that instruction will start 
the operations.
                // In that case this Prefetch instruction will act like a NOOP. 
-               if (CommonThreadPool.triggerRDDPool == null)
-                       CommonThreadPool.triggerRDDPool = 
Executors.newCachedThreadPool();
-               CommonThreadPool.triggerRDDPool.submit(new 
TriggerRDDOperationsTask(ec.getMatrixObject(output)));
+               if (CommonThreadPool.triggerRemoteOPsPool == null)
+                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
+               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRDDOperationsTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
similarity index 76%
rename from 
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRDDOperationsTask.java
rename to 
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
index 0a4d1b5..6eea8c9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRDDOperationsTask.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
@@ -23,10 +23,10 @@ import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.utils.Statistics;
 
-public class TriggerRDDOperationsTask implements Runnable {
+public class TriggerRemoteOperationsTask implements Runnable {
        MatrixObject _prefetchMO;
 
-       public TriggerRDDOperationsTask(MatrixObject mo) {
+       public TriggerRemoteOperationsTask(MatrixObject mo) {
                _prefetchMO = mo;
        }
 
@@ -36,14 +36,19 @@ public class TriggerRDDOperationsTask implements Runnable {
                synchronized (_prefetchMO) {
                        // Having this check if operations are pending inside 
the 
                        // critical section safeguards against concurrent rmVar.
-                       if (_prefetchMO.isPendingRDDOps()) {
+                       if (_prefetchMO.isPendingRDDOps() || 
_prefetchMO.isFederated()) {
+                               // TODO: Add robust runtime constraints for 
federated prefetch
                                // Execute and bring the result to local
                                _prefetchMO.acquireReadAndRelease();
                                prefetched = true;
                        }
                }
-               if (DMLScript.STATISTICS && prefetched)
-                       Statistics.incSparkAsyncPrefetchCount(1);
+               if (DMLScript.STATISTICS && prefetched) {
+                       if (_prefetchMO.isFederated())
+                               Statistics.incFedAsyncPrefetchCount(1);
+                       else
+                               Statistics.incSparkAsyncPrefetchCount(1);
+               }
        }
 
 }
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java 
b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index 2fc8049..abb1ced 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -49,7 +49,7 @@ public class CommonThreadPool implements ExecutorService
        private static final int size = 
InfrastructureAnalyzer.getLocalParallelism();
        private static final ExecutorService shared = ForkJoinPool.commonPool();
        private final ExecutorService _pool;
-       public static ExecutorService triggerRDDPool = null;
+       public static ExecutorService triggerRemoteOPsPool = null;
 
        public CommonThreadPool(ExecutorService pool) {
                _pool = pool;
@@ -80,10 +80,10 @@ public class CommonThreadPool implements ExecutorService
        }
 
        public static void shutdownAsyncRDDPool() {
-               if (triggerRDDPool != null) {
+               if (triggerRemoteOPsPool != null) {
                        //shutdown prefetch/broadcast thread pool
-                       triggerRDDPool.shutdown();
-                       triggerRDDPool = null;
+                       triggerRemoteOPsPool.shutdown();
+                       triggerRemoteOPsPool = null;
                }
        }
 
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 3940921..a3da2a7 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -177,6 +177,7 @@ public class Statistics
        private static final LongAdder federatedGetCount = new LongAdder();
        private static final LongAdder federatedExecuteInstructionCount = new 
LongAdder();
        private static final LongAdder federatedExecuteUDFCount = new 
LongAdder();
+       private static final LongAdder fedAsyncPrefetchCount = new LongAdder();
 
        private static LongAdder numNativeFailures = new LongAdder();
        public static LongAdder numNativeLibMatrixMultCalls = new LongAdder();
@@ -443,6 +444,10 @@ public class Statistics
                }
        }
 
+       public static void incFedAsyncPrefetchCount(long c) {
+               fedAsyncPrefetchCount.add(c);
+       }
+
        public static void startCompileTimer() {
                if( DMLScript.STATISTICS )
                        compileStartTime = System.nanoTime();
@@ -550,6 +555,7 @@ public class Statistics
                federatedGetCount.reset();
                federatedExecuteInstructionCount.reset();
                federatedExecuteUDFCount.reset();
+               fedAsyncPrefetchCount.reset();
 
                DMLCompressionStatistics.reset();
        }
@@ -1220,6 +1226,8 @@ public class Statistics
                                sb.append("Federated Execute (Inst, UDF):\t" +
                                        
federatedExecuteInstructionCount.longValue() + "/" +
                                        federatedExecuteUDFCount.longValue() + 
".\n");
+                               sb.append("Federated prefetch count:\t" +
+                                       fedAsyncPrefetchCount.longValue() + 
".\n");
                        }
                        if( transformEncoderCount.longValue() > 0) {
                                //TODO: Cleanup and condense
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index e0ef884..1e59b86 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -47,6 +47,8 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        private final static String TEST_NAME_4 = 
"FederatedMultiplyPlanningTest4";
        private final static String TEST_NAME_5 = 
"FederatedMultiplyPlanningTest5";
        private final static String TEST_NAME_6 = 
"FederatedMultiplyPlanningTest6";
+       private final static String TEST_NAME_7 = 
"FederatedMultiplyPlanningTest7";
+       private final static String TEST_NAME_8 = 
"FederatedMultiplyPlanningTest8";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
@@ -64,6 +66,8 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                addTestConfiguration(TEST_NAME_4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"}));
                addTestConfiguration(TEST_NAME_5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"}));
                addTestConfiguration(TEST_NAME_6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
        }
 
        @Parameterized.Parameters
@@ -112,6 +116,18 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                federatedTwoMatricesSingleNodeTest(TEST_NAME_6, 
expectedHeavyHitters);
        }
 
+       @Test
+       public void federatedMultiplyDoubleHop() {
+               String[] expectedHeavyHitters = new String[]{"fed_*", 
"fed_fedinit", "fed_r'", "fed_ba+*"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_7, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedMultiplyDoubleHop2() {
+               String[] expectedHeavyHitters = new String[]{"fed_fedinit", 
"fed_ba+*"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_8, 
expectedHeavyHitters);
+       }
+
        private void writeStandardMatrix(String matrixName, long seed){
                writeStandardMatrix(matrixName, seed, new 
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
        }
@@ -158,6 +174,14 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                        writeRowFederatedVector("Y1", 44);
                        writeRowFederatedVector("Y2", 21);
                }
+               else if ( testName.equals(TEST_NAME_8) ){
+                       writeColStandardMatrix("X1", 42, null);
+                       writeColStandardMatrix("X2", 1340, null);
+                       writeColStandardMatrix("Y1", 44, null);
+                       writeColStandardMatrix("Y2", 21, null);
+                       writeColStandardMatrix("W1", 76, null);
+                       writeColStandardMatrix("W2", 11, null);
+               }
                else {
                        writeStandardMatrix("X1", 42);
                        writeStandardMatrix("X2", 1340);
@@ -201,12 +225,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                        "X2=" + TestUtils.federatedAddress(port2, input("X2")),
                        "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
                        "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), 
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
-               if ( testName.equals(TEST_NAME_4) || 
testName.equals(TEST_NAME_5) ){
-                       programArgs = new String[] {"-stats","-explain", 
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
-                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
-                               "Y1=" + input("Y1"),
-                               "Y2=" + input("Y2"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
-               }
+               rewriteRealProgramArgs(testName, port1, port2);
                runTest(true, false, null, -1);
 
                OptimizerUtils.FEDERATED_COMPILATION = false;
@@ -215,6 +234,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                fullDMLScriptName = HOME + testName + "Reference.dml";
                programArgs = new String[] {"-nvargs", "X1=" + input("X1"), 
"X2=" + input("X2"), "Y1=" + input("Y1"),
                        "Y2=" + input("Y2"), "Z=" + expected("Z")};
+               rewriteReferenceProgramArgs(testName);
                runTest(true, false, null, -1);
 
                // compare via files
@@ -228,5 +248,29 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
        }
+
+       private void rewriteRealProgramArgs(String testName, int port1, int 
port2){
+               if ( testName.equals(TEST_NAME_4) || 
testName.equals(TEST_NAME_5) ){
+                       programArgs = new String[] {"-stats","-explain", 
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y1=" + input("Y1"),
+                               "Y2=" + input("Y2"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+               } else if ( testName.equals(TEST_NAME_8) ){
+                       programArgs = new String[] {"-stats","-explain", 
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+                               "Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")),
+                               "W1=" + input("W1"),
+                               "W2=" + input("W2"),
+                               "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+               }
+       }
+
+       private void rewriteReferenceProgramArgs(String testName){
+               if ( testName.equals(TEST_NAME_8) ){
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+                               "Y2=" + input("Y2"), "W1=" + input("W1"), "W2=" 
+ input("W2"), "Z=" + expected("Z")};
+               }
+       }
 }
 
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7.dml
new file mode 100644
index 0000000..5d4a1d3
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), 
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+              ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), 
list($r, $c)))
+Z0 = X * Y
+Z = t(Z0) %*% X
+Z1 = Z %*% t(colSums(Z0))
+write(Z1, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7Reference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7Reference.dml
new file mode 100644
index 0000000..76212a0
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7Reference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+Z0 = X * Y
+Z = t(Z0) %*% X
+Z1 = Z %*% t(colSums(Z0))
+write(Z1, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8.dml
new file mode 100644
index 0000000..5f3223c
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r, $c / 2), list(0, $c / 2), 
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+              ranges=list(list(0, 0), list($r, $c / 2), list(0, $c / 2), 
list($r, $c)))
+W = cbind(read($W1), read($W2))
+Z1 = Y
+Z2 = Z1 %*% t(X)
+Z3 = Z1 %*% t(W)
+Z4 = sum(Z3) * sum(Z2)
+write(Z4, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8Reference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8Reference.dml
new file mode 100644
index 0000000..c8c1797
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8Reference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = cbind(read($X1), read($X2))
+Y = cbind(read($Y1), read($Y2))
+W = cbind(read($W1), read($W2))
+Z1 = Y
+Z2 = Z1 %*% t(X)
+Z3 = Z1 %*% t(W)
+Z4 = sum(Z3) * sum(Z2)
+write(Z4, $Z)

Reply via email to