This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 03fc10328a [SYSTEMDS-3018] Federated Rewriting Fixes
03fc10328a is described below

commit 03fc10328a18fe731d9d2089e25802518cb26d27
Author: sebwrede <[email protected]>
AuthorDate: Fri Jul 22 10:39:00 2022 +0200

    [SYSTEMDS-3018] Federated Rewriting Fixes
    
    Edit Repetition Estimate Update To Prevent Infinite Loops.
    Add Memo Table Size Explain and Fed Instruction Parsing Detail.
    
    Closes #1669.
---
 src/main/java/org/apache/sysds/hops/Hop.java                       | 5 ++++-
 src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java      | 6 +++++-
 .../apache/sysds/runtime/instructions/FEDInstructionParser.java    | 1 +
 src/main/java/org/apache/sysds/utils/Explain.java                  | 7 +++++++
 4 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 4d1dff8f22..3988a6b59f 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -94,6 +94,7 @@ public abstract class Hop implements ParseInfo {
        protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
        protected FederatedCost _federatedCost = new FederatedCost();
        protected double repetitions = 1;
+       protected boolean repetitionsUpdated = false;
 
        /**
         * Field defining if prefetch should be activated for operation.
@@ -1556,8 +1557,10 @@ public abstract class Hop implements ParseInfo {
        }
 
        public void updateRepetitionEstimates(double repetitions){
-               if ( !federatedCostInitialized() ){
+               LOG.trace("Updating repetition estimates of " + this.getName() 
+ " to " + repetitions);
+               if ( !federatedCostInitialized() && !repetitionsUpdated ){
                        this.repetitions = repetitions;
+                       this.repetitionsUpdated = true;
                        for ( Hop input : getInput() )
                                input.updateRepetitionEstimates(repetitions);
                }
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
index f84aecc5e8..3ecb0b29b9 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -161,10 +161,14 @@ public class MemoTable {
                        .orElseThrow(() -> new DMLRuntimeException("FType not 
found in memo"));
        }
 
+       public int getSize(){
+               return hopRelMemo.size();
+       }
+
        @Override
        public String toString(){
                StringBuilder sb = new StringBuilder();
-               sb.append("Federated MemoTable has 
").append(hopRelMemo.size()).append(" entries with the following values:");
+               sb.append("Federated MemoTable has 
").append(getSize()).append(" entries with the following values:");
                sb.append("\n").append("{").append("\n");
                for (Map.Entry<Long,List<HopRel>> hopEntry : 
hopRelMemo.entrySet()){
                        sb.append("  
").append(hopEntry.getKey()).append(":").append("\n");
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 81d2983da1..f61e86e800 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -73,6 +73,7 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "/" ,  FEDType.Binary );
                String2FEDInstructionType.put( "1-*", FEDType.Binary); 
//special * case
                String2FEDInstructionType.put( "^2" , FEDType.Binary); 
//special ^ case
+               String2FEDInstructionType.put( "*2" , FEDType.Binary); 
//special * case
                String2FEDInstructionType.put( "max", FEDType.Binary );
                String2FEDInstructionType.put( "==",  FEDType.Binary);
                String2FEDInstructionType.put( "!=",  FEDType.Binary);
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java 
b/src/main/java/org/apache/sysds/utils/Explain.java
index ded46c039a..75740c1c5a 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -125,6 +125,7 @@ public class Explain
                return "# EXPLAIN ("+type.name()+"):\n"
                                + Explain.explainMemoryBudget(counts)+"\n"
                                + Explain.explainDegreeOfParallelism(counts)
+                               + Explain.explainMemoTableSize()
                                + Explain.explain(prog, rtprog, type, counts);
        }
 
@@ -185,6 +186,12 @@ public class Explain
                return sb.toString();
        }
 
+       private static String explainMemoTableSize(){
+               if ( MEMO_TABLE != null )
+                       return "\n# Number of HOPs in Memo = " + 
MEMO_TABLE.getSize();
+               else return "";
+       }
+
        public static String explain(DMLProgram prog, Program rtprog, 
ExplainType type) {
                return explain(prog, rtprog, type, null);
        }

Reply via email to