This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b36729701 [SYSTEMDS-3018] Federated Coordinator Privacy Constraint 
Retrieval
7b36729701 is described below

commit 7b367297012ea3bb9639d47783baeda99f2f7057
Author: sebwrede <[email protected]>
AuthorDate: Tue Jun 28 14:14:22 2022 +0200

    [SYSTEMDS-3018] Federated Coordinator Privacy Constraint Retrieval
    
    This commit will:
    - Include all privacy constraints in remote retrieval
    - Add privacy constraint propagation to all compiled federated planners
    - Add PrivacyConstraintLoader which handles loading of privacy constraints 
from federated workers and propagation of the constraints at the coordinator
    - Add privacy constraint to Explain output
    - Add FederatedPlannerUtil class
    - Edit hop propagation to throw exception when hop type is unknown and hop 
has privacy constraint on input
    
    Closes #1651.
---
 .../hops/fedplanner/FederatedPlannerCostbased.java |  38 ++-
 .../hops/fedplanner/FederatedPlannerUtils.java     |  67 +++++
 .../hops/fedplanner/PrivacyConstraintLoader.java   | 281 +++++++++++++++++++++
 .../hops/ipa/IPAPassRewriteFederatedPlan.java      |  17 +-
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   7 -
 .../hops/rewrite/RewriteFederatedExecution.java    | 187 --------------
 .../sysds/runtime/privacy/PrivacyConstraint.java   |   7 +-
 .../privacy/propagation/PrivacyPropagator.java     |  44 +++-
 src/main/java/org/apache/sysds/utils/Explain.java  |   3 +
 .../fedplanning/FederatedMultiplyPlanningTest.java |   8 +
 .../FederatedMultiplyPlanningTest11.dml            |  34 +++
 .../FederatedMultiplyPlanningTest11Reference.dml   |  32 +++
 12 files changed, 497 insertions(+), 228 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index 1f9abb4c18..3c33d783ab 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -170,7 +170,7 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                                selectFederatedExecutionPlan(sbHop, paramMap);
                                if(sbHop instanceof FunctionOp) {
                                        String funcName = ((FunctionOp) 
sbHop).getFunctionName();
-                                       Map<String, Hop> funcParamMap = 
getParamMap((FunctionOp) sbHop);
+                                       Map<String, Hop> funcParamMap = 
FederatedPlannerUtils.getParamMap((FunctionOp) sbHop);
                                        if ( paramMap != null && funcParamMap 
!= null)
                                                funcParamMap.putAll(paramMap);
                                        paramMap = funcParamMap;
@@ -182,22 +182,6 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                return new ArrayList<>(Collections.singletonList(sb));
        }
 
-       /**
-        * Return parameter map containing the mapping from parameter name to 
input hop
-        * for all parameters of the function hop.
-        * @param funcOp hop for which the mapping of parameter names to input 
hops are made
-        * @return parameter map or empty map if function has no parameters
-        */
-       private Map<String,Hop> getParamMap(FunctionOp funcOp){
-               String[] inputNames = funcOp.getInputVariableNames();
-               Map<String,Hop> paramMap = new HashMap<>();
-               if ( inputNames != null ){
-                       for ( int i = 0; i < funcOp.getInput().size(); i++ )
-                               paramMap.put(inputNames[i],funcOp.getInput(i));
-               }
-               return paramMap;
-       }
-
        /**
         * Set final fedouts of all hops starting from terminal hops.
         */
@@ -327,13 +311,21 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                ArrayList<HopRel> hopRels = getFedPlans(currentHop, paramMap);
                // Put NONE HopRel into memo table if no FOUT or LOUT HopRels 
were added
                if(hopRels.isEmpty())
-                       hopRels.add(getNONEHopRel(currentHop));
+                       hopRels.add(getNONEHopRel(currentHop, paramMap));
                addTrace(hopRels);
                hopRelMemo.put(currentHop, hopRels);
        }
 
-       private HopRel getNONEHopRel(Hop currentHop){
-               HopRel noneHopRel = new HopRel(currentHop, 
FederatedOutput.NONE, hopRelMemo);
+       private ArrayList<Hop> getHopInputs(Hop currentHop, Map<String, Hop> 
paramMap){
+               if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) )
+                       return 
FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites);
+               else
+                       return currentHop.getInput();
+       }
+
+       private HopRel getNONEHopRel(Hop currentHop, Map<String, Hop> paramMap){
+               ArrayList<Hop> inputs = getHopInputs(currentHop, paramMap);
+               HopRel noneHopRel = new HopRel(currentHop, 
FederatedOutput.NONE, hopRelMemo, inputs);
                FType[] inputFType = 
noneHopRel.getInputDependency().stream().map(HopRel::getFType).toArray(FType[]::new);
                FType outputFType = getFederatedOut(currentHop, inputFType);
                noneHopRel.setFType(outputFType);
@@ -348,9 +340,7 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
         */
        private ArrayList<HopRel> getFedPlans(Hop currentHop, Map<String, Hop> 
paramMap){
                ArrayList<HopRel> hopRels = new ArrayList<>();
-               ArrayList<Hop> inputHops = currentHop.getInput();
-               if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) )
-                       inputHops = getTransientInputs(currentHop, paramMap);
+               ArrayList<Hop> inputHops = getHopInputs(currentHop, paramMap);
                if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTWRITE) )
                        transientWrites.put(currentHop.getName(), currentHop);
                if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.FEDERATED) )
@@ -453,6 +443,8 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
        private void debugLog(Hop currentHop){
                if ( LOG.isDebugEnabled() ){
                        LOG.debug("Visiting HOP: " + currentHop + " Input size: 
" + currentHop.getInput().size());
+                       if (currentHop.getPrivacy() != null)
+                               LOG.debug(currentHop.getPrivacy());
                        int index = 0;
                        for ( Hop hop : currentHop.getInput()){
                                if ( hop == null )
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
new file mode 100644
index 0000000000..45b711a41d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+public class FederatedPlannerUtils {
+       /**
+        * Get transient inputs from either paramMap or transientWrites.
+        * Inputs from paramMap has higher priority than inputs from 
transientWrites.
+        * @param currentHop hop for which inputs are read from maps
+        * @param paramMap of local parameters
+        * @param transientWrites map of transient writes
+        * @return inputs of currentHop
+        */
+       public static ArrayList<Hop> getTransientInputs(Hop currentHop, 
Map<String, Hop> paramMap, Map<String,Hop> transientWrites){
+               Hop tWriteHop = null;
+               if ( paramMap != null)
+                       tWriteHop = paramMap.get(currentHop.getName());
+               if ( tWriteHop == null )
+                       tWriteHop = transientWrites.get(currentHop.getName());
+               if ( tWriteHop == null )
+                       throw new DMLRuntimeException("Transient write not 
found for " + currentHop);
+               else
+                       return new 
ArrayList<>(Collections.singletonList(tWriteHop));
+       }
+
+       /**
+        * Return parameter map containing the mapping from parameter name to 
input hop
+        * for all parameters of the function hop.
+        * @param funcOp hop for which the mapping of parameter names to input 
hops are made
+        * @return parameter map or empty map if function has no parameters
+        */
+       public static Map<String,Hop> getParamMap(FunctionOp funcOp){
+               String[] inputNames = funcOp.getInputVariableNames();
+               Map<String,Hop> paramMap = new HashMap<>();
+               if ( inputNames != null ){
+                       for ( int i = 0; i < funcOp.getInput().size(); i++ )
+                               paramMap.put(inputNames[i],funcOp.getInput(i));
+               }
+               return paramMap;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
new file mode 100644
index 0000000000..82e4316988
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.privacy.DMLPrivacyException;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
+import org.apache.sysds.utils.JSONHelper;
+import org.apache.wink.json4j.JSONObject;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Future;
+
+public class PrivacyConstraintLoader {
+
+       private final Map<Long, Hop> memo = new HashMap<>();
+       private final Map<String, Hop> transientWrites = new HashMap<>();
+
+       public void loadConstraints(DMLProgram prog){
+               rewriteStatementBlocks(prog, prog.getStatementBlocks(), null);
+       }
+
+       private void rewriteStatementBlocks(DMLProgram prog, 
List<StatementBlock> sbs, Map<String, Hop> paramMap) {
+               sbs.forEach(block -> rewriteStatementBlock(prog, block, 
paramMap));
+       }
+
+       private void rewriteStatementBlock(DMLProgram prog, StatementBlock 
block, Map<String, Hop> paramMap){
+               if(block instanceof WhileStatementBlock)
+                       rewriteWhileStatementBlock(prog, (WhileStatementBlock) 
block, paramMap);
+               else if(block instanceof IfStatementBlock)
+                       rewriteIfStatementBlock(prog, (IfStatementBlock) block, 
paramMap);
+               else if(block instanceof ForStatementBlock) {
+                       // This also includes ParForStatementBlocks
+                       rewriteForStatementBlock(prog, (ForStatementBlock) 
block, paramMap);
+               }
+               else if(block instanceof FunctionStatementBlock)
+                       rewriteFunctionStatementBlock(prog, 
(FunctionStatementBlock) block, paramMap);
+               else {
+                       // StatementBlock type (no subclass)
+                       rewriteDefaultStatementBlock(prog, block, paramMap);
+               }
+       }
+
+       private void rewriteWhileStatementBlock(DMLProgram prog, 
WhileStatementBlock whileSB, Map<String, Hop> paramMap) {
+               Hop whilePredicateHop = whileSB.getPredicateHops();
+               loadPrivacyConstraint(whilePredicateHop, paramMap);
+               for(Statement stm : whileSB.getStatements()) {
+                       WhileStatement whileStm = (WhileStatement) stm;
+                       rewriteStatementBlocks(prog, whileStm.getBody(), 
paramMap);
+               }
+       }
+
+       private void rewriteIfStatementBlock(DMLProgram prog, IfStatementBlock 
ifSB, Map<String, Hop> paramMap) {
+               loadPrivacyConstraint(ifSB.getPredicateHops(), paramMap);
+               for(Statement statement : ifSB.getStatements()) {
+                       IfStatement ifStatement = (IfStatement) statement;
+                       rewriteStatementBlocks(prog, ifStatement.getIfBody(), 
paramMap);
+                       rewriteStatementBlocks(prog, ifStatement.getElseBody(), 
paramMap);
+               }
+       }
+
+       private void rewriteForStatementBlock(DMLProgram prog, 
ForStatementBlock forSB, Map<String, Hop> paramMap) {
+               loadPrivacyConstraint(forSB.getFromHops(), paramMap);
+               loadPrivacyConstraint(forSB.getToHops(), paramMap);
+               loadPrivacyConstraint(forSB.getIncrementHops(), paramMap);
+               for(Statement statement : forSB.getStatements()) {
+                       ForStatement forStatement = ((ForStatement) statement);
+                       rewriteStatementBlocks(prog, forStatement.getBody(), 
paramMap);
+               }
+       }
+
+       private void rewriteFunctionStatementBlock(DMLProgram prog, 
FunctionStatementBlock funcSB, Map<String, Hop> paramMap) {
+               for(Statement statement : funcSB.getStatements()) {
+                       FunctionStatement funcStm = (FunctionStatement) 
statement;
+                       rewriteStatementBlocks(prog, funcStm.getBody(), 
paramMap);
+               }
+       }
+
+       private void rewriteDefaultStatementBlock(DMLProgram prog, 
StatementBlock sb, Map<String, Hop> paramMap) {
+               if(sb.hasHops()) {
+                       for(Hop sbHop : sb.getHops()) {
+                               loadPrivacyConstraint(sbHop, paramMap);
+                               if(sbHop instanceof FunctionOp) {
+                                       String funcName = ((FunctionOp) 
sbHop).getFunctionName();
+                                       Map<String, Hop> funcParamMap = 
FederatedPlannerUtils.getParamMap((FunctionOp) sbHop);
+                                       if ( paramMap != null && funcParamMap 
!= null)
+                                               funcParamMap.putAll(paramMap);
+                                       paramMap = funcParamMap;
+                                       FunctionStatementBlock sbFuncBlock = 
prog.getBuiltinFunctionDictionary().getFunction(funcName);
+                                       rewriteStatementBlock(prog, 
sbFuncBlock, paramMap);
+                               }
+                       }
+               }
+       }
+
+       private void loadPrivacyConstraint(Hop root, Map<String, Hop> paramMap){
+               if ( root != null && !memo.containsKey(root.getHopID()) ){
+                       for ( Hop input : root.getInput() ){
+                               loadPrivacyConstraint(input, paramMap);
+                       }
+                       propagatePrivConstraintsLocal(root, paramMap);
+                       memo.put(root.getHopID(), root);
+               }
+       }
+
+       private void propagatePrivConstraintsLocal(Hop currentHop, Map<String, 
Hop> paramMap){
+               if ( currentHop.isFederatedDataOp() )
+                       loadFederatedPrivacyConstraints(currentHop);
+               else if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTWRITE) ){
+                       
currentHop.setPrivacy(currentHop.getInput(0).getPrivacy());
+                       transientWrites.put(currentHop.getName(), currentHop);
+               }
+               else if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) ){
+                       
currentHop.setPrivacy(FederatedPlannerUtils.getTransientInputs(currentHop, 
paramMap, transientWrites).get(0).getPrivacy());
+               } else {
+                       PrivacyPropagator.hopPropagation(currentHop);
+               }
+       }
+
+       /**
+        * Get privacy constraints from federated workers for DataOps.
+        * @hop hop for which privacy constraints are loaded
+        */
+       private static void loadFederatedPrivacyConstraints(Hop hop){
+               try {
+                       PrivacyConstraint.PrivacyLevel constraintLevel = 
hop.getInput(0).getInput().stream().parallel()
+                               .map( in -> ((LiteralOp)in).getStringValue() )
+                               
.map(PrivacyConstraintLoader::sendPrivConstraintRequest)
+                               
.map(PrivacyConstraintLoader::unwrapPrivConstraint)
+                               .map(constraint -> (constraint != null) ? 
constraint.getPrivacyLevel() : PrivacyConstraint.PrivacyLevel.None)
+                               .reduce(PrivacyConstraint.PrivacyLevel.None, 
(out,in) -> {
+                                       if ( out == 
PrivacyConstraint.PrivacyLevel.Private || in == 
PrivacyConstraint.PrivacyLevel.Private )
+                                               return 
PrivacyConstraint.PrivacyLevel.Private;
+                                       else if ( out == 
PrivacyConstraint.PrivacyLevel.PrivateAggregation || in == 
PrivacyConstraint.PrivacyLevel.PrivateAggregation )
+                                               return 
PrivacyConstraint.PrivacyLevel.PrivateAggregation;
+                                       else
+                                               return out;
+                               });
+                       PrivacyConstraint fedDataPrivConstraint = 
(constraintLevel != PrivacyConstraint.PrivacyLevel.None) ?
+                               new PrivacyConstraint(constraintLevel) : null;
+
+                       hop.setPrivacy(fedDataPrivConstraint);
+               }
+               catch(Exception ex) {
+                       throw new DMLException(ex);
+               }
+       }
+
+       private static Future<FederatedResponse> 
sendPrivConstraintRequest(String address)
+       {
+               try{
+                       String[] parsedAddress = 
InitFEDInstruction.parseURL(address);
+                       String host = parsedAddress[0];
+                       int port = Integer.parseInt(parsedAddress[1]);
+                       PrivacyConstraintRetriever retriever = new 
PrivacyConstraintRetriever(parsedAddress[2]);
+                       FederatedRequest privacyRetrieval =
+                               new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever);
+                       InetSocketAddress inetAddress = new 
InetSocketAddress(InetAddress.getByName(host), port);
+                       return 
FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval);
+               } catch(UnknownHostException ex){
+                       throw new DMLException(ex);
+               }
+       }
+
+       private static PrivacyConstraint 
unwrapPrivConstraint(Future<FederatedResponse> privConstraintFuture)
+       {
+               try {
+                       FederatedResponse privConstraintResponse = 
privConstraintFuture.get();
+                       return (PrivacyConstraint) 
privConstraintResponse.getData()[0];
+               } catch(Exception ex){
+                       throw new DMLException(ex);
+               }
+       }
+
+       /**
+        * FederatedUDF for retrieving privacy constraint of data stored in 
file name.
+        */
+       public static class PrivacyConstraintRetriever extends FederatedUDF {
+               private static final long serialVersionUID = 
3551741240135587183L;
+               private final String filename;
+
+               public PrivacyConstraintRetriever(String filename){
+                       super(new long[]{});
+                       this.filename = filename;
+               }
+
+               /**
+                * Reads metadata JSON object, parses privacy constraint and 
returns the constraint in FederatedResponse.
+                * @param ec execution context
+                * @param data one or many data objects
+                * @return FederatedResponse with privacy constraint object
+                */
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       PrivacyConstraint privacyConstraint;
+                       FileSystem fs = null;
+                       try {
+                               String mtdname = 
DataExpression.getMTDFileName(filename);
+                               Path path = new Path(mtdname);
+                               fs = IOUtilFunctions.getFileSystem(mtdname);
+                               try(BufferedReader br = new BufferedReader(new 
InputStreamReader(fs.open(path)))) {
+                                       JSONObject metadataObject = 
JSONHelper.parse(br);
+                                       privacyConstraint = 
PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject);
+                               }
+                       }
+                       catch (DMLPrivacyException | 
FederatedWorkerHandlerException ex){
+                               throw ex;
+                       }
+                       catch (Exception ex) {
+                               String msg = "Exception in reading metadata of: 
" + filename;
+                               throw new DMLRuntimeException(msg);
+                       }
+                       finally {
+                               IOUtilFunctions.closeSilently(fs);
+                       }
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint);
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index 6be3b9c8ec..e6c683eb38 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -23,6 +23,7 @@ import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
+import org.apache.sysds.hops.fedplanner.PrivacyConstraintLoader;
 import org.apache.sysds.parser.DMLProgram;
 
 /**
@@ -58,16 +59,24 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         */
        @Override
        public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes) {
-               // obtain planner instance according to config
                String splanner = ConfigurationManager.getDMLConfig()
                        .getTextValue(DMLConfig.FEDERATED_PLANNER);
+               loadPrivacyConstraints(prog, splanner);
+               generatePlan(prog, fgraph, fcallSizes, splanner);
+               return false;
+       }
+
+       private void loadPrivacyConstraints(DMLProgram prog, String splanner){
+               if (FederatedPlanner.isCompiled(splanner))
+                       new PrivacyConstraintLoader().loadConstraints(prog);
+       }
+
+       private void generatePlan(DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes, String splanner){
                FederatedPlanner planner = 
FederatedPlanner.isCompiled(splanner) ?
                        FederatedPlanner.valueOf(splanner.toUpperCase()) :
                        FederatedPlanner.COMPILE_COST_BASED;
-               
+
                // run planner rewrite with forced federated exec types
                planner.getPlanner().rewriteProgram(prog, fgraph, fcallSizes);
-               
-               return false;
        }
 }
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index faec3504e9..db20ada280 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -27,10 +27,8 @@ import org.apache.log4j.Logger;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.CompilerConfig.ConfigType;
-import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.fedplanner.FTypes;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.ForStatement;
 import org.apache.sysds.parser.ForStatementBlock;
@@ -141,11 +139,6 @@ public class ProgramRewriter
                                _dagRuleSet.add( new 
RewriteAlgebraicSimplificationDynamic()      ); //dependencies: cse
                                _dagRuleSet.add( new 
RewriteAlgebraicSimplificationStatic()       ); //dependencies: cse
                        }
-                       String planner = ConfigurationManager.getDMLConfig()
-                               .getTextValue(DMLConfig.FEDERATED_PLANNER);
-                       if ( OptimizerUtils.FEDERATED_COMPILATION || 
FTypes.FederatedPlanner.isCompiled(planner) ) {
-                               _dagRuleSet.add( new 
RewriteFederatedExecution() );
-                       }
                }
                
                // cleanup after all rewrites applied 
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
deleted file mode 100644
index 822b4b5d95..0000000000
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.rewrite;
-
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.log4j.Logger;
-import org.apache.sysds.api.DMLException;
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.parser.DataExpression;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
-import 
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
-import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
-import org.apache.sysds.runtime.io.IOUtilFunctions;
-import org.apache.sysds.runtime.lineage.LineageItem;
-import org.apache.sysds.runtime.privacy.DMLPrivacyException;
-import org.apache.sysds.runtime.privacy.PrivacyConstraint;
-import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
-import org.apache.sysds.utils.JSONHelper;
-import org.apache.wink.json4j.JSONObject;
-
-import javax.net.ssl.SSLException;
-import java.io.BufferedReader;
-import java.io.InputStreamReader;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.net.UnknownHostException;
-import java.util.ArrayList;
-import java.util.concurrent.Future;
-
-public class RewriteFederatedExecution extends HopRewriteRule {
-       private static final Logger LOG = 
Logger.getLogger(RewriteFederatedExecution.class);
-       
-       @Override
-       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
-               if ( roots != null )
-                       for ( Hop root : roots )
-                               rewriteHopDAG(root, state);
-               return roots;
-       }
-
-       @Override
-       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
-               if ( root != null )
-                       visitHop(root);
-               return root;
-       }
-
-       private void visitHop(Hop hop){
-               if (hop.isVisited())
-                       return;
-
-               LOG.debug("RewriteFederatedExecution visitHop + " + hop);
-
-               // Depth first to get to the input
-               for ( Hop input : hop.getInput() )
-                       visitHop(input);
-
-               privacyBasedHopDecisionWithFedCall(hop);
-               hop.setVisited();
-       }
-
-       /**
-        * Get privacy constraints of DataOps from federated worker,
-        * propagate privacy constraints from input to current hop,
-        * and set federated output flag.
-        * @param hop current hop
-        */
-       private static void privacyBasedHopDecisionWithFedCall(Hop hop){
-               loadFederatedPrivacyConstraints(hop);
-               PrivacyPropagator.hopPropagation(hop);
-       }
-
-       /**
-        * Get privacy constraints from federated workers for DataOps.
-        * @hop hop for which privacy constraints are loaded
-        */
-       private static void loadFederatedPrivacyConstraints(Hop hop){
-               if ( hop.isFederatedDataOp() && hop.getPrivacy() == null){
-                       try {
-                               LOG.debug("Load privacy constraints of " + hop);
-                               PrivacyConstraint privConstraint = 
unwrapPrivConstraint(sendPrivConstraintRequest(hop));
-                               LOG.debug("PrivacyConstraint retrieved: " + 
privConstraint);
-                               hop.setPrivacy(privConstraint);
-                       }
-                       catch(Exception e) {
-                               throw new DMLException(e);
-                       }
-               }
-       }
-
-       private static Future<FederatedResponse> sendPrivConstraintRequest(Hop 
hop)
-               throws UnknownHostException, SSLException
-       {
-               String address = ((LiteralOp) 
hop.getInput(0).getInput(0)).getStringValue();
-               String[] parsedAddress = InitFEDInstruction.parseURL(address);
-               String host = parsedAddress[0];
-               int port = Integer.parseInt(parsedAddress[1]);
-               PrivacyConstraintRetriever retriever = new 
PrivacyConstraintRetriever(parsedAddress[2]);
-               FederatedRequest privacyRetrieval =
-                       new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever);
-               InetSocketAddress inetAddress = new 
InetSocketAddress(InetAddress.getByName(host), port);
-               return FederatedData.executeFederatedOperation(inetAddress, 
privacyRetrieval);
-       }
-
-       private static PrivacyConstraint 
unwrapPrivConstraint(Future<FederatedResponse> privConstraintFuture)
-               throws Exception
-       {
-               FederatedResponse privConstraintResponse = 
privConstraintFuture.get();
-               return (PrivacyConstraint) privConstraintResponse.getData()[0];
-       }
-
-       /**
-        * FederatedUDF for retrieving privacy constraint of data stored in 
file name.
-        */
-       public static class PrivacyConstraintRetriever extends FederatedUDF {
-               private static final long serialVersionUID = 
3551741240135587183L;
-               private final String filename;
-
-               public PrivacyConstraintRetriever(String filename){
-                       super(new long[]{});
-                       this.filename = filename;
-               }
-
-               /**
-                * Reads metadata JSON object, parses privacy constraint and 
returns the constraint in FederatedResponse.
-                * @param ec execution context
-                * @param data one or many data objects
-                * @return FederatedResponse with privacy constraint object
-                */
-               @Override
-               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
-                       PrivacyConstraint privacyConstraint;
-                       FileSystem fs = null;
-                       try {
-                               String mtdname = 
DataExpression.getMTDFileName(filename);
-                               Path path = new Path(mtdname);
-                               fs = IOUtilFunctions.getFileSystem(mtdname);
-                               try(BufferedReader br = new BufferedReader(new 
InputStreamReader(fs.open(path)))) {
-                                       JSONObject metadataObject = 
JSONHelper.parse(br);
-                                       privacyConstraint = 
PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject);
-                               }
-                       }
-                       catch (DMLPrivacyException | 
FederatedWorkerHandlerException ex){
-                               throw ex;
-                       }
-                       catch (Exception ex) {
-                               String msg = "Exception in reading metadata of: 
" + filename;
-                               throw new DMLRuntimeException(msg);
-                       }
-                       finally {
-                               IOUtilFunctions.closeSilently(fs);
-                       }
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint);
-               }
-
-               @Override
-               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
-                       return null;
-               }
-       }
-}
diff --git 
a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java 
b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
index 8ea061844a..fc9ba440c8 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
@@ -262,8 +262,11 @@ public class PrivacyConstraint implements Externalizable
 
        @Override
        public String toString(){
-               return "General privacy level: " + privacyLevel + 
System.getProperty("line.separator")
-                       + "Fine-grained privacy level: " + 
fineGrainedPrivacy.toString();
+               String constraintString = "General privacy level: " + 
privacyLevel;
+               if ( fineGrainedPrivacy != null && 
fineGrainedPrivacy.hasConstraints() )
+                       constraintString = constraintString + 
System.getProperty("line.separator")
+                               + "Fine-grained privacy level: " + 
fineGrainedPrivacy.toString();
+               return constraintString;
        }
 
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
 
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
index 7e6c0127e5..94834ebc6e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
@@ -23,11 +23,18 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 
+import org.apache.sysds.api.DMLException;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggUnaryOp;
 import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataGenOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
 import org.apache.sysds.hops.ReorgOp;
 import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.UnaryOp;
@@ -168,12 +175,39 @@ public class PrivacyPropagator
         * @param hop which the privacy constraints are propagated to
         */
        public static void hopPropagation(Hop hop){
-               PrivacyConstraint[] inputConstraints = hop.getInput().stream()
+               hopPropagation(hop, hop.getInput());
+       }
+
+       /**
+        * Propagate privacy constraints from input hops to given hop.
+        * @param hop which the privacy constraints are propagated to
+        * @param inputHops inputs to given hop
+        */
+       public static void hopPropagation(Hop hop, ArrayList<Hop> inputHops){
+               PrivacyConstraint[] inputConstraints = inputHops.stream()
                        .map(Hop::getPrivacy).toArray(PrivacyConstraint[]::new);
-               if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop 
instanceof ReorgOp )
-                       hop.setPrivacy(mergeNary(inputConstraints, 
OperatorType.NonAggregate));
+               OperatorType opType = getOpType(hop);
+               hop.setPrivacy(mergeNary(inputConstraints, opType));
+               if (opType == null && 
Arrays.stream(inputConstraints).anyMatch(Objects::nonNull))
+                       throw new DMLException("Input has constraint but hop 
type not recognized by PrivacyPropagator. " +
+                               "Hop is " + hop + " " + hop.getClass());
+       }
+
+       /**
+        * Get operator type of given hop.
+        * Returns null if hop type is not known.
+        * @param hop for which operator type is returned
+        * @return operator type of hop or null if hop type is unknown
+        */
+       private static OperatorType getOpType(Hop hop){
+               if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop 
instanceof ReorgOp
+                       || hop instanceof DataOp || hop instanceof LiteralOp || 
hop instanceof NaryOp
+                       || hop instanceof DataGenOp || hop instanceof 
FunctionOp )
+                       return OperatorType.NonAggregate;
                else if ( hop instanceof AggBinaryOp || hop instanceof 
AggUnaryOp  || hop instanceof UnaryOp )
-                       hop.setPrivacy(mergeNary(inputConstraints, 
OperatorType.Aggregate));
+                       return OperatorType.Aggregate;
+               else
+                       return null;
        }
 
        /**
@@ -406,7 +440,7 @@ public class PrivacyPropagator
                if (inputOperands != null){
                        for ( CPOperand input : inputOperands ){
                                PrivacyConstraint privacyConstraint = 
getInputPrivacyConstraint(ec, input);
-                               if ( privacyConstraint != null){
+                               if ( privacyConstraint != null && 
privacyConstraint.hasConstraints()){
                                        throw new DMLPrivacyException("Input of 
instruction " + inst + " has privacy constraints activated, but the constraints 
are not propagated during preprocessing of instruction.");
                                }
                        }
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java 
b/src/main/java/org/apache/sysds/utils/Explain.java
index 589f23a845..ded46c039a 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -626,6 +626,9 @@ public class Explain
                        }
                }
 
+               if ( hop.getPrivacy() != null )
+                       sb.append(" 
").append(hop.getPrivacy().getPrivacyLevel().name());
+
                sb.append('\n');
 
                hop.setVisited();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 14c093ebe8..2477bdef85 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -55,6 +55,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        private final static String TEST_NAME_8 = 
"FederatedMultiplyPlanningTest8";
        private final static String TEST_NAME_9 = 
"FederatedMultiplyPlanningTest9";
        private final static String TEST_NAME_10 = 
"FederatedMultiplyPlanningTest10";
+       private final static String TEST_NAME_11 = 
"FederatedMultiplyPlanningTest11";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
        private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
"SystemDS-config-cost-based.xml");
 
@@ -77,6 +78,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                addTestConfiguration(TEST_NAME_8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
                addTestConfiguration(TEST_NAME_9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
                addTestConfiguration(TEST_NAME_10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_11, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
        }
 
        @Parameterized.Parameters
@@ -153,6 +155,12 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                federatedTwoMatricesSingleNodeTest(TEST_NAME_10, 
expectedHeavyHitters);
        }
 
+       @Test
+       public void federatedMultiplyPlanningTest11(){
+               String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_11, 
expectedHeavyHitters);
+       }
+
        private void writeStandardMatrix(String matrixName, long seed){
                writeStandardMatrix(matrixName, seed, new 
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
        }
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml
new file mode 100644
index 0000000000..147bf2cd13
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), 
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+              ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), 
list($r, $c)))
+
+i = 0
+while(i < 10){
+    Z0 = X * Y
+    Z = t(Z0) %*% X
+    i=i+1
+}
+
+write(Z, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml
new file mode 100644
index 0000000000..187623bbfe
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+
+i = 0
+while(i < 10){
+    Z0 = X * Y
+    Z = t(Z0) %*% X
+    i=i+1
+}
+
+write(Z, $Z)

Reply via email to