This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c5ab81c6cf [SYSTEMDS-3790] Extraction of optimal FedPlans and conflict 
handling
c5ab81c6cf is described below

commit c5ab81c6cf8d4288762b2355f060755433e7e720
Author: min-guk <[email protected]>
AuthorDate: Sat Jan 25 10:22:27 2025 +0100

    [SYSTEMDS-3790] Extraction of optimal FedPlans and conflict handling
    
    Closes #2175.
---
 .../sysds/hops/fedplanner/FederatedMemoTable.java  | 214 +++++++--------------
 .../hops/fedplanner/FederatedMemoTablePrinter.java | 139 +++++++++++++
 .../fedplanner/FederatedPlanCostEnumerator.java    | 112 ++++++++++-
 .../fedplanner/FederatedPlanCostEstimator.java     | 130 ++++++++++++-
 .../federated/FederatedPlanCostEnumeratorTest.java |  18 +-
 .../FederatedPlanCostEnumeratorTest1.dml}          |   0
 .../FederatedPlanCostEnumeratorTest2.dml}          |   5 +-
 .../FederatedPlanCostEnumeratorTest3.dml}          |   5 +-
 8 files changed, 460 insertions(+), 163 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
index 16240f0281..b2b58871f6 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.hops.fedplanner;
 
 import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.lang3.tuple.ImmutablePair;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
@@ -29,8 +28,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.ArrayList;
 import java.util.Map;
-import java.util.HashSet;
-import java.util.Set;
 
 /**
  * A Memoization Table for managing federated plans (FedPlan) based on 
combinations of Hops and fedOutTypes.
@@ -71,12 +68,11 @@ public class FederatedMemoTable {
         * Retrieves the minimum cost child plan considering the parent's 
output type.
         * The cost is calculated using getParentViewCost to account for 
potential type mismatches.
         * 
-        * @param childHopID ?
-        * @param childFedOutType ?
-        * @return ?
+        * @param fedPlanPair ???
+        * @return min cost fed plan
         */
-       public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput 
childFedOutType) {
-               FedPlanVariants fedPlanVariantList = hopMemoTable.get(new 
ImmutablePair<>(childHopID, childFedOutType));
+       public FedPlan getMinCostFedPlan(Pair<Long, FederatedOutput> 
fedPlanPair) {
+               FedPlanVariants fedPlanVariantList = 
hopMemoTable.get(fedPlanPair);
                return fedPlanVariantList._fedPlanVariants.stream()
                                
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
                                .orElse(null);
@@ -86,6 +82,22 @@ public class FederatedMemoTable {
                return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
        }
 
+       public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput> 
fedPlanPair) {
+               return hopMemoTable.get(fedPlanPair);
+       }
+
+       public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput 
fedOutType) {
+               // Todo: Consider whether to verify if pruning has been 
performed
+               FedPlanVariants fedPlanVariantList = hopMemoTable.get(new 
ImmutablePair<>(hopID, fedOutType));
+               return fedPlanVariantList._fedPlanVariants.get(0);
+       }
+
+       public FedPlan getFedPlanAfterPrune(Pair<Long, FederatedOutput> 
fedPlanPair) {
+               // Todo: Consider whether to verify if pruning has been 
performed
+               FedPlanVariants fedPlanVariantList = 
hopMemoTable.get(fedPlanPair);
+               return fedPlanVariantList._fedPlanVariants.get(0);
+       }
+
        /**
         * Checks if the memo table contains an entry for a given Hop and 
fedOutType.
         *
@@ -98,162 +110,77 @@ public class FederatedMemoTable {
        }
 
        /**
-        * Prunes all entries in the memo table, retaining only the minimum-cost
-        * FedPlan for each entry.
-        */
-       public void pruneMemoTable() {
-               for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants> 
entry : hopMemoTable.entrySet()) {
-                       List<FedPlan> fedPlanList = 
entry.getValue().getFedPlanVariants();
-                       if (fedPlanList.size() > 1) {
-                               // Find the FedPlan with the minimum cost
-                               FedPlan minCostPlan = fedPlanList.stream()
-                                                       
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
-                                               .orElse(null);
-
-                               // Retain only the minimum cost plan
-                               fedPlanList.clear();
-                               fedPlanList.add(minCostPlan);
-                       }
-               }
-       }
-
-       /**
-        * Recursively prints a tree representation of the DAG starting from 
the given root FedPlan.
-        * Includes information about hopID, fedOutType, TotalCost, SelfCost, 
and NetCost for each node.
+        * Prunes the specified entry in the memo table, retaining only the 
minimum-cost
+        * FedPlan for the given Hop ID and federated output type.
         *
-        * @param rootFedPlan The starting point FedPlan to print
+        * @param hopID The ID of the Hop to prune
+        * @param federatedOutput The federated output type associated with the 
Hop
         */
-       public void printFedPlanTree(FedPlan rootFedPlan) {
-               Set<FedPlan> visited = new HashSet<>();
-               printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
+       public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) {
+               hopMemoTable.get(new ImmutablePair<>(hopID, 
federatedOutput)).prune();
        }
 
        /**
-        * Helper method to recursively print the FedPlan tree.
-        *
-        * @param plan  The current FedPlan to print
-        * @param visited Set to keep track of visited FedPlans (prevents 
cycles)
-        * @param depth   The current depth level for indentation
-        * @param isLast  Whether this node is the last child of its parent
+        * Represents common properties and costs associated with a Hop.
+        * This class holds a reference to the Hop and tracks its execution and 
network transfer costs.
         */
-       private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan> 
visited, int depth, boolean isLast) {
-               if (plan == null || visited.contains(plan)) {
-                       return;
-               }
-
-               visited.add(plan);
-
-               Hop hop = plan.getHopRef();
-               StringBuilder sb = new StringBuilder();
-
-               // Add FedPlan information
-               sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
-                               .append(plan.getHopRef().getOpString())
-                               .append(" [")
-                               .append(plan.getFedOutType())
-                               .append("]");
-
-               StringBuilder childs = new StringBuilder();
-               childs.append(" (");
-               boolean childAdded = false;
-               for( Hop input : hop.getInput()){
-                       childs.append(childAdded?",":"");
-                       childs.append(input.getHopID());
-                       childAdded = true;
-               }
-               childs.append(")");
-               if( childAdded )
-                       sb.append(childs.toString());
-                
-                
-               sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
-                               plan.getTotalCost(),
-                               plan.getSelfCost(),
-                               plan.getNetTransferCost()));
-
-               // Add matrix characteristics
-               sb.append(" [")
-                       .append(hop.getDim1()).append(", ")
-                       .append(hop.getDim2()).append(", ")
-                       .append(hop.getBlocksize()).append(", ")
-                       .append(hop.getNnz());
-
-               if (hop.getUpdateType().isInPlace()) {
-                       sb.append(", 
").append(hop.getUpdateType().toString().toLowerCase());
-               }
-               sb.append("]");
-
-               // Add memory estimates
-               sb.append(" [")
-                       
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
-                       
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
-                       
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
-                       
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
-
-               // Add reblock and checkpoint requirements
-               if (hop.requiresReblock() && hop.requiresCheckpoint()) {
-                       sb.append(" [rblk, chkpt]");
-               } else if (hop.requiresReblock()) {
-                       sb.append(" [rblk]");
-               } else if (hop.requiresCheckpoint()) {
-                       sb.append(" [chkpt]");
-               }
-
-               // Add execution type
-               if (hop.getExecType() != null) {
-                       sb.append(", ").append(hop.getExecType());
-               }
-
-               System.out.println(sb);
-
-               // Process child nodes
-               List<Pair<Long, FederatedOutput>> childRefs = 
plan.getChildFedPlans();
-               for (int i = 0; i < childRefs.size(); i++) {
-                       Pair<Long, FederatedOutput> childRef = childRefs.get(i);
-                       FedPlanVariants childVariants = 
getFedPlanVariants(childRef.getLeft(), childRef.getRight());
-                       if (childVariants == null || 
childVariants.getFedPlanVariants().isEmpty())
-                               continue;
+       public static class HopCommon {
+               protected final Hop hopRef;         // Reference to the 
associated Hop
+               protected double selfCost;          // Current execution cost 
(compute + memory access)
+               protected double netTransferCost;   // Network transfer cost
 
-                       boolean isLastChild = (i == childRefs.size() - 1);
-                       for (FedPlan childPlan : 
childVariants.getFedPlanVariants()) {
-                               printFedPlanTreeRecursive(childPlan, visited, 
depth + 1, isLastChild);
-                       }
+               protected HopCommon(Hop hopRef) {
+                       this.hopRef = hopRef;
+                       this.selfCost = 0;
+                       this.netTransferCost = 0;
                }
        }
 
        /**
-        * Represents a collection of federated execution plan variants for a 
specific Hop.
-        * Contains cost information and references to the associated plans.
+        * Represents a collection of federated execution plan variants for a 
specific Hop and FederatedOutput.
+        * This class contains cost information and references to the 
associated plans.
+        * It uses HopCommon to store common properties and costs related to 
the Hop.
         */
        public static class FedPlanVariants {
-               protected final Hop hopRef;              // Reference to the 
associated Hop
-               protected double selfCost;         // Current execution cost 
(compute + memory access)
-               protected double netTransferCost;   // Network transfer cost
-               private final FederatedOutput fedOutType;          // Output 
type (FOUT/LOUT)
-               protected List<FedPlan> _fedPlanVariants;       // List of plan 
variants
+               protected HopCommon hopCommon;      // Common properties and 
costs for the Hop
+               private final FederatedOutput fedOutType;  // Output type 
(FOUT/LOUT)
+               protected List<FedPlan> _fedPlanVariants;  // List of plan 
variants
 
                public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
-                       this.hopRef = hopRef;
+                       this.hopCommon = new HopCommon(hopRef);
                        this.fedOutType = fedOutType;
-                       this.selfCost = 0;
-                       this.netTransferCost = 0;
                        this._fedPlanVariants = new ArrayList<>();
                }
 
-               public int size() {return _fedPlanVariants.size();}
                public void addFedPlan(FedPlan fedPlan) 
{_fedPlanVariants.add(fedPlan);}
                public List<FedPlan> getFedPlanVariants() {return 
_fedPlanVariants;}
+               public boolean isEmpty() {return _fedPlanVariants.isEmpty();}
+
+               public void prune() {
+                       if (_fedPlanVariants.size() > 1) {
+                               // Find the FedPlan with the minimum cost
+                               FedPlan minCostPlan = _fedPlanVariants.stream()
+                                               
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+                                               .orElse(null);
+
+                               // Retain only the minimum cost plan
+                               _fedPlanVariants.clear();
+                               _fedPlanVariants.add(minCostPlan);
+                       }
+               }
        }
 
        /**
         * Represents a single federated execution plan with its associated 
costs and dependencies.
-        * Contains:
+        * This class contains:
         * 1. selfCost: Cost of current hop (compute + input/output memory 
access)
         * 2. totalCost: Cumulative cost including this plan and all child plans
         * 3. netTransferCost: Network transfer cost for this plan to parent 
plan.
+        * 
+        * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon 
to manage common properties and costs.
         */
        public static class FedPlan {
-               private double totalCost;                                 // 
Total cost including child plans
+               private double totalCost;                  // Total cost 
including child plans
                private final FedPlanVariants fedPlanVariants;  // Reference to 
variant list
                private final List<Pair<Long, FederatedOutput>> childFedPlans;  
// Child plan references
 
@@ -264,25 +191,26 @@ public class FederatedMemoTable {
                }
 
                public void setTotalCost(double totalCost) {this.totalCost = 
totalCost;}
-               public void setSelfCost(double selfCost) 
{fedPlanVariants.selfCost = selfCost;}
-               public void setNetTransferCost(double netTransferCost) 
{fedPlanVariants.netTransferCost = netTransferCost;}
-
-               public Hop getHopRef() {return fedPlanVariants.hopRef;}
+               public void setSelfCost(double selfCost) 
{fedPlanVariants.hopCommon.selfCost = selfCost;}
+               public void setNetTransferCost(double netTransferCost) 
{fedPlanVariants.hopCommon.netTransferCost = netTransferCost;}
+               
+               public Hop getHopRef() {return 
fedPlanVariants.hopCommon.hopRef;}
+               public long getHopID() {return 
fedPlanVariants.hopCommon.hopRef.getHopID();}
                public FederatedOutput getFedOutType() {return 
fedPlanVariants.fedOutType;}
                public double getTotalCost() {return totalCost;}
-               public double getSelfCost() {return fedPlanVariants.selfCost;}
-               private double getNetTransferCost() {return 
fedPlanVariants.netTransferCost;}
+               public double getSelfCost() {return 
fedPlanVariants.hopCommon.selfCost;}
+               public double getNetTransferCost() {return 
fedPlanVariants.hopCommon.netTransferCost;}
                public List<Pair<Long, FederatedOutput>> getChildFedPlans() 
{return childFedPlans;}
 
                /**
                 * Calculates the conditional network transfer cost based on 
output type compatibility.
                 * Returns 0 if output types match, otherwise returns the 
network transfer cost.
-                * @param parentFedOutType ?
-                * @return ?
+                * @param parentFedOutType The federated output type of the 
parent plan.
+                * @return The conditional network transfer cost.
                 */
                public double getCondNetTransferCost(FederatedOutput 
parentFedOutType) {
                        if (parentFedOutType == getFedOutType()) return 0;
-                       return fedPlanVariants.netTransferCost;
+                       return fedPlanVariants.hopCommon.netTransferCost;
                }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
new file mode 100644
index 0000000000..f7b3343a98
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class FederatedMemoTablePrinter {
+       /**
+        * Recursively prints a tree representation of the DAG starting from 
the given root FedPlan.
+        * Includes information about hopID, fedOutType, TotalCost, SelfCost, 
and NetCost for each node.
+        * Additionally, prints the additional total cost once at the beginning.
+        *
+        * @param rootFedPlan The starting point FedPlan to print
+        * @param memoTable The memoization table containing FedPlan variants
+        * @param additionalTotalCost The additional cost to be printed once
+        */
+       public static void printFedPlanTree(FederatedMemoTable.FedPlan 
rootFedPlan, FederatedMemoTable memoTable,
+                                                                               
double additionalTotalCost) {
+               System.out.println("Additional Cost: " + additionalTotalCost);
+               Set<FederatedMemoTable.FedPlan> visited = new HashSet<>();
+               printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0);
+       }
+
+       /**
+        * Helper method to recursively print the FedPlan tree.
+        *
+        * @param plan  The current FedPlan to print
+        * @param visited Set to keep track of visited FedPlans (prevents 
cycles)
+        * @param depth   The current depth level for indentation
+        */
+       private static void 
printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable 
memoTable,
+                                                                               
   Set<FederatedMemoTable.FedPlan> visited, int depth) {
+               if (plan == null || visited.contains(plan)) {
+                       return;
+               }
+
+               visited.add(plan);
+
+               Hop hop = plan.getHopRef();
+               StringBuilder sb = new StringBuilder();
+
+               // Add FedPlan information
+               sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
+                               .append(plan.getHopRef().getOpString())
+                               .append(" [")
+                               .append(plan.getFedOutType())
+                               .append("]");
+
+               StringBuilder childs = new StringBuilder();
+               childs.append(" (");
+               boolean childAdded = false;
+               for( Hop input : hop.getInput()){
+                       childs.append(childAdded?",":"");
+                       childs.append(input.getHopID());
+                       childAdded = true;
+               }
+               childs.append(")");
+               if( childAdded )
+                       sb.append(childs.toString());
+
+
+               sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
+                               plan.getTotalCost(),
+                               plan.getSelfCost(),
+                               plan.getNetTransferCost()));
+
+               // Add matrix characteristics
+               sb.append(" [")
+                               .append(hop.getDim1()).append(", ")
+                               .append(hop.getDim2()).append(", ")
+                               .append(hop.getBlocksize()).append(", ")
+                               .append(hop.getNnz());
+
+               if (hop.getUpdateType().isInPlace()) {
+                       sb.append(", 
").append(hop.getUpdateType().toString().toLowerCase());
+               }
+               sb.append("]");
+
+               // Add memory estimates
+               sb.append(" [")
+                               
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
+                               
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
+                               
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
+                               
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
+
+               // Add reblock and checkpoint requirements
+               if (hop.requiresReblock() && hop.requiresCheckpoint()) {
+                       sb.append(" [rblk, chkpt]");
+               } else if (hop.requiresReblock()) {
+                       sb.append(" [rblk]");
+               } else if (hop.requiresCheckpoint()) {
+                       sb.append(" [chkpt]");
+               }
+
+               // Add execution type
+               if (hop.getExecType() != null) {
+                       sb.append(", ").append(hop.getExecType());
+               }
+
+               System.out.println(sb);
+
+               // Process child nodes
+               List<Pair<Long, FEDInstruction.FederatedOutput>> 
childFedPlanPairs = plan.getChildFedPlans();
+               for (int i = 0; i < childFedPlanPairs.size(); i++) {
+                       Pair<Long, FEDInstruction.FederatedOutput> 
childFedPlanPair = childFedPlanPairs.get(i);
+                       FederatedMemoTable.FedPlanVariants childVariants = 
memoTable.getFedPlanVariants(childFedPlanPair);
+                       if (childVariants == null || childVariants.isEmpty())
+                               continue;
+
+                       for (FederatedMemoTable.FedPlan childPlan : 
childVariants.getFedPlanVariants()) {
+                               printFedPlanTreeRecursive(childPlan, memoTable, 
visited, depth + 1);
+                       }
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
index 73e8d5d693..be1cfa7cdf 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
@@ -20,10 +20,14 @@
 package org.apache.sysds.hops.fedplanner;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.Objects;
+import java.util.LinkedHashMap;
 
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.ImmutablePair;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
 import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
@@ -36,12 +40,13 @@ import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
  */
 public class FederatedPlanCostEnumerator {
        /**
-        * Entry point for federated plan enumeration. Creates a memo table and 
returns
-        * the minimum cost plan for the entire DAG.
+        * Entry point for federated plan enumeration. This method creates a 
memo table
+        * and returns the minimum cost plan for the entire Directed Acyclic 
Graph (DAG).
+        * It also resolves conflicts where FedPlans have different 
FederatedOutput types.
         * 
-        * @param rootHop ?
-        * @param printTree ?
-        * @return ?
+        * @param rootHop The root Hop node from which to start the plan 
enumeration.
+        * @param printTree A boolean flag indicating whether to print the 
federated plan tree.
+        * @return The optimal FedPlan with the minimum cost for the entire DAG.
         */
        public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean 
printTree) {
                // Create new memo table to store all plan variants
@@ -52,8 +57,12 @@ public class FederatedPlanCostEnumerator {
 
                // Return the minimum cost plan for the root node
                FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), 
memoTable);
-               memoTable.pruneMemoTable();
-               if (printTree) memoTable.printFedPlanTree(optimalPlan);
+
+               // Detect conflicts in the federated plans where different 
FedPlans have different FederatedOutput types
+               double additionalTotalCost = 
detectAndResolveConflictFedPlan(optimalPlan, memoTable);
+
+               // Optionally print the federated plan tree if requested
+               if (printTree) 
FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, 
additionalTotalCost);
 
                return optimalPlan;
        }
@@ -106,6 +115,10 @@ public class FederatedPlanCostEnumerator {
                        FedPlan lOutPlan = memoTable.addFedPlan(hop, 
FederatedOutput.LOUT, planChilds);
                        
FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
                }
+
+               // Prune MemoTable for hop.
+               memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT);
+               memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT);
        }
 
        /**
@@ -133,4 +146,89 @@ public class FederatedPlanCostEnumerator {
                }
                return minlOutFedPlan;
        }
+
+       /**
+        * Detects and resolves conflicts in federated plans starting from the 
root plan.
+        * This function performs a breadth-first search (BFS) to traverse the 
federated plan tree.
+        * It identifies conflicts where the same plan ID has different 
federated output types.
+        * For each conflict, it records the plan ID and its conflicting parent 
plans.
+        * The function ensures that each plan ID is associated with a 
consistent federated output type
+        * by resolving these conflicts iteratively.
+        *
+        * The process involves:
+        * - Using a map to track conflicts, associating each plan ID with its 
federated output type
+        *   and a list of parent plans.
+        * - Storing detected conflicts in a linked map, each entry containing 
a plan ID and its
+        *   conflicting parent plans.
+        * - Performing BFS traversal starting from the root plan, checking 
each child plan for conflicts.
+        * - If a conflict is detected (i.e., a plan ID has different output 
types), the conflicting plan
+        *   is removed from the BFS queue and added to the conflict map to 
prevent duplicate calculations.
+        * - Resolving conflicts by ensuring a consistent federated output type 
across the plan.
+        * - Re-running BFS with resolved conflicts to ensure all 
inconsistencies are addressed.
+        *
+        * @param rootPlan The root federated plan from which to start the 
conflict detection.
+        * @param memoTable The memoization table used to retrieve pruned 
federated plans.
+        * @return The cumulative additional cost for resolving conflicts.
+        */
+       private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, 
FederatedMemoTable memoTable) {
+               // Map to track conflicts: maps a plan ID to its federated 
output type and list of parent plans
+               Map<Long, Pair<FederatedOutput, List<FedPlan>>> 
conflictCheckMap = new HashMap<>();
+
+               // LinkedMap to store detected conflicts, each with a plan ID 
and its conflicting parent plans
+               LinkedHashMap<Long, List<FedPlan>> conflictLinkedMap = new 
LinkedHashMap<>();
+
+               // LinkedMap for BFS traversal starting from the root plan (Do 
not use value (boolean))
+               LinkedHashMap<FedPlan, Boolean> bfsLinkedMap = new 
LinkedHashMap<>();
+               bfsLinkedMap.put(rootPlan, true);
+
+               // Array to store cumulative additional cost for resolving 
conflicts
+               double[] cumulativeAdditionalCost = new double[]{0.0};
+
+               while (!bfsLinkedMap.isEmpty()) {
+                       // Perform BFS to detect conflicts in federated plans
+                       while (!bfsLinkedMap.isEmpty()) {
+                               FedPlan currentPlan = 
bfsLinkedMap.keySet().iterator().next();
+                               bfsLinkedMap.remove(currentPlan);
+
+                               // Iterate over each child plan of the current 
plan
+                               for (Pair<Long, FederatedOutput> childPlanPair 
: currentPlan.getChildFedPlans()) {
+                                       FedPlan childFedPlan = 
memoTable.getFedPlanAfterPrune(childPlanPair);
+
+                                       // Check if the child plan ID is 
already visited
+                                       if 
(conflictCheckMap.containsKey(childPlanPair.getLeft())) {
+                                               // Retrieve the existing 
conflict pair for the child plan
+                                               Pair<FederatedOutput, 
List<FedPlan>> conflictChildPlanPair = 
conflictCheckMap.get(childPlanPair.getLeft());
+                                               // Add the current plan to the 
list of parent plans
+                                               
conflictChildPlanPair.getRight().add(currentPlan);
+
+                                               // If the federated output type 
differs, a conflict is detected
+                                               if 
(conflictChildPlanPair.getLeft() != childPlanPair.getRight()) {
+                                                       // If this is the first 
detection, remove conflictChildFedPlan from the BFS queue and add it to the 
conflict linked map (queue)
+                                                       // If the existing 
FedPlan is not removed from the bfsqueue or both actions are performed, 
duplicate calculations for the same FedPlan and its children occur
+                                                       if 
(!conflictLinkedMap.containsKey(childPlanPair.getLeft())) {
+                                                               
conflictLinkedMap.put(childPlanPair.getLeft(), 
conflictChildPlanPair.getRight());
+                                                               
bfsLinkedMap.remove(childFedPlan);
+                                                       }
+                                               }
+                                       } else {
+                                               // If no conflict exists, 
create a new entry in the conflict check map
+                                               List<FedPlan> parentFedPlanList 
= new ArrayList<>();
+                                               
parentFedPlanList.add(currentPlan);
+
+                                               // Map the child plan ID to its 
output type and list of parent plans
+                                               
conflictCheckMap.put(childPlanPair.getLeft(), new 
ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList));
+                                               // Add the child plan to the 
BFS queue
+                                               bfsLinkedMap.put(childFedPlan, 
true);
+                                       }
+                               }
+                       }
+                       // Resolve these conflicts to ensure a consistent 
federated output type across the plan
+                       // Re-run BFS with resolved conflicts
+                       bfsLinkedMap = 
FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, 
cumulativeAdditionalCost);
+                       conflictLinkedMap.clear();
+               }
+
+               // Return the cumulative additional cost for resolving conflicts
+               return cumulativeAdditionalCost[0];
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
index a716c3321d..7bc7339563 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
@@ -24,6 +24,11 @@ import org.apache.sysds.hops.cost.ComputeCost;
 import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 
+import java.util.LinkedHashMap;
+import java.util.NoSuchElementException;
+import java.util.List;
+import java.util.Map;
+
 /**
  * Cost estimator for federated execution plans.
  * Calculates computation, memory access, and network transfer costs for 
federated operations.
@@ -47,7 +52,7 @@ public class FederatedPlanCostEstimator {
         * @param memoTable Table containing all plan variants
         */
        public static void computeFederatedPlanCost(FedPlan currentPlan, 
FederatedMemoTable memoTable) {
-               double totalCost = 0;
+               double totalCost;
                Hop currentHop = currentPlan.getHopRef();
 
                // Step 1: Calculate current node costs if not already computed
@@ -62,11 +67,11 @@ public class FederatedPlanCostEstimator {
                }
                
                // Step 2: Process each child plan and add their costs
-               for (Pair<Long, FederatedOutput> planRefMeta : 
currentPlan.getChildFedPlans()) {
+               for (Pair<Long, FederatedOutput> childPlanPair : 
currentPlan.getChildFedPlans()) {
                        // Find minimum cost child plan considering federation 
type compatibility
                        // Note: This approach might lead to suboptimal or 
wrong solutions when a child has multiple parents
                        // because we're selecting child plans independently 
for each parent
-                       FedPlan planRef = 
memoTable.getMinCostChildFedPlan(planRefMeta.getLeft(), planRefMeta.getRight());
+                       FedPlan planRef = 
memoTable.getMinCostFedPlan(childPlanPair);
 
                        // Add child plan cost (includes network transfer cost 
if federation types differ)
                        totalCost += planRef.getTotalCost() + 
planRef.getCondNetTransferCost(currentPlan.getFedOutType());
@@ -76,6 +81,125 @@ public class FederatedPlanCostEstimator {
                currentPlan.setTotalCost(totalCost);
        }
 
+       /**
+        * Resolves conflicts in federated plans where different plans have 
different FederatedOutput types.
+        * This function traverses the list of conflicting plans in reverse 
order to ensure that conflicts
+        * are resolved from the bottom-up, allowing for consistent federated 
output types across the plan.
+        * It calculates additional costs for each potential resolution and 
updates the cumulative additional cost.
+        *
+        * @param memoTable The FederatedMemoTable containing all federated 
plan variants.
+        * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent 
plans with conflicting federated outputs.
+        * @param cumulativeAdditionalCost An array to store the cumulative 
additional cost incurred by resolving conflicts.
+        * @return A LinkedHashMap of resolved federated plans, marked with a 
boolean indicating resolution status.
+        */
+       public static LinkedHashMap<FedPlan, Boolean> 
resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap<Long, 
List<FedPlan>> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) {
+               // LinkedHashMap to store resolved federated plans for BFS 
traversal.
+               LinkedHashMap<FedPlan, Boolean> resolvedFedPlanLinkedMap = new 
LinkedHashMap<>();
+
+               // Traverse the conflictFedPlanList in reverse order after BFS 
to resolve conflicts
+               for (Map.Entry<Long, List<FedPlan>> conflictFedPlanPair : 
conflictFedPlanLinkedMap.entrySet()) {
+                       long conflictHopID = conflictFedPlanPair.getKey();
+                       List<FedPlan> conflictParentFedPlans = 
conflictFedPlanPair.getValue();
+
+                       // Retrieve the conflicting federated plans for LOUT 
and FOUT types
+                       FedPlan confilctLOutFedPlan = 
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT);
+                       FedPlan confilctFOutFedPlan = 
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT);
+
+                       // Variables to store additional costs for LOUT and 
FOUT types
+                       double lOutAdditionalCost = 0;
+                       double fOutAdditionalCost = 0;
+
+                       // Flags to check if the plan involves network transfer
+                       // Network transfer cost is calculated only once, even 
if it occurs multiple times
+                       boolean isLOutNetTransfer = false;
+                       boolean isFOutNetTransfer = false; 
+
+                       // Determine the optimal federated output type based on 
the calculated costs
+                       FederatedOutput optimalFedOutType;
+
+                       // Iterate over each parent federated plan in the 
current conflict pair
+                       for (FedPlan conflictParentFedPlan : 
conflictParentFedPlans) {
+                               // Find the calculated FedOutType of the child 
plan
+                               Pair<Long, FederatedOutput> 
cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream()
+                                       .filter(pair -> 
pair.getLeft().equals(conflictHopID))
+                                       .findFirst()
+                                       .orElseThrow(() -> new 
NoSuchElementException("No matching pair found for ID: " + conflictHopID));
+                                                       
+                               // CASE 1. Calculated LOUT / Parent LOUT / 
Current LOUT: Total cost remains unchanged.
+                               // CASE 2. Calculated LOUT / Parent FOUT / 
Current LOUT: Total cost remains unchanged, subtract net cost, add net cost 
later.
+                               // CASE 3. Calculated FOUT / Parent LOUT / 
Current LOUT: Change total cost, subtract net cost.
+                               // CASE 4. Calculated FOUT / Parent FOUT / 
Current LOUT: Change total cost, add net cost later.
+                               // CASE 5. Calculated LOUT / Parent LOUT / 
Current FOUT: Change total cost, add net cost later.
+                               // CASE 6. Calculated LOUT / Parent FOUT / 
Current FOUT: Change total cost, subtract net cost.
+                               // CASE 7. Calculated FOUT / Parent LOUT / 
Current FOUT: Total cost remains unchanged, subtract net cost, add net cost 
later.
+                               // CASE 8. Calculated FOUT / Parent FOUT / 
Current FOUT: Total cost remains unchanged.
+                               
+                               // Adjust LOUT, FOUT costs based on the 
calculated plan's output type
+                               if (cacluatedConflictPlanPair.getRight() == 
FederatedOutput.LOUT) {
+                                       // When changing from calculated LOUT 
to current FOUT, subtract the existing LOUT total cost and add the FOUT total 
cost
+                                       // When maintaining calculated LOUT to 
current LOUT, the total cost remains unchanged.
+                                       fOutAdditionalCost += 
confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost();
+
+                                       if 
(conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) {
+                                               // (CASE 1) Previously, 
calculated was LOUT and parent was LOUT, so no network transfer cost occurred
+                                               // (CASE 5) If changing from 
calculated LOUT to current FOUT, network transfer cost occurs, but calculated 
later
+                                               isFOutNetTransfer = true;
+                                       } else {
+                                               // Previously, calculated was 
LOUT and parent was FOUT, so network transfer cost occurred
+                       // (CASE 2) If maintaining calculated LOUT to current 
LOUT, subtract existing network transfer cost and calculate later
+                                               isLOutNetTransfer = true;
+                                               lOutAdditionalCost -= 
confilctLOutFedPlan.getNetTransferCost();
+
+                                               // (CASE 6) If changing from 
calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it
+                                               fOutAdditionalCost -= 
confilctLOutFedPlan.getNetTransferCost();
+                                       }
+                               } else {
+                                       lOutAdditionalCost += 
confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost();
+
+                                       if 
(conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) {
+                                               isLOutNetTransfer = true;
+                                       } else {
+                                               isFOutNetTransfer = true;
+                                               lOutAdditionalCost -= 
confilctLOutFedPlan.getNetTransferCost();
+                                               fOutAdditionalCost -= 
confilctLOutFedPlan.getNetTransferCost();
+                                       }
+                               }
+                       }
+
+                       // Add network transfer costs if applicable
+                       if (isLOutNetTransfer) {
+                               lOutAdditionalCost += 
confilctLOutFedPlan.getNetTransferCost();
+                       }
+                       if (isFOutNetTransfer) {
+                               fOutAdditionalCost += 
confilctFOutFedPlan.getNetTransferCost();
+                       }
+
+                       // Determine the optimal federated output type based on 
the calculated costs
+                       if (lOutAdditionalCost <= fOutAdditionalCost) {
+                               optimalFedOutType = FederatedOutput.LOUT;
+                               cumulativeAdditionalCost[0] += 
lOutAdditionalCost;
+                               
resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true);
+                       } else {
+                               optimalFedOutType = FederatedOutput.FOUT;
+                               cumulativeAdditionalCost[0] += 
fOutAdditionalCost;
+                               
resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true);
+                       }    
+
+                       // Update only the optimal federated output type, not 
the cost itself or recursively
+                       for (FedPlan conflictParentFedPlan : 
conflictParentFedPlans) {
+                               for (Pair<Long, FederatedOutput> childPlanPair 
: conflictParentFedPlan.getChildFedPlans()) {
+                                       if (childPlanPair.getLeft() == 
conflictHopID && childPlanPair.getRight() != optimalFedOutType) {
+                                               int index = 
conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair);
+                                               
conflictParentFedPlan.getChildFedPlans().set(index, 
+                                                       
Pair.of(childPlanPair.getLeft(), optimalFedOutType));
+                                               break;
+                                       }
+                               }
+                       }
+               }
+               return resolvedFedPlanLinkedMap;
+       }
+       
        /**
         * Computes the cost for the current Hop node.
         * 
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
 
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
index 56de8cf3c4..20485588d3 100644
--- 
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
@@ -39,7 +39,7 @@ import 
org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator;
 
 public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
 {
-       private static final String TEST_DIR = "functions/federated/";
+       private static final String TEST_DIR = "functions/federated/privacy/";
        private static final String HOME = SCRIPT_DIR + TEST_DIR;
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/";
        
@@ -47,8 +47,15 @@ public class FederatedPlanCostEnumeratorTest extends 
AutomatedTestBase
        public void setUp() {}
        
        @Test
-       public void testDependencyAnalysis1() { runTest("cost.dml"); }
-       
+       public void testFederatedPlanCostEnumerator1() { 
runTest("FederatedPlanCostEnumeratorTest1.dml"); }
+
+       @Test
+       public void testFederatedPlanCostEnumerator2() { 
runTest("FederatedPlanCostEnumeratorTest2.dml"); }
+
+       @Test
+       public void testFederatedPlanCostEnumerator3() { 
runTest("FederatedPlanCostEnumeratorTest3.dml"); }
+
+       // Todo: Need to write test scripts for the federated version
        private void runTest( String scriptFilename ) {
                int index = scriptFilename.lastIndexOf(".dml");
                String testName = scriptFilename.substring(0, index > 0 ? index 
: scriptFilename.length());
@@ -72,10 +79,7 @@ public class FederatedPlanCostEnumeratorTest extends 
AutomatedTestBase
                        dmlt.constructHops(prog);
                        dmlt.rewriteHopsDAG(prog);
                        dmlt.constructLops(prog);
-                       /* TODO) In the current DAG, Hop's _outputMemEstimate 
is not initialized
-                       // This leads to incorrect fedplan generation, so test 
code needs to be modified
-                       // If needed, modify costEstimator to handle cases 
where _outputMemEstimate is not initialized
-                       */
+
                        Hop hops = 
prog.getStatementBlocks().get(0).getHops().get(0);
                        
FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true);
                }
diff --git a/src/test/scripts/functions/federated/cost.dml 
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest1.dml
similarity index 100%
copy from src/test/scripts/functions/federated/cost.dml
copy to 
src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest1.dml
diff --git a/src/test/scripts/functions/federated/cost.dml 
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest2.dml
similarity index 95%
copy from src/test/scripts/functions/federated/cost.dml
copy to 
src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest2.dml
index ec34d45bb6..3cc07eeb01 100644
--- a/src/test/scripts/functions/federated/cost.dml
+++ 
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest2.dml
@@ -21,5 +21,6 @@
 
 a = matrix(7,10,10);
 b = a + a^2;
-c = sqrt(b);
-print(sum(c));
\ No newline at end of file
+c = a * b;
+d = b + sqrt(c);
+print(sum(d));
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/cost.dml 
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest3.dml
similarity index 93%
rename from src/test/scripts/functions/federated/cost.dml
rename to 
src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest3.dml
index ec34d45bb6..7fe002df75 100644
--- a/src/test/scripts/functions/federated/cost.dml
+++ 
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest3.dml
@@ -20,6 +20,9 @@
 #-------------------------------------------------------------
 
 a = matrix(7,10,10);
-b = a + a^2;
+if (sum(a) > 0.5)
+    b = a + a^2;
+else
+    b = a * a;
 c = sqrt(b);
 print(sum(c));
\ No newline at end of file


Reply via email to