kev-inn commented on code in PR #1628:
URL: https://github.com/apache/systemds/pull/1628#discussion_r891284083
##########
src/main/java/org/apache/sysds/runtime/instructions/cp/QuantilePickCPInstruction.java:
##########
@@ -65,7 +68,7 @@ else if( parts.length == 5 ) {
CPOperand out = new CPOperand(parts[2]);
OperationTypes ptype = OperationTypes.valueOf(parts[3]);
boolean inmem = Boolean.parseBoolean(parts[4]);
- return new QuantilePickCPInstruction(null, in1, new
CPOperand(), out, ptype, inmem, opcode, str);
+ return new QuantilePickCPInstruction(null, in1, out,
ptype, inmem, opcode, str);
Review Comment:
Our new approach allows us to remove this fix/hack.
##########
src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java:
##########
@@ -82,4 +85,28 @@ public void processInstruction(ExecutionContext ec)
double val = covobj.getRequiredResult(_optr);
ec.setScalarOutput(output_name, new DoubleObject(val));
}
+
+ @Override
+ protected FEDInstruction tryReplaceWithFederated(ExecutionContext ec) {
+ /* NOTE: the requirement before was:
+ if( (instruction.input1.isMatrix() &&
ec.getMatrixObject(instruction.input1).isFederatedExcept(FTypes.FType.BROADCAST))
+ || (instruction.input2.isMatrix() &&
ec.getMatrixObject(instruction.input2).isFederatedExcept(FTypes.FType.BROADCAST)))
{
+ if(instruction.getOpcode().equals(...) )
+ ...
+ else if("cov".equals(instruction.getOpcode()) &&
(ec.getMatrixObject(instruction.input1).isFederated(
+ FTypes.FType.ROW) ||
+
ec.getMatrixObject(instruction.input2).isFederated(FTypes.FType.ROW)))
+ ...
+ else
+ ...
+
+ This allowed `input1(Broadcast & Row) & input2(!Broadcast &
...)`, I am not sure if that is intentional.
+ */
+ if((input1.isMatrix() &&
ec.getMatrixObject(input1).isFederated(FTypes.FType.ROW) &&
+
ec.getMatrixObject(input1).isFederatedExcept(FTypes.FType.BROADCAST)) ||
+ (input2.isMatrix() &&
ec.getMatrixObject(input2).isFederated(FTypes.FType.ROW) &&
+
ec.getMatrixObject(input2).isFederatedExcept(FTypes.FType.BROADCAST)))
Review Comment:
I am not sure about this, no testcase covers the case explained at the
bottom of the comment I think.
##########
src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryScalarCPInstruction.java:
##########
@@ -26,10 +26,10 @@
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
-public class UnaryScalarCPInstruction extends UnaryMatrixCPInstruction {
+public class UnaryScalarCPInstruction extends UnaryCPInstruction {
Review Comment:
This is a problem due to inheritance. We replace
`UnaryMatrixCPInstruction`s, but don't replace `UnaryScalarCPInstruction`s,
with FED instructions. I think this inheritance made no sense and probably was
a bug.
If it *wasn't* we can also overload the `tryReplaceWithFederated()` here.
##########
src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java:
##########
@@ -119,4 +123,23 @@ public void processInstruction(ExecutionContext ec) {
//set and release output
ec.setMatrixOutput(output.getName(), resultBlock);
}
+
+ public int getNumThreads() {
+ return _numThreads;
+ }
+
+ @Override
+ protected FEDInstruction tryReplaceWithFederated(ExecutionContext ec) {
+ if(!input1.isMatrix())
+ return null;
+
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ if(!mo1.isFederatedExcept(FType.BROADCAST))
+ return null;
+
+ if(mo1.isFederated(FType.ROW) ||
+ (mo1.getFedMapping().getFederatedRanges().length == 1
&& mo1.isFederated(FType.COL)))
+ return new QuantileSortFEDInstruction(this);
+ return null;
Review Comment:
As you can notice in all implementations I opted to use early return checks
to reduce nesting. The last if on the other hand checks if replacement is
appropriate. This is inconsistent to the checks before, which checked for
failure, but is much more readable and therefore maintainable than confusing
inverted conditions in my experience.
##########
src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java:
##########
@@ -53,6 +56,12 @@ public AggregateBinaryFEDInstruction(Operator op, CPOperand
in1, CPOperand in2,
super(FEDType.AggregateBinary, op, in1, in2, out, opcode, istr,
fedOut);
}
+ public AggregateBinaryFEDInstruction(AggregateBinaryCPInstruction
instruction) {
+ // TODO @KEVIN: should we update the instruction string?
+ this(instruction.getOperator(), instruction.input1,
instruction.input2, instruction.output,
+ instruction.getOpcode(),
instruction.getInstructionString(), FederatedOutput.NONE);
Review Comment:
I always keep the instruction string without changes, but we might want to
replace the leading `CP` with `FED` and add the type of `FederatedOutput`
(always `NONE`, as it is decided at runtime).
##########
src/main/java/org/apache/sysds/runtime/instructions/cp/QuantilePickCPInstruction.java:
##########
@@ -127,4 +130,22 @@ public void processInstruction(ExecutionContext ec) {
throw new DMLRuntimeException("Unsupported
qpick operation type: "+_type);
}
}
+
+ public OperationTypes getOperationType() {
+ return _type;
+ }
+
+ public boolean isInMem() {
+ return _inmem;
+ }
+
+ @Override
+ protected FEDInstruction tryReplaceWithFederated(ExecutionContext ec) {
+ // Should we federate if the quantiles are given as a federated
matrix? (input2)
+ // This might be useful for privacy and for performance just in
extreme corner cases.
+ // The current implementation does not support it at the moment.
Review Comment:
This comment could be removed as we anyway only replace instructions where
equivalent FED implementations exist, but it is worth discussing/remembering
what to do about this case.
##########
src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java:
##########
@@ -50,6 +52,11 @@ public TsmmFEDInstruction(CPOperand in, CPOperand out,
MMTSJType type, int k, St
this(in, out, type, k, opcode, istr, FederatedOutput.NONE);
}
+ public TsmmFEDInstruction(MMTSJCPInstruction instruction) {
+ this(instruction.getInputs()[0], instruction.getOutput(),
instruction.getMMTSJType(), -1 /*ignore numThreads*/,
Review Comment:
Again, `numThreads` without any impact I think.
##########
src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java:
##########
@@ -69,6 +78,22 @@ public static AggregateBinaryFEDInstruction
parseInstruction(String str) {
InstructionUtils.getMatMultOperator(k), in1, in2, out,
opcode, str, fedOut);
}
+ public static AggregateBinaryFEDInstruction tryReplace(Instruction
inst, ExecutionContext ec) {
+ if (!(inst instanceof AggregateBinaryCPInstruction)) {
+ return null;
+ }
+ AggregateBinaryCPInstruction instruction =
(AggregateBinaryCPInstruction) inst;
+ if(!instruction.input1.isMatrix() ||
!instruction.input2.isMatrix()) {
+ return null;
+ }
+ MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
+ MatrixObject mo2 = ec.getMatrixObject(instruction.input2);
+ if (!(mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)
|| mo1.isFederated(FType.COL))) {
+ return null;
+ }
+ return new AggregateBinaryFEDInstruction(instruction);
+ }
+
Review Comment:
This is outdated from my first attempt, will be removed.
##########
src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java:
##########
@@ -43,6 +44,12 @@ public MMChainFEDInstruction(CPOperand in1, CPOperand in2,
CPOperand in3,
_type = type;
}
+ public MMChainFEDInstruction(MMChainCPInstruction instruction) {
+ this(instruction.input1, instruction.input2,
instruction.input3, instruction.output,
+ instruction.getMMChainType(), -1 /* ignore numThreads
*/, instruction.getOpcode(),
+ instruction.getInstructionString());
Review Comment:
I will update this to use the same `numThreads`, but I don't think this has
any effect for the federated implementation.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]