This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new 98cb93d [SYSTEMDS-362] Federated runtime propagation of privacy constraints 98cb93d is described below commit 98cb93d0dbbc2f1fabc3796bbc21aca1874eed5f Author: sebwrede <seb.wr...@hotmail.com> AuthorDate: Sun May 31 19:41:45 2020 +0200 [SYSTEMDS-362] Federated runtime propagation of privacy constraints * Runtime propagation of privacy constraints * Privacy level as Enum with three levels: Private, PrivateAggregate, and None * Privacy handling in FederatedWorkerHandler preventing private data from being included in federated response * Test of privacy handling of different federated request types * Test of different privacy levels and combinations for Federated L2SVM Closes #919. --- src/main/java/org/apache/sysds/hops/Hop.java | 2 +- .../org/apache/sysds/parser/BinaryExpression.java | 2 +- .../org/apache/sysds/parser/DataExpression.java | 11 +- .../java/org/apache/sysds/parser/Identifier.java | 5 +- .../federated/FederatedWorkerHandler.java | 7 + .../sysds/runtime/instructions/Instruction.java | 4 + .../instructions/cp/BuiltinNaryCPInstruction.java | 8 + .../runtime/instructions/cp/CPInstruction.java | 3 + .../instructions/cp/QuaternaryCPInstruction.java | 3 + .../instructions/cp/VariableCPInstruction.java | 518 ++++++++++++--------- .../runtime/instructions/fed/FEDInstruction.java | 5 +- .../instructions/spark/ReblockSPInstruction.java | 2 +- ...acyConstraint.java => DMLPrivacyException.java} | 38 +- .../sysds/runtime/privacy/PrivacyConstraint.java | 30 +- .../sysds/runtime/privacy/PrivacyMonitor.java | 96 ++++ .../sysds/runtime/privacy/PrivacyPropagator.java | 315 ++++++++++++- .../org/apache/sysds/runtime/util/HDFSTool.java | 7 +- .../test/functions/privacy/FederatedL2SVMTest.java | 384 +++++++++++++++ .../privacy/FederatedWorkerHandlerTest.java | 339 ++++++++++++++ .../MatrixMultiplicationPropagationTest.java | 53 ++- .../privacy/MatrixRuntimePropagationTest.java | 123 +++++ .../privacy/MatrixRuntimePropagationTest.dml | 28 ++ 22 files changed, 1695 insertions(+), 288 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index f0ef363..24aade1 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -73,7 +73,7 @@ public abstract class Hop implements ParseInfo protected ValueType _valueType; protected boolean _visited = false; protected DataCharacteristics _dc = new MatrixCharacteristics(); - protected PrivacyConstraint _privacyConstraint = new PrivacyConstraint(); + protected PrivacyConstraint _privacyConstraint = null; protected UpdateType _updateType = UpdateType.COPY; protected ArrayList<Hop> _parent = new ArrayList<>(); diff --git a/src/main/java/org/apache/sysds/parser/BinaryExpression.java b/src/main/java/org/apache/sysds/parser/BinaryExpression.java index 6c177e2..acccb66 100644 --- a/src/main/java/org/apache/sysds/parser/BinaryExpression.java +++ b/src/main/java/org/apache/sysds/parser/BinaryExpression.java @@ -146,7 +146,7 @@ public class BinaryExpression extends Expression } // Set privacy of output - output.setPrivacy(PrivacyPropagator.MergeBinary( + output.setPrivacy(PrivacyPropagator.mergeBinary( getLeft().getOutput().getPrivacy(), getRight().getOutput().getPrivacy())); this.setOutput(output); diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java index c94532d..779f788 100644 --- a/src/main/java/org/apache/sysds/parser/DataExpression.java +++ b/src/main/java/org/apache/sysds/parser/DataExpression.java @@ -37,6 +37,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.io.FileFormatPropertiesMM; import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.JSONHelper; @@ -1097,10 +1098,8 @@ public class DataExpression extends DataIdentifier // set privacy Expression eprivacy = getVarParam("privacy"); - boolean privacy = false; - if ( eprivacy != null ) { - privacy = Boolean.valueOf(eprivacy.toString()); - getOutput().setPrivacy(privacy); + if ( eprivacy != null ){ + getOutput().setPrivacy(PrivacyLevel.valueOf(eprivacy.toString())); } // Following dimension checks must be done when data type = MATRIX_DATA_TYPE @@ -2074,7 +2073,6 @@ public class DataExpression extends DataIdentifier if ( key.toString().equalsIgnoreCase(DELIM_HAS_HEADER_ROW) || key.toString().equalsIgnoreCase(DELIM_FILL) || key.toString().equalsIgnoreCase(DELIM_SPARSE) - || key.toString().equalsIgnoreCase(PRIVACY) ) { // parse these parameters as boolean values BooleanIdentifier boolId = null; @@ -2096,7 +2094,8 @@ public class DataExpression extends DataIdentifier removeVarParam(key.toString()); addVarParam(key.toString(), doubleId); } - else if (key.toString().equalsIgnoreCase(DELIM_NA_STRINGS)) { + else if (key.toString().equalsIgnoreCase(DELIM_NA_STRINGS) + || key.toString().equalsIgnoreCase(PRIVACY)) { String naStrings = null; if ( val instanceof String) { naStrings = val.toString(); diff --git a/src/main/java/org/apache/sysds/parser/Identifier.java b/src/main/java/org/apache/sysds/parser/Identifier.java index 36d93f2..3ea3252 100644 --- a/src/main/java/org/apache/sysds/parser/Identifier.java +++ b/src/main/java/org/apache/sysds/parser/Identifier.java @@ -26,6 +26,7 @@ import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.parser.LanguageException.LanguageErrorCodes; import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; public abstract class Identifier extends Expression { @@ -104,8 +105,8 @@ public abstract class Identifier extends Expression _nnz = nnzs; } - public void setPrivacy(boolean privacy){ - _privacy = new PrivacyConstraint(privacy); + public void setPrivacy(PrivacyLevel privacyLevel){ + _privacy = new PrivacyConstraint(privacyLevel); } public void setPrivacy(PrivacyConstraint privacyConstraint){ diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index 279685b..bba731c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -51,6 +51,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.apache.sysds.runtime.privacy.PrivacyMonitor; +import org.apache.sysds.runtime.privacy.PrivacyPropagator; import org.apache.sysds.utils.JSONHelper; import org.apache.wink.json4j.JSONObject; @@ -149,6 +151,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { return new FederatedResponse(FederatedResponse.Type.ERROR, "Could not parse metadata file"); mc.setRows(mtd.getLong(DataExpression.READROWPARAM)); mc.setCols(mtd.getLong(DataExpression.READCOLPARAM)); + cd = PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd); fmt = FileFormat.safeValueOf(mtd.getString(DataExpression.FORMAT_TYPE)); } } @@ -181,6 +184,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { private FederatedResponse executeMatVecMult(long varID, MatrixBlock vector, boolean isMatVecMult) { MatrixObject matTo = (MatrixObject) _vars.get(varID); + matTo = PrivacyMonitor.handlePrivacy(matTo); MatrixBlock matBlock1 = matTo.acquireReadAndRelease(); // TODO other datatypes AggregateBinaryOperator ab_op = new AggregateBinaryOperator( @@ -199,6 +203,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { private FederatedResponse getVariableData(long varID) { Data dataObject = _vars.get(varID); + dataObject = PrivacyMonitor.handlePrivacy(dataObject); switch (dataObject.getDataType()) { case TENSOR: return new FederatedResponse(FederatedResponse.Type.SUCCESS, @@ -233,6 +238,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { + dataObject.getDataType().name()); } MatrixObject matrixObject = (MatrixObject) dataObject; + matrixObject = PrivacyMonitor.handlePrivacy(matrixObject); MatrixBlock matrixBlock = matrixObject.acquireRead(); // create matrix for calculation with correction MatrixCharacteristics mc = new MatrixCharacteristics(); @@ -270,6 +276,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { private FederatedResponse executeScalarOperation(long varID, ScalarOperator operator) { Data dataObject = _vars.get(varID); + dataObject = PrivacyMonitor.handlePrivacy(dataObject); if (dataObject.getDataType() != Types.DataType.MATRIX) { return new FederatedResponse(FederatedResponse.Type.ERROR, "FederatedWorkerHandler: ScalarOperator dont support " diff --git a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java index adcae38..db867ef 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java @@ -138,6 +138,10 @@ public abstract class Instruction privacyConstraint = lop.getPrivacyConstraint(); } + public void setPrivacyConstraint(PrivacyConstraint pc){ + privacyConstraint = pc; + } + public PrivacyConstraint getPrivacyConstraint(){ return privacyConstraint; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java index 1fa9d2b..3111042 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java @@ -48,6 +48,14 @@ public abstract class BuiltinNaryCPInstruction extends CPInstruction this.inputs = inputs; } + public CPOperand[] getInputs(){ + return inputs; + } + + public CPOperand getOutput(){ + return output; + } + public static BuiltinNaryCPInstruction parseInstruction(String str) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java index 1e60eea..82aaa7d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.instructions.CPInstructionParser; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.fed.FEDInstructionUtils; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.privacy.PrivacyPropagator; public abstract class CPInstruction extends Instruction { @@ -95,6 +96,8 @@ public abstract class CPInstruction extends Instruction //robustness federated instructions (runtime assignment) tmp = FEDInstructionUtils.checkAndReplaceCP(tmp, ec); + + tmp = PrivacyPropagator.preprocessInstruction(tmp, ec); return tmp; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java index de30062..7f8ec4d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java @@ -96,6 +96,9 @@ public class QuaternaryCPInstruction extends ComputationCPInstruction { throw new DMLRuntimeException("Unexpected opcode in QuaternaryCPInstruction: " + inst); } + public CPOperand getInput4() { + return input4; + } @Override public void processInstruction(ExecutionContext ec) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 21b3f637..f40abb0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -58,6 +58,10 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaData; import org.apache.sysds.runtime.meta.MetaDataFormat; import org.apache.sysds.runtime.meta.TensorCharacteristics; +import org.apache.sysds.runtime.privacy.DMLPrivacyException; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; +import org.apache.sysds.runtime.privacy.PrivacyMonitor; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.runtime.util.ProgramConverter; @@ -289,6 +293,10 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace return ret; } + public CPOperand getOutput(){ + return output; + } + private static int getArity(VariableOperationCode op) { switch(op) { case Write: @@ -512,71 +520,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace switch ( opcode ) { case CreateVariable: - //PRE: for robustness we cleanup existing variables, because a setVariable - //would cause a buffer pool memory leak as these objects would never be removed - if(ec.containsVariable(getInput1())) - processRemoveVariableInstruction(ec, getInput1().getName()); - - if ( getInput1().getDataType() == DataType.MATRIX ) { - //create new variable for symbol table and cache - //(existing objects gets cleared through rmvar instructions) - String fname = getInput2().getName(); - // check if unique filename needs to be generated - if( Boolean.parseBoolean(getInput3().getName()) ) - fname = getUniqueFileName(fname); - MatrixObject obj = new MatrixObject(getInput1().getValueType(), fname); - //clone meta data because it is updated on copy-on-write, otherwise there - //is potential for hidden side effects between variables. - obj.setMetaData((MetaData)metadata.clone()); - obj.setPrivacyConstraints(getPrivacyConstraint()); - obj.setFileFormatProperties(_formatProperties); - obj.setMarkForLinCache(true); - obj.enableCleanup(!getInput1().getName() - .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX)); - ec.setVariable(getInput1().getName(), obj); - - obj.setUpdateType(_updateType); - if(DMLScript.STATISTICS && _updateType.isInPlace()) - Statistics.incrementTotalUIPVar(); - } - else if( getInput1().getDataType() == DataType.TENSOR ) { - //create new variable for symbol table and cache - //(existing objects gets cleared through rmvar instructions) - String fname = getInput2().getName(); - // check if unique filename needs to be generated - if( Boolean.parseBoolean(getInput3().getName()) ) - fname = getUniqueFileName(fname); - CacheableData<?> obj = new TensorObject(getInput1().getValueType(), fname); - //clone meta data because it is updated on copy-on-write, otherwise there - //is potential for hidden side effects between variables. - obj.setMetaData((MetaData)metadata.clone()); - obj.setFileFormatProperties(_formatProperties); - obj.enableCleanup(!getInput1().getName() - .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX)); - ec.setVariable(getInput1().getName(), obj); - - // TODO update - } - else if( getInput1().getDataType() == DataType.FRAME ) { - String fname = getInput2().getName(); - if( Boolean.parseBoolean(getInput3().getName()) ) - fname = getUniqueFileName(fname); - FrameObject fobj = new FrameObject(fname); - fobj.setMetaData((MetaData)metadata.clone()); - fobj.setFileFormatProperties(_formatProperties); - if( _schema != null ) - fobj.setSchema(_schema); //after metadata - fobj.enableCleanup(!getInput1().getName() - .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX)); - ec.setVariable(getInput1().getName(), fobj); - } - else if ( getInput1().getDataType() == DataType.SCALAR ){ - //created variable not called for scalars - ec.setScalarOutput(getInput1().getName(), null); - } - else { - throw new DMLRuntimeException("Unexpected data type: " + getInput1().getDataType()); - } + processCreateVariableInstruction(ec); break; case AssignVariable: @@ -598,51 +542,241 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace break; case RemoveVariableAndFile: - // Remove the variable from HashMap _variables, and possibly delete the data on disk. - boolean del = ( (BooleanObject) ec.getScalarInput(getInput2().getName(), getInput2().getValueType(), true) ).getBooleanValue(); - MatrixObject m = (MatrixObject) ec.removeVariable(getInput1().getName()); + processRemoveVariableAndFileInstruction(ec); + break; + + case CastAsScalarVariable: //castAsScalarVariable + processCastAsScalarVariableInstruction(ec); + break; + + case CastAsMatrixVariable: + processCastAsMatrixVariableInstruction(ec); + break; + + case CastAsFrameVariable: + processCastAsFrameVariableInstruction(ec); + break; + + case CastAsDoubleVariable: + ScalarObject scalarDoubleInput = ec.getScalarInput(getInput1()); + ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToDouble(scalarDoubleInput)); + break; + + case CastAsIntegerVariable: + ScalarObject scalarLongInput = ec.getScalarInput(getInput1()); + ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToLong(scalarLongInput)); + break; + + case CastAsBooleanVariable: + ScalarObject scalarBooleanInput = ec.getScalarInput(getInput1()); + ec.setScalarOutput(output.getName(), new BooleanObject(scalarBooleanInput.getBooleanValue())); + break; - if ( !del ) { - // HDFS file should be retailed after clearData(), - // therefore data must be exported if dirty flag is set - if ( m.isDirty() ) - m.exportData(); + case Read: + processReadInstruction(ec); + break; + + case Write: + processWriteInstruction(ec); + break; + + case SetFileName: + processSetFileNameInstruction(ec); + break; + + default: + throw new DMLRuntimeException("Unknown opcode: " + opcode ); + } + } + + /** + * Handler for processInstruction "CreateVariable" case + * @param ec execution context of the instruction + */ + private void processCreateVariableInstruction(ExecutionContext ec){ + //PRE: for robustness we cleanup existing variables, because a setVariable + //would cause a buffer pool memory leak as these objects would never be removed + if(ec.containsVariable(getInput1())) + processRemoveVariableInstruction(ec, getInput1().getName()); + + switch(getInput1().getDataType()) { + case MATRIX: { + String fname = createUniqueFilename(); + MatrixObject obj = new MatrixObject(getInput1().getValueType(), fname); + setCacheableDataFields(obj); + obj.setUpdateType(_updateType); + obj.setMarkForLinCache(true); + ec.setVariable(getInput1().getName(), obj); + if(DMLScript.STATISTICS && _updateType.isInPlace()) + Statistics.incrementTotalUIPVar(); + break; } - else { - //throw new DMLRuntimeException("rmfilevar w/ true is not expected! " + instString); - //cleanDataOnHDFS(pb, input1.getName()); - cleanDataOnHDFS( m ); + case TENSOR: { + String fname = createUniqueFilename(); + TensorObject obj = new TensorObject(getInput1().getValueType(), fname); + setCacheableDataFields(obj); + ec.setVariable(getInput1().getName(), obj); + break; } + case FRAME: { + String fname = createUniqueFilename(); + FrameObject fobj = new FrameObject(fname); + setCacheableDataFields(fobj); + if( _schema != null ) + fobj.setSchema(_schema); //after metadata + ec.setVariable(getInput1().getName(), fobj); + break; + } + case SCALAR: { + //created variable not called for scalars + ec.setScalarOutput(getInput1().getName(), null); + break; + } + default: + throw new DMLRuntimeException("Unexpected data type: " + getInput1().getDataType()); + } + } + + private String createUniqueFilename(){ + //create new variable for symbol table and cache + //(existing objects gets cleared through rmvar instructions) + String fname = getInput2().getName(); + // check if unique filename needs to be generated + if( Boolean.parseBoolean(getInput3().getName()) ) { + fname = getUniqueFileName(fname); + } + return fname; + } + + private void setCacheableDataFields(CacheableData<?> obj){ + //clone meta data because it is updated on copy-on-write, otherwise there + //is potential for hidden side effects between variables. + obj.setMetaData((MetaData)metadata.clone()); + obj.setPrivacyConstraints(getPrivacyConstraint()); + obj.enableCleanup(!getInput1().getName() + .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX)); + obj.setFileFormatProperties(_formatProperties); + } + + /** + * Handler for mvvar instructions. + * Example: mvvar <srcvar> <destFile> <format> + * Move the file pointed by srcvar to destFile. + * Currently, applicable only when format=binaryblock. + * + * @param ec execution context + */ + @SuppressWarnings("rawtypes") + private void processMoveInstruction(ExecutionContext ec) { + + if ( getInput3() == null ) { + // example: mvvar tempA A - // check if in-memory object can be cleaned up - if ( !ec.getVariables().hasReferences(m) ) { - // no other variable in the symbol table points to the same Data object as that of input1.getName() + // get source variable + Data srcData = ec.getVariable(getInput1().getName()); + + if ( srcData == null ) { + throw new DMLRuntimeException("Unexpected error: could not find a data object " + + "for variable name:" + getInput1().getName() + ", while processing instruction "); + } + + if( getInput2().getDataType().isMatrix() || getInput2().getDataType().isFrame() ) { + // remove existing variable bound to target name + Data tgt = ec.removeVariable(getInput2().getName()); - //remove matrix object from cache - m.clearData(); + //cleanup matrix data on fs/hdfs (if necessary) + if( tgt != null ) + ec.cleanupDataObject(tgt); } - - break; - case CastAsScalarVariable: //castAsScalarVariable - if( getInput1().getDataType().isFrame() ) { - FrameBlock fBlock = ec.getFrameInput(getInput1().getName()); - if( fBlock.getNumRows()!=1 || fBlock.getNumColumns()!=1 ) - throw new DMLRuntimeException("Dimension mismatch - unable to cast frame '"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x "+fBlock.getNumColumns()+") to scalar."); - Object value = fBlock.get(0,0); - ec.releaseFrameInput(getInput1().getName()); - ec.setScalarOutput(output.getName(), - ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value)); + // do the actual move + ec.setVariable(getInput2().getName(), srcData); + ec.removeVariable(getInput1().getName()); + } + else { + // example instruction: mvvar <srcVar> <destFile> <format> + if ( ec.getVariable(getInput1().getName()) == null ) + throw new DMLRuntimeException("Unexpected error: could not find a data object for variable name:" + getInput1().getName() + ", while processing instruction " +this.toString()); + + Data object = ec.getVariable(getInput1().getName()); + + if ( getInput3().getName().equalsIgnoreCase("binaryblock") ) { + boolean success = false; + success = ((CacheableData)object).moveData(getInput2().getName(), getInput3().getName()); + if (!success) { + throw new DMLRuntimeException("Failed to move var " + getInput1().getName() + " to file " + getInput2().getName() + "."); + } } - else if( getInput1().getDataType().isMatrix() ) { + else + if(object instanceof MatrixObject) + throw new DMLRuntimeException("Unexpected formats while copying: from matrix blocks [" + + ((MatrixObject)object).getBlocksize() + "] to " + getInput3().getName()); + else if (object instanceof FrameObject) + throw new DMLRuntimeException("Unexpected formats while copying: from fram object [" + + ((FrameObject)object).getNumColumns() + "," + ((FrameObject)object).getNumColumns() + "] to " + getInput3().getName()); + } + } + + /** + * Handler for RemoveVariableAndFile instruction + * + * @param ec execution context + */ + private void processRemoveVariableAndFileInstruction(ExecutionContext ec){ + // Remove the variable from HashMap _variables, and possibly delete the data on disk. + boolean del = ( (BooleanObject) ec.getScalarInput(getInput2().getName(), getInput2().getValueType(), true) ).getBooleanValue(); + MatrixObject m = (MatrixObject) ec.removeVariable(getInput1().getName()); + + if ( !del ) { + // HDFS file should be retailed after clearData(), + // therefore data must be exported if dirty flag is set + if ( m.isDirty() ) + m.exportData(); + } + else { + //throw new DMLRuntimeException("rmfilevar w/ true is not expected! " + instString); + //cleanDataOnHDFS(pb, input1.getName()); + cleanDataOnHDFS( m ); + } + + // check if in-memory object can be cleaned up + if ( !ec.getVariables().hasReferences(m) ) { + // no other variable in the symbol table points to the same Data object as that of input1.getName() + + //remove matrix object from cache + m.clearData(); + } + } + + /** + * Process CastAsScalarVariable instruction. + * @param ec execution context + */ + private void processCastAsScalarVariableInstruction(ExecutionContext ec){ + //TODO: Create privacy constraints for ScalarObject so that the privacy constraints can be propagated to scalars as well. + PrivacyMonitor.handlePrivacyScalarOutput(getInput1(), ec); + + switch( getInput1().getDataType() ) { + case MATRIX: { MatrixBlock mBlock = ec.getMatrixInput(getInput1().getName()); if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 ) throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix '"+getInput1().getName()+"' of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar."); double value = mBlock.getValue(0,0); ec.releaseMatrixInput(getInput1().getName()); ec.setScalarOutput(output.getName(), new DoubleObject(value)); + break; + } + case FRAME: { + FrameBlock fBlock = ec.getFrameInput(getInput1().getName()); + if( fBlock.getNumRows()!=1 || fBlock.getNumColumns()!=1 ) + throw new DMLRuntimeException("Dimension mismatch - unable to cast frame '"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x "+fBlock.getNumColumns()+") to scalar."); + Object value = fBlock.get(0,0); + ec.releaseFrameInput(getInput1().getName()); + ec.setScalarOutput(output.getName(), + ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value)); + break; } - else if( getInput1().getDataType().isTensor() ) { + case TENSOR: { TensorBlock tBlock = ec.getTensorInput(getInput1().getName()); if (tBlock.getNumDims() != 2 || tBlock.getNumRows() != 1 || tBlock.getNumColumns() != 1) throw new DMLRuntimeException("Dimension mismatch - unable to cast tensor '" + getInput1().getName() + "' to scalar."); @@ -650,31 +784,41 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace ec.setScalarOutput(output.getName(), ScalarObjectFactory .createScalarObject(vt, tBlock.get(new int[] {0, 0}))); ec.releaseTensorInput(getInput1().getName()); + break; } - else if( getInput1().getDataType().isList() ) { + case LIST: { //TODO handling of cleanup status, potentially new object ListObject list = (ListObject)ec.getVariable(getInput1().getName()); ec.setVariable(output.getName(), list.slice(0)); + break; } - else { + default: throw new DMLRuntimeException("Unsupported data type " + "in as.scalar(): "+getInput1().getDataType().name()); - } - break; - case CastAsMatrixVariable:{ - if( getInput1().getDataType().isFrame() ) { + } + } + + /** + * Handler for CastAsMatrixVariable instruction + * @param ec execution context + */ + private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { + switch( getInput1().getDataType() ) { + case FRAME: { FrameBlock fin = ec.getFrameInput(getInput1().getName()); MatrixBlock out = DataConverter.convertToMatrixBlock(fin); ec.releaseFrameInput(getInput1().getName()); ec.setMatrixOutput(output.getName(), out); + break; } - else if( getInput1().getDataType().isScalar() ) { + case SCALAR: { ScalarObject scalarInput = ec.getScalarInput( getInput1().getName(), getInput1().getValueType(), getInput1().isLiteral()); MatrixBlock out = new MatrixBlock(scalarInput.getDoubleValue()); ec.setMatrixOutput(output.getName(), out); + break; } - else if( getInput1().getDataType().isList() ) { + case LIST: { //TODO handling of cleanup status, potentially new object ListObject list = (ListObject)ec.getVariable(getInput1().getName()); if( list.getLength() > 1 ) { @@ -696,47 +840,40 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace ec.setVariable(output.getName(), tmp); } } + break; } - else { + default: throw new DMLRuntimeException("Unsupported data type " + "in as.matrix(): "+getInput1().getDataType().name()); - } - break; } - case CastAsFrameVariable:{ - FrameBlock out = null; - if( getInput1().getDataType()==DataType.SCALAR ) { - ScalarObject scalarInput = ec.getScalarInput(getInput1()); - out = new FrameBlock(1, getInput1().getValueType()); - out.ensureAllocatedColumns(1); - out.set(0, 0, scalarInput.getStringValue()); - } - else { //DataType.FRAME - MatrixBlock min = ec.getMatrixInput(getInput1().getName()); - out = DataConverter.convertToFrameBlock(min); - ec.releaseMatrixInput(getInput1().getName()); - } - ec.setFrameOutput(output.getName(), out); - break; - } - case CastAsDoubleVariable:{ - ScalarObject in = ec.getScalarInput(getInput1()); - ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToDouble(in)); - break; - } - case CastAsIntegerVariable:{ - ScalarObject in = ec.getScalarInput(getInput1()); - ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToLong(in)); - break; - } - case CastAsBooleanVariable:{ + } + + /** + * Handler for CastAsFrameVariable instruction + * @param ec execution context + */ + private void processCastAsFrameVariableInstruction(ExecutionContext ec){ + FrameBlock out = null; + if( getInput1().getDataType()==DataType.SCALAR ) { ScalarObject scalarInput = ec.getScalarInput(getInput1()); - ec.setScalarOutput(output.getName(), new BooleanObject(scalarInput.getBooleanValue())); - break; + out = new FrameBlock(1, getInput1().getValueType()); + out.ensureAllocatedColumns(1); + out.set(0, 0, scalarInput.getStringValue()); } - - case Read: - ScalarObject res = null; + else { //DataType.FRAME + MatrixBlock min = ec.getMatrixInput(getInput1().getName()); + out = DataConverter.convertToFrameBlock(min); + ec.releaseMatrixInput(getInput1().getName()); + } + ec.setFrameOutput(output.getName(), out); + } + + /** + * Handler for Read instruction + * @param ec execution context + */ + private void processReadInstruction(ExecutionContext ec){ + ScalarObject res = null; try { switch(getInput1().getValueType()) { case FP64: @@ -759,89 +896,6 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace throw new DMLRuntimeException(e); } ec.setScalarOutput(getInput1().getName(), res); - - break; - - case Write: - processWriteInstruction(ec); - break; - - case SetFileName: - Data data = ec.getVariable(getInput1().getName()); - if ( data.getDataType() == DataType.MATRIX ) { - if ( getInput3().getName().equalsIgnoreCase("remote") ) { - ((MatrixObject)data).setFileName(getInput2().getName()); - } - else { - throw new DMLRuntimeException("Invalid location (" + getInput3().getName() + ") in SetFileName instruction: " + instString); - } - } else{ - throw new DMLRuntimeException("Invalid data type (" + getInput1().getDataType() + ") in SetFileName instruction: " + instString); - } - break; - - default: - throw new DMLRuntimeException("Unknown opcode: " + opcode ); - } - } - - /** - * Handler for mvvar instructions. - * Example: mvvar <srcvar> <destFile> <format> - * Move the file pointed by srcvar to destFile. - * Currently, applicable only when format=binaryblock. - * - * @param ec execution context - */ - @SuppressWarnings("rawtypes") - private void processMoveInstruction(ExecutionContext ec) { - - if ( getInput3() == null ) { - // example: mvvar tempA A - - // get source variable - Data srcData = ec.getVariable(getInput1().getName()); - - if ( srcData == null ) { - throw new DMLRuntimeException("Unexpected error: could not find a data object " - + "for variable name:" + getInput1().getName() + ", while processing instruction "); - } - - if( getInput2().getDataType().isMatrix() || getInput2().getDataType().isFrame() ) { - // remove existing variable bound to target name - Data tgt = ec.removeVariable(getInput2().getName()); - - //cleanup matrix data on fs/hdfs (if necessary) - if( tgt != null ) - ec.cleanupDataObject(tgt); - } - - // do the actual move - ec.setVariable(getInput2().getName(), srcData); - ec.removeVariable(getInput1().getName()); - } - else { - // example instruction: mvvar <srcVar> <destFile> <format> - if ( ec.getVariable(getInput1().getName()) == null ) - throw new DMLRuntimeException("Unexpected error: could not find a data object for variable name:" + getInput1().getName() + ", while processing instruction " +this.toString()); - - Object object = ec.getVariable(getInput1().getName()); - - if ( getInput3().getName().equalsIgnoreCase("binaryblock") ) { - boolean success = false; - success = ((CacheableData)object).moveData(getInput2().getName(), getInput3().getName()); - if (!success) { - throw new DMLRuntimeException("Failed to move var " + getInput1().getName() + " to file " + getInput2().getName() + "."); - } - } - else - if(object instanceof MatrixObject) - throw new DMLRuntimeException("Unexpected formats while copying: from matrix blocks [" - + ((MatrixObject)object).getBlocksize() + "] to " + getInput3().getName()); - else if (object instanceof FrameObject) - throw new DMLRuntimeException("Unexpected formats while copying: from fram object [" - + ((FrameObject)object).getNumColumns() + "," + ((FrameObject)object).getNumColumns() + "] to " + getInput3().getName()); - } } /** @@ -898,20 +952,38 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace else { // Default behavior MatrixObject mo = ec.getMatrixObject(getInput1().getName()); - mo.setPrivacyConstraints(getPrivacyConstraint()); mo.exportData(fname, fmtStr, _formatProperties); } + // Set privacy constraint of write instruction to the same as that of the input + setPrivacyConstraint(ec.getMatrixObject(getInput1().getName()).getPrivacyConstraint()); } else if( getInput1().getDataType() == DataType.FRAME ) { FrameObject mo = ec.getFrameObject(getInput1().getName()); mo.exportData(fname, fmtStr, _formatProperties); + setPrivacyConstraint(mo.getPrivacyConstraint()); } else if( getInput1().getDataType() == DataType.TENSOR ) { // TODO write tensor TensorObject to = ec.getTensorObject(getInput1().getName()); + setPrivacyConstraint(to.getPrivacyConstraint()); to.exportData(fname, fmtStr, _formatProperties); } } + + /** + * Handler for SetFileName instruction + * @param ec execution context + */ + private void processSetFileNameInstruction(ExecutionContext ec){ + Data data = ec.getVariable(getInput1().getName()); + if ( data.getDataType() == DataType.MATRIX ) { + if ( getInput3().getName().equalsIgnoreCase("remote") ) + ((MatrixObject)data).setFileName(getInput2().getName()); + else + throw new DMLRuntimeException("Invalid location (" + getInput3().getName() + ") in SetFileName instruction: " + instString); + } else + throw new DMLRuntimeException("Invalid data type (" + getInput1().getDataType() + ") in SetFileName instruction: " + instString); + } /** * Remove variable instruction externalized as a static function in order to allow various @@ -956,7 +1028,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace else { mo.exportData(fname, outFmt, _formatProperties); } - HDFSTool.writeMetaDataFile (fname + ".mtd", mo.getValueType(), dc, FileFormat.CSV, _formatProperties); + HDFSTool.writeMetaDataFile (fname + ".mtd", mo.getValueType(), dc, FileFormat.CSV, _formatProperties, mo.getPrivacyConstraint()); } catch (IOException e) { throw new DMLRuntimeException(e); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java index 9000200..fc064eb 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java @@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.privacy.PrivacyPropagator; public abstract class FEDInstruction extends Instruction { @@ -58,6 +59,8 @@ public abstract class FEDInstruction extends Instruction { @Override public Instruction preprocessInstruction(ExecutionContext ec) { - return super.preprocessInstruction(ec); + Instruction tmp = super.preprocessInstruction(ec); + tmp = PrivacyPropagator.preprocessInstruction(tmp, ec); + return tmp; } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java index 99602ae..cf0d162 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java @@ -88,7 +88,7 @@ public class ReblockSPInstruction extends UnarySPInstruction { DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName()); mcOut.set(mc.getRows(), mc.getCols(), blen, mc.getNonZeros()); - //get the source format form the meta data + //get the source format from the meta data MetaDataFormat iimd = (MetaDataFormat) obj.getMetaData(); if(iimd == null) throw new DMLRuntimeException("Error: Metadata not found"); diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/DMLPrivacyException.java similarity index 57% copy from src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java copy to src/main/java/org/apache/sysds/runtime/privacy/DMLPrivacyException.java index 2b32636..7e77b04 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/DMLPrivacyException.java @@ -19,24 +19,30 @@ package org.apache.sysds.runtime.privacy; +import org.apache.sysds.runtime.DMLRuntimeException; + /** - * PrivacyConstraint holds all privacy constraints for data in the system at compile time and runtime. + * This exception should be thrown to flag DML runtime errors related to the violation of privacy constraints. */ -public class PrivacyConstraint +public class DMLPrivacyException extends DMLRuntimeException { - protected boolean _privacy = false; - - public PrivacyConstraint(){} - - public PrivacyConstraint(boolean privacy) { - _privacy = privacy; - } + private static final long serialVersionUID = 1L; - public void setPrivacy(boolean privacy){ - _privacy = privacy; - } + //prevent string concatenation of classname w/ stop message + private DMLPrivacyException(Exception e) { + super(e); + } - public boolean getPrivacy(){ - return _privacy; - } -} \ No newline at end of file + private DMLPrivacyException(String string, Exception ex){ + super(string,ex); + } + + /** + * This is the only valid constructor for DMLPrivacyException. + * + * @param msg message + */ + public DMLPrivacyException(String msg) { + super(msg); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java index 2b32636..45b12be 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java @@ -24,19 +24,25 @@ package org.apache.sysds.runtime.privacy; */ public class PrivacyConstraint { - protected boolean _privacy = false; + public enum PrivacyLevel { + None, // No data exchange constraints. Data can be shared with anyone. + Private, // Data cannot leave the origin. + PrivateAggregation // Only aggregations of the data can leave the origin. + } - public PrivacyConstraint(){} + protected PrivacyLevel privacyLevel = PrivacyLevel.None; - public PrivacyConstraint(boolean privacy) { - _privacy = privacy; - } + public PrivacyConstraint(){} - public void setPrivacy(boolean privacy){ - _privacy = privacy; - } + public PrivacyConstraint(PrivacyLevel privacyLevel) { + setPrivacyLevel(privacyLevel); + } - public boolean getPrivacy(){ - return _privacy; - } -} \ No newline at end of file + public void setPrivacyLevel(PrivacyLevel privacyLevel){ + this.privacyLevel = privacyLevel; + } + + public PrivacyLevel getPrivacyLevel(){ + return privacyLevel; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java new file mode 100644 index 0000000..118a153 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.privacy; + +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; + +public class PrivacyMonitor +{ + //TODO maybe maintain a log of checked constaints for transfers + // in order to provide 'privacy explanations' similar to our stats + + /** + * Throws DMLPrivacyException if data object is CacheableData and privacy constraint is set to private or private aggregation. + * @param dataObject input data object + * @return data object or data object with privacy constraint removed in case the privacy level was none. + */ + public static Data handlePrivacy(Data dataObject){ + if ( dataObject instanceof CacheableData<?> ){ + PrivacyConstraint privacyConstraint = ((CacheableData<?>)dataObject).getPrivacyConstraint(); + if (privacyConstraint != null){ + PrivacyLevel privacyLevel = privacyConstraint.getPrivacyLevel(); + switch(privacyLevel){ + case None: + ((CacheableData<?>)dataObject).setPrivacyConstraints(null); + break; + case Private: + case PrivateAggregation: + throw new DMLPrivacyException("Cannot share variable, since the privacy constraint of the requested variable is set to " + privacyLevel.name()); + default: + throw new DMLPrivacyException("Privacy level " + privacyLevel.name() + " of variable not recognized"); + } + } + } + return dataObject; + } + + /** + * Throws DMLPrivacyException if privacy constraint of matrix object has level privacy. + * @param matrixObject input matrix object + * @return matrix object or matrix object with privacy constraint removed in case the privacy level was none. + */ + public static MatrixObject handlePrivacy(MatrixObject matrixObject){ + PrivacyConstraint privacyConstraint = matrixObject.getPrivacyConstraint(); + if (privacyConstraint != null){ + PrivacyLevel privacyLevel = privacyConstraint.getPrivacyLevel(); + switch(privacyLevel){ + case None: + matrixObject.setPrivacyConstraints(null); + break; + case Private: + throw new DMLPrivacyException("Cannot share variable, since the privacy constraint of the requested variable is set to " + privacyLevel.name()); + case PrivateAggregation: + break; + default: + throw new DMLPrivacyException("Privacy level " + privacyLevel.name() + " of variable not recognized"); + } + } + return matrixObject; + } + + /** + * Throw DMLPrivacyException if privacy is activated for the input variable + * @param input variable for which the privacy constraint is checked + */ + public static void handlePrivacyScalarOutput(CPOperand input, ExecutionContext ec) { + Data data = ec.getCacheableData(input); + if ( data != null && (data instanceof CacheableData<?>)){ + PrivacyConstraint privacyConstraintIn = ((CacheableData<?>) data).getPrivacyConstraint(); + if ( privacyConstraintIn != null && (privacyConstraintIn.getPrivacyLevel() == PrivacyLevel.Private) ){ + throw new DMLPrivacyException("Privacy constraint cannot be propagated to scalar for input " + input.getName()); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java index 2070c99..323330a 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java @@ -19,20 +19,325 @@ package org.apache.sysds.runtime.privacy; +import java.util.function.Function; + +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.CPInstruction; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; + /** * Class with static methods merging privacy constraints of operands * in expressions to generate the privacy constraints of the output. */ -public class PrivacyPropagator { +public class PrivacyPropagator +{ + public static CacheableData<?> parseAndSetPrivacyConstraint(CacheableData<?> cd, JSONObject mtd) + throws JSONException + { + if ( mtd.containsKey(DataExpression.PRIVACY) ) { + String privacyLevel = mtd.getString(DataExpression.PRIVACY); + if ( privacyLevel != null ) + cd.setPrivacyConstraints(new PrivacyConstraint(PrivacyLevel.valueOf(privacyLevel))); + } + return cd; + } + + public static PrivacyConstraint mergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) { + if (privacyConstraint1 != null && privacyConstraint2 != null){ + PrivacyLevel privacyLevel1 = privacyConstraint1.getPrivacyLevel(); + PrivacyLevel privacyLevel2 = privacyConstraint2.getPrivacyLevel(); - public static PrivacyConstraint MergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) { - if (privacyConstraint1 != null && privacyConstraint2 != null) - return new PrivacyConstraint( - privacyConstraint1.getPrivacy() || privacyConstraint2.getPrivacy()); + // One of the inputs are private, hence the output must be private. + if (privacyLevel1 == PrivacyLevel.Private || privacyLevel2 == PrivacyLevel.Private) + return new PrivacyConstraint(PrivacyLevel.Private); + // One of the inputs are private with aggregation allowed, but none of the inputs are completely private, + // hence the output must be private with aggregation. + else if (privacyLevel1 == PrivacyLevel.PrivateAggregation || privacyLevel2 == PrivacyLevel.PrivateAggregation) + return new PrivacyConstraint(PrivacyLevel.PrivateAggregation); + // Both inputs have privacy level "None", hence the privacy constraint can be removed. + else + return null; + } else if (privacyConstraint1 != null) return privacyConstraint1; else if (privacyConstraint2 != null) return privacyConstraint2; return null; } + + public static PrivacyConstraint mergeTernary(PrivacyConstraint[] privacyConstraints){ + return mergeBinary(mergeBinary(privacyConstraints[0], privacyConstraints[1]), privacyConstraints[2]); + } + + public static PrivacyConstraint mergeQuaternary(PrivacyConstraint[] privacyConstraints){ + return mergeBinary(mergeTernary(privacyConstraints), privacyConstraints[3]); + } + + public static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraints){ + PrivacyConstraint mergedPrivacyConstraint = privacyConstraints[0]; + for ( int i = 1; i < privacyConstraints.length; i++ ){ + mergedPrivacyConstraint = mergeBinary(mergedPrivacyConstraint, privacyConstraints[i]); + } + return mergedPrivacyConstraint; + } + + public static Instruction preprocessInstruction(Instruction inst, ExecutionContext ec){ + switch ( inst.getType() ){ + case CONTROL_PROGRAM: + return preprocessCPInstruction( (CPInstruction) inst, ec ); + case BREAKPOINT: + case SPARK: + case GPU: + case FEDERATED: + return inst; + default: + throwExceptionIfPrivacyActivated(inst, ec); + return inst; + } + } + + public static Instruction preprocessCPInstruction(CPInstruction inst, ExecutionContext ec){ + switch ( inst.getCPInstructionType() ) + { + case Variable: + return preprocessVariableCPInstruction((VariableCPInstruction) inst, ec); + case AggregateUnary: + case Reorg: + case Unary: + return preprocessUnaryCPInstruction((UnaryCPInstruction) inst, ec); + case AggregateBinary: + case Append: + case Binary: + return preprocessBinaryCPInstruction((BinaryCPInstruction) inst, ec); + case AggregateTernary: + case Ternary: + return preprocessTernaryCPInstruction((ComputationCPInstruction) inst, ec); + case Quaternary: + return preprocessQuaternary((QuaternaryCPInstruction) inst, ec); + case BuiltinNary: + case Builtin: + return preprocessBuiltinNary((BuiltinNaryCPInstruction) inst, ec); + case Ctable: + case MultiReturnParameterizedBuiltin: + case MultiReturnBuiltin: + case ParameterizedBuiltin: + default: + return preprocessInstructionSimple(inst, ec); + } + } + + /** + * Throw exception if privacy constraints are activated or return instruction if privacy is not activated + * @param inst instruction + * @param ec execution context + * @return instruction + */ + public static Instruction preprocessInstructionSimple(Instruction inst, ExecutionContext ec){ + throwExceptionIfPrivacyActivated(inst, ec); + return inst; + } + + public static Instruction preprocessBuiltinNary(BuiltinNaryCPInstruction inst, ExecutionContext ec){ + if (inst.getInputs() == null) return inst; + PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inst.getInputs()); + PrivacyConstraint mergedPrivacyConstraint = mergeNary(privacyConstraints); + inst.setPrivacyConstraint(mergedPrivacyConstraint); + setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.getOutput()); + return inst; + } + + public static Instruction preprocessQuaternary(QuaternaryCPInstruction inst, ExecutionContext ec){ + PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, + new CPOperand[] {inst.input1,inst.input2,inst.input3,inst.getInput4()}); + PrivacyConstraint mergedPrivacyConstraint = mergeQuaternary(privacyConstraints); + inst.setPrivacyConstraint(mergedPrivacyConstraint); + setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.output); + return inst; + } + + public static Instruction preprocessTernaryCPInstruction(ComputationCPInstruction inst, ExecutionContext ec){ + PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, new CPOperand[]{inst.input1, inst.input2, inst.input3}); + PrivacyConstraint mergedPrivacyConstraint = mergeTernary(privacyConstraints); + inst.setPrivacyConstraint(mergedPrivacyConstraint); + setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.output); + return inst; + } + + public static Instruction preprocessNaryInstruction(CPInstruction inst, ExecutionContext ec, CPOperand[] inputs, CPOperand output, Function<PrivacyConstraint[], PrivacyConstraint> mergeFunction){ + PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inputs); + PrivacyConstraint mergedPrivacyConstraint = mergeFunction.apply(privacyConstraints); + inst.setPrivacyConstraint(mergedPrivacyConstraint); + setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, output); + return inst; + + } + + public static Instruction preprocessBinaryCPInstruction(BinaryCPInstruction inst, ExecutionContext ec){ + PrivacyConstraint privacyConstraint1 = getInputPrivacyConstraint(ec, inst.input1); + PrivacyConstraint privacyConstraint2 = getInputPrivacyConstraint(ec, inst.input2); + if ( privacyConstraint1 != null || privacyConstraint2 != null) + { + PrivacyConstraint mergedPrivacyConstraint = mergeBinary(privacyConstraint1, privacyConstraint2); + inst.setPrivacyConstraint(mergedPrivacyConstraint); + setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.output); + } + return inst; + } + + public static Instruction preprocessUnaryCPInstruction(UnaryCPInstruction inst, ExecutionContext ec){ + return propagateInputPrivacy(inst, ec, inst.input1, inst.output); + } + + public static Instruction preprocessVariableCPInstruction(VariableCPInstruction inst, ExecutionContext ec){ + switch ( inst.getVariableOpcode() ) + { + case CreateVariable: + return propagateSecondInputPrivacy(inst, ec); + case AssignVariable: + //Assigns scalar, hence it does not have privacy activated + return inst; + case CopyVariable: + case MoveVariable: + return propagateFirstInputPrivacy(inst, ec); + case RemoveVariable: + return propagateAllInputPrivacy(inst, ec); + case RemoveVariableAndFile: + return propagateFirstInputPrivacy(inst, ec); + case CastAsScalarVariable: + return propagateCastAsScalarVariablePrivacy(inst, ec); + case CastAsMatrixVariable: + case CastAsFrameVariable: + return propagateFirstInputPrivacy(inst, ec); + case CastAsDoubleVariable: + case CastAsIntegerVariable: + case CastAsBooleanVariable: + return propagateCastAsScalarVariablePrivacy(inst, ec); + case Read: + return inst; + case Write: + return propagateFirstInputPrivacy(inst, ec); + case SetFileName: + return propagateFirstInputPrivacy(inst, ec); + default: + throwExceptionIfPrivacyActivated(inst, ec); + return inst; + } + } + + private static void throwExceptionIfPrivacyActivated(Instruction inst, ExecutionContext ec){ + if ( inst.getPrivacyConstraint() != null && inst.getPrivacyConstraint().getPrivacyLevel().equals(PrivacyLevel.Private) ) { + throw new DMLPrivacyException("Instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction."); + } + } + + /** + * Propagate privacy from first input and throw exception if privacy is activated. + * @param inst Instruction + * @param ec execution context + * @return instruction with or without privacy constraints + */ + private static Instruction propagateCastAsScalarVariablePrivacy(VariableCPInstruction inst, ExecutionContext ec){ + inst = (VariableCPInstruction) propagateFirstInputPrivacy(inst, ec); + return preprocessInstructionSimple(inst, ec); + } + + /** + * Propagate privacy constraints from all inputs if privacy constraints are set. + * @param inst instruction + * @param ec execution context + * @return instruction with or without privacy constraints + */ + private static Instruction propagateAllInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){ + //TODO: Propagate the most restricting constraints instead of just the latest activated constraint + for ( CPOperand input : inst.getInputs() ) + inst = (VariableCPInstruction) propagateInputPrivacy(inst, ec, input, inst.getOutput()); + return inst; + } + + /** + * Propagate privacy constraint to instruction and output of instruction + * if data of first input is CacheableData and + * privacy constraint is activated. + * @param inst VariableCPInstruction + * @param ec execution context + * @return instruction with or without privacy constraints + */ + private static Instruction propagateFirstInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){ + return propagateInputPrivacy(inst, ec, inst.getInput1(), inst.getOutput()); + } + + /** + * Propagate privacy constraint to instruction and output of instruction + * if data of second input is CacheableData and + * privacy constraint is activated. + * @param inst VariableCPInstruction + * @param ec execution context + * @return instruction with or without privacy constraints + */ + private static Instruction propagateSecondInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){ + return propagateInputPrivacy(inst, ec, inst.getInput2(), inst.getOutput()); + } + + /** + * Propagate privacy constraint to instruction and output of instruction + * if data of the specified variable is CacheableData + * and privacy constraint is activated + * @param inst instruction + * @param ec execution context + * @param inputOperand input from which the privacy constraint is found + * @param outputOperand output which the privacy constraint is propagated to + * @return instruction with or without privacy constraints + */ + private static Instruction propagateInputPrivacy(Instruction inst, ExecutionContext ec, CPOperand inputOperand, CPOperand outputOperand){ + PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, inputOperand); + if ( privacyConstraint != null ) { + inst.setPrivacyConstraint(privacyConstraint); + if ( outputOperand != null) + setOutputPrivacyConstraint(ec, privacyConstraint, outputOperand); + } + return inst; + } + + private static PrivacyConstraint getInputPrivacyConstraint(ExecutionContext ec, CPOperand input){ + if ( input != null && input.getName() != null){ + Data dd = ec.getVariable(input.getName()); + if ( dd != null && dd instanceof CacheableData) + return ((CacheableData<?>) dd).getPrivacyConstraint(); + } + return null; + } + + + private static PrivacyConstraint[] getInputPrivacyConstraints(ExecutionContext ec, CPOperand[] inputs){ + PrivacyConstraint[] privacyConstraints = new PrivacyConstraint[inputs.length]; + for ( int i = 0; i < inputs.length; i++ ){ + privacyConstraints[i] = getInputPrivacyConstraint(ec, inputs[i]); + } + return privacyConstraints; + + } + + private static void setOutputPrivacyConstraint(ExecutionContext ec, PrivacyConstraint privacyConstraint, CPOperand output){ + Data dd = ec.getVariable(output.getName()); + if ( dd != null ){ + if ( dd instanceof CacheableData ){ + ((CacheableData<?>) dd).setPrivacyConstraints(privacyConstraint); + ec.setVariable(output.getName(), dd); + } + else throw new DMLPrivacyException("Privacy constraint of " + output + " cannot be set since it is not an instance of CacheableData"); + } + } } \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java index 8ea5ff5..8b1e42e 100644 --- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java +++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java @@ -361,6 +361,11 @@ public class HDFSTool throws IOException { writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, dc, fmt, formatProperties); } + + public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics dc, FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) + throws IOException { + writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, dc, fmt, formatProperties, privacyConstraint); + } public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc, FileFormat fmt, FileFormatProperties formatProperties) @@ -452,7 +457,7 @@ public class HDFSTool //add privacy constraints if ( privacyConstraint != null ){ - mtd.put(DataExpression.PRIVACY, privacyConstraint.getPrivacy()); + mtd.put(DataExpression.PRIVACY, privacyConstraint.getPrivacyLevel().name()); } //add username and time diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedL2SVMTest.java new file mode 100644 index 0000000..c93b660 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedL2SVMTest.java @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.privacy; + +import org.junit.Test; +import org.apache.sysds.api.DMLException; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.wink.json4j.JSONException; + +import java.util.HashMap; +import java.util.Map; + +@net.jcip.annotations.NotThreadSafe +public class FederatedL2SVMTest extends AutomatedTestBase { + + private final static String TEST_DIR = "functions/federated/"; + private final static String TEST_NAME = "FederatedL2SVMTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedL2SVMTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + private int rows = 100; + private int cols = 10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); + } + + // PrivateAggregation Single Input + + @Test + public void federatedL2SVMCPPrivateAggregationX1() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation); + } + + @Test + public void federatedL2SVMCPPrivateAggregationX2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation); + } + + @Test + public void federatedL2SVMCPPrivateAggregationY() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation); + } + + // Private Single Input + + @Test + public void federatedL2SVMCPPrivateFederatedX1() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateFederatedX2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateFederatedY() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private); + } + + // Setting Privacy of Matrix (Throws Exception) + + @Test + public void federatedL2SVMCPPrivateMatrixX1() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null); + } + + @Test + public void federatedL2SVMCPPrivateMatrixX2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null); + } + + @Test + public void federatedL2SVMCPPrivateMatrixY() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null); + } + + @Test + public void federatedL2SVMCPPrivateFederatedAndMatrixX1() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateFederatedAndMatrixX2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateFederatedAndMatrixY() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null); + } + + // Privacy Level Private Combinations + + @Test + public void federatedL2SVMCPPrivateFederatedX1X2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateFederatedX1Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateFederatedX2Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateFederatedX1X2Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + // Privacy Level PrivateAggregation Combinations + @Test + public void federatedL2SVMCPPrivateAggregationFederatedX1X2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation); + } + + @Test + public void federatedL2SVMCPPrivateAggregationFederatedX1Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation); + } + + @Test + public void federatedL2SVMCPPrivateAggregationFederatedX2Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation); + } + + @Test + public void federatedL2SVMCPPrivateAggregationFederatedX1X2Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation); + } + + // Privacy Level Combinations + @Test + public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1X2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2Y() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX1() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private); + } + + @Test + public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX2() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private); + } + + @Test + public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2X1() throws JSONException { + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + // Require Federated Workers to return matrix + + @Test + public void federatedL2SVMCPPrivateAggregationX1Exception() throws JSONException { + rows = 1000; cols = 1; + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateAggregationX2Exception() throws JSONException { + rows = 1000; cols = 1; + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateX1Exception() throws JSONException { + rows = 1000; cols = 1; + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + @Test + public void federatedL2SVMCPPrivateX2Exception() throws JSONException { + rows = 1000; cols = 1; + Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>(); + privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private)); + federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class); + } + + private void federatedL2SVMNoException(Types.ExecMode execMode, Map<String, + PrivacyConstraint> privacyConstraintsFederated, Map<String, PrivacyConstraint> privacyConstraintsMatrix, + PrivacyLevel expectedPrivacyLevel) + throws JSONException + { + federatedL2SVM(execMode, privacyConstraintsFederated, privacyConstraintsMatrix, expectedPrivacyLevel, false, null, false, null); + } + + private void federatedL2SVM(Types.ExecMode execMode, Map<String, PrivacyConstraint> privacyConstraintsFederated, + Map<String, PrivacyConstraint> privacyConstraintsMatrix, PrivacyLevel expectedPrivacyLevel, + boolean exception1, Class<?> expectedException1, boolean exception2, Class<?> expectedException2 ) + throws JSONException + { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + Thread t1 = null, t2 = null; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int halfRows = rows / 2; + // We have two matrices handled by a single federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); + double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233); + for(int i = 0; i < rows; i++) + Y[i][0] = (Y[i][0] > 0) ? 1 : -1; + + // Write privacy constraints of normal matrix + if ( privacyConstraintsMatrix != null ){ + writeInputMatrixWithMTD("MX1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsMatrix.get("X1")); + writeInputMatrixWithMTD("MX2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsMatrix.get("X2")); + writeInputMatrixWithMTD("MY", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows), privacyConstraintsMatrix.get("Y")); + } else { + writeInputMatrixWithMTD("MX1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("MX2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("MY", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows)); + } + + // Write privacy constraints of federated matrix + if ( privacyConstraintsFederated != null ){ + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsFederated.get("X1")); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsFederated.get("X2")); + writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows), privacyConstraintsFederated.get("Y")); + } else { + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows)); + } + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + t1 = startLocalFedWorker(port1); + t2 = startLocalFedWorker(port2); + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-args", input("MX1"), input("MX2"), input("MY"), expected("Z")}; + runTest(true, exception1, expectedException1, -1); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-args", "\"localhost:" + port1 + "/" + input("X1") + "\"", + "\"localhost:" + port2 + "/" + input("X2") + "\"", Integer.toString(rows), Integer.toString(cols), + Integer.toString(halfRows), input("Y"), output("Z")}; + runTest(true, exception2, expectedException2, -1); + + if ( !(exception1 || exception2) ) { + compareResults(1e-9); + } + } + finally { + TestUtils.shutdownThreads(t1, t2); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java new file mode 100644 index 0000000..f74e3a9 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.privacy; + +import java.util.Arrays; + +import org.apache.sysds.api.DMLException; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.apache.sysds.common.Types; +import static java.lang.Thread.sleep; + +public class FederatedWorkerHandlerTest extends AutomatedTestBase { + + private static final String TEST_DIR = "functions/federated/"; + private static final String TEST_DIR_SCALAR = TEST_DIR + "matrix_scalar/"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWorkerHandlerTest.class.getSimpleName() + "/"; + private final static String TEST_CLASS_DIR_SCALAR = TEST_DIR_SCALAR + FederatedWorkerHandlerTest.class.getSimpleName() + "/"; + private static final String TEST_PROG_SCALAR_ADDITION_MATRIX = "ScalarAdditionFederatedMatrix"; + private final static String AGGREGATION_TEST_NAME = "FederatedSumTest"; + private final static String TRANSFER_TEST_NAME = "FederatedRCBindTest"; + private final static String MATVECMULT_TEST_NAME = "FederatedMultiplyTest"; + private static final String FEDERATED_WORKER_HOST = "localhost"; + private static final int FEDERATED_WORKER_PORT = 1222; + + private final static int blocksize = 1024; + private int rows = 10; + private int cols = 10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration("scalar", new TestConfiguration(TEST_CLASS_DIR_SCALAR, TEST_PROG_SCALAR_ADDITION_MATRIX, new String [] {"R"})); + addTestConfiguration("aggregation", new TestConfiguration(TEST_CLASS_DIR, AGGREGATION_TEST_NAME, new String[] {"S.scalar", "R", "C"})); + addTestConfiguration("transfer", new TestConfiguration(TEST_CLASS_DIR, TRANSFER_TEST_NAME, new String[] {"R", "C"})); + addTestConfiguration("matvecmult", new TestConfiguration(TEST_CLASS_DIR, MATVECMULT_TEST_NAME, new String[] {"Z"})); + } + + @Test + public void scalarPrivateTest(){ + scalarTest(PrivacyLevel.Private, DMLException.class); + } + + @Test + public void scalarPrivateAggregationTest(){ + scalarTest(PrivacyLevel.PrivateAggregation, DMLException.class); + } + + @Test + public void scalarNonePrivateTest(){ + scalarTest(PrivacyLevel.None, null); + } + + private void scalarTest(PrivacyLevel privacyLevel, Class<?> expectedException){ + getAndLoadTestConfiguration("scalar"); + + double[][] m = getRandomMatrix(this.rows, this.cols, -1, 1, 1.0, 1); + + PrivacyConstraint pc = new PrivacyConstraint(privacyLevel); + writeInputMatrixWithMTD("M", m, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols), pc); + + int s = TestUtils.getRandomInt(); + double[][] r = new double[rows][cols]; + for(int i = 0; i < rows; i++) { + for(int j = 0; j < cols; j++) { + r[i][j] = m[i][j] + s; + } + } + if (expectedException == null) + writeExpectedMatrix("R", r); + + runGenericScalarTest(TEST_PROG_SCALAR_ADDITION_MATRIX, s, expectedException); + } + + + private void runGenericScalarTest(String dmlFile, int s, Class<?> expectedException) + { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + + Thread t = null; + try { + // we need the reference file to not be written to hdfs, so we get the correct format + rtplatform = Types.ExecMode.SINGLE_NODE; + if (rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + programArgs = new String[] {"-w", Integer.toString(FEDERATED_WORKER_PORT)}; + t = new Thread(() -> runTest(true, false, null, -1)); + t.start(); + sleep(FED_WORKER_WAIT); + fullDMLScriptName = SCRIPT_DIR + TEST_DIR_SCALAR + dmlFile + ".dml"; + programArgs = new String[]{"-args", + TestUtils.federatedAddress(FEDERATED_WORKER_HOST, FEDERATED_WORKER_PORT, input("M")), + Integer.toString(rows), Integer.toString(cols), + Integer.toString(s), + output("R")}; + boolean exceptionExpected = (expectedException != null); + runTest(true, exceptionExpected, expectedException, -1); + + if ( !exceptionExpected ) + compareResults(); + } catch (InterruptedException e) { + e.printStackTrace(); + assert (false); + } finally { + rtplatform = platformOld; + TestUtils.shutdownThread(t); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + @Test + public void aggregatePrivateTest() { + federatedSum(Types.ExecMode.SINGLE_NODE, PrivacyLevel.Private, DMLException.class); + } + + @Test + public void aggregatePrivateAggregationTest() { + federatedSum(Types.ExecMode.SINGLE_NODE, PrivacyLevel.PrivateAggregation, null); + } + + @Test + public void aggregateNonePrivateTest() { + federatedSum(Types.ExecMode.SINGLE_NODE, PrivacyLevel.None, null); + } + + public void federatedSum(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + + Thread t = null; + + getAndLoadTestConfiguration("aggregation"); + String HOME = SCRIPT_DIR + TEST_DIR; + + double[][] A = getRandomMatrix(rows, cols, -10, 10, 1, 1); + writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols), new PrivacyConstraint(privacyLevel)); + int port = getRandomAvailablePort(); + t = startLocalFedWorker(port); + + // we need the reference file to not be written to hdfs, so we get the correct format + rtplatform = Types.ExecMode.SINGLE_NODE; + // Run reference dml script with normal matrix for Row/Col sum + fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")}; + runTest(true, false, null, -1); + + // write expected sum + double sum = 0; + for(double[] doubles : A) { + sum += Arrays.stream(doubles).sum(); + } + sum *= 2; + + if ( expectedException == null ) + writeExpectedScalar("S", sum); + + // reference file should not be written to hdfs, so we set platform here + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get("aggregation"); + loadTestConfiguration(config); + fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + ".dml"; + programArgs = new String[] {"-args", "\"localhost:" + port + "/" + input("A") + "\"", Integer.toString(rows), + Integer.toString(cols), Integer.toString(rows * 2), output("S"), output("R"), output("C")}; + + runTest(true, (expectedException != null), expectedException, -1); + + // compare all sums via files + if ( expectedException == null ) + compareResults(1e-11); + + TestUtils.shutdownThread(t); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + + @Test + public void transferPrivateTest() { + federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.Private, DMLException.class); + } + + @Test + public void transferPrivateAggregationTest() { + federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.PrivateAggregation, DMLException.class); + } + + @Test + public void transferNonePrivateTest() { + federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.None, null); + } + + public void federatedRCBind(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + + Thread t = null; + + getAndLoadTestConfiguration("transfer"); + String HOME = SCRIPT_DIR + TEST_DIR; + + double[][] A = getRandomMatrix(rows, cols, -10, 10, 1, 1); + writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols), new PrivacyConstraint(privacyLevel)); + + int port = getRandomAvailablePort(); + t = startLocalFedWorker(port); + + // we need the reference file to not be written to hdfs, so we get the correct format + rtplatform = Types.ExecMode.SINGLE_NODE; + // Run reference dml script with normal matrix for Row/Col sum + fullDMLScriptName = HOME + TRANSFER_TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-args", input("A"), expected("R"), expected("C")}; + runTest(true, false, null, -1); + + // reference file should not be written to hdfs, so we set platform here + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get("transfer"); + loadTestConfiguration(config); + fullDMLScriptName = HOME + TRANSFER_TEST_NAME + ".dml"; + programArgs = new String[] {"-args", "\"localhost:" + port + "/" + input("A") + "\"", Integer.toString(rows), + Integer.toString(cols), output("R"), output("C")}; + + runTest(true, (expectedException != null), expectedException, -1); + + // compare all sums via files + if ( expectedException == null ) + compareResults(1e-11); + + TestUtils.shutdownThread(t); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + + @Test + public void matVecMultPrivateTest() { + federatedMultiply(Types.ExecMode.SINGLE_NODE, PrivacyLevel.Private, DMLException.class); + } + + @Test + public void matVecMultPrivateAggregationTest() { + federatedMultiply(Types.ExecMode.SINGLE_NODE, PrivacyLevel.PrivateAggregation, DMLException.class); + } + + @Test + public void matVecMultNonePrivateTest() { + federatedMultiply(Types.ExecMode.SINGLE_NODE, PrivacyLevel.None, null); + } + + public void federatedMultiply(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + + Thread t1, t2; + + getAndLoadTestConfiguration("matvecmult"); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int halfRows = rows / 2; + // We have two matrices handled by a single federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); + // And another two matrices handled by a single federated worker + double[][] Y1 = getRandomMatrix(cols, halfRows, 0, 1, 1, 44); + double[][] Y2 = getRandomMatrix(cols, halfRows, 0, 1, 1, 21); + + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), new PrivacyConstraint(privacyLevel)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols)); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + t1 = startLocalFedWorker(port1); + t2 = startLocalFedWorker(port2); + + TestConfiguration config = availableTestConfigurations.get("matvecmult"); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + MATVECMULT_TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"), + "Y2=" + input("Y2"), "Z=" + expected("Z")}; + runTest(true, false, null, -1); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + MATVECMULT_TEST_NAME + ".dml"; + programArgs = new String[] {"-nvargs", + "X1=" + TestUtils.federatedAddress("localhost", port1, input("X1")), + "X2=" + TestUtils.federatedAddress("localhost", port2, input("X2")), + "Y1=" + TestUtils.federatedAddress("localhost", port1, input("Y1")), + "Y2=" + TestUtils.federatedAddress("localhost", port2, input("Y2")), "r=" + rows, "c=" + cols, + "hr=" + halfRows, "Z=" + output("Z")}; + runTest(true, (expectedException != null), expectedException, -1); + + // compare via files + if (expectedException == null) + compareResults(1e-9); + + TestUtils.shutdownThreads(t1, t2); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java index a16355a..0715b0a 100644 --- a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java @@ -27,6 +27,7 @@ import org.junit.Test; import org.apache.sysds.parser.DataExpression; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; @@ -46,26 +47,36 @@ public class MatrixMultiplicationPropagationTest extends AutomatedTestBase { } @Test - public void testMatrixMultiplicationPropagation() throws JSONException { - matrixMultiplicationPropagation(true, true); + public void testMatrixMultiplicationPropagationPrivate() throws JSONException { + matrixMultiplicationPropagation(PrivacyLevel.Private, true); } @Test - public void testMatrixMultiplicationPropagationFalse() throws JSONException { - matrixMultiplicationPropagation(false, true); + public void testMatrixMultiplicationPropagationNone() throws JSONException { + matrixMultiplicationPropagation(PrivacyLevel.None, true); } @Test - public void testMatrixMultiplicationPropagationSecondOperand() throws JSONException { - matrixMultiplicationPropagation(true, false); + public void testMatrixMultiplicationPropagationPrivateAggregation() throws JSONException { + matrixMultiplicationPropagation(PrivacyLevel.PrivateAggregation, true); } @Test - public void testMatrixMultiplicationPropagationSecondOperandFalse() throws JSONException { - matrixMultiplicationPropagation(false, false); + public void testMatrixMultiplicationPropagationSecondOperandPrivate() throws JSONException { + matrixMultiplicationPropagation(PrivacyLevel.Private, false); } - private void matrixMultiplicationPropagation(boolean privacy, boolean privateFirstOperand) throws JSONException { + @Test + public void testMatrixMultiplicationPropagationSecondOperandNone() throws JSONException { + matrixMultiplicationPropagation(PrivacyLevel.None, false); + } + + @Test + public void testMatrixMultiplicationPropagationSecondOperandPrivateAggregation() throws JSONException { + matrixMultiplicationPropagation(PrivacyLevel.PrivateAggregation, false); + } + + private void matrixMultiplicationPropagation(PrivacyLevel privacyLevel, boolean privateFirstOperand) throws JSONException { TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest"); loadTestConfiguration(config); @@ -78,7 +89,7 @@ public class MatrixMultiplicationPropagationTest extends AutomatedTestBase { double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1); double[][] c = TestUtils.performMatrixMultiplication(a, b); - PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacy); + PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacyLevel); MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k); if ( privateFirstOperand ) { @@ -99,7 +110,7 @@ public class MatrixMultiplicationPropagationTest extends AutomatedTestBase { // Check that the output metadata is correct String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY); - assertEquals(String.valueOf(privacy), actualPrivacyValue); + assertEquals(String.valueOf(privacyLevel), actualPrivacyValue); } @Test @@ -144,28 +155,32 @@ public class MatrixMultiplicationPropagationTest extends AutomatedTestBase { } @Test - public void testMatrixMultiplicationPrivacyInputTrue() throws JSONException { - testMatrixMultiplicationPrivacyInput(true); + public void testMatrixMultiplicationPrivacyInputPrivate() throws JSONException { + testMatrixMultiplicationPrivacyInput(PrivacyLevel.Private); + } + + @Test + public void testMatrixMultiplicationPrivacyInputNone() throws JSONException { + testMatrixMultiplicationPrivacyInput(PrivacyLevel.None); } @Test - public void testMatrixMultiplicationPrivacyInputFalse() throws JSONException { - testMatrixMultiplicationPrivacyInput(false); + public void testMatrixMultiplicationPrivacyInputPrivateAggregation() throws JSONException { + testMatrixMultiplicationPrivacyInput(PrivacyLevel.PrivateAggregation); } - private void testMatrixMultiplicationPrivacyInput(boolean privacy) throws JSONException { + private void testMatrixMultiplicationPrivacyInput(PrivacyLevel privacyLevel) throws JSONException { TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest"); loadTestConfiguration(config); double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1); - PrivacyConstraint privacyConstraint = new PrivacyConstraint(); - privacyConstraint.setPrivacy(privacy); + PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacyLevel); MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k); writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint); String actualPrivacyValue = readDMLMetaDataValue("a", INPUT_DIR, DataExpression.PRIVACY); - assertEquals(String.valueOf(privacy), actualPrivacyValue); + assertEquals(String.valueOf(privacyLevel), actualPrivacyValue); } } diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixRuntimePropagationTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixRuntimePropagationTest.java new file mode 100644 index 0000000..a72ea32 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixRuntimePropagationTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.privacy; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.wink.json4j.JSONException; +import org.junit.Test; + +public class MatrixRuntimePropagationTest extends AutomatedTestBase +{ + private static final String TEST_DIR = "functions/privacy/"; + private final static String TEST_CLASS_DIR = TEST_DIR + MatrixMultiplicationPropagationTest.class.getSimpleName() + "/"; + private final int m = 20; + private final int n = 20; + private final int k = 20; + + @Override + public void setUp() { + addTestConfiguration("MatrixRuntimePropagationTest", + new TestConfiguration(TEST_CLASS_DIR, "MatrixRuntimePropagationTest", new String[]{"c"})); + } + + @Test + public void testRuntimePropagationPrivate() throws JSONException { + conditionalPropagation(PrivacyLevel.Private); + } + + @Test + public void testRuntimePropagationNone() throws JSONException { + conditionalPropagation(PrivacyLevel.None); + } + + @Test + public void testRuntimePropagationPrivateAggregation() throws JSONException { + conditionalPropagation(PrivacyLevel.PrivateAggregation); + } + + private void conditionalPropagation(PrivacyLevel privacyLevel) throws JSONException { + + TestConfiguration config = availableTestConfigurations.get("MatrixRuntimePropagationTest"); + loadTestConfiguration(config); + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + config.getTestScript() + ".dml"; + + double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1); + double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1); + double sum; + + PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacyLevel); + MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k); + + writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint); + writeInputMatrix("b", b); + if ( privacyLevel.equals(PrivacyLevel.Private) || privacyLevel.equals(PrivacyLevel.PrivateAggregation) ){ + writeExpectedMatrix("c", a); + sum = TestUtils.sum(a, m, n) + 1; + } else { + writeExpectedMatrix("c", b); + sum = TestUtils.sum(a, m, n) - 1; + } + + programArgs = new String[]{"-nvargs", + "a=" + input("a"), "b=" + input("b"), "c=" + output("c"), + "m=" + m, "n=" + n, "k=" + k, "s=" + sum }; + + runTest(true,false,null,-1); + + // Check that the output data is correct + compareResults(1e-9); + + // Check that the output metadata is correct + if ( privacyLevel.equals(PrivacyLevel.Private) ) { + String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY); + assertEquals(PrivacyLevel.Private.name(), actualPrivacyValue); + } + else if ( privacyLevel.equals(PrivacyLevel.PrivateAggregation) ){ + String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY); + assertEquals(PrivacyLevel.PrivateAggregation.name(), actualPrivacyValue); + } + else { + // Check that a JSONException is thrown + // or that privacy level is set to none + // because no privacy metadata should be written to c + // except if the privacy written is set to private + boolean JSONExceptionThrown = false; + String actualPrivacyValue = null; + try{ + actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY); + } catch (JSONException e){ + JSONExceptionThrown = true; + } catch (Exception e){ + fail("Exception occured, but JSONException was expected. The exception thrown is: " + e.getMessage()); + e.printStackTrace(); + } + assert(JSONExceptionThrown || (PrivacyLevel.None.name().equals(actualPrivacyValue))); + } + } +} diff --git a/src/test/scripts/functions/privacy/MatrixRuntimePropagationTest.dml b/src/test/scripts/functions/privacy/MatrixRuntimePropagationTest.dml new file mode 100644 index 0000000..b51cbf3 --- /dev/null +++ b/src/test/scripts/functions/privacy/MatrixRuntimePropagationTest.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($a, rows=$m, cols=$n, format="text"); +B = read($b, rows=$n, cols=$k, format="text"); +if ( sum(A) < $s){ + write(A, $c, format="text"); +} else { + write(B, $c, format="text"); +} \ No newline at end of file