http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
new file mode 100644
index 0000000..2fa0de7
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -0,0 +1,1100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.opt.ReachabilityGraph.SubProblem;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
+import org.apache.sysml.hops.codegen.template.TemplateRow;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;
+import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * This cost-based plan selection algorithm chooses fused operators
+ * based on the DAG structure and resulting overall costs. This includes
+ * holistic decisions on 
+ * <ul>
+ *   <li>Materialization points per consumer</li>
+ *   <li>Sparsity exploitation and operator ordering</li>
+ *   <li>Decisions on overlapping template types</li>
+ *   <li>Decisions on multi-aggregates with shared reads</li>
+ *   <li>Constraints (e.g., memory budgets and block sizes)</li>  
+ * </ul>
+ * 
+ */
+public class PlanSelectionFuseCostBasedV2 extends PlanSelection
+{      
+       private static final Log LOG = 
LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName());
+       
+       //common bandwidth characteristics, with a conservative write bandwidth 
in order 
+       //to cover result allocation, write into main memory, and potential 
evictions
+       private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024;  
//2GB/s
+       private static final double READ_BANDWIDTH = 32d*1024*1024*1024;  
//32GB/s
+       private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 
//2GFLOPs/core
+               * InfrastructureAnalyzer.getLocalParallelism();
+       
+       //sparsity estimate for unknown sparsity to prefer sparse-safe fusion 
plans
+       private static final double SPARSE_SAFE_SPARSITY_EST = 0.1;
+       
+       //optimizer configuration
+       private static final boolean USE_COST_PRUNING = true;
+       private static final boolean USE_STRUCTURAL_PRUNING = true;
+       
+       private static final IDSequence COST_ID = new IDSequence();
+       private static final TemplateRow ROW_TPL = new TemplateRow();
+       private static final BasicPlanComparator BASE_COMPARE = new 
BasicPlanComparator();
+       private final TypedPlanComparator _typedCompare = new 
TypedPlanComparator();
+       
+       @Override
+       public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) 
+       {
+               //step 1: analyze connected partitions (nodes, roots, mat 
points)
+               Collection<PlanPartition> parts = 
PlanAnalyzer.analyzePlanPartitions(memo, roots, true);
+               
+               //step 2: optimize individual plan partitions
+               for( PlanPartition part : parts ) {
+                       //create composite templates (within the partition)
+                       createAndAddMultiAggPlans(memo, part.getPartition(), 
part.getRoots());
+                       
+                       //plan enumeration and plan selection
+                       selectPlans(memo, part);
+               }
+               
+               //step 3: add composite templates (across partitions)
+               createAndAddMultiAggPlans(memo, roots);
+               
+               //take all distinct best plans
+               for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
+                       memo.setDistinct(e.getKey(), e.getValue());
+       }
+       
+       private void selectPlans(CPlanMemoTable memo, PlanPartition part) 
+       {
+               //prune row aggregates with pure cellwise operations
+               for( Long hopID : part.getRoots() ) {
+                       MemoTableEntry me = memo.getBest(hopID, 
TemplateType.ROW);
+                       if( me.type == TemplateType.ROW && memo.contains(hopID, 
TemplateType.CELL)
+                               && isRowTemplateWithoutAgg(memo, 
memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+                               List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.ROW); 
+                               memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
+                               if( LOG.isTraceEnabled() ) {
+                                       LOG.trace("Removed row memo table 
entries w/o aggregation: "
+                                               + 
Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+                               }
+                       }
+               }
+               
+               //prune suboptimal outer product plans that are dominated by 
outer product plans w/ same number of 
+               //references but better fusion properties (e.g., for the 
patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), 
+               //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this 
would unnecessarily destroy a fusion pattern.
+               for( Long hopID : part.getPartition() ) {
+                       if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) 
{
+                               List<MemoTableEntry> entries = memo.get(hopID, 
TemplateType.OUTER);
+                               MemoTableEntry me1 = entries.get(0);
+                               MemoTableEntry me2 = entries.get(1);
+                               MemoTableEntry rmEntry = 
TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
+                               if( rmEntry != null ) {
+                                       
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
+                                       
memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Removed dominated 
outer product memo table entry: " + rmEntry);
+                               }
+                       }
+               }
+               
+               //if no materialization points, use basic fuse-all w/ partition 
awareness
+               if( part.getMatPointsExt() == null || 
part.getMatPointsExt().length==0 ) {
+                       for( Long hopID : part.getRoots() )
+                               rSelectPlansFuseAll(memo, 
+                                       memo.getHopRefs().get(hopID), null, 
part.getPartition());
+               }
+               else {
+                       //obtain hop compute costs per cell once
+                       HashMap<Long, Double> computeCosts = new HashMap<Long, 
Double>();
+                       for( Long hopID : part.getRoots() )
+                               rGetComputeCosts(memo.getHopRefs().get(hopID), 
part.getPartition(), computeCosts);
+                       
+                       //prepare pruning helpers and prune memo table w/ 
determined mat points
+                       StaticCosts costs = new StaticCosts(computeCosts, 
getComputeCost(computeCosts, memo), 
+                               getReadCost(part, memo), 
getWriteCost(part.getRoots(), memo));
+                       ReachabilityGraph rgraph = USE_STRUCTURAL_PRUNING ? new 
ReachabilityGraph(part, memo) : null;
+                       if( USE_STRUCTURAL_PRUNING ) {
+                               
part.setMatPointsExt(rgraph.getSortedSearchSpace());
+                               for( Long hopID : part.getPartition() )
+                                       memo.pruneRedundant(hopID, true, 
part.getMatPointsExt());
+                       }
+                       
+                       //enumerate and cost plans, returns optional plan
+                       boolean[] bestPlan = enumPlans(memo, part, costs, 
rgraph, 
+                                       part.getMatPointsExt(), 0, 
Double.MAX_VALUE);
+                       
+                       //prune memo table wrt best plan and select plans
+                       HashSet<Long> visited = new HashSet<Long>();
+                       for( Long hopID : part.getRoots() )
+                               rPruneSuboptimalPlans(memo, 
memo.getHopRefs().get(hopID), 
+                                       visited, part, part.getMatPointsExt(), 
bestPlan);
+                       HashSet<Long> visited2 = new HashSet<Long>();
+                       for( Long hopID : part.getRoots() )
+                               rPruneInvalidPlans(memo, 
memo.getHopRefs().get(hopID), 
+                                       visited2, part, bestPlan);
+                       
+                       for( Long hopID : part.getRoots() )
+                               rSelectPlansFuseAll(memo, 
+                                       memo.getHopRefs().get(hopID), null, 
part.getPartition());
+               }
+       }
+       
+       /**
+        * Core plan enumeration algorithm, invoked recursively for 
conditionally independent
+        * subproblems. This algorithm fully explores the exponential search 
space of 2^m,
+        * where m is the number of interesting materialization points. We 
iterate over
+        * a linearized search space without every instantiating the search 
tree. Furthermore,
+        * in order to reduce the enumeration overhead, we apply two 
high-impact pruning
+        * techniques (1) pruning by evolving lower/upper cost bounds, and (2) 
pruning by
+        * conditional structural properties (so-called cutsets of interesting 
points). 
+        * 
+        * @param memo memoization table of partial fusion plans
+        * @param part connected component (partition) of partial fusion plans 
with all necessary meta data
+        * @param costs summary of static costs (e.g., partition reads, writes, 
and compute costs per operator)
+        * @param rgraph reachability graph of interesting materialization 
points
+        * @param matPoints sorted materialization points (defined the search 
space)
+        * @param off offset for recursive invocation, indicating the fixed 
plan part
+        * @param bestC currently known best plan costs (used of upper bound)
+        * @return optimal assignment of materialization points
+        */
+       private static boolean[] enumPlans(CPlanMemoTable memo, PlanPartition 
part, StaticCosts costs, 
+               ReachabilityGraph rgraph, InterestingPoint[] matPoints, int 
off, double bestC) 
+       {
+               //scan linearized search space, w/ skips for branch and bound 
pruning
+               //and structural pruning (where we solve conditionally 
independent problems)
+               //bestC is monotonically non-increasing and serves as the upper 
bound
+               long len = (long)Math.pow(2, matPoints.length-off);
+               boolean[] bestPlan = null;
+               int numEvalPlans = 0;
+               
+               for( long i=0; i<len; i++ ) {
+                       //construct assignment
+                       boolean[] plan = createAssignment(matPoints.length-off, 
off, i);
+                       long pskip = 0; //skip after costing
+                       
+                       //skip plans with structural pruning
+                       if( USE_STRUCTURAL_PRUNING && (rgraph!=null) && 
rgraph.isCutSet(plan) ) {
+                               //compute skip (which also acts as boundary for 
subproblems)
+                               pskip = rgraph.getNumSkipPlans(plan);
+                               
+                               //start increment rgraph get subproblems
+                               SubProblem[] prob = rgraph.getSubproblems(plan);
+                               
+                               //solve subproblems independently and combine 
into best plan
+                               for( int j=0; j<prob.length; j++ ) {
+                                       boolean[] bestTmp = enumPlans(memo, 
part, 
+                                               costs, null, prob[j].freeMat, 
prob[j].offset, bestC);
+                                       LibSpoofPrimitives.vectWrite(bestTmp, 
plan, prob[j].freePos);
+                               }
+                               
+                               //note: the overall plan costs are evaluated in 
full, which reused
+                               //the default code path; hence we postpone the 
skip after costing
+                       }
+                       //skip plans with branch and bound pruning (cost)
+                       else if( USE_COST_PRUNING ) {
+                               double lbC = Math.max(costs._read, 
costs._compute) + costs._write 
+                                       + getMaterializationCost(part, 
matPoints, memo, plan);
+                               if( lbC >= bestC ) {
+                                       long skip = getNumSkipPlans(plan);
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Enum: Skip "+skip+" 
plans (by cost).");
+                                       i += skip - 1;
+                                       continue;
+                               }
+                       }
+                       
+                       //cost assignment on hops
+                       double C = getPlanCost(memo, part, matPoints, plan, 
costs._computeCosts);
+                       numEvalPlans ++;
+                       if( LOG.isTraceEnabled() )
+                               LOG.trace("Enum: "+Arrays.toString(plan)+" -> 
"+C);
+                       
+                       //cost comparisons
+                       if( bestPlan == null || C < bestC ) {
+                               bestC = C;
+                               bestPlan = plan;
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Enum: Found new best plan.");
+                       }
+                       
+                       //post skipping
+                       i += pskip;
+                       if( pskip !=0 && LOG.isTraceEnabled() )
+                               LOG.trace("Enum: Skip "+pskip+" plans (by 
structure).");
+               }
+               
+               if( DMLScript.STATISTICS )
+                       Statistics.incrementCodegenFPlanCompile(numEvalPlans);
+               if( LOG.isTraceEnabled() )
+                       LOG.trace("Enum: Optimal plan: 
"+Arrays.toString(bestPlan));
+               
+               //copy best plan w/o fixed offset plan
+               return Arrays.copyOfRange(bestPlan, off, bestPlan.length);
+       }
+       
+       private static boolean[] createAssignment(int len, int off, long pos) {
+               boolean[] ret = new boolean[off+len];
+               Arrays.fill(ret, 0, off, true);
+               long tmp = pos;
+               for( int i=0; i<len; i++ ) {
+                       ret[off+i] = (tmp >= Math.pow(2, len-i-1));
+                       tmp %= Math.pow(2, len-i-1);
+               }
+               return ret;     
+       }
+       
+       private static long getNumSkipPlans(boolean[] plan) {
+               int pos = ArrayUtils.lastIndexOf(plan, true);
+               return (long) Math.pow(2, plan.length-pos-1);
+       }
+       
+       private static double getMaterializationCost(PlanPartition part, 
InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) {
+               double costs = 0;
+               //currently active materialization points
+               HashSet<Long> matTargets = new HashSet<>();
+               for( int i=0; i<plan.length; i++ ) {
+                       long hopID = M[i].getToHopID();
+                       if( plan[i] && !matTargets.contains(hopID) ) {
+                               matTargets.add(hopID);
+                               Hop hop = memo.getHopRefs().get(hopID);
+                               long size = getSize(hop);
+                               costs += size * 8 / WRITE_BANDWIDTH + 
+                                               size * 8 / READ_BANDWIDTH;
+                       }
+               }
+               //points with non-partition consumers
+               for( Long hopID : part.getExtConsumed() )
+                       if( !matTargets.contains(hopID) ) {
+                               matTargets.add(hopID);
+                               Hop hop = memo.getHopRefs().get(hopID);
+                               costs += getSize(hop) * 8 / WRITE_BANDWIDTH;
+                       }
+               
+               return costs;
+       }
+       
+       private static double getReadCost(PlanPartition part, CPlanMemoTable 
memo) {
+               double costs = 0;
+               //get partition input reads (at least read once)
+               for( Long hopID : part.getInputs() ) {
+                       Hop hop = memo.getHopRefs().get(hopID);
+                       costs += getSize(hop) * 8 / READ_BANDWIDTH;
+               }
+               return costs;
+       }
+       
+       private static double getWriteCost(Collection<Long> R, CPlanMemoTable 
memo) {
+               double costs = 0;
+               for( Long hopID : R ) {
+                       Hop hop = memo.getHopRefs().get(hopID);
+                       costs += getSize(hop) * 8 / WRITE_BANDWIDTH;
+               }
+               return costs;
+       }
+       
+       private static double getComputeCost(HashMap<Long, Double> 
computeCosts, CPlanMemoTable memo) {
+               double costs = 0;
+               for( Entry<Long,Double> e : computeCosts.entrySet() ) {
+                       Hop mainInput = memo.getHopRefs()
+                               .get(e.getKey()).getInput().get(0);
+                       costs += getSize(mainInput) * e.getValue() / 
COMPUTE_BANDWIDTH;
+               }
+               return costs;
+       }
+       
+       private static long getSize(Hop hop) {
+               return Math.max(hop.getDim1(),1) 
+                       * Math.max(hop.getDim2(),1);
+       }
+       
+       //within-partition multi-agg templates
+       private static void createAndAddMultiAggPlans(CPlanMemoTable memo, 
HashSet<Long> partition, HashSet<Long> R)
+       {
+               //create index of plans that reference full aggregates to avoid 
circular dependencies
+               HashSet<Long> refHops = new HashSet<Long>();
+               for( Entry<Long, List<MemoTableEntry>> e : 
memo.getPlans().entrySet() )
+                       if( !e.getValue().isEmpty() ) {
+                               Hop hop = memo.getHopRefs().get(e.getKey());
+                               for( Hop c : hop.getInput() )
+                                       refHops.add(c.getHopID());
+                       }
+               
+               //find all full aggregations (the fact that they are in the 
same partition guarantees 
+               //that they also have common subexpressions, also full 
aggregations are by def root nodes)
+               ArrayList<Long> fullAggs = new ArrayList<Long>();
+               for( Long hopID : R ) {
+                       Hop root = memo.getHopRefs().get(hopID);
+                       if( !refHops.contains(hopID) && 
isMultiAggregateRoot(root) )
+                               fullAggs.add(hopID);
+               }
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Found within-partition ua(RC) aggregations: 
" +
+                               Arrays.toString(fullAggs.toArray(new Long[0])));
+               }
+               
+               //construct and add multiagg template plans (w/ max 3 
aggregations)
+               for( int i=0; i<fullAggs.size(); i+=3 ) {
+                       int ito = Math.min(i+3, fullAggs.size());
+                       if( ito-i >= 2 ) {
+                               MemoTableEntry me = new 
MemoTableEntry(TemplateType.MAGG,
+                                       fullAggs.get(i), fullAggs.get(i+1), 
((ito-i)==3)?fullAggs.get(i+2):-1, ito-i);
+                               if( isValidMultiAggregate(memo, me) ) {
+                                       for( int j=i; j<ito; j++ ) {
+                                               
memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
+                                               if( LOG.isTraceEnabled() )
+                                                       LOG.trace("Added 
multiagg plan: "+fullAggs.get(j)+" "+me);
+                                       }
+                               }
+                               else if( LOG.isTraceEnabled() ) {
+                                       LOG.trace("Removed invalid multiagg 
plan: "+me);
+                               }
+                       }
+               }
+       }
+       
+       //across-partition multi-agg templates with shared reads
+       private void createAndAddMultiAggPlans(CPlanMemoTable memo, 
ArrayList<Hop> roots)
+       {
+               //collect full aggregations as initial set of candidates
+               HashSet<Long> fullAggs = new HashSet<Long>();
+               Hop.resetVisitStatus(roots);
+               for( Hop hop : roots )
+                       rCollectFullAggregates(hop, fullAggs);
+               Hop.resetVisitStatus(roots);
+
+               //remove operators with assigned multi-agg plans
+               fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
+       
+               //check applicability for further analysis
+               if( fullAggs.size() <= 1 )
+                       return;
+       
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Found across-partition ua(RC) aggregations: 
" +
+                               Arrays.toString(fullAggs.toArray(new Long[0])));
+               }
+               
+               //collect information for all candidates 
+               //(subsumed aggregations, and inputs to fused operators) 
+               List<AggregateInfo> aggInfos = new ArrayList<AggregateInfo>();
+               for( Long hopID : fullAggs ) {
+                       Hop aggHop = memo.getHopRefs().get(hopID);
+                       AggregateInfo tmp = new AggregateInfo(aggHop);
+                       for( int i=0; i<aggHop.getInput().size(); i++ ) {
+                               Hop c = 
HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? 
+                                       
aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
+                               rExtractAggregateInfo(memo, c, tmp, 
TemplateType.CELL);
+                       }
+                       if( tmp._fusedInputs.isEmpty() ) {
+                               if( HopRewriteUtils.isMatrixMultiply(aggHop) ) {
+                                       
tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
+                                       
tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
+                               }
+                               else    
+                                       
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+                       }
+                       aggInfos.add(tmp);      
+               }
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Extracted across-partition ua(RC) 
aggregation info: ");
+                       for( AggregateInfo info : aggInfos )
+                               LOG.trace(info);
+               }
+               
+               //sort aggregations by num dependencies to simplify merging
+               //clusters of aggregations with parallel dependencies
+               aggInfos = aggInfos.stream()
+                       .sorted(Comparator.comparing(a -> a._inputAggs.size()))
+                       .collect(Collectors.toList());
+               
+               //greedy grouping of multi-agg candidates
+               boolean converged = false;
+               while( !converged ) {
+                       AggregateInfo merged = null;
+                       for( int i=0; i<aggInfos.size(); i++ ) {
+                               AggregateInfo current = aggInfos.get(i);
+                               for( int j=i+1; j<aggInfos.size(); j++ ) {
+                                       AggregateInfo that = aggInfos.get(j);
+                                       if( current.isMergable(that) ) {
+                                               merged = current.merge(that);
+                                               aggInfos.remove(j); j--;
+                                       }
+                               }
+                       }
+                       converged = (merged == null);
+               }
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Merged across-partition ua(RC) aggregation 
info: ");
+                       for( AggregateInfo info : aggInfos )
+                               LOG.trace(info);
+               }
+               
+               //construct and add multiagg template plans (w/ max 3 
aggregations)
+               for( AggregateInfo info : aggInfos ) {
+                       if( info._aggregates.size()<=1 )
+                               continue;
+                       Long[] aggs = info._aggregates.keySet().toArray(new 
Long[0]);
+                       MemoTableEntry me = new 
MemoTableEntry(TemplateType.MAGG,
+                               aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, 
aggs.length);
+                       for( int i=0; i<aggs.length; i++ ) {
+                               memo.add(memo.getHopRefs().get(aggs[i]), me);
+                               addBestPlan(aggs[i], me);
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Added multiagg* plan: 
"+aggs[i]+" "+me);
+                               
+                       }
+               }
+       }
+       
+       private static boolean isMultiAggregateRoot(Hop root) {
+               return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, 
AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) 
+                               && 
((AggUnaryOp)root).getDirection()==Direction.RowCol)
+                       || (root instanceof AggBinaryOp && root.getDim1()==1 && 
root.getDim2()==1
+                               && 
HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
+       }
+       
+       private static boolean isValidMultiAggregate(CPlanMemoTable memo, 
MemoTableEntry me) {
+               //ensure input consistent sizes (otherwise potential for 
incorrect results)
+               boolean ret = true;
+               Hop refSize = 
memo.getHopRefs().get(me.input1).getInput().get(0);
+               for( int i=1; ret && i<3; i++ ) {
+                       if( me.isPlanRef(i) )
+                               ret &= HopRewriteUtils.isEqualSize(refSize, 
+                                       
memo.getHopRefs().get(me.input(i)).getInput().get(0));
+               }
+               
+               //ensure that aggregates are independent of each other, i.e.,
+               //they to not have potentially transitive parent child 
references
+               for( int i=0; ret && i<3; i++ ) 
+                       if( me.isPlanRef(i) ) {
+                               HashSet<Long> probe = new HashSet<Long>();
+                               for( int j=0; j<3; j++ )
+                                       if( i != j )
+                                               probe.add(me.input(j));
+                               ret &= 
rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe);
+                       }
+               return ret;
+       }
+       
+       private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> 
probe) {
+               boolean ret = true;
+               for( Hop c : current.getInput() )
+                       ret &= rCheckMultiAggregate(c, probe);
+               ret &= !probe.contains(current.getHopID());
+               return ret;
+       }
+       
+       private static void rCollectFullAggregates(Hop current, HashSet<Long> 
aggs) {
+               if( current.isVisited() )
+                       return;
+               
+               //collect all applicable full aggregations per read
+               if( isMultiAggregateRoot(current) )
+                       aggs.add(current.getHopID());
+               
+               //recursively process children
+               for( Hop c : current.getInput() )
+                       rCollectFullAggregates(c, aggs);
+               
+               current.setVisited();
+       }
+       
+       private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop 
current, AggregateInfo aggInfo, TemplateType type) {
+               //collect input aggregates (dependents)
+               if( isMultiAggregateRoot(current) )
+                       aggInfo.addInputAggregate(current.getHopID());
+               
+               //recursively process children
+               MemoTableEntry me = (type!=null) ? 
memo.getBest(current.getHopID()) : null;
+               for( int i=0; i<current.getInput().size(); i++ ) {
+                       Hop c = current.getInput().get(i);
+                       if( me != null && me.isPlanRef(i) )
+                               rExtractAggregateInfo(memo, c, aggInfo, type);
+                       else {
+                               if( type != null && c.getDataType().isMatrix()  
) //add fused input
+                                       aggInfo.addFusedInput(c.getHopID());
+                               rExtractAggregateInfo(memo, c, aggInfo, null);
+                       }
+               }
+       }
+       
+       private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited) {
+               //consider all aggregations other than root operation
+               MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
+               boolean ret = true;
+               for(int i=0; i<3; i++)
+                       if( me.isPlanRef(i) )
+                               ret &= rIsRowTemplateWithoutAgg(memo, 
+                                       current.getInput().get(i), visited);
+               return ret;
+       }
+       
+       private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, 
Hop current, HashSet<Long> visited) {
+               if( visited.contains(current.getHopID()) )
+                       return true;
+               
+               boolean ret = true;
+               MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
+               for(int i=0; i<3; i++)
+                       if( me.isPlanRef(i) )
+                               ret &= rIsRowTemplateWithoutAgg(memo, 
current.getInput().get(i), visited);
+               ret &= !(current instanceof AggUnaryOp || current instanceof 
AggBinaryOp);
+               
+               visited.add(current.getHopID());
+               return ret;
+       }
+       
+       private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited, 
+               PlanPartition part, InterestingPoint[] matPoints, boolean[] 
plan) 
+       {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //remove memo table entries if necessary
+               long hopID = current.getHopID();
+               if( part.getPartition().contains(hopID) && memo.contains(hopID) 
) {
+                       Iterator<MemoTableEntry> iter = 
memo.get(hopID).iterator();
+                       while( iter.hasNext() ) {
+                               MemoTableEntry me = iter.next();
+                               if( !hasNoRefToMatPoint(hopID, me, matPoints, 
plan) && me.type!=TemplateType.OUTER ) {
+                                       iter.remove();
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Removed memo table 
entry: "+me);
+                               }
+                       }
+               }
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       rPruneSuboptimalPlans(memo, c, visited, part, 
matPoints, plan);
+               
+               visited.add(current.getHopID());
+       }
+       
+       private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited, PlanPartition part, boolean[] plan) {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       rPruneInvalidPlans(memo, c, visited, part, plan);
+               
+               //find invalid row aggregate leaf nodes (see TemplateRow.open) 
w/o matrix inputs, 
+               //i.e., plans that become invalid after the previous pruning 
step
+               long hopID = current.getHopID();
+               if( part.getPartition().contains(hopID) && memo.contains(hopID, 
TemplateType.ROW) ) {
+                       for( MemoTableEntry me : memo.get(hopID) ) {
+                               if( me.type==TemplateType.ROW ) {
+                                       //convert leaf node with pure vector 
inputs
+                                       if( !me.hasPlanRef() && 
!TemplateUtils.hasMatrixInput(current) ) {
+                                               me.type = TemplateType.CELL;
+                                               if( LOG.isTraceEnabled() )
+                                                       LOG.trace("Converted 
leaf memo table entry from row to cell: "+me);
+                                       }
+                                       
+                                       //convert inner node without row 
template input
+                                       if( me.hasPlanRef() && 
!ROW_TPL.open(current) ) {
+                                               boolean hasRowInput = false;
+                                               for( int i=0; i<3; i++ )
+                                                       if( me.isPlanRef(i) )
+                                                               hasRowInput |= 
memo.contains(me.input(i), TemplateType.ROW);
+                                               if( !hasRowInput ) {
+                                                       me.type = 
TemplateType.CELL;
+                                                       if( 
LOG.isTraceEnabled() )
+                                                               
LOG.trace("Converted inner memo table entry from row to cell: "+me);    
+                                               }
+                                       }
+                                       
+                               }
+                       }
+               }
+               
+               visited.add(current.getHopID());                
+       }
+       
+       private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, 
TemplateType currentType, HashSet<Long> partition) 
+       {       
+               if( isVisited(current.getHopID(), currentType) 
+                       || !partition.contains(current.getHopID()) )
+                       return;
+               
+               //step 1: prune subsumed plans of same type
+               if( memo.contains(current.getHopID()) ) {
+                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
+                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
+                       for( MemoTableEntry e1 : hopP )
+                               for( MemoTableEntry e2 : hopP )
+                                       if( e1 != e2 && e1.subsumes(e2) )
+                                               rmSet.add(e2);
+                       memo.remove(current, rmSet);
+               }
+               
+               //step 2: select plan for current path
+               MemoTableEntry best = null;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .min(BASE_COMPARE).orElse(null);
+                       }
+                       else {
+                               _typedCompare.setType(currentType);
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
+                                       .min(_typedCompare).orElse(null);
+                       }
+                       addBestPlan(current.getHopID(), best);
+               }
+               
+               //step 3: recursively process children
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
+                       rSelectPlansFuseAll(memo, current.getInput().get(i), 
pref, partition);
+               }
+               
+               setVisited(current.getHopID(), currentType);
+       }
+       
+       /////////////////////////////////////////////////////////
+       // Cost model fused operators w/ materialization points
+       //////////
+       
+       private static double getPlanCost(CPlanMemoTable memo, PlanPartition 
part, 
+               InterestingPoint[] matPoints,boolean[] plan, HashMap<Long, 
Double> computeCosts) 
+       {
+               //high level heuristic: every hop or fused operator has the 
following cost: 
+               //WRITE + max(COMPUTE, READ), where WRITE costs are given by 
the output size, 
+               //READ costs by the input sizes, and COMPUTE by operation 
specific FLOP
+               //counts times number of cells of main input, disregarding 
sparsity for now.
+               
+               HashSet<VisitMarkCost> visited = new HashSet<>();
+               double costs = 0;
+               for( Long hopID : part.getRoots() ) {
+                       costs += rGetPlanCosts(memo, 
memo.getHopRefs().get(hopID), 
+                               visited, part, matPoints, plan, computeCosts, 
null, null);
+               }
+               return costs;
+       }
+       
+       private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, 
HashSet<VisitMarkCost> visited, 
+               PlanPartition part, InterestingPoint[] matPoints, boolean[] 
plan, HashMap<Long, Double> computeCosts, 
+               CostVector costsCurrent, TemplateType currentType) 
+       {
+               //memoization per hop id and cost vector to account for 
redundant
+               //computation without double counting materialized results or 
compute
+               //costs of complex operation DAGs within a single fused operator
+               VisitMarkCost tag = new VisitMarkCost(current.getHopID(), 
+                       (costsCurrent==null || 
currentType==TemplateType.MAGG)?0:costsCurrent.ID);
+               if( visited.contains(tag) )
+                       return 0;
+               visited.add(tag);
+               
+               //open template if necessary, including memoization
+               //under awareness of current plan choice
+               MemoTableEntry best = null;
+               boolean opened = false;
+               if( memo.contains(current.getHopID()) ) {
+                       //note: this is the inner loop of plan enumeration and 
hence, we do not 
+                       //use streams, lambda expressions, etc to avoid 
unnecessary overhead
+                       long hopID = current.getHopID();
+                       if( currentType == null ) {
+                               for( MemoTableEntry me : memo.get(hopID) )
+                                       best = isValid(me, current) 
+                                               && hasNoRefToMatPoint(hopID, 
me, matPoints, plan) 
+                                               && 
BasicPlanComparator.icompare(me, best)<0 ? me : best;
+                               opened = true;
+                       }
+                       else {
+                               for( MemoTableEntry me : memo.get(hopID) )
+                                       best = (me.type == currentType || 
me.type==TemplateType.CELL)
+                                               && hasNoRefToMatPoint(hopID, 
me, matPoints, plan) 
+                                               && 
TypedPlanComparator.icompare(me, best, currentType)<0 ? me : best;
+                       }
+               }
+               
+               //create new cost vector if opened, initialized with write costs
+               CostVector costVect = !opened ? costsCurrent : new 
CostVector(getSize(current));
+               double costs = 0;
+               
+               //add other roots for multi-agg template to account for shared 
costs
+               if( opened && best != null && best.type == TemplateType.MAGG ) {
+                       //account costs to first multi-agg root 
+                       if( best.input1 == current.getHopID() )
+                               for( int i=1; i<3; i++ ) {
+                                       if( !best.isPlanRef(i) ) continue;
+                                       costs += rGetPlanCosts(memo, 
memo.getHopRefs().get(best.input(i)), visited, 
+                                               part, matPoints, plan, 
computeCosts, costVect, TemplateType.MAGG);
+                               }
+                       //skip other multi-agg roots
+                       else
+                               return 0;
+               }
+               
+               //add compute costs of current operator to costs vector
+               if( part.getPartition().contains(current.getHopID()) )
+                       costVect.computeCosts += 
computeCosts.get(current.getHopID());
+               
+               //process children recursively
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       Hop c = current.getInput().get(i);
+                       if( best!=null && best.isPlanRef(i) )
+                               costs += rGetPlanCosts(memo, c, visited, part, 
matPoints, plan, computeCosts, costVect, best.type);
+                       else if( best!=null && isImplicitlyFused(current, i, 
best.type) )
+                               
costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
+                       else { //include children and I/O costs
+                               costs += rGetPlanCosts(memo, c, visited, part, 
matPoints, plan, computeCosts, null, null);
+                               if( costVect != null && 
c.getDataType().isMatrix() )
+                                       costVect.addInputSize(c.getHopID(), 
getSize(c));
+                       }
+               }
+               
+               //add costs for opened fused operator
+               if( part.getPartition().contains(current.getHopID()) ) {
+                       if( opened ) {
+                               if( LOG.isTraceEnabled() ) {
+                                       String type = (best !=null) ? 
best.type.name() : "HOP";
+                                       LOG.trace("Cost vector ("+type+" 
"+current.getHopID()+"): "+costVect);
+                               }
+                               double tmpCosts = costVect.outSize * 8 / 
WRITE_BANDWIDTH //time for output write
+                                       + Math.max(costVect.getSumInputSizes() 
* 8 / READ_BANDWIDTH,
+                                       
costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH);
+                               //sparsity correction for outer-product 
template (and sparse-safe cell)
+                               if( best != null && best.type == 
TemplateType.OUTER ) {
+                                       Hop driver = 
memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
+                                       tmpCosts *= driver.dimsKnown(true) ? 
driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
+                               }
+                               costs += tmpCosts;
+                       }
+                       //add costs for non-partition read in the middle of 
fused operator
+                       else if( 
part.getExtConsumed().contains(current.getHopID()) ) {
+                               costs += rGetPlanCosts(memo, current, visited,
+                                       part, matPoints, plan, computeCosts, 
null, null);
+                       }
+               }
+               
+               //sanity check non-negative costs
+               if( costs < 0 || Double.isNaN(costs) || 
Double.isInfinite(costs) )
+                       throw new RuntimeException("Wrong cost estimate: 
"+costs);
+               
+               return costs;
+       }
+       
+       private static void rGetComputeCosts(Hop current, HashSet<Long> 
partition, HashMap<Long, Double> computeCosts) 
+       {
+               if( computeCosts.containsKey(current.getHopID()) 
+                       || !partition.contains(current.getHopID()) )
+                       return;
+               
+               //recursively process children
+               for( Hop c : current.getInput() )
+                       rGetComputeCosts(c, partition, computeCosts);
+               
+               //get costs for given hop
+               double costs = 1;
+               if( current instanceof UnaryOp ) {
+                       switch( ((UnaryOp)current).getOp() ) {
+                               case ABS:   
+                               case ROUND:
+                               case CEIL:
+                               case FLOOR:
+                               case SIGN:
+                               case SELP:    costs = 1; break; 
+                               case SPROP:
+                               case SQRT:    costs = 2; break;
+                               case EXP:     costs = 18; break;
+                               case SIGMOID: costs = 21; break;
+                               case LOG:    
+                               case LOG_NZ:  costs = 32; break;
+                               case NCOL:
+                               case NROW:
+                               case PRINT:
+                               case CAST_AS_BOOLEAN:
+                               case CAST_AS_DOUBLE:
+                               case CAST_AS_INT:
+                               case CAST_AS_MATRIX:
+                               case CAST_AS_SCALAR: costs = 1; break;
+                               case SIN:     costs = 18; break;
+                               case COS:     costs = 22; break;
+                               case TAN:     costs = 42; break;
+                               case ASIN:    costs = 93; break;
+                               case ACOS:    costs = 103; break;
+                               case ATAN:    costs = 40; break;
+                               case CUMSUM:
+                               case CUMMIN:
+                               case CUMMAX:
+                               case CUMPROD: costs = 1; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((UnaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof BinaryOp ) {
+                       switch( ((BinaryOp)current).getOp() ) {
+                               case MULT: 
+                               case PLUS:
+                               case MINUS:
+                               case MIN:
+                               case MAX: 
+                               case AND:
+                               case OR:
+                               case EQUAL:
+                               case NOTEQUAL:
+                               case LESS:
+                               case LESSEQUAL:
+                               case GREATER:
+                               case GREATEREQUAL: 
+                               case CBIND:
+                               case RBIND:   costs = 1; break;
+                               case INTDIV:  costs = 6; break;
+                               case MODULUS: costs = 8; break;
+                               case DIV:     costs = 22; break;
+                               case LOG:
+                               case LOG_NZ:  costs = 32; break;
+                               case POW:     costs = 
(HopRewriteUtils.isLiteralOfValue(
+                                               current.getInput().get(1), 2) ? 
1 : 16); break;
+                               case MINUS_NZ:
+                               case MINUS1_MULT: costs = 2; break;
+                               case CENTRALMOMENT:
+                                       int type = (int) 
(current.getInput().get(1) instanceof LiteralOp ? 
+                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+                                       switch( type ) {
+                                               case 0: costs = 1; break; 
//count
+                                               case 1: costs = 8; break; //mean
+                                               case 2: costs = 16; break; //cm2
+                                               case 3: costs = 31; break; //cm3
+                                               case 4: costs = 51; break; //cm4
+                                               case 5: costs = 16; break; 
//variance
+                                       }
+                                       break;
+                               case COVARIANCE: costs = 23; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((BinaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof TernaryOp ) {
+                       switch( ((TernaryOp)current).getOp() ) {
+                               case PLUS_MULT: 
+                               case MINUS_MULT: costs = 2; break;
+                               case CTABLE:     costs = 3; break;
+                               case CENTRALMOMENT:
+                                       int type = (int) 
(current.getInput().get(1) instanceof LiteralOp ? 
+                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+                                       switch( type ) {
+                                               case 0: costs = 2; break; 
//count
+                                               case 1: costs = 9; break; //mean
+                                               case 2: costs = 17; break; //cm2
+                                               case 3: costs = 32; break; //cm3
+                                               case 4: costs = 52; break; //cm4
+                                               case 5: costs = 17; break; 
//variance
+                                       }
+                                       break;
+                               case COVARIANCE: costs = 23; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((TernaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof ParameterizedBuiltinOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof IndexingOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof ReorgOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof AggBinaryOp ) {
+                       //outer product template
+                       if( HopRewriteUtils.isOuterProductLikeMM(current) )
+                               costs = 2 * current.getInput().get(0).getDim2();
+                       //row template w/ matrix-vector or matrix-matrix
+                       else
+                               costs = 2 * current .getDim2();
+               }
+               else if( current instanceof AggUnaryOp) {
+                       switch(((AggUnaryOp)current).getOp()) {
+                       case SUM:    costs = 4; break; 
+                       case SUM_SQ: costs = 5; break;
+                       case MIN:
+                       case MAX:    costs = 1; break;
+                       default:
+                               LOG.warn("Cost model not "
+                                       + "implemented yet for: 
"+((AggUnaryOp)current).getOp());                       
+                       }
+               }
+               
+               computeCosts.put(current.getHopID(), costs);
+       }
+       
+       private static boolean hasNoRefToMatPoint(long hopID, 
+                       MemoTableEntry me, InterestingPoint[] M, boolean[] 
plan) {
+               return !InterestingPoint.isMatPoint(M, hopID, me, plan);
+       }
+       
+       private static boolean isImplicitlyFused(Hop hop, int index, 
TemplateType type) {
+               return type == TemplateType.ROW
+                       && HopRewriteUtils.isMatrixMultiply(hop) && index==0 
+                       && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); 
+       }
+       
+       private static class CostVector {
+               public final long ID;
+               public final double outSize; 
+               public double computeCosts = 0;
+               public final HashMap<Long, Double> inSizes = new HashMap<Long, 
Double>();
+               
+               public CostVector(double outputSize) {
+                       ID = COST_ID.getNextID();
+                       outSize = outputSize;
+               }
+               public void addInputSize(long hopID, double inputSize) {
+                       //ensures that input sizes are not double counted
+                       inSizes.put(hopID, inputSize);
+               }
+               public double getSumInputSizes() {
+                       return inSizes.values().stream()
+                               .mapToDouble(d -> d.doubleValue()).sum();
+               }
+               public double getMaxInputSize() {
+                       return inSizes.values().stream()
+                               .mapToDouble(d -> 
d.doubleValue()).max().orElse(0);
+               }
+               public long getMaxInputSizeHopID() {
+                       long id = -1; double max = 0;
+                       for( Entry<Long,Double> e : inSizes.entrySet() )
+                               if( max < e.getValue() ) {
+                                       id = e.getKey();
+                                       max = e.getValue();
+                               }
+                       return id;
+               }
+               @Override
+               public String toString() {
+                       return "["+outSize+", "+computeCosts+", {"
+                               +Arrays.toString(inSizes.keySet().toArray(new 
Long[0]))+", "
+                               +Arrays.toString(inSizes.values().toArray(new 
Double[0]))+"}]";
+               }
+       }
+       
+       private static class StaticCosts {
+               public final HashMap<Long, Double> _computeCosts;
+               public final double _compute;
+               public final double _read;
+               public final double _write;
+               
+               public StaticCosts(HashMap<Long,Double> allComputeCosts, double 
computeCost, double readCost, double writeCost) {
+                       _computeCosts = allComputeCosts;
+                       _compute = computeCost;
+                       _read = readCost;
+                       _write = writeCost;
+               }
+       }
+       
+       private static class AggregateInfo {
+               public final HashMap<Long,Hop> _aggregates;
+               public final HashSet<Long> _inputAggs = new HashSet<Long>();
+               public final HashSet<Long> _fusedInputs = new HashSet<Long>();
+               public AggregateInfo(Hop aggregate) {
+                       _aggregates = new HashMap<Long, Hop>();
+                       _aggregates.put(aggregate.getHopID(), aggregate);
+               }
+               public void addInputAggregate(long hopID) {
+                       _inputAggs.add(hopID);
+               }
+               public void addFusedInput(long hopID) {
+                       _fusedInputs.add(hopID);
+               }
+               public boolean isMergable(AggregateInfo that) {
+                       //check independence
+                       boolean ret = _aggregates.size()<3 
+                               && 
_aggregates.size()+that._aggregates.size()<=3;
+                       for( Long hopID : that._aggregates.keySet() )
+                               ret &= !_inputAggs.contains(hopID);
+                       for( Long hopID : _aggregates.keySet() )
+                               ret &= !that._inputAggs.contains(hopID);
+                       //check partial shared reads
+                       ret &= !CollectionUtils.intersection(
+                               _fusedInputs, that._fusedInputs).isEmpty();
+                       //check consistent sizes (result correctness)
+                       Hop in1 = _aggregates.values().iterator().next();
+                       Hop in2 = that._aggregates.values().iterator().next();
+                       return ret && HopRewriteUtils.isEqualSize(
+                               
in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0),
+                               
in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0));
+               }
+               public AggregateInfo merge(AggregateInfo that) {
+                       _aggregates.putAll(that._aggregates);
+                       _inputAggs.addAll(that._inputAggs);
+                       _fusedInputs.addAll(that._fusedInputs);
+                       return this;
+               }
+               @Override
+               public String toString() {
+                       return 
"["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": "
+                               +"{"+Arrays.toString(_inputAggs.toArray(new 
Long[0]))+"}," 
+                               +"{"+Arrays.toString(_fusedInputs.toArray(new 
Long[0]))+"}]"; 
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
new file mode 100644
index 0000000..759a903
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Map.Entry;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+
+/**
+ * This plan selection heuristic aims for fusion without any redundant 
+ * computation, which, however, potentially leads to more materialized 
+ * intermediates than the fuse all heuristic.
+ * <p>
+ * NOTE: This heuristic is essentially the same as FuseAll, except that 
+ * any plans that refer to a hop with multiple consumers are removed in 
+ * a pre-processing step.
+ * 
+ */
+public class PlanSelectionFuseNoRedundancy extends PlanSelection
+{      
+       @Override
+       public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
+               //pruning and collection pass
+               for( Hop hop : roots )
+                       rSelectPlans(memo, hop, null);
+               
+               //take all distinct best plans
+               for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
+                       memo.setDistinct(e.getKey(), e.getValue());
+       }
+       
+       private void rSelectPlans(CPlanMemoTable memo, Hop current, 
TemplateType currentType) 
+       {       
+               if( isVisited(current.getHopID(), currentType) )
+                       return;
+               
+               //step 0: remove plans that refer to a common partial plan
+               if( memo.contains(current.getHopID()) ) {
+                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
+                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
+                       for( MemoTableEntry e1 : hopP )
+                               for( int i=0; i<3; i++ )
+                                       if( e1.isPlanRef(i) && 
current.getInput().get(i).getParent().size()>1 )
+                                               rmSet.add(e1); //remove 
references to hops w/ multiple consumers
+                       memo.remove(current, rmSet);
+               }
+               
+               //step 1: prune subsumed plans of same type
+               if( memo.contains(current.getHopID()) ) {
+                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
+                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
+                       for( MemoTableEntry e1 : hopP )
+                               for( MemoTableEntry e2 : hopP )
+                                       if( e1 != e2 && e1.subsumes(e2) )
+                                               rmSet.add(e2);
+                       memo.remove(current, rmSet);
+               }
+               
+               //step 2: select plan for current path
+               MemoTableEntry best = null;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .min(new 
BasicPlanComparator()).orElse(null);
+                       }
+                       else {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
+                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+                                       .orElse(null);
+                       }
+                       addBestPlan(current.getHopID(), best);
+               }
+               
+               //step 3: recursively process children
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
+                       rSelectPlans(memo, current.getInput().get(i), pref);
+               }
+               
+               setVisited(current.getHopID(), currentType);
+       }       
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
new file mode 100644
index 0000000..de1ed92
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysml.hops.codegen.opt.PlanSelection.VisitMarkCost;
+
+/**
+ *  
+ */
+public class ReachabilityGraph 
+{
+       private HashMap<Pair<Long,Long>,NodeLink> _matPoints = null;
+       private NodeLink _root = null;
+       
+       private InterestingPoint[] _searchSpace;
+       private CutSet[] _cutSets;
+       
+       public ReachabilityGraph(PlanPartition part, CPlanMemoTable memo) {
+               //create repository of materialization points
+               _matPoints = new HashMap<>();
+               for( InterestingPoint p : part.getMatPointsExt() )
+                       _matPoints.put(Pair.of(p._fromHopID, p._toHopID), new 
NodeLink(p));
+               
+               //create reachability graph
+               _root = new NodeLink(null);
+               HashSet<VisitMarkCost> visited = new HashSet<>();
+               for( Long hopID : part.getRoots() ) {
+                       Hop rootHop = memo.getHopRefs().get(hopID);
+                       addInputNodeLinks(rootHop, _root, part, memo, visited);
+               }
+               
+               //create candidate cutsets 
+               List<NodeLink> tmpCS = _matPoints.values().stream()
+                       .filter(p -> p._inputs.size() > 0 && p._p != null)
+                       .sorted().collect(Collectors.toList());
+               
+               //short-cut for partitions without cutsets
+               if( tmpCS.isEmpty() ) {
+                       _cutSets = new CutSet[0];
+                       _searchSpace = part.getMatPointsExt();
+                       return;
+               }
+               
+               //create composite cutsets 
+               ArrayList<ArrayList<NodeLink>> candCS = new ArrayList<>();
+               ArrayList<NodeLink> current = new ArrayList<>();
+               for( NodeLink node : tmpCS ) {
+                       if( current.isEmpty() )
+                               current.add(node);
+                       else if( current.get(0).equals(node) )
+                               current.add(node);
+                       else {
+                               candCS.add(current);
+                               current = new ArrayList<>();
+                               current.add(node);
+                       }
+               }
+               if( !current.isEmpty() )
+                       candCS.add(current);
+               
+               //evaluate cutsets (single, and duplicate pairs)
+               ArrayList<ArrayList<NodeLink>> remain = new ArrayList<>();
+               ArrayList<Pair<CutSet,Double>> cutSets = 
evaluateCutSets(candCS, remain);
+               if( !remain.isEmpty() && remain.size() < 5 ) {
+                       //second chance: for pairs for remaining candidates
+                       ArrayList<ArrayList<NodeLink>> candCS2 = new 
ArrayList<>();
+                       for( int i=0; i<remain.size()-1; i++)
+                               for( int j=i+1; j<remain.size(); j++) {
+                                       ArrayList<NodeLink> tmp = new 
ArrayList<>();
+                                       tmp.addAll(remain.get(i));
+                                       tmp.addAll(remain.get(j));
+                                       candCS2.add(tmp);
+                               }
+                       ArrayList<Pair<CutSet,Double>> cutSets2 = 
evaluateCutSets(candCS2, remain);
+                       //ensure constructed cutsets are disjoint
+                       HashSet<InterestingPoint> testDisjoint = new 
HashSet<>();
+                       for( Pair<CutSet,Double> cs : cutSets2 ) {
+                               if( !CollectionUtils.containsAny(testDisjoint, 
Arrays.asList(cs.getLeft().cut)) ) {
+                                       cutSets.add(cs);
+                                       CollectionUtils.addAll(testDisjoint, 
cs.getLeft().cut);
+                               }
+                       }
+               }
+               
+               //sort and linearize search space according to scores
+               _cutSets = cutSets.stream()
+                               .sorted(Comparator.comparing(p -> p.getRight()))
+                               .map(p -> p.getLeft()).toArray(CutSet[]::new);
+       
+               HashMap<InterestingPoint, Integer> probe = new HashMap<>();
+               ArrayList<InterestingPoint> lsearchSpace = new ArrayList<>();
+               for( CutSet cs : _cutSets ) {
+                       cs.updatePos(lsearchSpace.size());
+                       cs.updatePartitions(probe);
+                       CollectionUtils.addAll(lsearchSpace, cs.cut);
+                       for( InterestingPoint p: cs.cut )
+                               probe.put(p, probe.size()-1);
+               }
+               for( InterestingPoint p : part.getMatPointsExt() )
+                       if( !probe.containsKey(p) ) {
+                               lsearchSpace.add(p);
+                               probe.put(p, probe.size()-1);
+                       }
+               _searchSpace = lsearchSpace.toArray(new InterestingPoint[0]);
+               
+               //materialize partition indices
+               for( CutSet cs : _cutSets ) {
+                       cs.updatePartitionIndexes(probe);
+                       cs.finalizePartition();
+               }
+               
+               //final sanity check of interesting points
+               if( _searchSpace.length != part.getMatPointsExt().length )
+                       throw new RuntimeException("Corrupt linearized search 
space: " +
+                               _searchSpace.length+" vs 
"+part.getMatPointsExt().length);
+       }
+       
+       public InterestingPoint[] getSortedSearchSpace() {
+               return _searchSpace;
+       }
+
+       public boolean isCutSet(boolean[] plan) {
+               for( CutSet cs : _cutSets )
+                       if( isCutSet(cs, plan) )
+                               return true;
+               return false;
+       }
+       
+       public boolean isCutSet(CutSet cs, boolean[] plan) {
+               boolean ret = true;
+               for(int i=0; i<cs.posCut.length && ret; i++)
+                       ret &= plan[cs.posCut[i]];
+               return ret;
+       }
+       
+       public CutSet getCutSet(boolean[] plan) {
+               for( CutSet cs : _cutSets )
+                       if( isCutSet(cs, plan) )
+                               return cs;
+               throw new RuntimeException("No valid cut set found.");
+       }
+
+       public long getNumSkipPlans(boolean[] plan) {
+               for( CutSet cs : _cutSets )
+                       if( isCutSet(cs, plan) ) {
+                               int pos = cs.posCut[cs.posCut.length-1];        
                        
+                               return (long) Math.pow(2, plan.length-pos-1);
+                       }
+               throw new RuntimeException("Failed to compute "
+                       + "number of skip plans for plan without cutset.");
+       }
+
+
+       public SubProblem[] getSubproblems(boolean[] plan) {
+               CutSet cs = getCutSet(plan);
+               return new SubProblem[] {
+                               new SubProblem(cs.cut.length, cs.posLeft, 
cs.left), 
+                               new SubProblem(cs.cut.length, cs.posRight, 
cs.right)};
+       }
+       
+       @Override
+       public String toString() {
+               return "ReachabilityGraph("+_matPoints.size()+"):\n"
+                       + _root.explain(new HashSet<>());
+       }
+       
+       private void addInputNodeLinks(Hop current, NodeLink parent, 
PlanPartition part, 
+               CPlanMemoTable memo, HashSet<VisitMarkCost> visited) 
+       {
+               if( visited.contains(new VisitMarkCost(current.getHopID(), 
parent._ID)) )
+                       return;
+               
+               //process children
+               for( Hop in : current.getInput() ) {
+                       if( InterestingPoint.isMatPoint(part.getMatPointsExt(), 
current.getHopID(), in.getHopID()) ) {
+                               NodeLink tmp = 
_matPoints.get(Pair.of(current.getHopID(), in.getHopID()));
+                               parent.addInput(tmp);
+                               addInputNodeLinks(in, tmp, part, memo, visited);
+                       }
+                       else
+                               addInputNodeLinks(in, parent, part, memo, 
visited);
+               }
+               
+               visited.add(new VisitMarkCost(current.getHopID(), parent._ID));
+       }
+       
+       private void rCollectInputs(NodeLink current, HashSet<NodeLink> probe, 
HashSet<NodeLink> inputs) {
+               for( NodeLink c : current._inputs ) 
+                       if( !probe.contains(c) ) {
+                               rCollectInputs(c, probe, inputs);
+                               inputs.add(c);
+                       }
+       }
+       
+       private ArrayList<Pair<CutSet,Double>> 
evaluateCutSets(ArrayList<ArrayList<NodeLink>> candCS, 
ArrayList<ArrayList<NodeLink>> remain) {
+               ArrayList<Pair<CutSet,Double>> cutSets = new ArrayList<>();
+               
+               for( ArrayList<NodeLink> cand : candCS ) {
+                       HashSet<NodeLink> probe = new HashSet<>(cand);
+                       
+                       //determine subproblems for cutset candidates
+                       HashSet<NodeLink> part1 = new HashSet<>();
+                       rCollectInputs(_root, probe, part1);
+                       HashSet<NodeLink> part2 = new HashSet<>();
+                       for( NodeLink rNode : cand )
+                               rCollectInputs(rNode, probe, part2);
+                       
+                       //select, score and create cutsets
+                       if( !CollectionUtils.containsAny(part1, part2) 
+                               && !part1.isEmpty() && !part2.isEmpty()) {
+                               //score cutsets (smaller is better)
+                               double base = Math.pow(2, _matPoints.size());
+                               double numComb = Math.pow(2, cand.size());
+                               double score = (numComb-1)/numComb * base
+                                       + 1/numComb * Math.pow(2, part1.size())
+                                       + 1/numComb * Math.pow(2, part2.size());
+                               
+                               //construct cutset
+                               cutSets.add(Pair.of(new CutSet(
+                                       
cand.stream().map(p->p._p).toArray(InterestingPoint[]::new), 
+                                       
part1.stream().map(p->p._p).toArray(InterestingPoint[]::new), 
+                                       
part2.stream().map(p->p._p).toArray(InterestingPoint[]::new)), score));
+                       }
+                       else {
+                               remain.add(cand);
+                       }
+               }
+               
+               return cutSets;
+       }
+               
+       public static class SubProblem {
+               public int offset;
+               public int[] freePos;
+               public InterestingPoint[] freeMat;
+               
+               public SubProblem(int off, int[] pos, InterestingPoint[] mat) {
+                       offset = off;
+                       freePos = pos;
+                       freeMat = mat;
+               }
+       }
+       
+       public static class CutSet {
+               public InterestingPoint[] cut;
+               public InterestingPoint[] left;
+               public InterestingPoint[] right;
+               public int[] posCut;
+               public int[] posLeft;
+               public int[] posRight;
+               
+               public CutSet(InterestingPoint[] cutPoints, 
+                               InterestingPoint[] l, InterestingPoint[] r) {
+                       cut = cutPoints;
+                       left = l;
+                       right = r;
+               }
+               
+               public void updatePos(int index) {
+                       posCut = new int[cut.length];
+                       for(int i=0; i<posCut.length; i++)
+                               posCut[i] = index + i;
+               }
+               
+               public void updatePartitions(HashMap<InterestingPoint,Integer> 
blacklist) {
+                       left = Arrays.stream(left).filter(p -> 
!blacklist.containsKey(p))
+                               .toArray(InterestingPoint[]::new);
+                       right = Arrays.stream(right).filter(p -> 
!blacklist.containsKey(p))
+                               .toArray(InterestingPoint[]::new);
+               }
+               
+               public void 
updatePartitionIndexes(HashMap<InterestingPoint,Integer> probe) {
+                       posLeft = new int[left.length];
+                       for(int i=0; i<left.length; i++)
+                               posLeft[i] = probe.get(left[i]);
+                       posRight = new int[right.length];
+                       for(int i=0; i<right.length; i++)
+                               posRight[i] = probe.get(right[i]);
+               }
+               
+               public void finalizePartition() {
+                       left = (InterestingPoint[]) ArrayUtils.addAll(cut, 
left);
+                       right = (InterestingPoint[]) ArrayUtils.addAll(cut, 
right);
+               }
+               
+               @Override
+               public String toString() {
+                       return "Cut : "+Arrays.toString(cut);
+               }
+       }
+               
+       private static class NodeLink implements Comparable<NodeLink>
+       {
+               private static final IDSequence _seqID = new IDSequence();
+               
+               private ArrayList<NodeLink> _inputs = new ArrayList<>();
+               private long _ID;
+               private InterestingPoint _p;
+               
+               public NodeLink(InterestingPoint p) {
+                       _ID = _seqID.getNextID();
+                       _p = p;
+               } 
+               
+               public void addInput(NodeLink in) {
+                       _inputs.add(in);
+               }
+               
+               @Override
+               public boolean equals(Object o) {
+                       if( !(o instanceof NodeLink) )
+                               return false;
+                       NodeLink that = (NodeLink) o;
+                       boolean ret = (_inputs.size() == that._inputs.size());
+                       for( int i=0; i<_inputs.size() && ret; i++ )
+                               ret &= (_inputs.get(i)._ID == 
that._inputs.get(i)._ID);
+                       return ret;
+               }
+               
+               @Override
+               public int compareTo(NodeLink that) {
+                       if( _inputs.size() > that._inputs.size() )
+                               return -1;
+                       else if( _inputs.size() < that._inputs.size() )
+                               return 1;
+                       for( int i=0; i<_inputs.size(); i++ ) {
+                               int comp = Long.compare(_inputs.get(i)._ID, 
+                                       that._inputs.get(i)._ID);
+                               if( comp != 0 )
+                                       return comp;
+                       }
+                       return 0;
+               }
+               
+               @Override
+               public String toString() {
+                       StringBuilder inputs = new StringBuilder();
+                       for(NodeLink in : _inputs) {
+                               if( inputs.length() > 0 )
+                                       inputs.append(",");
+                               inputs.append(in._ID);
+                       }
+                       return _ID+" ("+inputs.toString()+") 
"+((_p!=null)?_p:"null");
+               }
+               
+               private String explain(HashSet<Long> visited) {
+                       if( visited.contains(_ID) )
+                               return "";
+                       //add children
+                       StringBuilder sb = new StringBuilder();
+                       StringBuilder inputs = new StringBuilder();
+                       for(NodeLink in : _inputs) {
+                               String tmp = in.explain(visited);
+                               if( !tmp.isEmpty() )
+                                       sb.append(tmp + "\n");
+                               if( inputs.length() > 0 )
+                                       inputs.append(",");
+                               inputs.append(in._ID);
+                       }
+                       //add node itself
+                       sb.append(_ID+" ("+inputs+") "+((_p!=null)?_p:"null"));
+                       visited.add(_ID);
+                       
+                       return sb.toString();
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index edbcdf9..4078060 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.codegen.template;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -36,6 +37,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.codegen.SpoofCompiler;
+import org.apache.sysml.hops.codegen.opt.InterestingPoint;
+import org.apache.sysml.hops.codegen.opt.PlanSelection;
 import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
@@ -53,6 +56,18 @@ public class CPlanMemoTable
                _plansBlacklist = new HashSet<Long>();
        }
        
+       public HashMap<Long, List<MemoTableEntry>> getPlans() {
+               return _plans;
+       }
+       
+       public HashSet<Long> getPlansBlacklisted() {
+               return _plansBlacklist;
+       }
+       
+       public HashMap<Long, Hop> getHopRefs() {
+               return _hopRefs;
+       }
+       
        public void addHop(Hop hop) {
                _hopRefs.put(hop.getHopID(), hop);
        }
@@ -78,6 +93,14 @@ public class CPlanMemoTable
                        .anyMatch(p -> (!checkClose||!p.closed) && 
probe.contains(p.type));
        }
        
+       public boolean containsNotIn(long hopID, Collection<TemplateType> 
types, 
+               boolean checkChildRefs, boolean excludeCell) {
+               return contains(hopID) && get(hopID).stream()
+                       .anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) 
+                               && (!excludeCell || p.type!=TemplateType.CELL)
+                               && !types.contains(p.type));
+       }
+       
        public int countEntries(long hopID) {
                return get(hopID).size();
        }
@@ -85,7 +108,7 @@ public class CPlanMemoTable
        public int countEntries(long hopID, TemplateType type) {
                return (int) get(hopID).stream()
                        .filter(p -> p.type==type).count();
-       } 
+       }
        
        public boolean containsTopLevel(long hopID) {
                return !_plansBlacklist.contains(hopID)
@@ -133,7 +156,7 @@ public class CPlanMemoTable
                        .distinct().collect(Collectors.toList()));
        }
 
-       public void pruneRedundant(long hopID) {
+       public void pruneRedundant(long hopID, boolean pruneDominated, 
InterestingPoint[] matPoints) {
                if( !contains(hopID) )
                        return;
                
@@ -146,7 +169,7 @@ public class CPlanMemoTable
                //prune dominated plans (e.g., opened plan subsumed by fused 
plan 
                //if single consumer of input; however this only applies to 
fusion
                //heuristic that only consider materialization points)
-               if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
+               if( pruneDominated ) {
                        HashSet<MemoTableEntry> rmList = new 
HashSet<MemoTableEntry>();
                        List<MemoTableEntry> list = _plans.get(hopID);
                        Hop hop = _hopRefs.get(hopID);
@@ -155,9 +178,12 @@ public class CPlanMemoTable
                                        if( e1 != e2 && e1.subsumes(e2) ) {
                                                //check that childs don't have 
multiple consumers
                                                boolean rmSafe = true; 
-                                               for( int i=0; i<=2; i++ )
+                                               for( int i=0; i<=2; i++ ) {
                                                        rmSafe &= 
(e1.isPlanRef(i) && !e2.isPlanRef(i)) ?
-                                                               
hop.getInput().get(i).getParent().size()==1 : true;
+                                                               
(matPoints!=null && !InterestingPoint.isMatPoint(
+                                                                       
matPoints, hopID, e1.input(i)))
+                                                               || 
hop.getInput().get(i).getParent().size()==1 : true;
+                                               }
                                                if( rmSafe )
                                                        rmList.add(e2);
                                        }
@@ -194,12 +220,14 @@ public class CPlanMemoTable
                //prune dominated plans (e.g., plan referenced by other plan 
and this
                //other plan is single consumer) by marking it as blacklisted 
because
                //the chain of entries is still required for cplan construction
-               for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
-                       for( MemoTableEntry me : e.getValue() ) {
-                               for( int i=0; i<=2; i++ )
-                                       if( me.isPlanRef(i) && 
_hopRefs.get(me.input(i)).getParent().size()==1 )
-                                               
_plansBlacklist.add(me.input(i));
-                       }
+               if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
+                       for( Entry<Long, List<MemoTableEntry>> e : 
_plans.entrySet() )
+                               for( MemoTableEntry me : e.getValue() ) {
+                                       for( int i=0; i<=2; i++ )
+                                               if( me.isPlanRef(i) && 
_hopRefs.get(me.input(i)).getParent().size()==1 )
+                                                       
_plansBlacklist.add(me.input(i));
+                               }
+               }
                
                //core plan selection
                PlanSelection selector = SpoofCompiler.createPlanSelector();
@@ -232,6 +260,16 @@ public class CPlanMemoTable
                        .distinct().collect(Collectors.toList());
        }
        
+       public List<TemplateType> getDistinctTemplateTypes(long hopID, int 
refAt) {
+               if(!contains(hopID))
+                       return Collections.emptyList();
+               //return distinct template types with reference at given 
position
+               return _plans.get(hopID).stream()
+                       .filter(p -> p.isPlanRef(refAt))
+                       .map(p -> p.type) //extract type
+                       .distinct().collect(Collectors.toList());
+       }
+       
        public MemoTableEntry getBest(long hopID) {
                List<MemoTableEntry> tmp = get(hopID);
                if( tmp == null || tmp.isEmpty() )

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
deleted file mode 100644
index f8a12fd..0000000
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.codegen.template;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
-import org.apache.sysml.hops.rewrite.HopRewriteUtils;
-import org.apache.sysml.runtime.util.UtilFunctions;
-
-public abstract class PlanSelection 
-{
-       private final HashMap<Long, List<MemoTableEntry>> _bestPlans = 
-                       new HashMap<Long, List<MemoTableEntry>>();
-       private final HashSet<VisitMark> _visited = new HashSet<VisitMark>();
-       
-       /**
-        * Given a HOP DAG G, and a set of partial fusions plans P, find the 
set of optimal, 
-        * non-conflicting fusion plans P' that applied to G minimizes costs C 
with
-        * P' = \argmin_{p \subseteq P} C(G, p) s.t. Z \vDash p, where Z is a 
set of 
-        * constraints such as memory budgets and block size restrictions per 
fused operator.
-        * 
-        * @param memo partial fusion plans P
-        * @param roots entry points of HOP DAG G
-        */
-       public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> 
roots);    
-       
-       /**
-        * Determines if the given partial fusion plan is valid.
-        * 
-        * @param me memo table entry
-        * @param hop current hop
-        * @return true if entry is valid as top-level plan
-        */
-       protected static boolean isValid(MemoTableEntry me, Hop hop) {
-               return (me.type == TemplateType.OuterProdTpl 
-                               && (me.closed || 
HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)))
-                       || (me.type == TemplateType.RowTpl)     
-                       || (me.type == TemplateType.CellTpl)
-                       || (me.type == TemplateType.MultiAggTpl);
-       }
-       
-       protected void addBestPlan(long hopID, MemoTableEntry me) {
-               if( me == null ) return;
-               if( !_bestPlans.containsKey(hopID) )
-                       _bestPlans.put(hopID, new ArrayList<MemoTableEntry>());
-               _bestPlans.get(hopID).add(me);
-       }
-       
-       protected HashMap<Long, List<MemoTableEntry>> getBestPlans() {
-               return _bestPlans;
-       }
-       
-       protected boolean isVisited(long hopID, TemplateType type) {
-               return _visited.contains(new VisitMark(hopID, type));
-       }
-       
-       protected void setVisited(long hopID, TemplateType type) {
-               _visited.add(new VisitMark(hopID, type));
-       }
-       
-       /**
-        * Basic plan comparator to compare memo table entries with regard to
-        * a pre-defined template preference order and the number of references.
-        */
-       protected static class BasicPlanComparator implements 
Comparator<MemoTableEntry> {
-               @Override
-               public int compare(MemoTableEntry o1, MemoTableEntry o2) {
-                       //for different types, select preferred type
-                       if( o1.type != o2.type )
-                               return Integer.compare(o1.type.getRank(), 
o2.type.getRank());
-                       
-                       //for same type, prefer plan with more refs
-                       return Integer.compare(
-                               3-o1.countPlanRefs(), 3-o2.countPlanRefs());
-               }
-       }
-       
-       private static class VisitMark {
-               private final long _hopID;
-               private final TemplateType _type;
-               
-               public VisitMark(long hopID, TemplateType type) {
-                       _hopID = hopID;
-                       _type = type;
-               }
-               @Override
-               public int hashCode() {
-                       return UtilFunctions.longHashCode(
-                               _hopID, (_type!=null)?_type.hashCode():0);
-               }
-               @Override 
-               public boolean equals(Object o) {
-                       return (o instanceof VisitMark
-                               && _hopID == ((VisitMark)o)._hopID
-                               && _type == ((VisitMark)o)._type);
-               }
-       }
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
deleted file mode 100644
index a455302..0000000
--- 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.codegen.template;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.Map.Entry;
-import java.util.HashSet;
-import java.util.List;
-
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
-
-/**
- * This plan selection heuristic aims for maximal fusion, which
- * potentially leads to overlapping fused operators and thus,
- * redundant computation but with a minimal number of materialized
- * intermediate results.
- * 
- */
-public class PlanSelectionFuseAll extends PlanSelection
-{      
-       @Override
-       public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
-               //pruning and collection pass
-               for( Hop hop : roots )
-                       rSelectPlans(memo, hop, null);
-               
-               //take all distinct best plans
-               for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
-                       memo.setDistinct(e.getKey(), e.getValue());
-       }
-       
-       private void rSelectPlans(CPlanMemoTable memo, Hop current, 
TemplateType currentType) 
-       {       
-               if( isVisited(current.getHopID(), currentType) )
-                       return;
-               
-               //step 1: prune subsumed plans of same type
-               if( memo.contains(current.getHopID()) ) {
-                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
-                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
-                       for( MemoTableEntry e1 : hopP )
-                               for( MemoTableEntry e2 : hopP )
-                                       if( e1 != e2 && e1.subsumes(e2) )
-                                               rmSet.add(e2);
-                       memo.remove(current, rmSet);
-               }
-               
-               //step 2: select plan for current path
-               MemoTableEntry best = null;
-               if( memo.contains(current.getHopID()) ) {
-                       if( currentType == null ) {
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> isValid(p, current))
-                                       .min(new 
BasicPlanComparator()).orElse(null);
-                       }
-                       else {
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CellTpl)
-                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
-                                       .orElse(null);
-                       }
-                       addBestPlan(current.getHopID(), best);
-               }
-               
-               //step 3: recursively process children
-               for( int i=0; i< current.getInput().size(); i++ ) {
-                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
-                       rSelectPlans(memo, current.getInput().get(i), pref);
-               }
-               
-               setVisited(current.getHopID(), currentType);
-       }       
-}

Reply via email to