This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new dfc57dc  [SYSTEMDS-2604] Compile federated instructions and federated 
output flag
dfc57dc is described below

commit dfc57dc4310d1dcc816cd7c2b5e7b1393b9abb78
Author: sebwrede <[email protected]>
AuthorDate: Mon Mar 22 19:34:59 2021 +0100

    [SYSTEMDS-2604] Compile federated instructions and federated output flag
    
    This commit adds a federated output flag to some operations and tests 
sequences of federated instructions with privacy constraints added.
    The operations supported are AggUnary, Binary, and Ternary. Closes #1229.
---
 .../java/org/apache/sysds/hops/AggUnaryOp.java     |   5 +-
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |   3 +-
 src/main/java/org/apache/sysds/hops/TernaryOp.java |   4 +
 src/main/java/org/apache/sysds/lops/Binary.java    |  14 ++-
 .../org/apache/sysds/lops/PartialAggregate.java    |  41 +++----
 src/main/java/org/apache/sysds/lops/Ternary.java   |  34 +++---
 .../runtime/instructions/FEDInstructionParser.java |  10 ++
 .../runtime/instructions/InstructionUtils.java     |  75 +++++++++++--
 .../fed/AggregateUnaryFEDInstruction.java          |  65 +++++++++--
 .../runtime/instructions/fed/FEDInstruction.java   |   2 +-
 .../instructions/fed/TernaryFEDInstruction.java    | 124 +++++++++++++++++----
 .../federated/algorithms/FederatedVarTest.java     |  12 +-
 .../test/functions/privacy/FederatedLmCGTest.java  |   1 +
 .../fedplanning/FederatedMultiplyPlanningTest.java |  84 +++++++++++++-
 .../privacy/FederatedMultiplyPlanningTest2.dml     |  29 +++++
 .../FederatedMultiplyPlanningTest2Reference.dml    |  27 +++++
 .../privacy/FederatedMultiplyPlanningTest3.dml     |  31 ++++++
 .../FederatedMultiplyPlanningTest3Reference.dml    |  29 +++++
 18 files changed, 492 insertions(+), 98 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index f842502..5d54535 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -59,6 +59,7 @@ public class AggUnaryOp extends MultiThreadedHop
                _direction = idx;
                getInput().add(0, inp);
                inp.getParent().add(this);
+               updateETFed();
        }
 
        @Override
@@ -151,13 +152,15 @@ public class AggUnaryOp extends MultiThreadedHop
                                        agg1 = new 
PartialAggregate(input.constructLops(),
                                                        _op, _direction, 
getDataType(),getValueType(), et, k);
                                }
-                               
+
                                setOutputDimensions(agg1);
                                setLineNumbers(agg1);
                                setLops(agg1);
                                
                                if (getDataType() == DataType.SCALAR) {
                                        
agg1.getOutputParameters().setDimensions(1, 1, getBlocksize(), getNnz());
+                               } else {
+                                       setFederatedOutput(agg1);
                                }
                        }
                        else if( et == ExecType.SPARK )
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 36cb051..897f707 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -224,6 +224,8 @@ public class BinaryOp extends MultiThreadedHop
                                constructLopsBinaryDefault();
                }
 
+               setFederatedOutput(getLops());
+
                //add reblock/checkpoint lops if necessary
                constructAndSetLopsDataFlowProperties();
 
@@ -464,7 +466,6 @@ public class BinaryOp extends MultiThreadedHop
                                                op, getDataType(), 
getValueType(), et,
                                                
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
 
-                               setFederatedOutput(binary);
                                setOutputDimensions(binary);
                                setLineNumbers(binary);
                                setLops(binary);
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java 
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 6f5a55b..3a8d02b 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -80,6 +80,7 @@ public class TernaryOp extends MultiThreadedHop
                getInput().add(0, inp1);
                getInput().add(1, inp2);
                getInput().add(2, inp3);
+               updateETFed();
                inp1.getParent().add(this);
                inp2.getParent().add(this);
                inp3.getParent().add(this);
@@ -97,6 +98,7 @@ public class TernaryOp extends MultiThreadedHop
                getInput().add(3, inp4);
                getInput().add(4, inp5);
                getInput().add(5, inp6);
+               updateETFed();
                inp1.getParent().add(this);
                inp2.getParent().add(this);
                inp3.getParent().add(this);
@@ -193,6 +195,8 @@ public class TernaryOp extends MultiThreadedHop
                catch(LopsException e) {
                        throw new HopsException(this.printErrorLocation() + 
"error constructing Lops for TernaryOp Hop " , e);
                }
+
+               setFederatedOutput(getLops());
                
                //add reblock/checkpoint lops if necessary
                constructAndSetLopsDataFlowProperties();
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java 
b/src/main/java/org/apache/sysds/lops/Binary.java
index 84bd033..5ba77bb 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -81,17 +81,19 @@ public class Binary extends Lop
 
        @Override
        public String getInstructions(String input1, String input2, String 
output) {
-               String baseInstruction = InstructionUtils.concatOperands(
+               InstructionUtils.concatBaseOperands(
                        getExecType().name(), getOpcode(),
                        getInputs().get(0).prepInputOperand(input1),
                        getInputs().get(1).prepInputOperand(input2),
                        prepOutputOperand(output)
                );
 
-               if( getExecType() == ExecType.CP || (!federatedOutput && 
getExecType() == ExecType.FED) )
-                       return InstructionUtils.concatOperands(baseInstruction, 
String.valueOf(_numThreads));
-               else if ( getExecType() == ExecType.FED )
-                       return InstructionUtils.concatOperands(baseInstruction, 
String.valueOf(_numThreads), String.valueOf(federatedOutput));
-               else return baseInstruction;
+               if ( getExecType() == ExecType.CP || getExecType() == 
ExecType.FED){
+                       
InstructionUtils.concatAdditionalOperand(String.valueOf(_numThreads));
+                       if ( federatedOutput )
+                               
InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
+               }
+
+               return InstructionUtils.getInstructionString();
        }
 }
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index c291782..118a804 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -27,7 +27,7 @@ import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
 import org.apache.sysds.hops.HopsException;
 import org.apache.sysds.lops.LopProperties.ExecType;
-
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 
 /**
  * Lop to perform a partial aggregation. It was introduced to do some initial
@@ -217,33 +217,22 @@ public class PartialAggregate extends Lop
        @Override
        public String getInstructions(String input1, String output) 
        {
-               StringBuilder sb = new StringBuilder();
-               sb.append( getExecType() );
-               
-               sb.append( OPERAND_DELIMITOR );
-               sb.append( getOpcode() );
-               
-               sb.append( OPERAND_DELIMITOR );
-               sb.append( getInputs().get(0).prepInputOperand(input1) );
-               
-               sb.append( OPERAND_DELIMITOR );
-               sb.append( prepOutputOperand(output) );
-               
-               //exec-type specific attributes
-               sb.append( OPERAND_DELIMITOR );
-               if( getExecType() == ExecType.SPARK )
-                       sb.append( _aggtype );
-               else if( getExecType() == ExecType.CP || getExecType() == 
ExecType.FED ) {
-                       sb.append(_numThreads);
+               InstructionUtils.concatBaseOperands(
+                       getExecType().name(),
+                       getOpcode(),
+                       getInputs().get(0).prepInputOperand(input1),
+                       prepOutputOperand(output));
 
-                       //number of outputs, valid for fed instruction
-                       if(getOpcode().equalsIgnoreCase("uarimin") || 
getOpcode().equalsIgnoreCase("uarimax")) {
-                               sb.append(OPERAND_DELIMITOR);
-                               sb.append("1");
-                       }
+               if ( getExecType() == ExecType.SPARK )
+                       
InstructionUtils.concatAdditionalOperand(_aggtype.toString());
+               else if ( getExecType() == ExecType.CP || getExecType() == 
ExecType.FED ){
+                       
InstructionUtils.concatAdditionalOperand(Integer.toString(_numThreads));
+                       if ( getOpcode().equalsIgnoreCase("uarimin") || 
getOpcode().equalsIgnoreCase("uarimax") )
+                               InstructionUtils.concatAdditionalOperand("1");
+                       if ( getExecType() == ExecType.FED && operation != 
AggOp.VAR )
+                               
InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
                }
-               
-               return sb.toString();
+               return InstructionUtils.getInstructionString();
        }
 
        public static String getOpcode(AggOp op, Direction dir)
diff --git a/src/main/java/org/apache/sysds/lops/Ternary.java 
b/src/main/java/org/apache/sysds/lops/Ternary.java
index a1a2d53..a6ad9d2 100644
--- a/src/main/java/org/apache/sysds/lops/Ternary.java
+++ b/src/main/java/org/apache/sysds/lops/Ternary.java
@@ -25,7 +25,7 @@ import org.apache.sysds.lops.LopProperties.ExecType;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.OpOp3;
 import org.apache.sysds.common.Types.ValueType;
-
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 
 /**
  * Lop to perform Sum of a matrix with another matrix multiplied by Scalar.
@@ -59,25 +59,19 @@ public class Ternary extends Lop
        
        @Override
        public String getInstructions(String input1, String input2, String 
input3, String output)  {
-               StringBuilder sb = new StringBuilder();
-               sb.append( getExecType() );
-               sb.append( OPERAND_DELIMITOR );
-               sb.append( _op.toString() );
-               
-               //process three operands and output
-               String[] inputs = new String[]{input1, input2, input3};
-               for( int i=0; i<3; i++ ) {
-                       sb.append( OPERAND_DELIMITOR );
-                       sb.append( 
getInputs().get(i).prepInputOperand(inputs[i]) );
-               }
-               sb.append( OPERAND_DELIMITOR );
-               sb.append( prepOutputOperand(output) );
-               
-               if( getExecType() == ExecType.CP && getDataType().isMatrix() ) {
-                       sb.append( OPERAND_DELIMITOR );
-                       sb.append( _numThreads );
+               InstructionUtils.concatOperands(
+                       getExecType().name(),
+                       _op.toString(),
+                       getInputs().get(0).prepInputOperand(input1),
+                       getInputs().get(1).prepInputOperand(input2),
+                       getInputs().get(2).prepInputOperand(input3),
+                       prepOutputOperand(output)
+               );
+               if( (getExecType() == ExecType.CP || getExecType() == 
ExecType.FED ) && getDataType().isMatrix() ){
+                       
InstructionUtils.concatAdditionalOperand(String.valueOf(_numThreads));
+                       if ( federatedOutput )
+                               
InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
                }
-               
-               return sb.toString();
+               return InstructionUtils.getInstructionString();
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index e6f430d..34db155 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -27,6 +27,7 @@ import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
 import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.TernaryFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.TsmmFEDInstruction;
 
 import java.util.HashMap;
@@ -46,18 +47,25 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "uasqk+"  , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uarsqk+" , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uacsqk+" , 
FEDType.AggregateUnary );
+               String2FEDInstructionType.put( "uavar"   , 
FEDType.AggregateUnary);
+               String2FEDInstructionType.put( "uarvar"  , 
FEDType.AggregateUnary);
+               String2FEDInstructionType.put( "uacvar"  , 
FEDType.AggregateUnary);
 
                // Arithmetic Instruction Opcodes
                String2FEDInstructionType.put( "+" , FEDType.Binary );
                String2FEDInstructionType.put( "-" , FEDType.Binary );
                String2FEDInstructionType.put( "*" , FEDType.Binary );
                String2FEDInstructionType.put( "/" , FEDType.Binary );
+               String2FEDInstructionType.put( "1-*" , FEDType.Binary); 
//special * case
 
                // Reorg Instruction Opcodes (repositioning of existing values)
                String2FEDInstructionType.put( "r'"     , FEDType.Reorg );
                String2FEDInstructionType.put( "rdiag"  , FEDType.Reorg );
                String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
 
+               // Ternary Instruction Opcodes
+               String2FEDInstructionType.put( "+*" , FEDType.Ternary);
+               String2FEDInstructionType.put( "-*" , FEDType.Ternary);
        }
 
        public static FEDInstruction parseSingleInstruction (String str ) {
@@ -86,6 +94,8 @@ public class FEDInstructionParser extends InstructionParser
                                return TsmmFEDInstruction.parseInstruction(str);
                        case Binary:
                                return 
BinaryFEDInstruction.parseInstruction(str);
+                       case Ternary:
+                               return 
TernaryFEDInstruction.parseInstruction(str);
                        case Reorg:
                                return 
ReorgFEDInstruction.parseInstruction(str);
                        default:
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index fb60d6b..7d386ec 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1033,16 +1033,78 @@ public class InstructionUtils
                parts[parts.length-1] = newName;
                return concatOperands(parts);
        }
-       
+
+       /**
+        * Concat the inputs as operands to generate the instruction string.
+        * The inputs are separated by the operand delimiter and appended
+        * using a ThreadLocal StringBuilder.
+        * @param inputs operand inputs given as strings
+        * @return the instruction string with the given inputs concatenated
+        */
        public static String concatOperands(String... inputs) {
-               return concatOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+               concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+               return _strBuilders.get().toString();
        }
-       
+
+       /**
+        * Concat the input parts with the value type delimiter.
+        * @param inputs input operand parts as strings
+        * @return concatenated input parts
+        */
        public static String concatOperandParts(String... inputs) {
-               return concatOperandsWithDelim(Instruction.VALUETYPE_PREFIX, 
inputs);
+               concatBaseOperandsWithDelim(Instruction.VALUETYPE_PREFIX, 
inputs);
+               return _strBuilders.get().toString();
        }
-       
-       private static String concatOperandsWithDelim(String delim, String... 
inputs) {
+
+       /**
+        * Concat the inputs as operands to generate the base instruction 
string.
+        * The base instruction string can subsequently be extended with the
+        * concatAdditional methods. The concatenation will be done using a
+        * ThreadLocal StringBuilder, so the concatenation is local to the 
thread.
+        * When all additional operands have been appended, the complete 
instruction
+        * string can be retrieved by calling the getInstructionString method.
+        * @param inputs operand inputs given as strings
+        */
+       public static void concatBaseOperands(String... inputs){
+               concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+       }
+
+       /**
+        * Concat input as an additional operand to the current thread-local 
base instruction string.
+        * @param input operand input given as string
+        */
+       public static void concatAdditionalOperand(String input){
+               StringBuilder sb = _strBuilders.get();
+               sb.append(Lop.OPERAND_DELIMITOR);
+               sb.append(input);
+       }
+
+       /**
+        * Concat inputs as additional operands to the current thread-local 
base instruction string.
+        * @param inputs operand inputs given as strings
+        */
+       public static void concatAdditionalOperands(String... inputs){
+               concatOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
+       }
+
+       /**
+        * Returns the current thread-local instruction string.
+        * This instruction string is built using the concat methods.
+        * @return instruction string
+        */
+       public static String getInstructionString(){
+               return _strBuilders.get().toString();
+       }
+
+       private static void concatOperandsWithDelim(String delim, String... 
inputs){
+               StringBuilder sb = _strBuilders.get();
+               for( int i=0; i<inputs.length; i++ ) {
+                       sb.append(delim);
+                       sb.append(inputs[i]);
+               }
+       }
+
+       private static void concatBaseOperandsWithDelim(String delim, String... 
inputs){
                StringBuilder sb = _strBuilders.get();
                sb.setLength(0); //reuse allocated space
                for( int i=0; i<inputs.length-1; i++ ) {
@@ -1050,7 +1112,6 @@ public class InstructionUtils
                        sb.append(delim);
                }
                sb.append(inputs[inputs.length-1]);
-               return sb.toString();
        }
        
        public static String concatStrings(String... inputs) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 5745ccd..da68d07 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.concurrent.Future;
 
 import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -35,11 +36,17 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
-       
+
        private AggregateUnaryFEDInstruction(AggregateUnaryOperator auop,
-               CPOperand in, CPOperand out, String opcode, String istr)
+               CPOperand in, CPOperand out, String opcode, String istr, 
boolean federatedOutput)
        {
-               super(FEDType.AggregateUnary, auop, in, out, opcode, istr);
+               super(FEDType.AggregateUnary, auop, in, out, opcode, istr, 
federatedOutput);
+       }
+
+       protected AggregateUnaryFEDInstruction(Operator op,
+               CPOperand in1, CPOperand in2, CPOperand out, String opcode, 
String istr, boolean federatedOutput)
+       {
+               super(FEDType.AggregateUnary, op, in1, in2, out, opcode, istr, 
federatedOutput);
        }
 
        protected AggregateUnaryFEDInstruction(Operator op,
@@ -68,7 +75,13 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
 
                if(InstructionUtils.getExecType(str) == ExecType.SPARK)
                        str = InstructionUtils.replaceOperand(str, 4, "-1");
-               return new AggregateUnaryFEDInstruction(aggun, in1, out, 
opcode, str);
+
+               boolean federatedOutput = false;
+               if ( parts.length > 6 )
+                       federatedOutput = Boolean.parseBoolean(parts[5]);
+               else if ( parts.length == 5 && !parts[4].equals("uarimin") && 
!parts[4].equals("uarimax") )
+                       federatedOutput = Boolean.parseBoolean(parts[4]);
+               return new AggregateUnaryFEDInstruction(aggun, in1, out, 
opcode, str, federatedOutput);
        }
        
        @Override
@@ -90,20 +103,58 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                        instString = 
InstructionUtils.replaceOperand(instString, 5, "2");
 
                //create federated commands for aggregation
+
+               if ( _federatedOutput )
+                       processFederatedOutput(map, in, ec);
+               else
+                       processGetOutput(map, aop, ec, in);
+       }
+
+       /**
+        * Sends federated request with instruction without retrieving the 
result from the workers.
+        * @param map federation map of the input
+        * @param in input matrix object
+        * @param ec execution context
+        */
+       private void processFederatedOutput(FederationMap map, MatrixObject in, 
ExecutionContext ec){
+               if ( output.isScalar() )
+                       throw new DMLRuntimeException("Output of FED 
instruction, " + output.toString()
+                               + ", is a scalar and the output is set to be 
federated. Scalars cannot be federated. ");
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                       new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, _federatedOutput);
+               map.execute(getTID(), fr1);
+
+               // derive new fed mapping for output
+               MatrixObject out = ec.getMatrixObject(output);
+               
out.setFedMapping(in.getFedMapping().copyWithNewID(fr1.getID()));
+       }
+
+       /**
+        * Sends federated request with instruction and retrieves the result 
from the workers.
+        * @param map federation map of input
+        * @param aggUOptr aggregate unary operator of the instruction
+        * @param ec execution context
+        * @param in input matrix object
+        */
+       private void processGetOutput(FederationMap map, AggregateUnaryOperator 
aggUOptr, ExecutionContext ec, MatrixObject in){
                FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()});
                FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
                FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
-               
+
                //execute federated commands and cleanups
                Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2, fr3);
                if( output.isScalar() )
-                       ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp, map));
+                       ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aggUOptr, tmp, map));
                else
-                       ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, tmp, map));
+                       ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aggUOptr, tmp, map));
        }
 
        private void processVar(ExecutionContext ec){
+               if ( _federatedOutput ){
+                       throw new DMLRuntimeException("Output of " + toString() 
+ " should not be federated "
+                               + "since the instruction requires consolidation 
of partial results to be computed.");
+               }
                AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
                MatrixObject in = ec.getMatrixObject(input1);
                FederationMap map = in.getFedMapping();
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index c75c798..a91cb0c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -38,6 +38,7 @@ public abstract class FEDInstruction extends Instruction {
                MMChain,
                MatrixIndexing,
                Ternary,
+               Tsmm,
                ParameterizedBuiltin,
                Quaternary,
                QSort,
@@ -45,7 +46,6 @@ public abstract class FEDInstruction extends Instruction {
                Reorg,
                Reshape,
                SpoofFused,
-               Tsmm,
                Unary
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 02791b9..5b52dbe 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -20,12 +20,15 @@
 package org.apache.sysds.runtime.instructions.fed;
 
 import java.util.Objects;
+import java.util.concurrent.Future;
 
 import com.sun.tools.javac.util.List;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -33,8 +36,9 @@ import 
org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 
 public class TernaryFEDInstruction extends ComputationFEDInstruction {
 
-       private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
-               super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, 
opcode, str);
+       private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
+               String opcode, String str, boolean federatedOutput) {
+               super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, 
opcode, str, federatedOutput);
        }
 
        public static TernaryFEDInstruction parseInstruction(String str) {
@@ -44,8 +48,10 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
                CPOperand operand2 = new CPOperand(parts[2]);
                CPOperand operand3 = new CPOperand(parts[3]);
                CPOperand outOperand = new CPOperand(parts[4]);
-               TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode);
-               return new TernaryFEDInstruction(op, operand1, operand2, 
operand3, outOperand, opcode, str);
+               int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) : 
1;
+               boolean federatedOutput = parts.length > 6 && 
parts[6].equals("true");
+               TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode, numThreads);
+               return new TernaryFEDInstruction(op, operand1, operand2, 
operand3, outOperand, opcode, str, federatedOutput);
        }
 
        @Override
@@ -87,9 +93,7 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
 
        private void processMatrixScalarInput(ExecutionContext ec, MatrixObject 
mo1, CPOperand in) {
                FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {in}, new 
long[] {mo1.getFedMapping().getID()});
-               mo1.getFedMapping().execute(getTID(), true, fr1);
-
-               setOutputFedMapping(ec, mo1, fr1.getID());
+               sendFederatedRequests(ec, mo1, fr1.getID(), fr1);
        }
 
        private void process2MatrixScalarInput(ExecutionContext ec, 
MatrixObject mo1, MatrixObject mo2, CPOperand in1, CPOperand in2) {
@@ -112,19 +116,95 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
                        varNewIn = new long[]{fr1[0].getID(), 
mo1.getFedMapping().getID()};
                }
                FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output, varOldIn, varNewIn);
-               FederatedRequest fr3;
 
                // 2 aligned inputs
                if(fr1 == null) {
-                       mo1.getFedMapping().execute(getTID(), true, fr2);
+                       sendFederatedRequests(ec, mo1, fr2.getID(), fr2);
                } else {
                        if(cleanupIn) {
-                               fr3 = mo1.getFedMapping().cleanup(getTID(), 
fr1[0].getID());
-                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
-                       } else
-                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2);
+                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+                               sendFederatedRequests(ec, mo1, fr2.getID(), 
fr1, fr2, fr3);
+                       }
+                       else
+                               sendFederatedRequests(ec, mo1, fr2.getID(), 
fr1, fr2);
+               }
+       }
+
+       /**
+        * Send federated requests and retrieve output if federated output flag 
is set.
+        * @param ec execution context
+        * @param fedMapObj matrix object with federated mapping where 
federated requests are sent to.
+        * @param fedOutputID ID of federated output
+        * @param federatedRequests federated requests for processing 
instruction
+        */
+       private void sendFederatedRequests(ExecutionContext ec, MatrixObject 
fedMapObj, long fedOutputID,
+               FederatedRequest... federatedRequests){
+               sendFederatedRequests(ec, fedMapObj, fedOutputID, null, null, 
federatedRequests);
+       }
+
+       /**
+        * Send federated requests and retrieve output if federated output flag 
is set.
+        * @param ec execution context
+        * @param fedMapObj matrix object with federated mapping where 
federated requests are sent to.
+        * @param fedOutputID ID of federated output
+        * @param federatedSlices federated requests for broadcasting slices 
before processing instruction
+        * @param federatedRequests federated requests for processing 
instruction
+        */
+       private void sendFederatedRequests(ExecutionContext ec, MatrixObject 
fedMapObj, long fedOutputID,
+               FederatedRequest[] federatedSlices, FederatedRequest... 
federatedRequests){
+               sendFederatedRequests(ec, fedMapObj, fedOutputID, 
federatedSlices, null, federatedRequests);
+       }
+
+       /**
+        * Send federated requests and retrieve output if federated output flag 
is set.
+        * @param ec execution context
+        * @param fedMapObj matrix object with federated mapping where 
federated requests are sent to.
+        * @param fedOutputID ID of federated output
+        * @param federatedSlices1 federated requests for broadcasting slices 
before processing instruction
+        * @param federatedSlices2 federated requests for broadcasting slices 
before processing instruction
+        * @param federatedRequests federated requests for processing 
instruction
+        */
+       private void sendFederatedRequests(ExecutionContext ec, MatrixObject 
fedMapObj, long fedOutputID,
+               FederatedRequest[] federatedSlices1, FederatedRequest[] 
federatedSlices2, FederatedRequest... federatedRequests){
+               if ( _federatedOutput ){
+                       fedMapObj.getFedMapping().execute(getTID(), true, 
federatedSlices1, federatedSlices2, federatedRequests);
+                       setOutputFedMapping(ec, fedMapObj, fedOutputID);
+               } else {
+                       processAndRetrieve(ec, fedMapObj, fedOutputID, 
federatedSlices1, federatedSlices2, federatedRequests);
                }
-               setOutputFedMapping(ec, mo1, fr2.getID());
+       }
+
+       /**
+        * Process instruction and get output from federated workers.
+        * @param ec execution context
+        * @param fedMapObj matrix object with federated mapping where 
federated requests are sent to.
+        * @param fedOutputID ID of federated output
+        * @param federatedSlices1 federated requests for broadcasting slices 
before processing instruction
+        * @param federatedSlices2 federated requests for broadcasting slices 
before processing instruction
+        * @param federatedRequests federated requests for processing 
instruction
+        */
+       private void processAndRetrieve(ExecutionContext ec, MatrixObject 
fedMapObj, long fedOutputID,
+               FederatedRequest[] federatedSlices1, FederatedRequest[] 
federatedSlices2, FederatedRequest... federatedRequests){
+               FederatedRequest getRequest = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fedOutputID);
+               Future<FederatedResponse>[] executionResponse = 
fedMapObj.getFedMapping().execute(
+                       getTID(), true, federatedSlices1, federatedSlices2, 
collectRequests(federatedRequests, getRequest));
+               ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(executionResponse,
+                       fedMapObj.isFederated(FederationMap.FType.COL)));
+       }
+
+       /**
+        * Collect federated requests into a single array of federated requests.
+        * The federated requests are added in the same order as the parameters 
of this method.
+        * @param fedRequests array of federated requests
+        * @param fedRequest1 federated request to occur after array
+        * @return federated requests collected in a single array
+        */
+       private FederatedRequest[] collectRequests(FederatedRequest[] 
fedRequests, FederatedRequest fedRequest1){
+               FederatedRequest[] allRequests = new 
FederatedRequest[fedRequests.length + 1];
+               for ( int i = 0; i < fedRequests.length; i++ )
+                       allRequests[i] = fedRequests[i];
+               allRequests[allRequests.length-1] = fedRequest1;
+               return allRequests;
        }
 
        private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, 
MatrixObject mo2, MatrixObject mo3) {
@@ -139,13 +219,13 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
                if(retAlignedValues._allAligned) {
                        fr3 = FederationUtils.callInstruction(instString, 
output, new CPOperand[] {input1, input2, input3},
                                new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
-                       mo1.getFedMapping().execute(getTID(), fr3);
+                       sendFederatedRequests(ec, mo1, fr3.getID(), fr3);
                }
                // 2 fed aligned inputs
                else if(retAlignedValues._twoAligned) {
                        fr3 = FederationUtils.callInstruction(instString, 
output, new CPOperand[] {input1, input2, input3}, retAlignedValues._vars);
                        fr4 = mo1.getFedMapping().cleanup(getTID(), 
retAlignedValues._fr[0].getID());
-                       mo1.getFedMapping().execute(getTID(), true, 
retAlignedValues._fr, fr3, fr4);
+                       sendFederatedRequests(ec, mo1, fr3.getID(), 
retAlignedValues._fr, fr3, fr4);
                }
                // 1 fed input or not aligned
                else {
@@ -169,11 +249,8 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
 
                        fr3 = FederationUtils.callInstruction(instString, 
output, new CPOperand[] {input1, input2, input3}, vars);
                        fr4 = mo1.getFedMapping().cleanup(getTID(), 
fr1[0].getID(), fr2[0].getID());
-                       mo1.getFedMapping().execute(getTID(), true, fr1, fr2, 
fr3, fr4);
+                       sendFederatedRequests(ec, mo1, fr3.getID(), fr1, fr2, 
fr3, fr4);
                }
-
-               //derive new fed mapping for output
-               setOutputFedMapping(ec, mo1, fr3.getID());
        }
 
        /**
@@ -225,9 +302,14 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
                }
        }
 
+       /**
+        * Set fed mapping of output. The data characteristics are not set.
+        * @param ec execution context
+        * @param fedMapObj federated matrix object from which federated 
mapping is derived
+        * @param fedOutputID ID for the fed mapping of output
+        */
        private void setOutputFedMapping(ExecutionContext ec, MatrixObject 
fedMapObj, long fedOutputID) {
                MatrixObject out = ec.getMatrixObject(output);
-               
out.getDataCharacteristics().set(fedMapObj.getDataCharacteristics());
                
out.setFedMapping(fedMapObj.getFedMapping().copyWithNewID(fedOutputID));
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
index 348f157..db07122 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
@@ -26,6 +26,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -71,10 +72,15 @@ public class FederatedVarTest extends AutomatedTestBase {
 
        @Test
        public void testVarCP() {
-               runAggregateOperationTest(ExecMode.SINGLE_NODE);
+               runAggregateOperationTest(ExecMode.SINGLE_NODE, false);
        }
 
-       private void runAggregateOperationTest(ExecMode execMode) {
+       @Test
+       public void testVarCPtoFED() {
+               runAggregateOperationTest(ExecMode.SINGLE_NODE, true);
+       }
+
+       private void runAggregateOperationTest(ExecMode execMode, boolean 
federatedCompilation) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
@@ -129,6 +135,7 @@ public class FederatedVarTest extends AutomatedTestBase {
 
                // Run actual dml script with federated matrix
 
+               OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-explain", "-stats", "100", 
"-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
@@ -152,6 +159,7 @@ public class FederatedVarTest extends AutomatedTestBase {
 
                TestUtils.shutdownThreads(t1, t2, t3, t4);
 
+               OptimizerUtils.FEDERATED_COMPILATION = false;
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
index edbe774..4f86483 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
[email protected]
 public class FederatedLmCGTest extends AutomatedTestBase
 {
        private final static String TEST_NAME = "lmCGFederated";
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index e4da423..7134347 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -39,6 +39,8 @@ import java.util.Collection;
 public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
        private final static String TEST_DIR = "functions/privacy/";
        private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
+       private final static String TEST_NAME_2 = 
"FederatedMultiplyPlanningTest2";
+       private final static String TEST_NAME_3 = 
"FederatedMultiplyPlanningTest3";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
@@ -51,6 +53,8 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"}));
        }
 
        @Parameterized.Parameters
@@ -64,7 +68,19 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        @Test
        public void federatedMultiplyCP() {
                OptimizerUtils.FEDERATED_COMPILATION = true;
-               federatedMultiply(Types.ExecMode.SINGLE_NODE);
+               federatedTwoMatricesSingleNodeTest(TEST_NAME);
+       }
+
+       @Test
+       public void federatedRowSum(){
+               OptimizerUtils.FEDERATED_COMPILATION = true;
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
+       }
+
+       @Test
+       public void federatedTernarySequence(){
+               OptimizerUtils.FEDERATED_COMPILATION = true;
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_3);
        }
 
        private void writeStandardMatrix(String matrixName, long seed){
@@ -75,7 +91,62 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                        new 
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
        }
 
-       public void federatedMultiply(Types.ExecMode execMode) {
+       public void federatedTwoMatricesSingleNodeTest(String testName){
+               federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
+       }
+
+       public void federatedTwoMatricesTest(Types.ExecMode execMode, String 
testName) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+
+               getAndLoadTestConfiguration(testName);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // Write input matrices
+               writeStandardMatrix("X1", 42);
+               writeStandardMatrix("X2", 1340);
+               writeStandardMatrix("Y1", 44);
+               writeStandardMatrix("Y2", 21);
+
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2);
+
+               TestConfiguration config = 
availableTestConfigurations.get(testName);
+               loadTestConfiguration(config);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + testName + ".dml";
+               programArgs = new String[] {"-explain", "-nvargs", "X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                       "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+                       "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+                       "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), 
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
+               runTest(true, false, null, -1);
+
+               OptimizerUtils.FEDERATED_COMPILATION = false;
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + testName + "Reference.dml";
+               programArgs = new String[] {"-nvargs", "X1=" + input("X1"), 
"X2=" + input("X2"), "Y1=" + input("Y1"),
+                       "Y2=" + input("Y2"), "Z=" + expected("Z")};
+               runTest(true, false, null, -1);
+
+               // compare via files
+               compareResults(1e-9);
+               heavyHittersContainsString("fed_*", "fed_ba+*");
+
+               TestUtils.shutdownThreads(t1, t2);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+
+       public void federatedThreeMatricesTest(Types.ExecMode execMode, String 
testName) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
                rtplatform = execMode;
@@ -83,7 +154,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                }
 
-               getAndLoadTestConfiguration(TEST_NAME);
+               getAndLoadTestConfiguration(testName);
                String HOME = SCRIPT_DIR + TEST_DIR;
 
                // Write input matrices
@@ -91,17 +162,18 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                writeStandardMatrix("X2", 1340);
                writeStandardMatrix("Y1", 44);
                writeStandardMatrix("Y2", 21);
+               writeStandardMatrix("W1", 55);
 
                int port1 = getRandomAvailablePort();
                int port2 = getRandomAvailablePort();
                Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
                Thread t2 = startLocalFedWorkerThread(port2);
 
-               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               TestConfiguration config = 
availableTestConfigurations.get(testName);
                loadTestConfiguration(config);
 
                // Run actual dml script with federated matrix
-               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               fullDMLScriptName = HOME + testName + ".dml";
                programArgs = new String[] {"-explain", "-nvargs", "X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
                        "X2=" + TestUtils.federatedAddress(port2, input("X2")),
                        "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
@@ -111,7 +183,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                OptimizerUtils.FEDERATED_COMPILATION = false;
 
                // Run reference dml script with normal matrix
-               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               fullDMLScriptName = HOME + testName + "Reference.dml";
                programArgs = new String[] {"-nvargs", "X1=" + input("X1"), 
"X2=" + input("X2"), "Y1=" + input("Y1"),
                        "Y2=" + input("Y2"), "Z=" + expected("Z")};
                runTest(true, false, null, -1);
diff --git 
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml 
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml
new file mode 100644
index 0000000..1b999d0
--- /dev/null
+++ b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), 
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+              ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), 
list($r, $c)))
+Z0 = X * Y
+Z1 = rowSums(Z0)
+Z = t(Z1) %*% X
+write(Z, $Z)
diff --git 
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml
 
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml
new file mode 100644
index 0000000..3d92cee
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+Z0 = X * Y
+Z1 = rowSums(Z0)
+Z = t(Z1) %*% X
+write(Z, $Z)
diff --git 
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml 
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
new file mode 100644
index 0000000..8e39a89
--- /dev/null
+++ b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c)))
+W = rand(rows=$r, cols=$c, min=0, max=1, pdf='uniform', seed=5)
+s = 3.5
+Z0 = W + s * X
+Z1 = 1 - Y * Z0
+Z = sum(Z1)
+write(Z, $Z)
diff --git 
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml
 
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml
new file mode 100644
index 0000000..15ad848
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+W = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, pdf='uniform', seed=5)
+s = 3.5
+Z0 = W + s * X
+Z1 = 1 - Y * Z0
+Z = sum(Z1)
+write(Z, $Z)

Reply via email to