This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new dfc57dc [SYSTEMDS-2604] Compile federated instructions and federated
output flag
dfc57dc is described below
commit dfc57dc4310d1dcc816cd7c2b5e7b1393b9abb78
Author: sebwrede <[email protected]>
AuthorDate: Mon Mar 22 19:34:59 2021 +0100
[SYSTEMDS-2604] Compile federated instructions and federated output flag
This commit adds a federated output flag to some operations and tests
sequences of federated instructions with privacy constraints added.
The operations supported are AggUnary, Binary, and Ternary. Closes #1229.
---
.../java/org/apache/sysds/hops/AggUnaryOp.java | 5 +-
src/main/java/org/apache/sysds/hops/BinaryOp.java | 3 +-
src/main/java/org/apache/sysds/hops/TernaryOp.java | 4 +
src/main/java/org/apache/sysds/lops/Binary.java | 14 ++-
.../org/apache/sysds/lops/PartialAggregate.java | 41 +++----
src/main/java/org/apache/sysds/lops/Ternary.java | 34 +++---
.../runtime/instructions/FEDInstructionParser.java | 10 ++
.../runtime/instructions/InstructionUtils.java | 75 +++++++++++--
.../fed/AggregateUnaryFEDInstruction.java | 65 +++++++++--
.../runtime/instructions/fed/FEDInstruction.java | 2 +-
.../instructions/fed/TernaryFEDInstruction.java | 124 +++++++++++++++++----
.../federated/algorithms/FederatedVarTest.java | 12 +-
.../test/functions/privacy/FederatedLmCGTest.java | 1 +
.../fedplanning/FederatedMultiplyPlanningTest.java | 84 +++++++++++++-
.../privacy/FederatedMultiplyPlanningTest2.dml | 29 +++++
.../FederatedMultiplyPlanningTest2Reference.dml | 27 +++++
.../privacy/FederatedMultiplyPlanningTest3.dml | 31 ++++++
.../FederatedMultiplyPlanningTest3Reference.dml | 29 +++++
18 files changed, 492 insertions(+), 98 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index f842502..5d54535 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -59,6 +59,7 @@ public class AggUnaryOp extends MultiThreadedHop
_direction = idx;
getInput().add(0, inp);
inp.getParent().add(this);
+ updateETFed();
}
@Override
@@ -151,13 +152,15 @@ public class AggUnaryOp extends MultiThreadedHop
agg1 = new
PartialAggregate(input.constructLops(),
_op, _direction,
getDataType(),getValueType(), et, k);
}
-
+
setOutputDimensions(agg1);
setLineNumbers(agg1);
setLops(agg1);
if (getDataType() == DataType.SCALAR) {
agg1.getOutputParameters().setDimensions(1, 1, getBlocksize(), getNnz());
+ } else {
+ setFederatedOutput(agg1);
}
}
else if( et == ExecType.SPARK )
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 36cb051..897f707 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -224,6 +224,8 @@ public class BinaryOp extends MultiThreadedHop
constructLopsBinaryDefault();
}
+ setFederatedOutput(getLops());
+
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
@@ -464,7 +466,6 @@ public class BinaryOp extends MultiThreadedHop
op, getDataType(),
getValueType(), et,
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
- setFederatedOutput(binary);
setOutputDimensions(binary);
setLineNumbers(binary);
setLops(binary);
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 6f5a55b..3a8d02b 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -80,6 +80,7 @@ public class TernaryOp extends MultiThreadedHop
getInput().add(0, inp1);
getInput().add(1, inp2);
getInput().add(2, inp3);
+ updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
inp3.getParent().add(this);
@@ -97,6 +98,7 @@ public class TernaryOp extends MultiThreadedHop
getInput().add(3, inp4);
getInput().add(4, inp5);
getInput().add(5, inp6);
+ updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
inp3.getParent().add(this);
@@ -193,6 +195,8 @@ public class TernaryOp extends MultiThreadedHop
catch(LopsException e) {
throw new HopsException(this.printErrorLocation() +
"error constructing Lops for TernaryOp Hop " , e);
}
+
+ setFederatedOutput(getLops());
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java
b/src/main/java/org/apache/sysds/lops/Binary.java
index 84bd033..5ba77bb 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -81,17 +81,19 @@ public class Binary extends Lop
@Override
public String getInstructions(String input1, String input2, String
output) {
- String baseInstruction = InstructionUtils.concatOperands(
+ InstructionUtils.concatBaseOperands(
getExecType().name(), getOpcode(),
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
prepOutputOperand(output)
);
- if( getExecType() == ExecType.CP || (!federatedOutput &&
getExecType() == ExecType.FED) )
- return InstructionUtils.concatOperands(baseInstruction,
String.valueOf(_numThreads));
- else if ( getExecType() == ExecType.FED )
- return InstructionUtils.concatOperands(baseInstruction,
String.valueOf(_numThreads), String.valueOf(federatedOutput));
- else return baseInstruction;
+ if ( getExecType() == ExecType.CP || getExecType() ==
ExecType.FED){
+
InstructionUtils.concatAdditionalOperand(String.valueOf(_numThreads));
+ if ( federatedOutput )
+
InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
+ }
+
+ return InstructionUtils.getInstructionString();
}
}
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index c291782..118a804 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -27,7 +27,7 @@ import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.lops.LopProperties.ExecType;
-
+import org.apache.sysds.runtime.instructions.InstructionUtils;
/**
* Lop to perform a partial aggregation. It was introduced to do some initial
@@ -217,33 +217,22 @@ public class PartialAggregate extends Lop
@Override
public String getInstructions(String input1, String output)
{
- StringBuilder sb = new StringBuilder();
- sb.append( getExecType() );
-
- sb.append( OPERAND_DELIMITOR );
- sb.append( getOpcode() );
-
- sb.append( OPERAND_DELIMITOR );
- sb.append( getInputs().get(0).prepInputOperand(input1) );
-
- sb.append( OPERAND_DELIMITOR );
- sb.append( prepOutputOperand(output) );
-
- //exec-type specific attributes
- sb.append( OPERAND_DELIMITOR );
- if( getExecType() == ExecType.SPARK )
- sb.append( _aggtype );
- else if( getExecType() == ExecType.CP || getExecType() ==
ExecType.FED ) {
- sb.append(_numThreads);
+ InstructionUtils.concatBaseOperands(
+ getExecType().name(),
+ getOpcode(),
+ getInputs().get(0).prepInputOperand(input1),
+ prepOutputOperand(output));
- //number of outputs, valid for fed instruction
- if(getOpcode().equalsIgnoreCase("uarimin") ||
getOpcode().equalsIgnoreCase("uarimax")) {
- sb.append(OPERAND_DELIMITOR);
- sb.append("1");
- }
+ if ( getExecType() == ExecType.SPARK )
+
InstructionUtils.concatAdditionalOperand(_aggtype.toString());
+ else if ( getExecType() == ExecType.CP || getExecType() ==
ExecType.FED ){
+
InstructionUtils.concatAdditionalOperand(Integer.toString(_numThreads));
+ if ( getOpcode().equalsIgnoreCase("uarimin") ||
getOpcode().equalsIgnoreCase("uarimax") )
+ InstructionUtils.concatAdditionalOperand("1");
+ if ( getExecType() == ExecType.FED && operation !=
AggOp.VAR )
+
InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
}
-
- return sb.toString();
+ return InstructionUtils.getInstructionString();
}
public static String getOpcode(AggOp op, Direction dir)
diff --git a/src/main/java/org/apache/sysds/lops/Ternary.java
b/src/main/java/org/apache/sysds/lops/Ternary.java
index a1a2d53..a6ad9d2 100644
--- a/src/main/java/org/apache/sysds/lops/Ternary.java
+++ b/src/main/java/org/apache/sysds/lops/Ternary.java
@@ -25,7 +25,7 @@ import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.ValueType;
-
+import org.apache.sysds.runtime.instructions.InstructionUtils;
/**
* Lop to perform Sum of a matrix with another matrix multiplied by Scalar.
@@ -59,25 +59,19 @@ public class Ternary extends Lop
@Override
public String getInstructions(String input1, String input2, String
input3, String output) {
- StringBuilder sb = new StringBuilder();
- sb.append( getExecType() );
- sb.append( OPERAND_DELIMITOR );
- sb.append( _op.toString() );
-
- //process three operands and output
- String[] inputs = new String[]{input1, input2, input3};
- for( int i=0; i<3; i++ ) {
- sb.append( OPERAND_DELIMITOR );
- sb.append(
getInputs().get(i).prepInputOperand(inputs[i]) );
- }
- sb.append( OPERAND_DELIMITOR );
- sb.append( prepOutputOperand(output) );
-
- if( getExecType() == ExecType.CP && getDataType().isMatrix() ) {
- sb.append( OPERAND_DELIMITOR );
- sb.append( _numThreads );
+ InstructionUtils.concatOperands(
+ getExecType().name(),
+ _op.toString(),
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ getInputs().get(2).prepInputOperand(input3),
+ prepOutputOperand(output)
+ );
+ if( (getExecType() == ExecType.CP || getExecType() ==
ExecType.FED ) && getDataType().isMatrix() ){
+
InstructionUtils.concatAdditionalOperand(String.valueOf(_numThreads));
+ if ( federatedOutput )
+
InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
}
-
- return sb.toString();
+ return InstructionUtils.getInstructionString();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index e6f430d..34db155 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -27,6 +27,7 @@ import
org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.TernaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.TsmmFEDInstruction;
import java.util.HashMap;
@@ -46,18 +47,25 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "uasqk+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uarsqk+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uacsqk+" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uavar" ,
FEDType.AggregateUnary);
+ String2FEDInstructionType.put( "uarvar" ,
FEDType.AggregateUnary);
+ String2FEDInstructionType.put( "uacvar" ,
FEDType.AggregateUnary);
// Arithmetic Instruction Opcodes
String2FEDInstructionType.put( "+" , FEDType.Binary );
String2FEDInstructionType.put( "-" , FEDType.Binary );
String2FEDInstructionType.put( "*" , FEDType.Binary );
String2FEDInstructionType.put( "/" , FEDType.Binary );
+ String2FEDInstructionType.put( "1-*" , FEDType.Binary);
//special * case
// Reorg Instruction Opcodes (repositioning of existing values)
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
+ // Ternary Instruction Opcodes
+ String2FEDInstructionType.put( "+*" , FEDType.Ternary);
+ String2FEDInstructionType.put( "-*" , FEDType.Ternary);
}
public static FEDInstruction parseSingleInstruction (String str ) {
@@ -86,6 +94,8 @@ public class FEDInstructionParser extends InstructionParser
return TsmmFEDInstruction.parseInstruction(str);
case Binary:
return
BinaryFEDInstruction.parseInstruction(str);
+ case Ternary:
+ return
TernaryFEDInstruction.parseInstruction(str);
case Reorg:
return
ReorgFEDInstruction.parseInstruction(str);
default:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index fb60d6b..7d386ec 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1033,16 +1033,78 @@ public class InstructionUtils
parts[parts.length-1] = newName;
return concatOperands(parts);
}
-
+
+ /**
+ * Concat the inputs as operands to generate the instruction string.
+ * The inputs are separated by the operand delimiter and appended
+ * using a ThreadLocal StringBuilder.
+ * @param inputs operand inputs given as strings
+ * @return the instruction string with the given inputs concatenated
+ */
public static String concatOperands(String... inputs) {
- return concatOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+ concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+ return _strBuilders.get().toString();
}
-
+
+ /**
+ * Concat the input parts with the value type delimiter.
+ * @param inputs input operand parts as strings
+ * @return concatenated input parts
+ */
public static String concatOperandParts(String... inputs) {
- return concatOperandsWithDelim(Instruction.VALUETYPE_PREFIX,
inputs);
+ concatBaseOperandsWithDelim(Instruction.VALUETYPE_PREFIX,
inputs);
+ return _strBuilders.get().toString();
}
-
- private static String concatOperandsWithDelim(String delim, String...
inputs) {
+
+ /**
+ * Concat the inputs as operands to generate the base instruction
string.
+ * The base instruction string can subsequently be extended with the
+ * concatAdditional methods. The concatenation will be done using a
+ * ThreadLocal StringBuilder, so the concatenation is local to the
thread.
+ * When all additional operands have been appended, the complete
instruction
+ * string can be retrieved by calling the getInstructionString method.
+ * @param inputs operand inputs given as strings
+ */
+ public static void concatBaseOperands(String... inputs){
+ concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+ }
+
+ /**
+ * Concat input as an additional operand to the current thread-local
base instruction string.
+ * @param input operand input given as string
+ */
+ public static void concatAdditionalOperand(String input){
+ StringBuilder sb = _strBuilders.get();
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(input);
+ }
+
+ /**
+ * Concat inputs as additional operands to the current thread-local
base instruction string.
+ * @param inputs operand inputs given as strings
+ */
+ public static void concatAdditionalOperands(String... inputs){
+ concatOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+ }
+
+ /**
+ * Returns the current thread-local instruction string.
+ * This instruction string is built using the concat methods.
+ * @return instruction string
+ */
+ public static String getInstructionString(){
+ return _strBuilders.get().toString();
+ }
+
+ private static void concatOperandsWithDelim(String delim, String...
inputs){
+ StringBuilder sb = _strBuilders.get();
+ for( int i=0; i<inputs.length; i++ ) {
+ sb.append(delim);
+ sb.append(inputs[i]);
+ }
+ }
+
+ private static void concatBaseOperandsWithDelim(String delim, String...
inputs){
StringBuilder sb = _strBuilders.get();
sb.setLength(0); //reuse allocated space
for( int i=0; i<inputs.length-1; i++ ) {
@@ -1050,7 +1112,6 @@ public class InstructionUtils
sb.append(delim);
}
sb.append(inputs[inputs.length-1]);
- return sb.toString();
}
public static String concatStrings(String... inputs) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 5745ccd..da68d07 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -35,11 +36,17 @@ import
org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
-
+
private AggregateUnaryFEDInstruction(AggregateUnaryOperator auop,
- CPOperand in, CPOperand out, String opcode, String istr)
+ CPOperand in, CPOperand out, String opcode, String istr,
boolean federatedOutput)
{
- super(FEDType.AggregateUnary, auop, in, out, opcode, istr);
+ super(FEDType.AggregateUnary, auop, in, out, opcode, istr,
federatedOutput);
+ }
+
+ protected AggregateUnaryFEDInstruction(Operator op,
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr, boolean federatedOutput)
+ {
+ super(FEDType.AggregateUnary, op, in1, in2, out, opcode, istr,
federatedOutput);
}
protected AggregateUnaryFEDInstruction(Operator op,
@@ -68,7 +75,13 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
if(InstructionUtils.getExecType(str) == ExecType.SPARK)
str = InstructionUtils.replaceOperand(str, 4, "-1");
- return new AggregateUnaryFEDInstruction(aggun, in1, out,
opcode, str);
+
+ boolean federatedOutput = false;
+ if ( parts.length > 6 )
+ federatedOutput = Boolean.parseBoolean(parts[5]);
+ else if ( parts.length == 5 && !parts[4].equals("uarimin") &&
!parts[4].equals("uarimax") )
+ federatedOutput = Boolean.parseBoolean(parts[4]);
+ return new AggregateUnaryFEDInstruction(aggun, in1, out,
opcode, str, federatedOutput);
}
@Override
@@ -90,20 +103,58 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
instString =
InstructionUtils.replaceOperand(instString, 5, "2");
//create federated commands for aggregation
+
+ if ( _federatedOutput )
+ processFederatedOutput(map, in, ec);
+ else
+ processGetOutput(map, aop, ec, in);
+ }
+
+ /**
+ * Sends federated request with instruction without retrieving the
result from the workers.
+ * @param map federation map of the input
+ * @param in input matrix object
+ * @param ec execution context
+ */
+ private void processFederatedOutput(FederationMap map, MatrixObject in,
ExecutionContext ec){
+ if ( output.isScalar() )
+ throw new DMLRuntimeException("Output of FED
instruction, " + output.toString()
+ + ", is a scalar and the output is set to be
federated. Scalars cannot be federated. ");
+ FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1}, new
long[]{in.getFedMapping().getID()}, _federatedOutput);
+ map.execute(getTID(), fr1);
+
+ // derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+
out.setFedMapping(in.getFedMapping().copyWithNewID(fr1.getID()));
+ }
+
+ /**
+ * Sends federated request with instruction and retrieves the result
from the workers.
+ * @param map federation map of input
+ * @param aggUOptr aggregate unary operator of the instruction
+ * @param ec execution context
+ * @param in input matrix object
+ */
+ private void processGetOutput(FederationMap map, AggregateUnaryOperator
aggUOptr, ExecutionContext ec, MatrixObject in){
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new
long[]{in.getFedMapping().getID()});
FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
-
+
//execute federated commands and cleanups
Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1,
fr2, fr3);
if( output.isScalar() )
- ec.setVariable(output.getName(),
FederationUtils.aggScalar(aop, tmp, map));
+ ec.setVariable(output.getName(),
FederationUtils.aggScalar(aggUOptr, tmp, map));
else
- ec.setMatrixOutput(output.getName(),
FederationUtils.aggMatrix(aop, tmp, map));
+ ec.setMatrixOutput(output.getName(),
FederationUtils.aggMatrix(aggUOptr, tmp, map));
}
private void processVar(ExecutionContext ec){
+ if ( _federatedOutput ){
+ throw new DMLRuntimeException("Output of " + toString()
+ " should not be federated "
+ + "since the instruction requires consolidation
of partial results to be computed.");
+ }
AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
MatrixObject in = ec.getMatrixObject(input1);
FederationMap map = in.getFedMapping();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index c75c798..a91cb0c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -38,6 +38,7 @@ public abstract class FEDInstruction extends Instruction {
MMChain,
MatrixIndexing,
Ternary,
+ Tsmm,
ParameterizedBuiltin,
Quaternary,
QSort,
@@ -45,7 +46,6 @@ public abstract class FEDInstruction extends Instruction {
Reorg,
Reshape,
SpoofFused,
- Tsmm,
Unary
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 02791b9..5b52dbe 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -20,12 +20,15 @@
package org.apache.sysds.runtime.instructions.fed;
import java.util.Objects;
+import java.util.concurrent.Future;
import com.sun.tools.javac.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -33,8 +36,9 @@ import
org.apache.sysds.runtime.matrix.operators.TernaryOperator;
public class TernaryFEDInstruction extends ComputationFEDInstruction {
- private TernaryFEDInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
- super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out,
opcode, str);
+ private TernaryFEDInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
+ String opcode, String str, boolean federatedOutput) {
+ super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out,
opcode, str, federatedOutput);
}
public static TernaryFEDInstruction parseInstruction(String str) {
@@ -44,8 +48,10 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
CPOperand operand2 = new CPOperand(parts[2]);
CPOperand operand3 = new CPOperand(parts[3]);
CPOperand outOperand = new CPOperand(parts[4]);
- TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode);
- return new TernaryFEDInstruction(op, operand1, operand2,
operand3, outOperand, opcode, str);
+ int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) :
1;
+ boolean federatedOutput = parts.length > 6 &&
parts[6].equals("true");
+ TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode, numThreads);
+ return new TernaryFEDInstruction(op, operand1, operand2,
operand3, outOperand, opcode, str, federatedOutput);
}
@Override
@@ -87,9 +93,7 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
private void processMatrixScalarInput(ExecutionContext ec, MatrixObject
mo1, CPOperand in) {
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {in}, new
long[] {mo1.getFedMapping().getID()});
- mo1.getFedMapping().execute(getTID(), true, fr1);
-
- setOutputFedMapping(ec, mo1, fr1.getID());
+ sendFederatedRequests(ec, mo1, fr1.getID(), fr1);
}
private void process2MatrixScalarInput(ExecutionContext ec,
MatrixObject mo1, MatrixObject mo2, CPOperand in1, CPOperand in2) {
@@ -112,19 +116,95 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
varNewIn = new long[]{fr1[0].getID(),
mo1.getFedMapping().getID()};
}
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output, varOldIn, varNewIn);
- FederatedRequest fr3;
// 2 aligned inputs
if(fr1 == null) {
- mo1.getFedMapping().execute(getTID(), true, fr2);
+ sendFederatedRequests(ec, mo1, fr2.getID(), fr2);
} else {
if(cleanupIn) {
- fr3 = mo1.getFedMapping().cleanup(getTID(),
fr1[0].getID());
- mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
- } else
- mo1.getFedMapping().execute(getTID(), true,
fr1, fr2);
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ sendFederatedRequests(ec, mo1, fr2.getID(),
fr1, fr2, fr3);
+ }
+ else
+ sendFederatedRequests(ec, mo1, fr2.getID(),
fr1, fr2);
+ }
+ }
+
+ /**
+ * Send federated requests and retrieve output if federated output flag
is set.
+ * @param ec execution context
+ * @param fedMapObj matrix object with federated mapping where
federated requests are sent to.
+ * @param fedOutputID ID of federated output
+ * @param federatedRequests federated requests for processing
instruction
+ */
+ private void sendFederatedRequests(ExecutionContext ec, MatrixObject
fedMapObj, long fedOutputID,
+ FederatedRequest... federatedRequests){
+ sendFederatedRequests(ec, fedMapObj, fedOutputID, null, null,
federatedRequests);
+ }
+
+ /**
+ * Send federated requests and retrieve output if federated output flag
is set.
+ * @param ec execution context
+ * @param fedMapObj matrix object with federated mapping where
federated requests are sent to.
+ * @param fedOutputID ID of federated output
+ * @param federatedSlices federated requests for broadcasting slices
before processing instruction
+ * @param federatedRequests federated requests for processing
instruction
+ */
+ private void sendFederatedRequests(ExecutionContext ec, MatrixObject
fedMapObj, long fedOutputID,
+ FederatedRequest[] federatedSlices, FederatedRequest...
federatedRequests){
+ sendFederatedRequests(ec, fedMapObj, fedOutputID,
federatedSlices, null, federatedRequests);
+ }
+
+ /**
+ * Send federated requests and retrieve output if federated output flag
is set.
+ * @param ec execution context
+ * @param fedMapObj matrix object with federated mapping where
federated requests are sent to.
+ * @param fedOutputID ID of federated output
+ * @param federatedSlices1 federated requests for broadcasting slices
before processing instruction
+ * @param federatedSlices2 federated requests for broadcasting slices
before processing instruction
+ * @param federatedRequests federated requests for processing
instruction
+ */
+ private void sendFederatedRequests(ExecutionContext ec, MatrixObject
fedMapObj, long fedOutputID,
+ FederatedRequest[] federatedSlices1, FederatedRequest[]
federatedSlices2, FederatedRequest... federatedRequests){
+ if ( _federatedOutput ){
+ fedMapObj.getFedMapping().execute(getTID(), true,
federatedSlices1, federatedSlices2, federatedRequests);
+ setOutputFedMapping(ec, fedMapObj, fedOutputID);
+ } else {
+ processAndRetrieve(ec, fedMapObj, fedOutputID,
federatedSlices1, federatedSlices2, federatedRequests);
}
- setOutputFedMapping(ec, mo1, fr2.getID());
+ }
+
+ /**
+ * Process instruction and get output from federated workers.
+ * @param ec execution context
+ * @param fedMapObj matrix object with federated mapping where
federated requests are sent to.
+ * @param fedOutputID ID of federated output
+ * @param federatedSlices1 federated requests for broadcasting slices
before processing instruction
+ * @param federatedSlices2 federated requests for broadcasting slices
before processing instruction
+ * @param federatedRequests federated requests for processing
instruction
+ */
+ private void processAndRetrieve(ExecutionContext ec, MatrixObject
fedMapObj, long fedOutputID,
+ FederatedRequest[] federatedSlices1, FederatedRequest[]
federatedSlices2, FederatedRequest... federatedRequests){
+ FederatedRequest getRequest = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fedOutputID);
+ Future<FederatedResponse>[] executionResponse =
fedMapObj.getFedMapping().execute(
+ getTID(), true, federatedSlices1, federatedSlices2,
collectRequests(federatedRequests, getRequest));
+ ec.setMatrixOutput(output.getName(),
FederationUtils.bind(executionResponse,
+ fedMapObj.isFederated(FederationMap.FType.COL)));
+ }
+
+ /**
+ * Collect federated requests into a single array of federated requests.
+ * The federated requests are added in the same order as the parameters
of this method.
+ * @param fedRequests array of federated requests
+ * @param fedRequest1 federated request to occur after array
+ * @return federated requests collected in a single array
+ */
+ private FederatedRequest[] collectRequests(FederatedRequest[]
fedRequests, FederatedRequest fedRequest1){
+ FederatedRequest[] allRequests = new
FederatedRequest[fedRequests.length + 1];
+ for ( int i = 0; i < fedRequests.length; i++ )
+ allRequests[i] = fedRequests[i];
+ allRequests[allRequests.length-1] = fedRequest1;
+ return allRequests;
}
private void processMatrixInput(ExecutionContext ec, MatrixObject mo1,
MatrixObject mo2, MatrixObject mo3) {
@@ -139,13 +219,13 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
if(retAlignedValues._allAligned) {
fr3 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1, input2, input3},
new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
- mo1.getFedMapping().execute(getTID(), fr3);
+ sendFederatedRequests(ec, mo1, fr3.getID(), fr3);
}
// 2 fed aligned inputs
else if(retAlignedValues._twoAligned) {
fr3 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1, input2, input3}, retAlignedValues._vars);
fr4 = mo1.getFedMapping().cleanup(getTID(),
retAlignedValues._fr[0].getID());
- mo1.getFedMapping().execute(getTID(), true,
retAlignedValues._fr, fr3, fr4);
+ sendFederatedRequests(ec, mo1, fr3.getID(),
retAlignedValues._fr, fr3, fr4);
}
// 1 fed input or not aligned
else {
@@ -169,11 +249,8 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
fr3 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1, input2, input3}, vars);
fr4 = mo1.getFedMapping().cleanup(getTID(),
fr1[0].getID(), fr2[0].getID());
- mo1.getFedMapping().execute(getTID(), true, fr1, fr2,
fr3, fr4);
+ sendFederatedRequests(ec, mo1, fr3.getID(), fr1, fr2,
fr3, fr4);
}
-
- //derive new fed mapping for output
- setOutputFedMapping(ec, mo1, fr3.getID());
}
/**
@@ -225,9 +302,14 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
}
}
+ /**
+ * Set fed mapping of output. The data characteristics are not set.
+ * @param ec execution context
+ * @param fedMapObj federated matrix object from which federated
mapping is derived
+ * @param fedOutputID ID for the fed mapping of output
+ */
private void setOutputFedMapping(ExecutionContext ec, MatrixObject
fedMapObj, long fedOutputID) {
MatrixObject out = ec.getMatrixObject(output);
-
out.getDataCharacteristics().set(fedMapObj.getDataCharacteristics());
out.setFedMapping(fedMapObj.getFedMapping().copyWithNewID(fedOutputID));
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
index 348f157..db07122 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
@@ -26,6 +26,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
@@ -71,10 +72,15 @@ public class FederatedVarTest extends AutomatedTestBase {
@Test
public void testVarCP() {
- runAggregateOperationTest(ExecMode.SINGLE_NODE);
+ runAggregateOperationTest(ExecMode.SINGLE_NODE, false);
}
- private void runAggregateOperationTest(ExecMode execMode) {
+ @Test
+ public void testVarCPtoFED() {
+ runAggregateOperationTest(ExecMode.SINGLE_NODE, true);
+ }
+
+ private void runAggregateOperationTest(ExecMode execMode, boolean
federatedCompilation) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -129,6 +135,7 @@ public class FederatedVarTest extends AutomatedTestBase {
// Run actual dml script with federated matrix
+ OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-explain", "-stats", "100",
"-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
@@ -152,6 +159,7 @@ public class FederatedVarTest extends AutomatedTestBase {
TestUtils.shutdownThreads(t1, t2, t3, t4);
+ OptimizerUtils.FEDERATED_COMPILATION = false;
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
index edbe774..4f86483 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
[email protected]
public class FederatedLmCGTest extends AutomatedTestBase
{
private final static String TEST_NAME = "lmCGFederated";
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index e4da423..7134347 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -39,6 +39,8 @@ import java.util.Collection;
public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/privacy/";
private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
+ private final static String TEST_NAME_2 =
"FederatedMultiplyPlanningTest2";
+ private final static String TEST_NAME_3 =
"FederatedMultiplyPlanningTest3";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -51,6 +53,8 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"}));
}
@Parameterized.Parameters
@@ -64,7 +68,19 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
@Test
public void federatedMultiplyCP() {
OptimizerUtils.FEDERATED_COMPILATION = true;
- federatedMultiply(Types.ExecMode.SINGLE_NODE);
+ federatedTwoMatricesSingleNodeTest(TEST_NAME);
+ }
+
+ @Test
+ public void federatedRowSum(){
+ OptimizerUtils.FEDERATED_COMPILATION = true;
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
+ }
+
+ @Test
+ public void federatedTernarySequence(){
+ OptimizerUtils.FEDERATED_COMPILATION = true;
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_3);
}
private void writeStandardMatrix(String matrixName, long seed){
@@ -75,7 +91,62 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
new
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
}
- public void federatedMultiply(Types.ExecMode execMode) {
+ public void federatedTwoMatricesSingleNodeTest(String testName){
+ federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
+ }
+
+ public void federatedTwoMatricesTest(Types.ExecMode execMode, String
testName) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // Write input matrices
+ writeStandardMatrix("X1", 42);
+ writeStandardMatrix("X2", 1340);
+ writeStandardMatrix("Y1", 44);
+ writeStandardMatrix("Y2", 21);
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2);
+
+ TestConfiguration config =
availableTestConfigurations.get(testName);
+ loadTestConfiguration(config);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] {"-explain", "-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ OptimizerUtils.FEDERATED_COMPILATION = false;
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" + input("X1"),
"X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ heavyHittersContainsString("fed_*", "fed_ba+*");
+
+ TestUtils.shutdownThreads(t1, t2);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+
+ public void federatedThreeMatricesTest(Types.ExecMode execMode, String
testName) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -83,7 +154,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- getAndLoadTestConfiguration(TEST_NAME);
+ getAndLoadTestConfiguration(testName);
String HOME = SCRIPT_DIR + TEST_DIR;
// Write input matrices
@@ -91,17 +162,18 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
writeStandardMatrix("X2", 1340);
writeStandardMatrix("Y1", 44);
writeStandardMatrix("Y2", 21);
+ writeStandardMatrix("W1", 55);
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
- TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ TestConfiguration config =
availableTestConfigurations.get(testName);
loadTestConfiguration(config);
// Run actual dml script with federated matrix
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ fullDMLScriptName = HOME + testName + ".dml";
programArgs = new String[] {"-explain", "-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
@@ -111,7 +183,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
OptimizerUtils.FEDERATED_COMPILATION = false;
// Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ fullDMLScriptName = HOME + testName + "Reference.dml";
programArgs = new String[] {"-nvargs", "X1=" + input("X1"),
"X2=" + input("X2"), "Y1=" + input("Y1"),
"Y2=" + input("Y2"), "Z=" + expected("Z")};
runTest(true, false, null, -1);
diff --git
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml
new file mode 100644
index 0000000..1b999d0
--- /dev/null
+++ b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0),
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0),
list($r, $c)))
+Z0 = X * Y
+Z1 = rowSums(Z0)
+Z = t(Z1) %*% X
+write(Z, $Z)
diff --git
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml
new file mode 100644
index 0000000..3d92cee
--- /dev/null
+++
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+Z0 = X * Y
+Z1 = rowSums(Z0)
+Z = t(Z1) %*% X
+write(Z, $Z)
diff --git
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
new file mode 100644
index 0000000..8e39a89
--- /dev/null
+++ b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c)))
+W = rand(rows=$r, cols=$c, min=0, max=1, pdf='uniform', seed=5)
+s = 3.5
+Z0 = W + s * X
+Z1 = 1 - Y * Z0
+Z = sum(Z1)
+write(Z, $Z)
diff --git
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml
new file mode 100644
index 0000000..15ad848
--- /dev/null
+++
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+W = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, pdf='uniform', seed=5)
+s = 3.5
+Z0 = W + s * X
+Z1 = 1 - Y * Z0
+Z = sum(Z1)
+write(Z, $Z)