This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0bca04715c [SYSTEMDS-3790] New Federated Planner MemoTable
0bca04715c is described below
commit 0bca04715c3280d5ab748976494b39ebba46889d
Author: min-guk <[email protected]>
AuthorDate: Thu Nov 21 09:24:36 2024 +0100
[SYSTEMDS-3790] New Federated Planner MemoTable
Closes #2141.
---
.../apache/sysds/hops/fedplanner/MemoTable.java | 160 ++++++++++++++++++
.../test/component/federated/MemoTableTest.java | 186 +++++++++++++++++++++
2 files changed, 346 insertions(+)
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
new file mode 100644
index 0000000000..8fce06b33e
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+
+/**
+ * A Memoization Table for managing federated plans (`FedPlan`) based on
+ * combinations of Hops and FTypes. Each combination is mapped to a list
+ * of possible execution plans, allowing for pruning and optimization.
+ */
+public class MemoTable {
+
+ // Maps combinations of Hop ID and FType to lists of FedPlans
+ private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> hopMemoTable
= new HashMap<>();
+
+ /**
+ * Represents a federated execution plan with its cost and associated
references.
+ */
+ public static class FedPlan {
+ @SuppressWarnings("unused")
+ private final Hop hopRef; // The
associated Hop object
+ private final double cost; // Cost of this
federated plan
+ @SuppressWarnings("unused")
+ private final List<Pair<Long, FType>> planRefs; // References
to dependent plans
+
+ public FedPlan(Hop hopRef, double cost, List<Pair<Long, FType>>
planRefs) {
+ this.hopRef = hopRef;
+ this.cost = cost;
+ this.planRefs = planRefs;
+ }
+
+ public double getCost() {
+ return cost;
+ }
+ }
+
+ /**
+ * Adds a single FedPlan to the memo table for a given Hop and FType.
+ * If the entry already exists, the new FedPlan is appended to the list.
+ *
+ * @param hop The Hop object.
+ * @param fType The associated FType.
+ * @param fedPlan The FedPlan to add.
+ */
+ public void addFedPlan(Hop hop, FType fType, FedPlan fedPlan) {
+ if (contains(hop, fType)) {
+ List<FedPlan> fedPlanList = get(hop, fType);
+ fedPlanList.add(fedPlan);
+ } else {
+ List<FedPlan> fedPlanList = new ArrayList<>();
+ fedPlanList.add(fedPlan);
+ hopMemoTable.put(new ImmutablePair<>(hop.getHopID(),
fType), fedPlanList);
+ }
+ }
+
+ /**
+ * Adds multiple FedPlans to the memo table for a given Hop and FType.
+ * If the entry already exists, the new FedPlans are appended to the
list.
+ *
+ * @param hop The Hop object.
+ * @param fType The associated FType.
+ * @param newFedPlanList The list of FedPlans to add.
+ */
+ public void addFedPlanList(Hop hop, FType fType, List<FedPlan>
fedPlanList) {
+ if (contains(hop, fType)) {
+ List<FedPlan> prevFedPlanList = get(hop, fType);
+ prevFedPlanList.addAll(fedPlanList);
+ } else {
+ hopMemoTable.put(new ImmutablePair<>(hop.getHopID(),
fType), fedPlanList);
+ }
+ }
+
+ /**
+ * Retrieves the list of FedPlans associated with a given Hop and FType.
+ *
+ * @param hop The Hop object.
+ * @param fType The associated FType.
+ * @return The list of FedPlans, or null if no entry exists.
+ */
+ public List<FedPlan> get(Hop hop, FType fType) {
+ return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(),
fType));
+ }
+
+ /**
+ * Checks if the memo table contains an entry for a given Hop and FType.
+ *
+ * @param hop The Hop object.
+ * @param fType The associated FType.
+ * @return True if the entry exists, false otherwise.
+ */
+ public boolean contains(Hop hop, FType fType) {
+ return hopMemoTable.containsKey(new
ImmutablePair<>(hop.getHopID(), fType));
+ }
+
+ /**
+ * Prunes the FedPlans associated with a specific Hop and FType,
+ * keeping only the plan with the minimum cost.
+ *
+ * @param hop The Hop object.
+ * @param fType The associated FType.
+ */
+ public void prunePlan(Hop hop, FType fType) {
+ prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(),
fType)));
+ }
+
+ /**
+ * Prunes all entries in the memo table, retaining only the minimum-cost
+ * FedPlan for each entry.
+ */
+ public void pruneAll() {
+ for (Map.Entry<Pair<Long, FType>, List<FedPlan>> entry :
hopMemoTable.entrySet()) {
+ prunePlan(entry.getValue());
+ }
+ }
+
+ /**
+ * Prunes the given list of FedPlans to retain only the plan with the
minimum cost.
+ *
+ * @param fedPlanList The list of FedPlans to prune.
+ */
+ private void prunePlan(List<FedPlan> fedPlanList) {
+ if (fedPlanList.size() > 1) {
+ // Find the FedPlan with the minimum cost
+ FedPlan minCostPlan = fedPlanList.stream()
+ .min(Comparator.comparingDouble(plan ->
plan.cost))
+ .orElse(null);
+
+ // Retain only the minimum cost plan
+ fedPlanList.clear();
+ fedPlanList.add(minCostPlan);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
new file mode 100644
index 0000000000..e3928c1263
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.federated;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FTypes;
+import org.apache.sysds.hops.fedplanner.MemoTable;
+import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan;
+import org.apache.commons.lang3.tuple.Pair;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.when;
+
+public class MemoTableTest {
+
+ private MemoTable memoTable;
+
+ @Mock
+ private Hop mockHop1;
+
+ @Mock
+ private Hop mockHop2;
+
+ private java.util.Random rand;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.openMocks(this);
+ memoTable = new MemoTable();
+
+ // Set up unique IDs for mock Hops
+ when(mockHop1.getHopID()).thenReturn(1L);
+ when(mockHop2.getHopID()).thenReturn(2L);
+
+ // Initialize random generator with fixed seed for reproducible
tests
+ rand = new java.util.Random(42);
+ }
+
+ @Test
+ public void testAddAndGetSingleFedPlan() {
+ // Initialize test data
+ List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+ FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
+
+ // Verify initial state
+ List<FedPlan> result = memoTable.get(mockHop1,
FTypes.FType.FULL);
+ assertNull("Initial FedPlan list should be null before adding
any plans", result);
+
+ // Add single FedPlan
+ memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
+
+ // Verify after addition
+ result = memoTable.get(mockHop1, FTypes.FType.FULL);
+ assertNotNull("FedPlan list should exist after adding a plan",
result);
+ assertEquals("FedPlan list should contain exactly one plan", 1,
result.size());
+ assertEquals("FedPlan cost should be exactly 10.0", 10.0,
result.get(0).getCost(), 0.001);
+ }
+
+ @Test
+ public void testAddMultipleDuplicatedFedPlans() {
+ // Initialize test data with duplicate costs
+ List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+ List<FedPlan> fedPlans = new ArrayList<>();
+ fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs)); // Unique
cost
+ fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // First
duplicate
+ fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // Second
duplicate
+
+ // Add multiple plans including duplicates
+ memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans);
+
+ // Verify handling of duplicate plans
+ List<FedPlan> result = memoTable.get(mockHop1,
FTypes.FType.FULL);
+ assertNotNull("FedPlan list should exist after adding multiple
plans", result);
+ assertEquals("FedPlan list should maintain all plans including
duplicates", 3, result.size());
+ }
+
+ @Test
+ public void testContains() {
+ // Initialize test data
+ List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+ FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
+
+ // Verify initial state
+ assertFalse("MemoTable should not contain any entries
initially",
+ memoTable.contains(mockHop1, FTypes.FType.FULL));
+
+ // Add plan and verify presence
+ memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
+
+ assertTrue("MemoTable should contain entry after adding
FedPlan",
+ memoTable.contains(mockHop1, FTypes.FType.FULL));
+ assertFalse("MemoTable should not contain entries for different
Hop",
+ memoTable.contains(mockHop2, FTypes.FType.FULL));
+ }
+
+ @Test
+ public void testPrunePlanPruneAll() {
+ // Initialize base test data
+ List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+ // Create separate FedPlan lists for independent testing of
each Hop
+ List<FedPlan> fedPlans1 = new ArrayList<>(); // Plans for
mockHop1
+ List<FedPlan> fedPlans2 = new ArrayList<>(); // Plans for
mockHop2
+
+ // Generate random cost FedPlans for both Hops
+ double minCost = Double.MAX_VALUE;
+ int size = 100;
+ for(int i = 0; i < size; i++) {
+ double cost = rand.nextDouble() * 1000; // Random cost
between 0 and 1000
+ fedPlans1.add(new FedPlan(mockHop1, cost, planRefs));
+ fedPlans2.add(new FedPlan(mockHop2, cost, planRefs));
+ minCost = Math.min(minCost, cost);
+ }
+
+ // Add FedPlan lists to MemoTable
+ memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL,
fedPlans1);
+ memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL,
fedPlans2);
+
+ // Test selective pruning on mockHop1
+ memoTable.prunePlan(mockHop1, FTypes.FType.FULL);
+
+ // Get results for verification
+ List<FedPlan> result1 = memoTable.get(mockHop1,
FTypes.FType.FULL);
+ List<FedPlan> result2 = memoTable.get(mockHop2,
FTypes.FType.FULL);
+
+ // Verify selective pruning results
+ assertNotNull("Pruned mockHop1 should maintain a FedPlan list",
result1);
+ assertEquals("Pruned mockHop1 should contain exactly one
minimum cost plan", 1, result1.size());
+ assertEquals("Pruned mockHop1's plan should have the minimum
cost", minCost, result1.get(0).getCost(), 0.001);
+
+ // Verify unpruned Hop state
+ assertNotNull("Unpruned mockHop2 should maintain a FedPlan
list", result2);
+ assertEquals("Unpruned mockHop2 should maintain all original
plans", size, result2.size());
+
+ // Add additional plans to both Hops
+ for(int i = 0; i < size; i++) {
+ double cost = rand.nextDouble() * 1000;
+ memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new
FedPlan(mockHop1, cost, planRefs));
+ memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new
FedPlan(mockHop2, cost, planRefs));
+ minCost = Math.min(minCost, cost);
+ }
+
+ // Test global pruning
+ memoTable.pruneAll();
+
+ // Verify global pruning results
+ assertNotNull("mockHop1 should maintain a FedPlan list after
global pruning", result1);
+ assertEquals("mockHop1 should contain exactly one minimum cost
plan after global pruning",
+ 1, result1.size());
+ assertEquals("mockHop1's plan should have the global minimum
cost",
+ minCost, result1.get(0).getCost(), 0.001);
+
+ assertNotNull("mockHop2 should maintain a FedPlan list after
global pruning", result2);
+ assertEquals("mockHop2 should contain exactly one minimum cost
plan after global pruning",
+ 1, result2.size());
+ assertEquals("mockHop2's plan should have the global minimum
cost",
+ minCost, result2.get(0).getCost(), 0.001);
+ }
+}