This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new fe179649a6 [SYSTEMDS-3018] Fed Planner Extended 4
fe179649a6 is described below
commit fe179649a6b6959f6be9daad1aeeef20e0c40c33
Author: sebwrede <[email protected]>
AuthorDate: Fri Apr 29 15:59:14 2022 +0200
[SYSTEMDS-3018] Fed Planner Extended 4
This commit does the following:
- Edit Output FType for AggBinary and Remove PART Processing in Tsmm
- Add Central Moment, Covariance, and Quantiles FedOut Flag Compilation
- Add Fedout Support in Instruction String for Binary Operations
- Call UpdateETFed from CostBased FedPlanner and Remove Calls from within
Hop Subclasses
- Add ExecType Check to AggUnaryOp and Add Log Trace in Cost Based
Fedplanner
Closes #1603.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 2 --
.../java/org/apache/sysds/hops/AggUnaryOp.java | 4 +--
src/main/java/org/apache/sysds/hops/BinaryOp.java | 3 +-
src/main/java/org/apache/sysds/hops/DataOp.java | 2 --
src/main/java/org/apache/sysds/hops/Hop.java | 15 +++++++--
src/main/java/org/apache/sysds/hops/ReorgOp.java | 2 --
src/main/java/org/apache/sysds/hops/TernaryOp.java | 2 --
src/main/java/org/apache/sysds/hops/UnaryOp.java | 1 -
.../sysds/hops/fedplanner/AFederatedPlanner.java | 5 +++
.../hops/fedplanner/FederatedPlannerCostbased.java | 3 ++
src/main/java/org/apache/sysds/lops/Append.java | 5 ++-
.../java/org/apache/sysds/lops/BinaryScalar.java | 5 ++-
.../java/org/apache/sysds/lops/CentralMoment.java | 8 +++--
.../java/org/apache/sysds/lops/CoVariance.java | 6 +++-
.../java/org/apache/sysds/lops/PickByCount.java | 10 +++++-
src/main/java/org/apache/sysds/lops/SortKeys.java | 31 +++++++++----------
.../runtime/instructions/FEDInstructionParser.java | 18 +++++++++++
.../instructions/fed/AppendFEDInstruction.java | 36 +++++++++++++++++++---
.../fed/CentralMomentFEDInstruction.java | 20 +++++++++++-
.../instructions/fed/CovarianceFEDInstruction.java | 19 +++++++++++-
.../runtime/instructions/fed/FEDInstruction.java | 2 ++
.../instructions/fed/FEDInstructionUtils.java | 22 ++++++-------
.../fed/QuantilePickFEDInstruction.java | 26 +++++++++++-----
.../fed/QuantileSortFEDInstruction.java | 20 +++++++++---
.../instructions/fed/TernaryFEDInstruction.java | 2 +-
.../instructions/fed/TsmmFEDInstruction.java | 16 ----------
26 files changed, 202 insertions(+), 83 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 3eb5c2a41e..403a0466f0 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -434,8 +434,6 @@ public class AggBinaryOp extends MultiThreadedHop {
_etype = ExecType.SPARK;
}
- updateETFed();
-
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 923503b170..c461b69bac 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -128,7 +128,7 @@ public class AggUnaryOp extends MultiThreadedHop
if( isTernaryAggregateRewriteApplicable() ) {
agg1 =
constructLopsTernaryAggregateRewrite(et);
}
- else if(
isUnaryAggregateOuterCPRewriteApplicable() )
+ else if( et != ExecType.FED &&
isUnaryAggregateOuterCPRewriteApplicable() )
{
BinaryOp binput =
(BinaryOp)getInput().get(0);
agg1 = new UAggOuterChain(
binput.getInput().get(0).constructLops(),
@@ -385,8 +385,6 @@ public class AggUnaryOp extends MultiThreadedHop
_etype = ExecType.SPARK;
}
- updateETFed();
-
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 20178afc73..791c3bdfbd 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -247,6 +247,7 @@ public class BinaryOp extends MultiThreadedHop {
getInput().get(0).getDim2(),
getInput().get(0).getBlocksize(),
getInput().get(0).getNnz());
+ updateLopFedOut(sort);
PickByCount pick = new PickByCount(
sort,
null,
@@ -466,7 +467,7 @@ public class BinaryOp extends MultiThreadedHop {
boolean isLeftXGt0 = isLeftXGt &&
HopRewriteUtils.isLiteralOfValue(potentialZero, 0);
- if(op == OpOp2.MULT && isLeftXGt0 &&
+ if(et != ExecType.FED && op == OpOp2.MULT &&
isLeftXGt0 &&
!getInput().get(0).isVector() &&
!getInput().get(1).isVector()
&& getInput().get(0).dimsKnown() &&
getInput().get(1).dimsKnown()) {
binary = new
DnnTransform(getInput().get(0).getInput().get(0).constructLops(),
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java
b/src/main/java/org/apache/sysds/hops/DataOp.java
index 548417deec..6035d59f87 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -498,8 +498,6 @@ public class DataOp extends Hop {
_etype = letype;
}
- updateETFed();
-
return _etype;
}
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 344bb3065c..4ce9a4b90f 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -228,9 +228,15 @@ public abstract class Hop implements ParseInfo {
public void setForcedExecType(ExecType etype)
{
+ logForcedETCall(etype);
_etypeForced = etype;
}
+ private void logForcedETCall(ExecType newEType){
+ if ( LOG.isDebugEnabled() && _etypeForced != null && newEType
!= _etypeForced )
+ LOG.debug("Forced ExecType of " + this + " changed from
" + _etypeForced + " to " + newEType);
+ }
+
public abstract boolean allowsAllExecTypes();
/**
@@ -908,12 +914,17 @@ public abstract class Hop implements ParseInfo {
* This method only has an effect if FEDERATED_COMPILATION is activated.
* Federated compilation is activated in OptimizerUtils.
*/
- protected void updateETFed() {
+ public void updateETFed() {
boolean localOut = hasLocalOutput();
boolean fedIn = getInput().stream().anyMatch(
in -> in.hasFederatedOutput() &&
!(in.prefetchActivated() && localOut));
- if( isFederatedDataOp() || fedIn )
+ if( isFederatedDataOp() || fedIn ){
+ setForcedExecType(ExecType.FED);
+ //TODO: Temporary solution where _etype is set directly
+ // since forcedExecType for BinaryOp may be overwritten
+ // if updateETFed is not called from optFindExecType.
_etype = ExecType.FED;
+ }
}
/**
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java
b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index e01fe4dead..057bdac782 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -374,8 +374,6 @@ public class ReorgOp extends MultiThreadedHop
checkAndSetInvalidCPDimsAndSize();
}
- updateETFed();
-
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index a754f1be23..bca5938051 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -511,8 +511,6 @@ public class TernaryOp extends MultiThreadedHop
checkAndSetInvalidCPDimsAndSize();
}
- updateETFed();
-
//mark for recompile (forever)
// additional condition: when execType=CP and additional
dimension inputs
// are provided (and those values are unknown at initial
compile time).
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 57df31a958..38a20fffc0 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -523,7 +523,6 @@ public class UnaryOp extends MultiThreadedHop
{
_etype = ExecType.CP;
} else {
- updateETFed();
setRequiresRecompileIfNecessary();
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
index 3403cc4bbe..6486ead712 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
@@ -34,6 +34,7 @@ import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
@@ -117,6 +118,10 @@ public abstract class AFederatedPlanner {
if ( hop.isScalar() )
return null;
if( hop instanceof AggBinaryOp ) {
+ MMTSJType mmtsj = ((AggBinaryOp)
hop).checkTransposeSelf() ; //determine tsmm pattern
+ if ( mmtsj != MMTSJType.NONE &&
+ (( mmtsj.isLeft() && ft[0] == FType.ROW ) || (
mmtsj.isRight() && ft[0] == FType.COL ) ))
+ return FType.BROADCAST;
if( ft[0] != null )
return ft[0] == FType.ROW ? FType.ROW : null;
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index ee39e468bd..a809d2bafd 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -215,6 +215,7 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
updateFederatedOutput(root, rootHopRel);
visitInputDependency(rootHopRel);
}
+ root.updateETFed();
}
/**
@@ -358,6 +359,8 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
} else {
foutHopRelMap.put(outputFType,
alt);
}
+ } else {
+ LOG.trace("Allows federated, but FOUT
is not allowed: " + currentHop + " input FTypes: " + inputCombination);
}
} else {
LOG.trace("Does not allow federated: " +
currentHop + " input FTypes: " + inputCombination);
diff --git a/src/main/java/org/apache/sysds/lops/Append.java
b/src/main/java/org/apache/sysds/lops/Append.java
index c8a04c61e2..c32da1e8b4 100644
--- a/src/main/java/org/apache/sysds/lops/Append.java
+++ b/src/main/java/org/apache/sysds/lops/Append.java
@@ -59,7 +59,7 @@ public class Append extends Lop
//called when append executes in CP
@Override
public String getInstructions(String input1, String input2, String
input3, String output) {
- return InstructionUtils.concatOperands(
+ String ret = InstructionUtils.concatOperands(
getExecType().name(),
"append",
getInputs().get(0).prepInputOperand(input1),
@@ -67,5 +67,8 @@ public class Append extends Lop
getInputs().get(2).prepScalarInputOperand(getExecType()),
prepOutputOperand(output),
String.valueOf(_cbind));
+ if ( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret,
_fedOutput.name());
+ return ret;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/BinaryScalar.java
b/src/main/java/org/apache/sysds/lops/BinaryScalar.java
index ac5b195764..9b3fa7d960 100644
--- a/src/main/java/org/apache/sysds/lops/BinaryScalar.java
+++ b/src/main/java/org/apache/sysds/lops/BinaryScalar.java
@@ -69,10 +69,13 @@ public class BinaryScalar extends Lop
@Override
public String getInstructions(String input1, String input2, String
output) {
- return InstructionUtils.concatOperands(
+ String ret = InstructionUtils.concatOperands(
getExecType().name(), operation.toString(),
getInputs().get(0).prepScalarInputOperand(getExecType()),
getInputs().get(1).prepScalarInputOperand(getExecType()),
prepOutputOperand(output));
+ if ( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret,
_fedOutput.name());
+ return ret;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/CentralMoment.java
b/src/main/java/org/apache/sysds/lops/CentralMoment.java
index 78ed5434cc..bb825fe5cf 100644
--- a/src/main/java/org/apache/sysds/lops/CentralMoment.java
+++ b/src/main/java/org/apache/sysds/lops/CentralMoment.java
@@ -96,9 +96,13 @@ public class CentralMoment extends Lop
getInputs().get((input3!=null)?2:1).prepScalarInputOperand(getExecType()),
prepOutputOperand(output)));
}
- if( getExecType() == ExecType.CP ) {
+ if( getExecType() == ExecType.CP || getExecType() ==
ExecType.FED ) {
sb.append(OPERAND_DELIMITOR);
- sb.append(String.valueOf(_numThreads));
+ sb.append(_numThreads);
+ if ( getExecType() == ExecType.FED ){
+ sb.append(OPERAND_DELIMITOR);
+ sb.append(_fedOutput);
+ }
}
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/lops/CoVariance.java
b/src/main/java/org/apache/sysds/lops/CoVariance.java
index 235738071f..a68844fa35 100644
--- a/src/main/java/org/apache/sysds/lops/CoVariance.java
+++ b/src/main/java/org/apache/sysds/lops/CoVariance.java
@@ -97,9 +97,13 @@ public class CoVariance extends Lop
}
sb.append( prepOutputOperand(output));
- if( getExecType() == ExecType.CP ) {
+ if( getExecType() == ExecType.CP || getExecType() ==
ExecType.FED ) {
sb.append( OPERAND_DELIMITOR );
sb.append(_numThreads);
+ if ( getExecType() == ExecType.FED ){
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( _fedOutput );
+ }
}
return sb.toString();
diff --git a/src/main/java/org/apache/sysds/lops/PickByCount.java
b/src/main/java/org/apache/sysds/lops/PickByCount.java
index b94a82efd8..2319a5c94d 100644
--- a/src/main/java/org/apache/sysds/lops/PickByCount.java
+++ b/src/main/java/org/apache/sysds/lops/PickByCount.java
@@ -109,6 +109,11 @@ public class PickByCount extends Lop
sb.append( OPERAND_DELIMITOR );
sb.append(inMemoryInput);
+
+ if ( getExecType() == ExecType.FED ){
+ sb.append( OPERAND_DELIMITOR );
+ sb.append(_fedOutput.name());
+ }
return sb.toString();
}
@@ -121,12 +126,15 @@ public class PickByCount extends Lop
*/
@Override
public String getInstructions(String input, String output) {
- return InstructionUtils.concatOperands(
+ String ret = InstructionUtils.concatOperands(
getExecType().name(),
OPCODE,
getInputs().get(0).prepInputOperand(input),
prepOutputOperand(output),
operation.name(),
String.valueOf(inMemoryInput));
+ if ( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret,
_fedOutput.name());
+ return ret;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/SortKeys.java
b/src/main/java/org/apache/sysds/lops/SortKeys.java
index f7a0f82388..58754a86eb 100644
--- a/src/main/java/org/apache/sysds/lops/SortKeys.java
+++ b/src/main/java/org/apache/sysds/lops/SortKeys.java
@@ -85,35 +85,34 @@ public class SortKeys extends Lop
@Override
public String getInstructions(String input, String output) {
- StringBuilder sb = new StringBuilder();
- sb.append(InstructionUtils.concatOperands(
+ String ret = InstructionUtils.concatOperands(
getExecType().name(),
OPCODE,
getInputs().get(0).prepInputOperand(input),
- prepOutputOperand(output)));
+ prepOutputOperand(output));
- if( getExecType() == ExecType.CP ) {
- sb.append( OPERAND_DELIMITOR );
- sb.append(_numThreads);
- }
- return sb.toString();
+ if( getExecType() == ExecType.CP ) {
+ ret = InstructionUtils.concatOperands(ret,
Integer.toString(_numThreads));
+ }
+ if ( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret,
Integer.toString(_numThreads), _fedOutput.name());
+ return ret;
}
@Override
public String getInstructions(String input1, String input2, String
output) {
- StringBuilder sb = new StringBuilder();
- sb.append(InstructionUtils.concatOperands(
+ String ret = InstructionUtils.concatOperands(
getExecType().name(),
OPCODE,
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
- prepOutputOperand(output)));
+ prepOutputOperand(output));
- if( getExecType() == ExecType.CP ) {
- sb.append( OPERAND_DELIMITOR );
- sb.append(_numThreads);
- }
- return sb.toString();
+ if( getExecType() == ExecType.CP )
+ ret = InstructionUtils.concatOperands(ret,
Integer.toString(_numThreads));
+ if ( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret,
Integer.toString(_numThreads), _fedOutput.name());
+ return ret;
}
// This method is invoked in two cases:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index e983e06f5e..2fde0a0fbc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -26,9 +26,13 @@ import
org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AppendFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.CentralMomentFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.CovarianceFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.QuantilePickFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.QuantileSortFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.TernaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.TsmmFEDInstruction;
@@ -81,6 +85,12 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "+*" , FEDType.Ternary);
String2FEDInstructionType.put( "-*" , FEDType.Ternary);
+ //central moment, covariance, quantiles (sort/pick)
+ String2FEDInstructionType.put( "cm", FEDType.CentralMoment);
+ String2FEDInstructionType.put( "cov", FEDType.Covariance);
+ String2FEDInstructionType.put( "qsort", FEDType.QSort);
+ String2FEDInstructionType.put( "qpick", FEDType.QPick);
+
String2FEDInstructionType.put(Append.OPCODE, FEDType.Append);
}
@@ -118,6 +128,14 @@ public class FEDInstructionParser extends InstructionParser
return
AppendFEDInstruction.parseInstruction(str);
case AggregateTernary:
return
AggregateTernaryFEDInstruction.parseInstruction(str);
+ case CentralMoment:
+ return
CentralMomentFEDInstruction.parseInstruction(str);
+ case Covariance:
+ return
CovarianceFEDInstruction.parseInstruction(str);
+ case QSort:
+ return
QuantileSortFEDInstruction.parseInstruction(str, true);
+ case QPick:
+ return
QuantilePickFEDInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid
FEDERATED Instruction Type: " + fedtype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 6add1c613b..8e672447b4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -30,8 +30,11 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
+import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -47,18 +50,43 @@ public class AppendFEDInstruction extends
BinaryFEDInstruction {
_cbind = cbind;
}
+ protected AppendFEDInstruction(Operator op, CPOperand in1, CPOperand
in2, CPOperand out, boolean cbind,
+ String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.Append, op, in1, in2, out, opcode, istr, fedOut);
+ _cbind = cbind;
+ }
+
+ public static AppendFEDInstruction parseInstruction(Instruction inst){
+ if ( inst instanceof CPInstruction || inst instanceof
SPInstruction ){
+ String instStr = inst.getInstructionString();
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(instStr);
+ InstructionUtils.checkNumFields(parts, 6, 5, 4);
+
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[parts.length - 2]);
+ boolean cbind = Boolean.parseBoolean(parts[parts.length
- 1]);
+
+ Operator op = new
ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1));
+ return new AppendFEDInstruction(op, in1, in2, out,
cbind, opcode, instStr);
+ }
+ else return parseInstruction(inst.getInstructionString());
+ }
+
public static AppendFEDInstruction parseInstruction(String str) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields(parts, 6, 5, 4);
+ InstructionUtils.checkNumFields(parts, 7, 6, 5);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
- CPOperand out = new CPOperand(parts[parts.length - 2]);
- boolean cbind = Boolean.parseBoolean(parts[parts.length - 1]);
+ CPOperand out = new CPOperand(parts[parts.length - 3]);
+ boolean cbind = Boolean.parseBoolean(parts[parts.length - 2]);
+ FederatedOutput fedOut =
FederatedOutput.valueOf(parts[parts.length-1]);
Operator op = new
ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1));
- return new AppendFEDInstruction(op, in1, in2, out, cbind,
opcode, str);
+ return new AppendFEDInstruction(op, in1, in2, out, cbind,
opcode, str, fedOut);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
index ab9c9ed10a..7180fec353 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
@@ -32,12 +32,15 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -52,7 +55,22 @@ public class CentralMomentFEDInstruction extends
AggregateUnaryFEDInstruction {
}
public static CentralMomentFEDInstruction parseInstruction(String str) {
- return
parseInstruction(CentralMomentCPInstruction.parseInstruction(str));
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ FederatedOutput fedOut =
FederatedOutput.valueOf(parts[parts.length-1]);
+ String cleanInstStr = InstructionUtils.removeFEDOutputFlag(str);
+ CentralMomentCPInstruction cpInst =
CentralMomentCPInstruction.parseInstruction(cleanInstStr);
+ CentralMomentFEDInstruction fedInst = parseInstruction(cpInst);
+ fedInst._fedOut = fedOut;
+ return fedInst;
+ }
+
+ public static CentralMomentFEDInstruction parseInstruction(Instruction
inst){
+ if ( inst instanceof CentralMomentCPInstruction)
+ return parseInstruction((CentralMomentCPInstruction)
inst);
+ else if ( inst instanceof SPInstruction )
+ return
parseInstruction(CentralMomentCPInstruction.parseInstruction(inst.getInstructionString()));
+ else
+ return parseInstruction(inst.getInstructionString());
}
public static CentralMomentFEDInstruction
parseInstruction(CentralMomentCPInstruction inst) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
index 18dfe75ed7..2c41ee2e62 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -37,12 +37,15 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.COVOperator;
@@ -57,7 +60,21 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
}
public static CovarianceFEDInstruction parseInstruction(String str) {
- return
parseInstruction(CovarianceCPInstruction.parseInstruction(str));
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ FederatedOutput fedOut =
FederatedOutput.valueOf(parts[parts.length-1]);
+ String cleanInstStr = InstructionUtils.removeFEDOutputFlag(str);
+ CovarianceFEDInstruction fedInst =
parseInstruction(CovarianceCPInstruction.parseInstruction(cleanInstStr));
+ fedInst._fedOut = fedOut;
+ return fedInst;
+ }
+
+ public static CovarianceFEDInstruction parseInstruction(Instruction
inst){
+ if ( inst instanceof CovarianceCPInstruction )
+ return parseInstruction((CovarianceCPInstruction) inst);
+ else if ( inst instanceof SPInstruction )
+ return
parseInstruction(CovarianceCPInstruction.parseInstruction(inst.getInstructionString()));
+ else
+ return parseInstruction(inst.getInstructionString());
}
public static CovarianceFEDInstruction
parseInstruction(CovarianceCPInstruction inst) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 010b8b2a07..955013870f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -33,7 +33,9 @@ public abstract class FEDInstruction extends Instruction {
Append,
Binary,
Cast,
+ CentralMoment,
Checkpoint,
+ Covariance,
CSVReblock,
Ctable,
CumulativeAggregate,
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 584f293021..5e3b7a9f8b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -156,10 +156,10 @@ public class FEDInstructionUtils {
MatrixObject mo1 =
ec.getMatrixObject(instruction.input1);
if(
mo1.isFederatedExcept(FType.BROADCAST) ) {
if(instruction.getOpcode().equalsIgnoreCase("cm"))
- fedinst =
CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
CentralMomentFEDInstruction.parseInstruction(inst);
else
if(inst.getOpcode().equalsIgnoreCase("qsort")) {
if(mo1.isFederated(FType.ROW) ||
mo1.getFedMapping().getFederatedRanges().length == 1 &&
mo1.isFederated(FType.COL))
- fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString(), false);
}
else
if(inst.getOpcode().equalsIgnoreCase("rshape"))
fedinst =
ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
@@ -180,12 +180,12 @@ public class FEDInstructionUtils {
if( (instruction.input1.isMatrix() &&
ec.getMatrixObject(instruction.input1).isFederatedExcept(FType.BROADCAST))
|| (instruction.input2.isMatrix() &&
ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
if(instruction.getOpcode().equals("append") )
- fedinst =
AppendFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
AppendFEDInstruction.parseInstruction(inst);
else
if(instruction.getOpcode().equals("qpick"))
- fedinst =
QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
QuantilePickFEDInstruction.parseInstruction(inst);
else
if("cov".equals(instruction.getOpcode()) &&
(ec.getMatrixObject(instruction.input1).isFederated(FType.ROW) ||
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
- fedinst =
CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
CovarianceFEDInstruction.parseInstruction(inst);
else
fedinst =
BinaryFEDInstruction.parseInstruction(
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
@@ -329,12 +329,12 @@ public class FEDInstructionUtils {
CentralMomentSPInstruction cinstruction =
(CentralMomentSPInstruction) inst;
Data data = ec.getVariable(cinstruction.input1);
if (data instanceof MatrixObject &&
((MatrixObject) data).isFederated() && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
- fedinst =
CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
CentralMomentFEDInstruction.parseInstruction(inst);
} else if (inst instanceof QuantileSortSPInstruction) {
QuantileSortSPInstruction qinstruction =
(QuantileSortSPInstruction) inst;
Data data = ec.getVariable(qinstruction.input1);
if (data instanceof MatrixObject &&
((MatrixObject) data).isFederated() && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
- fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString(), false);
}
else if (inst instanceof AggregateUnarySPInstruction) {
AggregateUnarySPInstruction auinstruction =
(AggregateUnarySPInstruction) inst;
@@ -369,7 +369,7 @@ public class FEDInstructionUtils {
fedinst =
CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction)inst);
else
if(inst.getOpcode().equalsIgnoreCase("qsort")) {
if(mo1.getFedMapping().getFederatedRanges().length == 1)
- fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString(), false);
}
else
if(inst.getOpcode().equalsIgnoreCase("rshape")) {
fedinst =
ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
@@ -394,7 +394,7 @@ public class FEDInstructionUtils {
QuantilePickSPInstruction qinstruction =
(QuantilePickSPInstruction) inst;
Data data = ec.getVariable(qinstruction.input1);
if(data instanceof MatrixObject &&
((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- fedinst =
QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
QuantilePickFEDInstruction.parseInstruction(inst);
}
else if (inst instanceof AppendGAlignedSPInstruction ||
inst instanceof AppendGSPInstruction
|| inst instanceof AppendMSPInstruction || inst
instanceof AppendRSPInstruction) {
@@ -403,7 +403,7 @@ public class FEDInstructionUtils {
Data data2 =
ec.getVariable(ainstruction.input2);
if ((data1 instanceof MatrixObject &&
((MatrixObject) data1).isFederatedExcept(FType.BROADCAST))
|| (data2 instanceof MatrixObject &&
((MatrixObject) data2).isFederatedExcept(FType.BROADCAST))) {
- fedinst =
AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+ fedinst =
AppendFEDInstruction.parseInstruction(instruction);
}
}
else if (inst instanceof BinaryMatrixScalarSPInstruction
@@ -422,7 +422,7 @@ public class FEDInstructionUtils {
|| (instruction.input2.isMatrix() &&
ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
if("cov".equals(instruction.getOpcode()) &&
(ec.getMatrixObject(instruction.input1)
.isFederated(FType.ROW) ||
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
- fedinst =
CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst =
CovarianceFEDInstruction.parseInstruction(inst);
else if(inst instanceof
CumulativeOffsetSPInstruction) {
fedinst =
CumulativeOffsetFEDInstruction.parseInstruction(inst.getInstructionString());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index cce571bbf7..2fa9c30e26 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -34,6 +34,7 @@ import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.PickByCount.OperationTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
@@ -47,6 +48,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
@@ -78,37 +80,45 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
this(op, in, in2, out, type, inmem, opcode, istr,
FederatedOutput.NONE);
}
+ public static QuantilePickFEDInstruction parseInstruction( Instruction
inst ){
+ return parseInstruction(inst.getInstructionString() +
Lop.OPERAND_DELIMITOR + FederatedOutput.NONE);
+ }
+
public static QuantilePickFEDInstruction parseInstruction ( String str
) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if ( !opcode.equalsIgnoreCase("qpick") )
throw new DMLRuntimeException("Unknown opcode while
parsing a QuantilePickCPInstruction: " + str);
+ FederatedOutput fedOut =
FederatedOutput.valueOf(parts[parts.length-1]);
+ QuantilePickFEDInstruction inst = null;
//instruction parsing
- if( parts.length == 4 ) {
- //instructions of length 4 originate from unary - mr-iqm
+ if( parts.length == 5 ) {
+ //instructions of length 5 originate from unary - mr-iqm
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
OperationTypes ptype = OperationTypes.IQM;
boolean inmem = false;
- return new QuantilePickFEDInstruction(null, in1, in2,
out, ptype, inmem, opcode, str);
+ inst = new QuantilePickFEDInstruction(null, in1, in2,
out, ptype, inmem, opcode, str);
}
- else if( parts.length == 5 ) {
+ else if( parts.length == 6 ) {
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
OperationTypes ptype = OperationTypes.valueOf(parts[3]);
boolean inmem = Boolean.parseBoolean(parts[4]);
- return new QuantilePickFEDInstruction(null, in1, out,
ptype, inmem, opcode, str);
+ inst = new QuantilePickFEDInstruction(null, in1, out,
ptype, inmem, opcode, str);
}
- else if( parts.length == 6 ) {
+ else if( parts.length == 7 ) {
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
OperationTypes ptype = OperationTypes.valueOf(parts[4]);
boolean inmem = Boolean.parseBoolean(parts[5]);
- return new QuantilePickFEDInstruction(null, in1, in2,
out, ptype, inmem, opcode, str);
+ inst = new QuantilePickFEDInstruction(null, in1, in2,
out, ptype, inmem, opcode, str);
}
- return null;
+ if ( inst != null )
+ inst._fedOut = fedOut;
+ return inst;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index b77dd83e6c..ece09392f6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -70,7 +70,7 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction {
}
}
- public static QuantileSortFEDInstruction parseInstruction ( String str
) {
+ public static QuantileSortFEDInstruction parseInstruction ( String str
, boolean hasFedOut) {
CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
CPOperand in2 = null;
CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
@@ -78,7 +78,17 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
boolean isSpark = str.startsWith("SPARK");
- int k = isSpark ? 1 : Integer.parseInt(parts[parts.length-1]);
+ int k;
+ FederatedOutput fedOut;
+ if ( hasFedOut){
+ k = isSpark ? 1 :
Integer.parseInt(parts[parts.length-2]);
+ fedOut = FederatedOutput.valueOf(parts[parts.length-1]);
+ } else {
+ k = isSpark ? 1 :
Integer.parseInt(parts[parts.length-1]);
+ fedOut = FederatedOutput.NONE;
+ }
+
+ QuantileSortFEDInstruction inst;
if ( opcode.equalsIgnoreCase(SortKeys.OPCODE) ) {
int oneInputLength = isSpark ? 3 : 4;
@@ -86,14 +96,14 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction {
if ( parts.length == oneInputLength ) {
// Example: sort:mVar1:mVar2 (input=mVar1,
output=mVar2)
parseUnaryInstruction(str, in1, out);
- return new QuantileSortFEDInstruction(in1, out,
opcode, str, k);
+ inst = new QuantileSortFEDInstruction(in1, out,
opcode, str, k);
}
else if ( parts.length == twoInputLength ) {
// Example: sort:mVar1:mVar2:mVar3
(input=mVar1, weights=mVar2, output=mVar3)
in2 = new CPOperand("",
Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
InstructionUtils.checkNumFields(str,
twoInputLength-1);
parseInstruction(str, in1, in2, out);
- return new QuantileSortFEDInstruction(in1, in2,
out, opcode, str, k);
+ inst = new QuantileSortFEDInstruction(in1, in2,
out, opcode, str, k);
}
else {
throw new DMLRuntimeException("Invalid number
of operands in instruction: " + str);
@@ -102,6 +112,8 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction {
else {
throw new DMLRuntimeException("Unknown opcode while
parsing a QuantileSortFEDInstruction: " + str);
}
+ inst._fedOut = fedOut;
+ return inst;
}
@Override
public void processInstruction(ExecutionContext ec) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 29c59abc85..c9d518f226 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -52,7 +52,7 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
CPOperand operand3 = new CPOperand(parts[3]);
CPOperand outOperand = new CPOperand(parts[4]);
int numThreads = parts.length>5 & !opcode.contains("map") ?
Integer.parseInt(parts[5]) : 1;
- FederatedOutput fedOut = parts.length>7 &&
!opcode.contains("map") ? FederatedOutput.valueOf(parts[6]) :
FederatedOutput.NONE;
+ FederatedOutput fedOut = parts.length>=7 &&
!opcode.contains("map") ? FederatedOutput.valueOf(parts[6]) :
FederatedOutput.NONE;
TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode, numThreads);
if( operand1.isFrame() && operand2.isScalar() ||
operand2.isFrame() && operand1.isScalar() )
return new TernaryFrameScalarFEDInstruction(op,
operand1, operand2, operand3, outOperand, opcode,
InstructionUtils.removeFEDOutputFlag(str), fedOut);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 11eefb46f2..7ad48cfa25 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -72,8 +72,6 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
MatrixObject mo1 = ec.getMatrixObject(input1);
if((_type.isLeft() && mo1.isFederated(FType.ROW)) ||
(mo1.isFederated(FType.COL) && _type.isRight()))
processRowCol(ec, mo1);
- else if ( mo1.isFederated(FType.PART) )
- processPart(ec, mo1);
else { //other combinations
String exMessage = (!mo1.isFederated() ||
mo1.getFedMapping() == null) ?
"Federated Tsmm does not support non-federated
input" :
@@ -82,20 +80,6 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction
{
}
}
- private void processPart(ExecutionContext ec, MatrixObject mo1){
- if (_fedOut.isForcedFederated()){
- FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo1);
- FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1}, new
long[]{mo1.getFedMapping().getID()}, true);
- mo1.getFedMapping().execute(getTID(), fr1, fr2);
- setOutputFederated(ec, mo1, fr2, FType.BROADCAST);
- } else {
- mo1.acquireReadAndRelease();
- CPInstruction tsmmCPInst =
CPInstructionParser.parseSingleInstruction(instString);
- tsmmCPInst.processInstruction(ec);
- }
- }
-
private void processRowCol(ExecutionContext ec, MatrixObject mo1){
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new
long[]{mo1.getFedMapping().getID()}, true);