This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new aba1707852 [SYSTEMDS-3018] Federated Planner Forced ExecType And 
FedOut Info
aba1707852 is described below

commit aba1707852d546a9c46ada3f185824df063494bd
Author: sebwrede <[email protected]>
AuthorDate: Wed May 11 16:08:53 2022 +0200

    [SYSTEMDS-3018] Federated Planner Forced ExecType And FedOut Info
    
    Applying this commit will:
    1) Add Forced ExecType and Other Adjustments of ExecType
    2) Add FedOut Info to Explain Hops Output
    
    Closes #1612.
---
 .../java/org/apache/sysds/hops/AggUnaryOp.java     |  3 ++
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |  4 +--
 src/main/java/org/apache/sysds/hops/Hop.java       | 18 ------------
 .../java/org/apache/sysds/hops/cost/HopRel.java    | 32 ++++++++++++++++------
 .../hops/fedplanner/FederatedPlannerCostbased.java | 14 +++++++++-
 .../apache/sysds/hops/fedplanner/MemoTable.java    | 15 ++++++++++
 .../runtime/instructions/FEDInstructionParser.java |  3 ++
 src/main/java/org/apache/sysds/utils/Explain.java  | 22 +++++++++++++++
 .../fedplanning/FederatedL2SVMPlanningTest.java    |  3 +-
 9 files changed, 82 insertions(+), 32 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index c461b69bac..23439b182e 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -608,6 +608,9 @@ public class AggUnaryOp extends MultiThreadedHop
                ExecType et_input = input1.optFindExecType();
                // Because ternary aggregate are not supported on GPU
                et_input = et_input == ExecType.GPU ? ExecType.CP :  et_input;
+               // If forced ExecType is FED, it means that the federated 
planner updated the ExecType and
+               // execution may fail if ExecType is not FED
+               et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED 
: et_input;
                
                return new TernaryAggregate(in1, in2, in3, AggOp.SUM, 
                        OpOp2.MULT, _direction, getDataType(), ValueType.FP64, 
et_input, k);
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 791c3bdfbd..2346eeebfe 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -755,11 +755,9 @@ public class BinaryOp extends MultiThreadedHop {
                        checkAndSetInvalidCPDimsAndSize();
                }
 
-               updateETFed();
-                       
                //spark-specific decision refinement (execute unary scalar w/ 
spark input and 
                //single parent also in spark because it's likely cheap and 
reduces intermediates)
-               if( transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP
+               if( transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP && _etypeForced != ExecType.FED
                        && getDataType().isMatrix() && (dt1.isScalar() || 
dt2.isScalar()) 
                        && supportsMatrixScalarOperations()                     
     //scalar operations
                        && !(getInput().get(dt1.isScalar()?1:0) instanceof 
DataOp)   //input is not checkpoint
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 4ce9a4b90f..e1e4fcc8d4 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -909,24 +909,6 @@ public abstract class Hop implements ParseInfo {
                return et;
        }
 
-       /**
-        * Update the execution type if input is federated.
-        * This method only has an effect if FEDERATED_COMPILATION is activated.
-        * Federated compilation is activated in OptimizerUtils.
-        */
-       public void updateETFed() {
-               boolean localOut = hasLocalOutput();
-               boolean fedIn = getInput().stream().anyMatch(
-                       in -> in.hasFederatedOutput() && 
!(in.prefetchActivated() && localOut));
-               if( isFederatedDataOp() || fedIn ){
-                       setForcedExecType(ExecType.FED);
-                       //TODO: Temporary solution where _etype is set directly
-                       // since forcedExecType for BinaryOp may be overwritten
-                       // if updateETFed is not called from optFindExecType.
-                       _etype = ExecType.FED;
-               }
-       }
-
        /**
         * Checks if ExecType is federated.
         * @return true if ExecType is federated
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java 
b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index 89a0f7cb50..70785950ca 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.hops.cost;
 
 import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.fedplanner.FTypes;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
@@ -43,8 +44,9 @@ import java.util.stream.Collectors;
 public class HopRel {
        protected final Hop hopRef;
        protected final FEDInstruction.FederatedOutput fedOut;
+       protected ExecType execType;
        protected FTypes.FType fType;
-       protected final FederatedCost cost;
+       protected FederatedCost cost;
        protected final Set<Long> costPointerSet = new HashSet<>();
        protected List<Hop> inputHops;
        protected List<HopRel> inputDependency = new ArrayList<>();
@@ -70,6 +72,13 @@ public class HopRel {
                this(associatedHop, fedOut, null, hopRelMemo, inputs);
        }
 
+       private HopRel(Hop associatedHop, FEDInstruction.FederatedOutput 
fedOut, FType fType, List<Hop> inputs){
+               hopRef = associatedHop;
+               this.fedOut = fedOut;
+               this.fType = fType;
+               this.inputHops = inputs;
+       }
+
        /**
         * Constructs a HopRel with input dependency and cost estimate based on 
entries in hopRelMemo.
         * @param associatedHop hop associated with this HopRel
@@ -79,21 +88,17 @@ public class HopRel {
         * @param inputs hop inputs which input dependencies and cost is based 
on
         */
        public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, 
FType fType, MemoTable hopRelMemo, ArrayList<Hop> inputs){
-               hopRef = associatedHop;
-               this.fedOut = fedOut;
-               this.fType = fType;
-               this.inputHops = inputs;
+               this(associatedHop, fedOut, fType, inputs);
                setInputDependency(hopRelMemo);
                cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
+               setExecType();
        }
 
        public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, 
FType fType, MemoTable hopRelMemo, List<Hop> inputs, List<FType> 
inputDependency){
-               hopRef = associatedHop;
-               this.fedOut = fedOut;
-               this.inputHops = inputs;
-               this.fType = fType;
+               this(associatedHop, fedOut, fType, inputs);
                setInputFTypeDependency(inputs, inputDependency, hopRelMemo);
                cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
+               setExecType();
        }
 
        private void setInputFTypeDependency(List<Hop> inputs, List<FType> 
inputDependency, MemoTable hopRelMemo){
@@ -103,6 +108,11 @@ public class HopRel {
                validateInputDependency();
        }
 
+       private void setExecType(){
+               if ( 
inputDependency.stream().anyMatch(HopRel::hasFederatedOutput) )
+                       execType = ExecType.FED;
+       }
+
        /**
         * Adds hopID to set of hops pointing to this HopRel.
         * By storing the hopID it can later be determined if the cost
@@ -154,6 +164,10 @@ public class HopRel {
                this.fType = fType;
        }
 
+       public ExecType getExecType(){
+               return execType;
+       }
+
        /**
         * Returns FOUT HopRel for given hop found in hopRelMemo or returns 
null if HopRel not found.
         * @param hop to look for in hopRelMemo
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index a809d2bafd..e9a25206f8 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -30,6 +30,7 @@ import java.util.Set;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 import org.apache.sysds.hops.DataOp;
@@ -53,6 +54,8 @@ import org.apache.sysds.parser.WhileStatement;
 import org.apache.sysds.parser.WhileStatementBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+import org.apache.sysds.utils.Explain;
+import org.apache.sysds.utils.Explain.ExplainType;
 
 public class FederatedPlannerCostbased extends AFederatedPlanner {
        private static final Log LOG = 
LogFactory.getLog(FederatedPlannerCostbased.class.getName());
@@ -77,6 +80,7 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                prog.updateRepetitionEstimates();
                rewriteStatementBlocks(prog, prog.getStatementBlocks());
                setFinalFedouts();
+               updateExplain();
        }
        
        /**
@@ -215,7 +219,6 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                        updateFederatedOutput(root, rootHopRel);
                        visitInputDependency(rootHopRel);
                }
-               root.updateETFed();
        }
 
        /**
@@ -238,6 +241,7 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
        private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
                root.setFederatedOutput(updateHopRel.getFederatedOutput());
                root.setFederatedCost(updateHopRel.getCostObject());
+               root.setForcedExecType(updateHopRel.getExecType());
                forceFixedFedOut(root);
                LOG.trace("Updated fedOut to " + 
updateHopRel.getFederatedOutput() + " for hop "
                        + root.getHopID() + " opcode: " + root.getOpString());
@@ -394,6 +398,14 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                }
        }
 
+       /**
+        * Add hopRelMemo to Explain class to get explain info related to 
federated enumeration.
+        */
+       private void updateExplain(){
+               if (DMLScript.EXPLAIN == ExplainType.HOPS)
+                       Explain.setMemo(hopRelMemo);
+       }
+
        /**
         * Write HOP visit to debug log if debug is activated.
         * @param currentHop hop written to log
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
index 6b9da0f400..5b399bd499 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -23,6 +23,7 @@ import org.apache.sysds.api.DMLException;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.cost.HopRel;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
 
 import java.util.Comparator;
 import java.util.HashMap;
@@ -46,6 +47,20 @@ public class MemoTable {
         */
        private final static Map<Long, List<HopRel>> hopRelMemo = new 
HashMap<>();
 
+       /**
+        * Get list of strings representing the different
+        * hopRel federated outputs related to root hop.
+        * @param root for which HopRel fedouts are found
+        * @return federated output values as strings
+        */
+       public List<String> getFedOutAlternatives(Hop root){
+               if ( !containsHop(root) )
+                       return null;
+               else return hopRelMemo.get(root.getHopID()).stream()
+                       .map(HopRel::getFederatedOutput)
+                       
.map(FEDInstruction.FederatedOutput::name).collect(Collectors.toList());
+       }
+
        /**
         * Get the HopRel with minimum cost for given root hop
         * @param root hop for which minimum cost HopRel is found
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 2fde0a0fbc..58ab43daba 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions;
 
 import org.apache.sysds.lops.Append;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
 import 
org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
@@ -52,6 +53,8 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "uak+"    , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uark+"   , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uack+"   , 
FEDType.AggregateUnary );
+               String2FEDInstructionType.put( "uamax"   , 
FEDType.AggregateUnary );
+               String2FEDInstructionType.put( "uamin"   , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uasqk+"  , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uarsqk+" , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uacsqk+" , 
FEDType.AggregateUnary );
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java 
b/src/main/java/org/apache/sysds/utils/Explain.java
index ba6fb7150e..c8e5902511 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -35,6 +35,7 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.codegen.cplan.CNode;
 import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
 import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysds.hops.fedplanner.MemoTable;
 import org.apache.sysds.hops.ipa.FunctionCallGraph;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.parser.DMLProgram;
@@ -78,6 +79,9 @@ public class Explain
        private static final boolean SHOW_DATA_DEPENDENCIES     = true;
        private static final boolean SHOW_DATA_FLOW_PROPERTIES  = true;
 
+       //federated execution plan alternatives
+       private static MemoTable MEMO_TABLE;
+
        //different explain levels
        public enum ExplainType {
                NONE,     // explain disabled
@@ -101,6 +105,14 @@ public class Explain
                public int numChkpts = 0;
        }
 
+       /**
+        * Store memo table for adding additional explain info regarding hops.
+        * @param memoTable to store in Explain
+        */
+       public static void setMemo(MemoTable memoTable){
+               MEMO_TABLE = memoTable;
+       }
+
        //////////////
        // public explain interface
 
@@ -600,6 +612,16 @@ public class Explain
                if (hop.getExecType() != null)
                        sb.append(", " + hop.getExecType());
 
+               if ( MEMO_TABLE != null && MEMO_TABLE.containsHop(hop) ){
+                       List<String> fedAlts = 
MEMO_TABLE.getFedOutAlternatives(hop);
+                       if ( fedAlts != null ){
+                               sb.append(" [ ");
+                               for ( String fedAlt : fedAlts )
+                                       sb.append(fedAlt).append(" ");
+                               sb.append("]");
+                       }
+               }
+
                sb.append('\n');
 
                hop.setVisited();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index e9ab6b6ad0..1ba9966773 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -138,7 +138,8 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
 
                        // Run actual dml script with federated matrix
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[] { "-stats", "-explain", 
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+                       programArgs = new String[] { "-stats", "-explain", 
"hops", "-nvargs",
+                               "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                                "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
                                "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
                        runTest(true, false, null, -1);

Reply via email to