This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new 013ca82 [SYSTEMDS-361] New privacy constraint meta data (compiler/runtime) 013ca82 is described below commit 013ca8224c23b1d9f63e254162a56fb78bf74c96 Author: sebwrede <seb.wr...@hotmail.com> AuthorDate: Fri Apr 24 20:00:13 2020 +0200 [SYSTEMDS-361] New privacy constraint meta data (compiler/runtime) Closes #895. --- docs/Tasks.txt | 7 + .../java/org/apache/sysds/hops/AggBinaryOp.java | 3 +- src/main/java/org/apache/sysds/hops/DataOp.java | 1 + src/main/java/org/apache/sysds/hops/Hop.java | 21 ++- src/main/java/org/apache/sysds/hops/LiteralOp.java | 1 + src/main/java/org/apache/sysds/lops/DataGen.java | 4 +- src/main/java/org/apache/sysds/lops/Lop.java | 18 +++ .../java/org/apache/sysds/lops/compile/Dag.java | 34 +++- .../org/apache/sysds/parser/BinaryExpression.java | 26 ++-- .../org/apache/sysds/parser/DMLTranslator.java | 3 + .../org/apache/sysds/parser/DataExpression.java | 135 +++++++--------- .../java/org/apache/sysds/parser/Identifier.java | 15 ++ .../controlprogram/caching/CacheableData.java | 16 +- .../sysds/runtime/instructions/Instruction.java | 12 ++ .../instructions/cp/VariableCPInstruction.java | 2 + .../org/apache/sysds/runtime/io/MatrixReader.java | 4 +- .../sysds/runtime/privacy/PrivacyConstraint.java | 42 +++++ .../sysds/runtime/privacy/PrivacyPropagator.java | 38 +++++ .../org/apache/sysds/runtime/util/HDFSTool.java | 38 ++++- .../org/apache/sysds/test/AutomatedTestBase.java | 43 +++++- src/test/java/org/apache/sysds/test/TestUtils.java | 129 ++++++---------- .../test/functions/data/misc/WriteMMTest.java | 2 +- .../MatrixMultiplicationPropagationTest.java | 171 +++++++++++++++++++++ .../MatrixMultiplicationPropagationTest.dml | 27 ++++ 24 files changed, 591 insertions(+), 201 deletions(-) diff --git a/docs/Tasks.txt b/docs/Tasks.txt index 2283d57..d1e30c0 100644 --- a/docs/Tasks.txt +++ b/docs/Tasks.txt @@ -260,5 +260,12 @@ SYSTEMDS-340 Compiler Assisted Lineage Caching and Reuse SYSTEMDS-350 Data Cleaning Framework * 351 New builtin function for error correction by schema OK +SYSTEMDS-360 Privacy/Data Exchange Constraints + * 361 Initial privacy meta data (compiler/runtime) OK + * 362 Runtime privacy propagation + * 363 Compile-time privacy propagation + * 364 Error handling violated privacy constraints + * 365 Extended privacy/data exchange constraints + Others: * Break append instruction to cbind and rbind diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index b456cc8..a04d267 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -627,7 +627,8 @@ public class AggBinaryOp extends MultiThreadedHop setOutputDimensions(matmultCP); } - setLineNumbers( matmultCP ); + setLineNumbers(matmultCP); + setPrivacy(matmultCP); setLops(matmultCP); } diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index 7a22727..99cf91e 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -311,6 +311,7 @@ public class DataOp extends Hop } setLineNumbers(l); + setPrivacy(l); setLops(l); //add reblock/checkpoint lops if necessary diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index ba0dd03..79a251f 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; import org.apache.sysds.runtime.util.UtilFunctions; import java.util.ArrayList; @@ -72,6 +73,7 @@ public abstract class Hop implements ParseInfo protected ValueType _valueType; protected boolean _visited = false; protected DataCharacteristics _dc = new MatrixCharacteristics(); + protected PrivacyConstraint _privacyConstraint = new PrivacyConstraint(); protected UpdateType _updateType = UpdateType.COPY; protected ArrayList<Hop> _parent = new ArrayList<>(); @@ -317,9 +319,10 @@ public abstract class Hop implements ParseInfo throw new HopsException(ex); } - setOutputDimensions( reblock ); - setLineNumbers( reblock ); - setLops( reblock ); + setOutputDimensions(reblock); + setLineNumbers(reblock); + setPrivacy(reblock); + setLops(reblock); } } @@ -764,6 +767,14 @@ public abstract class Hop implements ParseInfo return _dc.getNonZeros(); } + public void setPrivacy(PrivacyConstraint privacy){ + _privacyConstraint = privacy; + } + + public PrivacyConstraint getPrivacy(){ + return _privacyConstraint; + } + public void setUpdateType(UpdateType update){ _updateType = update; } @@ -1385,6 +1396,10 @@ public abstract class Hop implements ParseInfo protected void setLineNumbers(Lop lop) { lop.setAllPositions(getFilename(), getBeginLine(), getBeginColumn(), getEndLine(), getEndColumn()); } + + protected void setPrivacy(Lop lop) { + lop.setPrivacyConstraint(getPrivacy()); + } /** * Set parse information. diff --git a/src/main/java/org/apache/sysds/hops/LiteralOp.java b/src/main/java/org/apache/sysds/hops/LiteralOp.java index a7151de..61a7acb 100644 --- a/src/main/java/org/apache/sysds/hops/LiteralOp.java +++ b/src/main/java/org/apache/sysds/hops/LiteralOp.java @@ -112,6 +112,7 @@ public class LiteralOp extends Hop l.getOutputParameters().setDimensions(0, 0, 0, -1); setLineNumbers(l); + setPrivacy(l); setLops(l); } catch(LopsException e) { diff --git a/src/main/java/org/apache/sysds/lops/DataGen.java b/src/main/java/org/apache/sysds/lops/DataGen.java index 27a634c..ddc1a8a 100644 --- a/src/main/java/org/apache/sysds/lops/DataGen.java +++ b/src/main/java/org/apache/sysds/lops/DataGen.java @@ -127,8 +127,8 @@ public class DataGen extends Lop //sanity checks if ( _op != OpOpDG.RAND ) throw new LopsException("Invalid instruction generation for data generation method " + _op); - if( getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.length - 2 && // tensor - getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.length - 1 ) { // matrix + if( getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.size() - 2 && // tensor + getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.size() - 1 ) { // matrix throw new LopsException(printErrorLocation() + "Invalid number of operands (" + getInputs().size() + ") for a Rand operation"); } diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java index 8bb7e1a..76f0caa 100644 --- a/src/main/java/org/apache/sysds/lops/Lop.java +++ b/src/main/java/org/apache/sysds/lops/Lop.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.lops.LopProperties.ExecType; import org.apache.sysds.lops.compile.Dag; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; @@ -112,6 +113,11 @@ public abstract class Lop */ protected ArrayList<Lop> inputs; protected ArrayList<Lop> outputs; + + /** + * Privacy Constraint + */ + protected PrivacyConstraint privacyConstraint; /** * refers to #lops whose input is equal to the output produced by this lop. @@ -273,6 +279,18 @@ public abstract class Lop public void addOutput(Lop op) { outputs.add(op); } + + /** + * Method to set privacy constraint of Lop. + * @param privacy privacy constraint instance + */ + public void setPrivacyConstraint(PrivacyConstraint privacy){ + privacyConstraint = privacy; + } + + public PrivacyConstraint getPrivacyConstraint(){ + return privacyConstraint; + } public void setConsumerCount(int cc) { consumerCount = cc; diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index 6acfb87..65cbb99 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -352,6 +352,9 @@ public class Dag<N extends Lop> String inst_string = n.getInstructions(); CPInstruction currInstr = CPInstructionParser.parseSingleInstruction(inst_string); currInstr.setLocation(n); + // TODO find a more direct way of communicating the privacy constraints + // (visible to runtime explain); This change should apply to all occurrences. + currInstr.setPrivacyConstraint(n); insts.add(currInstr); } catch (DMLRuntimeException e) { throw new LopsException(n.printErrorLocation() + "error generating instructions from input variables in Dag -- \n", e); @@ -406,7 +409,10 @@ public class Dag<N extends Lop> if (locationInfo != null) currInstr.setLocation(locationInfo); else + { currInstr.setLocation(node); + currInstr.setPrivacyConstraint(node); + } inst.add(currInstr); excludeRemoveInstruction(label, deleteInst); @@ -593,12 +599,21 @@ public class Dag<N extends Lop> throw new LopsException("Error parsing the instruction:" + inst_string); } if (node._beginLine != 0) + { currInstr.setLocation(node); + currInstr.setPrivacyConstraint(node); + } else if ( !node.getOutputs().isEmpty() ) + { currInstr.setLocation(node.getOutputs().get(0)); + currInstr.setPrivacyConstraint(node.getOutputs().get(0)); + } else if ( !node.getInputs().isEmpty() ) + { currInstr.setLocation(node.getInputs().get(0)); - + currInstr.setPrivacyConstraint(node.getInputs().get(0)); + } + inst.add(currInstr); } catch (Exception e) { throw new LopsException(node.printErrorLocation() + "Problem generating simple inst - " @@ -785,6 +800,7 @@ public class Dag<N extends Lop> Instruction currInstr = VariableCPInstruction.prepareRemoveInstruction(oparams.getLabel()); currInstr.setLocation(node); + currInstr.setPrivacyConstraint(node); out.addLastInstruction(currInstr); } @@ -806,6 +822,7 @@ public class Dag<N extends Lop> oparams.getUpdateType()); createvarInst.setLocation(node); + createvarInst.setPrivacyConstraint(node); out.addPreInstruction(createvarInst); @@ -813,6 +830,7 @@ public class Dag<N extends Lop> Instruction currInstr = VariableCPInstruction.prepareRemoveInstruction(oparams.getLabel()); currInstr.setLocation(node); + currInstr.setPrivacyConstraint(node); out.addLastInstruction(currInstr); } @@ -832,10 +850,14 @@ public class Dag<N extends Lop> new MatrixCharacteristics(fnOutParams.getNumRows(), fnOutParams.getNumCols(), (int)fnOutParams.getBlocksize(), fnOutParams.getNnz()), oparams.getUpdateType()); - if (node._beginLine != 0) + if (node._beginLine != 0){ createvarInst.setLocation(node); - else + createvarInst.setPrivacyConstraint(node); + } + else { createvarInst.setLocation(fnOut); + createvarInst.setPrivacyConstraint(fnOut); + } out.addPreInstruction(createvarInst); } @@ -985,8 +1007,10 @@ public class Dag<N extends Lop> Instruction currInstr = (node.getExecType() == ExecType.SPARK) ? SPInstructionParser.parseSingleInstruction(io_inst) : CPInstructionParser.parseSingleInstruction(io_inst); - currInstr.setLocation((!node.getInputs().isEmpty() - && node.getInputs().get(0)._beginLine != 0) ? node.getInputs().get(0) : node); + Lop useNode = (!node.getInputs().isEmpty() + && node.getInputs().get(0)._beginLine != 0) ? node.getInputs().get(0) : node; + currInstr.setLocation(useNode); + currInstr.setPrivacyConstraint(useNode); out.addLastInstruction(currInstr); } diff --git a/src/main/java/org/apache/sysds/parser/BinaryExpression.java b/src/main/java/org/apache/sysds/parser/BinaryExpression.java index 86a4558..6c177e2 100644 --- a/src/main/java/org/apache/sysds/parser/BinaryExpression.java +++ b/src/main/java/org/apache/sysds/parser/BinaryExpression.java @@ -23,6 +23,7 @@ import java.util.HashMap; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.privacy.PrivacyPropagator; public class BinaryExpression extends Expression @@ -126,25 +127,28 @@ public class BinaryExpression extends Expression output.setValueType(resultVT); checkAndSetDimensions(output, conditional); - if (this.getOpCode() == Expression.BinaryOp.MATMULT) { - if ((this.getLeft().getOutput().getDataType() != DataType.MATRIX) || (this.getRight().getOutput().getDataType() != DataType.MATRIX)) { + if (getOpCode() == Expression.BinaryOp.MATMULT) { + if ((getLeft().getOutput().getDataType() != DataType.MATRIX) || (getRight().getOutput().getDataType() != DataType.MATRIX)) { // remove exception for now // throw new LanguageException( // "Matrix multiplication not supported for scalars", // LanguageException.LanguageErrorCodes.INVALID_PARAMETERS); } - if (this.getLeft().getOutput().getDim2() != -1 - && this.getRight().getOutput().getDim1() != -1 - && this.getLeft().getOutput().getDim2() != this.getRight() - .getOutput().getDim1()) + if (getLeft().getOutput().getDim2() != -1 && getRight().getOutput().getDim1() != -1 + && getLeft().getOutput().getDim2() != getRight().getOutput().getDim1()) { - raiseValidateError("invalid dimensions for matrix multiplication (k1="+this.getLeft().getOutput().getDim2()+", k2="+this.getRight().getOutput().getDim1()+")", - conditional, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS); + raiseValidateError("invalid dimensions for matrix multiplication (k1=" + +getLeft().getOutput().getDim2()+", k2="+getRight().getOutput().getDim1()+")", + conditional, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS); } - output.setDimensions(this.getLeft().getOutput().getDim1(), this - .getRight().getOutput().getDim2()); + output.setDimensions(getLeft().getOutput().getDim1(), + getRight().getOutput().getDim2()); } + // Set privacy of output + output.setPrivacy(PrivacyPropagator.MergeBinary( + getLeft().getOutput().getPrivacy(), getRight().getOutput().getPrivacy())); + this.setOutput(output); } @@ -199,7 +203,6 @@ public class BinaryExpression extends Expression } return "(" + leftString + " " + _opcode.toString() + " " + rightString + ")"; - } @Override @@ -223,6 +226,5 @@ public class BinaryExpression extends Expression || (op == BinaryOp.MULT) || (op == BinaryOp.DIV) || (op == BinaryOp.MODULUS) || (op == BinaryOp.INTDIV) || (op == BinaryOp.POW); - } } diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index e61c928..f1f64c1 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2082,6 +2082,8 @@ public class DMLTranslator setIdentifierParams(currBuiltinOp, source.getOutput()); if( source.getOpCode()==DataExpression.DataOp.READ ) ((DataOp)currBuiltinOp).setInputBlocksize(target.getBlocksize()); + else if ( source.getOpCode() == DataExpression.DataOp.WRITE ) + ((DataOp)currBuiltinOp).setPrivacy(hops.get(target.getName()).getPrivacy()); currBuiltinOp.setParseInfo(source); return currBuiltinOp; @@ -2747,6 +2749,7 @@ public class DMLTranslator if( id.getNnz()>= 0 ) h.setNnz(id.getNnz()); h.setBlocksize(id.getBlocksize()); + h.setPrivacy(id.getPrivacy()); } private boolean prepareReadAfterWrite( DMLProgram prog, HashMap<String, DataIdentifier> pWrites ) { diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java index 1b7ddc4..baa2b48 100644 --- a/src/main/java/org/apache/sysds/parser/DataExpression.java +++ b/src/main/java/org/apache/sysds/parser/DataExpression.java @@ -45,6 +45,8 @@ import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; import java.util.Map.Entry; @@ -96,6 +98,8 @@ public class DataExpression extends DataIdentifier public static final String SCHEMAPARAM = "schema"; public static final String CREATEDPARAM = "created"; + public static final String PRIVACY = "privacy"; + // Parameter names relevant to reading/writing delimited/csv files public static final String DELIM_DELIMITER = "sep"; public static final String DELIM_HAS_HEADER_ROW = "header"; @@ -107,31 +111,34 @@ public class DataExpression extends DataIdentifier public static final String DELIM_SPARSE = "sparse"; // applicable only for write - public static final String[] RAND_VALID_PARAM_NAMES = - {RAND_ROWS, RAND_COLS, RAND_DIMS, RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, RAND_PDF, RAND_LAMBDA}; + public static final Set<String> RAND_VALID_PARAM_NAMES = new HashSet<>( + Arrays.asList(RAND_ROWS, RAND_COLS, RAND_DIMS, + RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, RAND_PDF, RAND_LAMBDA)); - public static final String[] RESHAPE_VALID_PARAM_NAMES = - { RAND_BY_ROW, RAND_DIMNAMES, RAND_DATA, RAND_ROWS, RAND_COLS, RAND_DIMS}; + public static final Set<String> RESHAPE_VALID_PARAM_NAMES = new HashSet<>( + Arrays.asList(RAND_BY_ROW, RAND_DIMNAMES, RAND_DATA, RAND_ROWS, RAND_COLS, RAND_DIMS)); - public static final String[] SQL_VALID_PARAM_NAMES = {SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY}; + public static final Set<String> SQL_VALID_PARAM_NAMES = new HashSet<>( + Arrays.asList(SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY)); - public static final String[] FEDERATED_VALID_PARAM_NAMES = {FED_ADDRESSES, FED_RANGES}; + public static final Set<String> FEDERATED_VALID_PARAM_NAMES = new HashSet<>( + Arrays.asList(FED_ADDRESSES, FED_RANGES)); // Valid parameter names in a metadata file - public static final String[] READ_VALID_MTD_PARAM_NAMES = - { IO_FILENAME, READROWPARAM, READCOLPARAM, READNNZPARAM, FORMAT_TYPE, - ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, DATATYPEPARAM, VALUETYPEPARAM, SCHEMAPARAM, DESCRIPTIONPARAM, - AUTHORPARAM, CREATEDPARAM, + public static final Set<String> READ_VALID_MTD_PARAM_NAMES =new HashSet<>( + Arrays.asList(IO_FILENAME, READROWPARAM, READCOLPARAM, READNNZPARAM, + FORMAT_TYPE, ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, DATATYPEPARAM, + VALUETYPEPARAM, SCHEMAPARAM, DESCRIPTIONPARAM, AUTHORPARAM, CREATEDPARAM, // Parameters related to delimited/csv files. - DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS - }; + DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS, + // Parameters related to privacy + PRIVACY)); - public static final String[] READ_VALID_PARAM_NAMES = - { IO_FILENAME, READROWPARAM, READCOLPARAM, FORMAT_TYPE, DATATYPEPARAM, VALUETYPEPARAM, SCHEMAPARAM, - ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, READNNZPARAM, + public static final Set<String> READ_VALID_PARAM_NAMES = new HashSet<>( + Arrays.asList(IO_FILENAME, READROWPARAM, READCOLPARAM, FORMAT_TYPE, DATATYPEPARAM, + VALUETYPEPARAM, SCHEMAPARAM, ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, READNNZPARAM, // Parameters related to delimited/csv files. - DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS - }; + DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS)); /* Default Values for delimited (CSV/LIBSVM) files */ public static final String DEFAULT_DELIM_DELIMITER = ","; @@ -210,11 +217,8 @@ public class DataExpression extends DataIdentifier return null; } // verify parameter names for read function - boolean isValidName = false; - for (String paramName : READ_VALID_PARAM_NAMES){ - if (paramName.equals(currName)) - isValidName = true; - } + boolean isValidName = READ_VALID_PARAM_NAMES.contains(currName); + if (!isValidName){ errorListener.validationError(parseInfo, "attempted to add invalid read statement parameter " + currName); return null; @@ -466,15 +470,7 @@ public class DataExpression extends DataIdentifier return; } // check name is valid - boolean found = false; - if (paramName != null ){ - for (String name : RAND_VALID_PARAM_NAMES){ - if (name.equals(paramName)) { - found = true; - break; - } - } - } + boolean found = RAND_VALID_PARAM_NAMES.contains(paramName); if (!found){ raiseValidateError("unexpected parameter \"" + paramName + "\". Legal parameters for Rand statement are " @@ -500,10 +496,7 @@ public class DataExpression extends DataIdentifier public void addMatrixExprParam(String paramName, Expression paramValue) { // check name is valid - boolean found = false; - if (paramName != null ){ - found = Arrays.stream(RESHAPE_VALID_PARAM_NAMES).anyMatch((name) -> name.equals(paramName)); - } + boolean found = RESHAPE_VALID_PARAM_NAMES.contains(paramName); if (!found){ raiseValidateError("unexpected parameter \"" + paramName + @@ -529,10 +522,7 @@ public class DataExpression extends DataIdentifier public void addTensorExprParam(String paramName, Expression paramValue) { // check name is valid - boolean found = false; - if (paramName != null ){ - found = Arrays.asList(RESHAPE_VALID_PARAM_NAMES).contains(paramName); - } + boolean found = RESHAPE_VALID_PARAM_NAMES.contains(paramName); if (!found){ raiseValidateError("unexpected parameter \"" + paramName + "\". Legal parameters for tensor statement are " @@ -558,10 +548,7 @@ public class DataExpression extends DataIdentifier public void addSqlExprParam(String paramName, Expression paramValue) { // check name is valid - boolean found = false; - if (paramName != null ){ - found = Arrays.asList(SQL_VALID_PARAM_NAMES).contains(paramName); - } + boolean found = SQL_VALID_PARAM_NAMES.contains(paramName); if (!found){ raiseValidateError("unexpected parameter \"" + paramName + "\". Legal parameters for sql statement are " @@ -578,8 +565,7 @@ public class DataExpression extends DataIdentifier public void addFederatedExprParam(String paramName, Expression paramValue) { // check name is valid - boolean found = (paramName != null ) && - Arrays.asList(FEDERATED_VALID_PARAM_NAMES).contains(paramName); + boolean found = FEDERATED_VALID_PARAM_NAMES.contains(paramName); if (!found) raiseValidateError("unexpected parameter \"" + paramName + "\". Legal parameters for federated statement are " @@ -988,17 +974,11 @@ public class DataExpression extends DataIdentifier || key.equals(READNNZPARAM) || key.equals(DATATYPEPARAM) || key.equals(VALUETYPEPARAM) || key.equals(SCHEMAPARAM)) ) { - String msg = "Only parameters allowed are: " + IO_FILENAME + "," - + SCHEMAPARAM + "," - + DELIM_HAS_HEADER_ROW + "," - + DELIM_DELIMITER + "," - + DELIM_FILL + "," - + DELIM_FILL_VALUE + "," - + READROWPARAM + "," - + READCOLPARAM; - + String msg = "Only parameters allowed are: " + Arrays.toString(new String[] { + IO_FILENAME, SCHEMAPARAM, DELIM_HAS_HEADER_ROW, DELIM_DELIMITER, + DELIM_FILL, DELIM_FILL_VALUE, READROWPARAM, READCOLPARAM}); raiseValidateError("Invalid parameter " + key + " in read statement: " + - toString() + ". " + msg, conditional, LanguageErrorCodes.INVALID_PARAMETERS); + toString() + ". " + msg, conditional, LanguageErrorCodes.INVALID_PARAMETERS); } } } @@ -1087,18 +1067,25 @@ public class DataExpression extends DataIdentifier isMatrix = true; // set data type - getOutput().setDataType(isMatrix ? DataType.MATRIX : DataType.FRAME); - - // set number non-zeros - Expression ennz = this.getVarParam("nnz"); - long nnz = -1; - if( ennz != null ) - { - nnz = Long.valueOf(ennz.toString()); - getOutput().setNnz(nnz); - } - - // Following dimension checks must be done when data type = MATRIX_DATA_TYPE + getOutput().setDataType(isMatrix ? DataType.MATRIX : DataType.FRAME); + + // set number non-zeros + Expression ennz = getVarParam("nnz"); + long nnz = -1; + if( ennz != null ) { + nnz = Long.valueOf(ennz.toString()); + getOutput().setNnz(nnz); + } + + // set privacy + Expression eprivacy = getVarParam("privacy"); + boolean privacy = false; + if ( eprivacy != null ) { + privacy = Boolean.valueOf(eprivacy.toString()); + getOutput().setPrivacy(privacy); + } + + // Following dimension checks must be done when data type = MATRIX_DATA_TYPE // initialize size of target data identifier to UNKNOWN getOutput().setDimensions(-1, -1); @@ -1919,13 +1906,10 @@ public class DataExpression extends DataIdentifier } } - private void validateParams(boolean conditional, String[] validParamNames, String legalMessage) { + private void validateParams(boolean conditional, Set<String> validParamNames, String legalMessage) { for( String key : _varParams.keySet() ) { - boolean found = false; - for (String name : validParamNames) { - found |= name.equals(key); - } + boolean found = validParamNames.contains(key); if( !found ) { raiseValidateError("unexpected parameter \"" + key + "\". " + legalMessage, conditional); @@ -2061,11 +2045,7 @@ public class DataExpression extends DataIdentifier Object key = e.getKey(); Object val = e.getValue(); - boolean isValidName = false; - for (String paramName : READ_VALID_MTD_PARAM_NAMES){ - if (paramName.equals(key)) - isValidName = true; - } + boolean isValidName = READ_VALID_MTD_PARAM_NAMES.contains(key); if (!isValidName){ //wrong parameters always rejected raiseValidateError("MTD file " + mtdFileName + " contains invalid parameter name: " + key, false); @@ -2091,6 +2071,7 @@ public class DataExpression extends DataIdentifier if ( key.toString().equalsIgnoreCase(DELIM_HAS_HEADER_ROW) || key.toString().equalsIgnoreCase(DELIM_FILL) || key.toString().equalsIgnoreCase(DELIM_SPARSE) + || key.toString().equalsIgnoreCase(PRIVACY) ) { // parse these parameters as boolean values BooleanIdentifier boolId = null; diff --git a/src/main/java/org/apache/sysds/parser/Identifier.java b/src/main/java/org/apache/sysds/parser/Identifier.java index 402069e..39da340 100644 --- a/src/main/java/org/apache/sysds/parser/Identifier.java +++ b/src/main/java/org/apache/sysds/parser/Identifier.java @@ -24,6 +24,7 @@ import java.util.HashMap; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.parser.LanguageException.LanguageErrorCodes; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; public abstract class Identifier extends Expression { @@ -34,6 +35,7 @@ public abstract class Identifier extends Expression protected int _blocksize; protected long _nnz; protected FormatType _formatType; + protected PrivacyConstraint _privacy; public Identifier() { _dim1 = -1; @@ -62,6 +64,7 @@ public abstract class Identifier extends Expression _blocksize = i.getBlocksize(); _nnz = i.getNnz(); _formatType = i.getFormatType(); + _privacy = i.getPrivacy(); } public void setDimensionValueProperties(Identifier i) { @@ -99,6 +102,14 @@ public abstract class Identifier extends Expression public void setNnz(long nnzs){ _nnz = nnzs; } + + public void setPrivacy(boolean privacy){ + _privacy = new PrivacyConstraint(privacy); + } + + public void setPrivacy(PrivacyConstraint privacyConstraint){ + _privacy = privacyConstraint; + } public long getDim1(){ return _dim1; @@ -131,6 +142,10 @@ public abstract class Identifier extends Expression public long getNnz(){ return _nnz; } + + public PrivacyConstraint getPrivacy(){ + return _privacy; + } @Override public void validateExpression(HashMap<String,DataIdentifier> ids, HashMap<String,ConstIdentifier> constVars, boolean conditional) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index a27318a..32b6162 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -46,6 +46,7 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaData; import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.runtime.util.LocalFileUtils; import org.apache.sysds.utils.Statistics; @@ -161,6 +162,11 @@ public abstract class CacheableData<T extends CacheBlock> extends Data * must get the OutputInfo that matches with InputInfo stored inside _mtd. */ protected MetaData _metaData = null; + + /** + * Object holding all privacy constraints associated with the cacheable data. + */ + protected PrivacyConstraint _privacyConstraint = null; /** The name of HDFS file in which the data is backed up. */ protected String _hdfsFileName = null; // file name and path @@ -305,6 +311,14 @@ public abstract class CacheableData<T extends CacheBlock> extends Data public void removeMetaData() { _metaData = null; } + + public void setPrivacyConstraints(PrivacyConstraint pc) { + _privacyConstraint = pc; + } + + public PrivacyConstraint getPrivacyConstraint() { + return _privacyConstraint; + } public DataCharacteristics getDataCharacteristics() { return _metaData.getDataCharacteristics(); @@ -930,7 +944,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data //write the actual meta data file HDFSTool.writeMetaDataFile (filePathAndName + ".mtd", valueType, - getSchema(), dataType, dc, oinfo, formatProperties); + getSchema(), dataType, dc, oinfo, formatProperties, _privacyConstraint); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java index c3adaeb..adcae38 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java @@ -27,6 +27,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.lops.Lop; import org.apache.sysds.parser.DataIdentifier; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; public abstract class Instruction { @@ -69,6 +70,9 @@ public abstract class Instruction protected int endLine = -1; protected int beginCol = -1; protected int endCol = -1; + + //privacy meta data + protected PrivacyConstraint privacyConstraint = null; public String getFilename() { return filename; @@ -129,6 +133,14 @@ public abstract class Instruction this.endCol = oldInst.endCol; } } + + public void setPrivacyConstraint(Lop lop){ + privacyConstraint = lop.getPrivacyConstraint(); + } + + public PrivacyConstraint getPrivacyConstraint(){ + return privacyConstraint; + } /** * Getter for instruction line number diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 9176335..456b999 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -527,6 +527,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace //clone meta data because it is updated on copy-on-write, otherwise there //is potential for hidden side effects between variables. obj.setMetaData((MetaData)metadata.clone()); + obj.setPrivacyConstraints(getPrivacyConstraint()); obj.setFileFormatProperties(_formatProperties); obj.setMarkForLinCache(true); obj.enableCleanup(!getInput1().getName() @@ -895,6 +896,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace else { // Default behavior MatrixObject mo = ec.getMatrixObject(getInput1().getName()); + mo.setPrivacyConstraints(getPrivacyConstraint()); mo.exportData(fname, outFmt, _formatProperties); } } diff --git a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java index 192a477..893d665 100644 --- a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java +++ b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java @@ -58,8 +58,8 @@ public abstract class MatrixReader public abstract MatrixBlock readMatrixFromHDFS( String fname, long rlen, long clen, int blen, long estnnz ) throws IOException, DMLRuntimeException; - public abstract MatrixBlock readMatrixFromInputStream( InputStream is, long rlen, long clen, int blen, long estnnz ) - throws IOException, DMLRuntimeException; + public abstract MatrixBlock readMatrixFromInputStream( InputStream is, long rlen, long clen, int blen, long estnnz) + throws IOException, DMLRuntimeException; /** * NOTE: mallocDense controls if the output matrix blocks is fully allocated, this can be redundant diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java new file mode 100644 index 0000000..2b32636 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.privacy; + +/** + * PrivacyConstraint holds all privacy constraints for data in the system at compile time and runtime. + */ +public class PrivacyConstraint +{ + protected boolean _privacy = false; + + public PrivacyConstraint(){} + + public PrivacyConstraint(boolean privacy) { + _privacy = privacy; + } + + public void setPrivacy(boolean privacy){ + _privacy = privacy; + } + + public boolean getPrivacy(){ + return _privacy; + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java new file mode 100644 index 0000000..2070c99 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.privacy; + +/** + * Class with static methods merging privacy constraints of operands + * in expressions to generate the privacy constraints of the output. + */ +public class PrivacyPropagator { + + public static PrivacyConstraint MergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) { + if (privacyConstraint1 != null && privacyConstraint2 != null) + return new PrivacyConstraint( + privacyConstraint1.getPrivacy() || privacyConstraint2.getPrivacy()); + else if (privacyConstraint1 != null) + return privacyConstraint1; + else if (privacyConstraint2 != null) + return privacyConstraint2; + return null; + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java index 643c509..bb37873 100644 --- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java +++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java @@ -48,6 +48,7 @@ import org.apache.sysds.runtime.matrix.data.InputInfo; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.OutputInfo; import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; import java.io.BufferedReader; import java.io.BufferedWriter; @@ -341,10 +342,20 @@ public class HDFSTool throws IOException { writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, mc, outinfo); } + + public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics mc, OutputInfo outinfo, PrivacyConstraint privacyConstraint) + throws IOException { + writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, mc, outinfo, null, privacyConstraint); + } public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics mc, OutputInfo outinfo) throws IOException { - writeMetaDataFile(mtdfile, vt, schema, dt, mc, outinfo, null); + writeMetaDataFile(mtdfile, vt, schema, dt, mc, outinfo, (PrivacyConstraint) null); + } + + public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics mc, OutputInfo outinfo, PrivacyConstraint privacyConstraint) + throws IOException { + writeMetaDataFile(mtdfile, vt, schema, dt, mc, outinfo, null, privacyConstraint); } public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics dc, OutputInfo outinfo, FileFormatProperties formatProperties) @@ -356,10 +367,17 @@ public class HDFSTool OutputInfo outinfo, FileFormatProperties formatProperties) throws IOException { + writeMetaDataFile(mtdfile, vt, schema, dt, dc, outinfo, formatProperties, null); + } + + public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc, + OutputInfo outinfo, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) + throws IOException + { Path path = new Path(mtdfile); FileSystem fs = IOUtilFunctions.getFileSystem(path); try( BufferedWriter br = new BufferedWriter(new OutputStreamWriter(fs.create(path,true))) ) { - String mtd = metaDataToString(vt, schema, dt, dc, outinfo, formatProperties); + String mtd = metaDataToString(vt, schema, dt, dc, outinfo, formatProperties, privacyConstraint); br.write(mtd); } catch (Exception e) { throw new IOException("Error creating and writing metadata JSON file", e); @@ -369,10 +387,16 @@ public class HDFSTool public static void writeScalarMetaDataFile(String mtdfile, ValueType vt) throws IOException { + writeScalarMetaDataFile(mtdfile, vt, null); + } + + public static void writeScalarMetaDataFile(String mtdfile, ValueType vt, PrivacyConstraint privacyConstraint) + throws IOException + { Path path = new Path(mtdfile); FileSystem fs = IOUtilFunctions.getFileSystem(path); try( BufferedWriter br = new BufferedWriter(new OutputStreamWriter(fs.create(path,true))) ) { - String mtd = metaDataToString(vt, null, DataType.SCALAR, null, OutputInfo.TextCellOutputInfo, null); + String mtd = metaDataToString(vt, null, DataType.SCALAR, null, OutputInfo.TextCellOutputInfo, null, privacyConstraint); br.write(mtd); } catch (Exception e) { @@ -381,7 +405,7 @@ public class HDFSTool } public static String metaDataToString(ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc, - OutputInfo outinfo, FileFormatProperties formatProperties) throws JSONException, DMLRuntimeException + OutputInfo outinfo, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) throws JSONException, DMLRuntimeException { OrderedJSONObject mtd = new OrderedJSONObject(); // maintain order in output file @@ -427,6 +451,12 @@ public class HDFSTool } } + //add privacy constraints + if ( privacyConstraint != null ){ + mtd.put(DataExpression.PRIVACY, privacyConstraint.getPrivacy()); + } + + //add username and time String userName = System.getProperty("user.name"); if (StringUtils.isNotEmpty(userName)) { mtd.put(DataExpression.AUTHORPARAM, userName); diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index c3d224f..5217d0c 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -37,6 +37,7 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.SparkSession.Builder; +import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; import org.junit.After; import org.junit.Assert; @@ -61,6 +62,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.OutputInfo; import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.utils.ParameterBuilder; @@ -439,17 +441,32 @@ public abstract class AutomatedTestBase { protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, long nnz, boolean bIncludeR) { MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length, OptimizerUtils.DEFAULT_BLOCKSIZE, nnz); - return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc); + return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, null); + } + + protected double [][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR, + MatrixCharacteristics mc) { + return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, null); + } + + protected double [][] writeInputMatrixWithMTD(String name, double[][] matrix, PrivacyConstraint privacyConstraint) { + return writeInputMatrixWithMTD(name, matrix, false, null, privacyConstraint); + } + + protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR, PrivacyConstraint privacyConstraint) { + MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length, + OptimizerUtils.DEFAULT_BLOCKSIZE, -1); + return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, privacyConstraint); } protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR, - MatrixCharacteristics mc) { + MatrixCharacteristics mc, PrivacyConstraint privacyConstraint) { writeInputMatrix(name, matrix, bIncludeR); // write metadata file try { String completeMTDPath = baseDirectory + INPUT_DIR + name + ".mtd"; - HDFSTool.writeMetaDataFile(completeMTDPath, ValueType.FP64, mc, OutputInfo.stringToOutputInfo("textcell")); + HDFSTool.writeMetaDataFile(completeMTDPath, ValueType.FP64, mc, OutputInfo.stringToOutputInfo("textcell"), privacyConstraint); } catch(IOException e) { e.printStackTrace(); @@ -678,8 +695,7 @@ public abstract class AutomatedTestBase { public static MatrixCharacteristics readDMLMetaDataFile(String fileName) { try { - String fname = baseDirectory + OUTPUT_DIR + fileName + ".mtd"; - JSONObject meta = new DataExpression().readMetadataFile(fname, false); + JSONObject meta = getMetaDataJSON(fileName); long rlen = Long.parseLong(meta.get(DataExpression.READROWPARAM).toString()); long clen = Long.parseLong(meta.get(DataExpression.READCOLPARAM).toString()); return new MatrixCharacteristics(rlen, clen, -1, -1); @@ -689,10 +705,23 @@ public abstract class AutomatedTestBase { } } + public static JSONObject getMetaDataJSON(String fileName) { + return getMetaDataJSON(fileName, OUTPUT_DIR); + } + + public static JSONObject getMetaDataJSON(String fileName, String outputDir) { + String fname = baseDirectory + outputDir + fileName + ".mtd"; + return new DataExpression().readMetadataFile(fname, false); + } + + public static String readDMLMetaDataValue(String fileName, String outputDir, String key) throws JSONException { + JSONObject meta = getMetaDataJSON(fileName, outputDir); + return meta.get(key).toString(); + } + public static ValueType readDMLMetaDataValueType(String fileName) { try { - String fname = baseDirectory + OUTPUT_DIR + fileName + ".mtd"; - JSONObject meta = new DataExpression().readMetadataFile(fname, false); + JSONObject meta = getMetaDataJSON(fileName); return ValueType.fromExternalString(meta.get(DataExpression.VALUETYPEPARAM).toString()); } catch(Exception ex) { diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index a843552..a23f7b6 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -118,16 +118,7 @@ public class TestUtils Path compareFile = new Path(expectedFile); FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf); FSDataInputStream fsin = fs.open(compareFile); - try( BufferedReader compareIn = new BufferedReader(new InputStreamReader(fsin)) ) { - String line; - while ((line = compareIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - expectedValues.put(new CellIndex(i, j), v); - } - } + readValuesFromFileStream(fsin, expectedValues); HashMap<CellIndex, Double> actualValues = new HashMap<>(); @@ -135,16 +126,7 @@ public class TestUtils for (FileStatus file : outFiles) { FSDataInputStream fsout = fs.open(file.getPath()); - try( BufferedReader outIn = new BufferedReader(new InputStreamReader(fsout)) ) { - String line = null; - while ((line = outIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - actualValues.put(new CellIndex(i, j), v); - } - } + readValuesFromFileStream(fsout, actualValues); } ArrayList<Double> e_list = new ArrayList<>(); @@ -208,13 +190,7 @@ public class TestUtils line = compareIn.readLine(); expRcn = line.split(" "); - while ((line = compareIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - expectedValues.put(new CellIndex(i, j), v); - } + readValuesFromFileStreamAndPut(compareIn, expectedValues); } HashMap<CellIndex, Double> actualValues = new HashMap<>(); @@ -238,14 +214,8 @@ public class TestUtils else if (Integer.parseInt(expRcn[2]) != Integer.parseInt(rcn[2])) { System.out.println(" Nnz mismatch: expected " + Integer.parseInt(expRcn[2]) + ", actual " + Integer.parseInt(rcn[2])); } - - while ((line = outIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - actualValues.put(new CellIndex(i, j), v); - } + + readValuesFromFileStreamAndPut(outIn, actualValues); } @@ -270,6 +240,38 @@ public class TestUtils } /** + * Read doubles from the input stream and put them into the given hashmap of values. + * @param inputStream input stream of doubles with related indices + * @param values hashmap of values (initially empty) + * @throws IOException + */ + public static void readValuesFromFileStream(FSDataInputStream inputStream, HashMap<CellIndex, Double> values) + throws IOException + { + try( BufferedReader inReader = new BufferedReader(new InputStreamReader(inputStream)) ) { + readValuesFromFileStreamAndPut(inReader, values); + } + } + + /** + * Read values from file stream and put into hashmap + * @param inReader BufferedReader to read values from + * @param values hashmap where values are put + */ + public static void readValuesFromFileStreamAndPut(BufferedReader inReader, HashMap<CellIndex, Double> values) + throws IOException + { + String line = null; + while ((line = inReader.readLine()) != null) { + StringTokenizer st = new StringTokenizer(line, " "); + int i = Integer.parseInt(st.nextToken()); + int j = Integer.parseInt(st.nextToken()); + double v = Double.parseDouble(st.nextToken()); + values.put(new CellIndex(i, j), v); + } + } + + /** * <p> * Compares the expected values calculated in Java by testcase and which are * in the normal filesystem, with those calculated by SystemDS located in @@ -289,37 +291,17 @@ public class TestUtils Path outDirectory = new Path(actualDir); Path compareFile = new Path(expectedFile); FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf); - FSDataInputStream fsin = fs.open(compareFile); + FSDataInputStream fsin = fs.open(compareFile); HashMap<CellIndex, Double> expectedValues = new HashMap<>(); - - try( BufferedReader compareIn = new BufferedReader(new InputStreamReader(fsin)) ) { - String line; - while ((line = compareIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - expectedValues.put(new CellIndex(i, j), v); - } - } + readValuesFromFileStream(fsin, expectedValues); HashMap<CellIndex, Double> actualValues = new HashMap<>(); - FileStatus[] outFiles = fs.listStatus(outDirectory); for (FileStatus file : outFiles) { FSDataInputStream fsout = fs.open(file.getPath()); - try( BufferedReader outIn = new BufferedReader(new InputStreamReader(fsout)) ) { - String line = null; - while ((line = outIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - actualValues.put(new CellIndex(i, j), v); - } - } + readValuesFromFileStream(fsout, actualValues); } int countErrors = 0; @@ -378,20 +360,11 @@ public class TestUtils { Path outDirectory = new Path(filePath); FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf); - String line; FileStatus[] outFiles = fs.listStatus(outDirectory); for (FileStatus file : outFiles) { FSDataInputStream outIn = fs.open(file.getPath()); - try(BufferedReader reader = new BufferedReader(new InputStreamReader(outIn)) ) { - while ((line = reader.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - expectedValues.put(new CellIndex(i,j), v); - } - } + readValuesFromFileStream(outIn, expectedValues); } } catch (IOException e) { @@ -1036,33 +1009,17 @@ public class TestUtils HashMap<CellIndex, Double> expectedValues = new HashMap<>(); HashMap<CellIndex, Double> actualValues = new HashMap<>(); try(BufferedReader compareIn = new BufferedReader(new FileReader(rFile))) { - String line; // skip both R header lines compareIn.readLine(); compareIn.readLine(); - while ((line = compareIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - expectedValues.put(new CellIndex(i, j), v); - } + readValuesFromFileStreamAndPut(compareIn, expectedValues); } FileStatus[] outFiles = fs.listStatus(outDirectory); for (FileStatus file : outFiles) { FSDataInputStream fsout = fs.open(file.getPath()); - try(BufferedReader outIn = new BufferedReader(new InputStreamReader(fsout))) { - String line = null; - while ((line = outIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - double v = Double.parseDouble(st.nextToken()); - actualValues.put(new CellIndex(i, j), v); - } - } + readValuesFromFileStream(fsout, actualValues); } int countErrors = 0; diff --git a/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java b/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java index 37fc799..e670eab 100644 --- a/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java @@ -119,7 +119,7 @@ public class WriteMMTest extends AutomatedTestBase input("A"), Integer.toString(rows), Integer.toString(cols), output("B") }; //generate actual dataset - double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, System.currentTimeMillis()); + double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, System.currentTimeMillis()); writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows,cols, 1000, 1000)); writeExpectedMatrixMarket("B", A); diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java new file mode 100644 index 0000000..a16355a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.privacy; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import org.apache.wink.json4j.JSONException; +import org.junit.Test; +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +public class MatrixMultiplicationPropagationTest extends AutomatedTestBase { + + private static final String TEST_DIR = "functions/privacy/"; + private final static String TEST_CLASS_DIR = TEST_DIR + MatrixMultiplicationPropagationTest.class.getSimpleName() + "/"; + private final int m = 20; + private final int n = 20; + private final int k = 20; + + @Override + public void setUp() { + addTestConfiguration("MatrixMultiplicationPropagationTest", + new TestConfiguration(TEST_CLASS_DIR, "MatrixMultiplicationPropagationTest", new String[]{"c"})); + } + + @Test + public void testMatrixMultiplicationPropagation() throws JSONException { + matrixMultiplicationPropagation(true, true); + } + + @Test + public void testMatrixMultiplicationPropagationFalse() throws JSONException { + matrixMultiplicationPropagation(false, true); + } + + @Test + public void testMatrixMultiplicationPropagationSecondOperand() throws JSONException { + matrixMultiplicationPropagation(true, false); + } + + @Test + public void testMatrixMultiplicationPropagationSecondOperandFalse() throws JSONException { + matrixMultiplicationPropagation(false, false); + } + + private void matrixMultiplicationPropagation(boolean privacy, boolean privateFirstOperand) throws JSONException { + + TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest"); + loadTestConfiguration(config); + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + config.getTestScript() + ".dml"; + programArgs = new String[]{"-nvargs", + "a=" + input("a"), "b=" + input("b"), "c=" + output("c"), + "m=" + m, "n=" + n, "k=" + k}; + + double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1); + double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1); + double[][] c = TestUtils.performMatrixMultiplication(a, b); + + PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacy); + MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k); + + if ( privateFirstOperand ) { + writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint); + writeInputMatrix("b", b); + } + else { + writeInputMatrix("a", a); + writeInputMatrixWithMTD("b", b, false, dataCharacteristics, privacyConstraint); + } + + writeExpectedMatrix("c", c); + + runTest(true,false,null,-1); + + // Check that the output data is correct + compareResults(1e-9); + + // Check that the output metadata is correct + String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY); + assertEquals(String.valueOf(privacy), actualPrivacyValue); + } + + @Test + public void testMatrixMultiplicationNoPropagation() { + matrixMultiplicationNoPropagation(); + } + + private void matrixMultiplicationNoPropagation() { + TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest"); + loadTestConfiguration(config); + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + config.getTestScript() + ".dml"; + programArgs = new String[]{ "-nvargs", + "a=" + input("a"), "b=" + input("b"), "c=" + output("c"), + "m=" + m, "n=" + n, "k=" + k}; + + double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1); + double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1); + double[][] c = TestUtils.performMatrixMultiplication(a, b); + + + writeInputMatrix("a", a); + writeInputMatrix("b", b); + writeExpectedMatrix("c", c); + + runTest(true,false,null,-1); + + // Check that the output data is correct + compareResults(1e-9); + + // Check that a JSONException is thrown + // because no privacy metadata should be written to c + boolean JSONExceptionThrown = false; + try{ + readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY); + } catch (JSONException e){ + JSONExceptionThrown = true; + } catch (Exception e){ + fail("Exception occured, but JSONException was expected. The exception thrown is: " + e.getMessage()); + e.printStackTrace(); + } + assert(JSONExceptionThrown); + } + + @Test + public void testMatrixMultiplicationPrivacyInputTrue() throws JSONException { + testMatrixMultiplicationPrivacyInput(true); + } + + @Test + public void testMatrixMultiplicationPrivacyInputFalse() throws JSONException { + testMatrixMultiplicationPrivacyInput(false); + } + + private void testMatrixMultiplicationPrivacyInput(boolean privacy) throws JSONException { + TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest"); + loadTestConfiguration(config); + + double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1); + + PrivacyConstraint privacyConstraint = new PrivacyConstraint(); + privacyConstraint.setPrivacy(privacy); + MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k); + + writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint); + + String actualPrivacyValue = readDMLMetaDataValue("a", INPUT_DIR, DataExpression.PRIVACY); + assertEquals(String.valueOf(privacy), actualPrivacyValue); + } +} diff --git a/src/test/scripts/functions/privacy/MatrixMultiplicationPropagationTest.dml b/src/test/scripts/functions/privacy/MatrixMultiplicationPropagationTest.dml new file mode 100644 index 0000000..9705cef --- /dev/null +++ b/src/test/scripts/functions/privacy/MatrixMultiplicationPropagationTest.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# junit test class: org.tugraz.sysds.test.integration.functions.binary.matrix.MatrixMultiplicationTest.java + +A = read($a, rows=$m, cols=$n, format="text"); +B = read($b, rows=$n, cols=$k, format="text"); +C = A %*% B; +write(C, $c, format="text"); \ No newline at end of file