Repository: systemml
Updated Branches:
  refs/heads/master 948943d17 -> ddcb9e019


[SYSTEMML-1691,1692] New IPA passes: literal replacement and rewrites

This patch introduces two new passes for inter-procedural analysis
(IPA): (1) literal propagation and replacement into functions, and (2)
static rewrites, which are both applied for any number of requires IPA
iterations. The internal abstraction for function call summaries has
been extended accordingly. The new literal propagation and replacement
works on a fine granularity of individual function parameters and
propagates any literals that are consistent across all function calls,
independent of remaining inputs. Together with the additional rewrites
pass, this allows rewrites such as constant folding and subsequent
removal of branches which can significantly cut down the program size
and number of distributed operations. 

For example, for GLM poisson.log, this change reduced the size of the
initial runtime program from 2132/153 to 1164/45 local/distributed
instructions, which is now much easier to debug and profile.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1e6639c7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1e6639c7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1e6639c7

Branch: refs/heads/master
Commit: 1e6639c754961f51bb53754c1fa8b6dce404294a
Parents: 948943d
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Jun 15 20:01:12 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Jun 16 10:01:57 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/ipa/FunctionCallSizeInfo.java    |  90 ++++++++++-
 .../java/org/apache/sysml/hops/ipa/IPAPass.java |   3 +-
 .../hops/ipa/IPAPassApplyStaticHopRewrites.java |  53 +++++++
 .../ipa/IPAPassFlagFunctionsRecompileOnce.java  |   2 +-
 .../ipa/IPAPassPropagateReplaceLiterals.java    | 155 +++++++++++++++++++
 .../ipa/IPAPassRemoveConstantBinaryOps.java     |   2 +-
 .../IPAPassRemoveUnnecessaryCheckpoints.java    |   2 +-
 .../hops/ipa/IPAPassRemoveUnusedFunctions.java  |   2 +-
 .../sysml/hops/ipa/InterProceduralAnalysis.java | 104 +++++++------
 .../org/apache/sysml/parser/DMLTranslator.java  |  12 +-
 .../java/org/apache/sysml/utils/Explain.java    |   5 +-
 ...antFoldingScalarVariablePropagationTest.java |  17 +-
 12 files changed, 360 insertions(+), 87 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java 
b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
index 20054a2..402e780 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Set;
 
 import org.apache.sysml.hops.FunctionOp;
@@ -52,10 +53,16 @@ public class FunctionCallSizeInfo
        //to subsequent statement blocks and functions)
        private final Set<String> _fcandUnary;
        
-       //indicators for which function arguments it is safe to propagate nnz
+       //indicators for which function arguments of valid functions it 
+       //is safe to propagate the number of non-zeros 
        //(mapping from function keys to set of function input HopIDs)
        private final Map<String, Set<Long>> _fcandSafeNNZ;
        
+       //indicators which literal function arguments can be safely 
+       //propagated into and replaced in the respective functions 
+       //(mapping from function keys to set of function input positions)
+       private final Map<String, Set<Integer>> _fSafeLiterals;
+       
        /**
         * Constructs the function call summary for all functions
         * reachable from the main program. 
@@ -84,6 +91,7 @@ public class FunctionCallSizeInfo
                _fcand = new HashSet<String>();
                _fcandUnary = new HashSet<String>();
                _fcandSafeNNZ =  new HashMap<String, Set<Long>>();
+               _fSafeLiterals = new HashMap<String, Set<Integer>>();
                
                constructFunctionCallSizeInfo();
        }
@@ -169,17 +177,44 @@ public class FunctionCallSizeInfo
         * 
         * @param fkey function key
         * @param inputHopID hop ID of the input
-        * @return true if nnz can safely be propageted
+        * @return true if nnz can safely be propagated
         */
        public boolean isSafeNnz(String fkey, long inputHopID) {
                return _fcandSafeNNZ.containsKey(fkey)
                        && _fcandSafeNNZ.get(fkey).contains(inputHopID);
        }
        
+       /**
+        * Indicates if the given function has at least one input
+        * that allows for safe literal propagation and replacement,
+        * i.e., all function calls have consistent literal inputs.
+        * 
+        * @param fkey function key
+        * @return true if a literal can be safely propagated
+        */
+       public boolean hasSafeLiterals(String fkey) {
+               return _fSafeLiterals.containsKey(fkey)
+                       && !_fSafeLiterals.get(fkey).isEmpty();
+       }
+       
+       /**
+        * Indicates if the given function input allows for safe
+        * literal propagation and replacement, i.e., all function calls
+        * have consistent literal inputs.
+        * 
+        * @param fkey function key
+        * @param pos function input position
+        * @return true if literal that can be safely propagated
+        */
+       public boolean isSafeLiteral(String fkey, int pos) {
+               return _fSafeLiterals.containsKey(fkey)
+                       && _fSafeLiterals.get(fkey).contains(pos);
+       }
+       
        private void constructFunctionCallSizeInfo() 
                throws HopsException 
        {
-               //determine function candidates by evaluating all function calls
+               //step 1: determine function candidates by evaluating all 
function calls
                for( String fkey : _fgraph.getReachableFunctions() ) {
                        List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey);
                
@@ -215,7 +250,8 @@ public class FunctionCallSizeInfo
                        }
                }
                
-               //determine safe nnz propagation per input
+               //step 2: determine safe nnz propagation per input
+               //(considered for valid functions only)
                for( String fkey : _fcand ) {
                        FunctionOp first = 
_fgraph.getFunctionCalls(fkey).get(0);
                        HashSet<Long> tmp = new HashSet<Long>();
@@ -227,13 +263,38 @@ public class FunctionCallSizeInfo
                        }
                        _fcandSafeNNZ.put(fkey, tmp);
                }
+               
+               //step 3: determine safe literal replacement per function input
+               //(considered for all functions)
+               for( String fkey : _fgraph.getReachableFunctions() ) {
+                       List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey);
+                       FunctionOp first = flist.get(0);
+                       //initialize w/ all literals of first call
+                       HashSet<Integer> tmp = new HashSet<Integer>();
+                       for( int j=0; j<first.getInput().size(); j++ )
+                               if( first.getInput().get(j) instanceof 
LiteralOp )
+                                       tmp.add(j);
+                       //check consistency across all function calls
+                       for( int i=1; i<flist.size(); i++ ) {
+                               FunctionOp other = flist.get(i);
+                               for( int j=0; j<first.getInput().size(); j++ ) 
+                                       if( tmp.contains(j) ) {
+                                               Hop h1 = 
first.getInput().get(j);
+                                               Hop h2 = 
other.getInput().get(j);
+                                               if( !(h2 instanceof LiteralOp 
&& HopRewriteUtils
+                                                       
.isEqualValue((LiteralOp)h1, (LiteralOp)h2)) )
+                                                       tmp.remove(j);
+                                       }
+                       }
+                       _fSafeLiterals.put(fkey, tmp);
+               }
        }
        
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder();
                
-               sb.append("Valid Functions for Propagation: \n");
+               sb.append("Valid functions for propagation: \n");
                for( String fkey : getValidFunctions() ) {
                        sb.append("--");
                        sb.append(fkey);
@@ -247,7 +308,7 @@ public class FunctionCallSizeInfo
                }
                
                if( !getInvalidFunctions().isEmpty() ) {
-                       sb.append("Invaid Functions for Propagation: \n");
+                       sb.append("Invaid functions for propagation: \n");
                        for( String fkey : getInvalidFunctions() ) {
                                sb.append("--");
                                sb.append(fkey);
@@ -258,7 +319,7 @@ public class FunctionCallSizeInfo
                }
                
                if( !getDimsPreservingFunctions().isEmpty() ) {
-                       sb.append("Dims-Preserving Functions: \n");
+                       sb.append("Dimensions-preserving functions: \n");
                        for( String fkey : getDimsPreservingFunctions() ) {
                                sb.append("--");
                                sb.append(fkey);
@@ -268,6 +329,21 @@ public class FunctionCallSizeInfo
                        }
                }
                
+               sb.append("Valid scalars for propagation: \n");
+               for( Entry<String, Set<Integer>> e : _fSafeLiterals.entrySet() 
) {
+                       sb.append("--");
+                       sb.append(e.getKey());
+                       sb.append(": ");
+                       for( Integer pos : e.getValue() ) {
+                               sb.append(pos);
+                               sb.append(":");
+                               sb.append(_fgraph.getFunctionCalls(e.getKey())
+                                       .get(0).getInput().get(pos).getName());
+                               sb.append(" ");
+                       }
+                       sb.append("\n");
+               }
+               
                return sb.toString();
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
index cfd9df7..ced407e 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
@@ -46,8 +46,9 @@ public abstract class IPAPass
         * 
         * @param prog dml program
         * @param fgraph function call graph
+        * @param fcallSizes function call size infos
         * @throws HopsException
         */
-       public abstract void rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph ) 
+       public abstract void rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) 
                throws HopsException;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
new file mode 100644
index 0000000..f436658
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.rewrite.ProgramRewriter;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.LanguageException;
+
+/**
+ * This rewrite applies static hop dag and statement block
+ * rewrites such as constant folding and branch removal
+ * in order to simplify statistic propagation.
+ * 
+ */
+public class IPAPassApplyStaticHopRewrites extends IPAPass
+{
+       @Override
+       public boolean isApplicable() {
+               return InterProceduralAnalysis.APPLY_STATIC_REWRITES;
+       }
+       
+       @Override
+       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
+               throws HopsException
+       {
+               try {
+                       ProgramRewriter rewriter = new ProgramRewriter(true, 
false);
+                       rewriter.rewriteProgramHopDAGs(prog);
+               } 
+               catch (LanguageException ex) {
+                       throw new HopsException(ex);
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
index ee072e4..82f4681 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
@@ -48,7 +48,7 @@ public class IPAPassFlagFunctionsRecompileOnce extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) 
+       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
                throws HopsException
        {
                try {

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java
new file mode 100644
index 0000000..57647ff
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.FunctionOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.ForStatement;
+import org.apache.sysml.parser.ForStatementBlock;
+import org.apache.sysml.parser.FunctionStatement;
+import org.apache.sysml.parser.FunctionStatementBlock;
+import org.apache.sysml.parser.IfStatement;
+import org.apache.sysml.parser.IfStatementBlock;
+import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.parser.WhileStatement;
+import org.apache.sysml.parser.WhileStatementBlock;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
+
+/**
+ * This rewrite propagates and replaces literals into functions
+ * in order to enable subsequent rewrites such as branch removal.
+ * 
+ */
+public class IPAPassPropagateReplaceLiterals extends IPAPass
+{
+       @Override
+       public boolean isApplicable() {
+               return InterProceduralAnalysis.PROPAGATE_SCALAR_LITERALS;
+       }
+       
+       @Override
+       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
+               throws HopsException
+       {
+               for( String fkey : fgraph.getReachableFunctions() ) {
+                       FunctionOp first = fgraph.getFunctionCalls(fkey).get(0);
+                       
+                       //propagate and replace amenable literals into function
+                       if( fcallSizes.hasSafeLiterals(fkey) ) {
+                               FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(fkey);
+                               FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
+                               ArrayList<DataIdentifier> finputs = 
fstmt.getInputParams();
+                               
+                               //populate call vars with amenable literals
+                               LocalVariableMap callVars = new 
LocalVariableMap();
+                               for( int j=0; j<finputs.size(); j++ )
+                                       if( fcallSizes.isSafeLiteral(fkey, j) ) 
{
+                                               LiteralOp lit = (LiteralOp) 
first.getInput().get(j);
+                                               
callVars.put(finputs.get(j).getName(), ScalarObjectFactory
+                                                               
.createScalarObject(lit.getValueType(), lit));
+                                       }
+                               
+                               //propagate and replace literals
+                               for( StatementBlock sb : fstmt.getBody() )
+                                       rReplaceLiterals(sb, callVars);
+                       }
+               }
+       }
+       
+       private void rReplaceLiterals(StatementBlock sb, LocalVariableMap 
constants) 
+               throws HopsException 
+       {
+               //remove updated literals
+               for( String varname : sb.variablesUpdated().getVariableNames() )
+                       if( constants.keySet().contains(varname) )
+                               constants.remove(varname);
+               
+               //propagate and replace literals
+               if (sb instanceof WhileStatementBlock) {
+                       WhileStatementBlock wsb = (WhileStatementBlock) sb;
+                       WhileStatement ws = (WhileStatement)sb.getStatement(0);
+                       replaceLiterals(wsb.getPredicateHops(), constants);
+                       for (StatementBlock current : ws.getBody())
+                               rReplaceLiterals(current, constants);
+               } 
+               else if (sb instanceof IfStatementBlock) {
+                       IfStatementBlock isb = (IfStatementBlock) sb;
+                       IfStatement ifs = (IfStatement) sb.getStatement(0);
+                       replaceLiterals(isb.getPredicateHops(), constants);
+                       for (StatementBlock current : ifs.getIfBody())
+                               rReplaceLiterals(current, constants);
+                       for (StatementBlock current : ifs.getElseBody())
+                               rReplaceLiterals(current, constants);
+               } 
+               else if (sb instanceof ForStatementBlock) {
+                       ForStatementBlock fsb = (ForStatementBlock) sb;
+                       ForStatement fs = (ForStatement)sb.getStatement(0);
+                       replaceLiterals(fsb.getFromHops(), constants);
+                       replaceLiterals(fsb.getToHops(), constants);
+                       replaceLiterals(fsb.getIncrementHops(), constants);
+                       for (StatementBlock current : fs.getBody())
+                               rReplaceLiterals(current, constants);
+               }
+               else {
+                       replaceLiterals(sb.get_hops(), constants);
+               }
+       }
+       
+       private void replaceLiterals(ArrayList<Hop> roots, LocalVariableMap 
constants) 
+               throws HopsException 
+       {
+               if( roots == null )
+                       return;
+               
+               try {
+                       Hop.resetVisitStatus(roots);
+                       for( Hop root : roots )
+                               Recompiler.rReplaceLiterals(root, constants, 
true);
+                       Hop.resetVisitStatus(roots);
+               }
+               catch(Exception ex) {
+                       throw new HopsException(ex);
+               }
+       }
+       
+       private void replaceLiterals(Hop root, LocalVariableMap constants) 
+               throws HopsException 
+       {
+               if( root == null )
+                       return;
+               
+               try {
+                       root.resetVisitStatus();
+                       Recompiler.rReplaceLiterals(root, constants, true);     
+                       root.resetVisitStatus();
+               }
+               catch(Exception ex) {
+                       throw new HopsException(ex);
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
index c71ed45..1a433a3 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -57,7 +57,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) 
+       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
                throws HopsException
        {
                //approach: scan over top-level program (guaranteed to be 
unconditional),

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
index 20c47da..664ec2a 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
@@ -56,7 +56,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends 
IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) 
+       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
                throws HopsException
        {
                //remove unnecessary checkpoint before update 

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
index 3424a52..9d41ca6 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
@@ -43,7 +43,7 @@ public class IPAPassRemoveUnusedFunctions extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) 
+       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
                throws HopsException
        {
                try {

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 1d997ed..7d371ac 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -94,6 +94,8 @@ public class InterProceduralAnalysis
        protected static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; 
//remove unnecessary checkpoints (unconditionally overwritten intermediates) 
        protected static final boolean REMOVE_CONSTANT_BINARY_OPS     = true; 
//remove constant binary operations (e.g., X*ones, where ones=matrix(1,...)) 
        protected static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; 
//propagate scalar variables into functions that are called once
+       protected static final boolean PROPAGATE_SCALAR_LITERALS      = true; 
//propagate and replace scalar literals into functions
+       protected static final boolean APPLY_STATIC_REWRITES          = true; 
//apply static hop dag and statement block rewrites
        
        static {
                // for internal debugging only
@@ -132,6 +134,8 @@ public class InterProceduralAnalysis
                _passes.add(new IPAPassFlagFunctionsRecompileOnce());
                _passes.add(new IPAPassRemoveUnnecessaryCheckpoints());
                _passes.add(new IPAPassRemoveConstantBinaryOps());
+               _passes.add(new IPAPassPropagateReplaceLiterals());
+               _passes.add(new IPAPassApplyStaticHopRewrites());
        }
        
        public InterProceduralAnalysis(StatementBlock sb) {
@@ -145,39 +149,64 @@ public class InterProceduralAnalysis
        }
        
        /**
-        * Public interface to perform IPA over a given DML program.
+        * Main interface to perform IPA over a given DML program.
         * 
-        * @param dmlp the dml program
-        * @throws HopsException if HopsException occurs
+        * @throws HopsException in case of compilation errors
         */
-       public void analyzeProgram() 
-               throws HopsException
+       public void analyzeProgram() throws HopsException {
+               analyzeProgram(1); //single run
+       }
+       
+       /**
+        * Main interface to perform IPA over a given DML program.
+        * 
+        * @param repetitions number of IPA rounds 
+        * @throws HopsException in case of compilation errors
+        */
+       public void analyzeProgram(int repetitions) 
+               throws HopsException    
        {
-               //step 1: intra- and inter-procedural 
-               if( INTRA_PROCEDURAL_ANALYSIS ) {
+               //sanity check for valid number of repetitions
+               if( repetitions <= 0 )
+                       throw new HopsException("Invalid number of IPA 
repetitions: " + repetitions);
+               
+               //perform number of requested IPA iterations
+               for( int i=0; i<repetitions; i++ ) {
+                       if( LOG.isDebugEnabled() )
+                               LOG.debug("IPA: start IPA iteration " + (i+1) + 
"/" + repetitions +".");
+                       
                        //get function call size infos to obtain candidates for 
statistics propagation
                        FunctionCallSizeInfo fcallSizes = new 
FunctionCallSizeInfo(_fgraph);
                        if( LOG.isDebugEnabled() )
                                LOG.debug("IPA: Initial FunctionCallSummary: 
\n" + fcallSizes);
                        
-                       //get unary dimension-preserving non-candidate functions
-                       for( String tmp : fcallSizes.getInvalidFunctions() )
-                               if( 
isUnarySizePreservingFunction(_prog.getFunctionStatementBlock(tmp)) )
-                                       
fcallSizes.addDimsPreservingFunction(tmp);
-                       if( LOG.isDebugEnabled() )
-                               LOG.debug("IPA: Extended FunctionCallSummary: 
\n" + fcallSizes);
+                       //step 1: intra- and inter-procedural 
+                       if( INTRA_PROCEDURAL_ANALYSIS ) {
+                               //get unary dimension-preserving non-candidate 
functions
+                               for( String tmp : 
fcallSizes.getInvalidFunctions() )
+                                       if( 
isUnarySizePreservingFunction(_prog.getFunctionStatementBlock(tmp)) )
+                                               
fcallSizes.addDimsPreservingFunction(tmp);
+                               if( LOG.isDebugEnabled() )
+                                       LOG.debug("IPA: Extended 
FunctionCallSummary: \n" + fcallSizes);
+                               
+                               //propagate statistics and scalars into 
functions and across DAGs
+                               //(callVars used to chain outputs/inputs of 
multiple functions calls)
+                               LocalVariableMap callVars = new 
LocalVariableMap();
+                               for ( StatementBlock sb : 
_prog.getStatementBlocks() ) //propagate stats into candidates
+                                       propagateStatisticsAcrossBlock( sb, 
callVars, fcallSizes, new HashSet<String>() );
+                       }
                        
-                       //propagate statistics and scalars into functions and 
across DAGs
-                       //(callVars used to chain outputs/inputs of multiple 
functions calls)
-                       LocalVariableMap callVars = new LocalVariableMap();
-                       for ( StatementBlock sb : _prog.getStatementBlocks() ) 
//propagate stats into candidates
-                               propagateStatisticsAcrossBlock( sb, callVars, 
fcallSizes, new HashSet<String>() );
+                       //step 2: apply additional IPA passes
+                       for( IPAPass pass : _passes )
+                               if( pass.isApplicable() )
+                                       pass.rewriteProgram(_prog, _fgraph, 
fcallSizes);
                }
                
-               //step 2: apply additional IPA passes
-               for( IPAPass pass : _passes )
-                       if( pass.isApplicable() )
-                               pass.rewriteProgram(_prog, _fgraph);
+               //cleanup pass: remove unused functions
+               FunctionCallGraph graph2 = new FunctionCallGraph(_prog);
+               IPAPass rmFuns = new IPAPassRemoveUnusedFunctions();
+               if( rmFuns.isApplicable() )
+                       rmFuns.rewriteProgram(_prog, graph2, null);
        }
        
        public Set<String> analyzeSubProgram() 
@@ -240,19 +269,6 @@ public class InterProceduralAnalysis
        // INTRA-PROCEDURE ANALYSIS
        //////  
 
-       /**
-        * Perform intra-procedural analysis (IPA) by propagating statistics
-        * across statement blocks.
-        *
-        * @param sb  DML statement blocks.
-        * @param fcand  Function candidates.
-        * @param callVars  Map of variables eligible for propagation.
-        * @param fcandSafeNNZ  Function candidate safe non-zeros.
-        * @param unaryFcands  Unary function candidates.
-        * @param fnStack  Function stack to determine current scope.
-        * @throws HopsException  If a HopsException occurs.
-        * @throws ParseException  If a ParseException occurs.
-        */
        private void propagateStatisticsAcrossBlock( StatementBlock sb, 
LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack 
)
                throws HopsException
        {
@@ -421,16 +437,12 @@ public class InterProceduralAnalysis
         *
         * @param prog  The DML program.
         * @param roots List of HOP DAG root notes for propagation.
-        * @param fcand  Function candidates.
-        * @param callVars  Calling program's map of variables eligible for
-        *                     propagation.
-        * @param fcandSafeNNZ  Function candidate safe non-zeros.
-        * @param unaryFcands  Unary function candidates.
+        * @param callVars  Calling program's map of variables eligible for 
propagation.
+        * @param fcallSizes function call summary
         * @param fnStack  Function stack to determine current scope.
         * @throws HopsException  If a HopsException occurs.
-        * @throws ParseException  If a ParseException occurs.
         */
-       private void propagateStatisticsIntoFunctions(DMLProgram prog, 
ArrayList<Hop> roots, LocalVariableMap callVars, FunctionCallSizeInfo 
fcallSizes, Set<String> fnStack )
+       private void propagateStatisticsIntoFunctions(DMLProgram prog, 
ArrayList<Hop> roots, LocalVariableMap callVars, FunctionCallSizeInfo 
fcallSizes, Set<String> fnStack)
                        throws HopsException
        {
                for( Hop root : roots )
@@ -443,14 +455,10 @@ public class InterProceduralAnalysis
         *
         * @param prog  The DML program.
         * @param hop HOP to propagate statistics into.
-        * @param fcand  Function candidates.
-        * @param callVars  Calling program's map of variables eligible for
-        *                     propagation.
-        * @param fcandSafeNNZ  Function candidate safe non-zeros.
-        * @param unaryFcands  Unary function candidates.
+        * @param callVars  Calling program's map of variables eligible for 
propagation.
+        * @param fcallSizes function call summary
         * @param fnStack  Function stack to determine current scope.
         * @throws HopsException  If a HopsException occurs.
-        * @throws ParseException  If a ParseException occurs.
         */
        private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, 
LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack 
) 
                throws HopsException

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 47446f6..42ab12e 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -266,18 +266,8 @@ public class DMLTranslator
                //propagate size information from main into functions (but 
conservatively)
                if( OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS ) {
                        InterProceduralAnalysis ipa = new 
InterProceduralAnalysis(dmlp);
-                       ipa.analyzeProgram();
+                       
ipa.analyzeProgram(OptimizerUtils.ALLOW_IPA_SECOND_CHANCE ? 2 : 1);
                        resetHopsDAGVisitStatus(dmlp);
-                       if (OptimizerUtils.ALLOW_IPA_SECOND_CHANCE) {
-                               // SECOND CHANCE:
-                               // Rerun static rewrites + IPA to allow for 
further improvements, such as making use
-                               // of constant folding (static rewrite) after 
scalar -> literal replacement (IPA),
-                               // and then further scalar -> literal 
replacement (IPA).
-                               rewriter.rewriteProgramHopDAGs(dmlp);
-                               resetHopsDAGVisitStatus(dmlp);
-                               ipa.analyzeProgram();
-                               resetHopsDAGVisitStatus(dmlp);
-                       }
                }
 
                //apply hop rewrites (dynamic rewrites, after IPA)

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java 
b/src/main/java/org/apache/sysml/utils/Explain.java
index af7102b..5cf0548 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -229,11 +229,12 @@ public class Explain
                                for (String fname : 
prog.getFunctionStatementBlocks(namespace).keySet()) {
                                        FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(namespace, fname);
                                        FunctionStatement fstmt = 
(FunctionStatement) fsb.getStatement(0);
+                                       String fkey = 
DMLProgram.constructFunctionKey(namespace, fname);
                                        
                                        if (fstmt instanceof 
ExternalFunctionStatement)
-                                               sb.append("----EXTERNAL 
FUNCTION " + namespace + "::" + fname + "\n");
+                                               sb.append("----EXTERNAL 
FUNCTION " + fkey + "\n");
                                        else {
-                                               sb.append("----FUNCTION " + 
namespace + "::" + fname + " [recompile="+fsb.isRecompileOnce()+"]\n");
+                                               sb.append("----FUNCTION " + 
fkey + " [recompile="+fsb.isRecompileOnce()+"]\n");
                                                for (StatementBlock current : 
fstmt.getBody())
                                                        
sb.append(explainStatementBlock(current, 3));
                                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
index a73fe5b..e4ff6c3 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
@@ -110,20 +110,9 @@ public class 
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
                        runTest(true, false, null, -1);
 
                        // Check for correct number of compiled & executed 
Spark jobs
-                       if (IPA_SECOND_CHANCE) {
-                               // No distributed instructions 
compiled/executed with second chance enabled
-                               checkNumCompiledSparkInst(0);
-                               checkNumExecutedSparkInst(0);
-                       } else {
-                               // without second chance enabled, distributed 
jobs will be compiled/executed
-                               if (testname == TEST_NAME1) {
-                                       checkNumCompiledSparkInst(2);
-                                       checkNumExecutedSparkInst(1);
-                               } else {  //if (testname == TEST_NAME2) {
-                                       checkNumCompiledSparkInst(1);
-                                       checkNumExecutedSparkInst(0);
-                               }
-                       }
+                       // (MB: originally, this required a second chance, but 
not anymore)
+                       checkNumCompiledSparkInst(0);
+                       checkNumExecutedSparkInst(0);
                }
                finally {
                        // Reset

Reply via email to