This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 693ef52fa8 [SYSTEMDS-3790] Extended optimizer for federated execution
plans
693ef52fa8 is described below
commit 693ef52fa8b137aaf91bf45cee57cc895bce5aef
Author: min-guk <[email protected]>
AuthorDate: Fri Apr 18 11:52:42 2025 +0200
[SYSTEMDS-3790] Extended optimizer for federated execution plans
Closes #2238.
---
scripts/staging/fedplanner/graph.py | 268 ++++++++
.../sysds/hops/fedplanner/FederatedMemoTable.java | 173 ++---
.../hops/fedplanner/FederatedMemoTablePrinter.java | 153 +++--
.../fedplanner/FederatedPlanCostEnumerator.java | 757 +++++++++++++++------
.../fedplanner/FederatedPlanCostEstimator.java | 466 +++++++------
.../federated/FederatedPlanCostEnumeratorTest.java | 157 +++--
.../component/federated/FederatedPlanVisualizer.py | 268 ++++++++
.../privacy/FederatedPlanCostEnumeratorTest10.dml | 33 +
.../privacy/FederatedPlanCostEnumeratorTest4.dml | 28 +
.../privacy/FederatedPlanCostEnumeratorTest5.dml | 26 +
.../privacy/FederatedPlanCostEnumeratorTest6.dml | 34 +
.../privacy/FederatedPlanCostEnumeratorTest7.dml | 28 +
.../privacy/FederatedPlanCostEnumeratorTest8.dml | 49 ++
.../privacy/FederatedPlanCostEnumeratorTest9.dml | 58 ++
14 files changed, 1842 insertions(+), 656 deletions(-)
diff --git a/scripts/staging/fedplanner/graph.py
b/scripts/staging/fedplanner/graph.py
new file mode 100644
index 0000000000..b083c77913
--- /dev/null
+++ b/scripts/staging/fedplanner/graph.py
@@ -0,0 +1,268 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import sys
+import re
+import networkx as nx
+import matplotlib.pyplot as plt
+
+try:
+ import pygraphviz
+ from networkx.drawing.nx_agraph import graphviz_layout
+ HAS_PYGRAPHVIZ = True
+except ImportError:
+ HAS_PYGRAPHVIZ = False
+ print("[WARNING] pygraphviz not found. Please install via 'pip install
pygraphviz'.\n"
+ "If not installed, we will use an alternative layout
(spring_layout).")
+
+
+def parse_line(line: str):
+ """
+ Parse a single line from the trace file to extract:
+ - Node ID
+ - Operation (hop name)
+ - Kind (e.g., FOUT, LOUT, NREF)
+ - Total cost
+ - Weight
+ - Refs (list of IDs that this node depends on)
+ """
+
+ # 1) Match a node ID in the form of "(R)" or "(<number>)"
+ match_id = re.match(r'^\((R|\d+)\)', line)
+ if not match_id:
+ return None
+ node_id = match_id.group(1)
+
+ # 2) The remaining string after the node ID
+ after_id = line[match_id.end():].strip()
+
+ # Extract operation (hop name) before the first "["
+ match_label = re.search(r'^(.*?)\s*\[', after_id)
+ if match_label:
+ operation = match_label.group(1).strip()
+ else:
+ operation = after_id.strip()
+
+ # 3) Extract the kind (content inside the first pair of brackets "[]")
+ match_bracket = re.search(r'\[([^\]]+)\]', after_id)
+ if match_bracket:
+ kind = match_bracket.group(1).strip()
+ else:
+ kind = ""
+
+ # 4) Extract total and weight from the content inside curly braces "{}"
+ total = ""
+ weight = ""
+ match_curly = re.search(r'\{([^}]+)\}', line)
+ if match_curly:
+ curly_content = match_curly.group(1)
+ m_total = re.search(r'Total:\s*([\d\.]+)', curly_content)
+ m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content)
+ if m_total:
+ total = m_total.group(1)
+ if m_weight:
+ weight = m_weight.group(1)
+
+ # 5) Extract reference nodes: look for the first parenthesis containing
numbers after the hop name
+ match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id)
+ if match_refs:
+ ref_str = match_refs.group(1)
+ refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()]
+ else:
+ refs = []
+
+ return {
+ 'node_id': node_id,
+ 'operation': operation,
+ 'kind': kind,
+ 'total': total,
+ 'weight': weight,
+ 'refs': refs
+ }
+
+
+def build_dag_from_file(filename: str):
+ """
+ Read a trace file line by line and build a directed acyclic graph (DAG)
using NetworkX.
+ """
+ G = nx.DiGraph()
+ with open(filename, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+
+ info = parse_line(line)
+ if not info:
+ continue
+
+ node_id = info['node_id']
+ operation = info['operation']
+ kind = info['kind']
+ total = info['total']
+ weight = info['weight']
+ refs = info['refs']
+
+ # Add node with attributes
+ G.add_node(node_id, label=operation, kind=kind, total=total,
weight=weight)
+
+ # Add edges from references to this node
+ for r in refs:
+ if r not in G:
+ G.add_node(r, label=r, kind="", total="", weight="")
+ G.add_edge(r, node_id)
+ return G
+
+
+def main():
+ """
+ Main function that:
+ - Reads a filename from command-line arguments
+ - Builds a DAG from the file
+ - Draws and displays the DAG using matplotlib
+ """
+
+ # Get filename from command-line argument
+ if len(sys.argv) < 2:
+ print("[ERROR] No filename provided.\nUsage: python
plot_federated_dag.py <filename>")
+ sys.exit(1)
+ filename = sys.argv[1]
+
+ print(f"[INFO] Running with filename '{filename}'")
+
+ # Build the DAG
+ G = build_dag_from_file(filename)
+
+ # Print debug info: nodes and edges
+ print("Nodes:", G.nodes(data=True))
+ print("Edges:", list(G.edges()))
+
+ # Decide on layout
+ if HAS_PYGRAPHVIZ:
+ # graphviz_layout with rankdir=BT (bottom to top), etc.
+ pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5
-Granksep=0.8')
+ else:
+ # Fallback layout if pygraphviz is not installed
+ pos = nx.spring_layout(G, seed=42)
+
+ # Dynamically adjust figure size based on number of nodes
+ node_count = len(G.nodes())
+ fig_width = 10 + node_count / 10.0
+ fig_height = 6 + node_count / 10.0
+ plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300)
+ ax = plt.gca()
+ ax.set_facecolor('white')
+
+ # Generate labels for each node in the format:
+ # node_id: operation_name
+ # C<total> (W<weight>)
+ labels = {
+ n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total',
'')} (W{G.nodes[n].get('weight', '')})"
+ for n in G.nodes()
+ }
+
+ # Function to determine color based on 'kind'
+ def get_color(n):
+ k = G.nodes[n].get('kind', '').lower()
+ if k == 'fout':
+ return 'tomato'
+ elif k == 'lout':
+ return 'dodgerblue'
+ elif k == 'nref':
+ return 'mediumpurple'
+ else:
+ return 'mediumseagreen'
+
+ # Determine node shapes based on operation name:
+ # - '^' (triangle) if the label contains "twrite"
+ # - 's' (square) if the label contains "tread"
+ # - 'o' (circle) otherwise
+ triangle_nodes = [n for n in G.nodes() if 'twrite' in
G.nodes[n].get('label', '').lower()]
+ square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label',
'').lower()]
+ other_nodes = [
+ n for n in G.nodes()
+ if 'twrite' not in G.nodes[n].get('label', '').lower() and
+ 'tread' not in G.nodes[n].get('label', '').lower()
+ ]
+
+ # Colors for each group
+ triangle_colors = [get_color(n) for n in triangle_nodes]
+ square_colors = [get_color(n) for n in square_nodes]
+ other_colors = [get_color(n) for n in other_nodes]
+
+ # Draw nodes group-wise
+ node_collection_triangle = nx.draw_networkx_nodes(
+ G, pos, nodelist=triangle_nodes, node_size=800,
+ node_color=triangle_colors, node_shape='^', ax=ax
+ )
+ node_collection_square = nx.draw_networkx_nodes(
+ G, pos, nodelist=square_nodes, node_size=800,
+ node_color=square_colors, node_shape='s', ax=ax
+ )
+ node_collection_other = nx.draw_networkx_nodes(
+ G, pos, nodelist=other_nodes, node_size=800,
+ node_color=other_colors, node_shape='o', ax=ax
+ )
+
+ # Set z-order for nodes, edges, and labels
+ node_collection_triangle.set_zorder(1)
+ node_collection_square.set_zorder(1)
+ node_collection_other.set_zorder(1)
+
+ edge_collection = nx.draw_networkx_edges(G, pos, arrows=True,
arrowstyle='->', ax=ax)
+ if isinstance(edge_collection, list):
+ for ec in edge_collection:
+ ec.set_zorder(2)
+ else:
+ edge_collection.set_zorder(2)
+
+ label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9,
ax=ax)
+ for text in label_dict.values():
+ text.set_zorder(3)
+
+ # Set the title
+ plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold")
+
+ # Provide a small legend on the top-right or top-left
+ plt.text(1, 1,
+ "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))",
+ fontsize=12, ha='right', va='top', transform=ax.transAxes)
+
+ # Example mini-legend for different 'kind' values
+ plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes)
+ plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes)
+ plt.scatter(0.31, 0.95, color='mediumpurple', s=200,
transform=ax.transAxes)
+
+ plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center',
transform=ax.transAxes)
+ plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center',
transform=ax.transAxes)
+ plt.text(0.34, 0.95, "NREF", fontsize=12, va='center',
transform=ax.transAxes)
+
+ plt.axis("off")
+
+ # Save the plot to a file with the same name as the input file, but with a
.png extension
+ output_filename = f"{filename.rsplit('.', 1)[0]}.png"
+ plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight')
+
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
index b2b58871f6..b35723b817 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
@@ -19,15 +19,15 @@
package org.apache.sysds.hops.fedplanner;
-import org.apache.sysds.hops.Hop;
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.commons.lang3.tuple.ImmutablePair;
-import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
+import org.apache.sysds.hops.Hop;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
/**
* A Memoization Table for managing federated plans (FedPlan) based on
combinations of Hops and fedOutTypes.
@@ -38,48 +38,8 @@ public class FederatedMemoTable {
// Maps Hop ID and fedOutType pairs to their plan variants
private final Map<Pair<Long, FederatedOutput>, FedPlanVariants>
hopMemoTable = new HashMap<>();
- /**
- * Adds a new federated plan to the memo table.
- * Creates a new variant list if none exists for the given Hop and
fedOutType.
- *
- * @param hop The Hop node
- * @param fedOutType The federated output type
- * @param planChilds List of child plan references
- * @return The newly created FedPlan
- */
- public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType,
List<Pair<Long, FederatedOutput>> planChilds) {
- long hopID = hop.getHopID();
- FedPlanVariants fedPlanVariantList;
-
- if (contains(hopID, fedOutType)) {
- fedPlanVariantList = hopMemoTable.get(new
ImmutablePair<>(hopID, fedOutType));
- } else {
- fedPlanVariantList = new FedPlanVariants(hop,
fedOutType);
- hopMemoTable.put(new ImmutablePair<>(hopID,
fedOutType), fedPlanVariantList);
- }
-
- FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList);
- fedPlanVariantList.addFedPlan(newPlan);
-
- return newPlan;
- }
-
- /**
- * Retrieves the minimum cost child plan considering the parent's
output type.
- * The cost is calculated using getParentViewCost to account for
potential type mismatches.
- *
- * @param fedPlanPair ???
- * @return min cost fed plan
- */
- public FedPlan getMinCostFedPlan(Pair<Long, FederatedOutput>
fedPlanPair) {
- FedPlanVariants fedPlanVariantList =
hopMemoTable.get(fedPlanPair);
- return fedPlanVariantList._fedPlanVariants.stream()
-
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
- .orElse(null);
- }
-
- public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput
fedOutType) {
- return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
+ public void addFedPlanVariants(long hopID, FederatedOutput fedOutType,
FedPlanVariants fedPlanVariants) {
+ hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType),
fedPlanVariants);
}
public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput>
fedPlanPair) {
@@ -87,53 +47,47 @@ public class FederatedMemoTable {
}
public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput
fedOutType) {
- // Todo: Consider whether to verify if pruning has been
performed
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new
ImmutablePair<>(hopID, fedOutType));
return fedPlanVariantList._fedPlanVariants.get(0);
}
public FedPlan getFedPlanAfterPrune(Pair<Long, FederatedOutput>
fedPlanPair) {
- // Todo: Consider whether to verify if pruning has been
performed
FedPlanVariants fedPlanVariantList =
hopMemoTable.get(fedPlanPair);
return fedPlanVariantList._fedPlanVariants.get(0);
}
- /**
- * Checks if the memo table contains an entry for a given Hop and
fedOutType.
- *
- * @param hopID The Hop ID.
- * @param fedOutType The associated fedOutType.
- * @return True if the entry exists, false otherwise.
- */
public boolean contains(long hopID, FederatedOutput fedOutType) {
return hopMemoTable.containsKey(new ImmutablePair<>(hopID,
fedOutType));
}
/**
- * Prunes the specified entry in the memo table, retaining only the
minimum-cost
- * FedPlan for the given Hop ID and federated output type.
- *
- * @param hopID The ID of the Hop to prune
- * @param federatedOutput The federated output type associated with the
Hop
- */
- public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) {
- hopMemoTable.get(new ImmutablePair<>(hopID,
federatedOutput)).prune();
- }
-
- /**
- * Represents common properties and costs associated with a Hop.
- * This class holds a reference to the Hop and tracks its execution and
network transfer costs.
+ * Represents a single federated execution plan with its associated
costs and dependencies.
+ * This class contains:
+ * 1. selfCost: Cost of the current hop (computation + input/output
memory access).
+ * 2. cumulativeCost: Total cost including this plan's selfCost and all
child plans' cumulativeCost.
+ * 3. forwardingCost: Network transfer cost for this plan to the parent
plan.
+ *
+ * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon
to manage common properties and costs.
*/
- public static class HopCommon {
- protected final Hop hopRef; // Reference to the
associated Hop
- protected double selfCost; // Current execution cost
(compute + memory access)
- protected double netTransferCost; // Network transfer cost
+ public static class FedPlan {
+ private double cumulativeCost; // Total cost =
sum of selfCost + cumulativeCost of child plans
+ private final FedPlanVariants fedPlanVariants; // Reference to
variant list
+ private final List<Pair<Long, FederatedOutput>> childFedPlans;
// Child plan references
- protected HopCommon(Hop hopRef) {
- this.hopRef = hopRef;
- this.selfCost = 0;
- this.netTransferCost = 0;
+ public FedPlan(double cumulativeCost, FedPlanVariants
fedPlanVariants, List<Pair<Long, FederatedOutput>> childFedPlans) {
+ this.cumulativeCost = cumulativeCost;
+ this.fedPlanVariants = fedPlanVariants;
+ this.childFedPlans = childFedPlans;
}
+
+ public Hop getHopRef() {return
fedPlanVariants.hopCommon.getHopRef();}
+ public long getHopID() {return
fedPlanVariants.hopCommon.getHopRef().getHopID();}
+ public FederatedOutput getFedOutType() {return
fedPlanVariants.getFedOutType();}
+ public double getCumulativeCost() {return cumulativeCost;}
+ public double getSelfCost() {return
fedPlanVariants.hopCommon.getSelfCost();}
+ public double getForwardingCost() {return
fedPlanVariants.hopCommon.getForwardingCost();}
+ public double getWeight() {return
fedPlanVariants.hopCommon.getWeight();}
+ public List<Pair<Long, FederatedOutput>> getChildFedPlans()
{return childFedPlans;}
}
/**
@@ -146,21 +100,22 @@ public class FederatedMemoTable {
private final FederatedOutput fedOutType; // Output type
(FOUT/LOUT)
protected List<FedPlan> _fedPlanVariants; // List of plan
variants
- public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
- this.hopCommon = new HopCommon(hopRef);
+ public FedPlanVariants(HopCommon hopCommon, FederatedOutput
fedOutType) {
+ this.hopCommon = hopCommon;
this.fedOutType = fedOutType;
this._fedPlanVariants = new ArrayList<>();
}
+ public boolean isEmpty() {return _fedPlanVariants.isEmpty();}
public void addFedPlan(FedPlan fedPlan)
{_fedPlanVariants.add(fedPlan);}
public List<FedPlan> getFedPlanVariants() {return
_fedPlanVariants;}
- public boolean isEmpty() {return _fedPlanVariants.isEmpty();}
+ public FederatedOutput getFedOutType() {return fedOutType;}
- public void prune() {
+ public void pruneFedPlans() {
if (_fedPlanVariants.size() > 1) {
- // Find the FedPlan with the minimum cost
+ // Find the FedPlan with the minimum cumulative
cost
FedPlan minCostPlan = _fedPlanVariants.stream()
-
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+
.min(Comparator.comparingDouble(FedPlan::getCumulativeCost))
.orElse(null);
// Retain only the minimum cost plan
@@ -171,46 +126,28 @@ public class FederatedMemoTable {
}
/**
- * Represents a single federated execution plan with its associated
costs and dependencies.
- * This class contains:
- * 1. selfCost: Cost of current hop (compute + input/output memory
access)
- * 2. totalCost: Cumulative cost including this plan and all child plans
- * 3. netTransferCost: Network transfer cost for this plan to parent
plan.
- *
- * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon
to manage common properties and costs.
+ * Represents common properties and costs associated with a Hop.
+ * This class holds a reference to the Hop and tracks its execution and
network forwarding (transfer) costs.
*/
- public static class FedPlan {
- private double totalCost; // Total cost
including child plans
- private final FedPlanVariants fedPlanVariants; // Reference to
variant list
- private final List<Pair<Long, FederatedOutput>> childFedPlans;
// Child plan references
+ public static class HopCommon {
+ protected final Hop hopRef; // Reference to the associated Hop
+ protected double selfCost; // Cost of the hop's computation and
memory access
+ protected double forwardingCost; // Cost of forwarding the
hop's output to its parent
+ protected double weight; // Weight used to calculate cost based
on hop execution frequency
- public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans,
FedPlanVariants fedPlanVariants) {
- this.totalCost = 0;
- this.childFedPlans = childFedPlans;
- this.fedPlanVariants = fedPlanVariants;
+ public HopCommon(Hop hopRef, double weight) {
+ this.hopRef = hopRef;
+ this.selfCost = 0;
+ this.forwardingCost = 0;
+ this.weight = weight;
}
- public void setTotalCost(double totalCost) {this.totalCost =
totalCost;}
- public void setSelfCost(double selfCost)
{fedPlanVariants.hopCommon.selfCost = selfCost;}
- public void setNetTransferCost(double netTransferCost)
{fedPlanVariants.hopCommon.netTransferCost = netTransferCost;}
-
- public Hop getHopRef() {return
fedPlanVariants.hopCommon.hopRef;}
- public long getHopID() {return
fedPlanVariants.hopCommon.hopRef.getHopID();}
- public FederatedOutput getFedOutType() {return
fedPlanVariants.fedOutType;}
- public double getTotalCost() {return totalCost;}
- public double getSelfCost() {return
fedPlanVariants.hopCommon.selfCost;}
- public double getNetTransferCost() {return
fedPlanVariants.hopCommon.netTransferCost;}
- public List<Pair<Long, FederatedOutput>> getChildFedPlans()
{return childFedPlans;}
+ public Hop getHopRef() {return hopRef;}
+ public double getSelfCost() {return selfCost;}
+ public double getForwardingCost() {return forwardingCost;}
+ public double getWeight() {return weight;}
- /**
- * Calculates the conditional network transfer cost based on
output type compatibility.
- * Returns 0 if output types match, otherwise returns the
network transfer cost.
- * @param parentFedOutType The federated output type of the
parent plan.
- * @return The conditional network transfer cost.
- */
- public double getCondNetTransferCost(FederatedOutput
parentFedOutType) {
- if (parentFedOutType == getFedOutType()) return 0;
- return fedPlanVariants.hopCommon.netTransferCost;
- }
+ protected void setSelfCost(double selfCost) {this.selfCost =
selfCost;}
+ protected void setForwardingCost(double forwardingCost)
{this.forwardingCost = forwardingCost;}
}
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
index f7b3343a98..2841256607 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
@@ -1,28 +1,11 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
package org.apache.sysds.hops.fedplanner;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import java.util.HashSet;
import java.util.List;
@@ -35,14 +18,52 @@ public class FederatedMemoTablePrinter {
* Additionally, prints the additional total cost once at the beginning.
*
* @param rootFedPlan The starting point FedPlan to print
+ * @param rootHopStatSet ???
* @param memoTable The memoization table containing FedPlan variants
* @param additionalTotalCost The additional cost to be printed once
*/
- public static void printFedPlanTree(FederatedMemoTable.FedPlan
rootFedPlan, FederatedMemoTable memoTable,
-
double additionalTotalCost) {
+ public static void printFedPlanTree(FederatedMemoTable.FedPlan
rootFedPlan, Set<Hop> rootHopStatSet,
+
FederatedMemoTable memoTable, double additionalTotalCost) {
System.out.println("Additional Cost: " + additionalTotalCost);
- Set<FederatedMemoTable.FedPlan> visited = new HashSet<>();
+ Set<Long> visited = new HashSet<>();
printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0);
+
+ for (Hop hop : rootHopStatSet) {
+ FedPlan plan =
memoTable.getFedPlanAfterPrune(hop.getHopID(), FederatedOutput.LOUT);
+ printNotReferencedFedPlanRecursive(plan, memoTable,
visited, 1);
+ }
+ }
+
+ /**
+ * Helper method to recursively print the FedPlan tree.
+ *
+ * @param plan The current FedPlan to print
+ * @param visited Set to keep track of visited FedPlans (prevents
cycles)
+ * @param depth The current depth level for indentation
+ */
+ private static void
printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPlan plan,
FederatedMemoTable memoTable,
+
Set<Long> visited, int depth) {
+ long hopID = plan.getHopRef().getHopID();
+
+ if (visited.contains(hopID)) {
+ return;
+ }
+
+ visited.add(hopID);
+ printFedPlan(plan, depth, true);
+
+ // Process child nodes
+ List<Pair<Long, FEDInstruction.FederatedOutput>>
childFedPlanPairs = plan.getChildFedPlans();
+ for (int i = 0; i < childFedPlanPairs.size(); i++) {
+ Pair<Long, FEDInstruction.FederatedOutput>
childFedPlanPair = childFedPlanPairs.get(i);
+ FederatedMemoTable.FedPlanVariants childVariants =
memoTable.getFedPlanVariants(childFedPlanPair);
+ if (childVariants == null || childVariants.isEmpty())
+ continue;
+
+ for (FederatedMemoTable.FedPlan childPlan :
childVariants.getFedPlanVariants()) {
+ printNotReferencedFedPlanRecursive(childPlan,
memoTable, visited, depth + 1);
+ }
+ }
}
/**
@@ -53,40 +74,83 @@ public class FederatedMemoTablePrinter {
* @param depth The current depth level for indentation
*/
private static void
printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable
memoTable,
-
Set<FederatedMemoTable.FedPlan> visited, int depth) {
- if (plan == null || visited.contains(plan)) {
+
Set<Long> visited, int depth) {
+ long hopID = 0;
+
+ if (depth == 0) {
+ hopID = -1;
+ } else {
+ hopID = plan.getHopRef().getHopID();
+ }
+
+ if (visited.contains(hopID)) {
return;
}
- visited.add(plan);
+ visited.add(hopID);
+ printFedPlan(plan, depth, false);
- Hop hop = plan.getHopRef();
- StringBuilder sb = new StringBuilder();
+ // Process child nodes
+ List<Pair<Long, FEDInstruction.FederatedOutput>>
childFedPlanPairs = plan.getChildFedPlans();
+ for (int i = 0; i < childFedPlanPairs.size(); i++) {
+ Pair<Long, FEDInstruction.FederatedOutput>
childFedPlanPair = childFedPlanPairs.get(i);
+ FederatedMemoTable.FedPlanVariants childVariants =
memoTable.getFedPlanVariants(childFedPlanPair);
+ if (childVariants == null || childVariants.isEmpty())
+ continue;
+
+ for (FederatedMemoTable.FedPlan childPlan :
childVariants.getFedPlanVariants()) {
+ printFedPlanTreeRecursive(childPlan, memoTable,
visited, depth + 1);
+ }
+ }
+ }
- // Add FedPlan information
- sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
- .append(plan.getHopRef().getOpString())
- .append(" [")
- .append(plan.getFedOutType())
- .append("]");
+ private static void printFedPlan(FederatedMemoTable.FedPlan plan, int
depth, boolean isNotReferenced) {
+ StringBuilder sb = new StringBuilder();
+ Hop hop = null;
+
+ if (depth == 0){
+ sb.append("(R) ROOT [Root]");
+ } else {
+ hop = plan.getHopRef();
+ // Add FedPlan information
+ sb.append(String.format("(%d) ", hop.getHopID()))
+ .append(hop.getOpString())
+ .append(" [");
+
+ if (isNotReferenced) {
+ sb.append("NRef");
+ } else{
+ sb.append(plan.getFedOutType());
+ }
+ sb.append("]");
+ }
StringBuilder childs = new StringBuilder();
childs.append(" (");
+
boolean childAdded = false;
- for( Hop input : hop.getInput()){
+ for (Pair<Long, FederatedOutput> childPair :
plan.getChildFedPlans()){
childs.append(childAdded?",":"");
- childs.append(input.getHopID());
+ childs.append(childPair.getLeft());
childAdded = true;
}
+
childs.append(")");
+
if( childAdded )
sb.append(childs.toString());
+ if (depth == 0){
+ sb.append(String.format(" {Total: %.1f}",
plan.getCumulativeCost()));
+ System.out.println(sb);
+ return;
+ }
- sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
- plan.getTotalCost(),
+ sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f,
Weight: %.1f}",
+ plan.getCumulativeCost(),
plan.getSelfCost(),
- plan.getNetTransferCost()));
+ plan.getForwardingCost(),
+ plan.getWeight()));
// Add matrix characteristics
sb.append(" [")
@@ -122,18 +186,5 @@ public class FederatedMemoTablePrinter {
}
System.out.println(sb);
-
- // Process child nodes
- List<Pair<Long, FEDInstruction.FederatedOutput>>
childFedPlanPairs = plan.getChildFedPlans();
- for (int i = 0; i < childFedPlanPairs.size(); i++) {
- Pair<Long, FEDInstruction.FederatedOutput>
childFedPlanPair = childFedPlanPairs.get(i);
- FederatedMemoTable.FedPlanVariants childVariants =
memoTable.getFedPlanVariants(childFedPlanPair);
- if (childVariants == null || childVariants.isEmpty())
- continue;
-
- for (FederatedMemoTable.FedPlan childPlan :
childVariants.getFedPlanVariants()) {
- printFedPlanTreeRecursive(childPlan, memoTable,
visited, depth + 1);
- }
- }
}
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
index be1cfa7cdf..f3e8cc286d 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
@@ -17,218 +17,581 @@
* under the License.
*/
-package org.apache.sysds.hops.fedplanner;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Objects;
-import java.util.LinkedHashMap;
-
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.commons.lang3.tuple.ImmutablePair;
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
-import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
-import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
-
-/**
- * Enumerates and evaluates all possible federated execution plans for a given
Hop DAG.
- * Works with FederatedMemoTable to store plan variants and
FederatedPlanCostEstimator
- * to compute their costs.
- */
-public class FederatedPlanCostEnumerator {
- /**
- * Entry point for federated plan enumeration. This method creates a
memo table
- * and returns the minimum cost plan for the entire Directed Acyclic
Graph (DAG).
- * It also resolves conflicts where FedPlans have different
FederatedOutput types.
- *
- * @param rootHop The root Hop node from which to start the plan
enumeration.
- * @param printTree A boolean flag indicating whether to print the
federated plan tree.
- * @return The optimal FedPlan with the minimum cost for the entire DAG.
- */
- public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean
printTree) {
- // Create new memo table to store all plan variants
- FederatedMemoTable memoTable = new FederatedMemoTable();
-
- // Recursively enumerate all possible plans
- enumerateFederatedPlanCost(rootHop, memoTable);
-
- // Return the minimum cost plan for the root node
- FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(),
memoTable);
-
- // Detect conflicts in the federated plans where different
FedPlans have different FederatedOutput types
- double additionalTotalCost =
detectAndResolveConflictFedPlan(optimalPlan, memoTable);
-
- // Optionally print the federated plan tree if requested
- if (printTree)
FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable,
additionalTotalCost);
-
- return optimalPlan;
- }
-
- /**
- * Recursively enumerates all possible federated execution plans for a
Hop DAG.
- * For each node:
- * 1. First processes all input nodes recursively if not already
processed
- * 2. Generates all possible combinations of federation types
(FOUT/LOUT) for inputs
- * 3. Creates and evaluates both FOUT and LOUT variants for current
node with each input combination
- *
- * The enumeration uses a bottom-up approach where:
- * - Each input combination is represented by a binary number (i)
- * - Bit j in i determines whether input j is FOUT (1) or LOUT (0)
- * - Total number of combinations is 2^numInputs
- *
- * @param hop ?
- * @param memoTable ?
- */
- private static void enumerateFederatedPlanCost(Hop hop,
FederatedMemoTable memoTable) {
- int numInputs = hop.getInput().size();
-
+ package org.apache.sysds.hops.fedplanner;
+ import java.util.ArrayList;
+ import java.util.List;
+ import java.util.Map;
+ import java.util.HashMap;
+ import java.util.LinkedHashMap;
+ import java.util.Optional;
+ import java.util.Set;
+ import java.util.HashSet;
+
+ import org.apache.commons.lang3.tuple.Pair;
+
+ import org.apache.commons.lang3.tuple.ImmutablePair;
+ import org.apache.sysds.common.Types;
+ import org.apache.sysds.hops.DataOp;
+ import org.apache.sysds.hops.Hop;
+ import org.apache.sysds.hops.LiteralOp;
+ import org.apache.sysds.hops.UnaryOp;
+ import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon;
+ import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
+ import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
+ import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+ import org.apache.sysds.parser.DMLProgram;
+ import org.apache.sysds.parser.ForStatement;
+ import org.apache.sysds.parser.ForStatementBlock;
+ import org.apache.sysds.parser.FunctionStatement;
+ import org.apache.sysds.parser.FunctionStatementBlock;
+ import org.apache.sysds.parser.IfStatement;
+ import org.apache.sysds.parser.IfStatementBlock;
+ import org.apache.sysds.parser.StatementBlock;
+ import org.apache.sysds.parser.WhileStatement;
+ import org.apache.sysds.parser.WhileStatementBlock;
+ import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+ import org.apache.sysds.runtime.util.UtilFunctions;
+
+ public class FederatedPlanCostEnumerator {
+ private static final double DEFAULT_LOOP_WEIGHT = 10.0;
+ private static final double DEFAULT_IF_ELSE_WEIGHT = 0.5;
+
+ /**
+ * Enumerates the entire DML program to generate federated execution
plans.
+ * It processes each statement block, computes the optimal federated
plan,
+ * detects and resolves conflicts, and optionally prints the plan tree.
+ *
+ * @param prog The DML program to enumerate.
+ * @param isPrint A boolean indicating whether to print the federated
plan tree.
+ */
+ public static void enumerateProgram(DMLProgram prog, boolean isPrint) {
+ FederatedMemoTable memoTable = new FederatedMemoTable();
+
+ Map<String, List<Hop>> outerTransTable = new HashMap<>();
+ Map<String, List<Hop>> formerInnerTransTable = new HashMap<>();
+ Set<Hop> progRootHopSet = new HashSet<>(); // Set of hops for
the root dummy node
+ // TODO: Just for debug, remove later
+ Set<Hop> statRootHopSet = new HashSet<>(); // Set of hops that
have no parent but are not referenced
+
+ for (StatementBlock sb : prog.getStatementBlocks()) {
+ Optional.ofNullable(enumerateStatementBlock(sb,
memoTable, outerTransTable, formerInnerTransTable, progRootHopSet,
statRootHopSet, 1, false))
+ .ifPresent(outerTransTable::putAll);
+ }
+
+ FedPlan optimalPlan = getMinCostRootFedPlan(progRootHopSet,
memoTable);
+
+ // Detect conflicts in the federated plans where different
FedPlans have different FederatedOutput types
+ double additionalTotalCost =
detectAndResolveConflictFedPlan(optimalPlan, memoTable);
+
+ // Print the federated plan tree if requested
+ if (isPrint) {
+
FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, statRootHopSet,
memoTable, additionalTotalCost);
+ }
+ }
+
+
+ /**
+ * Enumerates the statement block and updates the transient and
memoization tables.
+ * This method processes different types of statement blocks such as
If, For, While, and Function blocks.
+ * It recursively enumerates the Hop DAGs within these blocks and
updates the corresponding tables.
+ * The method also calculates weights recursively for if-else/loops
and handles inner and outer block distinctions.
+ *
+ * @param sb The statement block to enumerate.
+ * @param memoTable The memoization table to store plan variants.
+ * @param outerTransTable The table to track immutable outer transient
writes.
+ * @param formerInnerTransTable The table to track immutable former
inner transient writes.
+ * @param progRootHopSet The set of hops to connect to the root dummy
node.
+ * @param statRootHopSet The set of statement root hops for debugging
purposes (check if not referenced).
+ * @param weight The weight associated with the current Hop.
+ * @param isInnerBlock A boolean indicating if the current block is an
inner block.
+ * @return A map of inner transient writes.
+ */
+ public static Map<String, List<Hop>>
enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable,
Map<String, List<Hop>> outerTransTable,
+
Map<String, List<Hop>>
formerInnerTransTable, Set<Hop> progRootHopSet, Set<Hop> statRootHopSet, double
weight, boolean isInnerBlock) {
+ Map<String, List<Hop>> innerTransTable = new HashMap<>();
+
+ if (sb instanceof IfStatementBlock) {
+ IfStatementBlock isb = (IfStatementBlock) sb;
+ IfStatement istmt = (IfStatement)isb.getStatement(0);
+
+ enumerateHopDAG(isb.getPredicateHops(), memoTable,
outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet,
statRootHopSet, weight, isInnerBlock);
+
+ // Treat outerTransTable as immutable in inner blocks
+ // Write TWrite of sb sequentially in innerTransTable,
and update formerInnerTransTable after the sb ends
+ // In case of if-else, create separate
formerInnerTransTables for if and else, merge them after completion, and update
formerInnerTransTable
+ Map<String, List<Hop>> ifFormerInnerTransTable = new
HashMap<>(formerInnerTransTable);
+ Map<String, List<Hop>> elseFormerInnerTransTable = new
HashMap<>(formerInnerTransTable);
+
+ for (StatementBlock csb : istmt.getIfBody()){
+
ifFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable,
outerTransTable, ifFormerInnerTransTable, progRootHopSet, statRootHopSet,
DEFAULT_IF_ELSE_WEIGHT * weight, true));
+ }
+
+ for (StatementBlock csb : istmt.getElseBody()){
+
elseFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable,
outerTransTable, elseFormerInnerTransTable, progRootHopSet, statRootHopSet,
DEFAULT_IF_ELSE_WEIGHT * weight, true));
+ }
+
+ // If there are common keys: merge elseValue list into
ifValue list
+ elseFormerInnerTransTable.forEach((key, elseValue) -> {
+ ifFormerInnerTransTable.merge(key, elseValue,
(ifValue, newValue) -> {
+ ifValue.addAll(newValue);
+ return ifValue;
+ });
+ });
+ // Update innerTransTable
+ innerTransTable.putAll(ifFormerInnerTransTable);
+ }
+ else if (sb instanceof ForStatementBlock) { //incl parfor
+ ForStatementBlock fsb = (ForStatementBlock) sb;
+ ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+
+ // Calculate for-loop iteration count if possible
+ double loopWeight = DEFAULT_LOOP_WEIGHT;
+ Hop from = fsb.getFromHops().getInput().get(0);
+ Hop to = fsb.getToHops().getInput().get(0);
+ Hop incr = (fsb.getIncrementHops() != null) ?
+
fsb.getIncrementHops().getInput().get(0) : new LiteralOp(1);
+
+ // Calculate for-loop iteration count (weight) if
from, to, and incr are literal ops (constant values)
+ if( from instanceof LiteralOp && to instanceof
LiteralOp && incr instanceof LiteralOp ) {
+ double dfrom =
HopRewriteUtils.getDoubleValue((LiteralOp) from);
+ double dto =
HopRewriteUtils.getDoubleValue((LiteralOp) to);
+ double dincr =
HopRewriteUtils.getDoubleValue((LiteralOp) incr);
+ if( dfrom > dto && dincr == 1 )
+ dincr = -1;
+ loopWeight = UtilFunctions.getSeqLength(dfrom,
dto, dincr, false);
+ }
+ weight *= loopWeight;
+
+ enumerateHopDAG(fsb.getFromHops(), memoTable,
outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet,
statRootHopSet, weight, isInnerBlock);
+ enumerateHopDAG(fsb.getToHops(), memoTable,
outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet,
statRootHopSet, weight, isInnerBlock);
+ enumerateHopDAG(fsb.getIncrementHops(), memoTable,
outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet,
statRootHopSet, weight, isInnerBlock);
+
+ enumerateStatementBlockBody(fstmt.getBody(),
memoTable, outerTransTable, formerInnerTransTable, innerTransTable,
progRootHopSet, statRootHopSet, weight);
+ }
+ else if (sb instanceof WhileStatementBlock) {
+ WhileStatementBlock wsb = (WhileStatementBlock) sb;
+ WhileStatement wstmt =
(WhileStatement)wsb.getStatement(0);
+ weight *= DEFAULT_LOOP_WEIGHT;
+
+ enumerateHopDAG(wsb.getPredicateHops(), memoTable,
outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet,
statRootHopSet, weight, isInnerBlock);
+ enumerateStatementBlockBody(wstmt.getBody(),
memoTable, outerTransTable, formerInnerTransTable, innerTransTable,
progRootHopSet, statRootHopSet, weight);
+ }
+ else if (sb instanceof FunctionStatementBlock) {
+ FunctionStatementBlock fsb =
(FunctionStatementBlock)sb;
+ FunctionStatement fstmt =
(FunctionStatement)fsb.getStatement(0);
+
+ // TODO: Do not descend for visited functions (use a
hash set for functions using their names)
+ enumerateStatementBlockBody(fstmt.getBody(),
memoTable, outerTransTable, formerInnerTransTable, innerTransTable,
progRootHopSet, statRootHopSet, weight);
+ }
+ else { //generic (last-level)
+ if( sb.getHops() != null ){
+ for(Hop c : sb.getHops())
+ // In the statement block, if isInner,
write hopDAG in innerTransTable, if not, write directly in outerTransTable
+ enumerateHopDAG(c, memoTable,
outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet,
statRootHopSet, weight, isInnerBlock);
+ }
+ }
+ return innerTransTable;
+ }
+
+ /**
+ * Enumerates the statement blocks within a body and updates the
transient and memoization tables.
+ *
+ * @param sbList The list of statement blocks to enumerate.
+ * @param memoTable The memoization table to store plan variants.
+ * @param outerTransTable The table to track immutable outer transient
writes.
+ * @param formerInnerTransTable The table to track immutable former
inner transient writes.
+ * @param innerTransTable The table to track inner transient writes.
+ * @param progRootHopSet The set of hops to connect to the root dummy
node.
+ * @param statRootHopSet The set of statement root hops for debugging
purposes (check if not referenced).
+ * @param weight The weight associated with the current Hop.
+ */
+ public static void enumerateStatementBlockBody(List<StatementBlock>
sbList, FederatedMemoTable memoTable, Map<String, List<Hop>> outerTransTable,
+
Map<String, List<Hop>> formerInnerTransTable, Map<String, List<Hop>>
innerTransTable, Set<Hop> progRootHopSet, Set<Hop> statRootHopSet, double
weight) {
+ // The statement blocks within the body reference
outerTransTable and formerInnerTransTable as immutable read-only,
+ // and record TWrite in the innerTransTable of the statement
block within the body.
+ // Update the formerInnerTransTable with the contents of the
returned innerTransTable.
+ for (StatementBlock sb : sbList)
+
formerInnerTransTable.putAll(enumerateStatementBlock(sb, memoTable,
outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, weight,
true));
+
+ // Then update and return the innerTransTable of the statement
block containing the body.
+ innerTransTable.putAll(formerInnerTransTable);
+ }
+
+ /**
+ * Enumerates the statement hop DAG within a statement block.
+ * This method recursively enumerates all possible federated execution
plans
+ * and identifies hops to connect to the root dummy node.
+ *
+ * @param rootHop The root Hop of the DAG to enumerate.
+ * @param memoTable The memoization table to store plan variants.
+ * @param outerTransTable The table to track transient writes.
+ * @param formerInnerTransTable The table to track immutable inner
transient writes.
+ * @param innerTransTable The table to track inner transient writes.
+ * @param progRootHopSet The set of hops to connect to the root dummy
node.
+ * @param statRootHopSet The set of root hops for debugging purposes.
+ * @param weight The weight associated with the current Hop.
+ * @param isInnerBlock A boolean indicating if the current block is an
inner block.
+ */
+ public static void enumerateHopDAG(Hop rootHop, FederatedMemoTable
memoTable, Map<String, List<Hop>> outerTransTable,
+
Map<String, List<Hop>> formerInnerTransTable, Map<String,List<Hop>>
innerTransTable, Set<Hop> progRootHopSet, Set<Hop> statRootHopSet, double
weight, boolean isInnerBlock) {
+ // Recursively enumerate all possible plans
+ rewireAndEnumerateFedPlan(rootHop, memoTable, outerTransTable,
formerInnerTransTable, innerTransTable, weight, isInnerBlock);
+
+ // Identify hops to connect to the root dummy node
+
+ if ((rootHop instanceof DataOp &&
(rootHop.getName().equals("__pred"))) // TWrite "__pred"
+ || (rootHop instanceof UnaryOp &&
((UnaryOp)rootHop).getOp() == Types.OpOp1.PRINT)){ // u(print)
+ // Connect TWrite pred and u(print) to the root dummy
node
+ // TODO: Should we check all statement-level root hops
to see if they are not referenced?
+ progRootHopSet.add(rootHop);
+ } else {
+ // TODO: Just for debug, remove later
+ // For identifying TWrites that are not referenced
later
+ statRootHopSet.add(rootHop);
+ }
+ }
+
+ /**
+ * Rewires and enumerates federated execution plans for a given Hop.
+ * This method processes all input nodes, rewires TWrite and TRead
operations,
+ * and generates federated plan variants for both inner and outer code
blocks.
+ *
+ * @param hop The Hop for which to rewire and enumerate federated
plans.
+ * @param memoTable The memoization table to store plan variants.
+ * @param outerTransTable The table to track transient writes.
+ * @param formerInnerTransTable The table to track immutable inner
transient writes.
+ * @param innerTransTable The table to track inner transient writes.
+ * @param weight The weight associated with the current Hop.
+ * @param isInner A boolean indicating if the current block is an
inner block.
+ */
+ private static void rewireAndEnumerateFedPlan(Hop hop,
FederatedMemoTable memoTable, Map<String, List<Hop>> outerTransTable,
+
Map<String, List<Hop>> formerInnerTransTable, Map<String,
List<Hop>> innerTransTable,
+
double weight, boolean isInner) {
// Process all input nodes first if not already in memo table
for (Hop inputHop : hop.getInput()) {
- if (!memoTable.contains(inputHop.getHopID(),
FederatedOutput.FOUT)
- && !memoTable.contains(inputHop.getHopID(),
FederatedOutput.LOUT)) {
- enumerateFederatedPlanCost(inputHop,
memoTable);
+ long inputHopID = inputHop.getHopID();
+ if (!memoTable.contains(inputHopID,
FederatedOutput.FOUT)
+ && !memoTable.contains(inputHopID,
FederatedOutput.LOUT)) {
+ rewireAndEnumerateFedPlan(inputHop, memoTable,
outerTransTable, formerInnerTransTable, innerTransTable, weight, isInner);
}
}
- // Generate all possible input combinations using binary
representation
- // i represents a specific combination of FOUT/LOUT for inputs
- for (int i = 0; i < (1 << numInputs); i++) {
- List<Pair<Long, FederatedOutput>> planChilds = new
ArrayList<>();
-
- // For each input, determine if it should be FOUT or
LOUT based on bit j in i
- for (int j = 0; j < numInputs; j++) {
- Hop inputHop = hop.getInput().get(j);
- // If bit j is set (1), use FOUT; otherwise use
LOUT
- FederatedOutput childType = ((i & (1 << j)) !=
0) ?
- FederatedOutput.FOUT :
FederatedOutput.LOUT;
- planChilds.add(Pair.of(inputHop.getHopID(),
childType));
- }
-
- // Create and evaluate FOUT variant for current input
combination
- FedPlan fOutPlan = memoTable.addFedPlan(hop,
FederatedOutput.FOUT, planChilds);
-
FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable);
-
- // Create and evaluate LOUT variant for current input
combination
- FedPlan lOutPlan = memoTable.addFedPlan(hop,
FederatedOutput.LOUT, planChilds);
-
FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
- }
+ // Determine modified child hops based on DataOp type and
transient operations
+ List<Hop> childHops = rewireTransReadWrite(hop,
outerTransTable, formerInnerTransTable, innerTransTable, isInner);
- // Prune MemoTable for hop.
- memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT);
- memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT);
+ // Enumerate the federated plan for the current Hop
+ enumerateFedPlan(hop, memoTable, childHops, weight);
}
- /**
- * Returns the minimum cost plan for the root Hop, comparing both FOUT
and LOUT variants.
- * Used to select the final execution plan after enumeration.
- *
- * @param HopID ?
- * @param memoTable ?
- * @return ?
- */
- private static FedPlan getMinCostRootFedPlan(long HopID,
FederatedMemoTable memoTable) {
- FedPlanVariants fOutFedPlanVariants =
memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT);
- FedPlanVariants lOutFedPlanVariants =
memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT);
-
- FedPlan minFOutFedPlan =
fOutFedPlanVariants._fedPlanVariants.stream()
-
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
-
.orElse(null);
- FedPlan minlOutFedPlan =
lOutFedPlanVariants._fedPlanVariants.stream()
-
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
-
.orElse(null);
-
- if (Objects.requireNonNull(minFOutFedPlan).getTotalCost()
- <
Objects.requireNonNull(minlOutFedPlan).getTotalCost()) {
- return minFOutFedPlan;
+ private static List<Hop> rewireTransReadWrite(Hop hop, Map<String,
List<Hop>> outerTransTable,
+
Map<String, List<Hop>> formerInnerTransTable,
+
Map<String, List<Hop>> innerTransTable, boolean
isInner) {
+ List<Hop> childHops = hop.getInput();
+
+ if (!(hop instanceof DataOp) || hop.getName().equals("__pred"))
{
+ return childHops; // Early exit for non-DataOp or __pred
}
- return minlOutFedPlan;
- }
-
- /**
- * Detects and resolves conflicts in federated plans starting from the
root plan.
- * This function performs a breadth-first search (BFS) to traverse the
federated plan tree.
- * It identifies conflicts where the same plan ID has different
federated output types.
- * For each conflict, it records the plan ID and its conflicting parent
plans.
- * The function ensures that each plan ID is associated with a
consistent federated output type
- * by resolving these conflicts iteratively.
- *
- * The process involves:
- * - Using a map to track conflicts, associating each plan ID with its
federated output type
- * and a list of parent plans.
- * - Storing detected conflicts in a linked map, each entry containing
a plan ID and its
- * conflicting parent plans.
- * - Performing BFS traversal starting from the root plan, checking
each child plan for conflicts.
- * - If a conflict is detected (i.e., a plan ID has different output
types), the conflicting plan
- * is removed from the BFS queue and added to the conflict map to
prevent duplicate calculations.
- * - Resolving conflicts by ensuring a consistent federated output type
across the plan.
- * - Re-running BFS with resolved conflicts to ensure all
inconsistencies are addressed.
- *
- * @param rootPlan The root federated plan from which to start the
conflict detection.
- * @param memoTable The memoization table used to retrieve pruned
federated plans.
- * @return The cumulative additional cost for resolving conflicts.
- */
- private static double detectAndResolveConflictFedPlan(FedPlan rootPlan,
FederatedMemoTable memoTable) {
- // Map to track conflicts: maps a plan ID to its federated
output type and list of parent plans
- Map<Long, Pair<FederatedOutput, List<FedPlan>>>
conflictCheckMap = new HashMap<>();
-
- // LinkedMap to store detected conflicts, each with a plan ID
and its conflicting parent plans
- LinkedHashMap<Long, List<FedPlan>> conflictLinkedMap = new
LinkedHashMap<>();
- // LinkedMap for BFS traversal starting from the root plan (Do
not use value (boolean))
- LinkedHashMap<FedPlan, Boolean> bfsLinkedMap = new
LinkedHashMap<>();
- bfsLinkedMap.put(rootPlan, true);
+ DataOp dataOp = (DataOp) hop;
+ Types.OpOpData opType = dataOp.getOp();
+ String hopName = dataOp.getName();
- // Array to store cumulative additional cost for resolving
conflicts
- double[] cumulativeAdditionalCost = new double[]{0.0};
-
- while (!bfsLinkedMap.isEmpty()) {
- // Perform BFS to detect conflicts in federated plans
- while (!bfsLinkedMap.isEmpty()) {
- FedPlan currentPlan =
bfsLinkedMap.keySet().iterator().next();
- bfsLinkedMap.remove(currentPlan);
+ if (isInner && opType == Types.OpOpData.TRANSIENTWRITE) {
+ innerTransTable.computeIfAbsent(hopName, k -> new
ArrayList<>()).add(hop);
+ }
+ else if (isInner && opType == Types.OpOpData.TRANSIENTREAD) {
+ childHops = rewireInnerTransRead(childHops, hopName,
+ innerTransTable, formerInnerTransTable,
outerTransTable);
+ }
+ else if (!isInner && opType == Types.OpOpData.TRANSIENTWRITE) {
+ outerTransTable.computeIfAbsent(hopName, k -> new
ArrayList<>()).add(hop);
+ }
+ else if (!isInner && opType == Types.OpOpData.TRANSIENTREAD) {
+ childHops = rewireOuterTransRead(childHops, hopName,
outerTransTable);
+ }
- // Iterate over each child plan of the current
plan
- for (Pair<Long, FederatedOutput> childPlanPair
: currentPlan.getChildFedPlans()) {
- FedPlan childFedPlan =
memoTable.getFedPlanAfterPrune(childPlanPair);
+ return childHops;
+ }
- // Check if the child plan ID is
already visited
- if
(conflictCheckMap.containsKey(childPlanPair.getLeft())) {
- // Retrieve the existing
conflict pair for the child plan
- Pair<FederatedOutput,
List<FedPlan>> conflictChildPlanPair =
conflictCheckMap.get(childPlanPair.getLeft());
- // Add the current plan to the
list of parent plans
-
conflictChildPlanPair.getRight().add(currentPlan);
+ private static List<Hop> rewireInnerTransRead(List<Hop> childHops,
String hopName, Map<String, List<Hop>> innerTransTable,
+
Map<String, List<Hop>> formerInnerTransTable,
Map<String, List<Hop>> outerTransTable) {
+ List<Hop> newChildHops = new ArrayList<>(childHops);
- // If the federated output type
differs, a conflict is detected
- if
(conflictChildPlanPair.getLeft() != childPlanPair.getRight()) {
- // If this is the first
detection, remove conflictChildFedPlan from the BFS queue and add it to the
conflict linked map (queue)
- // If the existing
FedPlan is not removed from the bfsqueue or both actions are performed,
duplicate calculations for the same FedPlan and its children occur
- if
(!conflictLinkedMap.containsKey(childPlanPair.getLeft())) {
-
conflictLinkedMap.put(childPlanPair.getLeft(),
conflictChildPlanPair.getRight());
-
bfsLinkedMap.remove(childFedPlan);
- }
- }
- } else {
- // If no conflict exists,
create a new entry in the conflict check map
- List<FedPlan> parentFedPlanList
= new ArrayList<>();
-
parentFedPlanList.add(currentPlan);
+ // Read according to priority: inner -> formerInner -> outer
+ List<Hop> additionalChildHops = innerTransTable.get(hopName);
+ if (additionalChildHops == null) {
+ additionalChildHops =
formerInnerTransTable.get(hopName);
+ }
+ if (additionalChildHops == null) {
+ additionalChildHops = outerTransTable.get(hopName);
+ }
- // Map the child plan ID to its
output type and list of parent plans
-
conflictCheckMap.put(childPlanPair.getLeft(), new
ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList));
- // Add the child plan to the
BFS queue
- bfsLinkedMap.put(childFedPlan,
true);
- }
- }
- }
- // Resolve these conflicts to ensure a consistent
federated output type across the plan
- // Re-run BFS with resolved conflicts
- bfsLinkedMap =
FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap,
cumulativeAdditionalCost);
- conflictLinkedMap.clear();
+ if (additionalChildHops != null) {
+ newChildHops.addAll(additionalChildHops);
}
+ return newChildHops;
+ }
- // Return the cumulative additional cost for resolving conflicts
- return cumulativeAdditionalCost[0];
+ private static List<Hop> rewireOuterTransRead(List<Hop> childHops,
String hopName, Map<String, List<Hop>> outerTransTable) {
+ List<Hop> newChildHops = new ArrayList<>(childHops);
+ List<Hop> additionalChildHops = outerTransTable.get(hopName);
+ if (additionalChildHops != null) {
+ newChildHops.addAll(additionalChildHops);
+ }
+ return newChildHops;
}
-}
+
+ /**
+ * Enumerates federated execution plans for a given Hop.
+ * This method calculates the self cost and child costs for the Hop,
+ * generates federated plan variants for both LOUT and FOUT output
types,
+ * and prunes redundant plans before adding them to the memo table.
+ *
+ * @param hop The Hop for which to enumerate federated plans.
+ * @param memoTable The memoization table to store plan variants.
+ * @param childHops The list of child hops.
+ * @param weight The weight associated with the current Hop.
+ */
+ private static void enumerateFedPlan(Hop hop, FederatedMemoTable
memoTable, List<Hop> childHops, double weight){
+ long hopID = hop.getHopID();
+ HopCommon hopCommon = new HopCommon(hop, weight);
+ double selfCost =
FederatedPlanCostEstimator.computeHopCost(hopCommon);
+
+ FedPlanVariants lOutFedPlanVariants = new
FedPlanVariants(hopCommon, FederatedOutput.LOUT);
+ FedPlanVariants fOutFedPlanVariants = new
FedPlanVariants(hopCommon, FederatedOutput.FOUT);
+
+ int numInputs = childHops.size();
+ int numInitInputs = hop.getInput().size();
+
+ double[][] childCumulativeCost = new double[numInputs][2]; //
# of child, LOUT/FOUT of child
+ double[] childForwardingCost = new double[numInputs]; // # of
child
+
+ // The self cost follows its own weight, while the forwarding
cost follows the parent's weight.
+ FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable,
childHops, childCumulativeCost, childForwardingCost);
+
+ if (numInitInputs == numInputs){
+ enumerateOnlyInitChildFedPlan(lOutFedPlanVariants,
fOutFedPlanVariants, numInitInputs, childHops, childCumulativeCost,
childForwardingCost, selfCost);
+ } else {
+ enumerateTReadInitChildFedPlan(lOutFedPlanVariants,
fOutFedPlanVariants, numInitInputs, numInputs, childHops, childCumulativeCost,
childForwardingCost, selfCost);
+ }
+
+ // Prune the FedPlans to remove redundant plans
+ lOutFedPlanVariants.pruneFedPlans();
+ fOutFedPlanVariants.pruneFedPlans();
+
+ // Add the FedPlanVariants to the memo table
+ memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT,
lOutFedPlanVariants);
+ memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT,
fOutFedPlanVariants);
+ }
+
+ /**
+ * Enumerates federated execution plans for initial child hops only.
+ * This method generates all possible combinations of federated output
types (LOUT and FOUT)
+ * for the initial child hops and calculates their cumulative costs.
+ *
+ * @param lOutFedPlanVariants The FedPlanVariants object for LOUT
output type.
+ * @param fOutFedPlanVariants The FedPlanVariants object for FOUT
output type.
+ * @param numInitInputs The number of initial input hops.
+ * @param childHops The list of child hops.
+ * @param childCumulativeCost The cumulative costs for each child hop.
+ * @param childForwardingCost The forwarding costs for each child hop.
+ * @param selfCost The self cost of the current hop.
+ */
+ private static void enumerateOnlyInitChildFedPlan(FedPlanVariants
lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, int numInitInputs,
List<Hop> childHops,
+ double[][] childCumulativeCost, double[]
childForwardingCost, double selfCost){
+ // Iterate 2^n times, generating two FedPlans (LOUT, FOUT)
each time.
+ for (int i = 0; i < (1 << numInitInputs); i++) {
+ double[] cumulativeCost = new double[]{selfCost,
selfCost};
+ List<Pair<Long, FederatedOutput>> planChilds = new
ArrayList<>();
+ // LOUT and FOUT share the same planChilds in each
iteration (only forwarding cost differs).
+ enumerateInitChildFedPlan(numInitInputs, childHops,
planChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i);
+
+ lOutFedPlanVariants.addFedPlan(new
FedPlan(cumulativeCost[0], lOutFedPlanVariants, planChilds));
+ fOutFedPlanVariants.addFedPlan(new
FedPlan(cumulativeCost[1], fOutFedPlanVariants, planChilds));
+ }
+ }
+
+ /**
+ * Enumerates federated execution plans for a TRead hop.
+ * This method calculates the cumulative costs for both LOUT and FOUT
federated output types
+ * by considering the additional child hops, which are TWrite hops.
+ * It generates all possible combinations of federated output types
for the initial child hops
+ * and adds the pre-calculated costs of the TWrite child hops to these
combinations.
+ *
+ * @param lOutFedPlanVariants The FedPlanVariants object for LOUT
output type.
+ * @param fOutFedPlanVariants The FedPlanVariants object for FOUT
output type.
+ * @param numInitInputs The number of initial input hops.
+ * @param numInputs The total number of input hops, including
additional TWrite hops.
+ * @param childHops The list of child hops.
+ * @param childCumulativeCost The cumulative costs for each child hop.
+ * @param childForwardingCost The forwarding costs for each child hop.
+ * @param selfCost The self cost of the current hop.
+ */
+ private static void enumerateTReadInitChildFedPlan(FedPlanVariants
lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants,
+ int numInitInputs, int numInputs,
List<Hop> childHops,
+ double[][] childCumulativeCost,
double[] childForwardingCost, double selfCost){
+ double lOutTReadCumulativeCost = selfCost;
+ double fOutTReadCumulativeCost = selfCost;
+
+ List<Pair<Long, FederatedOutput>> lOutTReadPlanChilds = new
ArrayList<>();
+ List<Pair<Long, FederatedOutput>> fOutTReadPlanChilds = new
ArrayList<>();
+
+ // Pre-calculate the cost for the additional child hop, which
is a TWrite hop, of the TRead hop.
+ // Constraint: TWrite must have the same FedOutType as TRead.
+ for (int j = numInitInputs; j < numInputs; j++) {
+ Hop inputHop = childHops.get(j);
+ lOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(),
FederatedOutput.LOUT));
+ fOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(),
FederatedOutput.FOUT));
+
+ lOutTReadCumulativeCost += childCumulativeCost[j][0];
+ fOutTReadCumulativeCost += childCumulativeCost[j][1];
+ // Skip TWrite -> TRead as they have the same
FedOutType.
+ }
+
+ for (int i = 0; i < (1 << numInitInputs); i++) {
+ double[] cumulativeCost = new double[]{selfCost,
selfCost};
+ List<Pair<Long, FederatedOutput>> lOutPlanChilds = new
ArrayList<>();
+ enumerateInitChildFedPlan(numInitInputs, childHops,
lOutPlanChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i);
+
+ // Copy lOutPlanChilds to create fOutPlanChilds and
add the pre-calculated cost of the TWrite child hop.
+ List<Pair<Long, FederatedOutput>> fOutPlanChilds = new
ArrayList<>(lOutPlanChilds);
+
+ lOutPlanChilds.addAll(lOutTReadPlanChilds);
+ fOutPlanChilds.addAll(fOutTReadPlanChilds);
+
+ cumulativeCost[0] += lOutTReadCumulativeCost;
+ cumulativeCost[1] += fOutTReadCumulativeCost;
+
+ lOutFedPlanVariants.addFedPlan(new
FedPlan(cumulativeCost[0], lOutFedPlanVariants, lOutPlanChilds));
+ fOutFedPlanVariants.addFedPlan(new
FedPlan(cumulativeCost[1], fOutFedPlanVariants, fOutPlanChilds));
+ }
+ }
+
+ // Calculates costs for initial child hops, determining FOUT or LOUT
based on `i`.
+ private static void enumerateInitChildFedPlan(int numInitInputs,
List<Hop> childHops, List<Pair<Long, FederatedOutput>> planChilds,
+ double[][] childCumulativeCost, double[]
childForwardingCost, double[] cumulativeCost, int i){
+ // For each input, determine if it should be FOUT or LOUT
based on bit j in i
+ for (int j = 0; j < numInitInputs; j++) {
+ Hop inputHop = childHops.get(j);
+ // Calculate the bit value to decide between FOUT and
LOUT for the current input
+ final int bit = (i & (1 << j)) != 0 ? 1 : 0; //
Determine the bit value (decides FOUT/LOUT)
+ final FederatedOutput childType = (bit == 1) ?
FederatedOutput.FOUT : FederatedOutput.LOUT;
+ planChilds.add(Pair.of(inputHop.getHopID(),
childType));
+
+ // Update the cumulative cost for LOUT, FOUT
+ cumulativeCost[0] += childCumulativeCost[j][bit] +
childForwardingCost[j] * bit;
+ cumulativeCost[1] += childCumulativeCost[j][bit] +
childForwardingCost[j] * (1 - bit);
+ }
+ }
+
+ // Creates a dummy root node (fedplan) and selects the FedPlan with
the minimum cost to return.
+ // The dummy root node does not have LOUT or FOUT.
+ private static FedPlan getMinCostRootFedPlan(Set<Hop> progRootHopSet,
FederatedMemoTable memoTable) {
+ double cumulativeCost = 0;
+ List<Pair<Long, FederatedOutput>> rootFedPlanChilds = new
ArrayList<>();
+
+ // Iterate over each Hop in the progRootHopSet
+ for (Hop endHop : progRootHopSet){
+ // Retrieve the pruned FedPlan for LOUT and FOUT from
the memo table
+ FedPlan lOutFedPlan =
memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.LOUT);
+ FedPlan fOutFedPlan =
memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.FOUT);
+
+ // Compare the cumulative costs of LOUT and FOUT
FedPlans
+ if (lOutFedPlan.getCumulativeCost() <=
fOutFedPlan.getCumulativeCost()){
+ cumulativeCost +=
lOutFedPlan.getCumulativeCost();
+
rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.LOUT));
+ } else{
+ cumulativeCost +=
fOutFedPlan.getCumulativeCost();
+
rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.FOUT));
+ }
+ }
+
+ return new FedPlan(cumulativeCost, null, rootFedPlanChilds);
+ }
+
+ /**
+ * Detects and resolves conflicts in federated plans starting from the
root plan.
+ * This function performs a breadth-first search (BFS) to traverse the
federated plan tree.
+ * It identifies conflicts where the same plan ID has different
federated output types.
+ * For each conflict, it records the plan ID and its conflicting
parent plans.
+ * The function ensures that each plan ID is associated with a
consistent federated output type
+ * by resolving these conflicts iteratively.
+ *
+ * The process involves:
+ * - Using a map to track conflicts, associating each plan ID with its
federated output type
+ * and a list of parent plans.
+ * - Storing detected conflicts in a linked map, each entry containing
a plan ID and its
+ * conflicting parent plans.
+ * - Performing BFS traversal starting from the root plan, checking
each child plan for conflicts.
+ * - If a conflict is detected (i.e., a plan ID has different output
types), the conflicting plan
+ * is removed from the BFS queue and added to the conflict map to
prevent duplicate calculations.
+ * - Resolving conflicts by ensuring a consistent federated output
type across the plan.
+ * - Re-running BFS with resolved conflicts to ensure all
inconsistencies are addressed.
+ *
+ * @param rootPlan The root federated plan from which to start the
conflict detection.
+ * @param memoTable The memoization table used to retrieve pruned
federated plans.
+ * @return The cumulative additional cost for resolving conflicts.
+ */
+ private static double detectAndResolveConflictFedPlan(FedPlan
rootPlan, FederatedMemoTable memoTable) {
+ // Map to track conflicts: maps a plan ID to its federated
output type and list of parent plans
+ Map<Long, Pair<FederatedOutput, List<FedPlan>>>
conflictCheckMap = new HashMap<>();
+
+ // LinkedMap to store detected conflicts, each with a plan ID
and its conflicting parent plans
+ LinkedHashMap<Long, List<FedPlan>> conflictLinkedMap = new
LinkedHashMap<>();
+
+ // LinkedMap for BFS traversal starting from the root plan (Do
not use value (boolean))
+ LinkedHashMap<FedPlan, Boolean> bfsLinkedMap = new
LinkedHashMap<>();
+ bfsLinkedMap.put(rootPlan, true);
+
+ // Array to store cumulative additional cost for resolving
conflicts
+ double[] cumulativeAdditionalCost = new double[]{0.0};
+
+ while (!bfsLinkedMap.isEmpty()) {
+ // Perform BFS to detect conflicts in federated plans
+ while (!bfsLinkedMap.isEmpty()) {
+ FedPlan currentPlan =
bfsLinkedMap.keySet().iterator().next();
+ bfsLinkedMap.remove(currentPlan);
+
+ // Iterate over each child plan of the current
plan
+ for (Pair<Long, FederatedOutput> childPlanPair
: currentPlan.getChildFedPlans()) {
+ FedPlan childFedPlan =
memoTable.getFedPlanAfterPrune(childPlanPair);
+
+ // Check if the child plan ID is
already visited
+ if
(conflictCheckMap.containsKey(childPlanPair.getLeft())) {
+ // Retrieve the existing
conflict pair for the child plan
+ Pair<FederatedOutput,
List<FedPlan>> conflictChildPlanPair =
conflictCheckMap.get(childPlanPair.getLeft());
+ // Add the current plan to the
list of parent plans
+
conflictChildPlanPair.getRight().add(currentPlan);
+
+ // If the federated output
type differs, a conflict is detected
+ if
(conflictChildPlanPair.getLeft() != childPlanPair.getRight()) {
+ // If this is the
first detection, remove conflictChildFedPlan from the BFS queue and add it to
the conflict linked map (queue)
+ // If the existing
FedPlan is not removed from the bfsqueue or both actions are performed,
duplicate calculations for the same FedPlan and its children occur
+ if
(!conflictLinkedMap.containsKey(childPlanPair.getLeft())) {
+
conflictLinkedMap.put(childPlanPair.getLeft(),
conflictChildPlanPair.getRight());
+
bfsLinkedMap.remove(childFedPlan);
+ }
+ }
+ } else {
+ // If no conflict exists,
create a new entry in the conflict check map
+ List<FedPlan>
parentFedPlanList = new ArrayList<>();
+
parentFedPlanList.add(currentPlan);
+
+ // Map the child plan ID to
its output type and list of parent plans
+
conflictCheckMap.put(childPlanPair.getLeft(), new
ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList));
+ // Add the child plan to the
BFS queue
+ bfsLinkedMap.put(childFedPlan,
true);
+ }
+ }
+ }
+ // Resolve these conflicts to ensure a consistent
federated output type across the plan
+ // Re-run BFS with resolved conflicts
+ bfsLinkedMap =
FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap,
cumulativeAdditionalCost);
+ conflictLinkedMap.clear();
+ }
+
+ // Return the cumulative additional cost for resolving
conflicts
+ return cumulativeAdditionalCost[0];
+ }
+ }
+
\ No newline at end of file
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
index 7bc7339563..9ff405ab28 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
@@ -17,224 +17,248 @@
* under the License.
*/
-package org.apache.sysds.hops.fedplanner;
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.cost.ComputeCost;
-import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
-import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
-
-import java.util.LinkedHashMap;
-import java.util.NoSuchElementException;
-import java.util.List;
-import java.util.Map;
-
-/**
- * Cost estimator for federated execution plans.
- * Calculates computation, memory access, and network transfer costs for
federated operations.
- * Works in conjunction with FederatedMemoTable to evaluate different
execution plan variants.
- */
-public class FederatedPlanCostEstimator {
- // Default value is used as a reasonable estimate since we only need
- // to compare relative costs between different federated plans
- // Memory bandwidth for local computations (25 GB/s)
- private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0;
- // Network bandwidth for data transfers between federated sites (1 Gbps)
- private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0;
-
- /**
- * Computes total cost of federated plan by:
- * 1. Computing current node cost (if not cached)
- * 2. Adding minimum-cost child plans
- * 3. Including network transfer costs when needed
- *
- * @param currentPlan Plan to compute cost for
- * @param memoTable Table containing all plan variants
- */
- public static void computeFederatedPlanCost(FedPlan currentPlan,
FederatedMemoTable memoTable) {
- double totalCost;
- Hop currentHop = currentPlan.getHopRef();
-
- // Step 1: Calculate current node costs if not already computed
- if (currentPlan.getSelfCost() == 0) {
- // Compute cost for current node (computation + memory
access)
- totalCost = computeCurrentCost(currentHop);
- currentPlan.setSelfCost(totalCost);
- // Calculate potential network transfer cost if
federation type changes
-
currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
- } else {
- totalCost = currentPlan.getSelfCost();
- }
-
- // Step 2: Process each child plan and add their costs
- for (Pair<Long, FederatedOutput> childPlanPair :
currentPlan.getChildFedPlans()) {
- // Find minimum cost child plan considering federation
type compatibility
- // Note: This approach might lead to suboptimal or
wrong solutions when a child has multiple parents
- // because we're selecting child plans independently
for each parent
- FedPlan planRef =
memoTable.getMinCostFedPlan(childPlanPair);
-
- // Add child plan cost (includes network transfer cost
if federation types differ)
- totalCost += planRef.getTotalCost() +
planRef.getCondNetTransferCost(currentPlan.getFedOutType());
- }
-
- // Step 3: Set final cumulative cost including current node
- currentPlan.setTotalCost(totalCost);
- }
-
- /**
- * Resolves conflicts in federated plans where different plans have
different FederatedOutput types.
- * This function traverses the list of conflicting plans in reverse
order to ensure that conflicts
- * are resolved from the bottom-up, allowing for consistent federated
output types across the plan.
- * It calculates additional costs for each potential resolution and
updates the cumulative additional cost.
- *
- * @param memoTable The FederatedMemoTable containing all federated
plan variants.
- * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent
plans with conflicting federated outputs.
- * @param cumulativeAdditionalCost An array to store the cumulative
additional cost incurred by resolving conflicts.
- * @return A LinkedHashMap of resolved federated plans, marked with a
boolean indicating resolution status.
- */
- public static LinkedHashMap<FedPlan, Boolean>
resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap<Long,
List<FedPlan>> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) {
- // LinkedHashMap to store resolved federated plans for BFS
traversal.
- LinkedHashMap<FedPlan, Boolean> resolvedFedPlanLinkedMap = new
LinkedHashMap<>();
-
- // Traverse the conflictFedPlanList in reverse order after BFS
to resolve conflicts
- for (Map.Entry<Long, List<FedPlan>> conflictFedPlanPair :
conflictFedPlanLinkedMap.entrySet()) {
- long conflictHopID = conflictFedPlanPair.getKey();
- List<FedPlan> conflictParentFedPlans =
conflictFedPlanPair.getValue();
-
- // Retrieve the conflicting federated plans for LOUT
and FOUT types
- FedPlan confilctLOutFedPlan =
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT);
- FedPlan confilctFOutFedPlan =
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT);
-
- // Variables to store additional costs for LOUT and
FOUT types
- double lOutAdditionalCost = 0;
- double fOutAdditionalCost = 0;
-
- // Flags to check if the plan involves network transfer
- // Network transfer cost is calculated only once, even
if it occurs multiple times
- boolean isLOutNetTransfer = false;
- boolean isFOutNetTransfer = false;
-
- // Determine the optimal federated output type based on
the calculated costs
- FederatedOutput optimalFedOutType;
-
- // Iterate over each parent federated plan in the
current conflict pair
- for (FedPlan conflictParentFedPlan :
conflictParentFedPlans) {
- // Find the calculated FedOutType of the child
plan
- Pair<Long, FederatedOutput>
cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream()
- .filter(pair ->
pair.getLeft().equals(conflictHopID))
- .findFirst()
- .orElseThrow(() -> new
NoSuchElementException("No matching pair found for ID: " + conflictHopID));
-
- // CASE 1. Calculated LOUT / Parent LOUT /
Current LOUT: Total cost remains unchanged.
- // CASE 2. Calculated LOUT / Parent FOUT /
Current LOUT: Total cost remains unchanged, subtract net cost, add net cost
later.
- // CASE 3. Calculated FOUT / Parent LOUT /
Current LOUT: Change total cost, subtract net cost.
- // CASE 4. Calculated FOUT / Parent FOUT /
Current LOUT: Change total cost, add net cost later.
- // CASE 5. Calculated LOUT / Parent LOUT /
Current FOUT: Change total cost, add net cost later.
- // CASE 6. Calculated LOUT / Parent FOUT /
Current FOUT: Change total cost, subtract net cost.
- // CASE 7. Calculated FOUT / Parent LOUT /
Current FOUT: Total cost remains unchanged, subtract net cost, add net cost
later.
- // CASE 8. Calculated FOUT / Parent FOUT /
Current FOUT: Total cost remains unchanged.
-
- // Adjust LOUT, FOUT costs based on the
calculated plan's output type
- if (cacluatedConflictPlanPair.getRight() ==
FederatedOutput.LOUT) {
- // When changing from calculated LOUT
to current FOUT, subtract the existing LOUT total cost and add the FOUT total
cost
- // When maintaining calculated LOUT to
current LOUT, the total cost remains unchanged.
- fOutAdditionalCost +=
confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost();
-
- if
(conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) {
- // (CASE 1) Previously,
calculated was LOUT and parent was LOUT, so no network transfer cost occurred
- // (CASE 5) If changing from
calculated LOUT to current FOUT, network transfer cost occurs, but calculated
later
- isFOutNetTransfer = true;
- } else {
- // Previously, calculated was
LOUT and parent was FOUT, so network transfer cost occurred
- // (CASE 2) If maintaining calculated LOUT to current
LOUT, subtract existing network transfer cost and calculate later
- isLOutNetTransfer = true;
- lOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
-
- // (CASE 6) If changing from
calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it
- fOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
- }
- } else {
- lOutAdditionalCost +=
confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost();
-
- if
(conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) {
- isLOutNetTransfer = true;
- } else {
- isFOutNetTransfer = true;
- lOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
- fOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
- }
- }
- }
-
- // Add network transfer costs if applicable
- if (isLOutNetTransfer) {
- lOutAdditionalCost +=
confilctLOutFedPlan.getNetTransferCost();
- }
- if (isFOutNetTransfer) {
- fOutAdditionalCost +=
confilctFOutFedPlan.getNetTransferCost();
- }
-
- // Determine the optimal federated output type based on
the calculated costs
- if (lOutAdditionalCost <= fOutAdditionalCost) {
- optimalFedOutType = FederatedOutput.LOUT;
- cumulativeAdditionalCost[0] +=
lOutAdditionalCost;
-
resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true);
- } else {
- optimalFedOutType = FederatedOutput.FOUT;
- cumulativeAdditionalCost[0] +=
fOutAdditionalCost;
-
resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true);
- }
-
- // Update only the optimal federated output type, not
the cost itself or recursively
- for (FedPlan conflictParentFedPlan :
conflictParentFedPlans) {
- for (Pair<Long, FederatedOutput> childPlanPair
: conflictParentFedPlan.getChildFedPlans()) {
- if (childPlanPair.getLeft() ==
conflictHopID && childPlanPair.getRight() != optimalFedOutType) {
- int index =
conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair);
-
conflictParentFedPlan.getChildFedPlans().set(index,
-
Pair.of(childPlanPair.getLeft(), optimalFedOutType));
- break;
- }
- }
- }
- }
- return resolvedFedPlanLinkedMap;
- }
-
- /**
- * Computes the cost for the current Hop node.
- *
- * @param currentHop The Hop node whose cost needs to be computed
- * @return The total cost for the current node's operation
- */
- private static double computeCurrentCost(Hop currentHop){
- double computeCost = ComputeCost.getHOPComputeCost(currentHop);
- double inputAccessCost =
computeHopMemoryAccessCost(currentHop.getInputMemEstimate());
- double ouputAccessCost =
computeHopMemoryAccessCost(currentHop.getOutputMemEstimate());
-
- // Compute total cost assuming:
- // 1. Computation and input access can be overlapped (hence
taking max)
- // 2. Output access must wait for both to complete (hence
adding)
- return Math.max(computeCost, inputAccessCost) + ouputAccessCost;
- }
-
- /**
- * Calculates the memory access cost based on data size and memory
bandwidth.
- *
- * @param memSize Size of data to be accessed (in bytes)
- * @return Time cost for memory access (in seconds)
- */
- private static double computeHopMemoryAccessCost(double memSize) {
- return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH;
- }
-
- /**
- * Calculates the network transfer cost based on data size and network
bandwidth.
- * Used when federation status changes between parent and child plans.
- *
- * @param memSize Size of data to be transferred (in bytes)
- * @return Time cost for network transfer (in seconds)
- */
- private static double computeHopNetworkAccessCost(double memSize) {
- return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH;
- }
-}
+ package org.apache.sysds.hops.fedplanner;
+ import org.apache.commons.lang3.tuple.Pair;
+ import org.apache.sysds.common.Types;
+ import org.apache.sysds.hops.DataOp;
+ import org.apache.sysds.hops.Hop;
+ import org.apache.sysds.hops.cost.ComputeCost;
+ import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
+ import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon;
+ import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+ import java.util.LinkedHashMap;
+ import java.util.NoSuchElementException;
+ import java.util.List;
+ import java.util.Map;
+
+ /**
+ * Cost estimator for federated execution plans.
+ * Calculates computation, memory access, and network transfer costs for
federated operations.
+ * Works in conjunction with FederatedMemoTable to evaluate different
execution plan variants.
+ */
+ public class FederatedPlanCostEstimator {
+ // Default value is used as a reasonable estimate since we only need
+ // to compare relative costs between different federated plans
+ // Memory bandwidth for local computations (25 GB/s)
+ private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0;
+ // Network bandwidth for data transfers between federated sites (1
Gbps)
+ private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0;
+
+ // Retrieves the cumulative and forwarding costs of the child hops and
stores them in arrays
+ public static void getChildCosts(HopCommon hopCommon,
FederatedMemoTable memoTable, List<Hop> inputHops,
+
double[][] childCumulativeCost, double[] childForwardingCost) {
+ for (int i = 0; i < inputHops.size(); i++) {
+ long childHopID = inputHops.get(i).getHopID();
+
+ FedPlan childLOutFedPlan =
memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT);
+ FedPlan childFOutFedPlan =
memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT);
+
+ // The cumulative cost of the child already includes
the weight
+ childCumulativeCost[i][0] =
childLOutFedPlan.getCumulativeCost();
+ childCumulativeCost[i][1] =
childFOutFedPlan.getCumulativeCost();
+
+ // TODO: Q. Shouldn't the child's forwarding cost
follow the parent's weight, regardless of loops or if-else statements?
+ childForwardingCost[i] = hopCommon.weight *
childLOutFedPlan.getForwardingCost();
+ }
+ }
+
+ /**
+ * Computes the cost associated with a given Hop node.
+ * This method calculates both the self cost and the forwarding cost
for the Hop,
+ * taking into account its type and the number of parent nodes.
+ *
+ * @param hopCommon The HopCommon object containing the Hop and its
properties.
+ * @return The self cost of the Hop.
+ */
+ public static double computeHopCost(HopCommon hopCommon){
+ // TWrite and TRead are meta-data operations, hence selfCost
is zero
+ if (hopCommon.hopRef instanceof DataOp){
+ if (((DataOp)hopCommon.hopRef).getOp() ==
Types.OpOpData.TRANSIENTWRITE ){
+ hopCommon.setSelfCost(0);
+ // Since TWrite and TRead have the same
FedOutType, forwarding cost is zero
+ hopCommon.setForwardingCost(0);
+ return 0;
+ } else if (((DataOp)hopCommon.hopRef).getOp() ==
Types.OpOpData.TRANSIENTREAD) {
+ hopCommon.setSelfCost(0);
+ // TRead may have a different FedOutType from
its parent, so calculate forwarding cost
+
hopCommon.setForwardingCost(computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate()));
+ return 0;
+ }
+ }
+
+ // In loops, selfCost is repeated, but forwarding may not be
+ // Therefore, the weight for forwarding follows the parent's
weight (TODO: Q. Is the parent also receiving forwarding once?)
+ double selfCost = hopCommon.weight *
computeSelfCost(hopCommon.hopRef);
+ double forwardingCost =
computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate());
+
+ int numParents = hopCommon.hopRef.getParent().size();
+ if (numParents >= 2) {
+ selfCost /= numParents;
+ forwardingCost /= numParents;
+ }
+
+ hopCommon.setSelfCost(selfCost);
+ hopCommon.setForwardingCost(forwardingCost);
+
+ return selfCost;
+ }
+
+ /**
+ * Computes the cost for the current Hop node.
+ *
+ * @param currentHop The Hop node whose cost needs to be computed
+ * @return The total cost for the current node's operation
+ */
+ private static double computeSelfCost(Hop currentHop){
+ double computeCost = ComputeCost.getHOPComputeCost(currentHop);
+ double inputAccessCost =
computeHopMemoryAccessCost(currentHop.getInputMemEstimate());
+ double ouputAccessCost =
computeHopMemoryAccessCost(currentHop.getOutputMemEstimate());
+
+ // Compute total cost assuming:
+ // 1. Computation and input access can be overlapped (hence
taking max)
+ // 2. Output access must wait for both to complete (hence
adding)
+ return Math.max(computeCost, inputAccessCost) +
ouputAccessCost;
+ }
+
+ /**
+ * Calculates the memory access cost based on data size and memory
bandwidth.
+ *
+ * @param memSize Size of data to be accessed (in bytes)
+ * @return Time cost for memory access (in seconds)
+ */
+ private static double computeHopMemoryAccessCost(double memSize) {
+ return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH;
+ }
+
+ /**
+ * Calculates the network transfer cost based on data size and network
bandwidth.
+ * Used when federation status changes between parent and child plans.
+ *
+ * @param memSize Size of data to be transferred (in bytes)
+ * @return Time cost for network transfer (in seconds)
+ */
+ private static double computeHopForwardingCost(double memSize) {
+ return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH;
+ }
+
+ /**
+ * Resolves conflicts in federated plans where different plans have
different FederatedOutput types.
+ * This function traverses the list of conflicting plans in reverse
order to ensure that conflicts
+ * are resolved from the bottom-up, allowing for consistent federated
output types across the plan.
+ * It calculates additional costs for each potential resolution and
updates the cumulative additional cost.
+ *
+ * @param memoTable The FederatedMemoTable containing all federated
plan variants.
+ * @param conflictFedPlanLinkedMap A map of plan IDs to lists of
parent plans with conflicting federated outputs.
+ * @param cumulativeAdditionalCost An array to store the cumulative
additional cost incurred by resolving conflicts.
+ * @return A LinkedHashMap of resolved federated plans, marked with a
boolean indicating resolution status.
+ */
+ public static LinkedHashMap<FedPlan, Boolean>
resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap<Long,
List<FedPlan>> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) {
+ // LinkedHashMap to store resolved federated plans for BFS
traversal.
+ LinkedHashMap<FedPlan, Boolean> resolvedFedPlanLinkedMap = new
LinkedHashMap<>();
+
+ // Traverse the conflictFedPlanList in reverse order after BFS
to resolve conflicts
+ for (Map.Entry<Long, List<FedPlan>> conflictFedPlanPair :
conflictFedPlanLinkedMap.entrySet()) {
+ long conflictHopID = conflictFedPlanPair.getKey();
+ List<FedPlan> conflictParentFedPlans =
conflictFedPlanPair.getValue();
+
+ // Retrieve the conflicting federated plans for LOUT
and FOUT types
+ FedPlan confilctLOutFedPlan =
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT);
+ FedPlan confilctFOutFedPlan =
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT);
+
+ // Variables to store additional costs for LOUT and
FOUT types
+ double lOutAdditionalCost = 0;
+ double fOutAdditionalCost = 0;
+
+ // Flags to check if the plan involves network transfer
+ // Network transfer cost is calculated only once, even
if it occurs multiple times
+ boolean isLOutForwarding = false;
+ boolean isFOutForwarding = false;
+
+ // Determine the optimal federated output type based
on the calculated costs
+ FederatedOutput optimalFedOutType;
+
+ // Iterate over each parent federated plan in the
current conflict pair
+ for (FedPlan conflictParentFedPlan :
conflictParentFedPlans) {
+ // Find the calculated FedOutType of the child
plan
+ Pair<Long, FederatedOutput>
cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream()
+ .filter(pair ->
pair.getLeft().equals(conflictHopID))
+ .findFirst()
+ .orElseThrow(() -> new
NoSuchElementException("No matching pair found for ID: " + conflictHopID));
+
+ // CASE 1. Calculated LOUT / Parent LOUT /
Current LOUT: Total cost remains unchanged.
+ // CASE 2. Calculated LOUT / Parent FOUT /
Current LOUT: Total cost remains unchanged, subtract net cost, add net cost
later.
+ // CASE 3. Calculated FOUT / Parent LOUT /
Current LOUT: Change total cost, subtract net cost.
+ // CASE 4. Calculated FOUT / Parent FOUT /
Current LOUT: Change total cost, add net cost later.
+ // CASE 5. Calculated LOUT / Parent LOUT /
Current FOUT: Change total cost, add net cost later.
+ // CASE 6. Calculated LOUT / Parent FOUT /
Current FOUT: Change total cost, subtract net cost.
+ // CASE 7. Calculated FOUT / Parent LOUT /
Current FOUT: Total cost remains unchanged, subtract net cost, add net cost
later.
+ // CASE 8. Calculated FOUT / Parent FOUT /
Current FOUT: Total cost remains unchanged.
+
+ // Adjust LOUT, FOUT costs based on the
calculated plan's output type
+ if (cacluatedConflictPlanPair.getRight() ==
FederatedOutput.LOUT) {
+ // When changing from calculated LOUT
to current FOUT, subtract the existing LOUT total cost and add the FOUT total
cost
+ // When maintaining calculated LOUT to
current LOUT, the total cost remains unchanged.
+ fOutAdditionalCost +=
confilctFOutFedPlan.getCumulativeCost() -
confilctLOutFedPlan.getCumulativeCost();
+
+ if
(conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) {
+ // (CASE 1) Previously,
calculated was LOUT and parent was LOUT, so no network transfer cost occurred
+ // (CASE 5) If changing from
calculated LOUT to current FOUT, network transfer cost occurs, but calculated
later
+ isFOutForwarding = true;
+ } else {
+ // Previously, calculated was
LOUT and parent was FOUT, so network transfer cost occurred
+ // (CASE 2) If maintaining
calculated LOUT to current LOUT, subtract existing network transfer cost and
calculate later
+ isLOutForwarding = true;
+ lOutAdditionalCost -=
confilctLOutFedPlan.getForwardingCost();
+
+ // (CASE 6) If changing from
calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it
+ fOutAdditionalCost -=
confilctLOutFedPlan.getForwardingCost();
+ }
+ } else {
+ lOutAdditionalCost +=
confilctLOutFedPlan.getCumulativeCost() -
confilctFOutFedPlan.getCumulativeCost();
+
+ if
(conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) {
+ isLOutForwarding = true;
+ } else {
+ isFOutForwarding = true;
+ lOutAdditionalCost -=
confilctLOutFedPlan.getForwardingCost();
+ fOutAdditionalCost -=
confilctLOutFedPlan.getForwardingCost();
+ }
+ }
+ }
+
+ // Add network transfer costs if applicable
+ if (isLOutForwarding) {
+ lOutAdditionalCost +=
confilctLOutFedPlan.getForwardingCost();
+ }
+ if (isFOutForwarding) {
+ fOutAdditionalCost +=
confilctFOutFedPlan.getForwardingCost();
+ }
+
+ // Determine the optimal federated output type based
on the calculated costs
+ if (lOutAdditionalCost <= fOutAdditionalCost) {
+ optimalFedOutType = FederatedOutput.LOUT;
+ cumulativeAdditionalCost[0] +=
lOutAdditionalCost;
+
resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true);
+ } else {
+ optimalFedOutType = FederatedOutput.FOUT;
+ cumulativeAdditionalCost[0] +=
fOutAdditionalCost;
+
resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true);
+ }
+
+ // Update only the optimal federated output type, not
the cost itself or recursively
+ for (FedPlan conflictParentFedPlan :
conflictParentFedPlans) {
+ for (Pair<Long, FederatedOutput> childPlanPair
: conflictParentFedPlan.getChildFedPlans()) {
+ if (childPlanPair.getLeft() ==
conflictHopID && childPlanPair.getRight() != optimalFedOutType) {
+ int index =
conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair);
+
conflictParentFedPlan.getChildFedPlans().set(index,
+
Pair.of(childPlanPair.getLeft(), optimalFedOutType));
+ break;
+ }
+ }
+ }
+ }
+ return resolvedFedPlanLinkedMap;
+ }
+ }
+
\ No newline at end of file
diff --git
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
index 20485588d3..0bc7d9f84f 100644
---
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
+++
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
@@ -17,75 +17,94 @@
* under the License.
*/
-package org.apache.sysds.test.component.federated;
+ package org.apache.sysds.test.component.federated;
-import java.io.IOException;
-import java.util.HashMap;
+ import java.io.IOException;
+ import java.util.HashMap;
+ import org.junit.Assert;
+ import org.junit.Test;
+ import org.apache.sysds.api.DMLScript;
+ import org.apache.sysds.conf.ConfigurationManager;
+ import org.apache.sysds.conf.DMLConfig;
+ import org.apache.sysds.parser.DMLProgram;
+ import org.apache.sysds.parser.DMLTranslator;
+ import org.apache.sysds.parser.ParserFactory;
+ import org.apache.sysds.parser.ParserWrapper;
+ import org.apache.sysds.test.AutomatedTestBase;
+ import org.apache.sysds.test.TestConfiguration;
+ import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator;
+
+ public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
+ {
+ private static final String TEST_DIR = "functions/federated/privacy/";
+ private static final String HOME = SCRIPT_DIR + TEST_DIR;
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {}
+
+ @Test
+ public void testFederatedPlanCostEnumerator1() {
runTest("FederatedPlanCostEnumeratorTest1.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator2() {
runTest("FederatedPlanCostEnumeratorTest2.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator3() {
runTest("FederatedPlanCostEnumeratorTest3.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator4() {
runTest("FederatedPlanCostEnumeratorTest4.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator5() {
runTest("FederatedPlanCostEnumeratorTest5.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator6() {
runTest("FederatedPlanCostEnumeratorTest6.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator7() {
runTest("FederatedPlanCostEnumeratorTest7.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator8() {
runTest("FederatedPlanCostEnumeratorTest8.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator9() {
runTest("FederatedPlanCostEnumeratorTest9.dml"); }
-import org.apache.sysds.hops.Hop;
-import org.junit.Assert;
-import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.parser.DMLProgram;
-import org.apache.sysds.parser.DMLTranslator;
-import org.apache.sysds.parser.ParserFactory;
-import org.apache.sysds.parser.ParserWrapper;
-import org.apache.sysds.test.AutomatedTestBase;
-import org.apache.sysds.test.TestConfiguration;
-import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator;
+ @Test
+ public void testFederatedPlanCostEnumerator10() {
runTest("FederatedPlanCostEnumeratorTest10.dml"); }
-
-public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
-{
- private static final String TEST_DIR = "functions/federated/privacy/";
- private static final String HOME = SCRIPT_DIR + TEST_DIR;
- private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/";
-
- @Override
- public void setUp() {}
-
- @Test
- public void testFederatedPlanCostEnumerator1() {
runTest("FederatedPlanCostEnumeratorTest1.dml"); }
-
- @Test
- public void testFederatedPlanCostEnumerator2() {
runTest("FederatedPlanCostEnumeratorTest2.dml"); }
-
- @Test
- public void testFederatedPlanCostEnumerator3() {
runTest("FederatedPlanCostEnumeratorTest3.dml"); }
-
- // Todo: Need to write test scripts for the federated version
- private void runTest( String scriptFilename ) {
- int index = scriptFilename.lastIndexOf(".dml");
- String testName = scriptFilename.substring(0, index > 0 ? index
: scriptFilename.length());
- TestConfiguration testConfig = new
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
- addTestConfiguration(testName, testConfig);
- loadTestConfiguration(testConfig);
-
- try {
- DMLConfig conf = new
DMLConfig(getCurConfigFile().getPath());
- ConfigurationManager.setLocalConfig(conf);
-
- //read script
- String dmlScriptString = DMLScript.readDMLScript(true,
HOME + scriptFilename);
-
- //parsing and dependency analysis
- ParserWrapper parser = ParserFactory.createParser();
- DMLProgram prog =
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new
HashMap<>());
- DMLTranslator dmlt = new DMLTranslator(prog);
- dmlt.liveVariableAnalysis(prog);
- dmlt.validateParseTree(prog);
- dmlt.constructHops(prog);
- dmlt.rewriteHopsDAG(prog);
- dmlt.constructLops(prog);
-
- Hop hops =
prog.getStatementBlocks().get(0).getHops().get(0);
-
FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true);
- }
- catch (IOException e) {
- e.printStackTrace();
- Assert.fail();
- }
- }
-}
+ // Todo: Need to write test scripts for the federated version
+ private void runTest( String scriptFilename ) {
+ int index = scriptFilename.lastIndexOf(".dml");
+ String testName = scriptFilename.substring(0, index > 0 ?
index : scriptFilename.length());
+ TestConfiguration testConfig = new
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+ addTestConfiguration(testName, testConfig);
+ loadTestConfiguration(testConfig);
+
+ try {
+ DMLConfig conf = new
DMLConfig(getCurConfigFile().getPath());
+ ConfigurationManager.setLocalConfig(conf);
+
+ //read script
+ String dmlScriptString = DMLScript.readDMLScript(true,
HOME + scriptFilename);
+
+ //parsing and dependency analysis
+ ParserWrapper parser = ParserFactory.createParser();
+ DMLProgram prog =
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new
HashMap<>());
+ DMLTranslator dmlt = new DMLTranslator(prog);
+ dmlt.liveVariableAnalysis(prog);
+ dmlt.validateParseTree(prog);
+ dmlt.constructHops(prog);
+ dmlt.rewriteHopsDAG(prog);
+ dmlt.constructLops(prog);
+ dmlt.rewriteLopDAG(prog);
+
+ FederatedPlanCostEnumerator.enumerateProgram(prog,
true);
+ }
+ catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ }
+ }
+
\ No newline at end of file
diff --git
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py
new file mode 100644
index 0000000000..b083c77913
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py
@@ -0,0 +1,268 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import sys
+import re
+import networkx as nx
+import matplotlib.pyplot as plt
+
+try:
+ import pygraphviz
+ from networkx.drawing.nx_agraph import graphviz_layout
+ HAS_PYGRAPHVIZ = True
+except ImportError:
+ HAS_PYGRAPHVIZ = False
+ print("[WARNING] pygraphviz not found. Please install via 'pip install
pygraphviz'.\n"
+ "If not installed, we will use an alternative layout
(spring_layout).")
+
+
+def parse_line(line: str):
+ """
+ Parse a single line from the trace file to extract:
+ - Node ID
+ - Operation (hop name)
+ - Kind (e.g., FOUT, LOUT, NREF)
+ - Total cost
+ - Weight
+ - Refs (list of IDs that this node depends on)
+ """
+
+ # 1) Match a node ID in the form of "(R)" or "(<number>)"
+ match_id = re.match(r'^\((R|\d+)\)', line)
+ if not match_id:
+ return None
+ node_id = match_id.group(1)
+
+ # 2) The remaining string after the node ID
+ after_id = line[match_id.end():].strip()
+
+ # Extract operation (hop name) before the first "["
+ match_label = re.search(r'^(.*?)\s*\[', after_id)
+ if match_label:
+ operation = match_label.group(1).strip()
+ else:
+ operation = after_id.strip()
+
+ # 3) Extract the kind (content inside the first pair of brackets "[]")
+ match_bracket = re.search(r'\[([^\]]+)\]', after_id)
+ if match_bracket:
+ kind = match_bracket.group(1).strip()
+ else:
+ kind = ""
+
+ # 4) Extract total and weight from the content inside curly braces "{}"
+ total = ""
+ weight = ""
+ match_curly = re.search(r'\{([^}]+)\}', line)
+ if match_curly:
+ curly_content = match_curly.group(1)
+ m_total = re.search(r'Total:\s*([\d\.]+)', curly_content)
+ m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content)
+ if m_total:
+ total = m_total.group(1)
+ if m_weight:
+ weight = m_weight.group(1)
+
+ # 5) Extract reference nodes: look for the first parenthesis containing
numbers after the hop name
+ match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id)
+ if match_refs:
+ ref_str = match_refs.group(1)
+ refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()]
+ else:
+ refs = []
+
+ return {
+ 'node_id': node_id,
+ 'operation': operation,
+ 'kind': kind,
+ 'total': total,
+ 'weight': weight,
+ 'refs': refs
+ }
+
+
+def build_dag_from_file(filename: str):
+ """
+ Read a trace file line by line and build a directed acyclic graph (DAG)
using NetworkX.
+ """
+ G = nx.DiGraph()
+ with open(filename, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+
+ info = parse_line(line)
+ if not info:
+ continue
+
+ node_id = info['node_id']
+ operation = info['operation']
+ kind = info['kind']
+ total = info['total']
+ weight = info['weight']
+ refs = info['refs']
+
+ # Add node with attributes
+ G.add_node(node_id, label=operation, kind=kind, total=total,
weight=weight)
+
+ # Add edges from references to this node
+ for r in refs:
+ if r not in G:
+ G.add_node(r, label=r, kind="", total="", weight="")
+ G.add_edge(r, node_id)
+ return G
+
+
+def main():
+ """
+ Main function that:
+ - Reads a filename from command-line arguments
+ - Builds a DAG from the file
+ - Draws and displays the DAG using matplotlib
+ """
+
+ # Get filename from command-line argument
+ if len(sys.argv) < 2:
+ print("[ERROR] No filename provided.\nUsage: python
plot_federated_dag.py <filename>")
+ sys.exit(1)
+ filename = sys.argv[1]
+
+ print(f"[INFO] Running with filename '{filename}'")
+
+ # Build the DAG
+ G = build_dag_from_file(filename)
+
+ # Print debug info: nodes and edges
+ print("Nodes:", G.nodes(data=True))
+ print("Edges:", list(G.edges()))
+
+ # Decide on layout
+ if HAS_PYGRAPHVIZ:
+ # graphviz_layout with rankdir=BT (bottom to top), etc.
+ pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5
-Granksep=0.8')
+ else:
+ # Fallback layout if pygraphviz is not installed
+ pos = nx.spring_layout(G, seed=42)
+
+ # Dynamically adjust figure size based on number of nodes
+ node_count = len(G.nodes())
+ fig_width = 10 + node_count / 10.0
+ fig_height = 6 + node_count / 10.0
+ plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300)
+ ax = plt.gca()
+ ax.set_facecolor('white')
+
+ # Generate labels for each node in the format:
+ # node_id: operation_name
+ # C<total> (W<weight>)
+ labels = {
+ n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total',
'')} (W{G.nodes[n].get('weight', '')})"
+ for n in G.nodes()
+ }
+
+ # Function to determine color based on 'kind'
+ def get_color(n):
+ k = G.nodes[n].get('kind', '').lower()
+ if k == 'fout':
+ return 'tomato'
+ elif k == 'lout':
+ return 'dodgerblue'
+ elif k == 'nref':
+ return 'mediumpurple'
+ else:
+ return 'mediumseagreen'
+
+ # Determine node shapes based on operation name:
+ # - '^' (triangle) if the label contains "twrite"
+ # - 's' (square) if the label contains "tread"
+ # - 'o' (circle) otherwise
+ triangle_nodes = [n for n in G.nodes() if 'twrite' in
G.nodes[n].get('label', '').lower()]
+ square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label',
'').lower()]
+ other_nodes = [
+ n for n in G.nodes()
+ if 'twrite' not in G.nodes[n].get('label', '').lower() and
+ 'tread' not in G.nodes[n].get('label', '').lower()
+ ]
+
+ # Colors for each group
+ triangle_colors = [get_color(n) for n in triangle_nodes]
+ square_colors = [get_color(n) for n in square_nodes]
+ other_colors = [get_color(n) for n in other_nodes]
+
+ # Draw nodes group-wise
+ node_collection_triangle = nx.draw_networkx_nodes(
+ G, pos, nodelist=triangle_nodes, node_size=800,
+ node_color=triangle_colors, node_shape='^', ax=ax
+ )
+ node_collection_square = nx.draw_networkx_nodes(
+ G, pos, nodelist=square_nodes, node_size=800,
+ node_color=square_colors, node_shape='s', ax=ax
+ )
+ node_collection_other = nx.draw_networkx_nodes(
+ G, pos, nodelist=other_nodes, node_size=800,
+ node_color=other_colors, node_shape='o', ax=ax
+ )
+
+ # Set z-order for nodes, edges, and labels
+ node_collection_triangle.set_zorder(1)
+ node_collection_square.set_zorder(1)
+ node_collection_other.set_zorder(1)
+
+ edge_collection = nx.draw_networkx_edges(G, pos, arrows=True,
arrowstyle='->', ax=ax)
+ if isinstance(edge_collection, list):
+ for ec in edge_collection:
+ ec.set_zorder(2)
+ else:
+ edge_collection.set_zorder(2)
+
+ label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9,
ax=ax)
+ for text in label_dict.values():
+ text.set_zorder(3)
+
+ # Set the title
+ plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold")
+
+ # Provide a small legend on the top-right or top-left
+ plt.text(1, 1,
+ "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))",
+ fontsize=12, ha='right', va='top', transform=ax.transAxes)
+
+ # Example mini-legend for different 'kind' values
+ plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes)
+ plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes)
+ plt.scatter(0.31, 0.95, color='mediumpurple', s=200,
transform=ax.transAxes)
+
+ plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center',
transform=ax.transAxes)
+ plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center',
transform=ax.transAxes)
+ plt.text(0.34, 0.95, "NREF", fontsize=12, va='center',
transform=ax.transAxes)
+
+ plt.axis("off")
+
+ # Save the plot to a file with the same name as the input file, but with a
.png extension
+ output_filename = f"{filename.rsplit('.', 1)[0]}.png"
+ plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight')
+
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git
a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml
new file mode 100644
index 0000000000..276de7bde9
--- /dev/null
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Recursive function: Calculate factorial
+factorialUser = function(int n) return (int result) {
+ if (n <= 1) {
+ result = 1; # base case
+ } else {
+ result = n * factorialUser(n - 1); # recursive call
+ }
+}
+
+number = 5;
+fact_result = factorialUser(number);
+print("Factorial of " + number + ": " + fact_result);
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml
new file mode 100644
index 0000000000..06533df144
--- /dev/null
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = matrix(7,10,10);
+if (sum(a) > 0.5)
+ b = a * 2;
+else
+ b = a * 3;
+c = sqrt(b);
+print(sum(c));
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml
new file mode 100644
index 0000000000..2721bbcbaf
--- /dev/null
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+for( i in 1:100 )
+{
+ b = i + 1;
+ print(b);
+}
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml
new file mode 100644
index 0000000000..b95ae1b5bb
--- /dev/null
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix(7, rows=10, cols=10)
+b = rand(rows = 1, cols = ncol(A), min = 1, max = 2);
+i = 0
+
+while (sum(b) < i) {
+ i = i + 1
+ b = b + i
+ A = A * A
+ s = b %*% A
+ print(mean(s))
+}
+c = sqrt(A)
+print(sum(c))
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml
new file mode 100644
index 0000000000..e3efaa2851
--- /dev/null
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = 1;
+
+parfor( i in 1:10 )
+{
+ b = i + a;
+ #print(b);
+}
diff --git
a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml
new file mode 100644
index 0000000000..1587ff613b
--- /dev/null
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = rand();
+b= rand();
+c= rand();
+d= rand();
+e= rand();
+f= rand();
+h= rand();
+i= rand();
+
+if (a < 30){
+ a = a + b;
+
+ if (a < 20) {
+ a = a * c;
+ } else {
+ a = a + d;
+
+ if (a < 10) {
+ a = a + e;
+ } else {
+ a = a + f;
+ }
+ }
+} else {
+ a = a + h;
+}
+c = a + i;
+print(mean(c))
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml
new file mode 100644
index 0000000000..b5713374f2
--- /dev/null
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml
@@ -0,0 +1,58 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Define UDFs
+meanUser = function (matrix[double] A) return (double m) {
+ m = sum(A)/nrow(A)
+}
+
+minMaxUser = function( matrix[double] M) return (double minVal, double maxVal)
{
+ minVal = min(M);
+ maxVal = max(M);
+}
+
+# Recursive function: Calculate factorial
+factorialUser = function(int n) return (int result) {
+ if (n <= 1) {
+ result = 1; # base case
+ } else {
+ result = n * factorialUser(n - 1); # recursive call
+ }
+}
+
+# Main script
+# 1. Create matrix and calculate statistics
+M = rand(rows=4, cols=4, min=1, max=5); # 4x4 random matrix
+avg = meanUser(M);
+[min_val, max_val] = minMaxUser(M);
+
+# 2. Call recursive function (factorial)
+number = 5;
+fact_result = factorialUser(number);
+
+# 3. Print results
+print("=== Matrix Statistics ===");
+print("Average: " + avg);
+print("Min: " + min_val + ", Max: " + max_val);
+
+print("\n=== Recursive Function ===");
+print("Factorial of " + number + ": " + fact_result);
\ No newline at end of file