Repository: incubator-systemml
Updated Branches:
  refs/heads/master a2c56f904 -> ac04b5708


[SYSTEMML-1631] Extended codegen row template (vector-vector ops)

This patch extends the code generator row template compiler/runtime with
the ability to handle vector-vector operations. For example, consider
the nn softmax-backward function

Y1 = X - rowMaxs(X) 
Y2 = exp(Y1)
Y3 = Y2 / rowSums(Y2)
Y4 = Y3 * rowSums(Y3)
R = Y4 - Y3 * rowSums(Y4)

We are now able to fuse the entire pipeline into a single row aggregate
operator. Note that the final minus operation requires the subtraction
of two temporary vectors. With the change, the row template has much
broader applicability. Hence, we also added an extension of the
cost-based plan selector to chose cell templates over row templates, if
the existing row memo table entries of individual root nodes do not
contain any partial aggregates.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ac04b570
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ac04b570
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ac04b570

Branch: refs/heads/master
Commit: ac04b5708d5a554e0b17ba55af3c1283549e6378
Parents: a2c56f9
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri May 26 19:33:50 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri May 26 20:05:15 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  10 +-
 .../sysml/hops/codegen/cplan/CNodeBinary.java   | 171 +++++++++++------
 .../sysml/hops/codegen/cplan/CNodeTpl.java      |  15 ++
 .../hops/codegen/template/CPlanMemoTable.java   |   5 +
 .../template/PlanSelectionFuseCostBased.java    |  29 +++
 .../hops/codegen/template/TemplateRow.java      |  27 ++-
 .../hops/codegen/template/TemplateUtils.java    |  14 +-
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  19 ++
 .../runtime/codegen/LibSpoofPrimitives.java     | 183 ++++++++++++++++++-
 .../runtime/matrix/data/LibMatrixMult.java      |   4 +-
 .../functions/codegen/RowAggTmplTest.java       |   7 +-
 .../scripts/functions/codegen/rowAggPattern15.R |   3 +-
 .../functions/codegen/rowAggPattern15.dml       |   3 +-
 13 files changed, 414 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index aa5e7e3..d87c107 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -182,7 +182,7 @@ public class SpoofCompiler
        
        public static void generateCodeFromStatementBlock(StatementBlock 
current)
                throws HopsException, DMLRuntimeException
-       {               
+       {
                if (current instanceof FunctionStatementBlock)
                {
                        FunctionStatementBlock fsb = 
(FunctionStatementBlock)current;
@@ -197,7 +197,7 @@ public class SpoofCompiler
                        wsb.setPredicateHops(optimize(wsb.getPredicateHops(), 
false));
                        for (StatementBlock sb : wstmt.getBody())
                                generateCodeFromStatementBlock(sb);
-               }       
+               }
                else if (current instanceof IfStatementBlock)
                {
                        IfStatementBlock isb = (IfStatementBlock) current;
@@ -227,7 +227,7 @@ public class SpoofCompiler
        
        public static void generateCodeFromProgramBlock(ProgramBlock current)
                throws HopsException, DMLRuntimeException, LopsException, 
IOException
-       {               
+       {
                if (current instanceof FunctionProgramBlock)
                {
                        FunctionProgramBlock fsb = 
(FunctionProgramBlock)current;
@@ -481,7 +481,7 @@ public class SpoofCompiler
        
        private static void rExploreCPlans(Hop hop, CPlanMemoTable memo, 
boolean compileLiterals) 
                throws DMLException
-       {               
+       {
                //top-down memoization of processed dag nodes
                if( memo.contains(hop.getHopID()) || memo.containsHop(hop) )
                        return;
@@ -549,7 +549,7 @@ public class SpoofCompiler
        
        private static void rConstructCPlans(Hop hop, CPlanMemoTable memo, 
HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, boolean compileLiterals, 
HashSet<Long> visited) 
                throws DMLException
-       {               
+       {
                //top-down memoization of processed dag nodes
                if( hop == null || visited.contains(hop.getHopID()) )
                        return;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 180d352..0c7883a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -22,6 +22,7 @@ package org.apache.sysml.hops.codegen.cplan;
 import java.util.Arrays;
 
 import org.apache.commons.lang.StringUtils;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.parser.Expression.DataType;
 
 
@@ -29,14 +30,20 @@ public class CNodeBinary extends CNode
 {
        public enum BinType {
                DOT_PRODUCT,
+               //vector-scalar-add operations
                VECT_MULT_ADD, VECT_DIV_ADD, VECT_MINUS_ADD, VECT_PLUS_ADD,
                VECT_POW_ADD, VECT_MIN_ADD, VECT_MAX_ADD,
                VECT_EQUAL_ADD, VECT_NOTEQUAL_ADD, VECT_LESS_ADD, 
                VECT_LESSEQUAL_ADD, VECT_GREATER_ADD, VECT_GREATEREQUAL_ADD,
+               //vector-scalar operations
                VECT_MULT_SCALAR, VECT_DIV_SCALAR, VECT_MINUS_SCALAR, 
VECT_PLUS_SCALAR,
                VECT_POW_SCALAR, VECT_MIN_SCALAR, VECT_MAX_SCALAR,
                VECT_EQUAL_SCALAR, VECT_NOTEQUAL_SCALAR, VECT_LESS_SCALAR, 
                VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, 
VECT_GREATEREQUAL_SCALAR,
+               //vector-vector operations
+               VECT_MULT, VECT_DIV, VECT_MINUS, VECT_PLUS, VECT_MIN, VECT_MAX, 
VECT_EQUAL, 
+               VECT_NOTEQUAL, VECT_LESS, VECT_LESSEQUAL, VECT_GREATER, 
VECT_GREATEREQUAL,
+               //scalar-scalar operations
                MULT, DIV, PLUS, MINUS, MODULUS, INTDIV, 
                LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL,
                MIN, MAX, AND, OR, LOG, LOG_NZ, POW,
@@ -50,9 +57,14 @@ public class CNodeBinary extends CNode
                }
                
                public boolean isCommutative() {
-                       return ( this == EQUAL || this == NOTEQUAL 
-                               || this == PLUS || this == MULT 
-                               || this == MIN || this == MAX );
+                       boolean ssComm = (this==EQUAL || this==NOTEQUAL 
+                               || this==PLUS || this==MULT || this==MIN || 
this==MAX);
+                       boolean vsComm = (this==VECT_EQUAL_SCALAR || 
this==VECT_NOTEQUAL_SCALAR 
+                                       || this==VECT_PLUS_SCALAR || 
this==VECT_MULT_SCALAR 
+                                       || this==VECT_MIN_SCALAR || 
this==VECT_MAX_SCALAR);
+                       boolean vvComm = (this==VECT_EQUAL || 
this==VECT_NOTEQUAL 
+                                       || this==VECT_PLUS || this==VECT_MULT 
|| this==VECT_MIN || this==VECT_MAX);
+                       return ssComm || vsComm || vvComm;
                }
                
                public String getTemplate(boolean sparse) {
@@ -61,6 +73,7 @@ public class CNodeBinary extends CNode
                                        return sparse ? "    double %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, %LEN%);\n" 
:
                                                                        "    
double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, 
%LEN%);\n";
                                
+                               //vector-scalar-add operations
                                case VECT_MULT_ADD:
                                case VECT_DIV_ADD:
                                case VECT_MINUS_ADD:
@@ -79,6 +92,7 @@ public class CNodeBinary extends CNode
                                                                        "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, 
%LEN%);\n";
                                }
                                
+                               //vector-scalar operations
                                case VECT_MULT_SCALAR:
                                case VECT_DIV_SCALAR:
                                case VECT_MINUS_SCALAR:
@@ -97,7 +111,26 @@ public class CNodeBinary extends CNode
                                                                        "    
double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, 
%LEN%);\n";
                                }
                                
-                               /*Can be replaced by function objects*/
+                               //vector-vector operations
+                               case VECT_MULT:
+                               case VECT_DIV:
+                               case VECT_MINUS:
+                               case VECT_PLUS:
+                               case VECT_MIN:
+                               case VECT_MAX:  
+                               case VECT_EQUAL:
+                               case VECT_NOTEQUAL:
+                               case VECT_LESS:
+                               case VECT_LESSEQUAL:
+                               case VECT_GREATER:
+                               case VECT_GREATEREQUAL: {
+                                       String vectName = 
getVectorPrimitiveName();
+                                       return sparse ? 
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, 
%LEN%);\n" : 
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, 
%LEN%);\n";
+                               }
+                               
+                               //scalar-scalar operations
                                case MULT:
                                        return "    double %TMP% = %IN1% * 
%IN2%;\n";
                                
@@ -152,6 +185,14 @@ public class CNodeBinary extends CNode
                                || this == VECT_LESS_SCALAR || this == 
VECT_LESSEQUAL_SCALAR
                                || this == VECT_GREATER_SCALAR || this == 
VECT_GREATEREQUAL_SCALAR;
                }
+               public boolean isVectorVectorPrimitive() {
+                       return this == VECT_DIV || this == VECT_MULT 
+                               || this == VECT_MINUS || this == VECT_PLUS
+                               || this == VECT_MIN || this == VECT_MAX
+                               || this == VECT_EQUAL || this == VECT_NOTEQUAL
+                               || this == VECT_LESS || this == VECT_LESSEQUAL
+                               || this == VECT_GREATER || this == 
VECT_GREATEREQUAL;
+               }
                public BinType getVectorAddPrimitive() {
                        return 
BinType.valueOf("VECT_"+getVectorPrimitiveName().toUpperCase()+"_ADD");
                }
@@ -187,7 +228,7 @@ public class CNodeBinary extends CNode
        public String codegen(boolean sparse) {
                if( _generated )
                        return "";
-                       
+               
                StringBuilder sb = new StringBuilder();
                
                //generate children
@@ -195,7 +236,8 @@ public class CNodeBinary extends CNode
                sb.append(_inputs.get(1).codegen(sparse));
                
                //generate binary operation (use sparse template, if data input)
-               boolean lsparse = sparse && (_inputs.get(0) instanceof 
CNodeData);
+               boolean lsparse = sparse && (_inputs.get(0) instanceof 
CNodeData 
+                       && !_inputs.get(0).getVarname().startsWith("b"));
                String var = createVarname();
                String tmp = _type.getTemplate(lsparse);
                tmp = tmp.replaceAll("%TMP%", var);
@@ -210,9 +252,9 @@ public class CNodeBinary extends CNode
                        tmp = tmp.replaceAll("%IN"+j+"%", varj );
                        
                        //replace start position of main input
-                       tmp = tmp.replaceAll("%POS"+j+"%", 
(!varj.startsWith("b") 
-                               && _inputs.get(j-1) instanceof CNodeData 
-                               && _inputs.get(j-1).getDataType().isMatrix()) ? 
varj+"i" : "0");
+                       tmp = tmp.replaceAll("%POS"+j+"%", (_inputs.get(j-1) 
instanceof CNodeData 
+                               && _inputs.get(j-1).getDataType().isMatrix()) ? 
(!varj.startsWith("b")) ? 
+                               varj+"i" : 
TemplateUtils.isMatrix(_inputs.get(j-1)) ? "rowIndex*len" : "0" : "0");
                }
                sb.append(tmp);
                
@@ -225,50 +267,62 @@ public class CNodeBinary extends CNode
        @Override
        public String toString() {
                switch(_type) {
-                       case DOT_PRODUCT: return "b(dot)";
-                       case VECT_MULT_ADD: return "b(vma)";
-                       case VECT_DIV_ADD: return "b(vda)";
-                       case VECT_MINUS_ADD: return "b(vmia)";
-                       case VECT_PLUS_ADD: return "b(vpa)";
-                       case VECT_POW_ADD: return "b(vpowa)";
-                       case VECT_MIN_ADD: return "b(vmina)";
-                       case VECT_MAX_ADD: return "b(vmaxa)";
-                       case VECT_EQUAL_ADD: return "b(veqa)";
-                       case VECT_NOTEQUAL_ADD: return "b(vneqa)";
-                       case VECT_LESS_ADD: return "b(vlta)";
-                       case VECT_LESSEQUAL_ADD: return "b(vltea)";
-                       case VECT_GREATEREQUAL_ADD: return "b(vgtea)";
-                       case VECT_GREATER_ADD: return "b(vgta)";
-                       case VECT_MULT_SCALAR:  return "b(vm)";
-                       case VECT_DIV_SCALAR:  return "b(vd)";
-                       case VECT_MINUS_SCALAR:  return "b(vmi)";
-                       case VECT_PLUS_SCALAR: return "b(vp)";
-                       case VECT_POW_SCALAR: return "b(vpow)";
-                       case VECT_MIN_SCALAR: return "b(vmin)";
-                       case VECT_MAX_SCALAR: return "b(vmax)";
-                       case VECT_EQUAL_SCALAR: return "b(veq)";
-                       case VECT_NOTEQUAL_SCALAR: return "b(vneq)";
-                       case VECT_LESS_SCALAR: return "b(vlt)";
-                       case VECT_LESSEQUAL_SCALAR: return "b(vlte)";
+                       case DOT_PRODUCT:              return "b(dot)";
+                       case VECT_MULT_ADD:            return "b(vma)";
+                       case VECT_DIV_ADD:             return "b(vda)";
+                       case VECT_MINUS_ADD:           return "b(vmia)";
+                       case VECT_PLUS_ADD:            return "b(vpa)";
+                       case VECT_POW_ADD:             return "b(vpowa)";
+                       case VECT_MIN_ADD:             return "b(vmina)";
+                       case VECT_MAX_ADD:             return "b(vmaxa)";
+                       case VECT_EQUAL_ADD:           return "b(veqa)";
+                       case VECT_NOTEQUAL_ADD:        return "b(vneqa)";
+                       case VECT_LESS_ADD:            return "b(vlta)";
+                       case VECT_LESSEQUAL_ADD:       return "b(vltea)";
+                       case VECT_GREATEREQUAL_ADD:    return "b(vgtea)";
+                       case VECT_GREATER_ADD:         return "b(vgta)";
+                       case VECT_MULT_SCALAR:         return "b(vm)";
+                       case VECT_DIV_SCALAR:          return "b(vd)";
+                       case VECT_MINUS_SCALAR:        return "b(vmi)";
+                       case VECT_PLUS_SCALAR:         return "b(vp)";
+                       case VECT_POW_SCALAR:          return "b(vpow)";
+                       case VECT_MIN_SCALAR:          return "b(vmin)";
+                       case VECT_MAX_SCALAR:          return "b(vmax)";
+                       case VECT_EQUAL_SCALAR:        return "b(veq)";
+                       case VECT_NOTEQUAL_SCALAR:     return "b(vneq)";
+                       case VECT_LESS_SCALAR:         return "b(vlt)";
+                       case VECT_LESSEQUAL_SCALAR:    return "b(vlte)";
                        case VECT_GREATEREQUAL_SCALAR: return "b(vgte)";
-                       case VECT_GREATER_SCALAR: return "b(vgt)";
-                       case MULT: return "b(*)";
-                       case DIV: return "b(/)";
-                       case PLUS: return "b(+)";
-                       case MINUS: return "b(-)";
-                       case POW: return "b(^)";
-                       case MODULUS: return "b(%%)";
-                       case INTDIV: return "b(%/%)";
-                       case LESS: return "b(<)";
-                       case LESSEQUAL: return "b(<=)";
-                       case GREATER: return "b(>)";
-                       case GREATEREQUAL: return "b(>=)";
-                       case EQUAL: return "b(==)";
-                       case NOTEQUAL: return "b(!=)";
-                       case OR: return "b(|)";
-                       case AND: return "b(&)";
-                       case MINUS1_MULT: return "b(1-*)";
-                       case MINUS_NZ: return "b(-nz)";
+                       case VECT_GREATER_SCALAR:      return "b(vgt)";
+                       case VECT_MULT:                return "b(v2m)";
+                       case VECT_DIV:                 return "b(v2d)";
+                       case VECT_MINUS:               return "b(v2mi)";
+                       case VECT_PLUS:                return "b(v2p)";
+                       case VECT_MIN:                 return "b(v2min)";
+                       case VECT_MAX:                 return "b(v2max)";
+                       case VECT_EQUAL:               return "b(v2eq)";
+                       case VECT_NOTEQUAL:            return "b(v2neq)";
+                       case VECT_LESS:                return "b(v2lt)";
+                       case VECT_LESSEQUAL:           return "b(v2lte)";
+                       case VECT_GREATEREQUAL:        return "b(v2gte)";
+                       case VECT_GREATER:             return "b(v2gt)";
+                       case MULT:                     return "b(*)";
+                       case DIV:                      return "b(/)";
+                       case PLUS:                     return "b(+)";
+                       case MINUS:                    return "b(-)";
+                       case POW:                      return "b(^)";
+                       case MODULUS:                  return "b(%%)";
+                       case INTDIV:                   return "b(%/%)";
+                       case LESS:                     return "b(<)";
+                       case LESSEQUAL:                return "b(<=)";
+                       case GREATER:                  return "b(>)";
+                       case GREATEREQUAL:             return "b(>=)";
+                       case EQUAL:                    return "b(==)";
+                       case NOTEQUAL:                 return "b(!=)";
+                       case OR:                       return "b(|)";
+                       case AND:                      return "b(&)";
+                       case MINUS1_MULT:              return "b(1-*)";
+                       case MINUS_NZ:                 return "b(-nz)";
                        default: return "b("+_type.name().toLowerCase()+")";
                }
        }
@@ -309,6 +363,19 @@ public class CNodeBinary extends CNode
                        case VECT_LESSEQUAL_SCALAR: 
                        case VECT_GREATER_SCALAR: 
                        case VECT_GREATEREQUAL_SCALAR:
+                       
+                       case VECT_DIV:  
+                       case VECT_MULT:
+                       case VECT_MINUS:
+                       case VECT_PLUS:
+                       case VECT_MIN:
+                       case VECT_MAX:
+                       case VECT_EQUAL: 
+                       case VECT_NOTEQUAL: 
+                       case VECT_LESS: 
+                       case VECT_LESSEQUAL: 
+                       case VECT_GREATER: 
+                       case VECT_GREATEREQUAL: 
                                _rows = _inputs.get(0)._rows;
                                _cols = _inputs.get(0)._cols;
                                _dataType= DataType.MATRIX;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
index 673ab10..ca474d2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
@@ -206,6 +206,21 @@ public abstract class CNodeTpl extends CNode implements 
Cloneable
                }
        }
        
+       public void rReorderCommutativeBinaryOps(CNode node, long mainHopID) {
+               if( isVisited() )
+                       return;
+               for( CNode c : node.getInput() )
+                       rReorderCommutativeBinaryOps(c, mainHopID);
+               if( node instanceof CNodeBinary && node.getInput().get(1) 
instanceof CNodeData
+                       && ((CNodeData)node.getInput().get(1)).getHopID() == 
mainHopID
+                       && ((CNodeBinary)node).getType().isCommutative() ) {
+                       CNode tmp = node.getInput().get(0);
+                       node.getInput().set(0, node.getInput().get(1));
+                       node.getInput().set(1, tmp);
+               }
+               setVisited();
+       }
+       
        /**
         * Checks for duplicates (object ref or varname).
         * 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index 75d3475..375f695 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -187,6 +187,11 @@ public class CPlanMemoTable
                return _plans.get(hopID);
        }
        
+       public List<MemoTableEntry> get(long hopID, TemplateType type) {
+               return _plans.get(hopID).stream()
+                       .filter(p -> p.type==type).collect(Collectors.toList());
+       }
+       
        public List<MemoTableEntry> getDistinct(long hopID) {
                //return distinct entries wrt type and closed attributes
                return _plans.get(hopID).stream()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index ed6a080..5769fcb 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -443,6 +443,20 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
        
        private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, 
HashSet<Long> R, ArrayList<Long> M) 
        {
+               //prune row aggregates with pure cellwise operations
+               for( Long hopID : R ) {
+                       MemoTableEntry me = memo.getBest(hopID, 
TemplateType.RowTpl);
+                       if( me.type == TemplateType.RowTpl && 
memo.contains(hopID, TemplateType.CellTpl)
+                               && rIsRowTemplateWithoutAgg(memo, 
memo._hopRefs.get(hopID), new HashSet<Long>())) {
+                               List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.RowTpl); 
+                               memo.remove(memo._hopRefs.get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
+                               if( LOG.isTraceEnabled() ) {
+                                       LOG.trace("Removed row memo table 
entries w/o aggregation: "
+                                               + 
Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+                               }
+                       }
+               }
+               
                //if no materialization points, use basic fuse-all w/ partition 
awareness
                if( M == null || M.isEmpty() ) {
                        for( Long hopID : R )
@@ -497,6 +511,21 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                }
        }
        
+       private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, 
Hop current, HashSet<Long> visited) {
+               if( visited.contains(current.getHopID()) )
+                       return true;
+               
+               boolean ret = true;
+               MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.RowTpl);
+               for(int i=0; i<3; i++)
+                       if( me.isPlanRef(i) )
+                               ret &= rIsRowTemplateWithoutAgg(memo, 
current.getInput().get(i), visited);
+               ret &= !(current instanceof AggUnaryOp || current instanceof 
AggBinaryOp);
+               
+               visited.add(current.getHopID());
+               return ret;
+       }
+       
        private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, 
boolean[] plan) {
                //memoization (not via hops because in middle of dag)
                if( visited.contains(current.getHopID()) )

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 3979aae..ae25ded 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -87,12 +87,13 @@ public class TemplateRow extends TemplateBase
                return !isClosed() && 
                        (  (hop instanceof BinaryOp && 
TemplateUtils.isOperationSupported(hop) 
                                && 
(HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
-                                       || 
HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) 
+                                       || 
HopRewriteUtils.isBinaryMatrixScalarOperation(hop)
+                                       || 
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) ) 
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
                                        && TemplateCell.isValidOperation(hop))  
        
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
                                && HopRewriteUtils.isAggUnaryOp(hop, 
SUPPORTED_ROW_AGG))
-                       || (hop instanceof AggBinaryOp && hop.getDim1()>1 
+                       || (hop instanceof AggBinaryOp && hop.getDim1()>1 && 
hop.getDim2()==1
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
        }
 
@@ -100,8 +101,9 @@ public class TemplateRow extends TemplateBase
        public boolean merge(Hop hop, Hop input) {
                //merge rowagg tpl with cell tpl if input is a vector
                return !isClosed() &&
-                       ((hop instanceof BinaryOp && input.getDim2()==1 
//matrix-scalar/vector-vector ops )
-                               && TemplateUtils.isOperationSupported(hop))
+                       ((hop instanceof BinaryOp && 
TemplateUtils.isOperationSupported(hop)
+                               && (input.getDim2()==1 
//matrix-scalar/vector-vector ops )
+                                       || 
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)))
                         ||(hop instanceof AggBinaryOp && input.getDim2()==1
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
        }
@@ -140,6 +142,8 @@ public class TemplateRow extends TemplateBase
                tpl.setRowType(TemplateUtils.getRowType(hop, sinHops.get(0)));
                tpl.setNumVectorIntermediates(TemplateUtils
                        .countVectorIntermediates(output, new HashSet<Long>()));
+               tpl.getOutput().resetVisitStatus();
+               tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), 
sinHops.get(0).getHopID());
                
                // return cplan instance
                return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), 
tpl);
@@ -199,6 +203,7 @@ public class TemplateRow extends TemplateBase
                                inHops.remove(hop.getInput().get(0)); 
                                
inHops.add(hop.getInput().get(0).getInput().get(0));
                                
+                               //note: vectorMultAdd applicable to 
vector-scalar, and vector-vector
                                out = new CNodeBinary(cdata1, cdata2, 
BinType.VECT_MULT_ADD);
                                inHops2.put("X", 
hop.getInput().get(0).getInput().get(0));
                        }
@@ -245,10 +250,16 @@ public class TemplateRow extends TemplateBase
                        if(hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1 )
                        {
                                if( HopRewriteUtils.isBinary(hop, 
SUPPORTED_VECT_BINARY) ) {
-                                       String opname = 
"VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR";
-                                       if( TemplateUtils.isColVector(cdata2) )
-                                               cdata2 = new CNodeUnary(cdata2, 
UnaryType.LOOKUP_R);
-                                       out = new CNodeBinary(cdata1, cdata2, 
BinType.valueOf(opname));
+                                       if( TemplateUtils.isMatrix(cdata1) && 
TemplateUtils.isMatrix(cdata2) ) {
+                                               String opname = 
"VECT_"+((BinaryOp)hop).getOp().name();
+                                               out = new CNodeBinary(cdata1, 
cdata2, BinType.valueOf(opname));
+                                       }
+                                       else {
+                                               String opname = 
"VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR";
+                                               if( 
TemplateUtils.isColVector(cdata2) )
+                                                       cdata2 = new 
CNodeUnary(cdata2, UnaryType.LOOKUP_R);
+                                               out = new CNodeBinary(cdata1, 
cdata2, BinType.valueOf(opname));
+                                       }
                                }
                                else 
                                        throw new RuntimeException("Unsupported 
binary matrix "

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 89211db..2059e17 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -74,6 +74,11 @@ public class TemplateUtils
                        && hop.getNumRows() == 1 && hop.getNumCols() != 1);
        }
        
+       public static boolean isMatrix(CNode hop) {
+               return (hop.getDataType() == DataType.MATRIX 
+                       && hop.getNumRows() != 1 && hop.getNumCols() != 1);
+       }
+       
        public static CNode wrapLookupIfNecessary(CNode node, Hop hop) {
                CNode ret = node;
                if( isColVector(node) )
@@ -343,10 +348,11 @@ public class TemplateUtils
                for( CNode c : node.getInput() )
                        ret += countVectorIntermediates(c, memo);
                //compute vector requirements of current node
-               int cntBin = ((node instanceof CNodeBinary 
-                       && 
((CNodeBinary)node).getType().isVectorScalarPrimitive()) ? 1 : 0);
-               int cntUn = ((node instanceof CNodeUnary
-                               && 
((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0);
+               int cntBin = (node instanceof CNodeBinary 
+                       && 
(((CNodeBinary)node).getType().isVectorScalarPrimitive() 
+                       || 
((CNodeBinary)node).getType().isVectorVectorPrimitive())) ? 1 : 0;
+               int cntUn = (node instanceof CNodeUnary
+                               && 
((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0;
                return ret + cntBin + cntUn;
        }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index f7be8b9..b406bb7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
 
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.sysml.api.DMLScript;
@@ -794,6 +795,24 @@ public class HopRewriteUtils
                        && hop.getInput().get(1).dimsKnown() && 
hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1;
        }
        
+       public static boolean isBinaryMatrixMatrixOperationWithSharedInput(Hop 
hop) {
+               boolean ret = isBinaryMatrixMatrixOperation(hop);
+               ret = ret && (rContainsInput(hop.getInput().get(0), 
hop.getInput().get(1), new HashSet<Long>())
+                               || rContainsInput(hop.getInput().get(1), 
hop.getInput().get(0), new HashSet<Long>()));
+               return ret;
+       }
+       
+       private static boolean rContainsInput(Hop current, Hop probe, 
HashSet<Long> memo) {
+               if( memo.contains(current.getHopID()) )
+                       return false;
+               boolean ret = false;
+               for( int i=0; i<current.getInput().size() && !ret; i++ )
+                       ret |= rContainsInput(current.getInput().get(i), probe, 
memo);
+               ret |= (current == probe);
+               memo.add(current.getHopID());
+               return ret;
+       }
+       
        public static boolean isBinaryMatrixColVectorOperation(Hop hop) {
                return hop instanceof BinaryOp 
                        && hop.getInput().get(0).getDataType().isMatrix() && 
hop.getInput().get(1).getDataType().isMatrix()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index 7a9adeb..1b4369f 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -63,18 +63,36 @@ public class LibSpoofPrimitives
                LibMatrixMult.vectMultiplyAdd(bval, a, c, bix, bi, ci, len);
        }
        
+       public static void vectMultAdd(double[] a, double[] b, double[] c, int 
bi, int ci, int len) {
+               double[] tmp = vectMultWrite(a, b, 0, bi, len);
+               LibMatrixMult.vectAdd(tmp, c, 0, ci, len);
+       }
+       
        public static double[] vectMultWrite(double[] a, double bval, int bi, 
int len) {
                double[] c = allocVector(len, false);
                LibMatrixMult.vectMultiplyWrite(bval, a, c, bi, 0, len);
                return c;
        }
        
+       public static double[] vectMultWrite(double[] a, double[] b, int ai, 
int bi, int len) {
+               double[] c = allocVector(len, false);
+               LibMatrixMult.vectMultiplyWrite(a, b, c, ai, bi, 0, len);
+               return c;
+       }
+       
        public static double[] vectMultWrite(double[] a, double bval, int[] 
bix, int bi, int len) {
                double[] c = allocVector(len, true);
                LibMatrixMult.vectMultiplyAdd(bval, a, c, bix, bi, 0, len);
                return c;
        }
        
+       public static double[] vectMultWrite(double[] a, double[] b, int[] aix, 
int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = a[j] * b[bi+aix[j]];
+               return c;
+       }
+       
        public static void vectWrite(double[] a, double[] c, int ci, int len) {
                System.arraycopy(a, 0, c, ci, len);
        }
@@ -168,6 +186,13 @@ public class LibSpoofPrimitives
                        c[j] = a[ai] / bval;
                return c;
        }
+       
+       public static double[] vectDivWrite(double[] a, double[] b, int ai, int 
bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = a[ai] / b[bi];
+               return c;
+       }
 
        public static double[] vectDivWrite(double[] a, double bval, int[] aix, 
int ai, int len) {
                double[] c = allocVector(len, true);
@@ -176,6 +201,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectDivWrite(double[] a, double[] b, int[] aix, 
int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = a[j] / b[bi+aix[j]];
+               return c;
+       }
+       
        //custom vector minus
        
        public static void vectMinusAdd(double[] a, double bval, double[] c, 
int ai, int ci, int len) {
@@ -195,6 +227,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectMinusWrite(double[] a, double[] b, int ai, 
int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = a[ai] - b[bi];
+               return c;
+       }
+       
        public static double[] vectMinusWrite(double[] a, double bval, int[] 
aix, int ai, int len) {
                double[] c = allocVector(len, true);
                for( int j = ai; j < ai+len; j++ )
@@ -202,6 +241,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectMinusWrite(double[] a, double[] b, int[] 
aix, int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = a[j] - b[bi+aix[j]];
+               return c;
+       }
+       
        //custom vector plus
        
        public static void vectPlusAdd(double[] a, double bval, double[] c, int 
ai, int ci, int len) {
@@ -220,6 +266,13 @@ public class LibSpoofPrimitives
                        c[j] = a[ai] + bval;
                return c;
        }
+       
+       public static double[] vectPlusWrite(double[] a, double[] b, int ai, 
int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++)
+                       c[j] = a[ai] + b[bi];
+               return c;
+       }
 
        public static double[] vectPlusWrite(double[] a, double bval, int[] 
aix, int ai, int len) {
                double[] c = allocVector(len, true);
@@ -228,6 +281,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectPlusWrite(double[] a, double[] b, int[] aix, 
int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = a[j] + b[bi+aix[j]];
+               return c;
+       }
+       
        //custom vector pow
        
        public static void vectPowAdd(double[] a, double bval, double[] c, int 
ai, int ci, int len) {
@@ -246,6 +306,13 @@ public class LibSpoofPrimitives
                        c[j] = Math.pow(a[ai], bval);
                return c;
        }
+       
+       public static double[] vectPowWrite(double[] a, double[] b, int ai, int 
bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = Math.pow(a[ai], b[bi]);
+               return c;
+       }
 
        public static double[] vectPowWrite(double[] a, double bval, int[] aix, 
int ai, int len) {
                double[] c = allocVector(len, true);
@@ -272,6 +339,13 @@ public class LibSpoofPrimitives
                        c[j] = Math.min(a[ai], bval);
                return c;
        }
+       
+       public static double[] vectMinWrite(double[] a, double[] b, int ai, int 
bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = Math.min(a[ai], b[bi]);
+               return c;
+       }
 
        public static double[] vectMinWrite(double[] a, double bval, int[] aix, 
int ai, int len) {
                double[] c = allocVector(len, true);
@@ -280,6 +354,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectMinWrite(double[] a, double[] b, int[] aix, 
int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = Math.min(a[j], b[bi+aix[j]]);
+               return c;
+       }
+       
        //custom vector max
        
        public static void vectMaxAdd(double[] a, double bval, double[] c, int 
ai, int ci, int len) {
@@ -298,6 +379,13 @@ public class LibSpoofPrimitives
                        c[j] = Math.max(a[ai], bval);
                return c;
        }
+       
+       public static double[] vectMaxWrite(double[] a, double[] b, int ai, int 
bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = Math.max(a[ai], b[bi]);
+               return c;
+       }
 
        public static double[] vectMaxWrite(double[] a, double bval, int[] aix, 
int ai, int len) {
                double[] c = allocVector(len, true);
@@ -305,6 +393,13 @@ public class LibSpoofPrimitives
                        c[aix[j]] = Math.max(a[j], bval);
                return c;
        }
+       
+       public static double[] vectMaxWrite(double[] a, double[] b, int[] aix, 
int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = Math.max(a[j], b[bi+aix[j]]);
+               return c;
+       }
 
        //custom exp
        
@@ -584,13 +679,27 @@ public class LibSpoofPrimitives
                        c[j] = (a[ai] == bval) ? 1 : 0;
                return c;
        }
+       
+       public static double[] vectEqualWrite(double[] a, double[] b, int ai, 
int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = (a[ai] == b[bi]) ? 1 : 0;
+               return c;
+       }
 
        public static double[] vectEqualWrite(double[] a, double bval, int[] 
aix, int ai, int len) {
                double[] c = allocVector(len, true);
                for( int j = ai; j < ai+len; j++ )
                        c[aix[j]] = (a[j] == bval) ? 1 : 0;
                return c;
-       }       
+       }
+       
+       public static double[] vectEqualWrite(double[] a, double[] b, int[] 
aix, int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = (a[j] == b[bi+aix[j]]) ? 1 : 0;
+               return c;
+       }
        
        //custom vector not equal
        
@@ -610,6 +719,13 @@ public class LibSpoofPrimitives
                        c[j] = (a[ai] != bval) ? 1 : 0;
                return c;
        }
+       
+       public static double[] vectNotequalWrite(double[] a, double[] b, int 
ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = (a[ai] != b[bi]) ? 1 : 0;
+               return c;
+       }
 
        public static double[] vectNotequalWrite(double[] a, double bval, int[] 
aix, int ai, int len) {
                double[] c = allocVector(len, true);
@@ -618,6 +734,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectNotequalWrite(double[] a, double[] b, int[] 
aix, int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = (a[j] != b[bi+aix[j]]) ? 1 : 0;
+               return c;
+       }
+       
        //custom vector less
        
        public static void vectLessAdd(double[] a, double bval, double[] c, int 
ai, int ci, int len) {
@@ -636,6 +759,13 @@ public class LibSpoofPrimitives
                        c[j] = (a[ai] < bval) ? 1 : 0;
                return c;
        }
+       
+       public static double[] vectLessWrite(double[] a, double[] b, int ai, 
int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = (a[ai] < b[bi]) ? 1 : 0;
+               return c;
+       }
 
        public static double[] vectLessWrite(double[] a, double bval, int[] 
aix, int ai, int len) {
                double[] c = allocVector(len, true);
@@ -644,6 +774,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectLessWrite(double[] a, double[] b, int[] aix, 
int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = (a[j] < b[bi+aix[j]]) ? 1 : 0;
+               return c;
+       }
+       
        //custom vector less equal
        
        public static void vectLessequalAdd(double[] a, double bval, double[] 
c, int ai, int ci, int len) {
@@ -662,6 +799,13 @@ public class LibSpoofPrimitives
                        c[j] = (a[ai] <= bval) ? 1 : 0;
                return c;
        }
+       
+       public static double[] vectLessequalWrite(double[] a, double[] b, int 
ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = (a[ai] <= b[bi]) ? 1 : 0;
+               return c;
+       }
 
        public static double[] vectLessequalWrite(double[] a, double bval, 
int[] aix, int ai, int len) {
                double[] c = allocVector(len, true);
@@ -669,6 +813,13 @@ public class LibSpoofPrimitives
                        c[aix[j]] = (a[j] <= bval) ? 1 : 0;
                return c;
        }
+       
+       public static double[] vectLessequalWrite(double[] a, double[] b, int[] 
aix, int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = (a[j] <= b[bi+aix[j]]) ? 1 : 0;
+               return c;
+       }
 
        //custom vector greater
        
@@ -688,13 +839,27 @@ public class LibSpoofPrimitives
                        c[j] = (a[ai] > bval) ? 1 : 0;
                return c;
        }
+       
+       public static double[] vectGreaterWrite(double[] a, double[] b, int ai, 
int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = (a[ai] > b[bi]) ? 1 : 0;
+               return c;
+       }
 
        public static double[] vectGreaterWrite(double[] a, double bval, int[] 
aix, int ai, int len) {
                double[] c = allocVector(len, true);
                for( int j = ai; j < ai+len; j++ )
                        c[aix[j]] = (a[j] > bval) ? 1 : 0;
                return c;
-       }       
+       }
+       
+       public static double[] vectGreaterWrite(double[] a, double[] b, int[] 
aix, int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = (a[j] > b[bi+aix[j]]) ? 1 : 0;
+               return c;
+       }
        
        //custom vector greater equal
        
@@ -714,6 +879,13 @@ public class LibSpoofPrimitives
                        c[j] = (a[ai] >= bval) ? 1 : 0;
                return c;
        }
+       
+       public static double[] vectGreaterequalWrite(double[] a, double[] b, 
int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = 0; j < len; j++, ai++, bi++)
+                       c[j] = (a[ai] >= b[bi]) ? 1 : 0;
+               return c;
+       }
 
        public static double[] vectGreaterequalWrite(double[] a, double bval, 
int[] aix, int ai, int len) {
                double[] c = allocVector(len, true);
@@ -722,6 +894,13 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       public static double[] vectGreaterequalWrite(double[] a, double[] b, 
int[] aix, int ai, int bi, int len) {
+               double[] c = allocVector(len, false);
+               for( int j = ai; j < ai+len; j++ )
+                       c[aix[j]] = (a[j] >= b[bi+aix[j]]) ? 1 : 0;
+               return c;
+       }
+       
        //complex builtin functions that are not directly generated
        //(included here in order to reduce the number of imports)
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 983ce53..f7d2d54 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -3146,8 +3146,8 @@ public class LibMatrixMult
                }
        }
 
-       @SuppressWarnings("unused")
-       private static void vectMultiplyWrite( double[] a, double[] b, double[] 
c, int ai, int bi, int ci, final int len )
+       //note: public for use by codegen for consistency
+       public static void vectMultiplyWrite( double[] a, double[] b, double[] 
c, int ai, int bi, int ci, final int len )
        {
                final int bn = len%8;
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index b7f82a7..362f1dc 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -50,7 +50,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME12 = TEST_NAME+"12"; //Y=(X>=v); 
R=Y/rowSums(Y)
        private static final String TEST_NAME13 = TEST_NAME+"13"; 
//rowSums(X)+rowSums(Y)
        private static final String TEST_NAME14 = TEST_NAME+"14"; 
//colSums(max(floor(round(abs(min(sign(X+Y),1)))),7))
-       private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn 
- softmax backward (partially)
+       private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn 
- softmax backward
        private static final String TEST_NAME16 = TEST_NAME+"16"; 
//Y=X-rowIndexMax(X); R=Y/rowSums(Y)
        
        private static final String TEST_DIR = "functions/codegen/";
@@ -344,6 +344,11 @@ public class RowAggTmplTest extends AutomatedTestBase
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                        
Assert.assertTrue(heavyHittersContainsSubString("spoofRA") 
                                        || 
heavyHittersContainsSubString("sp_spoofRA"));
+                       
+                       //ensure full aggregates for certain patterns
+                       if( testname.equals(TEST_NAME15) )
+                               
Assert.assertTrue(!heavyHittersContainsSubString("uark+"));
+                               
                }
                finally {
                        rtplatform = platformOld;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/test/scripts/functions/codegen/rowAggPattern15.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern15.R 
b/src/test/scripts/functions/codegen/rowAggPattern15.R
index a24679a..3151f0c 100644
--- a/src/test/scripts/functions/codegen/rowAggPattern15.R
+++ b/src/test/scripts/functions/codegen/rowAggPattern15.R
@@ -30,6 +30,7 @@ X = matrix(seq(1,1500), 150, 10, byrow=TRUE);
 Y1 = X - rowMaxs(X) 
 Y2 = exp(Y1)
 Y3 = Y2 / rowSums(Y2)
-R = Y3 * rowSums(Y3)
+Y4 = Y3 * rowSums(Y3)
+R = Y4 - Y3 * rowSums(Y4)
 
 writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/test/scripts/functions/codegen/rowAggPattern15.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern15.dml 
b/src/test/scripts/functions/codegen/rowAggPattern15.dml
index d51397a..0dd98a5 100644
--- a/src/test/scripts/functions/codegen/rowAggPattern15.dml
+++ b/src/test/scripts/functions/codegen/rowAggPattern15.dml
@@ -24,6 +24,7 @@ X = matrix(seq(1,1500), rows=150, cols=10);
 Y1 = X - rowMaxs(X) 
 Y2 = exp(Y1)
 Y3 = Y2 / rowSums(Y2)
-R = Y3 * rowSums(Y3)
+Y4 = Y3 * rowSums(Y3)
+R = Y4 - Y3 * rowSums(Y4)
 
 write(R, $1)


Reply via email to