This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new ccd6a36 [SYSTEMDS-3018] Conflict handling federated plan enumeration
ccd6a36 is described below
commit ccd6a360dc072e62f36f4f76e76234c8c3079487
Author: sebwrede <[email protected]>
AuthorDate: Mon Dec 27 21:34:25 2021 +0100
[SYSTEMDS-3018] Conflict handling federated plan enumeration
Federated plan enumeration build a global data flow graph and computes
optimal plans per interesting property (fed-out, local-out). In trees,
we could purely compose optimal plans from optimal plans of inputs, but
in DAGs optimal input plans of n-ary operations might not agree on the
decisions of common subexpressions.
We mitigate this issue (fed-out vs local-out) decisions by keeping the
data federated, but additionally spawning an asynchronous prefetch
operation to also bring the data into local memory if at least one
subplan prefers local intermediates.
Closes #1476.
Co-authored-by: arnabp <[email protected]>
---
scripts/builtin/normalize.dml | 2 +-
src/main/java/org/apache/sysds/hops/Hop.java | 32 +++++-
.../sysds/hops/cost/FederatedCostEstimator.java | 8 +-
.../java/org/apache/sysds/hops/cost/HopRel.java | 27 ++---
.../hops/ipa/IPAPassRewriteFederatedPlan.java | 45 +++++---
.../java/org/apache/sysds/hops/ipa/MemoTable.java | 118 +++++++++++++++++++++
src/main/java/org/apache/sysds/lops/Lop.java | 20 ++++
.../java/org/apache/sysds/lops/compile/Dag.java | 42 +++++++-
.../controlprogram/federated/FederationMap.java | 7 +-
.../instructions/cp/BroadcastCPInstruction.java | 6 +-
.../instructions/cp/PrefetchCPInstruction.java | 6 +-
...sTask.java => TriggerRemoteOperationsTask.java} | 15 ++-
.../sysds/runtime/util/CommonThreadPool.java | 8 +-
.../java/org/apache/sysds/utils/Statistics.java | 8 ++
.../fedplanning/FederatedMultiplyPlanningTest.java | 56 ++++++++--
.../fedplanning/FederatedMultiplyPlanningTest7.dml | 29 +++++
.../FederatedMultiplyPlanningTest7Reference.dml | 27 +++++
.../fedplanning/FederatedMultiplyPlanningTest8.dml | 31 ++++++
.../FederatedMultiplyPlanningTest8Reference.dml | 29 +++++
19 files changed, 445 insertions(+), 71 deletions(-)
diff --git a/scripts/builtin/normalize.dml b/scripts/builtin/normalize.dml
index e2a32be..f7b86c2 100644
--- a/scripts/builtin/normalize.dml
+++ b/scripts/builtin/normalize.dml
@@ -39,6 +39,6 @@ m_normalize = function(Matrix[Double] X)
# compute feature ranges for transformations
cmin = colMins(X);
cmax = colMaxs(X);
- # normalize features to range [0,1]
+ # normalize features to range [0,1]
Y = normalizeApply(X, cmin, cmax);
}
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 336beb0..f47fcff 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -93,6 +93,14 @@ public abstract class Hop implements ParseInfo {
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
protected FederatedCost _federatedCost = new FederatedCost();
+
+ /**
+ * Field defining if prefetch should be activated for operation.
+ * When prefetch is activated, the output will be transferred from
+ * remote federated sites to local before one of the subsequent
+ * local operations.
+ */
+ protected boolean activatePrefetch;
// Estimated size for the output produced from this Hop in bytes
protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
@@ -187,6 +195,21 @@ public abstract class Hop implements ParseInfo {
public void setFederatedOutput(FederatedOutput federatedOutput){
_federatedOutput = federatedOutput;
}
+
+ /**
+ * Activate prefetch of HOP.
+ */
+ public void activatePrefetch(){
+ activatePrefetch = true;
+ }
+
+ /**
+ * Checks if prefetch is activated for this hop.
+ * @return true if prefetch is activated
+ */
+ public boolean prefetchActivated(){
+ return activatePrefetch;
+ }
public void resetExecType()
{
@@ -352,6 +375,8 @@ public abstract class Hop implements ParseInfo {
//propagate federated output configuration to lops
if( isFederated() )
getLops().setFederatedOutput(_federatedOutput);
+ if ( prefetchActivated() )
+ getLops().activatePrefetch();
//Step 1: construct reblock lop if required (output of hop)
constructAndSetReblockLopIfRequired();
@@ -869,8 +894,11 @@ public abstract class Hop implements ParseInfo {
* This method only has an effect if FEDERATED_COMPILATION is activated.
* Federated compilation is activated in OptimizerUtils.
*/
- protected void updateETFed(){
- if ( someInputFederated() || isFederatedDataOp() )
+ protected void updateETFed() {
+ boolean localOut = hasLocalOutput();
+ boolean fedIn = getInput().stream().anyMatch(
+ in -> in.hasFederatedOutput() &&
!(in.prefetchActivated() && localOut));
+ if( isFederatedDataOp() || fedIn )
_etype = ExecType.FED;
}
diff --git
a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 7089ed8..96a33d4 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops.cost;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.ipa.MemoTable;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
@@ -33,8 +34,6 @@ import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
/**
* Cost estimator for federated executions with methods and constants for
going through DML programs to estimate costs.
@@ -200,10 +199,9 @@ public class FederatedCostEstimator {
* @param hopRelMemo memo table of HopRels for calculating input costs
* @return cost estimation of Hop DAG starting from given root HopRel
*/
- public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>>
hopRelMemo){
+ public FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
// Check if root is in memo table.
- if ( hopRelMemo.containsKey(root.hopRef.getHopID())
- &&
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut ==
root.fedOut) ){
+ if ( hopRelMemo.containsHopRel(root) ){
return root.getCostObject();
}
else {
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index 6191a6c..b1cc6dd 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -21,15 +21,14 @@ package org.apache.sysds.hops.cost;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.ipa.MemoTable;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
-import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@@ -52,7 +51,7 @@ public class HopRel {
* @param fedOut FederatedOutput value assigned to this HopRel
* @param hopRelMemo memo table storing other HopRels including the
inputs of associatedHop
*/
- public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut,
Map<Long, List<HopRel>> hopRelMemo){
+ public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut,
MemoTable hopRelMemo){
hopRef = associatedHop;
this.fedOut = fedOut;
setInputDependency(hopRelMemo);
@@ -108,27 +107,15 @@ public class HopRel {
* @param hopRelMemo memo table storing HopRels
* @return FOUT HopRel found in hopRelMemo
*/
- private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>>
hopRelMemo){
- return
hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
- }
-
- /**
- * Get the HopRel with minimum cost for given hop
- * @param hopRelMemo memo table storing HopRels
- * @param input hop for which minimum cost HopRel is found
- * @return HopRel with minimum cost for given hop
- */
- private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop
input){
- return hopRelMemo.get(input.getHopID()).stream()
- .min(Comparator.comparingDouble(a -> a.cost.getTotal()))
- .orElseThrow(() -> new DMLException("No element in Memo
Table found for input"));
+ private HopRel getFOUTHopRel(Hop hop, MemoTable hopRelMemo){
+ return hopRelMemo.getFederatedOutputAlternativeOrNull(hop);
}
/**
* Set valid and optimal input dependency for this HopRel as a field.
* @param hopRelMemo memo table storing input HopRels
*/
- private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
+ private void setInputDependency(MemoTable hopRelMemo){
if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
if ( fedOut == FederatedOutput.FOUT &&
!hopRef.isFederatedDataOp() ) {
int lowestFOUTIndex = 0;
@@ -152,7 +139,7 @@ public class HopRel {
for(int i = 0; i < hopRef.getInput().size();
i++) {
if(i != lowestFOUTIndex) {
Hop input = hopRef.getInput(i);
- inputHopRels[i] =
getMinOfInput(hopRelMemo, input);
+ inputHopRels[i] =
hopRelMemo.getMinCostAlternative(input);
}
else {
inputHopRels[i] =
lowestFOUTHopRel;
@@ -162,7 +149,7 @@ public class HopRel {
} else {
inputDependency.addAll(
hopRef.getInput().stream()
- .map(input ->
getMinOfInput(hopRelMemo, input))
+
.map(hopRelMemo::getMinCostAlternative)
.collect(Collectors.toList()));
}
}
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index 8c8df49..59333ab 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -19,7 +19,6 @@
package org.apache.sysds.hops.ipa;
-import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
@@ -45,10 +44,9 @@ import
org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
-import java.util.Map;
+import java.util.Set;
/**
* This rewrite generates a federated execution plan by estimating and setting
costs and the FederatedOutput values of
@@ -57,7 +55,8 @@ import java.util.Map;
*/
public class IPAPassRewriteFederatedPlan extends IPAPass {
- private final static Map<Long, List<HopRel>> hopRelMemo = new
HashMap<>();
+ private final static MemoTable hopRelMemo = new MemoTable();
+ private final static Set<Long> hopRelUpdatedFinal = new HashSet<>();
/**
* Indicates if an IPA pass is applicable for the current configuration.
@@ -66,7 +65,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* @param fgraph function call graph
* @return true if federated compilation is activated.
*/
- @Override public boolean isApplicable(FunctionCallGraph fgraph) {
+ @Override
+ public boolean isApplicable(FunctionCallGraph fgraph) {
return OptimizerUtils.FEDERATED_COMPILATION;
}
@@ -79,7 +79,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* @param fcallSizes function call size infos
* @return false since the function call graph never has to be rebuilt
*/
- @Override public boolean rewriteProgram(DMLProgram prog,
FunctionCallGraph fgraph,
+ @Override
+ public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
FunctionCallSizeInfo fcallSizes) {
rewriteStatementBlocks(prog, prog.getStatementBlocks());
return false;
@@ -189,9 +190,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* @param root hop for which FederatedOutput needs to be set
*/
private void setFinalFedout(Hop root) {
- HopRel optimalRootHopRel =
hopRelMemo.get(root.getHopID()).stream()
- .min(Comparator.comparingDouble(HopRel::getCost))
- .orElseThrow(() -> new DMLException("Hop root " + root
+ " has no feasible federated output alternatives"));
+ HopRel optimalRootHopRel =
hopRelMemo.getMinCostAlternative(root);
setFinalFedout(root, optimalRootHopRel);
}
@@ -202,8 +201,21 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* @param rootHopRel from which FederatedOutput value and cost is
retrieved
*/
private void setFinalFedout(Hop root, HopRel rootHopRel) {
- updateFederatedOutput(root, rootHopRel);
- visitInputDependency(rootHopRel);
+ if ( hopRelUpdatedFinal.contains(root.getHopID()) ){
+ if((rootHopRel.hasLocalOutput() ^
root.hasLocalOutput()) && hopRelMemo.hasFederatedOutputAlternative(root)){
+ // Update with FOUT alternative without
visiting inputs
+ updateFederatedOutput(root,
hopRelMemo.getFederatedOutputAlternative(root));
+ root.activatePrefetch();
+ }
+ else {
+ // Update without visiting inputs
+ updateFederatedOutput(root, rootHopRel);
+ }
+ }
+ else {
+ updateFederatedOutput(root, rootHopRel);
+ visitInputDependency(rootHopRel);
+ }
}
/**
@@ -226,6 +238,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
root.setFederatedOutput(updateHopRel.getFederatedOutput());
root.setFederatedCost(updateHopRel.getCostObject());
+ hopRelUpdatedFinal.add(root.getHopID());
}
/**
@@ -257,7 +270,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
*/
private void visitFedPlanHop(Hop currentHop) {
// If the currentHop is in the hopRelMemo table, it means that
it has been visited
- if(hopRelMemo.containsKey(currentHop.getHopID()))
+ if(hopRelMemo.containsHop(currentHop))
return;
// If the currentHop has input, then the input should be
visited depth-first
if(currentHop.getInput() != null &&
currentHop.getInput().size() > 0) {
@@ -273,7 +286,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
}
if(hopRels.isEmpty())
hopRels.add(new HopRel(currentHop,
FEDInstruction.FederatedOutput.NONE, hopRelMemo));
- hopRelMemo.put(currentHop.getHopID(), hopRels);
+ hopRelMemo.put(currentHop, hopRels);
}
/**
@@ -319,8 +332,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
if(associatedHop instanceof AggUnaryOp &&
associatedHop.isScalar())
return false;
// It can only be FOUT if at least one of the inputs are FOUT,
except if it is a federated DataOp
- if(associatedHop.getInput().stream().noneMatch(input ->
hopRelMemo.get(input.getHopID()).stream()
- .anyMatch(HopRel::hasFederatedOutput)) &&
!associatedHop.isFederatedDataOp())
+
if(associatedHop.getInput().stream().noneMatch(hopRelMemo::hasFederatedOutputAlternative)
+ && !associatedHop.isFederatedDataOp())
return false;
return true;
}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
b/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
new file mode 100644
index 0000000..c1aeff6
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.ipa;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.cost.HopRel;
+
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Memoization of federated execution alternatives.
+ * This memoization data structure is used when generating optimal federated
execution plans.
+ * The alternative executions are stored as HopRels and the methods of this
class are used to
+ * add, update, and retrieve the alternatives.
+ */
+public class MemoTable {
+ //TODO refactoring: could we generalize the privacy and codegen memo
tables into
+ // a generic implementation (e.g., MemoTable<HopRel>) that can be
reused in both?
+
+ /**
+ * Map holding the relation between Hop IDs and execution plan
alternatives.
+ */
+ private final static Map<Long, List<HopRel>> hopRelMemo = new
HashMap<>();
+
+ /**
+ * Get the HopRel with minimum cost for given root hop
+ * @param root hop for which minimum cost HopRel is found
+ * @return HopRel with minimum cost for given hop
+ */
+ public HopRel getMinCostAlternative(Hop root){
+ return hopRelMemo.get(root.getHopID()).stream()
+ .min(Comparator.comparingDouble(HopRel::getCost))
+ .orElseThrow(() -> new DMLException("Hop root " + root
+ " has no feasible federated output alternatives"));
+ }
+
+ /**
+ * Checks if any of the federated execution alternatives for the given
root hop has federated output.
+ * @param root hop for which execution alternatives are checked
+ * @return true if root has federated output as an execution alternative
+ */
+ public boolean hasFederatedOutputAlternative(Hop root){
+ return
hopRelMemo.get(root.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput);
+ }
+
+ /**
+ * Get the federated output alternative for given root hop or throw
exception if not found.
+ * @param root hop for which federated output HopRel is returned
+ * @return federated output HopRel for given root hop
+ */
+ public HopRel getFederatedOutputAlternative(Hop root){
+ return getFederatedOutputAlternativeOptional(root).orElseThrow(
+ () -> new DMLException("Hop root " + root + " has no
FOUT alternative"));
+ }
+
+ /**
+ * Get the federated output alternative for given root hop or null if
not found.
+ * @param root hop for which federated output HopRel is returned
+ * @return federated output HopRel for given root hop
+ */
+ public HopRel getFederatedOutputAlternativeOrNull(Hop root){
+ return getFederatedOutputAlternativeOptional(root).orElse(null);
+ }
+
+ private Optional<HopRel> getFederatedOutputAlternativeOptional(Hop
root){
+ return
hopRelMemo.get(root.getHopID()).stream().filter(HopRel::hasFederatedOutput).findFirst();
+ }
+
+ /**
+ * Memoize hopRels related to given root.
+ * @param root for which hopRels are added
+ * @param hopRels execution alternatives related to the given root
+ */
+ public void put(Hop root, List<HopRel> hopRels){
+ hopRelMemo.put(root.getHopID(), hopRels);
+ }
+
+ /**
+ * Checks if root hop has been added to memo.
+ * @param root hop
+ * @return true if root has been added to memo.
+ */
+ public boolean containsHop(Hop root){
+ return hopRelMemo.containsKey(root.getHopID());
+ }
+
+ /**
+ * Checks if given HopRel has been added to memo.
+ * @param root HopRel
+ * @return true if root HopRel has been added to memo.
+ */
+ public boolean containsHopRel(HopRel root){
+ return containsHop(root.getHopRef())
+ && hopRelMemo.get(root.getHopRef().getHopID()).stream()
+ .anyMatch(h -> h.getFederatedOutput() ==
root.getFederatedOutput());
+ }
+}
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java
b/src/main/java/org/apache/sysds/lops/Lop.java
index 7da091f..dda7cdd 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -117,6 +117,14 @@ public abstract class Lop
protected PrivacyConstraint privacyConstraint;
/**
+ * Field defining if prefetch should be activated for operation.
+ * When prefetch is activated, the output will be transferred from
+ * remote federated sites to local before one of the subsequent
+ * local operations.
+ */
+ protected boolean activatePrefetch;
+
+ /**
* Enum defining if the output of the operation should be forced
federated, forced local or neither.
* If it is FOUT, the output should be kept at federated sites.
* If it is LOUT, the output should be retrieved by the coordinator.
@@ -316,9 +324,21 @@ public abstract class Lop
return privacyConstraint;
}
+ public void activatePrefetch(){
+ activatePrefetch = true;
+ }
+
+ public boolean prefetchActivated(){
+ return activatePrefetch;
+ }
+
public void setFederatedOutput(FederatedOutput fedOutput){
_fedOutput = fedOutput;
}
+
+ public FederatedOutput getFederatedOutput(){
+ return _fedOutput;
+ }
public void setConsumerCount(int cc) {
consumerCount = cc;
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 9b7f1e5..76090f0 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -203,7 +203,9 @@ public class Dag<N extends Lop>
List<Lop> node_pf = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS
? addPrefetchLop(node_v) : node_v;
List<Lop> node_bc = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS
? addBroadcastLop(node_pf) : node_pf;
// TODO: Merge via a single traversal of the nodes
-
+
+ prefetchFederated(node_bc);
+
// do greedy grouping of operations
ArrayList<Instruction> inst =
doPlainInstructionGen(sb, node_bc);
@@ -211,6 +213,42 @@ public class Dag<N extends Lop>
// cleanup instruction (e.g., create packed rmvar instructions)
return cleanupInstructions(inst);
}
+
+ /**
+ * Checks if the given input needs to be prefetched before executing
given lop.
+ * @param input to check for prefetch
+ * @param lop which possibly needs the input prefetched
+ * @return true if given input needs to be prefetched before lop
+ */
+ private boolean inputNeedsPrefetch(Lop input, Lop lop){
+ return input.prefetchActivated() && lop.getExecType() !=
ExecType.FED
+ && input.getFederatedOutput().isForcedFederated();
+ }
+
+ /**
+ * Add prefetch lop between input and lop.
+ * @param input to be prefetched
+ * @param lop for which the given input needs to be prefetched
+ */
+ private void addFedPrefetchLop(Lop input, Lop lop){
+ UnaryCP prefetch = new UnaryCP(input, OpOp1.PREFETCH,
input.getDataType(), input.getValueType(), ExecType.CP);
+ prefetch.addOutput(lop);
+ lop.replaceInput(input, prefetch);
+ input.removeOutput(lop);
+ }
+
+ /**
+ * Add prefetch lops where needed.
+ * @param lops for which prefetch lops could be added.
+ */
+ private void prefetchFederated(List<Lop> lops){
+ for ( Lop lop : lops ){
+ for ( Lop input : lop.getInputs() ){
+ if ( inputNeedsPrefetch(input, lop) )
+ addFedPrefetchLop(input, lop);
+ }
+ }
+ }
private static List<Lop> doTopologicalSortTwoLevelOrder(List<Lop> v) {
//partition nodes into leaf/inner nodes and dag root nodes,
@@ -251,7 +289,7 @@ public class Dag<N extends Lop>
}
return nodesWithPrefetch;
}
-
+
private static List<Lop> addBroadcastLop(List<Lop> nodes) {
List<Lop> nodesWithBroadcast = new ArrayList<>();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 39309d6..680f608 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -591,17 +591,16 @@ public class FederationMap {
}
// derive output type
switch(_type) {
- case FULL:
- _type = FType.FULL;
- break;
case ROW:
_type = FType.COL;
break;
case COL:
_type = FType.ROW;
break;
+ case FULL:
case PART:
- _type = FType.PART;
+ // FULL and PART are not changed
+ break;
default:
_type = FType.OTHER;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
index d29ef4c..2cc9d7c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -44,8 +44,8 @@ public class BroadcastCPInstruction extends
UnaryCPInstruction {
public void processInstruction(ExecutionContext ec) {
ec.setVariable(output.getName(), ec.getMatrixObject(input1));
- if (CommonThreadPool.triggerRDDPool == null)
- CommonThreadPool.triggerRDDPool =
Executors.newCachedThreadPool();
- CommonThreadPool.triggerRDDPool.submit(new
TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
+ if (CommonThreadPool.triggerRemoteOPsPool == null)
+ CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
+ CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 00f8ac2..9d95a58 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -49,8 +49,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction
{
// If the next instruction which takes this output as an input
comes before
// the prefetch thread triggers, that instruction will start
the operations.
// In that case this Prefetch instruction will act like a NOOP.
- if (CommonThreadPool.triggerRDDPool == null)
- CommonThreadPool.triggerRDDPool =
Executors.newCachedThreadPool();
- CommonThreadPool.triggerRDDPool.submit(new
TriggerRDDOperationsTask(ec.getMatrixObject(output)));
+ if (CommonThreadPool.triggerRemoteOPsPool == null)
+ CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
+ CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRDDOperationsTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
similarity index 76%
rename from
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRDDOperationsTask.java
rename to
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
index 0a4d1b5..6eea8c9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRDDOperationsTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
@@ -23,10 +23,10 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.utils.Statistics;
-public class TriggerRDDOperationsTask implements Runnable {
+public class TriggerRemoteOperationsTask implements Runnable {
MatrixObject _prefetchMO;
- public TriggerRDDOperationsTask(MatrixObject mo) {
+ public TriggerRemoteOperationsTask(MatrixObject mo) {
_prefetchMO = mo;
}
@@ -36,14 +36,19 @@ public class TriggerRDDOperationsTask implements Runnable {
synchronized (_prefetchMO) {
// Having this check if operations are pending inside
the
// critical section safeguards against concurrent rmVar.
- if (_prefetchMO.isPendingRDDOps()) {
+ if (_prefetchMO.isPendingRDDOps() ||
_prefetchMO.isFederated()) {
+ // TODO: Add robust runtime constraints for
federated prefetch
// Execute and bring the result to local
_prefetchMO.acquireReadAndRelease();
prefetched = true;
}
}
- if (DMLScript.STATISTICS && prefetched)
- Statistics.incSparkAsyncPrefetchCount(1);
+ if (DMLScript.STATISTICS && prefetched) {
+ if (_prefetchMO.isFederated())
+ Statistics.incFedAsyncPrefetchCount(1);
+ else
+ Statistics.incSparkAsyncPrefetchCount(1);
+ }
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index 2fc8049..abb1ced 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -49,7 +49,7 @@ public class CommonThreadPool implements ExecutorService
private static final int size =
InfrastructureAnalyzer.getLocalParallelism();
private static final ExecutorService shared = ForkJoinPool.commonPool();
private final ExecutorService _pool;
- public static ExecutorService triggerRDDPool = null;
+ public static ExecutorService triggerRemoteOPsPool = null;
public CommonThreadPool(ExecutorService pool) {
_pool = pool;
@@ -80,10 +80,10 @@ public class CommonThreadPool implements ExecutorService
}
public static void shutdownAsyncRDDPool() {
- if (triggerRDDPool != null) {
+ if (triggerRemoteOPsPool != null) {
//shutdown prefetch/broadcast thread pool
- triggerRDDPool.shutdown();
- triggerRDDPool = null;
+ triggerRemoteOPsPool.shutdown();
+ triggerRemoteOPsPool = null;
}
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 3940921..a3da2a7 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -177,6 +177,7 @@ public class Statistics
private static final LongAdder federatedGetCount = new LongAdder();
private static final LongAdder federatedExecuteInstructionCount = new
LongAdder();
private static final LongAdder federatedExecuteUDFCount = new
LongAdder();
+ private static final LongAdder fedAsyncPrefetchCount = new LongAdder();
private static LongAdder numNativeFailures = new LongAdder();
public static LongAdder numNativeLibMatrixMultCalls = new LongAdder();
@@ -443,6 +444,10 @@ public class Statistics
}
}
+ public static void incFedAsyncPrefetchCount(long c) {
+ fedAsyncPrefetchCount.add(c);
+ }
+
public static void startCompileTimer() {
if( DMLScript.STATISTICS )
compileStartTime = System.nanoTime();
@@ -550,6 +555,7 @@ public class Statistics
federatedGetCount.reset();
federatedExecuteInstructionCount.reset();
federatedExecuteUDFCount.reset();
+ fedAsyncPrefetchCount.reset();
DMLCompressionStatistics.reset();
}
@@ -1220,6 +1226,8 @@ public class Statistics
sb.append("Federated Execute (Inst, UDF):\t" +
federatedExecuteInstructionCount.longValue() + "/" +
federatedExecuteUDFCount.longValue() +
".\n");
+ sb.append("Federated prefetch count:\t" +
+ fedAsyncPrefetchCount.longValue() +
".\n");
}
if( transformEncoderCount.longValue() > 0) {
//TODO: Cleanup and condense
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index e0ef884..1e59b86 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -47,6 +47,8 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
private final static String TEST_NAME_4 =
"FederatedMultiplyPlanningTest4";
private final static String TEST_NAME_5 =
"FederatedMultiplyPlanningTest5";
private final static String TEST_NAME_6 =
"FederatedMultiplyPlanningTest6";
+ private final static String TEST_NAME_7 =
"FederatedMultiplyPlanningTest7";
+ private final static String TEST_NAME_8 =
"FederatedMultiplyPlanningTest8";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -64,6 +66,8 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
addTestConfiguration(TEST_NAME_4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_6, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_7, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_8, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
}
@Parameterized.Parameters
@@ -112,6 +116,18 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
federatedTwoMatricesSingleNodeTest(TEST_NAME_6,
expectedHeavyHitters);
}
+ @Test
+ public void federatedMultiplyDoubleHop() {
+ String[] expectedHeavyHitters = new String[]{"fed_*",
"fed_fedinit", "fed_r'", "fed_ba+*"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_7,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedMultiplyDoubleHop2() {
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit",
"fed_ba+*"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_8,
expectedHeavyHitters);
+ }
+
private void writeStandardMatrix(String matrixName, long seed){
writeStandardMatrix(matrixName, seed, new
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
}
@@ -158,6 +174,14 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
writeRowFederatedVector("Y1", 44);
writeRowFederatedVector("Y2", 21);
}
+ else if ( testName.equals(TEST_NAME_8) ){
+ writeColStandardMatrix("X1", 42, null);
+ writeColStandardMatrix("X2", 1340, null);
+ writeColStandardMatrix("Y1", 44, null);
+ writeColStandardMatrix("Y2", 21, null);
+ writeColStandardMatrix("W1", 76, null);
+ writeColStandardMatrix("W2", 11, null);
+ }
else {
writeStandardMatrix("X1", 42);
writeStandardMatrix("X2", 1340);
@@ -201,12 +225,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
- if ( testName.equals(TEST_NAME_4) ||
testName.equals(TEST_NAME_5) ){
- programArgs = new String[] {"-stats","-explain",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
- "Y1=" + input("Y1"),
- "Y2=" + input("Y2"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
- }
+ rewriteRealProgramArgs(testName, port1, port2);
runTest(true, false, null, -1);
OptimizerUtils.FEDERATED_COMPILATION = false;
@@ -215,6 +234,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
fullDMLScriptName = HOME + testName + "Reference.dml";
programArgs = new String[] {"-nvargs", "X1=" + input("X1"),
"X2=" + input("X2"), "Y1=" + input("Y1"),
"Y2=" + input("Y2"), "Z=" + expected("Z")};
+ rewriteReferenceProgramArgs(testName);
runTest(true, false, null, -1);
// compare via files
@@ -228,5 +248,29 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
+
+ private void rewriteRealProgramArgs(String testName, int port1, int
port2){
+ if ( testName.equals(TEST_NAME_4) ||
testName.equals(TEST_NAME_5) ){
+ programArgs = new String[] {"-stats","-explain",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+ } else if ( testName.equals(TEST_NAME_8) ){
+ programArgs = new String[] {"-stats","-explain",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2,
input("Y2")),
+ "W1=" + input("W1"),
+ "W2=" + input("W2"),
+ "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ }
+ }
+
+ private void rewriteReferenceProgramArgs(String testName){
+ if ( testName.equals(TEST_NAME_8) ){
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "W1=" + input("W1"), "W2="
+ input("W2"), "Z=" + expected("Z")};
+ }
+ }
}
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7.dml
new file mode 100644
index 0000000..5d4a1d3
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0),
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0),
list($r, $c)))
+Z0 = X * Y
+Z = t(Z0) %*% X
+Z1 = Z %*% t(colSums(Z0))
+write(Z1, $Z)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7Reference.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7Reference.dml
new file mode 100644
index 0000000..76212a0
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest7Reference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+Z0 = X * Y
+Z = t(Z0) %*% X
+Z1 = Z %*% t(colSums(Z0))
+write(Z1, $Z)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8.dml
new file mode 100644
index 0000000..5f3223c
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r, $c / 2), list(0, $c / 2),
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($r, $c / 2), list(0, $c / 2),
list($r, $c)))
+W = cbind(read($W1), read($W2))
+Z1 = Y
+Z2 = Z1 %*% t(X)
+Z3 = Z1 %*% t(W)
+Z4 = sum(Z3) * sum(Z2)
+write(Z4, $Z)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8Reference.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8Reference.dml
new file mode 100644
index 0000000..c8c1797
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest8Reference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = cbind(read($X1), read($X2))
+Y = cbind(read($Y1), read($Y2))
+W = cbind(read($W1), read($W2))
+Z1 = Y
+Z2 = Z1 %*% t(X)
+Z3 = Z1 %*% t(W)
+Z4 = sum(Z3) * sum(Z2)
+write(Z4, $Z)