This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cc24dc36c1 [SYSTEMDS-3386] Refactor replacement CP Cleanup
cc24dc36c1 is described below
commit cc24dc36c1d026f3ce96ef0c0285f6fd21c2e2d9
Author: baunsgaard <[email protected]>
AuthorDate: Thu Aug 18 18:37:35 2022 +0200
[SYSTEMDS-3386] Refactor replacement CP Cleanup
This commit moves the logic of parsing the individual instructions
from CP or SP to individual Fed instructions, giving a cleaner design.
Closes #1680
---
.../fed/AggregateBinaryFEDInstruction.java | 22 +-
.../fed/AggregateTernaryFEDInstruction.java | 24 +-
.../fed/AggregateUnaryFEDInstruction.java | 4 +-
.../instructions/fed/BinaryFEDInstruction.java | 133 +++++-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 4 +-
.../instructions/fed/CastFEDInstruction.java | 14 +-
.../instructions/fed/CtableFEDInstruction.java | 46 +-
.../instructions/fed/FEDInstructionUtils.java | 479 +++------------------
.../instructions/fed/MMChainFEDInstruction.java | 19 +-
...tiReturnParameterizedBuiltinFEDInstruction.java | 28 +-
.../fed/ParameterizedBuiltinFEDInstruction.java | 71 +--
.../fed/QuantileSortFEDInstruction.java | 1 +
.../instructions/fed/QuaternaryFEDInstruction.java | 43 +-
.../instructions/fed/ReorgFEDInstruction.java | 29 +-
.../instructions/fed/SpoofFEDInstruction.java | 24 +-
.../instructions/fed/TernaryFEDInstruction.java | 51 ++-
.../instructions/fed/TsmmFEDInstruction.java | 14 +-
.../instructions/fed/UnaryFEDInstruction.java | 143 +++++-
.../fed/UnaryMatrixFEDInstruction.java | 8 +-
.../instructions/fed/VariableFEDInstruction.java | 19 +-
20 files changed, 629 insertions(+), 547 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6128df20e0..9340e9fb12 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -38,7 +38,6 @@ import
org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import
org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -55,14 +54,23 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
super(FEDType.AggregateBinary, op, in1, in2, out, opcode, istr,
fedOut);
}
- public static AggregateBinaryFEDInstruction
parseInstruction(AggregateBinaryCPInstruction instr) {
- return new AggregateBinaryFEDInstruction(instr.getOperator(),
instr.input1, instr.input2, instr.output,
- instr.getOpcode(), instr.getInstructionString(),
FederatedOutput.NONE);
+ public static AggregateBinaryFEDInstruction
parseInstruction(AggregateBinaryCPInstruction inst,
+ ExecutionContext ec) {
+ if(inst.input1.isMatrix() && inst.input2.isMatrix()) {
+ MatrixObject mo1 = ec.getMatrixObject(inst.input1);
+ MatrixObject mo2 = ec.getMatrixObject(inst.input2);
+ if((mo1.isFederated(FType.ROW) &&
mo1.isFederatedExcept(FType.BROADCAST)) ||
+ (mo2.isFederated(FType.ROW) &&
mo2.isFederatedExcept(FType.BROADCAST)) ||
+ (mo1.isFederated(FType.COL) &&
mo1.isFederatedExcept(FType.BROADCAST))) {
+ return
AggregateBinaryFEDInstruction.parseInstruction(inst);
+ }
+ }
+ return null;
}
- public static AggregateBinaryFEDInstruction
parseInstruction(AggregateBinarySPInstruction instr) {
+ private static AggregateBinaryFEDInstruction
parseInstruction(AggregateBinaryCPInstruction instr) {
return new AggregateBinaryFEDInstruction(instr.getOperator(),
instr.input1, instr.input2, instr.output,
- instr.getOpcode(),
instr.getInstructionString(), FederatedOutput.NONE);
+ instr.getOpcode(), instr.getInstructionString(),
FederatedOutput.NONE);
}
public static AggregateBinaryFEDInstruction parseInstruction(String
str) {
@@ -70,7 +78,7 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
String opcode = parts[0];
if(!opcode.equalsIgnoreCase("ba+*"))
throw new
DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown
opcode " + opcode);
-
+
InstructionUtils.checkNumFields(parts, 5);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 5a54fc6374..f8e8f8ad22 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -48,12 +48,30 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
super(FEDType.AggregateTernary, op, in1, in2, in3, out, opcode,
istr, fedOut);
}
- public static AggregateTernaryFEDInstruction
parseInstruction(AggregateTernaryCPInstruction instr) {
+ public static AggregateTernaryFEDInstruction
parseInstruction(AggregateTernaryCPInstruction inst,
+ ExecutionContext ec) {
+ if(inst.input1.isMatrix() &&
ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST) &&
+ inst.input2.isMatrix() &&
ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) {
+ return parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static AggregateTernaryFEDInstruction
parseInstruction(AggregateTernarySPInstruction inst,
+ ExecutionContext ec) {
+ if(inst.input1.isMatrix() &&
ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST) &&
+ inst.input2.isMatrix() &&
ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) {
+ return parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static AggregateTernaryFEDInstruction
parseInstruction(AggregateTernaryCPInstruction instr) {
return new AggregateTernaryFEDInstruction(instr.getOperator(),
instr.input1, instr.input2, instr.input3,
instr.output, instr.getOpcode(),
instr.getInstructionString(), FederatedOutput.NONE);
}
- public static AggregateTernaryFEDInstruction
parseInstruction(AggregateTernarySPInstruction instr) {
+ private static AggregateTernaryFEDInstruction
parseInstruction(AggregateTernarySPInstruction instr) {
return new AggregateTernaryFEDInstruction(instr.getOperator(),
instr.input1, instr.input2, instr.input3,
instr.output, instr.getOpcode(),
instr.getInstructionString(), FederatedOutput.NONE);
}
@@ -79,8 +97,8 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
}
else {
throw new
DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown
opcode " + opcode);
- }
}
+}
@Override
public void processInstruction(ExecutionContext ec) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index b4b729a96d..55554240b9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -78,7 +78,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
return new AggregateUnaryFEDInstruction(instr.getOperator(),
instr.input1, instr.input2, instr.input3,
instr.output, instr.getOpcode(),
instr.getInstructionString());
}
-
+
public static AggregateUnaryFEDInstruction parseInstruction(String str)
{
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
@@ -101,7 +101,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
fedOut = FederatedOutput.valueOf(parts[5]);
return new AggregateUnaryFEDInstruction(aggun, in1, out,
opcode, str, fedOut);
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
if (getOpcode().contains("var")) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 20b378ac8e..10a907d78a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -21,31 +21,126 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.BinaryM.VectorType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import
org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
+import
org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.QuantilePickCPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CovarianceSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
- protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr, FederatedOutput fedOut) {
+ protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator
op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr, FederatedOutput fedOut) {
super(type, op, in1, in2, out, opcode, istr, fedOut);
}
- protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
+ protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator
op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr) {
this(type, op, in1, in2, out, opcode, istr,
FederatedOutput.NONE);
}
- public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String istr) {
+ public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand in3,
+ CPOperand out, String opcode, String istr) {
super(type, op, in1, in2, in3, out, opcode, istr);
}
+ public static BinaryFEDInstruction parseInstruction(BinaryCPInstruction
inst, ExecutionContext ec) {
+ if((inst.input1.isMatrix() &&
ec.getMatrixObject(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2 != null && inst.input2.isMatrix() &&
+
ec.getMatrixObject(inst.input2).isFederatedExcept(FType.BROADCAST))) {
+ if(inst instanceof AppendCPInstruction)
+ return
AppendFEDInstruction.parseInstruction((AppendCPInstruction) inst);
+ else if(inst instanceof QuantilePickCPInstruction)
+ return
QuantilePickFEDInstruction.parseInstruction((QuantilePickCPInstruction) inst);
+ else if(inst instanceof CovarianceCPInstruction &&
(ec.getMatrixObject(inst.input1).isFederated(FType.ROW) ||
+
ec.getMatrixObject(inst.input2).isFederated(FType.ROW)))
+ return
CovarianceFEDInstruction.parseInstruction((CovarianceCPInstruction) inst);
+ else if(inst instanceof BinaryMatrixMatrixCPInstruction)
+ return
BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixCPInstruction)
inst);
+ else if(inst instanceof BinaryMatrixScalarCPInstruction)
+ return
BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarCPInstruction)
inst);
+ }
+ return null;
+ }
+
+ public static BinaryFEDInstruction parseInstruction(BinarySPInstruction
inst, ExecutionContext ec) {
+ if(inst instanceof MapmmSPInstruction || inst instanceof
CpmmSPInstruction || inst instanceof RmmSPInstruction) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST)) {
+ return
MMFEDInstruction.parseInstruction((AggregateBinarySPInstruction) inst);
+ }
+ }
+ else if(inst instanceof QuantilePickSPInstruction) {
+ QuantilePickSPInstruction qinstruction =
(QuantilePickSPInstruction) inst;
+ Data data = ec.getVariable(qinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
+ return
QuantilePickFEDInstruction.parseInstruction(qinstruction);
+ }
+ else if(inst instanceof AppendGAlignedSPInstruction || inst
instanceof AppendGSPInstruction ||
+ inst instanceof AppendMSPInstruction || inst instanceof
AppendRSPInstruction) {
+ BinarySPInstruction ainstruction =
(BinarySPInstruction) inst;
+ Data data1 = ec.getVariable(ainstruction.input1);
+ Data data2 = ec.getVariable(ainstruction.input2);
+ if((data1 instanceof MatrixObject && ((MatrixObject)
data1).isFederatedExcept(FType.BROADCAST)) ||
+ (data2 instanceof MatrixObject &&
((MatrixObject) data2).isFederatedExcept(FType.BROADCAST))) {
+ return
AppendFEDInstruction.parseInstruction((AppendSPInstruction) inst);
+ }
+ }
+ else if(inst instanceof BinaryMatrixScalarSPInstruction) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST)) {
+ return
BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarSPInstruction)
inst);
+ }
+ }
+ else if(inst instanceof BinaryMatrixMatrixSPInstruction) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST)) {
+ return
BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixSPInstruction)
inst);
+ }
+ }
+ else if((inst.input1.isMatrix() &&
ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2.isMatrix() &&
ec.getMatrixObject(inst.input2).isFederatedExcept(FType.BROADCAST))) {
+ if(inst instanceof CovarianceSPInstruction &&
(ec.getMatrixObject(inst.input1).isFederated(FType.ROW) ||
+
ec.getMatrixObject(inst.input2).isFederated(FType.ROW)))
+ return
CovarianceFEDInstruction.parseInstruction((CovarianceSPInstruction) inst);
+ else if(inst instanceof CumulativeOffsetSPInstruction) {
+ return
CumulativeOffsetFEDInstruction.parseInstruction((CumulativeOffsetSPInstruction)
inst);
+ }
+ else
+ return
BinaryFEDInstruction.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),
+
FEDInstruction.FederatedOutput.NONE.name()));
+ }
+ return null;
+ }
+
public static BinaryFEDInstruction parseInstruction(String str) {
+ // TODO remove
if(str.startsWith(ExecType.SPARK.name())) {
// rewrite the spark instruction to a cp instruction
str = rewriteSparkInstructionToCP(str);
@@ -57,20 +152,20 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
- FederatedOutput fedOut =
FederatedOutput.valueOf(parts[parts.length-1]);
+ FederatedOutput fedOut =
FederatedOutput.valueOf(parts[parts.length - 1]);
checkOutputDataType(in1, in2, out);
Operator operator =
InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
- //Operator operator =
InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
+ // Operator operator =
InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
// TODO different binary instructions
- if( in1.getDataType() == DataType.SCALAR && in2.getDataType()
== DataType.SCALAR )
+ if(in1.getDataType() == DataType.SCALAR && in2.getDataType() ==
DataType.SCALAR)
throw new DMLRuntimeException("Federated binary scalar
scalar operations not yet supported");
- else if( in1.getDataType() == DataType.MATRIX &&
in2.getDataType() == DataType.MATRIX )
+ else if(in1.getDataType() == DataType.MATRIX &&
in2.getDataType() == DataType.MATRIX)
return new BinaryMatrixMatrixFEDInstruction(operator,
in1, in2, out, opcode, str, fedOut);
- else if( in1.getDataType() == DataType.TENSOR &&
in2.getDataType() == DataType.TENSOR )
+ else if(in1.getDataType() == DataType.TENSOR &&
in2.getDataType() == DataType.TENSOR)
throw new DMLRuntimeException("Federated binary tensor
tensor operations not yet supported");
- else if( in1.isMatrix() && in2.isScalar() || in2.isMatrix() &&
in1.isScalar() )
+ else if(in1.isMatrix() && in2.isScalar() || in2.isMatrix() &&
in1.isScalar())
return new BinaryMatrixScalarFEDInstruction(operator,
in1, in2, out, opcode, str, fedOut);
else
throw new DMLRuntimeException("Federated binary
operations not yet supported:" + opcode);
@@ -78,7 +173,7 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
protected static String parseBinaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand out) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(instr);
- InstructionUtils.checkNumFields ( parts, 3, 4 );
+ InstructionUtils.checkNumFields(parts, 3, 4);
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
@@ -86,9 +181,10 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
return opcode;
}
- protected static String parseBinaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand in3, CPOperand out) {
+ protected static String parseBinaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand in3,
+ CPOperand out) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(instr);
- InstructionUtils.checkNumFields ( parts, 4 );
+ InstructionUtils.checkNumFields(parts, 4);
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
@@ -99,9 +195,10 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
protected static void checkOutputDataType(CPOperand in1, CPOperand in2,
CPOperand out) {
// check for valid data type of output
- if( (in1.getDataType() == DataType.MATRIX || in2.getDataType()
== DataType.MATRIX) && out.getDataType() != DataType.MATRIX )
- throw new DMLRuntimeException("Element-wise matrix
operations between variables " + in1.getName() +
- " and " + in2.getName() + " must produce a
matrix, which " + out.getName() + " is not");
+ if((in1.getDataType() == DataType.MATRIX || in2.getDataType()
== DataType.MATRIX) &&
+ out.getDataType() != DataType.MATRIX)
+ throw new DMLRuntimeException("Element-wise matrix
operations between variables " + in1.getName() + " and "
+ + in2.getName() + " must produce a matrix,
which " + out.getName() + " is not");
}
protected static String rewriteSparkInstructionToCP(String inst_str) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 250b5d193a..fb8455ea9e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -42,12 +42,12 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
super(FEDType.Binary, op, in1, in2, out, opcode, istr, fedOut);
}
- public static BinaryMatrixMatrixFEDInstruction
parseInstruction(BinaryMatrixMatrixCPInstruction instr) {
+ protected static BinaryMatrixMatrixFEDInstruction
parseInstruction(BinaryMatrixMatrixCPInstruction instr) {
return new
BinaryMatrixMatrixFEDInstruction(instr.getOperator(), instr.input1,
instr.input2, instr.output,
instr.getOpcode(), instr.getInstructionString(),
FederatedOutput.NONE);
}
- public static BinaryMatrixMatrixFEDInstruction
parseInstruction(BinaryMatrixMatrixSPInstruction instr) {
+ protected static BinaryMatrixMatrixFEDInstruction
parseInstruction(BinaryMatrixMatrixSPInstruction instr) {
String instrStr =
rewriteSparkInstructionToCP(instr.getInstructionString());
String opcode =
InstructionUtils.getInstructionPartsWithValueType(instrStr)[0];
return new
BinaryMatrixMatrixFEDInstruction(instr.getOperator(), instr.input1,
instr.input2, instr.output,
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
index df2fe11e12..89edb9a4dc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
@@ -28,6 +28,7 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -49,12 +50,21 @@ public class CastFEDInstruction extends UnaryFEDInstruction
{
super(FEDInstruction.FEDType.Cast, op, in, out, opcode, istr);
}
- public static CastFEDInstruction parseInstruction(CastSPInstruction
spInstruction) {
+ public static CastFEDInstruction parseInstruction(CastSPInstruction
inst, ExecutionContext ec) {
+
if((inst.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_FRAME.toString()) ||
+
inst.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_MATRIX.toString())) &&
inst.input1.isMatrix() &&
+
ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) {
+ return CastFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static CastFEDInstruction parseInstruction(CastSPInstruction
spInstruction) {
return new CastFEDInstruction(spInstruction.getOperator(),
spInstruction.input1, spInstruction.output,
spInstruction.getOpcode(),
spInstruction.getInstructionString());
}
- public static CastFEDInstruction parseInstruction ( String str ) {
+ public static CastFEDInstruction parseInstruction(String str) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 2);
String opcode = parts[0];
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 3f87668492..0ca04788e2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -21,16 +21,17 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.Arrays;
import java.util.Collections;
-import java.util.concurrent.Future;
import java.util.Iterator;
import java.util.SortedMap;
-import java.util.stream.IntStream;
import java.util.TreeMap;
+import java.util.concurrent.Future;
+import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -56,25 +57,38 @@ import
org.apache.sysds.runtime.matrix.operators.BinaryOperator;
public class CtableFEDInstruction extends ComputationFEDInstruction {
private final CPOperand _outDim1;
private final CPOperand _outDim2;
- //private final boolean _isExpand;
- //private final boolean _ignoreZeros;
private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand
in3, CPOperand out, CPOperand outDim1,
CPOperand outDim2, boolean isExpand, boolean ignoreZeros,
String opcode, String istr) {
super(FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
_outDim1 = outDim1;
_outDim2 = outDim2;
- // _isExpand = isExpand;
- // _ignoreZeros = ignoreZeros;
}
- public static CtableFEDInstruction parseInstruction(CtableCPInstruction
instr) {
+ public static CtableFEDInstruction parseInstruction(CtableCPInstruction
inst, ExecutionContext ec) {
+ if((inst.getOpcode().equalsIgnoreCase("ctable") ||
inst.getOpcode().equalsIgnoreCase("ctableexpand")) &&
+
(ec.getCacheableData(inst.input1).isFederated(FType.ROW) ||
+ (inst.input2.isMatrix() &&
ec.getCacheableData(inst.input2).isFederated(FType.ROW)) ||
+ (inst.input3.isMatrix() &&
ec.getCacheableData(inst.input3).isFederated(FType.ROW))))
+ return CtableFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ private static CtableFEDInstruction
parseInstruction(CtableCPInstruction instr) {
return new CtableFEDInstruction(instr.input1, instr.input2,
instr.input3, instr.output, instr.getOutDim1(),
instr.getOutDim2(), instr.getIsExpand(),
instr.getIgnoreZeros(), instr.getOpcode(),
instr.getInstructionString());
}
+
+ public static CtableFEDInstruction parseInstruction(CtableSPInstruction
inst, ExecutionContext ec) {
+ if(inst.getOpcode().equalsIgnoreCase("ctable") &&
(ec.getCacheableData(inst.input1).isFederated(FType.ROW) ||
+ (inst.input2.isMatrix() &&
ec.getCacheableData(inst.input2).isFederated(FType.ROW)) ||
+ (inst.input3.isMatrix() &&
ec.getCacheableData(inst.input3).isFederated(FType.ROW))))
+ return CtableFEDInstruction.parseInstruction(inst);
+ return null;
+ }
- public static CtableFEDInstruction parseInstruction(CtableSPInstruction
instr) {
+ private static CtableFEDInstruction
parseInstruction(CtableSPInstruction instr) {
return new CtableFEDInstruction(instr.input1, instr.input2,
instr.input3, instr.output, instr.getOutDim1(),
instr.getOutDim2(), instr.getIsExpand(),
instr.getIgnoreZeros(), instr.getOpcode(),
instr.getInstructionString());
@@ -83,33 +97,27 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
public static CtableFEDInstruction parseInstruction(String inst) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(inst);
InstructionUtils.checkNumFields(parts, 7);
-
String opcode = parts[0];
-
- //handle opcode
+ // handle opcode
if(!(opcode.equalsIgnoreCase("ctable")) &&
!(opcode.equalsIgnoreCase("ctableexpand"))) {
throw new DMLRuntimeException("Unexpected opcode in
CtableFEDInstruction: " + inst);
}
-
- //handle operands
+ // handle operands
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
-
- //handle known dimension information
+ // handle known dimension information
String[] dim1Fields =
parts[4].split(Instruction.LITERAL_PREFIX);
String[] dim2Fields =
parts[5].split(Instruction.LITERAL_PREFIX);
-
CPOperand out = new CPOperand(parts[6]);
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
-
+
boolean dim1Literal = Boolean.parseBoolean(dim1Fields[1]);
CPOperand outDim1 = new CPOperand(dim1Fields[0],
ValueType.FP64, DataType.SCALAR, dim1Literal);
boolean dim2Literal = Boolean.parseBoolean(dim2Fields[1]);
CPOperand outDim2 = new CPOperand(dim2Fields[0],
ValueType.FP64, DataType.SCALAR, dim2Literal);
// ctable does not require any operator, so we simply pass-in a
dummy operator with null functionobject
- return new CtableFEDInstruction(in1,
- in2, in3, out, outDim1, outDim2, false, ignoreZeros,
opcode, inst);
+ return new CtableFEDInstruction(in1, in2, in3, out, outDim1,
outDim2, false, ignoreZeros, opcode, inst);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index eb93ff9d59..c5ada38032 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -19,476 +19,127 @@
package org.apache.sysds.runtime.instructions.fed;
-import org.apache.commons.lang3.ArrayUtils;
-import org.apache.sysds.common.Types.OpOp1;
-import org.apache.sysds.hops.fedplanner.FTypes.FType;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
-import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
-import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
-import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
-import
org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
-import
org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import
org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.QuantilePickCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.QuantileSortCPInstruction;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.ReshapeCPInstruction;
import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
-import
org.apache.sysds.runtime.instructions.cp.TernaryFrameScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
-import
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
-import
org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import
org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
-import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
-import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.CovarianceSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
-import
org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.MultiReturnParameterizedBuiltinSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
-import
org.apache.sysds.runtime.instructions.spark.TernaryFrameScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
public class FEDInstructionUtils {
-
- private static final String[] PARAM_BUILTINS = new String[]{
- "replace", "rmempty", "lowertri", "uppertri",
"transformdecode", "transformapply", "tokenize"};
public static boolean noFedRuntimeConversion = false;
-
- // private static final Log LOG =
LogFactory.getLog(FEDInstructionUtils.class.getName());
-
- // This is currently a rather simplistic to our solution of replacing
instructions with their correct federated
- // counterpart, since we do not propagate the information that a matrix
is federated, therefore we can not decide
- // to choose a federated instruction earlier.
/**
* Check and replace CP instructions with federated instructions if the
instruction match criteria.
*
- * @param inst The instruction to analyse
- * @param ec The Execution Context
+ * @param inst The instruction to analyze
+ * @param ec The Execution Context
* @return The potentially modified instruction
*/
public static Instruction checkAndReplaceCP(Instruction inst,
ExecutionContext ec) {
- if ( !noFedRuntimeConversion ){
- FEDInstruction fedinst = null;
- if (inst instanceof AggregateBinaryCPInstruction) {
- AggregateBinaryCPInstruction instruction =
(AggregateBinaryCPInstruction) inst;
- if( instruction.input1.isMatrix() &&
instruction.input2.isMatrix()) {
- MatrixObject mo1 =
ec.getMatrixObject(instruction.input1);
- MatrixObject mo2 =
ec.getMatrixObject(instruction.input2);
- if ( (mo1.isFederated(FType.ROW) &&
mo1.isFederatedExcept(FType.BROADCAST))
- || (mo2.isFederated(FType.ROW)
&& mo2.isFederatedExcept(FType.BROADCAST))
- || (mo1.isFederated(FType.COL)
&& mo1.isFederatedExcept(FType.BROADCAST))) {
- fedinst =
AggregateBinaryFEDInstruction.parseInstruction(instruction);
- }
- }
- }
- else if( inst instanceof MMChainCPInstruction) {
- MMChainCPInstruction linst =
(MMChainCPInstruction) inst;
- MatrixObject mo =
ec.getMatrixObject(linst.input1);
- if( mo.isFederated(FType.ROW) )
- fedinst =
MMChainFEDInstruction.parseInstruction(linst);
- }
- else if( inst instanceof MMTSJCPInstruction ) {
- MMTSJCPInstruction linst = (MMTSJCPInstruction)
inst;
- MatrixObject mo =
ec.getMatrixObject(linst.input1);
- if( (mo.isFederated(FType.ROW) &&
mo.isFederatedExcept(FType.BROADCAST) && linst.getMMTSJType().isLeft()) ||
- (mo.isFederated(FType.COL) &&
mo.isFederatedExcept(FType.BROADCAST) && linst.getMMTSJType().isRight()))
- fedinst =
TsmmFEDInstruction.parseInstruction(linst);
- }
- else if (inst instanceof UnaryCPInstruction && ! (inst
instanceof IndexingCPInstruction)) {
- UnaryCPInstruction instruction =
(UnaryCPInstruction) inst;
- if(inst instanceof ReorgCPInstruction &&
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
- || inst.getOpcode().equals("rev"))) {
- ReorgCPInstruction rinst =
(ReorgCPInstruction) inst;
- CacheableData<?> mo =
ec.getCacheableData(rinst.input1);
+ if(noFedRuntimeConversion)
+ return inst;
- if((mo instanceof MatrixObject || mo
instanceof FrameObject)
- &&
mo.isFederatedExcept(FType.BROADCAST) )
- fedinst =
ReorgFEDInstruction.parseInstruction(rinst);
- }
- else if(instruction.input1 != null &&
instruction.input1.isMatrix()
- &&
ec.containsVariable(instruction.input1)) {
+ FEDInstruction fedinst = null;
- MatrixObject mo1 =
ec.getMatrixObject(instruction.input1);
- if(
mo1.isFederatedExcept(FType.BROADCAST) ) {
- if(instruction instanceof
CentralMomentCPInstruction)
- fedinst =
CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction) inst);
- else if(inst instanceof
QuantileSortCPInstruction) {
-
if(mo1.isFederated(FType.ROW) ||
mo1.getFedMapping().getFederatedRanges().length == 1 &&
mo1.isFederated(FType.COL))
- fedinst =
QuantileSortFEDInstruction.parseInstruction((QuantileSortCPInstruction) inst);
- }
- else if(inst instanceof
ReshapeCPInstruction)
- fedinst =
ReshapeFEDInstruction.parseInstruction((ReshapeCPInstruction) inst);
- else if(inst instanceof
AggregateUnaryCPInstruction &&
-
((AggregateUnaryCPInstruction) instruction).getAUType() ==
AggregateUnaryCPInstruction.AUType.DEFAULT)
- fedinst =
AggregateUnaryFEDInstruction.parseInstruction((AggregateUnaryCPInstruction)
inst);
- else if(inst instanceof
UnaryMatrixCPInstruction) {
-
if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()) &&
-
!(inst.getOpcode().equalsIgnoreCase("ucumk+*") && mo1.isFederated(FType.COL)))
- fedinst =
UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixCPInstruction) inst);
- }
- }
- }
- }
- else if (inst instanceof BinaryCPInstruction) {
- BinaryCPInstruction instruction =
(BinaryCPInstruction) inst;
- if((instruction.input1.isMatrix() &&
-
ec.getMatrixObject(instruction.input1).isFederatedExcept(FType.BROADCAST)) ||
- (instruction.input2 != null &&
instruction.input2.isMatrix() &&
-
ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
- if(instruction instanceof
AppendCPInstruction)
- fedinst =
AppendFEDInstruction.parseInstruction((AppendCPInstruction) inst);
- else if(instruction instanceof
QuantilePickCPInstruction)
- fedinst =
QuantilePickFEDInstruction.parseInstruction((QuantilePickCPInstruction) inst);
- else if(instruction instanceof
CovarianceCPInstruction &&
-
(ec.getMatrixObject(instruction.input1).isFederated(FType.ROW) ||
-
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
- fedinst =
CovarianceFEDInstruction.parseInstruction((CovarianceCPInstruction) inst);
- else if(instruction instanceof
BinaryMatrixMatrixCPInstruction)
- fedinst =
BinaryMatrixMatrixFEDInstruction
-
.parseInstruction((BinaryMatrixMatrixCPInstruction) inst);
- else if(instruction instanceof
BinaryMatrixScalarCPInstruction)
- fedinst =
BinaryMatrixScalarFEDInstruction
-
.parseInstruction((BinaryMatrixScalarCPInstruction) inst);
- }
- }
- else if( inst instanceof
ParameterizedBuiltinCPInstruction ) {
- ParameterizedBuiltinCPInstruction pinst =
(ParameterizedBuiltinCPInstruction) inst;
- if( ArrayUtils.contains(PARAM_BUILTINS,
pinst.getOpcode()) && pinst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
- fedinst =
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst);
- }
- else if (inst instanceof
MultiReturnParameterizedBuiltinCPInstruction) {
- MultiReturnParameterizedBuiltinCPInstruction
minst = (MultiReturnParameterizedBuiltinCPInstruction) inst;
- if(minst.getOpcode().equals("transformencode")
&& minst.input1.isFrame()) {
- CacheableData<?> fo =
ec.getCacheableData(minst.input1);
-
if(fo.isFederatedExcept(FType.BROADCAST)) {
- fedinst =
MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(minst);
- }
- }
- }
- else if(inst instanceof IndexingCPInstruction) {
- // matrix and frame indexing
- IndexingCPInstruction minst =
(IndexingCPInstruction) inst;
- if((minst.input1.isMatrix() ||
minst.input1.isFrame())
- &&
ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
IndexingFEDInstruction.parseInstruction(minst);
- }
- }
- else if(inst instanceof TernaryCPInstruction) {
- TernaryCPInstruction tinst =
(TernaryCPInstruction) inst;
- if(inst.getOpcode().equals("_map") && inst
instanceof TernaryFrameScalarCPInstruction &&
!inst.getInstructionString().contains("UtilFunctions")
- && tinst.input1.isFrame() &&
ec.getFrameObject(tinst.input1).isFederated()) {
- long margin =
ec.getScalarInput(tinst.input3).getLongValue();
- FrameObject fo =
ec.getFrameObject(tinst.input1);
- if(margin == 0 ||
(fo.isFederated(FType.ROW) && margin == 1) || (fo.isFederated(FType.COL) &&
margin == 2))
- fedinst =
TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarCPInstruction)
inst);
- }
- else if((tinst.input1.isMatrix() &&
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
- || (tinst.input2.isMatrix() &&
ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
- || (tinst.input3.isMatrix() &&
ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
- fedinst =
TernaryFEDInstruction.parseInstruction(tinst);
- }
- }
- else if(inst instanceof VariableCPInstruction ){
- VariableCPInstruction ins =
(VariableCPInstruction) inst;
- if(ins.getVariableOpcode() ==
VariableOperationCode.Write
- && ins.getInput1().isMatrix()
- &&
ins.getInput3().getName().contains("federated")){
- fedinst =
VariableFEDInstruction.parseInstruction(ins);
- }
- else if(ins.getVariableOpcode() ==
VariableOperationCode.CastAsFrameVariable
- && ins.getInput1().isMatrix()
- &&
ec.getCacheableData(ins.getInput1()).isFederatedExcept(FType.BROADCAST)){
- fedinst =
VariableFEDInstruction.parseInstruction(ins);
- }
- else if(ins.getVariableOpcode() ==
VariableOperationCode.CastAsMatrixVariable
- && ins.getInput1().isFrame()
- &&
ec.getCacheableData(ins.getInput1()).isFederatedExcept(FType.BROADCAST)){
- fedinst =
VariableFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof AggregateTernaryCPInstruction){
- AggregateTernaryCPInstruction ins =
(AggregateTernaryCPInstruction) inst;
- if(ins.input1.isMatrix() &&
ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)
- && ins.input2.isMatrix() &&
ec.getCacheableData(ins.input2).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
AggregateTernaryFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof QuaternaryCPInstruction) {
- QuaternaryCPInstruction instruction =
(QuaternaryCPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject &&
((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- fedinst =
QuaternaryFEDInstruction.parseInstruction(instruction);
- }
- else if(inst instanceof SpoofCPInstruction) {
- SpoofCPInstruction ins = (SpoofCPInstruction)
inst;
- Class<?> scla =
ins.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
- && SpoofFEDInstruction.isFederated(ec,
ins.getInputs(), scla))
- || (scla == SpoofRowwise.class &&
SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
- fedinst =
SpoofFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof CtableCPInstruction) {
- CtableCPInstruction cinst =
(CtableCPInstruction) inst;
- if((inst.getOpcode().equalsIgnoreCase("ctable")
|| inst.getOpcode().equalsIgnoreCase("ctableexpand"))
- && (
ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
- || (cinst.input2.isMatrix() &&
ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
- || (cinst.input3.isMatrix() &&
ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
- fedinst =
CtableFEDInstruction.parseInstruction(cinst);
- }
+ if(inst instanceof AggregateBinaryCPInstruction)
+ fedinst =
AggregateBinaryFEDInstruction.parseInstruction((AggregateBinaryCPInstruction)
inst, ec);
+ else if(inst instanceof MMChainCPInstruction)
+ fedinst =
MMChainFEDInstruction.parseInstruction((MMChainCPInstruction) inst, ec);
+ else if(inst instanceof MMTSJCPInstruction)
+ fedinst =
TsmmFEDInstruction.parseInstruction((MMTSJCPInstruction) inst, ec);
+ else if(inst instanceof UnaryCPInstruction)
+ fedinst =
UnaryFEDInstruction.parseInstruction((UnaryCPInstruction) inst, ec);
+ else if(inst instanceof BinaryCPInstruction)
+ fedinst =
BinaryFEDInstruction.parseInstruction((BinaryCPInstruction) inst, ec);
+ else if(inst instanceof ParameterizedBuiltinCPInstruction)
+ fedinst =
ParameterizedBuiltinFEDInstruction.parseInstruction((ParameterizedBuiltinCPInstruction)
inst, ec);
+ else if(inst instanceof
MultiReturnParameterizedBuiltinCPInstruction)
+ fedinst = MultiReturnParameterizedBuiltinFEDInstruction
+
.parseInstruction((MultiReturnParameterizedBuiltinCPInstruction) inst, ec);
+ else if(inst instanceof TernaryCPInstruction)
+ fedinst =
TernaryFEDInstruction.parseInstruction((TernaryCPInstruction) inst, ec);
+ else if(inst instanceof VariableCPInstruction)
+ fedinst =
VariableFEDInstruction.parseInstruction((VariableCPInstruction) inst, ec);
+ else if(inst instanceof AggregateTernaryCPInstruction)
+ fedinst =
AggregateTernaryFEDInstruction.parseInstruction((AggregateTernaryCPInstruction)
inst, ec);
+ else if(inst instanceof QuaternaryCPInstruction)
+ fedinst =
QuaternaryFEDInstruction.parseInstruction((QuaternaryCPInstruction) inst, ec);
+ else if(inst instanceof SpoofCPInstruction)
+ fedinst =
SpoofFEDInstruction.parseInstruction((SpoofCPInstruction) inst, ec);
+ else if(inst instanceof CtableCPInstruction)
+ fedinst =
CtableFEDInstruction.parseInstruction((CtableCPInstruction) inst, ec);
- //set thread id for federated context management
- if( fedinst != null ) {
- fedinst.setTID(ec.getTID());
- return fedinst;
- }
+ // set thread id for federated context management
+ if(fedinst != null) {
+ fedinst.setTID(ec.getTID());
+ return fedinst;
}
-
+
return inst;
+
}
public static Instruction checkAndReplaceSP(Instruction inst,
ExecutionContext ec) {
+ if(noFedRuntimeConversion)
+ return inst;
FEDInstruction fedinst = null;
- if(inst instanceof CastSPInstruction){
- CastSPInstruction ins = (CastSPInstruction) inst;
-
if((ins.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_FRAME.toString())
- ||
ins.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_MATRIX.toString()))
- && ins.input1.isMatrix() &&
ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)){
- fedinst =
CastFEDInstruction.parseInstruction(ins);
- }
- }
- else if (inst instanceof WriteSPInstruction) {
+ if(inst instanceof CastSPInstruction)
+ fedinst =
CastFEDInstruction.parseInstruction((CastSPInstruction) inst, ec);
+ else if(inst instanceof WriteSPInstruction) {
WriteSPInstruction instruction = (WriteSPInstruction)
inst;
Data data = ec.getVariable(instruction.input1);
- if (data instanceof CacheableData &&
((CacheableData<?>) data).isFederated()) {
+ if(data instanceof CacheableData && ((CacheableData<?>)
data).isFederated()) {
// Write spark instruction can not be executed
for federated matrix objects (tries to get rdds which do
// not exist), therefore we replace the
instruction with the VariableCPInstruction.
return
VariableCPInstruction.parseInstruction(instruction.getInstructionString());
}
}
- else if(inst instanceof QuaternarySPInstruction) {
- QuaternarySPInstruction instruction =
(QuaternarySPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
- fedinst =
QuaternaryFEDInstruction.parseInstruction(instruction);
- }
- else if(inst instanceof SpoofSPInstruction) {
- SpoofSPInstruction ins = (SpoofSPInstruction) inst;
- Class<?> scla = ins.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
- && SpoofFEDInstruction.isFederated(ec,
ins.getInputs(), scla))
- || (scla == SpoofRowwise.class &&
SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
- fedinst =
SpoofFEDInstruction.parseInstruction(ins);
- }
- }
- else if (inst instanceof UnarySPInstruction && ! (inst
instanceof IndexingSPInstruction)) {
- UnarySPInstruction instruction = (UnarySPInstruction)
inst;
- if (inst instanceof CentralMomentSPInstruction) {
- CentralMomentSPInstruction cinstruction =
(CentralMomentSPInstruction) inst;
- Data data = ec.getVariable(cinstruction.input1);
- if (data instanceof MatrixObject &&
((MatrixObject) data).isFederated() && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
- fedinst =
CentralMomentFEDInstruction.parseInstruction(cinstruction);
- } else if (inst instanceof QuantileSortSPInstruction) {
- QuantileSortSPInstruction qinstruction =
(QuantileSortSPInstruction) inst;
- Data data = ec.getVariable(qinstruction.input1);
- if (data instanceof MatrixObject &&
((MatrixObject) data).isFederated() && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
- fedinst =
QuantileSortFEDInstruction.parseInstruction(qinstruction);
- }
- else if (inst instanceof AggregateUnarySPInstruction) {
- AggregateUnarySPInstruction auinstruction =
(AggregateUnarySPInstruction) inst;
- Data data =
ec.getVariable(auinstruction.input1);
- if(data instanceof MatrixObject &&
((MatrixObject) data).isFederated() && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
- if(ArrayUtils.contains(new
String[]{"uarimin", "uarimax"}, auinstruction.getOpcode())) {
- if(((MatrixObject)
data).getFedMapping().getType() == FType.ROW)
- fedinst =
AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
- }
- else
- fedinst =
AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
- }
- else if(inst instanceof ReorgSPInstruction &&
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
- || inst.getOpcode().equals("rev"))) {
- ReorgSPInstruction rinst = (ReorgSPInstruction)
inst;
- CacheableData<?> mo =
ec.getCacheableData(rinst.input1);
- if((mo instanceof MatrixObject || mo instanceof
FrameObject) && mo.isFederated() && mo.isFederatedExcept(FType.BROADCAST))
- fedinst =
ReorgFEDInstruction.parseInstruction(rinst);
- }
- else if(inst instanceof ReblockSPInstruction &&
instruction.input1 != null && (instruction.input1.isFrame() ||
instruction.input1.isMatrix())) {
- ReblockSPInstruction rinst =
(ReblockSPInstruction) instruction;
- CacheableData<?> data =
ec.getCacheableData(rinst.input1);
- if(data.isFederatedExcept(FType.BROADCAST))
- fedinst =
ReblockFEDInstruction.parseInstruction((ReblockSPInstruction) inst);
- }
- else if(instruction.input1 != null &&
instruction.input1.isMatrix() && ec.containsVariable(instruction.input1)) {
- MatrixObject mo1 =
ec.getMatrixObject(instruction.input1);
- if(mo1.isFederatedExcept(FType.BROADCAST)) {
-
if(instruction.getOpcode().equalsIgnoreCase("cm"))
- fedinst =
CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction)inst);
- else
if(inst.getOpcode().equalsIgnoreCase("qsort")) {
-
if(mo1.getFedMapping().getFederatedRanges().length == 1)
- fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString(), false);
- }
- else
if(inst.getOpcode().equalsIgnoreCase("rshape")) {
- fedinst =
ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
- }
- else if(inst instanceof
UnaryMatrixSPInstruction) {
-
if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()))
- fedinst =
UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixSPInstruction) inst);
- }
- }
- }
- }
- else if (inst instanceof BinarySPInstruction) {
- BinarySPInstruction instruction = (BinarySPInstruction)
inst;
- if (inst instanceof MapmmSPInstruction || inst
instanceof CpmmSPInstruction || inst instanceof RmmSPInstruction) {
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject &&
((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
MMFEDInstruction.parseInstruction((AggregateBinarySPInstruction) instruction);
- }
- }
- else
- if(inst instanceof QuantilePickSPInstruction) {
- QuantilePickSPInstruction qinstruction =
(QuantilePickSPInstruction) inst;
- Data data = ec.getVariable(qinstruction.input1);
- if(data instanceof MatrixObject &&
((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- fedinst =
QuantilePickFEDInstruction.parseInstruction(qinstruction);
- }
- else if (inst instanceof AppendGAlignedSPInstruction ||
inst instanceof AppendGSPInstruction
- || inst instanceof AppendMSPInstruction || inst
instanceof AppendRSPInstruction) {
- BinarySPInstruction ainstruction =
(BinarySPInstruction) inst;
- Data data1 =
ec.getVariable(ainstruction.input1);
- Data data2 =
ec.getVariable(ainstruction.input2);
- if ((data1 instanceof MatrixObject &&
((MatrixObject) data1).isFederatedExcept(FType.BROADCAST))
- || (data2 instanceof MatrixObject &&
((MatrixObject) data2).isFederatedExcept(FType.BROADCAST))) {
- fedinst =
AppendFEDInstruction.parseInstruction((AppendSPInstruction) instruction);
- }
- }
- else if (inst instanceof
BinaryMatrixScalarSPInstruction) {
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject &&
((MatrixObject)data).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarSPInstruction)
inst);
- }
- }
- else if (inst instanceof
BinaryMatrixMatrixSPInstruction) {
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject &&
((MatrixObject)data).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixSPInstruction)
inst);
- }
- }
- else if( (instruction.input1.isMatrix() &&
ec.getCacheableData(instruction.input1).isFederatedExcept(FType.BROADCAST))
- || (instruction.input2.isMatrix() &&
ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
- if(inst instanceof CovarianceSPInstruction &&
(ec.getMatrixObject(instruction.input1)
- .isFederated(FType.ROW) ||
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
- fedinst =
CovarianceFEDInstruction.parseInstruction((CovarianceSPInstruction) inst);
- else if(inst instanceof
CumulativeOffsetSPInstruction) {
- fedinst =
CumulativeOffsetFEDInstruction.parseInstruction((CumulativeOffsetSPInstruction)
inst);
- }
- else
- fedinst =
BinaryFEDInstruction.parseInstruction(InstructionUtils
-
.concatOperands(inst.getInstructionString(),
FEDInstruction.FederatedOutput.NONE.name()));
- }
- }
- else if( inst instanceof ParameterizedBuiltinSPInstruction) {
- ParameterizedBuiltinSPInstruction pinst =
(ParameterizedBuiltinSPInstruction) inst;
- if( pinst.getOpcode().equalsIgnoreCase("replace") &&
pinst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
- fedinst =
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst);
- }
- else if (inst instanceof
MultiReturnParameterizedBuiltinSPInstruction) {
- MultiReturnParameterizedBuiltinSPInstruction minst =
(MultiReturnParameterizedBuiltinSPInstruction) inst;
- if(minst.getOpcode().equals("transformencode") &&
minst.input1.isFrame()) {
- CacheableData<?> fo =
ec.getCacheableData(minst.input1);
- if(fo.isFederatedExcept(FType.BROADCAST)) {
- fedinst =
MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(minst);
- }
- }
- }
- else if(inst instanceof IndexingSPInstruction) {
- // matrix and frame indexing
- IndexingSPInstruction minst = (IndexingSPInstruction)
inst;
- if((minst.input1.isMatrix() || minst.input1.isFrame())
- &&
ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
IndexingFEDInstruction.parseInstruction(minst);
- }
- }
- else if(inst instanceof TernarySPInstruction) {
- TernarySPInstruction tinst = (TernarySPInstruction)
inst;
- if(inst.getOpcode().equals("_map") && inst instanceof
TernaryFrameScalarSPInstruction &&
!inst.getInstructionString().contains("UtilFunctions")
- && tinst.input1.isFrame() &&
ec.getFrameObject(tinst.input1).isFederated()) {
- long margin =
ec.getScalarInput(tinst.input3).getLongValue();
- FrameObject fo =
ec.getFrameObject(tinst.input1);
- if(margin == 0 || (fo.isFederated(FType.ROW) &&
margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
- fedinst =
TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarSPInstruction)
tinst);
- } else if((tinst.input1.isMatrix() &&
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
- || (tinst.input2.isMatrix() &&
ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
- || (tinst.input3.isMatrix() &&
ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
- fedinst =
TernaryFEDInstruction.parseInstruction(tinst);
- }
- }
- else if(inst instanceof AggregateTernarySPInstruction){
- AggregateTernarySPInstruction ins =
(AggregateTernarySPInstruction) inst;
- if(ins.input1.isMatrix() &&
ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST) &&
ins.input2.isMatrix() &&
-
ec.getCacheableData(ins.input2).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
AggregateTernaryFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof CtableSPInstruction) {
- CtableSPInstruction cinst = (CtableSPInstruction) inst;
- if(inst.getOpcode().equalsIgnoreCase("ctable")
- && (
ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
- || (cinst.input2.isMatrix() &&
ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
- || (cinst.input3.isMatrix() &&
ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
- fedinst =
CtableFEDInstruction.parseInstruction(cinst);
- }
+ else if(inst instanceof QuaternarySPInstruction)
+ fedinst =
QuaternaryFEDInstruction.parseInstruction((QuaternarySPInstruction) inst, ec);
+ else if(inst instanceof SpoofSPInstruction)
+ fedinst =
SpoofFEDInstruction.parseInstruction((SpoofSPInstruction) inst, ec);
+ else if(inst instanceof UnarySPInstruction)
+ fedinst =
UnaryFEDInstruction.parseInstruction((UnarySPInstruction) inst, ec);
+ else if(inst instanceof BinarySPInstruction)
+ fedinst =
BinaryFEDInstruction.parseInstruction((BinarySPInstruction) inst, ec);
+ else if(inst instanceof ParameterizedBuiltinSPInstruction)
+ fedinst =
ParameterizedBuiltinFEDInstruction.parseInstruction((ParameterizedBuiltinSPInstruction)
inst, ec);
+ else if(inst instanceof
MultiReturnParameterizedBuiltinSPInstruction)
+ fedinst = MultiReturnParameterizedBuiltinFEDInstruction
+
.parseInstruction((MultiReturnParameterizedBuiltinSPInstruction) inst, ec);
+ else if(inst instanceof TernarySPInstruction)
+ fedinst =
TernaryFEDInstruction.parseInstruction((TernarySPInstruction) inst, ec);
+ else if(inst instanceof AggregateTernarySPInstruction)
+ fedinst =
AggregateTernaryFEDInstruction.parseInstruction((AggregateTernarySPInstruction)
inst, ec);
+ else if(inst instanceof CtableSPInstruction)
+ fedinst =
CtableFEDInstruction.parseInstruction((CtableSPInstruction) inst, ec);
- //set thread id for federated context management
- if( fedinst != null ) {
+ // set thread id for federated context management
+ if(fedinst != null) {
fedinst.setTID(ec.getTID());
return fedinst;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 5ddc46d899..cf5af3d9c8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -19,23 +19,24 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.concurrent.Future;
+
import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import java.util.concurrent.Future;
-
public class MMChainFEDInstruction extends UnaryFEDInstruction {
public MMChainFEDInstruction(CPOperand in1, CPOperand in2, CPOperand
in3,
@@ -50,7 +51,15 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
return _type;
}
- public static MMChainFEDInstruction
parseInstruction(MMChainCPInstruction instr) {
+ public static MMChainFEDInstruction
parseInstruction(MMChainCPInstruction inst, ExecutionContext ec) {
+ MMChainCPInstruction linst = (MMChainCPInstruction) inst;
+ MatrixObject mo = ec.getMatrixObject(linst.input1);
+ if( mo.isFederated(FType.ROW) )
+ return MMChainFEDInstruction.parseInstruction(linst);
+ return null;
+ }
+
+ private static MMChainFEDInstruction
parseInstruction(MMChainCPInstruction instr) {
return new MMChainFEDInstruction(instr.input1, instr.input2,
instr.input3, instr.output, instr.getMMChainType(),
instr.getNumThreads(), instr.getOpcode(),
instr.getInstructionString());
}
@@ -62,7 +71,7 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
-
+
if( parts.length==6 ) {
CPOperand out= new CPOperand(parts[3]);
ChainType type = ChainType.valueOf(parts[4]);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 16471a1497..c9135eb013 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -35,8 +35,10 @@ import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.fedplanner.FTypes;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -79,15 +81,35 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
}
public static MultiReturnParameterizedBuiltinFEDInstruction
parseInstruction(
+ MultiReturnParameterizedBuiltinCPInstruction inst,
ExecutionContext ec) {
+ if(inst.getOpcode().equals("transformencode") &&
inst.input1.isFrame()) {
+ CacheableData<?> fo = ec.getCacheableData(inst.input1);
+ if(fo.isFederatedExcept(FType.BROADCAST))
+ return
MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static MultiReturnParameterizedBuiltinFEDInstruction
parseInstruction(
+ MultiReturnParameterizedBuiltinSPInstruction inst,
ExecutionContext ec) {
+ if(inst.getOpcode().equals("transformencode") &&
inst.input1.isFrame()) {
+ CacheableData<?> fo = ec.getCacheableData(inst.input1);
+ if(fo.isFederatedExcept(FType.BROADCAST))
+ return
MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static MultiReturnParameterizedBuiltinFEDInstruction
parseInstruction(
MultiReturnParameterizedBuiltinCPInstruction instr) {
return new
MultiReturnParameterizedBuiltinFEDInstruction(instr.getOperator(),
instr.input1, instr.input2,
instr.getOutputs(), instr.getOpcode(),
instr.getInstructionString());
}
- public static MultiReturnParameterizedBuiltinFEDInstruction
parseInstruction(
- MultiReturnParameterizedBuiltinSPInstruction instr) {
+ private static MultiReturnParameterizedBuiltinFEDInstruction
parseInstruction(
+ MultiReturnParameterizedBuiltinSPInstruction instr) {
return new
MultiReturnParameterizedBuiltinFEDInstruction(instr.getOperator(),
instr.input1, instr.input2,
- instr.getOutputs(), instr.getOpcode(),
instr.getInstructionString());
+ instr.getOutputs(), instr.getOpcode(),
instr.getInstructionString());
}
public static MultiReturnParameterizedBuiltinFEDInstruction
parseInstruction(String str) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index c0b60c557a..b9794b413e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -84,12 +84,40 @@ import org.apache.sysds.runtime.util.UtilFunctions;
public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstruction {
protected final HashMap<String, String> params;
+ private static final String[] PARAM_BUILTINS = new String[]{
+ "replace", "rmempty", "lowertri", "uppertri",
"transformdecode", "transformapply", "tokenize"};
+
+
protected ParameterizedBuiltinFEDInstruction(Operator op,
HashMap<String, String> paramsMap, CPOperand out,
String opcode, String istr) {
super(FEDType.ParameterizedBuiltin, op, null, null, out,
opcode, istr);
params = paramsMap;
}
+ public static ParameterizedBuiltinFEDInstruction
parseInstruction(String str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ // first part is always the opcode
+ String opcode = parts[0];
+ // last part is always the output
+ CPOperand out = new CPOperand(parts[parts.length - 1]);
+
+ // process remaining parts and build a hash map
+ LinkedHashMap<String, String> paramsMap =
constructParameterMap(parts);
+
+ // determine the appropriate value function
+ if(opcode.equalsIgnoreCase("replace") ||
opcode.equalsIgnoreCase("rmempty") ||
+ opcode.equalsIgnoreCase("lowertri") ||
opcode.equalsIgnoreCase("uppertri")) {
+ ValueFunction func =
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ return new ParameterizedBuiltinFEDInstruction(new
SimpleOperator(func), paramsMap, out, opcode, str);
+ }
+ else if(opcode.equals("transformapply") ||
opcode.equals("transformdecode") || opcode.equals("tokenize")) {
+ return new ParameterizedBuiltinFEDInstruction(null,
paramsMap, out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported opcode (" +
opcode + ") for ParameterizedBuiltinFEDInstruction.");
+ }
+ }
+
public HashMap<String, String> getParameterMap() {
return params;
}
@@ -112,39 +140,28 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
return paramMap;
}
- public static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinCPInstruction instr) {
- return new
ParameterizedBuiltinFEDInstruction(instr.getOperator(),
instr.getParameterMap(), instr.output,
- instr.getOpcode(), instr.getInstructionString());
+ public static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinCPInstruction inst,
+ ExecutionContext ec) {
+ if(ArrayUtils.contains(PARAM_BUILTINS, inst.getOpcode()) &&
inst.getTarget(ec).isFederatedExcept(FType.BROADCAST))
+ return
ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ public static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinSPInstruction inst,
+ ExecutionContext ec) {
+ if( inst.getOpcode().equalsIgnoreCase("replace") &&
inst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
+ return
ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ return null;
}
- public static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinSPInstruction instr) {
+ private static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinCPInstruction instr) {
return new
ParameterizedBuiltinFEDInstruction(instr.getOperator(),
instr.getParameterMap(), instr.output,
instr.getOpcode(), instr.getInstructionString());
}
- public static ParameterizedBuiltinFEDInstruction
parseInstruction(String str) {
- String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- // first part is always the opcode
- String opcode = parts[0];
- // last part is always the output
- CPOperand out = new CPOperand(parts[parts.length - 1]);
-
- // process remaining parts and build a hash map
- LinkedHashMap<String, String> paramsMap =
constructParameterMap(parts);
-
- // determine the appropriate value function
- if(opcode.equalsIgnoreCase("replace") ||
opcode.equalsIgnoreCase("rmempty") ||
- opcode.equalsIgnoreCase("lowertri") ||
opcode.equalsIgnoreCase("uppertri")) {
- ValueFunction func =
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
- return new ParameterizedBuiltinFEDInstruction(new
SimpleOperator(func), paramsMap, out, opcode, str);
- }
- else if(opcode.equals("transformapply") ||
opcode.equals("transformdecode") || opcode.equals("tokenize")) {
- return new ParameterizedBuiltinFEDInstruction(null,
paramsMap, out, opcode, str);
- }
- else {
- throw new DMLRuntimeException(
- "Unsupported opcode (" + opcode + ") for
ParameterizedBuiltinFEDInstruction.");
- }
+ private static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinSPInstruction instr) {
+ return new
ParameterizedBuiltinFEDInstruction(instr.getOperator(),
instr.getParameterMap(), instr.output,
+ instr.getOpcode(), instr.getInstructionString());
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index f817c4c2a6..871128a83f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -127,6 +127,7 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction {
inst._fedOut = fedOut;
return inst;
}
+
@Override
public void processInstruction(ExecutionContext ec) {
if(ec.getMatrixObject(input1).isFederated(FType.COL) ||
ec.getMatrixObject(input1).isFederated(FType.FULL))
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
index 9b5014e6d7..ee89c72485 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
@@ -21,17 +21,18 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysds.lops.WeightedDivMM;
-import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
+import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysds.lops.WeightedSquaredLoss;
-import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
+import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -39,6 +40,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -58,7 +60,21 @@ public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
_input4 = in4;
}
- public static QuaternaryFEDInstruction
parseInstruction(QuaternaryCPInstruction instr) {
+ public static QuaternaryFEDInstruction
parseInstruction(QuaternaryCPInstruction inst, ExecutionContext ec) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
+ return QuaternaryFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ public static QuaternaryFEDInstruction
parseInstruction(QuaternarySPInstruction inst, ExecutionContext ec) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
+ return QuaternaryFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ private static QuaternaryFEDInstruction
parseInstruction(QuaternaryCPInstruction instr) {
QuaternaryOperator qop = (QuaternaryOperator)
instr.getOperator();
if(qop.wtype1 != null)
return
QuaternaryWSLossFEDInstruction.parseInstruction(instr);
@@ -74,7 +90,7 @@ public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
return null;
}
- public static QuaternaryFEDInstruction
parseInstruction(QuaternarySPInstruction instr) {
+ private static QuaternaryFEDInstruction
parseInstruction(QuaternarySPInstruction instr) {
QuaternaryOperator qop = (QuaternaryOperator)
instr.getOperator();
if(qop.wtype1 != null)
return
QuaternaryWSLossFEDInstruction.parseInstruction(instr);
@@ -99,7 +115,8 @@ public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
- int addInput4 = (opcode.equals(WeightedCrossEntropy.OPCODE_CP)
|| opcode.equals(WeightedSquaredLoss.OPCODE_CP) ||
opcode.equals(WeightedDivMM.OPCODE_CP)) ? 1 : 0;
+ int addInput4 = (opcode.equals(WeightedCrossEntropy.OPCODE_CP)
|| opcode.equals(WeightedSquaredLoss.OPCODE_CP) ||
+ opcode.equals(WeightedDivMM.OPCODE_CP)) ? 1 : 0;
int addUOpcode = (opcode.equals(WeightedUnaryMM.OPCODE_CP) ? 1
: 0);
InstructionUtils.checkNumFields(parts, 6 + addInput4 +
addUOpcode);
@@ -124,11 +141,10 @@ public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
Double.parseDouble(in4.getName())) :
new QuaternaryOperator(wcemm_type));
return new QuaternaryWCeMMFEDInstruction(qop,
in1, in2, in3, in4, out, opcode, str);
}
- else if(opcode.equals(WeightedDivMM.OPCODE_CP))
- {
+ else if(opcode.equals(WeightedDivMM.OPCODE_CP)) {
final WDivMMType wdivmm_type =
WDivMMType.valueOf(parts[6]);
if(wdivmm_type.hasFourInputs())
- checkDataTypes(new
DataType[]{DataType.SCALAR, DataType.MATRIX}, in4);
+ checkDataTypes(new DataType[]
{DataType.SCALAR, DataType.MATRIX}, in4);
qop = new QuaternaryOperator(wdivmm_type);
return new QuaternaryWDivMMFEDInstruction(qop,
in1, in2, in3, in4, out, opcode, str);
}
@@ -145,8 +161,7 @@ public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
qop = new QuaternaryOperator(wsigmoid_type);
return new QuaternaryWSigmoidFEDInstruction(qop, in1,
in2, in3, out, opcode, str);
}
- else if(opcode.equals(WeightedUnaryMM.OPCODE_CP))
- {
+ else if(opcode.equals(WeightedUnaryMM.OPCODE_CP)) {
final WUMMType wumm_type = WUMMType.valueOf(parts[6]);
String uopcode = parts[1];
qop = new QuaternaryOperator(wumm_type, uopcode);
@@ -179,7 +194,7 @@ public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
protected static String rewriteSparkInstructionToCP(String inst_str) {
// TODO: don't perform replacement over the whole instruction
string, possibly changing string literals,
- // instead only at positions of ExecType and Opcode
+ // instead only at positions of ExecType and Opcode
// rewrite the spark instruction to a cp instruction
inst_str = inst_str.replace(ExecType.SPARK.name(),
ExecType.CP.name());
if(inst_str.contains(WeightedCrossEntropy.OPCODE))
@@ -203,11 +218,11 @@ public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
return inst_str;
}
-
+
protected void setOutputDataCharacteristics(MatrixObject X,
MatrixObject U, MatrixObject V, ExecutionContext ec) {
long rows = X.getNumRows() > 1 ? X.getNumRows() :
U.getNumRows();
- long cols = X.getNumColumns() > 1 ? X.getNumColumns()
- : (U.getNumColumns() == V.getNumRows() ?
V.getNumColumns() : V.getNumRows());
+ long cols = X.getNumColumns() > 1 ? X
+ .getNumColumns() : (U.getNumColumns() == V.getNumRows()
? V.getNumColumns() : V.getNumRows());
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(rows, cols, (int)
X.getBlocksize());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index bf3632f1dd..0b173d7fe8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -76,37 +76,40 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
rinst.getInstructionString(), FederatedOutput.NONE);
}
- public static ReorgFEDInstruction parseInstruction ( String str ) {
+ public static ReorgFEDInstruction parseInstruction(String str) {
CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
FederatedOutput fedOut;
- if ( opcode.equalsIgnoreCase("r'") ) {
+ if(opcode.equalsIgnoreCase("r'")) {
InstructionUtils.checkNumFields(str, 2, 3, 4);
in.split(parts[1]);
out.split(parts[2]);
int k = str.startsWith(Types.ExecMode.SPARK.name()) ? 0
: Integer.parseInt(parts[3]);
- fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ?
- FederatedOutput.valueOf(parts[3]) :
FederatedOutput.valueOf(parts[4]);
- return new ReorgFEDInstruction(new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str,
fedOut);
+ fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ?
FederatedOutput.valueOf(parts[3]) : FederatedOutput
+ .valueOf(parts[4]);
+ return new ReorgFEDInstruction(new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str,
+ fedOut);
}
- else if ( opcode.equalsIgnoreCase("rdiag") ) {
- parseUnaryInstruction(str, in, out); //max 2 operands
+ else if(opcode.equalsIgnoreCase("rdiag")) {
+ parseUnaryInstruction(str, in, out); // max 2 operands
fedOut = parseFedOutFlag(str, 3);
- return new ReorgFEDInstruction(new
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str, fedOut);
+ return new ReorgFEDInstruction(new
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str,
+ fedOut);
}
- else if ( opcode.equalsIgnoreCase("rev") ) {
- parseUnaryInstruction(str, in, out); //max 2 operands
+ else if(opcode.equalsIgnoreCase("rev")) {
+ parseUnaryInstruction(str, in, out); // max 2 operands
fedOut = parseFedOutFlag(str, 3);
- return new ReorgFEDInstruction(new
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut);
+ return new ReorgFEDInstruction(new
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str,
+ fedOut);
}
else {
- throw new DMLRuntimeException("ReorgFEDInstruction:
unsupported opcode: "+opcode);
+ throw new DMLRuntimeException("ReorgFEDInstruction:
unsupported opcode: " + opcode);
}
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index e5af25ef02..14b16111e8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -69,12 +69,32 @@ public class SpoofFEDInstruction extends FEDInstruction
_output = out;
}
- public static SpoofFEDInstruction parseInstruction(SpoofCPInstruction
instr) {
+ public static SpoofFEDInstruction parseInstruction(SpoofCPInstruction
inst, ExecutionContext ec){
+ Class<?> scla = inst.getOperatorClass().getSuperclass();
+ if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
+ && SpoofFEDInstruction.isFederated(ec,
inst.getInputs(), scla))
+ || (scla == SpoofRowwise.class &&
SpoofFEDInstruction.isFederated(ec, FType.ROW, inst.getInputs(), scla))) {
+ return SpoofFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static SpoofFEDInstruction parseInstruction(SpoofSPInstruction
inst, ExecutionContext ec){
+ Class<?> scla = inst.getOperatorClass().getSuperclass();
+ if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
+ && SpoofFEDInstruction.isFederated(ec,
inst.getInputs(), scla))
+ || (scla == SpoofRowwise.class &&
SpoofFEDInstruction.isFederated(ec, FType.ROW, inst.getInputs(), scla))) {
+ return SpoofFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static SpoofFEDInstruction parseInstruction(SpoofCPInstruction
instr) {
return new SpoofFEDInstruction(instr.getSpoofOperator(),
instr.getInputs(), instr.getOutput(),
instr.getOpcode(), instr.getInstructionString());
}
- public static SpoofFEDInstruction parseInstruction(SpoofSPInstruction
instr) {
+ private static SpoofFEDInstruction parseInstruction(SpoofSPInstruction
instr) {
SpoofOperator op =
CodegenUtils.createInstance(instr.getOperatorClass());
return new SpoofFEDInstruction(op, instr.getInputs(),
instr.getOutput(), instr.getOpcode(),
instr.getInstructionString());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 342faf3296..0883e6fe02 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -25,6 +25,7 @@ import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -34,6 +35,8 @@ import
org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
+import
org.apache.sysds.runtime.instructions.cp.TernaryFrameScalarCPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.TernaryFrameScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -46,12 +49,46 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out,
opcode, str, fedOut);
}
- public static TernaryFEDInstruction
parseInstruction(TernaryCPInstruction instr) {
+ public static TernaryFEDInstruction
parseInstruction(TernaryCPInstruction inst, ExecutionContext ec) {
+ if(inst.getOpcode().equals("_map") && inst instanceof
TernaryFrameScalarCPInstruction &&
+ !inst.getInstructionString().contains("UtilFunctions")
&& inst.input1.isFrame() &&
+ ec.getFrameObject(inst.input1).isFederated()) {
+ long margin =
ec.getScalarInput(inst.input3).getLongValue();
+ FrameObject fo = ec.getFrameObject(inst.input1);
+ if(margin == 0 || (fo.isFederated(FType.ROW) && margin
== 1) || (fo.isFederated(FType.COL) && margin == 2))
+ return
TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarCPInstruction)
inst);
+ }
+ else if((inst.input1.isMatrix() &&
ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2.isMatrix() &&
ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input3.isMatrix() &&
ec.getCacheableData(inst.input3).isFederatedExcept(FType.BROADCAST))) {
+ return TernaryFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static TernaryFEDInstruction
parseInstruction(TernarySPInstruction inst, ExecutionContext ec) {
+ if(inst.getOpcode().equals("_map") && inst instanceof
TernaryFrameScalarSPInstruction &&
+ !inst.getInstructionString().contains("UtilFunctions")
&& inst.input1.isFrame() &&
+ ec.getFrameObject(inst.input1).isFederated()) {
+ long margin =
ec.getScalarInput(inst.input3).getLongValue();
+ FrameObject fo = ec.getFrameObject(inst.input1);
+ if(margin == 0 || (fo.isFederated(FType.ROW) && margin
== 1) || (fo.isFederated(FType.COL) && margin == 2))
+ return
TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarSPInstruction)
inst);
+ }
+ else if((inst.input1.isMatrix() &&
ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2.isMatrix() &&
ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input3.isMatrix() &&
ec.getCacheableData(inst.input3).isFederatedExcept(FType.BROADCAST))) {
+ return TernaryFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static TernaryFEDInstruction
parseInstruction(TernaryCPInstruction instr) {
return new TernaryFEDInstruction((TernaryOperator)
instr.getOperator(), instr.input1, instr.input2,
instr.input3, instr.output, instr.getOpcode(),
instr.getInstructionString(), FederatedOutput.NONE);
}
- public static TernaryFEDInstruction
parseInstruction(TernarySPInstruction instr) {
+ private static TernaryFEDInstruction
parseInstruction(TernarySPInstruction instr) {
return new TernaryFEDInstruction((TernaryOperator)
instr.getOperator(), instr.input1, instr.input2,
instr.input3, instr.output, instr.getOpcode(),
instr.getInstructionString(), FederatedOutput.NONE);
}
@@ -63,11 +100,13 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
CPOperand operand2 = new CPOperand(parts[2]);
CPOperand operand3 = new CPOperand(parts[3]);
CPOperand outOperand = new CPOperand(parts[4]);
- int numThreads = parts.length>5 & !opcode.contains("map") ?
Integer.parseInt(parts[5]) : 1;
- FederatedOutput fedOut = parts.length>=7 &&
!opcode.contains("map") ? FederatedOutput.valueOf(parts[6]) :
FederatedOutput.NONE;
+ int numThreads = parts.length > 5 & !opcode.contains("map") ?
Integer.parseInt(parts[5]) : 1;
+ FederatedOutput fedOut = parts.length >= 7 &&
!opcode.contains("map") ? FederatedOutput
+ .valueOf(parts[6]) : FederatedOutput.NONE;
TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode, numThreads);
- if( operand1.isFrame() && operand2.isScalar() ||
operand2.isFrame() && operand1.isScalar() )
- return new TernaryFrameScalarFEDInstruction(op,
operand1, operand2, operand3, outOperand, opcode,
InstructionUtils.removeFEDOutputFlag(str), fedOut);
+ if(operand1.isFrame() && operand2.isScalar() ||
operand2.isFrame() && operand1.isScalar())
+ return new TernaryFrameScalarFEDInstruction(op,
operand1, operand2, operand3, outOperand, opcode,
+ InstructionUtils.removeFEDOutputFlag(str),
fedOut);
return new TernaryFEDInstruction(op, operand1, operand2,
operand3, outOperand, opcode, str, fedOut);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 3b15b273db..3d34338049 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -51,7 +51,15 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction
{
this(in, out, type, k, opcode, istr, FederatedOutput.NONE);
}
- public static TsmmFEDInstruction parseInstruction(MMTSJCPInstruction
instr) {
+ public static TsmmFEDInstruction parseInstruction(MMTSJCPInstruction
inst, ExecutionContext ec) {
+ MatrixObject mo = ec.getMatrixObject(inst.input1);
+ if( (mo.isFederated(FType.ROW) &&
mo.isFederatedExcept(FType.BROADCAST) && inst.getMMTSJType().isLeft()) ||
+ (mo.isFederated(FType.COL) &&
mo.isFederatedExcept(FType.BROADCAST) && inst.getMMTSJType().isRight()))
+ return parseInstruction(inst);
+ return null;
+ }
+
+ private static TsmmFEDInstruction parseInstruction(MMTSJCPInstruction
instr) {
return new TsmmFEDInstruction(instr.input1, instr.getOutput(),
instr.getMMTSJType(), instr.getNumThreads(),
instr.getOpcode(), instr.getInstructionString());
}
@@ -61,7 +69,7 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
String opcode = parts[0];
if(!opcode.equalsIgnoreCase("tsmm"))
throw new
DMLRuntimeException("TsmmFedInstruction.parseInstruction():: Unknown opcode " +
opcode);
-
+
InstructionUtils.checkNumFields(parts, 3, 4, 5);
CPOperand in = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
@@ -70,7 +78,7 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
FederatedOutput fedOut = (parts.length > 5) ?
FederatedOutput.valueOf(parts[5]) : FederatedOutput.NONE;
return new TsmmFEDInstruction(in, out, type, k, opcode, str,
fedOut);
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index 1c66e77768..623872e963 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -19,9 +19,32 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.QuantileSortCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ReshapeCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
@@ -55,7 +78,121 @@ public abstract class UnaryFEDInstruction extends
ComputationFEDInstruction {
super(type, op, in1, in2, in3, out, opcode, instr, fedOut);
}
- static String parseUnaryInstruction(String instr, CPOperand in,
CPOperand out) {
+ public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction
inst, ExecutionContext ec) {
+ if(inst instanceof IndexingCPInstruction) {
+ // matrix and frame indexing
+ IndexingCPInstruction minst = (IndexingCPInstruction)
inst;
+ if((minst.input1.isMatrix() || minst.input1.isFrame())
&&
+
ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
+ return
IndexingFEDInstruction.parseInstruction(minst);
+ }
+ }
+ else if(inst instanceof ReorgCPInstruction &&
+ (inst.getOpcode().equals("r'") ||
inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+ ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
+ CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+
+ if((mo instanceof MatrixObject || mo instanceof
FrameObject) && mo.isFederatedExcept(FType.BROADCAST))
+ return
ReorgFEDInstruction.parseInstruction(rinst);
+ }
+ else if(inst.input1 != null && inst.input1.isMatrix() &&
ec.containsVariable(inst.input1)) {
+
+ MatrixObject mo1 = ec.getMatrixObject(inst.input1);
+ if(mo1.isFederatedExcept(FType.BROADCAST)) {
+ if(inst instanceof CentralMomentCPInstruction)
+ return
CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction) inst);
+ else if(inst instanceof
QuantileSortCPInstruction) {
+ if(mo1.isFederated(FType.ROW) ||
+
mo1.getFedMapping().getFederatedRanges().length == 1 &&
mo1.isFederated(FType.COL))
+ return
QuantileSortFEDInstruction.parseInstruction((QuantileSortCPInstruction) inst);
+ }
+ else if(inst instanceof ReshapeCPInstruction)
+ return
ReshapeFEDInstruction.parseInstruction((ReshapeCPInstruction) inst);
+ else if(inst instanceof
AggregateUnaryCPInstruction &&
+ ((AggregateUnaryCPInstruction)
inst).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT)
+ return
AggregateUnaryFEDInstruction.parseInstruction((AggregateUnaryCPInstruction)
inst);
+ else if(inst instanceof
UnaryMatrixCPInstruction) {
+
if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()) &&
+
!(inst.getOpcode().equalsIgnoreCase("ucumk+*") && mo1.isFederated(FType.COL)))
+ return
UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixCPInstruction) inst);
+ }
+ }
+ }
+ return null;
+ }
+
+ public static UnaryFEDInstruction parseInstruction(UnarySPInstruction
inst, ExecutionContext ec) {
+ if(inst instanceof IndexingSPInstruction) {
+ // matrix and frame indexing
+ IndexingSPInstruction minst = (IndexingSPInstruction)
inst;
+ if((minst.input1.isMatrix() || minst.input1.isFrame())
&&
+
ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
+ return
IndexingFEDInstruction.parseInstruction(minst);
+ }
+ }
+ else if(inst instanceof CentralMomentSPInstruction) {
+ CentralMomentSPInstruction cinstruction =
(CentralMomentSPInstruction) inst;
+ Data data = ec.getVariable(cinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated() &&
+ ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
+ return
CentralMomentFEDInstruction.parseInstruction(cinstruction);
+ }
+ else if(inst instanceof QuantileSortSPInstruction) {
+ QuantileSortSPInstruction qinstruction =
(QuantileSortSPInstruction) inst;
+ Data data = ec.getVariable(qinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated() &&
+ ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
+ return
QuantileSortFEDInstruction.parseInstruction(qinstruction);
+ }
+ else if(inst instanceof AggregateUnarySPInstruction) {
+ AggregateUnarySPInstruction auinstruction =
(AggregateUnarySPInstruction) inst;
+ Data data = ec.getVariable(auinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated() &&
+ ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST))
+ if(ArrayUtils.contains(new String[] {"uarimin",
"uarimax"}, auinstruction.getOpcode())) {
+ if(((MatrixObject)
data).getFedMapping().getType() == FType.ROW)
+ return
AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
+ }
+ else
+ return
AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
+ }
+ else if(inst instanceof ReorgSPInstruction &&
+ (inst.getOpcode().equals("r'") ||
inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+ ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
+ CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+ if((mo instanceof MatrixObject || mo instanceof
FrameObject) && mo.isFederated() &&
+ mo.isFederatedExcept(FType.BROADCAST))
+ return
ReorgFEDInstruction.parseInstruction(rinst);
+ }
+ else if(inst instanceof ReblockSPInstruction && inst.input1 !=
null &&
+ (inst.input1.isFrame() || inst.input1.isMatrix())) {
+ ReblockSPInstruction rinst = (ReblockSPInstruction)
inst;
+ CacheableData<?> data =
ec.getCacheableData(rinst.input1);
+ if(data.isFederatedExcept(FType.BROADCAST))
+ return
ReblockFEDInstruction.parseInstruction((ReblockSPInstruction) inst);
+ }
+ else if(inst.input1 != null && inst.input1.isMatrix() &&
ec.containsVariable(inst.input1)) {
+ MatrixObject mo1 = ec.getMatrixObject(inst.input1);
+ if(mo1.isFederatedExcept(FType.BROADCAST)) {
+ if(inst.getOpcode().equalsIgnoreCase("cm"))
+ return
CentralMomentFEDInstruction.parseInstruction((CentralMomentSPInstruction) inst);
+ else
if(inst.getOpcode().equalsIgnoreCase("qsort")) {
+
if(mo1.getFedMapping().getFederatedRanges().length == 1)
+ return
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString(), false);
+ }
+ else
if(inst.getOpcode().equalsIgnoreCase("rshape")) {
+ return
ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if(inst instanceof
UnaryMatrixSPInstruction) {
+
if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()))
+ return
UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixSPInstruction) inst);
+ }
+ }
+ }
+ return null;
+ }
+
+ protected static String parseUnaryInstruction(String instr, CPOperand
in, CPOperand out) {
//TODO: simplify once all fed instructions have consistent flags
int num = InstructionUtils.checkNumFields(instr, 2, 3, 4);
if(num == 2)
@@ -69,12 +206,12 @@ public abstract class UnaryFEDInstruction extends
ComputationFEDInstruction {
}
}
- static String parseUnaryInstruction(String instr, CPOperand in1,
CPOperand in2, CPOperand out) {
+ protected static String parseUnaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand out) {
InstructionUtils.checkNumFields(instr, 3);
return parse(instr, in1, in2, null, out);
}
- static String parseUnaryInstruction(String instr, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out) {
+ protected static String parseUnaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand in3, CPOperand out) {
InstructionUtils.checkNumFields(instr, 4);
return parse(instr, in1, in2, in3, out);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index c3c2111641..890b681cef 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -70,12 +70,14 @@ public class UnaryMatrixFEDInstruction extends
UnaryFEDInstruction {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
- if(parts.length == 5 && (opcode.equalsIgnoreCase("exp") ||
opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
+ if(parts.length == 5 &&
+ (opcode.equalsIgnoreCase("exp") ||
opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
in.split(parts[1]);
out.split(parts[2]);
ValueFunction func = Builtin.getBuiltinFnObject(opcode);
- if( Arrays.asList(new
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
){
- UnaryOperator op = new
UnaryOperator(func,Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
+ if(Arrays.asList(new String[] {"ucumk+", "ucum*",
"ucumk+*", "ucummin", "ucummax", "exp", "log", "sigmoid"})
+ .contains(opcode)) {
+ UnaryOperator op = new UnaryOperator(func,
Integer.parseInt(parts[3]), Boolean.parseBoolean(parts[4]));
return new UnaryMatrixFEDInstruction(op, in,
out, opcode, str);
}
else
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index 4a51f49083..f89c32f374 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -29,6 +29,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -55,7 +56,23 @@ public class VariableFEDInstruction extends FEDInstruction
implements LineageTra
_in = in;
}
- public static VariableFEDInstruction
parseInstruction(VariableCPInstruction cpInstruction) {
+ public static VariableFEDInstruction
parseInstruction(VariableCPInstruction inst, ExecutionContext ec) {
+ if(inst.getVariableOpcode() == VariableOperationCode.Write &&
inst.getInput1().isMatrix() &&
+ inst.getInput3().getName().contains("federated")) {
+ return VariableFEDInstruction.parseInstruction(inst);
+ }
+ else if(inst.getVariableOpcode() ==
VariableOperationCode.CastAsFrameVariable && inst.getInput1().isMatrix() &&
+
ec.getCacheableData(inst.getInput1()).isFederatedExcept(FType.BROADCAST)) {
+ return VariableFEDInstruction.parseInstruction(inst);
+ }
+ else if(inst.getVariableOpcode() ==
VariableOperationCode.CastAsMatrixVariable && inst.getInput1().isFrame() &&
+
ec.getCacheableData(inst.getInput1()).isFederatedExcept(FType.BROADCAST)) {
+ return VariableFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static VariableFEDInstruction
parseInstruction(VariableCPInstruction cpInstruction) {
return new VariableFEDInstruction(cpInstruction);
}