Repository: systemml
Updated Branches:
  refs/heads/master 1b3dff06b -> 85e3a9631


New rewrite rule for chains of element-wise multiply.

Placed rewrite rule after Common Subexpression Elimination.
Included helper method in HopRewriteUtils.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d578838
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d578838
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d578838

Branch: refs/heads/master
Commit: 7d578838cc291a1adb6229bae01f7c9428b6f858
Parents: c434208
Author: Dylan Hutchison <dhutc...@cs.washington.edu>
Authored: Thu Jun 8 18:17:36 2017 -0700
Committer: Dylan Hutchison <dhutc...@cs.washington.edu>
Committed: Sun Jun 18 17:43:13 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  17 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   1 +
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 186 +++++++++++++++++++
 .../org/apache/sysml/parser/Expression.java     |   1 +
 4 files changed, 204 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index cf6081b..4d23cb9 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -241,7 +241,22 @@ public class HopRewriteUtils
                parent.getInput().add( pos, child );
                child.getParent().add( parent );
        }
-       
+
+       /**
+        * Replace an old Hop with a replacement Hop.
+        * If the old Hop has no parents, then return the replacement.
+        * Otherwise rewire each of the Hop's parents into the replacement and 
return the replacement.
+        * @return replacement
+        */
+       public static Hop replaceHop(final Hop old, final Hop replacement) {
+               final ArrayList<Hop> rootParents = old.getParent();
+               if (rootParents.isEmpty())
+                       return replacement; // new old!
+               HopRewriteUtils.rewireAllParentChildReferences(old, 
replacement);
+               return replacement;
+       }
+
+
        public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) 
{
                ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent());
                for( Hop lparent : parents )

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 0e65f3f..8573dd7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,6 +96,7 @@ public class ProgramRewriter
                        _dagRuleSet.add(     new 
RewriteRemoveUnnecessaryCasts()             );         
                        if( 
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
                                _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination()     );
+                       _dagRuleSet.add( new RewriteEMult()                     
             ); //dependency: cse
                        if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
                                _dagRuleSet.add( new RewriteConstantFolding()   
                 ); //dependency: cse
                        if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
new file mode 100644
index 0000000..47c32a9
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+
+/**
+ * Prerequisite: RewriteCommonSubexpressionElimination must run before this 
rule.
+ *
+ * Rewrite a chain of element-wise multiply hops that contain identical 
elements.
+ * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), 
where `^` is element-wise power.
+ *
+ * Does not rewrite in the presence of foreign parents in the middle of the 
e-wise multiply chain,
+ * since foreign parents may rely on the individual results.
+ */
+public class RewriteEMult extends HopRewriteRule {
+       @Override
+       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) throws HopsException {
+               if( roots == null )
+                       return null;
+
+               for( int i=0; i<roots.size(); i++ ) {
+                       Hop h = roots.get(i);
+                       roots.set(i, rule_RewriteEMult(h));
+               }
+               return roots;
+       }
+
+       @Override
+       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws 
HopsException {
+               if( root == null )
+                       return null;
+               return rule_RewriteEMult(root);
+       }
+
+       private static boolean isBinaryMult(final Hop hop) {
+               return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == 
Hop.OpOp2.MULT;
+       }
+
+       private static Hop rule_RewriteEMult(final Hop root) {
+               if (root.isVisited())
+                       return root;
+               root.setVisited();
+
+               final ArrayList<Hop> rootInputs = root.getInput();
+
+               // 1. Find immediate subtree of EMults.
+               if (isBinaryMult(root)) {
+                       final Hop left = rootInputs.get(0), right = 
rootInputs.get(1);
+                       final BinaryOp r = (BinaryOp)root;
+                       final Set<BinaryOp> emults = new HashSet<>();
+                       final Multiset<Hop> leaves = HashMultiset.create();
+                       findEMultsAndLeaves(r, emults, leaves);
+                       // 2. Ensure it is profitable to do a rewrite.
+                       if (isOptimizable(leaves)) {
+                               // 3. Check for foreign parents.
+                               // A foreign parent is a parent of some EMult 
that is not in the set.
+                               // Foreign parents destroy correctness of this 
rewrite.
+                               final boolean okay = (!isBinaryMult(left) || 
checkForeignParent(emults, (BinaryOp)left)) &&
+                                               (!isBinaryMult(right) || 
checkForeignParent(emults, (BinaryOp)right));
+                               if (okay) {
+                                       // 4. Construct replacement EMults for 
the leaves
+                                       final Hop replacement = 
constructReplacement(leaves);
+
+                                       // 5. Replace root with replacement
+                                       return HopRewriteUtils.replaceHop(root, 
replacement);
+                               }
+                       }
+               }
+
+               // This rewrite is not applicable to the current root.
+               // Try the root's children.
+               for (int i = 0; i < rootInputs.size(); i++) {
+                       final Hop input = rootInputs.get(i);
+                       final Hop newInput = rule_RewriteEMult(input);
+                       rootInputs.set(i, newInput);
+               }
+               return root;
+       }
+
+       private static Hop constructReplacement(final Multiset<Hop> leaves) {
+               // Sort by data type
+               final SortedMap<Hop,Integer> sorted = new 
TreeMap<>(compareByDataType);
+               for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
+                       final Hop h = entry.getElement();
+                       // unlink parents (the EMults, which we are throwing 
away)
+                       h.getParent().clear();
+                       sorted.put(h, entry.getCount());
+               }
+               // sorted contains all leaves, sorted by data type, stripped 
from their parents
+
+               // Construct left-deep EMult tree
+               Iterator<Map.Entry<Hop, Integer>> iterator = 
sorted.entrySet().iterator();
+               Hop first = constructPower(iterator.next());
+
+               for (int i = 1; i < sorted.size(); i++) {
+                       final Hop second = constructPower(iterator.next());
+                       first = HopRewriteUtils.createBinary(first, second, 
Hop.OpOp2.MULT);
+               }
+               return first;
+       }
+
+       private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+               final Hop hop = entry.getKey();
+               final int cnt = entry.getValue();
+               assert(cnt >= 1);
+               if (cnt == 1)
+                       return hop;
+               return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), 
Hop.OpOp2.POW);
+       }
+
+       private static Comparator<Hop> compareByDataType = 
Comparator.comparing(Hop::getDataType);
+
+       private static boolean checkForeignParent(final Set<BinaryOp> emults, 
final BinaryOp child) {
+               final ArrayList<Hop> parents = child.getParent();
+               if (parents.size() > 1)
+                       for (final Hop parent : parents)
+                               //noinspection SuspiciousMethodCalls
+                               if (!emults.contains(parent))
+                                       return false;
+               // child does not have foreign parents
+
+               final ArrayList<Hop> inputs = child.getInput();
+               final Hop left = inputs.get(0), right = inputs.get(1);
+               return  (!isBinaryMult(left) || checkForeignParent(emults, 
(BinaryOp)left)) &&
+                               (!isBinaryMult(right) || 
checkForeignParent(emults, (BinaryOp)right));
+       }
+
+       /**
+        * Create a set of the counts of all BinaryOp MULTs in the immediate 
subtree, starting with root.
+        */
+       private static void findEMultsAndLeaves(final BinaryOp root, final 
Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+               // Because RewriteCommonSubexpressionElimination already ran, 
it is safe to compare by equality.
+               emults.add(root);
+
+               final ArrayList<Hop> inputs = root.getInput();
+               final Hop left = inputs.get(0), right = inputs.get(1);
+
+               if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, 
emults, leaves);
+               else leaves.add(left);
+
+               if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, 
emults, leaves);
+               else leaves.add(right);
+       }
+
+       /** Only optimize a subtree of EMults if at least one leaf occurs more 
than once. */
+       private static boolean isOptimizable(final Multiset<Hop> set) {
+               for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
+                       if (hopEntry.getCount() > 1)
+                               return true;
+               }
+               return false;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index 9ee3fba..b944e29 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,6 +162,7 @@ public abstract class Expression
         * Data types (matrix, scalar, frame, object, unknown).
         */
        public enum DataType {
+               // Careful: the order of these enums is significant! See 
RewriteEMult.comparatorByDataType
                MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
                
                public boolean isMatrix() {

Reply via email to