This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cb61dc7 [SYSTEMDS-3018] Federated Planning Operator Support
cb61dc7 is described below
commit cb61dc74526655fb7bf9c6db4b857b6f6a070230
Author: sebwrede <[email protected]>
AuthorDate: Tue Mar 8 12:31:01 2022 +0100
[SYSTEMDS-3018] Federated Planning Operator Support
Closes #1557.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 2 +
src/main/java/org/apache/sysds/hops/Hop.java | 11 +-
.../org/apache/sysds/lops/TernaryAggregate.java | 8 +-
.../runtime/instructions/FEDInstructionParser.java | 5 +
.../fed/AggregateTernaryFEDInstruction.java | 18 ++-
.../fedplanning/FederatedL2SVMPlanningTest.java | 144 +++++++++++++++++++++
.../fedplanning/FederatedMultiplyPlanningTest.java | 78 ++++++-----
.../fedplanning/FederatedL2SVMPlanningTest.dml | 132 +++++++++++++++++++
.../FederatedL2SVMPlanningTestReference.dml | 131 +++++++++++++++++++
.../fedplanning/FederatedMultiplyPlanningTest9.dml | 35 +++++
.../FederatedMultiplyPlanningTest9Reference.dml | 33 +++++
11 files changed, 554 insertions(+), 43 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 078c754..31309c1 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -668,11 +668,13 @@ public class AggBinaryOp extends MultiThreadedHop {
new Transform(lY, ReOrgOp.TRANS, getDataType(),
getValueType(), ExecType.CP, k);
tY.getOutputParameters().setDimensions(Y.getDim2(),
Y.getDim1(), getBlocksize(), Y.getNnz());
setLineNumbers(tY);
+ updateLopFedOut(tY);
//matrix mult
Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(),
getValueType(), ExecType.CP, k);
mult.getOutputParameters().setDimensions(Y.getDim2(),
X.getDim2(), getBlocksize(), getNnz());
setLineNumbers(mult);
+ updateLopFedOut(mult);
//result transpose (dimensions set outside)
Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(),
getValueType(), ExecType.CP, k);
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 003492f..f91fad9 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -375,10 +375,19 @@ public abstract class Hop implements ParseInfo {
public boolean requiresLineageCaching() {
return _requiresLineageCaching;
}
+
+ public void updateLopFedOut(Lop lop){
+ updateLopFedOut(lop, getExecType(), _federatedOutput);
+ }
+
+ public void updateLopFedOut(Lop lop, ExecType execType, FederatedOutput
fedOut){
+ if ( execType == ExecType.FED )
+ lop.setFederatedOutput(fedOut);
+ }
public void constructAndSetLopsDataFlowProperties() {
//propagate federated output configuration to lops
- if( isFederated() )
+ if( isFederated() || getLops().getExecType() == ExecType.FED )
getLops().setFederatedOutput(_federatedOutput);
if ( prefetchActivated() )
getLops().activatePrefetch();
diff --git a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
index 6058c63..65773c0 100644
--- a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
@@ -82,9 +82,13 @@ public class TernaryAggregate extends Lop
sb.append( OPERAND_DELIMITOR );
sb.append( prepOutputOperand(output));
- if( getExecType() == ExecType.CP ) {
+ if( getExecType() == ExecType.CP || getExecType() ==
ExecType.FED ) {
sb.append( OPERAND_DELIMITOR );
- sb.append( _numThreads );
+ sb.append( _numThreads );
+ if ( getExecType() == ExecType.FED ){
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( _fedOutput.name() );
+ }
}
return sb.toString();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 8000da7..11ea4e0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions;
import org.apache.sysds.lops.Append;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AppendFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
@@ -42,6 +43,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "fedinit" , FEDType.Init );
String2FEDInstructionType.put( "tsmm" , FEDType.Tsmm );
String2FEDInstructionType.put( "ba+*" ,
FEDType.AggregateBinary );
+ String2FEDInstructionType.put( "tak+*" ,
FEDType.AggregateTernary);
String2FEDInstructionType.put( "uak+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uark+" ,
FEDType.AggregateUnary );
@@ -59,6 +61,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "*" , FEDType.Binary );
String2FEDInstructionType.put( "/" , FEDType.Binary );
String2FEDInstructionType.put( "1-*" , FEDType.Binary);
//special * case
+ String2FEDInstructionType.put( "max" , FEDType.Binary );
// Reorg Instruction Opcodes (repositioning of existing values)
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
@@ -106,6 +109,8 @@ public class FEDInstructionParser extends InstructionParser
return
ReorgFEDInstruction.parseInstruction(str);
case Append:
return
AppendFEDInstruction.parseInstruction(str);
+ case AggregateTernary:
+ return
AggregateTernaryFEDInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid
FEDERATED Instruction Type: " + fedtype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 1b0c0bf..cfe0baf 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -42,8 +42,8 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
// private static final Log LOG =
LogFactory.getLog(AggregateTernaryFEDInstruction.class.getName());
private AggregateTernaryFEDInstruction(Operator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
- String opcode, String istr) {
- super(FEDType.AggregateTernary, op, in1, in2, in3, out, opcode,
istr);
+ String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.AggregateTernary, op, in1, in2, in3, out, opcode,
istr, fedOut);
}
public static AggregateTernaryFEDInstruction parseInstruction(String
str) {
@@ -51,16 +51,19 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
String opcode = parts[0];
if(opcode.equalsIgnoreCase("tak+*") ||
opcode.equalsIgnoreCase("tack+*")) {
- InstructionUtils.checkNumFields(parts, 5);
+ InstructionUtils.checkNumFields(parts, 5, 6);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
int numThreads = Integer.parseInt(parts[5]);
+ FederatedOutput fedOut = FederatedOutput.NONE;
+ if ( parts.length == 7 )
+ fedOut = FederatedOutput.valueOf(parts[6]);
AggregateTernaryOperator op =
InstructionUtils.parseAggregateTernaryOperator(opcode, numThreads);
- return new AggregateTernaryFEDInstruction(op, in1, in2,
in3, out, opcode, str);
+ return new AggregateTernaryFEDInstruction(op, in1, in2,
in3, out, opcode, str, fedOut);
}
else {
throw new
DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown
opcode " + opcode);
@@ -77,7 +80,8 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
&&
mo2.getFedMapping().isAligned(mo3.getFedMapping(), mo1.isFederated(FType.ROW) ?
AlignType.ROW : AlignType.COL)) {
FederatedRequest fr1 =
FederationUtils.callInstruction(getInstructionString(), output,
new CPOperand[] {input1, input2, input3},
- new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+ new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()},
+ true);
FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
Future<FederatedResponse>[] response =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
@@ -96,7 +100,7 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(ec.getScalarInput(input3));
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[] {input1, input2, input3},
- new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr1.getID()});
+ new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr1.getID()}, true);
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
FederatedRequest fr4 =
mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
@@ -121,7 +125,7 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
FederatedRequest[] fr2 =
mo1.getFedMapping().broadcastSliced(mo2, false);
FederatedRequest fr3 =
FederationUtils.callInstruction(getInstructionString(), output,
new CPOperand[] {input1, input2, input3},
- new long[] {mo1.getFedMapping().getID(),
fr2[0].getID(), fr1[0].getID()});
+ new long[] {mo1.getFedMapping().getID(),
fr2[0].getID(), fr1[0].getID()}, true);
FederatedRequest fr4 = new
FederatedRequest(RequestType.GET_VAR, fr3.getID());
Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4);
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
new file mode 100644
index 0000000..75fc236
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
[email protected]
+public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
+ private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ public final int rows = 100;
+ public final int cols = 10;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ public void runL2SVMTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
+ loadAndRunTest(expectedHeavyHitters);
+ }
+
+ private void writeInputMatrices(){
+ writeStandardRowFedMatrix("X1", 65, null);
+ writeStandardRowFedMatrix("X2", 75, null);
+ writeBinaryVector("Y", 44, null);
+
+ }
+
+ private void writeBinaryVector(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+ double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed);
+ for(int i = 0; i < rows; i++)
+ matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1,
blocksize, rows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+ writeStandardMatrix(matrixName, seed, rows, privacyConstraint);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed, int
numRows, PrivacyConstraint privacyConstraint){
+ double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1,
seed);
+ writeStandardMatrix(matrixName, numRows, privacyConstraint,
matrix);
+ }
+
+ private void writeStandardMatrix(String matrixName, int numRows,
PrivacyConstraint privacyConstraint, double[][] matrix){
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
cols, blocksize, (long) numRows * cols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+ }
+
+ private void writeStandardRowFedMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+ int halfRows = rows/2;
+ writeStandardMatrix(matrixName, seed, halfRows,
privacyConstraint);
+ }
+
+ private void loadAndRunTest(String[] expectedHeavyHitters){
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+
+ Thread t1 = null, t2 = null;
+
+ try {
+ OptimizerUtils.FEDERATED_COMPILATION = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ writeInputMatrices();
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] { "-stats", "-explain",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y=" + input("Y"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ OptimizerUtils.FEDERATED_COMPILATION = false;
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"),
+ "Y=" + input("Y"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ }
+ finally {
+ OptimizerUtils.FEDERATED_COMPILATION = false;
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 1e59b86..6bc993e 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -49,6 +49,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
private final static String TEST_NAME_6 =
"FederatedMultiplyPlanningTest6";
private final static String TEST_NAME_7 =
"FederatedMultiplyPlanningTest7";
private final static String TEST_NAME_8 =
"FederatedMultiplyPlanningTest8";
+ private final static String TEST_NAME_9 =
"FederatedMultiplyPlanningTest9";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -68,6 +69,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
addTestConfiguration(TEST_NAME_6, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_7, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_8, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
+ addTestConfiguration(TEST_NAME_9, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
}
@Parameterized.Parameters
@@ -128,6 +130,12 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
federatedTwoMatricesSingleNodeTest(TEST_NAME_8,
expectedHeavyHitters);
}
+ @Test
+ public void federatedMultiplyPlanningTest9(){
+ String[] expectedHeavyHitters = new String[]{"fed_+*",
"fed_1-*", "fed_fedinit", "fed_tak+*", "fed_max"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_9,
expectedHeavyHitters);
+ }
+
private void writeStandardMatrix(String matrixName, long seed){
writeStandardMatrix(matrixName, seed, new
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
}
@@ -201,52 +209,56 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
}
private void federatedTwoMatricesTest(Types.ExecMode execMode, String
testName, String[] expectedHeavyHitters) {
- OptimizerUtils.FEDERATED_COMPILATION = true;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
if(rtplatform == Types.ExecMode.SPARK) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
+ Thread t1 = null, t2 = null;
- getAndLoadTestConfiguration(testName);
- String HOME = SCRIPT_DIR + TEST_DIR;
-
- writeInputMatrices(testName);
-
- int port1 = getRandomAvailablePort();
- int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ try{
+ OptimizerUtils.FEDERATED_COMPILATION = true;
- // Run actual dml script with federated matrix
- fullDMLScriptName = HOME + testName + ".dml";
- programArgs = new String[] {"-stats", "-explain", "-nvargs",
"X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "X2=" + TestUtils.federatedAddress(port2, input("X2")),
- "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
- "Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
- rewriteRealProgramArgs(testName, port1, port2);
- runTest(true, false, null, -1);
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
- OptimizerUtils.FEDERATED_COMPILATION = false;
+ writeInputMatrices(testName);
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + testName + "Reference.dml";
- programArgs = new String[] {"-nvargs", "X1=" + input("X1"),
"X2=" + input("X2"), "Y1=" + input("Y1"),
- "Y2=" + input("Y2"), "Z=" + expected("Z")};
- rewriteReferenceProgramArgs(testName);
- runTest(true, false, null, -1);
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
- // compare via files
- compareResults(1e-9);
- if (!heavyHittersContainsAllString(expectedHeavyHitters))
- fail("The following expected heavy hitters are missing:
"
- +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] {"-stats", "-explain",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2,
input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ rewriteRealProgramArgs(testName, port1, port2);
+ runTest(true, false, null, -1);
- TestUtils.shutdownThreads(t1, t2);
+ OptimizerUtils.FEDERATED_COMPILATION = false;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z")};
+ rewriteReferenceProgramArgs(testName);
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ } finally {
+ OptimizerUtils.FEDERATED_COMPILATION = false;
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
}
private void rewriteRealProgramArgs(String testName, int port1, int
port2){
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.dml
new file mode 100644
index 0000000..13a63bb
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.dml
@@ -0,0 +1,132 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+ maxii = 20
+ verbose = FALSE
+ columnId = -1
+ Y = read($Y)
+ X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+ intercept = FALSE
+ epsilon = 1e-12
+ lambda = 1
+ maxIterations = 100
+
+ #check input parameter assertions
+ if(nrow(X) < 2)
+ stop("L2SVM: Stopping due to invalid inputs: Not possible to learn a
binary class classifier without at least 2 rows")
+ if(epsilon < 0)
+ stop("L2SVM: Stopping due to invalid argument: Tolerance (tol) must be
non-negative")
+ if(lambda < 0)
+ stop("L2SVM: Stopping due to invalid argument: Regularization constant
(reg) must be non-negative")
+ if(maxIterations < 1)
+ stop("L2SVM: Stopping due to invalid argument: Maximum iterations should
be a positive integer")
+ if(ncol(Y) < 1)
+ stop("L2SVM: Stopping due to invalid multiple label columns, maybe use
MSVM instead?")
+
+ #check input lables and transform into -1/1
+ check_min = min(Y)
+ check_max = max(Y)
+
+ num_min = sum(Y == check_min)
+ num_max = sum(Y == check_max)
+
+ # TODO make this a stop condition for l2svm instead of just printing.
+ if(num_min + num_max != nrow(Y))
+ print("L2SVM: WARNING invalid number of labels in Y: "+num_min+" "+num_max)
+
+ # Scale inputs to -1 for negative, and 1 for positive classification
+ if(check_min != -1 | check_max != +1)
+ Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max -
check_min)
+
+ # If column_id is -1 then we assume that the fundamental algorithm is MSVM,
+ # Therefore don't print message.
+ if(verbose & columnId == -1)
+ print('Running L2-SVM ')
+
+ num_samples = nrow(X)
+ num_classes = ncol(Y)
+
+ # Add Bias
+ num_rows_in_w = ncol(X)
+ if (intercept) {
+ ones = matrix(1, rows=num_samples, cols=1)
+ X = cbind(X, ones);
+ num_rows_in_w += 1
+ }
+
+ w = matrix(0, rows=num_rows_in_w, cols=1)
+
+ g_old = t(X) %*% Y
+ s = g_old
+
+ Xw = matrix(0, rows=nrow(X), cols=1)
+
+ iter = 0
+ continue = TRUE
+ while(continue & iter < maxIterations) {
+ # minimizing primal obj along direction s
+ step_sz = 0
+ Xd = X %*% s
+ wd = lambda * sum(w * s)
+ dd = lambda * sum(s * s)
+ continue1 = TRUE
+ iiter = 0
+ while(continue1 & iiter < maxii){
+ tmp_Xw = Xw + step_sz*Xd
+ out = 1 - Y * (tmp_Xw)
+ sv = (out > 0)
+ out = out * sv
+ g = wd + step_sz*dd - sum(out * Y * Xd)
+ h = dd + sum(Xd * sv * Xd)
+ step_sz = step_sz - g/h
+ continue1 = (g*g/h >= epsilon)
+ iiter = iiter + 1
+ }
+
+ #update weights
+ w = w + step_sz*s
+ Xw = Xw + step_sz*Xd
+
+ out = 1 - Y * Xw
+ sv = (out > 0)
+ out = sv * out
+ obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
+ g_new = t(X) %*% (out * Y) - lambda * w
+
+ if(verbose) {
+ colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,")
+ print("Iter: " + toString(iter) + " InnerIter: " + toString(iiter) +"
--- "+ colstr + " Obj:" + obj)
+ }
+
+ tmp = sum(s * g_old)
+ continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
+
+ #non-linear CG step
+ be = sum(g_new * g_new)/sum(g_old * g_old)
+ s = be * s + g_new
+ g_old = g_new
+
+ iter = iter + 1
+ }
+ model = w
+ write(model, $Z)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTestReference.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTestReference.dml
new file mode 100644
index 0000000..d92b3dc
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMPlanningTestReference.dml
@@ -0,0 +1,131 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+ maxii = 20
+ verbose = FALSE
+ columnId = -1
+ Y = read($Y)
+ X = rbind(read($X1), read($X2))
+ intercept = FALSE
+ epsilon = 1e-12
+ lambda = 1
+ maxIterations = 100
+
+ #check input parameter assertions
+ if(nrow(X) < 2)
+ stop("L2SVM: Stopping due to invalid inputs: Not possible to learn a
binary class classifier without at least 2 rows")
+ if(epsilon < 0)
+ stop("L2SVM: Stopping due to invalid argument: Tolerance (tol) must be
non-negative")
+ if(lambda < 0)
+ stop("L2SVM: Stopping due to invalid argument: Regularization constant
(reg) must be non-negative")
+ if(maxIterations < 1)
+ stop("L2SVM: Stopping due to invalid argument: Maximum iterations should
be a positive integer")
+ if(ncol(Y) < 1)
+ stop("L2SVM: Stopping due to invalid multiple label columns, maybe use
MSVM instead?")
+
+ #check input lables and transform into -1/1
+ check_min = min(Y)
+ check_max = max(Y)
+
+ num_min = sum(Y == check_min)
+ num_max = sum(Y == check_max)
+
+ # TODO make this a stop condition for l2svm instead of just printing.
+ if(num_min + num_max != nrow(Y))
+ print("L2SVM: WARNING invalid number of labels in Y: "+num_min+" "+num_max)
+
+ # Scale inputs to -1 for negative, and 1 for positive classification
+ if(check_min != -1 | check_max != +1)
+ Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max -
check_min)
+
+ # If column_id is -1 then we assume that the fundamental algorithm is MSVM,
+ # Therefore don't print message.
+ if(verbose & columnId == -1)
+ print('Running L2-SVM ')
+
+ num_samples = nrow(X)
+ num_classes = ncol(Y)
+
+ # Add Bias
+ num_rows_in_w = ncol(X)
+ if (intercept) {
+ ones = matrix(1, rows=num_samples, cols=1)
+ X = cbind(X, ones);
+ num_rows_in_w += 1
+ }
+
+ w = matrix(0, rows=num_rows_in_w, cols=1)
+
+ g_old = t(X) %*% Y
+ s = g_old
+
+ Xw = matrix(0, rows=nrow(X), cols=1)
+
+ iter = 0
+ continue = TRUE
+ while(continue & iter < maxIterations) {
+ # minimizing primal obj along direction s
+ step_sz = 0
+ Xd = X %*% s
+ wd = lambda * sum(w * s)
+ dd = lambda * sum(s * s)
+ continue1 = TRUE
+ iiter = 0
+ while(continue1 & iiter < maxii){
+ tmp_Xw = Xw + step_sz*Xd
+ out = 1 - Y * (tmp_Xw)
+ sv = (out > 0)
+ out = out * sv
+ g = wd + step_sz*dd - sum(out * Y * Xd)
+ h = dd + sum(Xd * sv * Xd)
+ step_sz = step_sz - g/h
+ continue1 = (g*g/h >= epsilon)
+ iiter = iiter + 1
+ }
+
+ #update weights
+ w = w + step_sz*s
+ Xw = Xw + step_sz*Xd
+
+ out = 1 - Y * Xw
+ sv = (out > 0)
+ out = sv * out
+ obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
+ g_new = t(X) %*% (out * Y) - lambda * w
+
+ if(verbose) {
+ colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,")
+ print("Iter: " + toString(iter) + " InnerIter: " + toString(iiter) +"
--- "+ colstr + " Obj:" + obj)
+ }
+
+ tmp = sum(s * g_old)
+ continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
+
+ #non-linear CG step
+ be = sum(g_new * g_new)/sum(g_old * g_old)
+ s = be * s + g_new
+ g_old = g_new
+
+ iter = iter + 1
+ }
+ model = w
+ write(model, $Z)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9.dml
new file mode 100644
index 0000000..cafd93d
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0),
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0),
list($r, $c)))
+W = rand(rows=$r, cols=$c, min=0, max=1, pdf='uniform', seed=5)
+step_sz = 4
+s = t(X) %*% Y
+Xd = X %*% s
+Z0 = W + step_sz * X
+Z1 = 1 - Y * Z0
+Z2 = (Z1 > 0)
+Z3 = Z1 * Z2
+Z = sum(Z3 * Y * Xd)
+write(Z, $Z)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9Reference.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9Reference.dml
new file mode 100644
index 0000000..6e6575f
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest9Reference.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+W = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, pdf='uniform', seed=5)
+step_sz = 4
+s = t(X) %*% Y
+Xd = X %*% s
+Z0 = W + step_sz * X
+Z1 = 1 - Y * Z0
+Z2 = (Z1 > 0)
+Z3 = Z1 * Z2
+Z = sum(Z3 * Y * Xd)
+write(Z, $Z)