This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cb61dc7  [SYSTEMDS-3018] Federated Planning Operator Support
cb61dc7 is described below

commit cb61dc74526655fb7bf9c6db4b857b6f6a070230
Author: sebwrede <[email protected]>
AuthorDate: Tue Mar 8 12:31:01 2022 +0100

    [SYSTEMDS-3018] Federated Planning Operator Support
    
    Closes #1557.
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    |   2 +
 src/main/java/org/apache/sysds/hops/Hop.java       |  11 +-
 .../org/apache/sysds/lops/TernaryAggregate.java    |   8 +-
 .../runtime/instructions/FEDInstructionParser.java |   5 +
 .../fed/AggregateTernaryFEDInstruction.java        |  18 ++-
 .../fedplanning/FederatedL2SVMPlanningTest.java    | 144 +++++++++++++++++++++
 .../fedplanning/FederatedMultiplyPlanningTest.java |  78 ++++++-----
 .../fedplanning/FederatedL2SVMPlanningTest.dml     | 132 +++++++++++++++++++
 .../FederatedL2SVMPlanningTestReference.dml        | 131 +++++++++++++++++++
 .../fedplanning/FederatedMultiplyPlanningTest9.dml |  35 +++++
 .../FederatedMultiplyPlanningTest9Reference.dml    |  33 +++++
 11 files changed, 554 insertions(+), 43 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 078c754..31309c1 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -668,11 +668,13 @@ public class AggBinaryOp extends MultiThreadedHop {
                                new Transform(lY, ReOrgOp.TRANS, getDataType(), 
getValueType(), ExecType.CP, k);
                tY.getOutputParameters().setDimensions(Y.getDim2(), 
Y.getDim1(), getBlocksize(), Y.getNnz());
                setLineNumbers(tY);
+               updateLopFedOut(tY);
                
                //matrix mult
                Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(), 
getValueType(), ExecType.CP, k);
                mult.getOutputParameters().setDimensions(Y.getDim2(), 
X.getDim2(), getBlocksize(), getNnz());
                setLineNumbers(mult);
+               updateLopFedOut(mult);
                
                //result transpose (dimensions set outside)
                Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), 
getValueType(), ExecType.CP, k);
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 003492f..f91fad9 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -375,10 +375,19 @@ public abstract class Hop implements ParseInfo {
        public boolean requiresLineageCaching() {
                return _requiresLineageCaching;
        }
+
+       public void updateLopFedOut(Lop lop){
+               updateLopFedOut(lop, getExecType(), _federatedOutput);
+       }
+
+       public void updateLopFedOut(Lop lop, ExecType execType, FederatedOutput 
fedOut){
+               if ( execType == ExecType.FED )
+                       lop.setFederatedOutput(fedOut);
+       }
        
        public void constructAndSetLopsDataFlowProperties() {
                //propagate federated output configuration to lops
-               if( isFederated() )
+               if( isFederated() || getLops().getExecType() == ExecType.FED )
                        getLops().setFederatedOutput(_federatedOutput);
                if ( prefetchActivated() )
                        getLops().activatePrefetch();
diff --git a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java 
b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
index 6058c63..65773c0 100644
--- a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
@@ -82,9 +82,13 @@ public class TernaryAggregate extends Lop
                sb.append( OPERAND_DELIMITOR );
                sb.append( prepOutputOperand(output));
                
-               if( getExecType() == ExecType.CP ) {
+               if( getExecType() == ExecType.CP || getExecType() == 
ExecType.FED ) {
                        sb.append( OPERAND_DELIMITOR );
-                       sb.append( _numThreads );       
+                       sb.append( _numThreads );
+                       if ( getExecType() == ExecType.FED ){
+                               sb.append( OPERAND_DELIMITOR );
+                               sb.append( _fedOutput.name() );
+                       }
                }
                
                return sb.toString();
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 8000da7..11ea4e0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions;
 import org.apache.sysds.lops.Append;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.AppendFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
@@ -42,6 +43,7 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "fedinit"  , FEDType.Init );
                String2FEDInstructionType.put( "tsmm"     , FEDType.Tsmm );
                String2FEDInstructionType.put( "ba+*"     , 
FEDType.AggregateBinary );
+               String2FEDInstructionType.put( "tak+*"    , 
FEDType.AggregateTernary);
 
                String2FEDInstructionType.put( "uak+"    , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uark+"   , 
FEDType.AggregateUnary );
@@ -59,6 +61,7 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "*" , FEDType.Binary );
                String2FEDInstructionType.put( "/" , FEDType.Binary );
                String2FEDInstructionType.put( "1-*" , FEDType.Binary); 
//special * case
+               String2FEDInstructionType.put( "max" , FEDType.Binary );
 
                // Reorg Instruction Opcodes (repositioning of existing values)
                String2FEDInstructionType.put( "r'"     , FEDType.Reorg );
@@ -106,6 +109,8 @@ public class FEDInstructionParser extends InstructionParser
                                return 
ReorgFEDInstruction.parseInstruction(str);
                        case Append:
                                return 
AppendFEDInstruction.parseInstruction(str);
+                       case AggregateTernary:
+                               return 
AggregateTernaryFEDInstruction.parseInstruction(str);
                        default:
                                throw new DMLRuntimeException("Invalid 
FEDERATED Instruction Type: " + fedtype );
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 1b0c0bf..cfe0baf 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -42,8 +42,8 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
        // private static final Log LOG = 
LogFactory.getLog(AggregateTernaryFEDInstruction.class.getName());
 
        private AggregateTernaryFEDInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
-               String opcode, String istr) {
-               super(FEDType.AggregateTernary, op, in1, in2, in3, out, opcode, 
istr);
+               String opcode, String istr, FederatedOutput fedOut) {
+               super(FEDType.AggregateTernary, op, in1, in2, in3, out, opcode, 
istr, fedOut);
        }
 
        public static AggregateTernaryFEDInstruction parseInstruction(String 
str) {
@@ -51,16 +51,19 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                String opcode = parts[0];
 
                if(opcode.equalsIgnoreCase("tak+*") || 
opcode.equalsIgnoreCase("tack+*")) {
-                       InstructionUtils.checkNumFields(parts, 5);
+                       InstructionUtils.checkNumFields(parts, 5, 6);
 
                        CPOperand in1 = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
                        CPOperand in3 = new CPOperand(parts[3]);
                        CPOperand out = new CPOperand(parts[4]);
                        int numThreads = Integer.parseInt(parts[5]);
+                       FederatedOutput fedOut = FederatedOutput.NONE;
+                       if ( parts.length == 7 )
+                               fedOut = FederatedOutput.valueOf(parts[6]);
 
                        AggregateTernaryOperator op = 
InstructionUtils.parseAggregateTernaryOperator(opcode, numThreads);
-                       return new AggregateTernaryFEDInstruction(op, in1, in2, 
in3, out, opcode, str);
+                       return new AggregateTernaryFEDInstruction(op, in1, in2, 
in3, out, opcode, str, fedOut);
                }
                else {
                        throw new 
DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown 
opcode " + opcode);
@@ -77,7 +80,8 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                                && 
mo2.getFedMapping().isAligned(mo3.getFedMapping(), mo1.isFederated(FType.ROW) ? 
AlignType.ROW : AlignType.COL)) {
                        FederatedRequest fr1 = 
FederationUtils.callInstruction(getInstructionString(), output,
                                new CPOperand[] {input1, input2, input3},
-                               new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+                               new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()},
+                               true);
                        FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
                        FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
                        Future<FederatedResponse>[] response = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
@@ -96,7 +100,7 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(ec.getScalarInput(input3));
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[] {input1, input2, input3},
-                               new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr1.getID()});
+                               new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr1.getID()}, true);
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
                        FederatedRequest fr4 = 
mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
                        Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
@@ -121,7 +125,7 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                        FederatedRequest[] fr2 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
                        FederatedRequest fr3 = 
FederationUtils.callInstruction(getInstructionString(), output,
                                new CPOperand[] {input1, input2, input3},
-                               new long[] {mo1.getFedMapping().getID(), 
fr2[0].getID(), fr1[0].getID()});
+                               new long[] {mo1.getFedMapping().getID(), 
fr2[0].getID(), fr1[0].getID()}, true);
                        FederatedRequest fr4 = new 
FederatedRequest(RequestType.GET_VAR, fr3.getID());
                        Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4);
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
new file mode 100644
index 0000000..75fc236
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
[email protected]
+public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
+
+       private final static String TEST_DIR = "functions/privacy/fedplanning/";
+       private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       public final int rows = 100;
+       public final int cols = 10;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Test
+       public void runL2SVMTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
+               loadAndRunTest(expectedHeavyHitters);
+       }
+
+       private void writeInputMatrices(){
+               writeStandardRowFedMatrix("X1", 65, null);
+               writeStandardRowFedMatrix("X2", 75, null);
+               writeBinaryVector("Y", 44, null);
+
+       }
+
+       private void writeBinaryVector(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+               double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed);
+               for(int i = 0; i < rows; i++)
+                       matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+               MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, 
blocksize, rows);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+       }
+
+       private void writeStandardMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+               writeStandardMatrix(matrixName, seed, rows, privacyConstraint);
+       }
+
+       private void writeStandardMatrix(String matrixName, long seed, int 
numRows, PrivacyConstraint privacyConstraint){
+               double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, 
seed);
+               writeStandardMatrix(matrixName, numRows, privacyConstraint, 
matrix);
+       }
+
+       private void writeStandardMatrix(String matrixName, int numRows, 
PrivacyConstraint privacyConstraint, double[][] matrix){
+               MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
cols, blocksize, (long) numRows * cols);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+       }
+
+       private void writeStandardRowFedMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+               int halfRows = rows/2;
+               writeStandardMatrix(matrixName, seed, halfRows, 
privacyConstraint);
+       }
+
+       private void loadAndRunTest(String[] expectedHeavyHitters){
+
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+
+               Thread t1 = null, t2 = null;
+
+               try {
+                       OptimizerUtils.FEDERATED_COMPILATION = true;
+
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       writeInputMatrices();
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2);
+
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] { "-stats", "-explain", 
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+                       runTest(true, false, null, -1);
+
+                       OptimizerUtils.FEDERATED_COMPILATION = false;
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
+                               "Y=" + input("Y"), "Z=" + expected("Z")};
+                       runTest(true, false, null, -1);
+
+                       // compare via files
+                       compareResults(1e-9);
+                       if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+               }
+               finally {
+                       OptimizerUtils.FEDERATED_COMPILATION = false;
+                       TestUtils.shutdownThreads(t1, t2);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 1e59b86..6bc993e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -49,6 +49,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        private final static String TEST_NAME_6 = 
"FederatedMultiplyPlanningTest6";
        private final static String TEST_NAME_7 = 
"FederatedMultiplyPlanningTest7";
        private final static String TEST_NAME_8 = 
"FederatedMultiplyPlanningTest8";
+       private final static String TEST_NAME_9 = 
"FederatedMultiplyPlanningTest9";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
@@ -68,6 +69,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                addTestConfiguration(TEST_NAME_6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
                addTestConfiguration(TEST_NAME_7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
                addTestConfiguration(TEST_NAME_8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
+               addTestConfiguration(TEST_NAME_9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
        }
 
        @Parameterized.Parameters
@@ -128,6 +130,12 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                federatedTwoMatricesSingleNodeTest(TEST_NAME_8, 
expectedHeavyHitters);
        }
 
+       @Test
+       public void federatedMultiplyPlanningTest9(){
+               String[] expectedHeavyHitters = new String[]{"fed_+*", 
"fed_1-*", "fed_fedinit", "fed_tak+*", "fed_max"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_9, 
expectedHeavyHitters);
+       }
+
        private void writeStandardMatrix(String matrixName, long seed){
                writeStandardMatrix(matrixName, seed, new 
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
        }
@@ -201,52 +209,56 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        }
 
        private void federatedTwoMatricesTest(Types.ExecMode execMode, String 
testName, String[] expectedHeavyHitters) {
-               OptimizerUtils.FEDERATED_COMPILATION = true;
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
                rtplatform = execMode;
                if(rtplatform == Types.ExecMode.SPARK) {
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                }
+               Thread t1 = null, t2 = null;
 
-               getAndLoadTestConfiguration(testName);
-               String HOME = SCRIPT_DIR + TEST_DIR;
-
-               writeInputMatrices(testName);
-
-               int port1 = getRandomAvailablePort();
-               int port2 = getRandomAvailablePort();
-               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
-               Thread t2 = startLocalFedWorkerThread(port2);
+               try{
+                       OptimizerUtils.FEDERATED_COMPILATION = true;
 
-               // Run actual dml script with federated matrix
-               fullDMLScriptName = HOME + testName + ".dml";
-               programArgs = new String[] {"-stats", "-explain", "-nvargs", 
"X1=" + TestUtils.federatedAddress(port1, input("X1")),
-                       "X2=" + TestUtils.federatedAddress(port2, input("X2")),
-                       "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
-                       "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), 
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
-               rewriteRealProgramArgs(testName, port1, port2);
-               runTest(true, false, null, -1);
+                       getAndLoadTestConfiguration(testName);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
 
-               OptimizerUtils.FEDERATED_COMPILATION = false;
+                       writeInputMatrices(testName);
 
-               // Run reference dml script with normal matrix
-               fullDMLScriptName = HOME + testName + "Reference.dml";
-               programArgs = new String[] {"-nvargs", "X1=" + input("X1"), 
"X2=" + input("X2"), "Y1=" + input("Y1"),
-                       "Y2=" + input("Y2"), "Z=" + expected("Z")};
-               rewriteReferenceProgramArgs(testName);
-               runTest(true, false, null, -1);
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2);
 
-               // compare via files
-               compareResults(1e-9);
-               if (!heavyHittersContainsAllString(expectedHeavyHitters))
-                       fail("The following expected heavy hitters are missing: 
"
-                               + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[] {"-stats", "-explain", 
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+                               "Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+                       rewriteRealProgramArgs(testName, port1, port2);
+                       runTest(true, false, null, -1);
 
-               TestUtils.shutdownThreads(t1, t2);
+                       OptimizerUtils.FEDERATED_COMPILATION = false;
 
-               rtplatform = platformOld;
-               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + testName + "Reference.dml";
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+                               "Y2=" + input("Y2"), "Z=" + expected("Z")};
+                       rewriteReferenceProgramArgs(testName);
+                       runTest(true, false, null, -1);
+
+                       // compare via files
+                       compareResults(1e-9);
+                       if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+               } finally {
+                       OptimizerUtils.FEDERATED_COMPILATION = false;
+                       TestUtils.shutdownThreads(t1, t2);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
        }
 
        private void rewriteRealProgramArgs(String testName, int port1, int 
port2){
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.dml 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.dml
new file mode 100644
index 0000000..13a63bb
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.dml
@@ -0,0 +1,132 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+  maxii = 20
+  verbose = FALSE
+  columnId = -1
+  Y = read($Y)
+  X = federated(addresses=list($X1, $X2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+  intercept = FALSE
+  epsilon = 1e-12
+  lambda = 1
+  maxIterations = 100
+
+  #check input parameter assertions
+  if(nrow(X) < 2)
+    stop("L2SVM: Stopping due to invalid inputs: Not possible to learn a 
binary class classifier without at least 2 rows")
+  if(epsilon < 0)
+    stop("L2SVM: Stopping due to invalid argument: Tolerance (tol) must be 
non-negative")
+  if(lambda < 0)
+    stop("L2SVM: Stopping due to invalid argument: Regularization constant 
(reg) must be non-negative")
+  if(maxIterations < 1)
+    stop("L2SVM: Stopping due to invalid argument: Maximum iterations should 
be a positive integer")
+  if(ncol(Y) < 1)
+    stop("L2SVM: Stopping due to invalid multiple label columns, maybe use 
MSVM instead?")
+
+  #check input lables and transform into -1/1
+  check_min = min(Y)
+  check_max = max(Y)
+
+  num_min = sum(Y == check_min)
+  num_max = sum(Y == check_max)
+
+  # TODO make this a stop condition for l2svm instead of just printing.
+  if(num_min + num_max != nrow(Y))
+    print("L2SVM: WARNING invalid number of labels in Y: "+num_min+" "+num_max)
+
+  # Scale inputs to -1 for negative, and 1 for positive classification
+  if(check_min != -1 | check_max != +1)
+    Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - 
check_min)
+
+  # If column_id is -1 then we assume that the fundamental algorithm is MSVM,
+  # Therefore don't print message.
+  if(verbose & columnId == -1)
+    print('Running L2-SVM ')
+
+  num_samples = nrow(X)
+  num_classes = ncol(Y)
+
+  # Add Bias
+  num_rows_in_w = ncol(X)
+  if (intercept) {
+    ones  = matrix(1, rows=num_samples, cols=1)
+    X = cbind(X, ones);
+    num_rows_in_w += 1
+  }
+
+  w = matrix(0, rows=num_rows_in_w, cols=1)
+
+  g_old = t(X) %*% Y
+  s = g_old
+
+  Xw = matrix(0, rows=nrow(X), cols=1)
+
+  iter = 0
+  continue = TRUE
+  while(continue & iter < maxIterations)  {
+    # minimizing primal obj along direction s
+    step_sz = 0
+    Xd = X %*% s
+    wd = lambda * sum(w * s)
+    dd = lambda * sum(s * s)
+    continue1 = TRUE
+    iiter = 0
+    while(continue1 & iiter < maxii){
+      tmp_Xw = Xw + step_sz*Xd
+      out = 1 - Y * (tmp_Xw)
+      sv = (out > 0)
+      out = out * sv
+      g = wd + step_sz*dd - sum(out * Y * Xd)
+      h = dd + sum(Xd * sv * Xd)
+      step_sz = step_sz - g/h
+      continue1 = (g*g/h >= epsilon)
+      iiter = iiter + 1
+    }
+
+    #update weights
+    w = w + step_sz*s
+    Xw = Xw + step_sz*Xd
+
+    out = 1 - Y * Xw
+    sv = (out > 0)
+    out = sv * out
+    obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
+    g_new = t(X) %*% (out * Y) - lambda * w
+
+    if(verbose) {
+      colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,")
+      print("Iter: " + toString(iter) + " InnerIter: " + toString(iiter) +" 
--- "+ colstr + " Obj:" + obj)
+    }
+
+    tmp = sum(s * g_old)
+    continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
+
+    #non-linear CG step
+    be = sum(g_new * g_new)/sum(g_old * g_old)
+    s = be * s + g_new
+    g_old = g_new
+
+    iter = iter + 1
+  }
+  model = w
+  write(model, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTestReference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTestReference.dml
new file mode 100644
index 0000000..d92b3dc
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTestReference.dml
@@ -0,0 +1,131 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+  maxii = 20
+  verbose = FALSE
+  columnId = -1
+  Y = read($Y)
+  X = rbind(read($X1), read($X2))
+  intercept = FALSE
+  epsilon = 1e-12
+  lambda = 1
+  maxIterations = 100
+
+  #check input parameter assertions
+  if(nrow(X) < 2)
+    stop("L2SVM: Stopping due to invalid inputs: Not possible to learn a 
binary class classifier without at least 2 rows")
+  if(epsilon < 0)
+    stop("L2SVM: Stopping due to invalid argument: Tolerance (tol) must be 
non-negative")
+  if(lambda < 0)
+    stop("L2SVM: Stopping due to invalid argument: Regularization constant 
(reg) must be non-negative")
+  if(maxIterations < 1)
+    stop("L2SVM: Stopping due to invalid argument: Maximum iterations should 
be a positive integer")
+  if(ncol(Y) < 1)
+    stop("L2SVM: Stopping due to invalid multiple label columns, maybe use 
MSVM instead?")
+
+  #check input lables and transform into -1/1
+  check_min = min(Y)
+  check_max = max(Y)
+
+  num_min = sum(Y == check_min)
+  num_max = sum(Y == check_max)
+
+  # TODO make this a stop condition for l2svm instead of just printing.
+  if(num_min + num_max != nrow(Y))
+    print("L2SVM: WARNING invalid number of labels in Y: "+num_min+" "+num_max)
+
+  # Scale inputs to -1 for negative, and 1 for positive classification
+  if(check_min != -1 | check_max != +1)
+    Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - 
check_min)
+
+  # If column_id is -1 then we assume that the fundamental algorithm is MSVM,
+  # Therefore don't print message.
+  if(verbose & columnId == -1)
+    print('Running L2-SVM ')
+
+  num_samples = nrow(X)
+  num_classes = ncol(Y)
+
+  # Add Bias
+  num_rows_in_w = ncol(X)
+  if (intercept) {
+    ones  = matrix(1, rows=num_samples, cols=1)
+    X = cbind(X, ones);
+    num_rows_in_w += 1
+  }
+
+  w = matrix(0, rows=num_rows_in_w, cols=1)
+
+  g_old = t(X) %*% Y
+  s = g_old
+
+  Xw = matrix(0, rows=nrow(X), cols=1)
+
+  iter = 0
+  continue = TRUE
+  while(continue & iter < maxIterations)  {
+    # minimizing primal obj along direction s
+    step_sz = 0
+    Xd = X %*% s
+    wd = lambda * sum(w * s)
+    dd = lambda * sum(s * s)
+    continue1 = TRUE
+    iiter = 0
+    while(continue1 & iiter < maxii){
+      tmp_Xw = Xw + step_sz*Xd
+      out = 1 - Y * (tmp_Xw)
+      sv = (out > 0)
+      out = out * sv
+      g = wd + step_sz*dd - sum(out * Y * Xd)
+      h = dd + sum(Xd * sv * Xd)
+      step_sz = step_sz - g/h
+      continue1 = (g*g/h >= epsilon)
+      iiter = iiter + 1
+    }
+
+    #update weights
+    w = w + step_sz*s
+    Xw = Xw + step_sz*Xd
+
+    out = 1 - Y * Xw
+    sv = (out > 0)
+    out = sv * out
+    obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
+    g_new = t(X) %*% (out * Y) - lambda * w
+
+    if(verbose) {
+      colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,")
+      print("Iter: " + toString(iter) + " InnerIter: " + toString(iiter) +" 
--- "+ colstr + " Obj:" + obj)
+    }
+
+    tmp = sum(s * g_old)
+    continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
+
+    #non-linear CG step
+    be = sum(g_new * g_new)/sum(g_old * g_old)
+    s = be * s + g_new
+    g_old = g_new
+
+    iter = iter + 1
+  }
+  model = w
+  write(model, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9.dml
new file mode 100644
index 0000000..cafd93d
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), 
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+              ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), 
list($r, $c)))
+W = rand(rows=$r, cols=$c, min=0, max=1, pdf='uniform', seed=5)
+step_sz = 4
+s = t(X) %*% Y
+Xd = X %*% s
+Z0 = W + step_sz * X
+Z1 = 1 - Y * Z0
+Z2 = (Z1 > 0)
+Z3 = Z1 * Z2
+Z = sum(Z3 * Y * Xd)
+write(Z, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9Reference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9Reference.dml
new file mode 100644
index 0000000..6e6575f
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9Reference.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+W = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, pdf='uniform', seed=5)
+step_sz = 4
+s = t(X) %*% Y
+Xd = X %*% s
+Z0 = W + step_sz * X
+Z1 = 1 - Y * Z0
+Z2 = (Z1 > 0)
+Z3 = Z1 * Z2
+Z = sum(Z3 * Y * Xd)
+write(Z, $Z)

Reply via email to