This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 438a8ed69c [SYSTEMDS-3018] Dynamic Recompilation of Functions with
Federated Planner
438a8ed69c is described below
commit 438a8ed69c2d311276615cf65964d3bce82a91ce
Author: Kevin Innerebner <[email protected]>
AuthorDate: Thu Jun 30 17:37:03 2022 +0200
[SYSTEMDS-3018] Dynamic Recompilation of Functions with Federated Planner
This commit recomputes the federated plan if a function is called multiple
times with different federated arguments.
Closes #1649.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 4 +-
src/main/java/org/apache/sysds/hops/Hop.java | 2 +
.../sysds/hops/fedplanner/AFederatedPlanner.java | 10 ++
.../hops/fedplanner/FederatedPlannerCostbased.java | 120 +++++++++++--
.../hops/fedplanner/FederatedPlannerFedAll.java | 19 ++-
.../apache/sysds/hops/fedplanner/MemoTable.java | 2 +-
.../RewriteAlgebraicSimplificationDynamic.java | 5 +-
.../controlprogram/FunctionProgramBlock.java | 54 ++++++
.../controlprogram/federated/FederationMap.java | 5 +-
.../parfor/opt/OptTreeConverter.java | 3 +
.../runtime/instructions/FEDInstructionParser.java | 9 +
.../fed/AggregateUnaryFEDInstruction.java | 6 +-
.../instructions/fed/IndexingFEDInstruction.java | 5 +-
.../sysds/runtime/util/ProgramConverter.java | 3 +-
.../fedplanning/FederatedDynamicPlanningTest.java | 187 +++++++++++++++++++++
.../FederatedDynamicFunctionPlanningTest.dml | 37 ++++
...deratedDynamicFunctionPlanningTestReference.dml | 35 ++++
17 files changed, 484 insertions(+), 22 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 16b42c4840..dd04307229 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -645,6 +645,7 @@ public class AggBinaryOp extends MultiThreadedHop {
int k =
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
matmultCP = new
MatMultCP(getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getDataType(), getValueType(), et, k);
+ updateLopFedOut(matmultCP);
}
setOutputDimensions(matmultCP);
}
@@ -668,7 +669,8 @@ public class AggBinaryOp extends MultiThreadedHop {
new Transform(lY, ReOrgOp.TRANS, getDataType(),
getValueType(), inputReorgExecType, k);
tY.getOutputParameters().setDimensions(Y.getDim2(),
Y.getDim1(), getBlocksize(), Y.getNnz());
setLineNumbers(tY);
- updateLopFedOut(tY);
+ if (Y.hasFederatedOutput())
+ updateLopFedOut(tY);
//matrix mult
Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(),
getValueType(), et, k); //CP or FED
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 2ee317f35e..4d1dff8f22 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -1594,6 +1594,8 @@ public abstract class Hop implements ParseInfo {
_etype = that._etype;
_etypeForced = that._etypeForced;
+ _federatedOutput = that._federatedOutput;
+ _federatedCost = that._federatedCost;
_outputMemEstimate = that._outputMemEstimate;
_memEstimate = that._memEstimate;
_processingMemEstimate = that._processingMemEstimate;
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
index 6486ead712..1b4382bb05 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
@@ -37,6 +37,8 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
public abstract class AFederatedPlanner {
@@ -51,6 +53,14 @@ public abstract class AFederatedPlanner {
public abstract void rewriteProgram( DMLProgram prog,
FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes );
+ /**
+ * Selects a federated execution plan for the given function, taking
into account
+ * federation types of the arguments.
+ *
+ * @param function The function statement block to recompile.
+ * @param funcArgs The function arguments.
+ */
+ public abstract void rewriteFunctionDynamic(FunctionStatementBlock
function, LocalVariableMap funcArgs);
protected boolean allowsFederated(Hop hop, Map<Long, FType> fedHops) {
//generically obtain the input FTypes
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index 3c33d783ab..368882793a 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -42,6 +42,7 @@ import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
@@ -53,6 +54,11 @@ import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.IntObject;
import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.utils.Explain;
import org.apache.sysds.utils.Explain.ExplainType;
@@ -60,16 +66,17 @@ import org.apache.sysds.utils.Explain.ExplainType;
public class FederatedPlannerCostbased extends AFederatedPlanner {
private static final Log LOG =
LogFactory.getLog(FederatedPlannerCostbased.class.getName());
- private final static MemoTable hopRelMemo = new MemoTable();
+ private final MemoTable hopRelMemo = new MemoTable();
/**
* IDs of hops for which the final fedout value has been set.
*/
- private final static Set<Long> hopRelUpdatedFinal = new HashSet<>();
+ private final Set<Long> hopRelUpdatedFinal = new HashSet<>();
/**
* Terminal hops in DML program given to this rewriter.
*/
- private final static List<Hop> terminalHops = new ArrayList<>();
- private final static Map<String, Hop> transientWrites = new HashMap<>();
+ private final List<Hop> terminalHops = new ArrayList<>();
+ private final Map<String, Hop> transientWrites = new HashMap<>();
+ private LocalVariableMap localVariableMap = new LocalVariableMap();
public List<Hop> getTerminalHops(){
return terminalHops;
@@ -82,7 +89,15 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
setFinalFedouts();
updateExplain();
}
-
+
+ @Override
+ public void rewriteFunctionDynamic(FunctionStatementBlock function,
LocalVariableMap funcArgs) {
+ localVariableMap = funcArgs;
+ rewriteStatementBlock(function.getDMLProg(), function, null);
+ setFinalFedouts();
+ updateExplain();
+ }
+
/**
* Estimates cost and enumerates federated execution plans in
hopRelMemo.
* The method calls the contained statement blocks recursively.
@@ -149,10 +164,18 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
selectFederatedExecutionPlan(forSB.getFromHops(), paramMap);
selectFederatedExecutionPlan(forSB.getToHops(), paramMap);
selectFederatedExecutionPlan(forSB.getIncrementHops(),
paramMap);
+
+ // add iter variable to local variable map allowing us to
reason over transient reads in the HOP DAG
+ DataIdentifier iterVar = ((ForStatement)
forSB.getStatement(0)).getIterablePredicate().getIterVar();
+ LocalVariableMap tmpLocalVariableMap = localVariableMap;
+ localVariableMap = (LocalVariableMap) localVariableMap.clone();
+ // value doesn't matter, localVariableMap is just used to check
if the variable is federated
+ localVariableMap.put(iterVar.getName(), new IntObject(-1));
for(Statement statement : forSB.getStatements()) {
ForStatement forStatement = ((ForStatement) statement);
forStatement.setBody(rewriteStatementBlocks(prog,
forStatement.getBody(), paramMap));
}
+ localVariableMap = tmpLocalVariableMap;
return new ArrayList<>(Collections.singletonList(forSB));
}
@@ -170,22 +193,64 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
selectFederatedExecutionPlan(sbHop, paramMap);
if(sbHop instanceof FunctionOp) {
String funcName = ((FunctionOp)
sbHop).getFunctionName();
+ String funcNamespace = ((FunctionOp)
sbHop).getFunctionNamespace();
Map<String, Hop> funcParamMap =
FederatedPlannerUtils.getParamMap((FunctionOp) sbHop);
if ( paramMap != null && funcParamMap
!= null)
funcParamMap.putAll(paramMap);
paramMap = funcParamMap;
- FunctionStatementBlock sbFuncBlock =
prog.getBuiltinFunctionDictionary().getFunction(funcName);
+ FunctionStatementBlock sbFuncBlock =
prog.getFunctionDictionary(funcNamespace)
+ .getFunction(funcName);
rewriteStatementBlock(prog,
sbFuncBlock, paramMap);
+
+ FunctionStatement funcStatement =
(FunctionStatement) sbFuncBlock.getStatement(0);
+ mapFunctionOutputs((FunctionOp) sbHop,
funcStatement);
}
}
}
return new ArrayList<>(Collections.singletonList(sb));
}
+ /**
+ * Saves the HOPs (TWrite) of the function return values for
+ * the variable name used when calling the function.
+ *
+ * Example:
+ * <code>
+ * f = function() return (matrix[double] model) {a = rand(1, 1);}
+ * b = f();
+ * </code>
+ * This function saves the HOP writing to <code>a</code> for identifier
<code>b</code>.
+ *
+ * @param sbHop The <code>FunctionOp</code> for the call
+ * @param funcStatement The <code>FunctionStatement</code> of the
called function
+ */
+ private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement
funcStatement) {
+ for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i)
{
+ Hop outputWrite =
transientWrites.get(funcStatement.getOutputParams().get(i).getName());
+ transientWrites.put(sbHop.getOutputVariableNames()[i],
outputWrite);
+ }
+ }
+
+ /**
+ * Return parameter map containing the mapping from parameter name to
input hop
+ * for all parameters of the function hop.
+ * @param funcOp hop for which the mapping of parameter names to input
hops are made
+ * @return parameter map or empty map if function has no parameters
+ */
+ private Map<String,Hop> getParamMap(FunctionOp funcOp){
+ String[] inputNames = funcOp.getInputVariableNames();
+ Map<String,Hop> paramMap = new HashMap<>();
+ if ( inputNames != null ){
+ for ( int i = 0; i < funcOp.getInput().size(); i++ )
+ paramMap.put(inputNames[i],funcOp.getInput(i));
+ }
+ return paramMap;
+ }
+
/**
* Set final fedouts of all hops starting from terminal hops.
*/
- private void setFinalFedouts(){
+ public void setFinalFedouts(){
for ( Hop root : terminalHops)
setFinalFedout(root);
}
@@ -248,6 +313,7 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
root.setFederatedCost(updateHopRel.getCostObject());
root.setForcedExecType(updateHopRel.getExecType());
forceFixedFedOut(root);
+
LOG.trace("Updated fedOut to " +
updateHopRel.getFederatedOutput() + " for hop "
+ root.getHopID() + " opcode: " + root.getOpString());
hopRelUpdatedFinal.add(root.getHopID());
@@ -340,7 +406,14 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
*/
private ArrayList<HopRel> getFedPlans(Hop currentHop, Map<String, Hop>
paramMap){
ArrayList<HopRel> hopRels = new ArrayList<>();
- ArrayList<Hop> inputHops = getHopInputs(currentHop, paramMap);
+ ArrayList<Hop> inputHops = currentHop.getInput();
+ if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTREAD) ) {
+ inputHops = getTransientInputs(currentHop, paramMap);
+ if (inputHops == null) {
+ // check if transient read on a runtime
variable (only when planning during dynamic recompilation)
+ return createHopRelsFromRuntimeVars(currentHop,
hopRels);
+ }
+ }
if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTWRITE) )
transientWrites.put(currentHop.getName(), currentHop);
if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.FEDERATED) )
@@ -352,6 +425,24 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
return hopRels;
}
+ private ArrayList<HopRel> createHopRelsFromRuntimeVars(Hop currentHop,
ArrayList<HopRel> hopRels) {
+ Data variable = localVariableMap.get(currentHop.getName());
+ if (variable == null) {
+ throw new DMLRuntimeException("Transient write not
found for " + currentHop);
+ }
+ FederationMap fedMapping = null;
+ if (variable instanceof CacheableData<?>) {
+ CacheableData<?> cacheable = (CacheableData<?>)
variable;
+ fedMapping = cacheable.getFedMapping();
+ }
+ if(fedMapping != null)
+ hopRels.add(new HopRel(currentHop,
FederatedOutput.FOUT, fedMapping.getType(), hopRelMemo,
+ new ArrayList<>()));
+ else
+ hopRels.add(new HopRel(currentHop,
FederatedOutput.LOUT, hopRelMemo, new ArrayList<>()));
+ return hopRels;
+ }
+
/**
* Get transient inputs from either paramMap or transientWrites.
* Inputs from paramMap has higher priority than inputs from
transientWrites.
@@ -360,13 +451,20 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
* @return inputs of currentHop
*/
private ArrayList<Hop> getTransientInputs(Hop currentHop, Map<String,
Hop> paramMap){
+ // FIXME: does not work for function calls (except when the
return names match the variables their results are assigned to)
+ // `model = l2svm(...)` works (because `m_l2svm =
function(...) return(Matrix[Double] model)`),
+ // `m = l2svm(...)` does not
Hop tWriteHop = null;
if ( paramMap != null)
tWriteHop = paramMap.get(currentHop.getName());
if ( tWriteHop == null )
tWriteHop = transientWrites.get(currentHop.getName());
- if ( tWriteHop == null )
- throw new DMLRuntimeException("Transient write not
found for " + currentHop);
+ if ( tWriteHop == null ) {
+ if(localVariableMap.get(currentHop.getName()) != null)
+ return null;
+ else
+ throw new DMLRuntimeException("Transient write
not found for " + currentHop);
+ }
else
return new
ArrayList<>(Collections.singletonList(tWriteHop));
}
@@ -431,7 +529,7 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
/**
* Add hopRelMemo to Explain class to get explain info related to
federated enumeration.
*/
- private void updateExplain(){
+ public void updateExplain(){
if (DMLScript.EXPLAIN == ExplainType.HOPS)
Explain.setMemo(hopRelMemo);
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
index f11bbd1c13..4bf2e5606a 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
@@ -41,6 +41,9 @@ import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.instructions.cp.Data;
import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
/**
@@ -59,7 +62,21 @@ public class FederatedPlannerFedAll extends
AFederatedPlanner {
for(StatementBlock sb : prog.getStatementBlocks())
rRewriteStatementBlock(sb, fedVars);
}
-
+
+ @Override
+ public void rewriteFunctionDynamic(FunctionStatementBlock function,
LocalVariableMap funcArgs) {
+ Map<String, FType> fedVars = new HashMap<>();
+ for(Map.Entry<String, Data> varName : funcArgs.entrySet()) {
+ Data data = varName.getValue();
+ FType fType = null;
+ if(data instanceof CacheableData<?> &&
((CacheableData<?>) data).isFederated()) {
+ fType = ((CacheableData<?>)
data).getFedMapping().getType();
+ }
+ fedVars.put(varName.getKey(), fType);
+ }
+ rRewriteStatementBlock(function, fedVars);
+ }
+
private void rRewriteStatementBlock(StatementBlock sb, Map<String,
FType> fedVars) {
//TODO currently this rewrite assumes consistent decisions in
conditional control flow
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
index 5b399bd499..f84aecc5e8 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -45,7 +45,7 @@ public class MemoTable {
/**
* Map holding the relation between Hop IDs and execution plan
alternatives.
*/
- private final static Map<Long, List<HopRel>> hopRelMemo = new
HashMap<>();
+ private final Map<Long, List<HopRel>> hopRelMemo = new HashMap<>();
/**
* Get list of strings representing the different
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index c142dab12f..e181c60a78 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -19,6 +19,7 @@
package org.apache.sysds.hops.rewrite;
+import org.apache.sysds.common.Types;
import static org.apache.sysds.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
import java.util.ArrayList;
@@ -2398,7 +2399,9 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
}
//rewire parent-child operators if rewrite applied
- if( ternop != null ) {
+ if( ternop != null ) {
+ if (right.getForcedExecType() ==
Types.ExecType.FED)
+
ternop.setForcedExecType(Types.ExecType.FED);
HopRewriteUtils.replaceChildReference(parent,
hi, ternop, pos);
hi = ternop;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index f8cded464f..d03c459d0b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -29,11 +29,18 @@ import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FunctionBlock;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.fedplanner.AFederatedPlanner;
+import org.apache.sysds.hops.fedplanner.FTypes;
+import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.recompile.Recompiler.ResetType;
import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.util.ProgramConverter;
@@ -50,6 +57,7 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
private boolean _recompileOnce = false;
private boolean _nondeterministic = false;
+ private boolean _isFedPlan = false;
public FunctionProgramBlock( Program prog, List<DataIdentifier>
inputParams, List<DataIdentifier> outputParams) {
super(prog);
@@ -121,7 +129,14 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
boolean codegen =
ConfigurationManager.isCodegenEnabled();
boolean singlenode =
DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE;
ResetType reset = (codegen || singlenode) ?
ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
+
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, false,
reset);
+ if (shouldRunFedPlanner(ec)) {
+
recompileFederatedPlan((LocalVariableMap) ec.getVariables().clone());
+ // recreate instructions/LOPs for new
updated HOPs
+
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, false,
reset);
+ }
+
if( DMLScript.STATISTICS ){
long t1 = System.nanoTime();
@@ -151,6 +166,45 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
checkOutputParameters(ec.getVariables());
}
+ private boolean shouldRunFedPlanner(ExecutionContext ec) {
+ String planner =
ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.FEDERATED_PLANNER);
+ if (!OptimizerUtils.FEDERATED_COMPILATION &&
!FTypes.FederatedPlanner.isCompiled(planner))
+ return false;
+ for (String varName : ec.getVariables().keySet()) {
+ Data variable = ec.getVariable(varName);
+ if (variable instanceof CacheableData<?> &&
((CacheableData<?>) variable).isFederated()) {
+ _isFedPlan = true;
+ return true;
+ }
+ }
+ if (_isFedPlan) {
+ _isFedPlan = false;
+ // current function uses HOPs with FED execution type.
Remove the forced FED execution type by running
+ // planner again
+ return true;
+ }
+ else {
+ return false;
+ }
+ }
+
+ /**
+ * Recompile the HOPs of the function, keeping federation in mind.
+ * @param variableMap The variable map for the function arguments
+ */
+ private void recompileFederatedPlan(LocalVariableMap variableMap) {
+ String splanner = ConfigurationManager.getDMLConfig()
+ .getTextValue(DMLConfig.FEDERATED_PLANNER);
+ AFederatedPlanner planner =
FTypes.FederatedPlanner.isCompiled(splanner) ?
+
FTypes.FederatedPlanner.valueOf(splanner.toUpperCase()).getPlanner() :
+ new FederatedPlannerCostbased();
+ if (planner == null)
+ // unreachable, if planner does not support compilation
cost based would be chosen
+ throw new DMLRuntimeException(
+ "Recompilation chose to apply federation
planner, but configured planner does not support compilation.");
+ planner.rewriteFunctionDynamic((FunctionStatementBlock) _sb,
variableMap);
+ }
+
protected void checkOutputParameters( LocalVariableMap vars )
{
for( DataIdentifier diOut : _outputParams ) {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index fcef0d7984..d1d766b958 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -652,8 +652,9 @@ public class FederationMap {
while(iter.hasNext()) {
Entry<FederatedRange, FederatedData> e = iter.next();
FederatedRange range = e.getKey();
- long rs = range.getBeginDims()[0], re =
range.getEndDims()[0], cs = range.getBeginDims()[1],
- ce = range.getEndDims()[1];
+ // ends converted from exclusive to inclusive
+ long rs = range.getBeginDims()[0], re =
range.getEndDims()[0] - 1, cs = range.getBeginDims()[1],
+ ce = range.getEndDims()[1] - 1;
boolean overlap = ((ixrange.colStart <= ce) &&
(ixrange.colEnd >= cs) && (ixrange.rowStart <= re) &&
(ixrange.rowEnd >= rs));
if(!overlap)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
index 296d40d72c..38e57429c7 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
@@ -464,6 +464,9 @@ public class OptTreeConverter
node.setExecType(ExecType.CP); break;
case SPARK:
node.setExecType(ExecType.SPARK); break;
+ // TODO: create execution mode for parfor loop
+ case FED:
+ node.setExecType(ExecType.CP); break;
default:
throw new
DMLRuntimeException("Unsupported optnode exec type: "+et);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 8e5e673e1d..17f448588e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -20,6 +20,8 @@
package org.apache.sysds.runtime.instructions;
import org.apache.sysds.lops.Append;
+import org.apache.sysds.lops.LeftIndex;
+import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
import
org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
@@ -30,6 +32,7 @@ import
org.apache.sysds.runtime.instructions.fed.CentralMomentFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.CovarianceFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
+import org.apache.sysds.runtime.instructions.fed.IndexingFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuantilePickFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuantileSortFEDInstruction;
@@ -53,6 +56,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "uark+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uack+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uamax" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uacmax" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uamin" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uasqk+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uarsqk+" ,
FEDType.AggregateUnary );
@@ -93,6 +97,9 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "qsort", FEDType.QSort);
String2FEDInstructionType.put( "qpick", FEDType.QPick);
+ String2FEDInstructionType.put(RightIndex.OPCODE,
FEDType.MatrixIndexing);
+ String2FEDInstructionType.put(LeftIndex.OPCODE,
FEDType.MatrixIndexing);
+
String2FEDInstructionType.put(Append.OPCODE, FEDType.Append);
}
@@ -138,6 +145,8 @@ public class FEDInstructionParser extends InstructionParser
return
QuantileSortFEDInstruction.parseInstruction(str, true);
case QPick:
return
QuantilePickFEDInstruction.parseInstruction(str);
+ case MatrixIndexing:
+ return
IndexingFEDInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid
FEDERATED Instruction Type: " + fedtype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 6a89a33eb5..1dfc6ed3b9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -151,9 +151,9 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
*/
private void deriveNewOutputFedMapping(MatrixObject in, MatrixObject
out, FederatedRequest fr1){
//Get agg type
- if ( !(instOpcode.equals("uack+") ||
instOpcode.equals("uark+")) )
- throw new DMLRuntimeException("Operation " + instOpcode
+ " is unknown to FOUT processing");
- boolean isColAgg = instOpcode.equals("uack+");
+ //if ( !(instOpcode.equals("uack+") ||
instOpcode.equals("uark+")) )
+ // throw new DMLRuntimeException("Operation " + instOpcode
+ " is unknown to FOUT processing");
+ boolean isColAgg = ((AggregateUnaryOperator)
_optr).isColAggregate();
//Get partition type
FType inFtype = in.getFedMapping().getType();
//Get fedmap from in
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index 4e4448ba97..128fc1d4a6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -185,8 +185,11 @@ public final class IndexingFEDInstruction extends
UnaryFEDInstruction {
long id = FederationUtils.getNextFedDataID();
FederatedRequest tmp = new
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id,
in.getMetaData().getDataCharacteristics(), in.getDataType());
+ Types.ExecType execType =
InstructionUtils.getExecType(instString);
+ if (execType == Types.ExecType.FED)
+ execType = Types.ExecType.CP;
FederatedRequest[] fr1 =
FederationUtils.callInstruction(instStrings, output, id,
- new CPOperand[] {input1}, new long[] {fedMap.getID()},
InstructionUtils.getExecType(instString));
+ new CPOperand[] {input1}, new long[] {fedMap.getID()},
execType);
fedMap.execute(getTID(), true, tmp);
fedMap.execute(getTID(), true, fr1, new FederatedRequest[0]);
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index f4a9b0ecea..8a4476bf63 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -87,6 +87,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.Lineage;
@@ -472,7 +473,7 @@ public class ProgramConverter
try
{
- if( oInst instanceof CPInstruction || oInst instanceof
SPInstruction
+ if( oInst instanceof CPInstruction || oInst instanceof
SPInstruction || oInst instanceof FEDInstruction
|| oInst instanceof GPUInstruction ) {
if( oInst instanceof FunctionCallCPInstruction
&& cpFunctions ) {
FunctionCallCPInstruction tmp =
(FunctionCallCPInstruction) oInst;
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedDynamicPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedDynamicPlanningTest.java
new file mode 100644
index 0000000000..196423afa1
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedDynamicPlanningTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
[email protected]
+public class FederatedDynamicPlanningTest extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(FederatedDynamicPlanningTest.class.getName());
+
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
+ private final static String TEST_NAME =
"FederatedDynamicFunctionPlanningTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedDynamicPlanningTest.class.getSimpleName() + "/";
+ private static File TEST_CONF_FILE;
+
+ private final static int blocksize = 1024;
+ public final int rows = 1000;
+ public final int cols = 1000;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ public void runDynamicFullFunctionTest() {
+ // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+ // some rewrites not being applied. Might be a bug.
+ String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_max",
+ "fed_1-*", "fed_>"};
+ setTestConf("SystemDS-config-fout.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ public void runDynamicHeuristicFunctionTest() {
+ // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+ // some rewrites not being applied. Might be a bug.
+ String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*"};
+ setTestConf("SystemDS-config-heuristic.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ public void runDynamicCostBasedFunctionTest() {
+ // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+ // some rewrites not being applied. Might be a bug.
+ String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_max",
+ "fed_1-*", "fed_>"};
+ setTestConf("SystemDS-config-cost-based.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ private void setTestConf(String test_conf) {
+ TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+ }
+
+ private void writeInputMatrices() {
+ writeBinaryVector("A", 42, rows, null);
+ writeStandardMatrix("B1", 65, rows / 2, cols, null);
+ writeStandardMatrix("B2", 75, rows / 2, cols, null);
+ writeStandardMatrix("C1", 13, rows, cols / 2, null);
+ writeStandardMatrix("C2", 17, rows, cols / 2, null);
+ }
+
+ private void writeBinaryVector(String matrixName, long seed, int
numRows, PrivacyConstraint privacyConstraint){
+ double[][] matrix = getRandomMatrix(numRows, 1, -1, 1, 1, seed);
+ for(int i = 0; i < numRows; i++)
+ matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
1, blocksize, numRows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed, int
numRows, int numCols,
+ PrivacyConstraint privacyConstraint) {
+ double[][] matrix = getRandomMatrix(numRows, numCols, 0, 1, 1,
seed);
+ writeStandardMatrix(matrixName, numRows, numCols,
privacyConstraint, matrix);
+ }
+
+ private void writeStandardMatrix(String matrixName, int numRows, int
numCols, PrivacyConstraint privacyConstraint,
+ double[][] matrix) {
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
numCols, blocksize, (long) numRows * numCols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+ }
+
+ private void loadAndRunTest(String[] expectedHeavyHitters, String
testName) {
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+
+ Thread t1 = null, t2 = null;
+
+ try {
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ writeInputMatrices();
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] {"-stats", "-explain",
"hops", "-nvargs",
+ "r=" + rows, "c=" + cols,
+ "A=" + input("A"),
+ "B1=" + TestUtils.federatedAddress(port1,
input("B1")),
+ "B2=" + TestUtils.federatedAddress(port2,
input("B2")),
+ "C1=" + TestUtils.federatedAddress(port1,
input("C1")),
+ "C2=" + TestUtils.federatedAddress(port2,
input("C2")),
+ "lB1=" + input("B1"),
+ "lB2=" + input("B2"),
+ "Z=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs",
+ "r=" + rows, "c=" + cols,
+ "A=" + input("A"),
+ "B1=" + input("B1"),
+ "B2=" + input("B2"),
+ "C1=" + input("C1"),
+ "C2=" + input("C2"),
+ "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure scratch space and local temporary
+ * directory locations are also updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+ return TEST_CONF_FILE;
+ }
+
+}
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedDynamicFunctionPlanningTest.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedDynamicFunctionPlanningTest.dml
new file mode 100644
index 0000000000..17af190ca6
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedDynamicFunctionPlanningTest.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($A)
+B = federated(addresses=list($B1, $B2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+C = federated(addresses=list($C1, $C2),
+ ranges=list(list(0, 0), list($r, $c / 2), list(0, $c / 2), list($r, $c)))
+D = rbind(read($lB1), read($lB2))
+
+model = l2svm(X=D, Y=A)
+Z = model
+# TODO: call would perform `Y %*% X` internally due to rewrite, but MM with
RHS column-federated is not implemented
+#model = l2svm(X=C, Y=A)
+#Z = rbind(Z, model)
+model = l2svm(X=B, Y=A)
+Z = rbind(Z, model)
+
+write(Z, $Z)
\ No newline at end of file
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedDynamicFunctionPlanningTestReference.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedDynamicFunctionPlanningTestReference.dml
new file mode 100644
index 0000000000..3b64f10ead
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedDynamicFunctionPlanningTestReference.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($A)
+B = rbind(read($B1), read($B2))
+C = cbind(read($C1), read($C2))
+D = rbind(read($B1), read($B2))
+
+
+model = l2svm(X=B, Y=A)
+Z = model
+#model = l2svm(X=C, Y=A)
+#Z = rbind(Z, model)
+model = l2svm(X=D, Y=A)
+Z = rbind(Z, model)
+
+write(Z, $Z)