This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new d522183 [SYSTEMDS-3018] Compilation of federated ops under privacy
constraints
d522183 is described below
commit d522183249bd792fffa7e00b833213eb172f1fbf
Author: sebwrede <[email protected]>
AuthorDate: Sun Jul 4 23:06:54 2021 +0200
[SYSTEMDS-3018] Compilation of federated ops under privacy constraints
Closes #1313.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 1 -
.../java/org/apache/sysds/hops/AggUnaryOp.java | 1 -
src/main/java/org/apache/sysds/hops/BinaryOp.java | 1 -
src/main/java/org/apache/sysds/hops/DataOp.java | 4 +
src/main/java/org/apache/sysds/hops/Hop.java | 35 +-
src/main/java/org/apache/sysds/hops/ReorgOp.java | 5 +-
src/main/java/org/apache/sysds/hops/TernaryOp.java | 2 -
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 3 +
.../RewriteAlgebraicSimplificationDynamic.java | 2 +-
.../hops/rewrite/RewriteFederatedExecution.java | 197 ++++++++
.../runtime/instructions/FEDInstructionParser.java | 6 +
.../runtime/instructions/cp/SqlCPInstruction.java | 9 +
.../fed/AggregateBinaryFEDInstruction.java | 74 ++-
.../runtime/privacy/propagation/OperatorType.java | 47 ++
.../privacy/propagation/PrivacyPropagator.java | 503 +++++++++------------
.../fedplanning/FederatedMultiplyPlanningTest.java | 8 +-
16 files changed, 552 insertions(+), 346 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index a17430e..9b3356f 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -102,7 +102,6 @@ public class AggBinaryOp extends MultiThreadedHop
outerOp = outOp;
getInput().add(0, in1);
getInput().add(1, in2);
- updateETFed();
in1.getParent().add(this);
in2.getParent().add(this);
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 4840f97..9c18f49 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -59,7 +59,6 @@ public class AggUnaryOp extends MultiThreadedHop
_direction = idx;
getInput().add(0, inp);
inp.getParent().add(this);
- updateETFed();
}
@Override
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index f114bc0..8826f92 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -97,7 +97,6 @@ public class BinaryOp extends MultiThreadedHop
op = o;
getInput().add(0, inp1);
getInput().add(1, inp2);
- updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java
b/src/main/java/org/apache/sysds/hops/DataOp.java
index 52d424e..03dfd08 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -355,6 +355,10 @@ public class DataOp extends Hop {
return( _op == OpOpData.PERSISTENTREAD || _op ==
OpOpData.PERSISTENTWRITE );
}
+ public boolean isFederatedData(){
+ return _op == OpOpData.FEDERATED;
+ }
+
@Override
public String getOpString() {
String s = new String("");
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index e01ffa1..0ef6d96 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -86,9 +86,9 @@ public abstract class Hop implements ParseInfo {
protected ExecType _etypeForced = null; //exec type forced via platform
or external optimizer
/**
- * Boolean defining if the output of the operation should be federated.
- * If it is true, the output should be kept at federated sites.
- * If it is false, the output should be retrieved by the coordinator.
+ * Field defining if the output of the operation should be federated.
+ * If it is fout, the output should be kept at federated sites.
+ * If it is lout, the output should be retrieved by the coordinator.
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
@@ -173,6 +173,14 @@ public abstract class Hop implements ParseInfo {
{
return _etype;
}
+
+ public void setExecType(ExecType execType){
+ _etype = execType;
+ }
+
+ public void setFederatedOutput(FederatedOutput federatedOutput){
+ _federatedOutput = federatedOutput;
+ }
public void resetExecType()
{
@@ -770,25 +778,12 @@ public abstract class Hop implements ParseInfo {
/**
* Update the execution type if input is federated and federated
compilation is activated.
* Federated compilation is activated in OptimizerUtils.
+ * This method only has an effect if FEDERATED_COMPILATION is activated.
*/
protected void updateETFed(){
- if ( inputIsFED() )
+ if ( _federatedOutput == FederatedOutput.FOUT ||
_federatedOutput == FederatedOutput.LOUT )
_etype = ExecType.FED;
}
-
- /**
- * Returns true if any input has federated ExecType.
- * This method can only return true if FedDecision is activated.
- * @return true if any input has federated ExecType
- */
- protected boolean inputIsFED(){
- if ( !OptimizerUtils.FEDERATED_COMPILATION )
- return false;
- for ( Hop input : _input )
- if ( input.isFederated() || input.isFederatedOutput() )
- return true;
- return false;
- }
public boolean isFederated(){
return getExecType() == ExecType.FED;
@@ -798,6 +793,10 @@ public abstract class Hop implements ParseInfo {
return _federatedOutput == FederatedOutput.FOUT;
}
+ public boolean someInputFederated(){
+ return getInput().stream().anyMatch(Hop::hasFederatedOutput);
+ }
+
public ArrayList<Hop> getParent() {
return _parent;
}
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java
b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index 89aeb03..d2dcebf 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -61,8 +61,7 @@ public class ReorgOp extends MultiThreadedHop
_op = o;
getInput().add(0, inp);
inp.getParent().add(this);
- updateETFed();
-
+
//compute unknown dims and nnz
refreshSizeInformation();
}
@@ -78,8 +77,6 @@ public class ReorgOp extends MultiThreadedHop
in.getParent().add(this);
}
- updateETFed();
-
//compute unknown dims and nnz
refreshSizeInformation();
}
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index f254d0d..7d9fca6 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -80,7 +80,6 @@ public class TernaryOp extends MultiThreadedHop
getInput().add(0, inp1);
getInput().add(1, inp2);
getInput().add(2, inp3);
- updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
inp3.getParent().add(this);
@@ -98,7 +97,6 @@ public class TernaryOp extends MultiThreadedHop
getInput().add(3, inp4);
getInput().add(4, inp5);
getInput().add(5, inp6);
- updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
inp3.getParent().add(this);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 467d476..2e3edb0 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -138,6 +138,9 @@ public class ProgramRewriter
_dagRuleSet.add( new
RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
_dagRuleSet.add( new
RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
}
+ if ( OptimizerUtils.FEDERATED_COMPILATION ) {
+ _dagRuleSet.add( new
RewriteFederatedExecution() );
+ }
}
// cleanup after all rewrites applied
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 71a7240..269050c 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -139,7 +139,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
//recursively process children
for( int i=0; i<hop.getInput().size(); i++)
{
- Hop hi = hop.getInput().get(i);
+ Hop hi = hop.getInput(i);
//process childs recursively first (to allow roll-up)
if( descendFirst )
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
new file mode 100644
index 0000000..29cda4a
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.privacy.DMLPrivacyException;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
+import org.apache.sysds.utils.JSONHelper;
+import org.apache.wink.json4j.JSONObject;
+
+import javax.net.ssl.SSLException;
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
+import java.util.concurrent.Future;
+
+public class RewriteFederatedExecution extends HopRewriteRule {
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state) {
+ if ( roots == null )
+ return null;
+ for ( Hop root : roots )
+ visitHop(root);
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+ if( root == null )
+ return null;
+ visitHop(root);
+ return root;
+ }
+
+ private void visitHop(Hop hop){
+ if (hop.isVisited())
+ return;
+
+ // Depth first to get to the input
+ for ( Hop input : hop.getInput() )
+ visitHop(input);
+
+ privacyBasedHopDecisionWithFedCall(hop);
+ hop.setVisited();
+ }
+
+ private static void privacyBasedHopDecision(Hop hop){
+ PrivacyPropagator.hopPropagation(hop);
+ PrivacyConstraint privacyConstraint = hop.getPrivacy();
+ if ( privacyConstraint != null &&
privacyConstraint.hasConstraints() )
+
hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+ else if ( hop.someInputFederated() )
+
hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+ }
+
+ /**
+ * Get privacy constraints of DataOps from federated worker,
+ * propagate privacy constraints from input to current hop,
+ * and set federated output flag.
+ * @param hop current hop
+ */
+ private static void privacyBasedHopDecisionWithFedCall(Hop hop){
+ loadFederatedPrivacyConstraints(hop);
+ privacyBasedHopDecision(hop);
+ }
+
+ /**
+ * Get privacy constraints from federated workers for DataOps.
+ * @hop hop for which privacy constraints are loaded
+ */
+ private static void loadFederatedPrivacyConstraints(Hop hop){
+ if ( isFederatedDataOp(hop) && hop.getPrivacy() == null){
+ try {
+ PrivacyConstraint privConstraint =
unwrapPrivConstraint(sendPrivConstraintRequest(hop));
+ hop.setPrivacy(privConstraint);
+ }
+ catch(Exception e) {
+ throw new DMLException(e.getMessage());
+ }
+ }
+ }
+
+ private static Future<FederatedResponse> sendPrivConstraintRequest(Hop
hop)
+ throws UnknownHostException, SSLException
+ {
+ String address = ((LiteralOp)
hop.getInput(0).getInput(0)).getStringValue();
+ String[] parsedAddress = InitFEDInstruction.parseURL(address);
+ String host = parsedAddress[0];
+ int port = Integer.parseInt(parsedAddress[1]);
+ PrivacyConstraintRetriever retriever = new
PrivacyConstraintRetriever(parsedAddress[2]);
+ FederatedRequest privacyRetrieval =
+ new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever);
+ InetSocketAddress inetAddress = new
InetSocketAddress(InetAddress.getByName(host), port);
+ return FederatedData.executeFederatedOperation(inetAddress,
privacyRetrieval);
+ }
+
+ private static PrivacyConstraint
unwrapPrivConstraint(Future<FederatedResponse> privConstraintFuture)
+ throws Exception
+ {
+ FederatedResponse privConstraintResponse =
privConstraintFuture.get();
+ return (PrivacyConstraint) privConstraintResponse.getData()[0];
+ }
+
+ private static boolean isFederatedDataOp(Hop hop){
+ return hop instanceof DataOp && ((DataOp)
hop).isFederatedData();
+ }
+
+ /**
+ * FederatedUDF for retrieving privacy constraint of data stored in
file name.
+ */
+ public static class PrivacyConstraintRetriever extends FederatedUDF {
+ private static final long serialVersionUID =
3551741240135587183L;
+ private final String filename;
+
+ public PrivacyConstraintRetriever(String filename){
+ super(new long[]{});
+ this.filename = filename;
+ }
+
+ /**
+ * Reads metadata JSON object, parses privacy constraint and
returns the constraint in FederatedResponse.
+ * @param ec execution context
+ * @param data one or many data objects
+ * @return FederatedResponse with privacy constraint object
+ */
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ PrivacyConstraint privacyConstraint;
+ FileSystem fs = null;
+ try {
+ String mtdname =
DataExpression.getMTDFileName(filename);
+ Path path = new Path(mtdname);
+ fs = IOUtilFunctions.getFileSystem(mtdname);
+ try(BufferedReader br = new BufferedReader(new
InputStreamReader(fs.open(path)))) {
+ JSONObject metadataObject =
JSONHelper.parse(br);
+ privacyConstraint =
PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject);
+ }
+ }
+ catch (DMLPrivacyException |
FederatedWorkerHandlerException ex){
+ throw ex;
+ }
+ catch (Exception ex) {
+ String msg = "Exception in reading metadata of:
" + filename;
+ throw new DMLRuntimeException(msg);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(fs);
+ }
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint);
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 34db155..bae38a2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -19,9 +19,11 @@
package org.apache.sysds.runtime.instructions;
+import org.apache.sysds.lops.Append;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.AppendFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
@@ -66,6 +68,8 @@ public class FEDInstructionParser extends InstructionParser
// Ternary Instruction Opcodes
String2FEDInstructionType.put( "+*" , FEDType.Ternary);
String2FEDInstructionType.put( "-*" , FEDType.Ternary);
+
+ String2FEDInstructionType.put(Append.OPCODE, FEDType.Append);
}
public static FEDInstruction parseSingleInstruction (String str ) {
@@ -98,6 +102,8 @@ public class FEDInstructionParser extends InstructionParser
return
TernaryFEDInstruction.parseInstruction(str);
case Reorg:
return
ReorgFEDInstruction.parseInstruction(str);
+ case Append:
+ return
AppendFEDInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid
FEDERATED Instruction Type: " + fedtype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java
index add9bb7..e4453bc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java
@@ -136,4 +136,13 @@ public class SqlCPInstruction extends CPInstruction {
public CPOperand getOutput(){
return _output;
}
+
+ /**
+ * Returns the inputs of the instruction.
+ * Inputs are conn, user, pass, and query.
+ * @return inputs of the instruction
+ */
+ public CPOperand[] getInputs(){
+ return new CPOperand[]{_conn, _user, _pass, _query};
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 10dd7c6..535e12d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -24,6 +24,7 @@ import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
@@ -145,24 +146,43 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
}
//#2 vector - federated matrix multiplication
else if (mo2.isFederated(FType.ROW)) {// VM + MM
- //construct commands: broadcast rhs, fed mv, retrieve
results
- FederatedRequest[] fr1 =
mo2.getFedMapping().broadcastSliced(mo1, true);
- FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2},
- new long[]{fr1[0].getID(),
mo2.getFedMapping().getID()}, true);
- if ( _fedOut.isForcedFederated() ){
- // Partial aggregates (set fedmapping to the
partial aggs)
- FederatedRequest fr3 =
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID());
- mo2.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
- setPartialOutput(mo2.getFedMapping(), mo1, mo2,
fr2.getID(), ec);
+ if ( mo1.isFederated(FType.COL) &&
isAggBinaryFedAligned(mo1,mo2) ){
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping
to the partial aggs)
+ mo2.getFedMapping().execute(getTID(),
true, fr2);
+ setPartialOutput(mo2.getFedMapping(),
mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ //execute federated operations and
aggregate
+ Future<FederatedResponse>[] tmp =
mo2.getFedMapping().execute(getTID(), fr2, fr3);
+ MatrixBlock ret =
FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(),
ret);
+ }
}
else {
- FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
- FederatedRequest fr4 =
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
- MatrixBlock ret = FederationUtils.aggAdd(tmp);
- ec.setMatrixOutput(output.getName(), ret);
+ //construct commands: broadcast rhs, fed mv,
retrieve results
+ FederatedRequest[] fr1 =
mo2.getFedMapping().broadcastSliced(mo1, true);
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{fr1[0].getID(),
mo2.getFedMapping().getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping
to the partial aggs)
+ FederatedRequest fr3 =
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ mo2.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
+ setPartialOutput(mo2.getFedMapping(),
mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 =
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+ //execute federated operations and
aggregate
+ Future<FederatedResponse>[] tmp =
mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret =
FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(),
ret);
+ }
}
}
//#3 col-federated matrix vector multiplication
@@ -195,6 +215,28 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
}
/**
+ * Checks alignment of dimensions for the federated aggregate binary
processing without broadcast.
+ * If the begin and end ranges of mo1 has cols equal to the rows of the
begin and end ranges of mo2,
+ * the two inputs are aligned for the processing of the federated
aggregate binary instruction without broadcasting.
+ * @param mo1 input matrix object 1
+ * @param mo2 input matrix object 2
+ * @return true if the two inputs are aligned for aggregate binary
processing without broadcasting
+ */
+ private static boolean isAggBinaryFedAligned(MatrixObject mo1,
MatrixObject mo2){
+ FederatedRange[] mo1FederatedRanges =
mo1.getFedMapping().getFederatedRanges();
+ FederatedRange[] mo2FederatedRanges =
mo2.getFedMapping().getFederatedRanges();
+ for ( int i = 0; i < mo1FederatedRanges.length; i++ ){
+ FederatedRange mo1FedRange = mo1FederatedRanges[i];
+ FederatedRange mo2FedRange = mo2FederatedRanges[i];
+
+ if ( mo1FedRange.getBeginDims()[1] !=
mo2FedRange.getBeginDims()[0]
+ || mo1FedRange.getEndDims()[1] !=
mo2FedRange.getEndDims()[0])
+ return false;
+ }
+ return true;
+ }
+
+ /**
* Sets the output with a federated mapping of overlapping partial
aggregates.
* @param federationMap federated map from which the federated metadata
is retrieved
* @param mo1 matrix object with number of rows used to set the number
of rows of the output
diff --git
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java
index 18a94b1..ebd9adf 100644
---
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java
+++
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java
@@ -19,7 +19,54 @@
package org.apache.sysds.runtime.privacy.propagation;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
public enum OperatorType {
Aggregate,
NonAggregate;
+
+ /**
+ * Returns the operator type of MMChainCPInstruction based on the input
data characteristics.
+ * @param inst MMChainCPInstruction for which operator type is returned
+ * @param ec execution context
+ * @return operator type of instruction
+ */
+ public static OperatorType getAggregationType(MMChainCPInstruction
inst, ExecutionContext ec){
+ DataCharacteristics inputDataCharacteristics =
ec.getDataCharacteristics(inst.getInputs()[0].getName());
+ if ( inputDataCharacteristics.getRows() == 1 &&
inputDataCharacteristics.getCols() == 1)
+ return NonAggregate;
+ else return Aggregate;
+ }
+
+ /**
+ * Returns the operator type of MMTSJCPInstruction based on the input
data characteristics and the MMTSJType.
+ * @param inst MMTSJCPInstruction for which operator type is returned
+ * @param ec execution context
+ * @return operator type of instruction
+ */
+ public static OperatorType getAggregationType(MMTSJCPInstruction inst,
ExecutionContext ec){
+ DataCharacteristics inputDataCharacteristics =
ec.getDataCharacteristics(inst.getInputs()[0].getName());
+ if ( (inputDataCharacteristics.getRows() == 1 &&
inst.getMMTSJType() == MMTSJ.MMTSJType.LEFT)
+ || (inputDataCharacteristics.getCols() == 1 &&
inst.getMMTSJType() != MMTSJ.MMTSJType.LEFT) )
+ return OperatorType.NonAggregate;
+ else return OperatorType.Aggregate;
+ }
+
+ /**
+ * Returns the operator type of AggregateBinaryCPInstruction based on
the input data characteristics and the transpose.
+ * @param inst AggregateBinaryCPInstruction for which operator type is
returned
+ * @param ec execution context
+ * @return operator type of instruction
+ */
+ public static OperatorType
getAggregationType(AggregateBinaryCPInstruction inst, ExecutionContext ec){
+ DataCharacteristics inputDC =
ec.getDataCharacteristics(inst.input1.getName());
+ if ((inputDC.getCols() == 1 && !inst.transposeLeft) ||
(inputDC.getRows() == 1 && inst.transposeLeft) )
+ return OperatorType.NonAggregate;
+ else return OperatorType.Aggregate;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
index 71e1d46..945df73 100644
---
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
+++
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
@@ -19,8 +19,18 @@
package org.apache.sysds.runtime.privacy.propagation;
-import java.util.*;
-
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
@@ -39,15 +49,39 @@ import org.apache.wink.json4j.JSONObject;
*/
public class PrivacyPropagator
{
+ /**
+ * Parses the privacy constraint of the given metadata object
+ * and sets the field of the given Data if the privacy constraint is
not null.
+ * @param cd data for which privacy constraint is set
+ * @param mtd metadata object
+ * @return data object with privacy constraint set
+ * @throws JSONException during parsing of metadata
+ */
public static Data parseAndSetPrivacyConstraint(Data cd, JSONObject mtd)
throws JSONException
{
+ PrivacyConstraint mtdPrivConstraint =
parseAndReturnPrivacyConstraint(mtd);
+ if ( mtdPrivConstraint != null )
+ cd.setPrivacyConstraints(mtdPrivConstraint);
+ return cd;
+ }
+
+ /**
+ * Parses the privacy constraint of the given metadata object
+ * or returns null if no privacy constraint is set in the metadata.
+ * @param mtd metadata
+ * @return privacy constraint parsed from metadata object
+ * @throws JSONException during parsing of metadata
+ */
+ public static PrivacyConstraint
parseAndReturnPrivacyConstraint(JSONObject mtd)
+ throws JSONException
+ {
if ( mtd.containsKey(DataExpression.PRIVACY) ) {
String privacyLevel =
mtd.getString(DataExpression.PRIVACY);
if ( privacyLevel != null )
- cd.setPrivacyConstraints(new
PrivacyConstraint(PrivacyLevel.valueOf(privacyLevel)));
+ return new
PrivacyConstraint(PrivacyLevel.valueOf(privacyLevel));
}
- return cd;
+ return null;
}
private static boolean anyInputHasLevel(PrivacyLevel[] inputLevels,
PrivacyLevel targetLevel){
@@ -92,7 +126,13 @@ public class PrivacyPropagator
return PrivacyLevel.None;
}
- public static PrivacyConstraint mergeNary(PrivacyConstraint[]
privacyConstraints, OperatorType operatorType){
+ /**
+ * Merges the given privacy constraints with the core propagation using
the given operator type.
+ * @param privacyConstraints array of privacy constraints to merge
+ * @param operatorType type of operation to use when merging with the
core propagation
+ * @return merged privacy constraint
+ */
+ private static PrivacyConstraint mergeNary(PrivacyConstraint[]
privacyConstraints, OperatorType operatorType){
PrivacyLevel[] privacyLevels = Arrays.stream(privacyConstraints)
.map(constraint -> {
if (constraint != null)
@@ -104,21 +144,17 @@ public class PrivacyPropagator
return new PrivacyConstraint(outputPrivacyLevel);
}
+ /**
+ * Merges the input privacy constraints using the core propagation with
NonAggregate operator type.
+ * @param privacyConstraint1 first privacy constraint
+ * @param privacyConstraint2 second privacy constraint
+ * @return merged privacy constraint
+ */
public static PrivacyConstraint mergeBinary(PrivacyConstraint
privacyConstraint1, PrivacyConstraint privacyConstraint2) {
if (privacyConstraint1 != null && privacyConstraint2 != null){
- PrivacyLevel privacyLevel1 =
privacyConstraint1.getPrivacyLevel();
- PrivacyLevel privacyLevel2 =
privacyConstraint2.getPrivacyLevel();
-
- // One of the inputs are private, hence the output must
be private.
- if (privacyLevel1 == PrivacyLevel.Private ||
privacyLevel2 == PrivacyLevel.Private)
- return new
PrivacyConstraint(PrivacyLevel.Private);
- // One of the inputs are private with aggregation
allowed, but none of the inputs are completely private,
- // hence the output must be private with aggregation.
- else if (privacyLevel1 ==
PrivacyLevel.PrivateAggregation || privacyLevel2 ==
PrivacyLevel.PrivateAggregation)
- return new
PrivacyConstraint(PrivacyLevel.PrivateAggregation);
- // Both inputs have privacy level "None", hence the
privacy constraint can be removed.
- else
- return null;
+ PrivacyLevel[] privacyLevels = new PrivacyLevel[]{
+
privacyConstraint1.getPrivacyLevel(),privacyConstraint2.getPrivacyLevel()};
+ return new
PrivacyConstraint(corePropagation(privacyLevels, OperatorType.NonAggregate));
}
else if (privacyConstraint1 != null)
return privacyConstraint1;
@@ -127,12 +163,36 @@ public class PrivacyPropagator
return null;
}
- public static PrivacyConstraint mergeNary(PrivacyConstraint[]
privacyConstraints){
- PrivacyConstraint mergedPrivacyConstraint =
privacyConstraints[0];
- for ( int i = 1; i < privacyConstraints.length; i++ ){
- mergedPrivacyConstraint =
mergeBinary(mergedPrivacyConstraint, privacyConstraints[i]);
+ /**
+ * Propagate privacy constraints from input hops to given hop.
+ * @param hop which the privacy constraints are propagated to
+ */
+ public static void hopPropagation(Hop hop){
+ PrivacyConstraint[] inputConstraints = hop.getInput().stream()
+ .map(Hop::getPrivacy).toArray(PrivacyConstraint[]::new);
+ if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop
instanceof ReorgOp )
+ hop.setPrivacy(mergeNary(inputConstraints,
OperatorType.NonAggregate));
+ else if ( hop instanceof AggBinaryOp || hop instanceof
AggUnaryOp || hop instanceof UnaryOp )
+ hop.setPrivacy(mergeNary(inputConstraints,
OperatorType.Aggregate));
+ }
+
+ /**
+ * Propagate privacy constraints to output variables
+ * based on privacy constraint of CPOperand output in instruction
+ * which has been set during privacy propagation preprocessing.
+ * @param inst instruction for which privacy constraints are propagated
+ * @param ec execution context
+ */
+ public static void postProcessInstruction(Instruction inst,
ExecutionContext ec){
+ // if inst has output
+ List<CPOperand> instOutputs = getOutputOperands(inst);
+ if (!instOutputs.isEmpty()){
+ for ( CPOperand output : instOutputs ){
+ PrivacyConstraint outputPrivacyConstraint =
output.getPrivacyConstraint();
+ if (
PrivacyUtils.someConstraintSetUnary(outputPrivacyConstraint) )
+ setOutputPrivacyConstraint(ec,
outputPrivacyConstraint, output.getName());
+ }
}
- return mergedPrivacyConstraint;
}
/**
@@ -145,127 +205,88 @@ public class PrivacyPropagator
public static Instruction preprocessInstruction(Instruction inst,
ExecutionContext ec){
switch ( inst.getType() ){
case CONTROL_PROGRAM:
- return preprocessCPInstructionFineGrained(
(CPInstruction) inst, ec );
+ return preprocessCPInstruction( (CPInstruction)
inst, ec );
case BREAKPOINT:
case SPARK:
case GPU:
case FEDERATED:
return inst;
default:
- throwExceptionIfPrivacyActivated(inst);
- return inst;
+ return throwExceptionIfInputOrInstPrivacy(inst,
ec);
}
}
- public static Instruction
preprocessCPInstructionFineGrained(CPInstruction inst, ExecutionContext ec){
- switch ( inst.getCPInstructionType() ){
- case AggregateBinary:
- if ( inst instanceof
AggregateBinaryCPInstruction ){
- // This can only be a matrix
multiplication and it does not count as an aggregation in terms of privacy.
- return
preprocessAggregateBinaryCPInstruction((AggregateBinaryCPInstruction)inst, ec);
- } else if ( inst instanceof
CovarianceCPInstruction ){
- return
preprocessCovarianceCPInstruction((CovarianceCPInstruction)inst, ec);
- } else preprocessInstructionSimple(inst, ec);
- case AggregateTernary:
- //TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessTernaryCPInstruction((ComputationCPInstruction) inst, ec);
- case AggregateUnary:
- // Assumption: aggregates in one or several
dimensions, number of dimensions may change, only certain slices of the data
may be aggregated upon, elements do not change position
- return
preprocessAggregateUnaryCPInstruction((AggregateUnaryCPInstruction)inst, ec);
- case Append:
- return
preprocessAppendCPInstruction((AppendCPInstruction) inst, ec);
+ private static Instruction preprocessCPInstruction(CPInstruction inst,
ExecutionContext ec){
+ switch(inst.getCPInstructionType()){
case Binary:
- // TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessBinaryCPInstruction((BinaryCPInstruction) inst, ec);
case Builtin:
case BuiltinNary:
- //TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessBuiltinNary((BuiltinNaryCPInstruction) inst, ec);
- /*case CentralMoment:
- break;
- case Compression:
- break;
- case Covariance:
- break;
- case Ctable:
- break;
- case Dnn:
- break;
- */
case FCall:
- //TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessExternal((FunctionCallCPInstruction) inst, ec);
- /*
- case MMChain:
- break;
- case MMTSJ:
- break;
- case MatrixIndexing:
- break;*/
- case MultiReturnBuiltin:
- case MultiReturnParameterizedBuiltin:
- // TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessMultiReturn((ComputationCPInstruction)inst, ec);
- /*case PMMJ:
- break;*/
case ParameterizedBuiltin:
- // TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessParameterizedBuiltin((ParameterizedBuiltinCPInstruction) inst, ec);
- /*case Partition:
- break;
- case QPick:
- break;
- case QSort:
- break;*/
case Quaternary:
- // TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessQuaternary((QuaternaryCPInstruction) inst, ec);
- /*case Rand:
- break;*/
case Reorg:
- // TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessUnaryCPInstruction((UnaryCPInstruction) inst, ec);
- /*case Reshape:
- break;
- case SpoofFused:
- break;
- case Sql:
- break;
- case StringInit:
- break;*/
case Ternary:
- // TODO: Support propagation of fine-grained
privacy constraints
- return
preprocessTernaryCPInstruction((ComputationCPInstruction) inst, ec);
- /*case UaggOuterChain:
- break;*/
case Unary:
- // Assumption: No aggregation, elements do not
change position, no change of dimensions
- return
preprocessUnaryCPInstruction((UnaryCPInstruction) inst, ec);
+ case MultiReturnBuiltin:
+ case MultiReturnParameterizedBuiltin:
+ case MatrixIndexing:
+ return mergePrivacyConstraintsFromInput( inst,
ec, OperatorType.NonAggregate );
+ case AggregateTernary:
+ case AggregateUnary:
+ return mergePrivacyConstraintsFromInput(inst,
ec, OperatorType.Aggregate);
+ case Append:
+ return
preprocessAppendCPInstruction((AppendCPInstruction) inst, ec);
+ case AggregateBinary:
+ if ( inst instanceof
AggregateBinaryCPInstruction )
+ return
preprocessAggregateBinaryCPInstruction((AggregateBinaryCPInstruction)inst, ec);
+ else return
throwExceptionIfInputOrInstPrivacy(inst, ec);
+ case MMTSJ:
+ OperatorType mmtsjOpType =
OperatorType.getAggregationType((MMTSJCPInstruction) inst, ec);
+ return mergePrivacyConstraintsFromInput(inst,
ec, mmtsjOpType);
+ case MMChain:
+ OperatorType mmChainOpType =
OperatorType.getAggregationType((MMChainCPInstruction) inst, ec);
+ return mergePrivacyConstraintsFromInput(inst,
ec, mmChainOpType);
case Variable:
return
preprocessVariableCPInstruction((VariableCPInstruction) inst, ec);
default:
- return preprocessInstructionSimple(inst, ec);
-
+ return throwExceptionIfInputOrInstPrivacy(inst,
ec);
}
}
- /**
- * Throw exception if privacy constraint activated for instruction or
for input to instruction.
- * @param inst covariance instruction
- * @param ec execution context
- * @return input instruction if privacy constraints are not activated
- */
- private static Instruction
preprocessCovarianceCPInstruction(CovarianceCPInstruction inst,
ExecutionContext ec){
- throwExceptionIfPrivacyActivated(inst);
- for ( CPOperand input : inst.getInputs() ){
- PrivacyConstraint privacyConstraint =
getInputPrivacyConstraint(ec, input);
- if ( privacyConstraint != null){
- throw new DMLPrivacyException("Input of
instruction " + inst + " has privacy constraints activated, but the constraints
are not propagated during preprocessing of instruction.");
- }
+ private static Instruction
preprocessVariableCPInstruction(VariableCPInstruction inst, ExecutionContext
ec){
+ switch ( inst.getVariableOpcode() ) {
+ case CopyVariable:
+ case MoveVariable:
+ case RemoveVariableAndFile:
+ case CastAsMatrixVariable:
+ case CastAsFrameVariable:
+ case Write:
+ case SetFileName:
+ case CastAsScalarVariable:
+ case CastAsDoubleVariable:
+ case CastAsIntegerVariable:
+ case CastAsBooleanVariable:
+ return propagateFirstInputPrivacy(inst, ec);
+ case CreateVariable:
+ return propagateSecondInputPrivacy(inst, ec);
+ case AssignVariable:
+ case RemoveVariable:
+ return mergePrivacyConstraintsFromInput( inst,
ec, OperatorType.NonAggregate );
+ case Read:
+ // Adds scalar object to variable map, hence
input (data type and filename) privacy should not be propagated
+ return inst;
+ default:
+ return throwExceptionIfInputOrInstPrivacy(inst,
ec);
}
- return inst;
}
+ /**
+ * Propagates fine-grained constraints if input has fine-grained
constraints,
+ * otherwise it propagates general constraints.
+ * @param inst aggregate binary instruction for which constraints are
propagated
+ * @param ec execution context
+ * @return instruction with merged privacy constraints propagated to it
and output CPOperand
+ */
private static Instruction
preprocessAggregateBinaryCPInstruction(AggregateBinaryCPInstruction inst,
ExecutionContext ec){
PrivacyConstraint[] privacyConstraints =
getInputPrivacyConstraints(ec, inst.getInputs());
if ( PrivacyUtils.someConstraintSetBinary(privacyConstraints) ){
@@ -279,7 +300,7 @@ public class PrivacyPropagator
ec.releaseMatrixInput(inst.input1.getName(),
inst.input2.getName());
}
else {
- mergedPrivacyConstraint =
mergeNary(privacyConstraints, OperatorType.Aggregate);
+ mergedPrivacyConstraint =
mergeNary(privacyConstraints, OperatorType.getAggregationType(inst, ec));
inst.setPrivacyConstraint(mergedPrivacyConstraint);
}
inst.output.setPrivacyConstraint(mergedPrivacyConstraint);
@@ -287,7 +308,13 @@ public class PrivacyPropagator
return inst;
}
- public static Instruction
preprocessAppendCPInstruction(AppendCPInstruction inst, ExecutionContext ec){
+ /**
+ * Propagates input privacy constraints using general and fine-grained
constraints depending on the AppendType.
+ * @param inst append instruction for which constraints are propagated
+ * @param ec execution context
+ * @return instruction with merged privacy constraints propagated to it
and output CPOperand
+ */
+ private static Instruction
preprocessAppendCPInstruction(AppendCPInstruction inst, ExecutionContext ec){
PrivacyConstraint[] privacyConstraints =
getInputPrivacyConstraints(ec, inst.getInputs());
if ( PrivacyUtils.someConstraintSetBinary(privacyConstraints) ){
if ( inst.getAppendType() ==
AppendCPInstruction.AppendType.STRING ){
@@ -327,92 +354,35 @@ public class PrivacyPropagator
return inst;
}
- public static Instruction
preprocessBinaryCPInstruction(BinaryCPInstruction inst, ExecutionContext ec){
- PrivacyConstraint privacyConstraint1 =
getInputPrivacyConstraint(ec, inst.input1);
- PrivacyConstraint privacyConstraint2 =
getInputPrivacyConstraint(ec, inst.input2);
- if ( privacyConstraint1 != null || privacyConstraint2 != null) {
- PrivacyConstraint mergedPrivacyConstraint =
mergeBinary(privacyConstraint1, privacyConstraint2);
- inst.setPrivacyConstraint(mergedPrivacyConstraint);
-
inst.output.setPrivacyConstraint(mergedPrivacyConstraint);
- }
- return inst;
- }
-
/**
- * Propagate privacy constraint to output if any of the elements are
private.
- * Privacy constraint is always propagated to instruction.
- * @param inst aggregate instruction
+ * Propagates privacy constraints from input to instruction and output
CPOperand based on given operator type.
+ * The propagation is done through the core propagation.
+ * @param inst instruction for which privacy is propagated
* @param ec execution context
- * @return updated instruction with propagated privacy constraints
+ * @param operatorType defining whether the instruction is aggregating
the input
+ * @return instruction with the merged privacy constraint propagated to
it and output CPOperand
*/
- private static Instruction
preprocessAggregateUnaryCPInstruction(AggregateUnaryCPInstruction inst,
ExecutionContext ec){
- PrivacyConstraint privacyConstraint =
getInputPrivacyConstraint(ec, inst.input1);
- if ( privacyConstraint != null ) {
- inst.setPrivacyConstraint(privacyConstraint);
- if ( inst.output != null){
- //Only propagate to output if any of the
elements are private.
- //It is an aggregation, hence the constraint
can be removed in case of any other privacy level.
- if(privacyConstraint.hasPrivateElements())
- inst.output.setPrivacyConstraint(new
PrivacyConstraint(PrivacyLevel.Private));
- }
- }
- return inst;
+ private static Instruction mergePrivacyConstraintsFromInput(Instruction
inst, ExecutionContext ec,
+ OperatorType operatorType){
+ return mergePrivacyConstraintsFromInput(inst, ec,
getInputOperands(inst), getOutputOperands(inst), operatorType);
}
/**
- * Throw exception if privacy constraints are activated or return
instruction if privacy is not activated
- * @param inst instruction
+ * Propagates privacy constraints from input to instruction and output
CPOperand based on given operator type.
+ * The propagation is done through the core propagation.
+ * @param inst instruction for which privacy is propagated
* @param ec execution context
- * @return instruction
+ * @param inputs to instruction
+ * @param outputs of instruction
+ * @param operatorType defining whether the instruction is aggregating
the input
+ * @return instruction with the merged privacy constraint propagated to
it and output CPOperand
*/
- public static Instruction preprocessInstructionSimple(Instruction inst,
ExecutionContext ec){
- throwExceptionIfPrivacyActivated(inst);
- return inst;
- }
-
-
- public static Instruction preprocessExternal(FunctionCallCPInstruction
inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(
- inst,
- ec,
- inst.getInputs(),
- inst.getBoundOutputParamNames().toArray(new String[0])
- );
- }
-
- public static Instruction
preprocessMultiReturn(ComputationCPInstruction inst, ExecutionContext ec){
- List<CPOperand> outputs = getOutputOperands(inst);
- return mergePrivacyConstraintsFromInput(inst, ec,
inst.getInputs(), outputs);
- }
-
- public static Instruction
preprocessParameterizedBuiltin(ParameterizedBuiltinCPInstruction inst,
ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(inst, ec,
inst.getInputs(), inst.getOutput() );
- }
-
- private static Instruction mergePrivacyConstraintsFromInput(Instruction
inst, ExecutionContext ec, CPOperand[] inputs, String[] outputNames){
+ private static Instruction mergePrivacyConstraintsFromInput(Instruction
inst, ExecutionContext ec,
+ CPOperand[] inputs, List<CPOperand> outputs, OperatorType
operatorType){
if ( inputs != null && inputs.length > 0 ){
PrivacyConstraint[] privacyConstraints =
getInputPrivacyConstraints(ec, inputs);
if ( privacyConstraints != null ){
- PrivacyConstraint mergedPrivacyConstraint =
mergeNary(privacyConstraints);
-
inst.setPrivacyConstraint(mergedPrivacyConstraint);
- if ( outputNames != null ){
- for (String outputName : outputNames)
- setOutputPrivacyConstraint(ec,
mergedPrivacyConstraint, outputName);
- }
- }
- }
- return inst;
- }
-
- private static Instruction mergePrivacyConstraintsFromInput(Instruction
inst, ExecutionContext ec, CPOperand[] inputs, CPOperand output){
- return mergePrivacyConstraintsFromInput(inst, ec, inputs,
getSingletonList(output));
- }
-
- private static Instruction mergePrivacyConstraintsFromInput(Instruction
inst, ExecutionContext ec, CPOperand[] inputs, List<CPOperand> outputs){
- if ( inputs != null && inputs.length > 0 ){
- PrivacyConstraint[] privacyConstraints =
getInputPrivacyConstraints(ec, inputs);
- if ( privacyConstraints != null ){
- PrivacyConstraint mergedPrivacyConstraint =
mergeNary(privacyConstraints);
+ PrivacyConstraint mergedPrivacyConstraint =
mergeNary(privacyConstraints, operatorType);
inst.setPrivacyConstraint(mergedPrivacyConstraint);
for ( CPOperand output : outputs ){
if ( output != null ) {
@@ -424,54 +394,24 @@ public class PrivacyPropagator
return inst;
}
- public static Instruction
preprocessBuiltinNary(BuiltinNaryCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(inst, ec,
inst.getInputs(), inst.getOutput() );
- }
-
- public static Instruction preprocessQuaternary(QuaternaryCPInstruction
inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(
- inst,
- ec,
- new CPOperand[]
{inst.input1,inst.input2,inst.input3,inst.getInput4()},
- inst.output
- );
- }
-
- public static Instruction
preprocessTernaryCPInstruction(ComputationCPInstruction inst, ExecutionContext
ec){
- return mergePrivacyConstraintsFromInput(inst, ec,
inst.getInputs(), inst.output);
- }
-
- public static Instruction
preprocessUnaryCPInstruction(UnaryCPInstruction inst, ExecutionContext ec){
- return propagateInputPrivacy(inst, ec, inst.input1,
inst.output);
- }
-
- public static Instruction
preprocessVariableCPInstruction(VariableCPInstruction inst, ExecutionContext
ec){
- switch ( inst.getVariableOpcode() ) {
- case CreateVariable:
- return propagateSecondInputPrivacy(inst, ec);
- case AssignVariable:
- return propagateInputPrivacy(inst, ec,
inst.getInput1(), inst.getInput2());
- case CopyVariable:
- case MoveVariable:
- case RemoveVariableAndFile:
- case CastAsMatrixVariable:
- case CastAsFrameVariable:
- case Write:
- case SetFileName:
- return propagateFirstInputPrivacy(inst, ec);
- case RemoveVariable:
- return propagateAllInputPrivacy(inst, ec);
- case CastAsScalarVariable:
- case CastAsDoubleVariable:
- case CastAsIntegerVariable:
- case CastAsBooleanVariable:
- return
propagateCastAsScalarVariablePrivacy(inst, ec);
- case Read:
- return inst;
- default:
- throwExceptionIfPrivacyActivated(inst);
- return inst;
+ /**
+ * Throw exception if privacy constraint activated for instruction or
for input to instruction.
+ * @param inst covariance instruction
+ * @param ec execution context
+ * @return input instruction if privacy constraints are not activated
+ */
+ private static Instruction
throwExceptionIfInputOrInstPrivacy(Instruction inst, ExecutionContext ec){
+ throwExceptionIfPrivacyActivated(inst);
+ CPOperand[] inputOperands = getInputOperands(inst);
+ if (inputOperands != null){
+ for ( CPOperand input : inputOperands ){
+ PrivacyConstraint privacyConstraint =
getInputPrivacyConstraint(ec, input);
+ if ( privacyConstraint != null){
+ throw new DMLPrivacyException("Input of
instruction " + inst + " has privacy constraints activated, but the constraints
are not propagated during preprocessing of instruction.");
+ }
+ }
}
+ return inst;
}
private static void throwExceptionIfPrivacyActivated(Instruction inst){
@@ -481,28 +421,6 @@ public class PrivacyPropagator
}
/**
- * Propagate privacy from first input.
- * @param inst Instruction
- * @param ec execution context
- * @return instruction with or without privacy constraints
- */
- private static Instruction
propagateCastAsScalarVariablePrivacy(VariableCPInstruction inst,
ExecutionContext ec){
- inst = (VariableCPInstruction) propagateFirstInputPrivacy(inst,
ec);
- return inst;
- }
-
- /**
- * Propagate privacy constraints from all inputs if privacy constraints
are set.
- * @param inst instruction
- * @param ec execution context
- * @return instruction with or without privacy constraints
- */
- private static Instruction
propagateAllInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(
- inst, ec, inst.getInputs().toArray(new CPOperand[0]),
inst.getOutput());
- }
-
- /**
* Propagate privacy constraint to instruction and output of instruction
* if data of first input is CacheableData and
* privacy constraint is activated.
@@ -561,7 +479,12 @@ public class PrivacyPropagator
return null;
}
-
+ /**
+ * Returns input privacy constraints as array or returns null if no
privacy constraints are found in the inputs.
+ * @param ec execution context
+ * @param inputs from which privacy constraints are retrieved
+ * @return array of privacy constraints from inputs
+ */
private static PrivacyConstraint[]
getInputPrivacyConstraints(ExecutionContext ec, CPOperand[] inputs){
if ( inputs != null && inputs.length > 0){
boolean privacyFound = false;
@@ -595,41 +518,29 @@ public class PrivacyPropagator
}
/**
- * Propagate privacy constraints to output variables
- * based on privacy constraint of CPOperand output in instruction
- * which has been set during privacy propagation preprocessing.
- * @param inst instruction for which privacy constraints are propagated
- * @param ec execution context
+ * Returns input CPOperands of instruction or returns null if
instruction type is not supported by this method.
+ * @param inst instruction from which the inputs are retrieved
+ * @return array of input CPOperands or null
*/
- public static void postProcessInstruction(Instruction inst,
ExecutionContext ec){
- // if inst has output
- List<CPOperand> instOutputs = getOutputOperands(inst);
- if (!instOutputs.isEmpty()){
- for ( CPOperand output : instOutputs ){
- PrivacyConstraint outputPrivacyConstraint =
output.getPrivacyConstraint();
- if (
PrivacyUtils.someConstraintSetUnary(outputPrivacyConstraint) )
- setOutputPrivacyConstraint(ec,
outputPrivacyConstraint, output.getName());
- }
- }
- }
-
- @SuppressWarnings("unused")
- private static String[] getOutputVariableName(Instruction inst){
- String[] instructionOutputNames = null;
- // The order of the following statements is important
- if ( inst instanceof
MultiReturnParameterizedBuiltinCPInstruction )
- instructionOutputNames =
((MultiReturnParameterizedBuiltinCPInstruction) inst).getOutputNames();
- else if ( inst instanceof MultiReturnBuiltinCPInstruction )
- instructionOutputNames =
((MultiReturnBuiltinCPInstruction) inst).getOutputNames();
- else if ( inst instanceof ComputationCPInstruction )
- instructionOutputNames = new
String[]{((ComputationCPInstruction) inst).getOutputVariableName()};
- else if ( inst instanceof VariableCPInstruction )
- instructionOutputNames = new
String[]{((VariableCPInstruction) inst).getOutputVariableName()};
- else if ( inst instanceof SqlCPInstruction )
- instructionOutputNames = new
String[]{((SqlCPInstruction) inst).getOutputVariableName()};
- return instructionOutputNames;
+ private static CPOperand[] getInputOperands(Instruction inst){
+ if ( inst instanceof ComputationCPInstruction )
+ return ((ComputationCPInstruction)inst).getInputs();
+ if ( inst instanceof BuiltinNaryCPInstruction )
+ return ((BuiltinNaryCPInstruction)inst).getInputs();
+ if ( inst instanceof FunctionCallCPInstruction )
+ return ((FunctionCallCPInstruction)inst).getInputs();
+ if ( inst instanceof SqlCPInstruction )
+ return ((SqlCPInstruction)inst).getInputs();
+ else return null;
}
+ /**
+ * Returns a list of output CPOperands of instruction or an empty list
if the instruction has no outputs.
+ * Note that this method needs to be extended as new instruction types
are added, otherwise it will
+ * return an empty list for instructions that may have outputs.
+ * @param inst instruction from which the outputs are retrieved
+ * @return list of outputs
+ */
private static List<CPOperand> getOutputOperands(Instruction inst){
// The order of the following statements is important
if ( inst instanceof
MultiReturnParameterizedBuiltinCPInstruction )
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 79fe54f..342a26b 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.privacy.fedplanning;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -81,7 +80,6 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
}
@Test
- @Ignore
public void federatedRowSum(){
federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
}
@@ -98,14 +96,12 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
}
@Test
- @Ignore
public void federatedAggregateBinaryColFedSequence(){
cols = rows;
federatedTwoMatricesSingleNodeTest(TEST_NAME_5);
}
@Test
- @Ignore
public void federatedAggregateBinarySequence2(){
federatedTwoMatricesSingleNodeTest(TEST_NAME_6);
}
@@ -147,8 +143,8 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
if ( testName.equals(TEST_NAME_5) ){
writeColStandardMatrix("X1", 42);
writeColStandardMatrix("X2", 1340);
- writeColStandardMatrix("Y1", 44);
- writeColStandardMatrix("Y2", 21);
+ writeColStandardMatrix("Y1", 44, null);
+ writeColStandardMatrix("Y2", 21, null);
}
else if ( testName.equals(TEST_NAME_6) ){
writeColStandardMatrix("X1", 42);