This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7f49828 [SYSTEMDS-3309/10] Federated planners fed_all and
fed_heuristic
7f49828 is described below
commit 7f49828865c3239c220c6d8078eaeec3af3a17af
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 12 22:33:46 2022 +0100
[SYSTEMDS-3309/10] Federated planners fed_all and fed_heuristic
Besides the existing runtime conversion of federated instructions, this
patch introduces the first generation of federated planners,
specifically two heuristic planners FED_ALL and FED_HEURISTIC, which
compile federated instructions during initial compilation, and without
runtime conversion. FED_ALL tries to keep everything it can federated,
while FED_HEURISTIC consolidates federated vectors (aggregates) in order
to avoid the high-latency of these federated vector operations.
Furthermore, this patch includes a basic refactoring of centralizing the
shared runtime and compiler types, introducing the necessary planner
abstractions, and fixing selected aspects (hop-local rewrites) of
compiling federated runtime instructions.
---
conf/SystemDS-config.xml.template | 2 +-
src/main/java/org/apache/sysds/conf/DMLConfig.java | 2 +-
.../java/org/apache/sysds/hops/AggBinaryOp.java | 8 +-
src/main/java/org/apache/sysds/hops/Hop.java | 2 +-
.../java/org/apache/sysds/hops/OptimizerUtils.java | 2 +-
.../sysds/hops/cost/FederatedCostEstimator.java | 2 +-
.../java/org/apache/sysds/hops/cost/HopRel.java | 2 +-
.../sysds/hops/fedplanner/AFederatedPlanner.java | 112 +++++++
.../org/apache/sysds/hops/fedplanner/FTypes.java | 138 ++++++++
.../hops/fedplanner/FederatedPlannerAllFed.java | 137 ++++++++
.../FederatedPlannerCostbased.java} | 52 +--
.../hops/fedplanner/FederatedPlannerHeuristic.java | 45 +++
.../sysds/hops/{ipa => fedplanner}/MemoTable.java | 2 +-
.../hops/ipa/IPAPassRewriteFederatedPlan.java | 373 +--------------------
.../hops/rewrite/RewriteFederatedExecution.java | 11 +-
.../controlprogram/caching/CacheableData.java | 2 +-
.../controlprogram/caching/MatrixObject.java | 3 +-
.../controlprogram/context/ExecutionContext.java | 2 +-
.../controlprogram/federated/FederationMap.java | 84 +----
.../controlprogram/federated/FederationUtils.java | 30 +-
.../paramserv/dp/DataPartitionFederatedScheme.java | 5 +-
.../fed/AggregateBinaryFEDInstruction.java | 4 +-
.../fed/AggregateTernaryFEDInstruction.java | 4 +-
.../fed/AggregateUnaryFEDInstruction.java | 19 +-
.../instructions/fed/AppendFEDInstruction.java | 4 +-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 4 +-
.../instructions/fed/CtableFEDInstruction.java | 2 +-
.../fed/CumulativeOffsetFEDInstruction.java | 4 +-
.../instructions/fed/FEDInstructionUtils.java | 2 +-
.../instructions/fed/IndexingFEDInstruction.java | 7 +-
.../instructions/fed/InitFEDInstruction.java | 2 +-
.../instructions/fed/MMChainFEDInstruction.java | 2 +-
.../runtime/instructions/fed/MMFEDInstruction.java | 4 +-
.../fed/ParameterizedBuiltinFEDInstruction.java | 11 +-
.../fed/QuantilePickFEDInstruction.java | 3 +-
.../fed/QuantileSortFEDInstruction.java | 3 +-
.../fed/QuaternaryWCeMMFEDInstruction.java | 4 +-
.../fed/QuaternaryWDivMMFEDInstruction.java | 4 +-
.../fed/QuaternaryWSLossFEDInstruction.java | 4 +-
.../fed/QuaternaryWSigmoidFEDInstruction.java | 4 +-
.../fed/QuaternaryWUMMFEDInstruction.java | 4 +-
.../instructions/fed/ReorgFEDInstruction.java | 11 +-
.../instructions/fed/SpoofFEDInstruction.java | 4 +-
.../instructions/fed/TernaryFEDInstruction.java | 10 +-
.../instructions/fed/TsmmFEDInstruction.java | 4 +-
.../fed/UnaryMatrixFEDInstruction.java | 4 +-
.../org/apache/sysds/test/AutomatedTestBase.java | 3 +-
.../pipelines/BuiltinTopkEvaluateTest.java | 1 -
.../fedplanning/FederatedCostEstimatorTest.java | 4 +-
.../fedplanning/FederatedL2SVMPlanningTest.java | 1 +
50 files changed, 574 insertions(+), 579 deletions(-)
diff --git a/conf/SystemDS-config.xml.template
b/conf/SystemDS-config.xml.template
index b4ba733..c43ba71 100644
--- a/conf/SystemDS-config.xml.template
+++ b/conf/SystemDS-config.xml.template
@@ -100,6 +100,6 @@
<!-- enables compiler assisted partial rewrites (e.g. Append-TSMM) -->
<sysds.lineage.compilerassisted>true</sysds.lineage.compilerassisted>
- <!-- set the federated plan generator (none, [runtime], compile_allfed,
compile_heuristic, compile_costbased) -->
+ <!-- set the federated plan generator (none, [runtime], compile_fed_all,
compile_fed_heuristic, compile_cost_based) -->
<sysds.federated.planner>runtime</sysds.federated.planner>
</root>
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index ea6bef0..f9400fa 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -43,7 +43,7 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.SpoofCompiler.CompilerType;
import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
import org.apache.sysds.hops.codegen.SpoofCompiler.PlanSelector;
-import
org.apache.sysds.hops.rewrite.RewriteFederatedExecution.FederatedPlanner;
+import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
import org.apache.sysds.lops.Compression;
import org.apache.sysds.lops.compile.linearization.ILinearize.DagLinearization;
import org.apache.sysds.parser.ParseException;
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 31309c1..6d0cff4 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -434,7 +434,7 @@ public class AggBinaryOp extends MultiThreadedHop {
}
updateETFed();
-
+
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
@@ -640,7 +640,7 @@ public class AggBinaryOp extends MultiThreadedHop {
}
else {
if( isLeftTransposeRewriteApplicable(true) ) {
- matmultCP =
constructCPLopsMMWithLeftTransposeRewrite();
+ matmultCP =
constructCPLopsMMWithLeftTransposeRewrite(et);
}
else {
int k =
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
@@ -655,7 +655,7 @@ public class AggBinaryOp extends MultiThreadedHop {
setLops(matmultCP);
}
- private Lop constructCPLopsMMWithLeftTransposeRewrite()
+ private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et)
{
Hop X = getInput().get(0).getInput().get(0); //guaranteed to
exists
Hop Y = getInput().get(1);
@@ -671,7 +671,7 @@ public class AggBinaryOp extends MultiThreadedHop {
updateLopFedOut(tY);
//matrix mult
- Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(),
getValueType(), ExecType.CP, k);
+ Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(),
getValueType(), et, k); //CP or FED
mult.getOutputParameters().setDimensions(Y.getDim2(),
X.getDim2(), getBlocksize(), getNnz());
setLineNumbers(mult);
updateLopFedOut(mult);
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index f91fad9..af89b4d 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -254,7 +254,7 @@ public abstract class Hop implements ParseInfo {
{
if(DMLScript.USE_ACCELERATOR && DMLScript.FORCE_ACCELERATOR &&
isGPUEnabled())
_etypeForced = ExecType.GPU; // enabled with -gpu force
option
- else if ( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
) {
+ else if ( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
&& _etypeForced != ExecType.FED ) {
if(OptimizerUtils.isMemoryBasedOptLevel() &&
DMLScript.USE_ACCELERATOR && isGPUEnabled()) {
// enabled with -exec singlenode -gpu option
_etypeForced = findExecTypeByMemEstimate();
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 19d0dec..f96cd02 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -34,8 +34,8 @@ import org.apache.sysds.conf.CompilerConfig;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
-import
org.apache.sysds.hops.rewrite.RewriteFederatedExecution.FederatedPlanner;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
diff --git
a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 400caa9..d0d7b5f 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -22,7 +22,7 @@ package org.apache.sysds.hops.cost;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.ipa.MemoTable;
+import org.apache.sysds.hops.fedplanner.MemoTable;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index bd5ee85..77d8356 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -21,7 +21,7 @@ package org.apache.sysds.hops.cost;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.ipa.MemoTable;
+import org.apache.sysds.hops.fedplanner.MemoTable;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
new file mode 100644
index 0000000..50c5f46
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import java.util.Map;
+
+import org.apache.sysds.common.Types.ReOrgOp;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DataExpression;
+
+public abstract class AFederatedPlanner {
+
+ /**
+ * Selects a federated execution plan for the given program
+ * by setting the forced execution type.
+ *
+ * @param prog dml program
+ * @param fgraph function call graph
+ * @param fcallSizes function call graph sizes
+ */
+ public abstract void rewriteProgram( DMLProgram prog,
+ FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes );
+
+
+ protected boolean allowsFederated(Hop hop, Map<Long, FType> fedHops) {
+ //generically obtain the input FTypes
+ FType[] ft = new FType[hop.getInput().size()];
+ for( int i=0; i<hop.getInput().size(); i++ )
+ ft[i] = fedHops.get(hop.getInput(i).getHopID());
+
+ //handle specific operators
+ if( hop instanceof AggBinaryOp ) {
+ return (ft[0] != null && ft[1] == null)
+ || (ft[0] == null && ft[1] != null);
+ }
+ else if( hop instanceof BinaryOp &&
!hop.getDataType().isScalar() ) {
+ return (ft[0] != null && ft[1] == null)
+ || (ft[0] == null && ft[1] != null)
+ || (ft[0] != null && ft[0] == ft[1]);
+ }
+ else if(ft.length==1 && ft[0] != null) {
+ return HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS);
+ }
+
+ return false;
+ }
+
+ protected FType getFederatedOut(Hop hop, Map<Long, FType> fedHops) {
+ //generically obtain the input FTypes
+ FType[] ft = new FType[hop.getInput().size()];
+ for( int i=0; i<hop.getInput().size(); i++ )
+ ft[i] = fedHops.get(hop.getInput(i).getHopID());
+
+ //handle specific operators
+ if( hop instanceof AggBinaryOp ) {
+ if( ft[0] != null )
+ return ft[0] == FType.ROW ? FType.ROW : null;
+ else if( ft[0] != null )
+ return ft[0] == FType.COL ? FType.COL : null;
+ }
+ else if( hop instanceof BinaryOp )
+ return ft[0] != null ? ft[0] : ft[1];
+ else if( HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS) )
+ return ft[0] == FType.ROW ? FType.COL : FType.COL;
+
+ return null;
+ }
+
+ protected FType deriveFType(DataOp fedInit) {
+ Hop ranges =
fedInit.getInput(fedInit.getParameterIndex(DataExpression.FED_RANGES));
+ boolean rowPartitioned = true;
+ boolean colPartitioned = true;
+ for( int i=0; i<ranges.getInput().size()/2; i++ ) { // workers
+ Hop beg = ranges.getInput(2*i);
+ Hop end = ranges.getInput(2*i+1);
+ long rl =
HopRewriteUtils.getIntValueSafe(beg.getInput(0));
+ long ru =
HopRewriteUtils.getIntValueSafe(end.getInput(0));
+ long cl =
HopRewriteUtils.getIntValueSafe(beg.getInput(1));
+ long cu =
HopRewriteUtils.getIntValueSafe(end.getInput(1));
+ rowPartitioned &= (cu-cl == fedInit.getDim2());
+ colPartitioned &= (ru-rl == fedInit.getDim1());
+ }
+ return rowPartitioned && colPartitioned ?
+ FType.FULL : rowPartitioned ? FType.ROW :
+ colPartitioned ? FType.COL : FType.OTHER;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
new file mode 100644
index 0000000..98de495
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+public class FTypes
+{
+ public enum FederatedPlanner {
+ NONE,
+ RUNTIME,
+ COMPILE_FED_ALL,
+ COMPILE_FED_HEURISTIC,
+ COMPILE_COST_BASED;
+ public AFederatedPlanner getPlanner() {
+ switch( this ) {
+ case COMPILE_FED_ALL:
+ return new FederatedPlannerAllFed();
+ case COMPILE_FED_HEURISTIC:
+ return new FederatedPlannerHeuristic();
+ case COMPILE_COST_BASED:
+ return new FederatedPlannerCostbased();
+ case NONE:
+ case RUNTIME:
+ default:
+ return null;
+ }
+ }
+ public boolean isCompiled() {
+ return this != NONE && this != RUNTIME;
+ }
+ public static boolean isCompiled(String planner) {
+ try {
+ return
FederatedPlanner.valueOf(planner.toUpperCase()).isCompiled();
+ }
+ catch(Exception ex) {
+ ex.printStackTrace();
+ return false;
+ }
+ }
+ }
+
+ public enum FPartitioning {
+ ROW, //row partitioned, groups of entire rows
+ COL, //column partitioned, groups of entire columns
+ MIXED, //arbitrary rectangles
+ NONE, //entire data in a location
+ }
+
+ public enum FReplication {
+ NONE, //every data item in a separate location
+ FULL, //every data item at every location
+ OVERLAP, //every data item partially at every location, w/
addition as aggregation method
+ }
+
+ public enum FType {
+ ROW(FPartitioning.ROW, FReplication.NONE),
+ COL(FPartitioning.COL, FReplication.NONE),
+ FULL(FPartitioning.NONE, FReplication.NONE),
+ BROADCAST(FPartitioning.NONE, FReplication.FULL),
+ PART(FPartitioning.NONE, FReplication.OVERLAP),
+ OTHER(FPartitioning.MIXED, FReplication.NONE);
+
+ private final FPartitioning _partType;
+ @SuppressWarnings("unused") //not yet
+ private final FReplication _repType;
+
+ private FType(FPartitioning ptype, FReplication rtype) {
+ _partType = ptype;
+ _repType = rtype;
+ }
+
+ public boolean isRowPartitioned() {
+ return _partType == FPartitioning.ROW
+ || _partType == FPartitioning.NONE;
+ }
+
+ public boolean isColPartitioned() {
+ return _partType == FPartitioning.COL
+ || _partType == FPartitioning.NONE;
+ }
+
+ public FPartitioning getPartType() {
+ return this._partType;
+ }
+
+ public boolean isType(FType t) {
+ switch(t) {
+ case ROW:
+ return isRowPartitioned();
+ case COL:
+ return isColPartitioned();
+ case FULL:
+ case OTHER:
+ default:
+ return t == this;
+ }
+ }
+ }
+
+ // Alignment Check Type
+ public enum AlignType {
+ FULL, // exact matching dimensions of partitions on the same
federated worker
+ ROW, // matching rows of partitions on the same federated worker
+ COL, // matching columns of partitions on the same federated
worker
+ FULL_T, // matching dimensions with transposed dimensions of
partitions on the same federated worker
+ ROW_T, // matching rows with columns of partitions on the same
federated worker
+ COL_T; // matching columns with rows of partitions on the same
federated worker
+
+ public boolean isTransposed() {
+ return (this == FULL_T || this == ROW_T || this ==
COL_T);
+ }
+ public boolean isFullType() {
+ return (this == FULL || this == FULL_T);
+ }
+ public boolean isRowType() {
+ return (this == ROW || this == ROW_T);
+ }
+ public boolean isColType() {
+ return (this == COL || this == COL_T);
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
new file mode 100644
index 0000000..a35d94c
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+/**
+ * Baseline federated planner that compiles all hops
+ * that support federated execution on federated inputs to
+ * forced federated operations.
+ */
+public class FederatedPlannerAllFed extends AFederatedPlanner {
+
+ @Override
+ public void rewriteProgram( DMLProgram prog,
+ FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes )
+ {
+ // handle main program
+ Map<String, FType> fedVars = new HashMap<>();
+ for(StatementBlock sb : prog.getStatementBlocks())
+ rRewriteStatementBlock(sb, fedVars);
+ }
+
+ private void rRewriteStatementBlock(StatementBlock sb, Map<String,
FType> fedVars) {
+ //TODO currently this rewrite assumes consistent decisions in
conditional control flow
+
+ if (sb instanceof FunctionStatementBlock) {
+ FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
+ FunctionStatement fstmt =
(FunctionStatement)fsb.getStatement(0);
+ for (StatementBlock csb : fstmt.getBody())
+ rRewriteStatementBlock(csb, fedVars);
+ }
+ else if (sb instanceof WhileStatementBlock) {
+ WhileStatementBlock wsb = (WhileStatementBlock) sb;
+ WhileStatement wstmt =
(WhileStatement)wsb.getStatement(0);
+ rRewriteHop(wsb.getPredicateHops(), new HashMap<>(),
Collections.emptyMap());
+ for (StatementBlock csb : wstmt.getBody())
+ rRewriteStatementBlock(csb, fedVars);
+ }
+ else if (sb instanceof IfStatementBlock) {
+ IfStatementBlock isb = (IfStatementBlock) sb;
+ IfStatement istmt = (IfStatement)isb.getStatement(0);
+ rRewriteHop(isb.getPredicateHops(), new HashMap<>(),
Collections.emptyMap());
+ for (StatementBlock csb : istmt.getIfBody())
+ rRewriteStatementBlock(csb, fedVars);
+ for (StatementBlock csb : istmt.getElseBody())
+ rRewriteStatementBlock(csb, fedVars);
+ }
+ else if (sb instanceof ForStatementBlock) { //incl parfor
+ ForStatementBlock fsb = (ForStatementBlock) sb;
+ ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+ rRewriteHop(fsb.getFromHops(), new HashMap<>(),
Collections.emptyMap());
+ rRewriteHop(fsb.getToHops(), new HashMap<>(),
Collections.emptyMap());
+ rRewriteHop(fsb.getIncrementHops(), new HashMap<>(),
Collections.emptyMap());
+ for (StatementBlock csb : fstmt.getBody())
+ rRewriteStatementBlock(csb, fedVars);
+ }
+ else //generic (last-level)
+ {
+ //process entire hop DAGs with memoization
+ Map<Long, FType> fedHops = new HashMap<>();
+ if( sb.getHops() != null )
+ for( Hop c : sb.getHops() )
+ rRewriteHop(c, fedHops, fedVars);
+
+ //TODO handle function calls
+
+ //propagate federated outputs across DAGs
+ if( sb.getHops() != null )
+ for( Hop c : sb.getHops() )
+ if( HopRewriteUtils.isData(c,
OpOpData.TRANSIENTWRITE) )
+ fedVars.put(c.getName(),
fedHops.get(c.getInput(0).getHopID()));
+ }
+ }
+
+ private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String,
FType> fedVars) {
+ if( memo.containsKey(hop.getHopID()) )
+ return; //already processed
+
+ //process children first
+ for( Hop c : hop.getInput() )
+ rRewriteHop(c, memo, fedVars);
+
+ //handle specific operators (except transient writes)
+ if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
+ memo.put(hop.getHopID(), deriveFType((DataOp)hop));
+ else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) )
+ memo.put(hop.getHopID(), fedVars.get(hop.getName()));
+ else if( allowsFederated(hop, memo) ) {
+ hop.setForcedExecType(ExecType.FED);
+ memo.put(hop.getHopID(), getFederatedOut(hop, memo));
+ if( memo.get(hop.getHopID()) != null )
+ hop.setFederatedOutput(FederatedOutput.FOUT);
+ }
+ else // memoization as processed, but not federated
+ memo.put(hop.getHopID(), null);
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
similarity index 91%
copy from
src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
copy to
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index 383be42..04532f3 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -17,7 +17,13 @@
* under the License.
*/
-package org.apache.sysds.hops.ipa;
+package org.apache.sysds.hops.fedplanner;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -31,6 +37,8 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
@@ -45,19 +53,8 @@ import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-
-/**
- * This rewrite generates a federated execution plan by estimating and setting
costs and the FederatedOutput values of
- * all relevant hops in the DML program.
- * The rewrite is only applied if federated compilation is activated in
OptimizerUtils.
- */
-public class IPAPassRewriteFederatedPlan extends IPAPass {
- private static final Log LOG =
LogFactory.getLog(IPAPassRewriteFederatedPlan.class.getName());
+public class FederatedPlannerCostbased extends AFederatedPlanner {
+ private static final Log LOG =
LogFactory.getLog(FederatedPlannerCostbased.class.getName());
private final static MemoTable hopRelMemo = new MemoTable();
/**
@@ -73,36 +70,13 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
return terminalHops;
}
- /**
- * Indicates if an IPA pass is applicable for the current configuration.
- * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
- *
- * @param fgraph function call graph
- * @return true if federated compilation is activated.
- */
@Override
- public boolean isApplicable(FunctionCallGraph fgraph) {
- return OptimizerUtils.FEDERATED_COMPILATION;
- }
-
- /**
- * Estimates cost and selects a federated execution plan
- * by setting the federated output value of each hop in the program.
- *
- * @param prog dml program
- * @param fgraph function call graph
- * @param fcallSizes function call size infos
- * @return false since the function call graph never has to be rebuilt
- */
- @Override
- public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
- FunctionCallSizeInfo fcallSizes) {
+ public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph,
FunctionCallSizeInfo fcallSizes ) {
prog.updateRepetitionEstimates();
rewriteStatementBlocks(prog, prog.getStatementBlocks());
setFinalFedouts();
- return false;
}
-
+
/**
* Estimates cost and enumerates federated execution plans in
hopRelMemo.
* The method calls the contained statement blocks recursively.
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
new file mode 100644
index 0000000..15b12ac
--- /dev/null
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import java.util.Map;
+
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
+
+public class FederatedPlannerHeuristic extends FederatedPlannerAllFed {
+
+ @Override
+ protected FType getFederatedOut(Hop hop, Map<Long, FType> fedHops) {
+ FType ret = super.getFederatedOut(hop, fedHops); // FedAll
+
+ //apply operator-specific heuristics
+ if( hop instanceof AggBinaryOp) {
+ if( (ret == FType.ROW && hop.getDim2()==1)
+ || (ret == FType.COL && hop.getDim1()==1) )
+ {
+ ret = null; //get local vectors
+ }
+ }
+
+ return ret;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
similarity index 99%
rename from src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
rename to src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
index fc95c29..6b3eb53 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.sysds.hops.ipa;
+package org.apache.sysds.hops.fedplanner;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.Hop;
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index 383be42..6be3b9c 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -19,37 +19,11 @@
package org.apache.sysds.hops.ipa;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.hops.AggBinaryOp;
-import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DataOp;
-import org.apache.sysds.hops.FunctionOp;
-import org.apache.sysds.hops.Hop;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
-import org.apache.sysds.hops.cost.HopRel;
-import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
import org.apache.sysds.parser.DMLProgram;
-import org.apache.sysds.parser.ForStatement;
-import org.apache.sysds.parser.ForStatementBlock;
-import org.apache.sysds.parser.FunctionStatement;
-import org.apache.sysds.parser.FunctionStatementBlock;
-import org.apache.sysds.parser.IfStatement;
-import org.apache.sysds.parser.IfStatementBlock;
-import org.apache.sysds.parser.Statement;
-import org.apache.sysds.parser.StatementBlock;
-import org.apache.sysds.parser.WhileStatement;
-import org.apache.sysds.parser.WhileStatementBlock;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
/**
* This rewrite generates a federated execution plan by estimating and setting
costs and the FederatedOutput values of
@@ -57,21 +31,6 @@ import java.util.Set;
* The rewrite is only applied if federated compilation is activated in
OptimizerUtils.
*/
public class IPAPassRewriteFederatedPlan extends IPAPass {
- private static final Log LOG =
LogFactory.getLog(IPAPassRewriteFederatedPlan.class.getName());
-
- private final static MemoTable hopRelMemo = new MemoTable();
- /**
- * IDs of hops for which the final fedout value has been set.
- */
- private final static Set<Long> hopRelUpdatedFinal = new HashSet<>();
- /**
- * Terminal hops in DML program given to this rewriter.
- */
- private final static List<Hop> terminalHops = new ArrayList<>();
-
- public List<Hop> getTerminalHops(){
- return terminalHops;
- }
/**
* Indicates if an IPA pass is applicable for the current configuration.
@@ -82,7 +41,10 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
*/
@Override
public boolean isApplicable(FunctionCallGraph fgraph) {
- return OptimizerUtils.FEDERATED_COMPILATION;
+ String planner = ConfigurationManager.getDMLConfig()
+ .getTextValue(DMLConfig.FEDERATED_PLANNER);
+ return OptimizerUtils.FEDERATED_COMPILATION
+ || FederatedPlanner.isCompiled(planner);
}
/**
@@ -95,316 +57,17 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* @return false since the function call graph never has to be rebuilt
*/
@Override
- public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
- FunctionCallSizeInfo fcallSizes) {
- prog.updateRepetitionEstimates();
- rewriteStatementBlocks(prog, prog.getStatementBlocks());
- setFinalFedouts();
+ public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph
fgraph, FunctionCallSizeInfo fcallSizes) {
+ // obtain planner instance according to config
+ String splanner = ConfigurationManager.getDMLConfig()
+ .getTextValue(DMLConfig.FEDERATED_PLANNER);
+ FederatedPlanner planner =
FederatedPlanner.isCompiled(splanner) ?
+ FederatedPlanner.valueOf(splanner.toUpperCase()) :
+ FederatedPlanner.COMPILE_COST_BASED;
+
+ // run planner rewrite with forced federated exec types
+ planner.getPlanner().rewriteProgram(prog, fgraph, fcallSizes);
+
return false;
}
-
- /**
- * Estimates cost and enumerates federated execution plans in
hopRelMemo.
- * The method calls the contained statement blocks recursively.
- *
- * @param prog dml program
- * @param sbs list of statement blocks
- * @return list of statement blocks with the federated output value
updated for each hop
- */
- private ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram
prog, List<StatementBlock> sbs) {
- ArrayList<StatementBlock> rewrittenStmBlocks = new
ArrayList<>();
- for(StatementBlock stmBlock : sbs)
- rewrittenStmBlocks.addAll(rewriteStatementBlock(prog,
stmBlock));
- return rewrittenStmBlocks;
- }
-
- /**
- * Estimates cost and enumerates federated execution plans in
hopRelMemo.
- * The method calls the contained statement blocks recursively.
- *
- * @param prog dml program
- * @param sb statement block
- * @return list of statement blocks with the federated output value
updated for each hop
- */
- public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog,
StatementBlock sb) {
- if(sb instanceof WhileStatementBlock)
- return rewriteWhileStatementBlock(prog,
(WhileStatementBlock) sb);
- else if(sb instanceof IfStatementBlock)
- return rewriteIfStatementBlock(prog, (IfStatementBlock)
sb);
- else if(sb instanceof ForStatementBlock) {
- // This also includes ParForStatementBlocks
- return rewriteForStatementBlock(prog,
(ForStatementBlock) sb);
- }
- else if(sb instanceof FunctionStatementBlock)
- return rewriteFunctionStatementBlock(prog,
(FunctionStatementBlock) sb);
- else {
- // StatementBlock type (no subclass)
- return rewriteDefaultStatementBlock(prog, sb);
- }
- }
-
- private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram
prog, WhileStatementBlock whileSB) {
- Hop whilePredicateHop = whileSB.getPredicateHops();
- selectFederatedExecutionPlan(whilePredicateHop);
- for(Statement stm : whileSB.getStatements()) {
- WhileStatement whileStm = (WhileStatement) stm;
- whileStm.setBody(rewriteStatementBlocks(prog,
whileStm.getBody()));
- }
- return new ArrayList<>(Collections.singletonList(whileSB));
- }
-
- private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram
prog, IfStatementBlock ifSB) {
- selectFederatedExecutionPlan(ifSB.getPredicateHops());
- for(Statement statement : ifSB.getStatements()) {
- IfStatement ifStatement = (IfStatement) statement;
- ifStatement.setIfBody(rewriteStatementBlocks(prog,
ifStatement.getIfBody()));
- ifStatement.setElseBody(rewriteStatementBlocks(prog,
ifStatement.getElseBody()));
- }
- return new ArrayList<>(Collections.singletonList(ifSB));
- }
-
- private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram
prog, ForStatementBlock forSB) {
- selectFederatedExecutionPlan(forSB.getFromHops());
- selectFederatedExecutionPlan(forSB.getToHops());
- selectFederatedExecutionPlan(forSB.getIncrementHops());
- for(Statement statement : forSB.getStatements()) {
- ForStatement forStatement = ((ForStatement) statement);
- forStatement.setBody(rewriteStatementBlocks(prog,
forStatement.getBody()));
- }
- return new ArrayList<>(Collections.singletonList(forSB));
- }
-
- private ArrayList<StatementBlock>
rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB) {
- for(Statement statement : funcSB.getStatements()) {
- FunctionStatement funcStm = (FunctionStatement)
statement;
- funcStm.setBody(rewriteStatementBlocks(prog,
funcStm.getBody()));
- }
- return new ArrayList<>(Collections.singletonList(funcSB));
- }
-
- private ArrayList<StatementBlock>
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb) {
- if(sb.hasHops()) {
- for(Hop sbHop : sb.getHops()) {
- if(sbHop instanceof FunctionOp) {
- String funcName = ((FunctionOp)
sbHop).getFunctionName();
- FunctionStatementBlock sbFuncBlock =
prog.getBuiltinFunctionDictionary().getFunction(funcName);
- rewriteStatementBlock(prog,
sbFuncBlock);
- }
- else
- selectFederatedExecutionPlan(sbHop);
- }
- }
- return new ArrayList<>(Collections.singletonList(sb));
- }
-
- /**
- * Set final fedouts of all hops starting from terminal hops.
- */
- private void setFinalFedouts(){
- for ( Hop root : terminalHops)
- setFinalFedout(root);
- }
-
- /**
- * Sets FederatedOutput field of all hops in DAG starting from given
root.
- * The FederatedOutput chosen for root is the minimum cost HopRel found
in memo table for the given root.
- * The FederatedOutput values chosen for the inputs to the root are
chosen based on the input dependencies.
- *
- * @param root hop for which FederatedOutput needs to be set
- */
- private void setFinalFedout(Hop root) {
- HopRel optimalRootHopRel =
hopRelMemo.getMinCostAlternative(root);
- setFinalFedout(root, optimalRootHopRel);
- }
-
- /**
- * Update the FederatedOutput value and cost based on information
stored in given rootHopRel.
- *
- * @param root hop for which FederatedOutput is set
- * @param rootHopRel from which FederatedOutput value and cost is
retrieved
- */
- private void setFinalFedout(Hop root, HopRel rootHopRel) {
- if ( hopRelUpdatedFinal.contains(root.getHopID()) ){
- if((rootHopRel.hasLocalOutput() ^
root.hasLocalOutput()) && hopRelMemo.hasFederatedOutputAlternative(root)){
- // Update with FOUT alternative without
visiting inputs
- updateFederatedOutput(root,
hopRelMemo.getFederatedOutputAlternative(root));
- root.activatePrefetch();
- }
- else {
- // Update without visiting inputs
- updateFederatedOutput(root, rootHopRel);
- }
- }
- else {
- updateFederatedOutput(root, rootHopRel);
- visitInputDependency(rootHopRel);
- }
- }
-
- /**
- * Sets FederatedOutput value for each of the inputs of rootHopRel
- *
- * @param rootHopRel which has its input values updated
- */
- private void visitInputDependency(HopRel rootHopRel) {
- List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
- for(HopRel input : hopRelInputs)
- setFinalFedout(input.getHopRef(), input);
- }
-
- /**
- * Updates FederatedOutput value and cost estimate based on
updateHopRel values.
- *
- * @param root which has its values updated
- * @param updateHopRel from which the values are retrieved
- */
- private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
- root.setFederatedOutput(updateHopRel.getFederatedOutput());
- root.setFederatedCost(updateHopRel.getCostObject());
- forceFixedFedOut(root);
- hopRelUpdatedFinal.add(root.getHopID());
- }
-
- /**
- * Set federated output to fixed value if FEDERATED_SPECS is activated
for root hop.
- * @param root hop set to fixed fedout value as loaded from
FEDERATED_SPECS
- */
- private void forceFixedFedOut(Hop root){
- if (
OptimizerUtils.FEDERATED_SPECS.containsKey(root.getBeginLine()) ){
- FEDInstruction.FederatedOutput fedOutSpec =
OptimizerUtils.FEDERATED_SPECS.get(root.getBeginLine());
- root.setFederatedOutput(fedOutSpec);
- if ( fedOutSpec.isForcedFederated() )
- root.deactivatePrefetch();
- }
- }
-
- /**
- * Select federated execution plan for every Hop in the DAG starting
from given roots.
- * The cost estimates of the hops are also updated when FederatedOutput
is updated in the hops.
- *
- * @param roots starting point for going through the Hop DAG to update
the FederatedOutput fields.
- */
- @SuppressWarnings("unused")
- private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
- for ( Hop root : roots )
- selectFederatedExecutionPlan(root);
- }
-
- /**
- * Select federated execution plan for every Hop in the DAG starting
from given root.
- *
- * @param root starting point for going through the Hop DAG to update
the federatedOutput fields
- */
- private void selectFederatedExecutionPlan(Hop root) {
- if ( root != null ){
- visitFedPlanHop(root);
- if ( HopRewriteUtils.isTerminalHop(root) )
- terminalHops.add(root);
- }
- }
-
- /**
- * Go through the Hop DAG and set the FederatedOutput field and cost
estimate for each Hop from leaf to given currentHop.
- *
- * @param currentHop the Hop from which the DAG is visited
- */
- private void visitFedPlanHop(Hop currentHop) {
- // If the currentHop is in the hopRelMemo table, it means that
it has been visited
- if(hopRelMemo.containsHop(currentHop))
- return;
- // If the currentHop has input, then the input should be
visited depth-first
- if(currentHop.getInput() != null &&
currentHop.getInput().size() > 0) {
- debugLog(currentHop);
- for(Hop input : currentHop.getInput())
- visitFedPlanHop(input);
- }
- // Put FOUT, LOUT, and None HopRels into the memo table
- ArrayList<HopRel> hopRels = new ArrayList<>();
- if(isFedInstSupportedHop(currentHop)) {
- for(FEDInstruction.FederatedOutput fedoutValue :
FEDInstruction.FederatedOutput.values())
- if(isFedOutSupported(currentHop, fedoutValue))
- hopRels.add(new HopRel(currentHop,
fedoutValue, hopRelMemo));
- }
- if(hopRels.isEmpty())
- hopRels.add(new HopRel(currentHop,
FEDInstruction.FederatedOutput.NONE, hopRelMemo));
- hopRelMemo.put(currentHop, hopRels);
- }
-
- /**
- * Write HOP visit to debug log if debug is activated.
- * @param currentHop hop written to log
- */
- private void debugLog(Hop currentHop){
- if ( LOG.isDebugEnabled() ){
- LOG.debug("Visiting HOP: " + currentHop + " Input size:
" + currentHop.getInput().size());
- int index = 0;
- for ( Hop hop : currentHop.getInput()){
- if ( hop == null )
- LOG.debug("Input at index is null: " +
index);
- else
- LOG.debug("HOP input: " + hop + " at
index " + index + " of " + currentHop);
- index++;
- }
- }
- }
-
- /**
- * Checks if the instructions related to the given hop supports
FOUT/LOUT processing.
- *
- * @param hop to check for federated support
- * @return true if federated instructions related to hop supports
FOUT/LOUT processing
- */
- private boolean isFedInstSupportedHop(Hop hop) {
- // The following operations are supported given that the above
conditions have not returned already
- return (hop instanceof AggBinaryOp || hop instanceof BinaryOp
|| hop instanceof ReorgOp
- || hop instanceof AggUnaryOp || hop instanceof
TernaryOp || hop instanceof DataOp);
- }
-
- /**
- * Checks if the associatedHop supports the given federated output
value.
- *
- * @param associatedHop to check support of
- * @param fedOut federated output value
- * @return true if associatedHop supports fedOut
- */
- private boolean isFedOutSupported(Hop associatedHop,
FEDInstruction.FederatedOutput fedOut) {
- switch(fedOut) {
- case FOUT:
- return isFOUTSupported(associatedHop);
- case LOUT:
- return isLOUTSupported(associatedHop);
- case NONE:
- return false;
- default:
- return true;
- }
- }
-
- /**
- * Checks to see if the associatedHop supports FOUT.
- *
- * @param associatedHop for which FOUT support is checked
- * @return true if FOUT is supported by the associatedHop
- */
- private boolean isFOUTSupported(Hop associatedHop) {
- // If the output of AggUnaryOp is a scalar, the operation
cannot be FOUT
- if(associatedHop instanceof AggUnaryOp &&
associatedHop.isScalar())
- return false;
- // It can only be FOUT if at least one of the inputs are FOUT,
except if it is a federated DataOp
-
if(associatedHop.getInput().stream().noneMatch(hopRelMemo::hasFederatedOutputAlternative)
- && !associatedHop.isFederatedDataOp())
- return false;
- return true;
- }
-
- /**
- * Checks to see if the associatedHop supports LOUT.
- * It supports LOUT if the output has no privacy constraints.
- *
- * @param associatedHop for which LOUT support is checked.
- * @return true if LOUT is supported by the associatedHop
- */
- private boolean isLOUTSupported(Hop associatedHop) {
- return associatedHop.getPrivacy() == null ||
!associatedHop.getPrivacy().hasConstraints();
- }
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index c70df67..822b4b5 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -55,14 +55,6 @@ import java.util.concurrent.Future;
public class RewriteFederatedExecution extends HopRewriteRule {
private static final Logger LOG =
Logger.getLogger(RewriteFederatedExecution.class);
-
- public enum FederatedPlanner {
- NONE,
- RUNTIME,
- COMPILE_ALLFED,
- COMPILE_HEURISTIC,
- COMPILE_COSTBASED,
- }
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state) {
@@ -72,7 +64,8 @@ public class RewriteFederatedExecution extends HopRewriteRule
{
return roots;
}
- @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus
state) {
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
if ( root != null )
visitHop(root);
return root;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index f9438be..925db3b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -37,11 +37,11 @@ import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.cp.Data;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index b09b295..6afefcf 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -32,6 +32,7 @@ import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import
org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
@@ -518,7 +519,7 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
// TODO sparse optimization
List<Pair<FederatedRange, Future<FederatedResponse>>>
readResponses = fedMap.requestFederatedData();
try {
- if(fedMap.getType() == FederationMap.FType.PART)
+ if(fedMap.getType() == FType.PART)
return
FederationUtils.aggregateResponses(readResponses);
else
return
FederationUtils.bindResponses(readResponses, dims);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 8735d05..c15a07d 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -27,6 +27,7 @@ import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.Program;
@@ -36,7 +37,6 @@ import
org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 1676078..d55d463 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -33,6 +33,8 @@ import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
@@ -46,86 +48,6 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;
public class FederationMap {
- public enum FPartitioning {
- ROW, //row partitioned, groups of entire rows
- COL, //column partitioned, groups of entire columns
- MIXED, //arbitrary rectangles
- NONE, //entire data in a location
- }
-
- public enum FReplication {
- NONE, //every data item in a separate location
- FULL, //every data item at every location
- OVERLAP, //every data item partially at every location, w/
addition as aggregation method
- }
-
- public enum FType {
- ROW(FPartitioning.ROW, FReplication.NONE),
- COL(FPartitioning.COL, FReplication.NONE),
- FULL(FPartitioning.NONE, FReplication.NONE),
- BROADCAST(FPartitioning.NONE, FReplication.FULL),
- PART(FPartitioning.NONE, FReplication.OVERLAP),
- OTHER(FPartitioning.MIXED, FReplication.NONE);
-
- private final FPartitioning _partType;
- @SuppressWarnings("unused") //not yet
- private final FReplication _repType;
-
- private FType(FPartitioning ptype, FReplication rtype) {
- _partType = ptype;
- _repType = rtype;
- }
-
- public boolean isRowPartitioned() {
- return _partType == FPartitioning.ROW
- || _partType == FPartitioning.NONE;
- }
-
- public boolean isColPartitioned() {
- return _partType == FPartitioning.COL
- || _partType == FPartitioning.NONE;
- }
-
- public FPartitioning getPartType() {
- return this._partType;
- }
-
- public boolean isType(FType t) {
- switch(t) {
- case ROW:
- return isRowPartitioned();
- case COL:
- return isColPartitioned();
- case FULL:
- case OTHER:
- default:
- return t == this;
- }
- }
- }
-
- // Alignment Check Type
- public enum AlignType {
- FULL, // exact matching dimensions of partitions on the same
federated worker
- ROW, // matching rows of partitions on the same federated worker
- COL, // matching columns of partitions on the same federated
worker
- FULL_T, // matching dimensions with transposed dimensions of
partitions on the same federated worker
- ROW_T, // matching rows with columns of partitions on the same
federated worker
- COL_T; // matching columns with rows of partitions on the same
federated worker
-
- public boolean isTransposed() {
- return (this == FULL_T || this == ROW_T || this ==
COL_T);
- }
- public boolean isFullType() {
- return (this == FULL || this == FULL_T);
- }
- public boolean isRowType() {
- return (this == ROW || this == ROW_T);
- }
- public boolean isColType() {
- return (this == COL || this == COL_T);
- }
- }
private long _ID = -1;
private final List<Pair<FederatedRange, FederatedData>> _fedMap;
@@ -317,7 +239,7 @@ public class FederationMap {
public boolean isAligned(FederationMap that, boolean transposed) {
boolean ret = true;
//TODO support operations with fully broadcast objects
- if (_type == FederationMap.FType.BROADCAST)
+ if (_type == FType.BROADCAST)
return false;
for(Pair<FederatedRange, FederatedData> e : _fedMap) {
FederatedRange range = !transposed ? e.getKey() : new
FederatedRange(e.getKey()).transpose();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index e2430bb..606b6c0 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -28,6 +28,8 @@ import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.log4j.Logger;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.fedplanner.FTypes.FPartitioning;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
@@ -68,7 +70,7 @@ public class FederationUtils {
public static void checkFedMapType(MatrixObject mo) {
FederationMap fedMap = mo.getFedMapping();
- FederationMap.FType oldType = fedMap.getType();
+ FType oldType = fedMap.getType();
boolean isRow = true;
long prev = 0;
@@ -78,10 +80,10 @@ public class FederationUtils {
else
isRow = false;
}
- if(isRow && oldType.getPartType() ==
FederationMap.FPartitioning.COL)
- fedMap.setType(FederationMap.FType.ROW);
- else if(!isRow && oldType.getPartType() ==
FederationMap.FPartitioning.ROW)
- fedMap.setType(FederationMap.FType.COL);
+ if(isRow && oldType.getPartType() == FPartitioning.COL)
+ fedMap.setType(FType.ROW);
+ else if(!isRow && oldType.getPartType() == FPartitioning.ROW)
+ fedMap.setType(FType.COL);
}
//TODO remove rmFedOutFlag, once all federated instructions have this
flag, then unconditionally remove
@@ -217,9 +219,9 @@ public class FederationUtils {
}
}
- public static MatrixBlock aggMinMax(Future<FederatedResponse>[] ffr,
boolean isMin, boolean isScalar, Optional<FederationMap.FType> fedType) {
+ public static MatrixBlock aggMinMax(Future<FederatedResponse>[] ffr,
boolean isMin, boolean isScalar, Optional<FType> fedType) {
try {
- if (!fedType.isPresent() || fedType.get() ==
FederationMap.FType.OTHER) {
+ if (!fedType.isPresent() || fedType.get() ==
FType.OTHER) {
double res = isMin ? Double.MAX_VALUE :
-Double.MAX_VALUE;
for (Future<FederatedResponse> fr : ffr) {
double v = isScalar ? ((ScalarObject)
fr.get().getData()[0]).getDoubleValue() :
@@ -229,11 +231,11 @@ public class FederationUtils {
return new MatrixBlock(1, 1, res);
} else {
MatrixBlock[] tmp = getResults(ffr);
- int dim = fedType.get() ==
FederationMap.FType.COL ? tmp[0].getNumRows() : tmp[0].getNumColumns();
+ int dim = fedType.get() == FType.COL ?
tmp[0].getNumRows() : tmp[0].getNumColumns();
for (int i = 0; i < ffr.length - 1; i++)
for (int j = 0; j < dim; j++)
- if (fedType.get() ==
FederationMap.FType.COL)
+ if (fedType.get() == FType.COL)
tmp[i + 1].setValue(j,
0, isMin ? Double.min(tmp[i].getValue(j, 0), tmp[i + 1].getValue(j, 0)) :
Double.max(tmp[i].getValue(j, 0), tmp[i + 1].getValue(j, 0)));
else tmp[i + 1].setValue(0, j,
isMin ? Double.min(tmp[i].getValue(0, j), tmp[i + 1].getValue(0, j)) :
@@ -248,7 +250,7 @@ public class FederationUtils {
public static MatrixBlock aggProd(Future<FederatedResponse>[] ffr,
FederationMap fedMap, AggregateUnaryOperator aop) {
try {
- boolean rowFed = fedMap.getType() ==
FederationMap.FType.ROW;
+ boolean rowFed = fedMap.getType() == FType.ROW;
MatrixBlock ret = aop.isFullAggregate() ? (rowFed ?
new MatrixBlock(ffr.length, 1, 1.0) : new
MatrixBlock(1, ffr.length, 1.0)) :
(rowFed ?
@@ -395,9 +397,9 @@ public class FederationUtils {
}
public static MatrixBlock aggMatrix(AggregateUnaryOperator aop,
Future<FederatedResponse>[] ffr, Future<FederatedResponse>[] meanFfr,
FederationMap map) {
- if (aop.isRowAggregate() && map.getType() ==
FederationMap.FType.ROW)
+ if (aop.isRowAggregate() && map.getType() == FType.ROW)
return bind(ffr, false);
- else if (aop.isColAggregate() && map.getType() ==
FederationMap.FType.COL)
+ else if (aop.isColAggregate() && map.getType() == FType.COL)
return bind(ffr, true);
if (aop.aggOp.increOp.fn instanceof KahanFunction)
@@ -473,9 +475,9 @@ public class FederationUtils {
}
public static MatrixBlock aggMatrix(AggregateUnaryOperator aop,
Future<FederatedResponse>[] ffr, FederationMap map) {
- if (aop.isRowAggregate() && map.getType() ==
FederationMap.FType.ROW)
+ if (aop.isRowAggregate() && map.getType() == FType.ROW)
return bind(ffr, false);
- else if (aop.isColAggregate() && map.getType() ==
FederationMap.FType.COL)
+ else if (aop.isColAggregate() && map.getType() == FType.COL)
return bind(ffr, true);
if (aop.aggOp.increOp.fn instanceof KahanFunction)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index e9bec6c..5bb3e12 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv.dp;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -77,7 +78,7 @@ public abstract class DataPartitionFederatedScheme {
* @param fedMatrix the federated input matrix
*/
static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
- if (fedMatrix.isFederated(FederationMap.FType.ROW)) {
+ if (fedMatrix.isFederated(FType.ROW)) {
List<MatrixObject> slices =
Collections.synchronizedList(new ArrayList<>());
fedMatrix.getFedMapping().forEachParallel((range, data)
-> {
// Create sliced matrix object
@@ -91,7 +92,7 @@ public abstract class DataPartitionFederatedScheme {
List<Pair<FederatedRange, FederatedData>>
newFedHashMap = new ArrayList<>();
newFedHashMap.add(Pair.of(range, data));
slice.setFedMapping(new
FederationMap(fedMatrix.getFedMapping().getID(), newFedHashMap));
-
slice.getFedMapping().setType(FederationMap.FType.ROW);
+ slice.getFedMapping().setType(FType.ROW);
slices.add(slice);
return null;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 8ea77c4..ef447df 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -22,6 +22,8 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -29,8 +31,6 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index cfe0baf..a9efb89 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -21,14 +21,14 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 88a066a..668fd0b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -102,7 +103,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
MatrixObject in = ec.getMatrixObject(input1);
FederationMap map = in.getFedMapping();
- if((instOpcode.equalsIgnoreCase("uarimax") ||
instOpcode.equalsIgnoreCase("uarimin")) &&
in.isFederated(FederationMap.FType.COL))
+ if((instOpcode.equalsIgnoreCase("uarimax") ||
instOpcode.equalsIgnoreCase("uarimin")) && in.isFederated(FType.COL))
instString =
InstructionUtils.replaceOperand(instString, 5, "2");
// create federated commands for aggregation
@@ -150,7 +151,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
throw new DMLRuntimeException("Operation " + instOpcode
+ " is unknown to FOUT processing");
boolean isColAgg = instOpcode.equals("uack+");
//Get partition type
- FederationMap.FType inFtype = in.getFedMapping().getType();
+ FType inFtype = in.getFedMapping().getType();
//Get fedmap from in
FederationMap inputFedMapCopy =
in.getFedMapping().copyWithNewID(fr1.getID());
@@ -160,7 +161,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
if ( inFtype.isRowPartitioned() && !isColAgg ){
for ( FederatedRange range :
inputFedMapCopy.getFederatedRanges() )
range.setEndDim(1,out.getNumColumns());
- inputFedMapCopy.setType(FederationMap.FType.ROW);
+ inputFedMapCopy.setType(FType.ROW);
}
//if partition type is row and aggregation type is col
// then get row and col dimension from out and use those
dimensions for both federated workers
@@ -175,7 +176,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
range.setEndDim(0,out.getNumRows());
range.setEndDim(1,out.getNumColumns());
}
- inputFedMapCopy.setType(FederationMap.FType.PART);
+ inputFedMapCopy.setType(FType.PART);
}
//if partition type is col and aggregation type is col
// then set row dimension to output and col dimension to in
col split
@@ -183,7 +184,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
if ( inFtype.isColPartitioned() && isColAgg ){
for ( FederatedRange range :
inputFedMapCopy.getFederatedRanges() )
range.setEndDim(0,out.getNumRows());
- inputFedMapCopy.setType(FederationMap.FType.COL);
+ inputFedMapCopy.setType(FType.COL);
}
//set out fedmap in the end
@@ -231,7 +232,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
id = tmpRequest.getID();
}
else {
- if((map.getType() == FederationMap.FType.COL &&
aop.isColAggregate()) || (map.getType() == FederationMap.FType.ROW &&
aop.isRowAggregate()))
+ if((map.getType() == FType.COL &&
aop.isColAggregate()) || (map.getType() == FType.ROW && aop.isRowAggregate()))
tmpRequest = new
FederatedRequest(RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1),
in.getDataType());
else {
DataCharacteristics dc =
ec.getDataCharacteristics(output.getName());
@@ -274,8 +275,8 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest fr1;
long id = FederationUtils.getNextFedDataID();
- if((map.getType() == FederationMap.FType.COL &&
aop.isColAggregate()) ||
- (map.getType() == FederationMap.FType.ROW &&
aop.isRowAggregate()))
+ if((map.getType() == FType.COL && aop.isColAggregate()) ||
+ (map.getType() == FType.ROW && aop.isRowAggregate()))
fr1 = new FederatedRequest(RequestType.PUT_VAR, id, new
MatrixCharacteristics(-1, -1), in.getDataType());
else
fr1 = new FederatedRequest(RequestType.PUT_VAR, id, dc,
in.getDataType());
@@ -299,7 +300,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
id = fr1.getID();
}
else {
- if((map.getType() == FederationMap.FType.COL &&
aop.isColAggregate()) || (map.getType() == FederationMap.FType.ROW &&
aop.isRowAggregate()))
+ if((map.getType() == FType.COL && aop.isColAggregate())
|| (map.getType() == FType.ROW && aop.isRowAggregate()))
fr1 = new FederatedRequest(RequestType.PUT_VAR,
id, new MatrixCharacteristics(-1, -1), in.getDataType());
else
fr1 = new FederatedRequest(RequestType.PUT_VAR,
id, dc, in.getDataType());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 56671a1..84ff4f1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -19,14 +19,14 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 813379a..e631c48 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -19,13 +19,13 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index e644ef1..e016f2b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -30,6 +30,7 @@ import java.util.TreeMap;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -39,7 +40,6 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java
index 67288c3..b870aaf 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java
@@ -23,12 +23,12 @@ import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -74,7 +74,7 @@ public class CumulativeOffsetFEDInstruction extends
BinaryFEDInstruction
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
MatrixObject mo2 = ec.getMatrixObject(input2);
- if(getOpcode().startsWith("bcumoff") &&
mo1.isFederated(FederationMap.FType.ROW))
+ if(getOpcode().startsWith("bcumoff") &&
mo1.isFederated(FType.ROW))
processCumulativeInstruction(ec);
else {
//federated execution on arbitrary row/column partitions
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index e1c4587..fc6e651 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.instructions.fed;
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
@@ -30,7 +31,6 @@ import
org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index 92968d8..f5e4861 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -27,6 +27,7 @@ import java.util.Objects;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.RightIndex;
@@ -169,7 +170,7 @@ public final class IndexingFEDInstruction extends
UnaryFEDInstruction {
if(input1.isFrame()) {
//modify frame schema
- if(in.isFederated(FederationMap.FType.ROW))
+ if(in.isFederated(FType.ROW))
schema = Arrays.asList(((FrameObject)
in).getSchema((int) csn, (int) cen));
else
Collections.addAll(schema,
((FrameObject) in).getSchema((int) csn, (int) cen));
@@ -239,7 +240,7 @@ public final class IndexingFEDInstruction extends
UnaryFEDInstruction {
// find ranges where to apply leftIndex
long to;
- if(in1.isFederated(FederationMap.FType.ROW) && (to =
(prev + ren - rsn)) >= 0 &&
+ if(in1.isFederated(FType.ROW) && (to = (prev + ren -
rsn)) >= 0 &&
to < in2.getNumRows() && ixrange.rowStart <=
re) {
sliceIxs[i] = new int[] { prev, (int) to, 0,
(int) in2.getNumColumns()-1};
prev = (int) (to + 1);
@@ -248,7 +249,7 @@ public final class IndexingFEDInstruction extends
UnaryFEDInstruction {
ranges[i] = range;
from = Math.min(i, from);
}
- else if(in1.isFederated(FederationMap.FType.COL) && (to
= (prev + cen - csn)) >= 0 &&
+ else if(in1.isFederated(FType.COL) && (to = (prev + cen
- csn)) >= 0 &&
to < in2.getNumColumns() && ixrange.colStart <=
ce) {
sliceIxs[i] = new int[] {0, (int)
in2.getNumRows() - 1, prev, (int) to};
prev = (int) (to + 1);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 29b2a17..6e18115 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -41,6 +41,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -49,7 +50,6 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 5e08c0e..9719298 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -19,13 +19,13 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
index 865696b..e0550c9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
@@ -25,6 +25,8 @@ import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -34,8 +36,6 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 410fc4f..d1b773d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -40,6 +40,7 @@ import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheDataOutput;
@@ -196,11 +197,11 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
long cols = 0;
for(int i = 0; i < ffr.length; i++) {
try {
- if(in.isFederated(FederationMap.FType.COL)) {
+ if(in.isFederated(FType.COL)) {
out.getFedMapping().getFederatedRanges()[i + 1].setBeginDim(1, cols);
cols += ((ScalarObject)
ffr[i].get().getData()[0]).getLongValue();
}
- else if(in.isFederated(FederationMap.FType.ROW))
+ else if(in.isFederated(FType.ROW))
cols = ((ScalarObject)
ffr[i].get().getData()[0]).getLongValue();
out.getFedMapping().getFederatedRanges()[i].setEndDim(1, cols);
}
@@ -223,7 +224,7 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
MatrixObject mo = (MatrixObject) getTarget(ec);
FederationMap fedMap = mo.getFedMapping();
- boolean rowFed = mo.isFederated(FederationMap.FType.ROW);
+ boolean rowFed = mo.isFederated(FType.ROW);
long varID = FederationUtils.getNextFedDataID();
FederationMap diagFedMap;
@@ -458,7 +459,7 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
dcs = finalDcs1;
out.getDataCharacteristics().set(mo.getDataCharacteristics());
- int len = marginRow ? mo.getSchema().length : (int)
(mo.isFederated(FederationMap.FType.ROW) ? s
+ int len = marginRow ? mo.getSchema().length : (int)
(mo.isFederated(FType.ROW) ? s
.getNonZeros() :
finalSchema.values().stream().mapToInt(e -> e.length).sum());
ValueType[] schema = new ValueType[len];
int pos = 0;
@@ -467,7 +468,7 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
if(marginRow) {
schema = mo.getSchema();
- } else if(mo.isFederated(FederationMap.FType.ROW)) {
+ } else if(mo.isFederated(FType.ROW)) {
schema = finalSchema.get(federatedRange);
} else {
ValueType[] tmp =
finalSchema.get(federatedRange);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index 83a1360..c6a5b08 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -31,6 +31,7 @@ import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.PickByCount.OperationTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -109,7 +110,7 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
-
if(ec.getMatrixObject(input1).isFederated(FederationMap.FType.COL) ||
ec.getMatrixObject(input1).isFederated(FederationMap.FType.FULL))
+ if(ec.getMatrixObject(input1).isFederated(FType.COL) ||
ec.getMatrixObject(input1).isFederated(FType.FULL))
processColumnQPick(ec);
else
processRowQPick(ec);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index cb76404..ded83e7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -77,7 +78,7 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction{
}
@Override
public void processInstruction(ExecutionContext ec) {
-
if(ec.getMatrixObject(input1).isFederated(FederationMap.FType.COL) ||
ec.getMatrixObject(input1).isFederated(FederationMap.FType.FULL))
+ if(ec.getMatrixObject(input1).isFederated(FType.COL) ||
ec.getMatrixObject(input1).isFederated(FType.FULL))
processColumnQSort(ec);
else
processRowQSort(ec);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
index 5731078..65a9a81 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
@@ -24,14 +24,14 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.Reques
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
index c3fbb08..58950ea 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
@@ -23,6 +23,8 @@ import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -30,8 +32,6 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
index 8fb1ae9..53c7d4e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
@@ -26,9 +26,9 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
index 5d9c608..7196901 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
@@ -22,13 +22,13 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.Operator;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
index 4f929af..c188c8c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
@@ -22,13 +22,13 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.Operator;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index 4202498..2a8308d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -31,6 +31,7 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -103,7 +104,7 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
if( !mo1.isFederated() )
throw new DMLRuntimeException("Federated Reorg: "
+ "Federated input expected, but invoked w/
"+mo1.isFederated());
- if ( !( mo1.isFederated(FederationMap.FType.COL) ||
mo1.isFederated(FederationMap.FType.ROW)) )
+ if ( !( mo1.isFederated(FType.COL) ||
mo1.isFederated(FType.ROW)) )
throw new DMLRuntimeException("Federation type " +
mo1.getFedMapping().getType()
+ " is not supported for Reorg processing");
@@ -125,7 +126,7 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest getRequest = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
Future<FederatedResponse>[] execResponse =
mo1.getFedMapping().execute(getTID(), true, fr1, getRequest);
ec.setMatrixOutput(output.getName(),
- FederationUtils.bind(execResponse,
mo1.isFederated(FederationMap.FType.COL)));
+ FederationUtils.bind(execResponse,
mo1.isFederated(FType.COL)));
}
}
else if(instOpcode.equalsIgnoreCase("rev")) {
@@ -137,7 +138,7 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
new long[] {mo1.getFedMapping().getID()},
isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true);
mo1.getFedMapping().execute(getTID(), true, fr, fr1);
- if(mo1.isFederated(FederationMap.FType.ROW))
+ if(mo1.isFederated(FType.ROW))
mo1.getFedMapping().reverseFedMap();
//derive output federated mapping
@@ -225,7 +226,7 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
private RdiagResult rdiagV2M (MatrixObject mo1, ReorgOperator r_op) {
FederationMap fedMap = mo1.getFedMapping();
- boolean rowFed = mo1.isFederated(FederationMap.FType.ROW);
+ boolean rowFed = mo1.isFederated(FType.ROW);
long varID = FederationUtils.getNextFedDataID();
Map<FederatedRange, int[]> dcs = new HashMap<>();
@@ -257,7 +258,7 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) {
FederationMap fedMap = mo1.getFedMapping();
- boolean rowFed = mo1.isFederated(FederationMap.FType.ROW);
+ boolean rowFed = mo1.isFederated(FType.ROW);
long varID = FederationUtils.getNextFedDataID();
Map<FederatedRange, int[]> dcs = new HashMap<>();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index cd05ad5..257b9b5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -36,9 +36,9 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index cb8d9fa..12436fb 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -24,11 +24,11 @@ import java.util.Objects;
import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -193,7 +193,7 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
Future<FederatedResponse>[] executionResponse =
fedMapObj.getFedMapping().execute(
getTID(), true, federatedSlices1, federatedSlices2,
collectRequests(federatedRequests, getRequest));
ec.setMatrixOutput(output.getName(),
FederationUtils.bind(executionResponse,
- fedMapObj.isFederated(FederationMap.FType.COL)));
+ fedMapObj.isFederated(FType.COL)));
}
/**
@@ -274,17 +274,17 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
boolean allAligned = mo1.isFederated() && mo2.isFederated() &&
mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)
&&
mo1.getFedMapping().isAligned(mo3.getFedMapping(),
false);
boolean twoAligned = false;
- if(!allAligned && mo1.isFederated() &&
!mo1.isFederated(FederationMap.FType.BROADCAST) && mo2.isFederated() &&
+ if(!allAligned && mo1.isFederated() &&
!mo1.isFederated(FType.BROADCAST) && mo2.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(),
false)) {
twoAligned = true;
fr = mo1.getFedMapping().broadcastSliced(mo3, false);
vars = new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr[0].getID()};
- } else if(!allAligned && mo1.isFederated() &&
!mo1.isFederated(FederationMap.FType.BROADCAST) &&
+ } else if(!allAligned && mo1.isFederated() &&
!mo1.isFederated(FType.BROADCAST) &&
mo3.isFederated() &&
mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {
twoAligned = true;
fr = mo1.getFedMapping().broadcastSliced(mo2, false);
vars = new long[] {mo1.getFedMapping().getID(),
fr[0].getID(), mo3.getFedMapping().getID()};
- } else if(!mo1.isFederated(FederationMap.FType.BROADCAST) &&
mo2.isFederated() && mo3.isFederated() &&
mo2.getFedMapping().isAligned(mo3.getFedMapping(), false) && !allAligned) {
+ } else if(!mo1.isFederated(FType.BROADCAST) &&
mo2.isFederated() && mo3.isFederated() &&
mo2.getFedMapping().isAligned(mo3.getFedMapping(), false) && !allAligned) {
twoAligned = true;
mo1 = mo2;
mo2 = mo3;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index aef46ce..41ec2a8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -28,7 +29,6 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -67,7 +67,7 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
- if((_type.isLeft() && mo1.isFederated(FederationMap.FType.ROW))
|| (mo1.isFederated(FederationMap.FType.COL) && _type.isRight())) {
+ if((_type.isLeft() && mo1.isFederated(FType.ROW)) ||
(mo1.isFederated(FType.COL) && _type.isRight())) {
//construct commands: fed tsmm, retrieve results
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new
long[]{mo1.getFedMapping().getID()});
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index 1aac3cc..615e94a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -25,12 +25,12 @@ import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -76,7 +76,7 @@ public class UnaryMatrixFEDInstruction extends
UnaryFEDInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
- if(getOpcode().startsWith("ucum") &&
mo1.isFederated(FederationMap.FType.ROW))
+ if(getOpcode().startsWith("ucum") && mo1.isFederated(FType.ROW))
processCumulativeInstruction(ec, mo1);
else {
//federated execution on arbitrary row/column partitions
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 5944947..8aca1f3 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -59,6 +59,7 @@ import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.DataExpression;
@@ -685,7 +686,7 @@ public abstract class AutomatedTestBase {
}
federatedMatrixObject.setFedMapping(new
FederationMap(FederationUtils.getNextFedDataID(), fedHashMap));
-
federatedMatrixObject.getFedMapping().setType(FederationMap.FType.ROW);
+ federatedMatrixObject.getFedMapping().setType(FType.ROW);
writeInputFederatedWithMTD(name, federatedMatrixObject, null);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
index 5413b9b..33ee0e1 100644
---
a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
@@ -25,7 +25,6 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Ignore;
-import org.junit.Test;
public class BuiltinTopkEvaluateTest extends AutomatedTestBase {
// private final static String TEST_NAME1 = "prioritized";
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
index b8ad989..d2ec111 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -33,8 +33,8 @@ import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
-import org.apache.sysds.hops.ipa.IPAPassRewriteFederatedPlan;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.LanguageException;
@@ -293,7 +293,7 @@ public class FederatedCostEstimatorTest extends
AutomatedTestBase {
}
private void compareResults(DMLProgram prog) {
- IPAPassRewriteFederatedPlan rewriter = new
IPAPassRewriteFederatedPlan();
+ FederatedPlannerCostbased rewriter = new
FederatedPlannerCostbased();
rewriter.rewriteProgram(prog, new FunctionCallGraph(prog),
null);
double actualCost = 0;
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index 75fc236..cc3c5dd 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -71,6 +71,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
}
+ @SuppressWarnings("unused")
private void writeStandardMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
writeStandardMatrix(matrixName, seed, rows, privacyConstraint);
}