Repository: incubator-systemml Updated Branches: refs/heads/master a2c56f904 -> ac04b5708
[SYSTEMML-1631] Extended codegen row template (vector-vector ops) This patch extends the code generator row template compiler/runtime with the ability to handle vector-vector operations. For example, consider the nn softmax-backward function Y1 = X - rowMaxs(X) Y2 = exp(Y1) Y3 = Y2 / rowSums(Y2) Y4 = Y3 * rowSums(Y3) R = Y4 - Y3 * rowSums(Y4) We are now able to fuse the entire pipeline into a single row aggregate operator. Note that the final minus operation requires the subtraction of two temporary vectors. With the change, the row template has much broader applicability. Hence, we also added an extension of the cost-based plan selector to chose cell templates over row templates, if the existing row memo table entries of individual root nodes do not contain any partial aggregates. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ac04b570 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ac04b570 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ac04b570 Branch: refs/heads/master Commit: ac04b5708d5a554e0b17ba55af3c1283549e6378 Parents: a2c56f9 Author: Matthias Boehm <mboe...@gmail.com> Authored: Fri May 26 19:33:50 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri May 26 20:05:15 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 10 +- .../sysml/hops/codegen/cplan/CNodeBinary.java | 171 +++++++++++------ .../sysml/hops/codegen/cplan/CNodeTpl.java | 15 ++ .../hops/codegen/template/CPlanMemoTable.java | 5 + .../template/PlanSelectionFuseCostBased.java | 29 +++ .../hops/codegen/template/TemplateRow.java | 27 ++- .../hops/codegen/template/TemplateUtils.java | 14 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 19 ++ .../runtime/codegen/LibSpoofPrimitives.java | 183 ++++++++++++++++++- .../runtime/matrix/data/LibMatrixMult.java | 4 +- .../functions/codegen/RowAggTmplTest.java | 7 +- .../scripts/functions/codegen/rowAggPattern15.R | 3 +- .../functions/codegen/rowAggPattern15.dml | 3 +- 13 files changed, 414 insertions(+), 76 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index aa5e7e3..d87c107 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -182,7 +182,7 @@ public class SpoofCompiler public static void generateCodeFromStatementBlock(StatementBlock current) throws HopsException, DMLRuntimeException - { + { if (current instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)current; @@ -197,7 +197,7 @@ public class SpoofCompiler wsb.setPredicateHops(optimize(wsb.getPredicateHops(), false)); for (StatementBlock sb : wstmt.getBody()) generateCodeFromStatementBlock(sb); - } + } else if (current instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) current; @@ -227,7 +227,7 @@ public class SpoofCompiler public static void generateCodeFromProgramBlock(ProgramBlock current) throws HopsException, DMLRuntimeException, LopsException, IOException - { + { if (current instanceof FunctionProgramBlock) { FunctionProgramBlock fsb = (FunctionProgramBlock)current; @@ -481,7 +481,7 @@ public class SpoofCompiler private static void rExploreCPlans(Hop hop, CPlanMemoTable memo, boolean compileLiterals) throws DMLException - { + { //top-down memoization of processed dag nodes if( memo.contains(hop.getHopID()) || memo.containsHop(hop) ) return; @@ -549,7 +549,7 @@ public class SpoofCompiler private static void rConstructCPlans(Hop hop, CPlanMemoTable memo, HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, boolean compileLiterals, HashSet<Long> visited) throws DMLException - { + { //top-down memoization of processed dag nodes if( hop == null || visited.contains(hop.getHopID()) ) return; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index 180d352..0c7883a 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -22,6 +22,7 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.Arrays; import org.apache.commons.lang.StringUtils; +import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.parser.Expression.DataType; @@ -29,14 +30,20 @@ public class CNodeBinary extends CNode { public enum BinType { DOT_PRODUCT, + //vector-scalar-add operations VECT_MULT_ADD, VECT_DIV_ADD, VECT_MINUS_ADD, VECT_PLUS_ADD, VECT_POW_ADD, VECT_MIN_ADD, VECT_MAX_ADD, VECT_EQUAL_ADD, VECT_NOTEQUAL_ADD, VECT_LESS_ADD, VECT_LESSEQUAL_ADD, VECT_GREATER_ADD, VECT_GREATEREQUAL_ADD, + //vector-scalar operations VECT_MULT_SCALAR, VECT_DIV_SCALAR, VECT_MINUS_SCALAR, VECT_PLUS_SCALAR, VECT_POW_SCALAR, VECT_MIN_SCALAR, VECT_MAX_SCALAR, VECT_EQUAL_SCALAR, VECT_NOTEQUAL_SCALAR, VECT_LESS_SCALAR, VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, VECT_GREATEREQUAL_SCALAR, + //vector-vector operations + VECT_MULT, VECT_DIV, VECT_MINUS, VECT_PLUS, VECT_MIN, VECT_MAX, VECT_EQUAL, + VECT_NOTEQUAL, VECT_LESS, VECT_LESSEQUAL, VECT_GREATER, VECT_GREATEREQUAL, + //scalar-scalar operations MULT, DIV, PLUS, MINUS, MODULUS, INTDIV, LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL, MIN, MAX, AND, OR, LOG, LOG_NZ, POW, @@ -50,9 +57,14 @@ public class CNodeBinary extends CNode } public boolean isCommutative() { - return ( this == EQUAL || this == NOTEQUAL - || this == PLUS || this == MULT - || this == MIN || this == MAX ); + boolean ssComm = (this==EQUAL || this==NOTEQUAL + || this==PLUS || this==MULT || this==MIN || this==MAX); + boolean vsComm = (this==VECT_EQUAL_SCALAR || this==VECT_NOTEQUAL_SCALAR + || this==VECT_PLUS_SCALAR || this==VECT_MULT_SCALAR + || this==VECT_MIN_SCALAR || this==VECT_MAX_SCALAR); + boolean vvComm = (this==VECT_EQUAL || this==VECT_NOTEQUAL + || this==VECT_PLUS || this==VECT_MULT || this==VECT_MIN || this==VECT_MAX); + return ssComm || vsComm || vvComm; } public String getTemplate(boolean sparse) { @@ -61,6 +73,7 @@ public class CNodeBinary extends CNode return sparse ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, %LEN%);\n" : " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; + //vector-scalar-add operations case VECT_MULT_ADD: case VECT_DIV_ADD: case VECT_MINUS_ADD: @@ -79,6 +92,7 @@ public class CNodeBinary extends CNode " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n"; } + //vector-scalar operations case VECT_MULT_SCALAR: case VECT_DIV_SCALAR: case VECT_MINUS_SCALAR: @@ -97,7 +111,26 @@ public class CNodeBinary extends CNode " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n"; } - /*Can be replaced by function objects*/ + //vector-vector operations + case VECT_MULT: + case VECT_DIV: + case VECT_MINUS: + case VECT_PLUS: + case VECT_MIN: + case VECT_MAX: + case VECT_EQUAL: + case VECT_NOTEQUAL: + case VECT_LESS: + case VECT_LESSEQUAL: + case VECT_GREATER: + case VECT_GREATEREQUAL: { + String vectName = getVectorPrimitiveName(); + return sparse ? + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; + } + + //scalar-scalar operations case MULT: return " double %TMP% = %IN1% * %IN2%;\n"; @@ -152,6 +185,14 @@ public class CNodeBinary extends CNode || this == VECT_LESS_SCALAR || this == VECT_LESSEQUAL_SCALAR || this == VECT_GREATER_SCALAR || this == VECT_GREATEREQUAL_SCALAR; } + public boolean isVectorVectorPrimitive() { + return this == VECT_DIV || this == VECT_MULT + || this == VECT_MINUS || this == VECT_PLUS + || this == VECT_MIN || this == VECT_MAX + || this == VECT_EQUAL || this == VECT_NOTEQUAL + || this == VECT_LESS || this == VECT_LESSEQUAL + || this == VECT_GREATER || this == VECT_GREATEREQUAL; + } public BinType getVectorAddPrimitive() { return BinType.valueOf("VECT_"+getVectorPrimitiveName().toUpperCase()+"_ADD"); } @@ -187,7 +228,7 @@ public class CNodeBinary extends CNode public String codegen(boolean sparse) { if( _generated ) return ""; - + StringBuilder sb = new StringBuilder(); //generate children @@ -195,7 +236,8 @@ public class CNodeBinary extends CNode sb.append(_inputs.get(1).codegen(sparse)); //generate binary operation (use sparse template, if data input) - boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData); + boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData + && !_inputs.get(0).getVarname().startsWith("b")); String var = createVarname(); String tmp = _type.getTemplate(lsparse); tmp = tmp.replaceAll("%TMP%", var); @@ -210,9 +252,9 @@ public class CNodeBinary extends CNode tmp = tmp.replaceAll("%IN"+j+"%", varj ); //replace start position of main input - tmp = tmp.replaceAll("%POS"+j+"%", (!varj.startsWith("b") - && _inputs.get(j-1) instanceof CNodeData - && _inputs.get(j-1).getDataType().isMatrix()) ? varj+"i" : "0"); + tmp = tmp.replaceAll("%POS"+j+"%", (_inputs.get(j-1) instanceof CNodeData + && _inputs.get(j-1).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? + varj+"i" : TemplateUtils.isMatrix(_inputs.get(j-1)) ? "rowIndex*len" : "0" : "0"); } sb.append(tmp); @@ -225,50 +267,62 @@ public class CNodeBinary extends CNode @Override public String toString() { switch(_type) { - case DOT_PRODUCT: return "b(dot)"; - case VECT_MULT_ADD: return "b(vma)"; - case VECT_DIV_ADD: return "b(vda)"; - case VECT_MINUS_ADD: return "b(vmia)"; - case VECT_PLUS_ADD: return "b(vpa)"; - case VECT_POW_ADD: return "b(vpowa)"; - case VECT_MIN_ADD: return "b(vmina)"; - case VECT_MAX_ADD: return "b(vmaxa)"; - case VECT_EQUAL_ADD: return "b(veqa)"; - case VECT_NOTEQUAL_ADD: return "b(vneqa)"; - case VECT_LESS_ADD: return "b(vlta)"; - case VECT_LESSEQUAL_ADD: return "b(vltea)"; - case VECT_GREATEREQUAL_ADD: return "b(vgtea)"; - case VECT_GREATER_ADD: return "b(vgta)"; - case VECT_MULT_SCALAR: return "b(vm)"; - case VECT_DIV_SCALAR: return "b(vd)"; - case VECT_MINUS_SCALAR: return "b(vmi)"; - case VECT_PLUS_SCALAR: return "b(vp)"; - case VECT_POW_SCALAR: return "b(vpow)"; - case VECT_MIN_SCALAR: return "b(vmin)"; - case VECT_MAX_SCALAR: return "b(vmax)"; - case VECT_EQUAL_SCALAR: return "b(veq)"; - case VECT_NOTEQUAL_SCALAR: return "b(vneq)"; - case VECT_LESS_SCALAR: return "b(vlt)"; - case VECT_LESSEQUAL_SCALAR: return "b(vlte)"; + case DOT_PRODUCT: return "b(dot)"; + case VECT_MULT_ADD: return "b(vma)"; + case VECT_DIV_ADD: return "b(vda)"; + case VECT_MINUS_ADD: return "b(vmia)"; + case VECT_PLUS_ADD: return "b(vpa)"; + case VECT_POW_ADD: return "b(vpowa)"; + case VECT_MIN_ADD: return "b(vmina)"; + case VECT_MAX_ADD: return "b(vmaxa)"; + case VECT_EQUAL_ADD: return "b(veqa)"; + case VECT_NOTEQUAL_ADD: return "b(vneqa)"; + case VECT_LESS_ADD: return "b(vlta)"; + case VECT_LESSEQUAL_ADD: return "b(vltea)"; + case VECT_GREATEREQUAL_ADD: return "b(vgtea)"; + case VECT_GREATER_ADD: return "b(vgta)"; + case VECT_MULT_SCALAR: return "b(vm)"; + case VECT_DIV_SCALAR: return "b(vd)"; + case VECT_MINUS_SCALAR: return "b(vmi)"; + case VECT_PLUS_SCALAR: return "b(vp)"; + case VECT_POW_SCALAR: return "b(vpow)"; + case VECT_MIN_SCALAR: return "b(vmin)"; + case VECT_MAX_SCALAR: return "b(vmax)"; + case VECT_EQUAL_SCALAR: return "b(veq)"; + case VECT_NOTEQUAL_SCALAR: return "b(vneq)"; + case VECT_LESS_SCALAR: return "b(vlt)"; + case VECT_LESSEQUAL_SCALAR: return "b(vlte)"; case VECT_GREATEREQUAL_SCALAR: return "b(vgte)"; - case VECT_GREATER_SCALAR: return "b(vgt)"; - case MULT: return "b(*)"; - case DIV: return "b(/)"; - case PLUS: return "b(+)"; - case MINUS: return "b(-)"; - case POW: return "b(^)"; - case MODULUS: return "b(%%)"; - case INTDIV: return "b(%/%)"; - case LESS: return "b(<)"; - case LESSEQUAL: return "b(<=)"; - case GREATER: return "b(>)"; - case GREATEREQUAL: return "b(>=)"; - case EQUAL: return "b(==)"; - case NOTEQUAL: return "b(!=)"; - case OR: return "b(|)"; - case AND: return "b(&)"; - case MINUS1_MULT: return "b(1-*)"; - case MINUS_NZ: return "b(-nz)"; + case VECT_GREATER_SCALAR: return "b(vgt)"; + case VECT_MULT: return "b(v2m)"; + case VECT_DIV: return "b(v2d)"; + case VECT_MINUS: return "b(v2mi)"; + case VECT_PLUS: return "b(v2p)"; + case VECT_MIN: return "b(v2min)"; + case VECT_MAX: return "b(v2max)"; + case VECT_EQUAL: return "b(v2eq)"; + case VECT_NOTEQUAL: return "b(v2neq)"; + case VECT_LESS: return "b(v2lt)"; + case VECT_LESSEQUAL: return "b(v2lte)"; + case VECT_GREATEREQUAL: return "b(v2gte)"; + case VECT_GREATER: return "b(v2gt)"; + case MULT: return "b(*)"; + case DIV: return "b(/)"; + case PLUS: return "b(+)"; + case MINUS: return "b(-)"; + case POW: return "b(^)"; + case MODULUS: return "b(%%)"; + case INTDIV: return "b(%/%)"; + case LESS: return "b(<)"; + case LESSEQUAL: return "b(<=)"; + case GREATER: return "b(>)"; + case GREATEREQUAL: return "b(>=)"; + case EQUAL: return "b(==)"; + case NOTEQUAL: return "b(!=)"; + case OR: return "b(|)"; + case AND: return "b(&)"; + case MINUS1_MULT: return "b(1-*)"; + case MINUS_NZ: return "b(-nz)"; default: return "b("+_type.name().toLowerCase()+")"; } } @@ -309,6 +363,19 @@ public class CNodeBinary extends CNode case VECT_LESSEQUAL_SCALAR: case VECT_GREATER_SCALAR: case VECT_GREATEREQUAL_SCALAR: + + case VECT_DIV: + case VECT_MULT: + case VECT_MINUS: + case VECT_PLUS: + case VECT_MIN: + case VECT_MAX: + case VECT_EQUAL: + case VECT_NOTEQUAL: + case VECT_LESS: + case VECT_LESSEQUAL: + case VECT_GREATER: + case VECT_GREATEREQUAL: _rows = _inputs.get(0)._rows; _cols = _inputs.get(0)._cols; _dataType= DataType.MATRIX; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java index 673ab10..ca474d2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java @@ -206,6 +206,21 @@ public abstract class CNodeTpl extends CNode implements Cloneable } } + public void rReorderCommutativeBinaryOps(CNode node, long mainHopID) { + if( isVisited() ) + return; + for( CNode c : node.getInput() ) + rReorderCommutativeBinaryOps(c, mainHopID); + if( node instanceof CNodeBinary && node.getInput().get(1) instanceof CNodeData + && ((CNodeData)node.getInput().get(1)).getHopID() == mainHopID + && ((CNodeBinary)node).getType().isCommutative() ) { + CNode tmp = node.getInput().get(0); + node.getInput().set(0, node.getInput().get(1)); + node.getInput().set(1, tmp); + } + setVisited(); + } + /** * Checks for duplicates (object ref or varname). * http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 75d3475..375f695 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -187,6 +187,11 @@ public class CPlanMemoTable return _plans.get(hopID); } + public List<MemoTableEntry> get(long hopID, TemplateType type) { + return _plans.get(hopID).stream() + .filter(p -> p.type==type).collect(Collectors.toList()); + } + public List<MemoTableEntry> getDistinct(long hopID) { //return distinct entries wrt type and closed attributes return _plans.get(hopID).stream() http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index ed6a080..5769fcb 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -443,6 +443,20 @@ public class PlanSelectionFuseCostBased extends PlanSelection private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) { + //prune row aggregates with pure cellwise operations + for( Long hopID : R ) { + MemoTableEntry me = memo.getBest(hopID, TemplateType.RowTpl); + if( me.type == TemplateType.RowTpl && memo.contains(hopID, TemplateType.CellTpl) + && rIsRowTemplateWithoutAgg(memo, memo._hopRefs.get(hopID), new HashSet<Long>())) { + List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.RowTpl); + memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(blacklist)); + if( LOG.isTraceEnabled() ) { + LOG.trace("Removed row memo table entries w/o aggregation: " + + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); + } + } + } + //if no materialization points, use basic fuse-all w/ partition awareness if( M == null || M.isEmpty() ) { for( Long hopID : R ) @@ -497,6 +511,21 @@ public class PlanSelectionFuseCostBased extends PlanSelection } } + private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { + if( visited.contains(current.getHopID()) ) + return true; + + boolean ret = true; + MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl); + for(int i=0; i<3; i++) + if( me.isPlanRef(i) ) + ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited); + ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp); + + visited.add(current.getHopID()); + return ret; + } + private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) { //memoization (not via hops because in middle of dag) if( visited.contains(current.getHopID()) ) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 3979aae..ae25ded 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -87,12 +87,13 @@ public class TemplateRow extends TemplateBase return !isClosed() && ( (hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop) && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) - || HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) + || HopRewriteUtils.isBinaryMatrixScalarOperation(hop) + || HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) ) || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop)) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) - || (hop instanceof AggBinaryOp && hop.getDim1()>1 + || (hop instanceof AggBinaryOp && hop.getDim1()>1 && hop.getDim2()==1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); } @@ -100,8 +101,9 @@ public class TemplateRow extends TemplateBase public boolean merge(Hop hop, Hop input) { //merge rowagg tpl with cell tpl if input is a vector return !isClosed() && - ((hop instanceof BinaryOp && input.getDim2()==1 //matrix-scalar/vector-vector ops ) - && TemplateUtils.isOperationSupported(hop)) + ((hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop) + && (input.getDim2()==1 //matrix-scalar/vector-vector ops ) + || HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop))) ||(hop instanceof AggBinaryOp && input.getDim2()==1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); } @@ -140,6 +142,8 @@ public class TemplateRow extends TemplateBase tpl.setRowType(TemplateUtils.getRowType(hop, sinHops.get(0))); tpl.setNumVectorIntermediates(TemplateUtils .countVectorIntermediates(output, new HashSet<Long>())); + tpl.getOutput().resetVisitStatus(); + tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops.get(0).getHopID()); // return cplan instance return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); @@ -199,6 +203,7 @@ public class TemplateRow extends TemplateBase inHops.remove(hop.getInput().get(0)); inHops.add(hop.getInput().get(0).getInput().get(0)); + //note: vectorMultAdd applicable to vector-scalar, and vector-vector out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD); inHops2.put("X", hop.getInput().get(0).getInput().get(0)); } @@ -245,10 +250,16 @@ public class TemplateRow extends TemplateBase if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 ) { if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) { - String opname = "VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR"; - if( TemplateUtils.isColVector(cdata2) ) - cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); - out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname)); + if( TemplateUtils.isMatrix(cdata1) && TemplateUtils.isMatrix(cdata2) ) { + String opname = "VECT_"+((BinaryOp)hop).getOp().name(); + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname)); + } + else { + String opname = "VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR"; + if( TemplateUtils.isColVector(cdata2) ) + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname)); + } } else throw new RuntimeException("Unsupported binary matrix " http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 89211db..2059e17 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -74,6 +74,11 @@ public class TemplateUtils && hop.getNumRows() == 1 && hop.getNumCols() != 1); } + public static boolean isMatrix(CNode hop) { + return (hop.getDataType() == DataType.MATRIX + && hop.getNumRows() != 1 && hop.getNumCols() != 1); + } + public static CNode wrapLookupIfNecessary(CNode node, Hop hop) { CNode ret = node; if( isColVector(node) ) @@ -343,10 +348,11 @@ public class TemplateUtils for( CNode c : node.getInput() ) ret += countVectorIntermediates(c, memo); //compute vector requirements of current node - int cntBin = ((node instanceof CNodeBinary - && ((CNodeBinary)node).getType().isVectorScalarPrimitive()) ? 1 : 0); - int cntUn = ((node instanceof CNodeUnary - && ((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0); + int cntBin = (node instanceof CNodeBinary + && (((CNodeBinary)node).getType().isVectorScalarPrimitive() + || ((CNodeBinary)node).getType().isVectorVectorPrimitive())) ? 1 : 0; + int cntUn = (node instanceof CNodeUnary + && ((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0; return ret + cntBin + cntUn; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index f7be8b9..b406bb7 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import org.apache.commons.lang.ArrayUtils; import org.apache.sysml.api.DMLScript; @@ -794,6 +795,24 @@ public class HopRewriteUtils && hop.getInput().get(1).dimsKnown() && hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1; } + public static boolean isBinaryMatrixMatrixOperationWithSharedInput(Hop hop) { + boolean ret = isBinaryMatrixMatrixOperation(hop); + ret = ret && (rContainsInput(hop.getInput().get(0), hop.getInput().get(1), new HashSet<Long>()) + || rContainsInput(hop.getInput().get(1), hop.getInput().get(0), new HashSet<Long>())); + return ret; + } + + private static boolean rContainsInput(Hop current, Hop probe, HashSet<Long> memo) { + if( memo.contains(current.getHopID()) ) + return false; + boolean ret = false; + for( int i=0; i<current.getInput().size() && !ret; i++ ) + ret |= rContainsInput(current.getInput().get(i), probe, memo); + ret |= (current == probe); + memo.add(current.getHopID()); + return ret; + } + public static boolean isBinaryMatrixColVectorOperation(Hop hop) { return hop instanceof BinaryOp && hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix() http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java index 7a9adeb..1b4369f 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java @@ -63,18 +63,36 @@ public class LibSpoofPrimitives LibMatrixMult.vectMultiplyAdd(bval, a, c, bix, bi, ci, len); } + public static void vectMultAdd(double[] a, double[] b, double[] c, int bi, int ci, int len) { + double[] tmp = vectMultWrite(a, b, 0, bi, len); + LibMatrixMult.vectAdd(tmp, c, 0, ci, len); + } + public static double[] vectMultWrite(double[] a, double bval, int bi, int len) { double[] c = allocVector(len, false); LibMatrixMult.vectMultiplyWrite(bval, a, c, bi, 0, len); return c; } + public static double[] vectMultWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + LibMatrixMult.vectMultiplyWrite(a, b, c, ai, bi, 0, len); + return c; + } + public static double[] vectMultWrite(double[] a, double bval, int[] bix, int bi, int len) { double[] c = allocVector(len, true); LibMatrixMult.vectMultiplyAdd(bval, a, c, bix, bi, 0, len); return c; } + public static double[] vectMultWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = a[j] * b[bi+aix[j]]; + return c; + } + public static void vectWrite(double[] a, double[] c, int ci, int len) { System.arraycopy(a, 0, c, ci, len); } @@ -168,6 +186,13 @@ public class LibSpoofPrimitives c[j] = a[ai] / bval; return c; } + + public static double[] vectDivWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = a[ai] / b[bi]; + return c; + } public static double[] vectDivWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -176,6 +201,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectDivWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = a[j] / b[bi+aix[j]]; + return c; + } + //custom vector minus public static void vectMinusAdd(double[] a, double bval, double[] c, int ai, int ci, int len) { @@ -195,6 +227,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectMinusWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = a[ai] - b[bi]; + return c; + } + public static double[] vectMinusWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); for( int j = ai; j < ai+len; j++ ) @@ -202,6 +241,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectMinusWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = a[j] - b[bi+aix[j]]; + return c; + } + //custom vector plus public static void vectPlusAdd(double[] a, double bval, double[] c, int ai, int ci, int len) { @@ -220,6 +266,13 @@ public class LibSpoofPrimitives c[j] = a[ai] + bval; return c; } + + public static double[] vectPlusWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++) + c[j] = a[ai] + b[bi]; + return c; + } public static double[] vectPlusWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -228,6 +281,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectPlusWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = a[j] + b[bi+aix[j]]; + return c; + } + //custom vector pow public static void vectPowAdd(double[] a, double bval, double[] c, int ai, int ci, int len) { @@ -246,6 +306,13 @@ public class LibSpoofPrimitives c[j] = Math.pow(a[ai], bval); return c; } + + public static double[] vectPowWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = Math.pow(a[ai], b[bi]); + return c; + } public static double[] vectPowWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -272,6 +339,13 @@ public class LibSpoofPrimitives c[j] = Math.min(a[ai], bval); return c; } + + public static double[] vectMinWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = Math.min(a[ai], b[bi]); + return c; + } public static double[] vectMinWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -280,6 +354,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectMinWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = Math.min(a[j], b[bi+aix[j]]); + return c; + } + //custom vector max public static void vectMaxAdd(double[] a, double bval, double[] c, int ai, int ci, int len) { @@ -298,6 +379,13 @@ public class LibSpoofPrimitives c[j] = Math.max(a[ai], bval); return c; } + + public static double[] vectMaxWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = Math.max(a[ai], b[bi]); + return c; + } public static double[] vectMaxWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -305,6 +393,13 @@ public class LibSpoofPrimitives c[aix[j]] = Math.max(a[j], bval); return c; } + + public static double[] vectMaxWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = Math.max(a[j], b[bi+aix[j]]); + return c; + } //custom exp @@ -584,13 +679,27 @@ public class LibSpoofPrimitives c[j] = (a[ai] == bval) ? 1 : 0; return c; } + + public static double[] vectEqualWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = (a[ai] == b[bi]) ? 1 : 0; + return c; + } public static double[] vectEqualWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); for( int j = ai; j < ai+len; j++ ) c[aix[j]] = (a[j] == bval) ? 1 : 0; return c; - } + } + + public static double[] vectEqualWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = (a[j] == b[bi+aix[j]]) ? 1 : 0; + return c; + } //custom vector not equal @@ -610,6 +719,13 @@ public class LibSpoofPrimitives c[j] = (a[ai] != bval) ? 1 : 0; return c; } + + public static double[] vectNotequalWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = (a[ai] != b[bi]) ? 1 : 0; + return c; + } public static double[] vectNotequalWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -618,6 +734,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectNotequalWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = (a[j] != b[bi+aix[j]]) ? 1 : 0; + return c; + } + //custom vector less public static void vectLessAdd(double[] a, double bval, double[] c, int ai, int ci, int len) { @@ -636,6 +759,13 @@ public class LibSpoofPrimitives c[j] = (a[ai] < bval) ? 1 : 0; return c; } + + public static double[] vectLessWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = (a[ai] < b[bi]) ? 1 : 0; + return c; + } public static double[] vectLessWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -644,6 +774,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectLessWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = (a[j] < b[bi+aix[j]]) ? 1 : 0; + return c; + } + //custom vector less equal public static void vectLessequalAdd(double[] a, double bval, double[] c, int ai, int ci, int len) { @@ -662,6 +799,13 @@ public class LibSpoofPrimitives c[j] = (a[ai] <= bval) ? 1 : 0; return c; } + + public static double[] vectLessequalWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = (a[ai] <= b[bi]) ? 1 : 0; + return c; + } public static double[] vectLessequalWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -669,6 +813,13 @@ public class LibSpoofPrimitives c[aix[j]] = (a[j] <= bval) ? 1 : 0; return c; } + + public static double[] vectLessequalWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = (a[j] <= b[bi+aix[j]]) ? 1 : 0; + return c; + } //custom vector greater @@ -688,13 +839,27 @@ public class LibSpoofPrimitives c[j] = (a[ai] > bval) ? 1 : 0; return c; } + + public static double[] vectGreaterWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = (a[ai] > b[bi]) ? 1 : 0; + return c; + } public static double[] vectGreaterWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); for( int j = ai; j < ai+len; j++ ) c[aix[j]] = (a[j] > bval) ? 1 : 0; return c; - } + } + + public static double[] vectGreaterWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = (a[j] > b[bi+aix[j]]) ? 1 : 0; + return c; + } //custom vector greater equal @@ -714,6 +879,13 @@ public class LibSpoofPrimitives c[j] = (a[ai] >= bval) ? 1 : 0; return c; } + + public static double[] vectGreaterequalWrite(double[] a, double[] b, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = 0; j < len; j++, ai++, bi++) + c[j] = (a[ai] >= b[bi]) ? 1 : 0; + return c; + } public static double[] vectGreaterequalWrite(double[] a, double bval, int[] aix, int ai, int len) { double[] c = allocVector(len, true); @@ -722,6 +894,13 @@ public class LibSpoofPrimitives return c; } + public static double[] vectGreaterequalWrite(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double[] c = allocVector(len, false); + for( int j = ai; j < ai+len; j++ ) + c[aix[j]] = (a[j] >= b[bi+aix[j]]) ? 1 : 0; + return c; + } + //complex builtin functions that are not directly generated //(included here in order to reduce the number of imports) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index 983ce53..f7d2d54 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -3146,8 +3146,8 @@ public class LibMatrixMult } } - @SuppressWarnings("unused") - private static void vectMultiplyWrite( double[] a, double[] b, double[] c, int ai, int bi, int ci, final int len ) + //note: public for use by codegen for consistency + public static void vectMultiplyWrite( double[] a, double[] b, double[] c, int ai, int bi, int ci, final int len ) { final int bn = len%8; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index b7f82a7..362f1dc 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -50,7 +50,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME12 = TEST_NAME+"12"; //Y=(X>=v); R=Y/rowSums(Y) private static final String TEST_NAME13 = TEST_NAME+"13"; //rowSums(X)+rowSums(Y) private static final String TEST_NAME14 = TEST_NAME+"14"; //colSums(max(floor(round(abs(min(sign(X+Y),1)))),7)) - private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn - softmax backward (partially) + private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn - softmax backward private static final String TEST_NAME16 = TEST_NAME+"16"; //Y=X-rowIndexMax(X); R=Y/rowSums(Y) private static final String TEST_DIR = "functions/codegen/"; @@ -344,6 +344,11 @@ public class RowAggTmplTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); Assert.assertTrue(heavyHittersContainsSubString("spoofRA") || heavyHittersContainsSubString("sp_spoofRA")); + + //ensure full aggregates for certain patterns + if( testname.equals(TEST_NAME15) ) + Assert.assertTrue(!heavyHittersContainsSubString("uark+")); + } finally { rtplatform = platformOld; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/test/scripts/functions/codegen/rowAggPattern15.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern15.R b/src/test/scripts/functions/codegen/rowAggPattern15.R index a24679a..3151f0c 100644 --- a/src/test/scripts/functions/codegen/rowAggPattern15.R +++ b/src/test/scripts/functions/codegen/rowAggPattern15.R @@ -30,6 +30,7 @@ X = matrix(seq(1,1500), 150, 10, byrow=TRUE); Y1 = X - rowMaxs(X) Y2 = exp(Y1) Y3 = Y2 / rowSums(Y2) -R = Y3 * rowSums(Y3) +Y4 = Y3 * rowSums(Y3) +R = Y4 - Y3 * rowSums(Y4) writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ac04b570/src/test/scripts/functions/codegen/rowAggPattern15.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern15.dml b/src/test/scripts/functions/codegen/rowAggPattern15.dml index d51397a..0dd98a5 100644 --- a/src/test/scripts/functions/codegen/rowAggPattern15.dml +++ b/src/test/scripts/functions/codegen/rowAggPattern15.dml @@ -24,6 +24,7 @@ X = matrix(seq(1,1500), rows=150, cols=10); Y1 = X - rowMaxs(X) Y2 = exp(Y1) Y3 = Y2 / rowSums(Y2) -R = Y3 * rowSums(Y3) +Y4 = Y3 * rowSums(Y3) +R = Y4 - Y3 * rowSums(Y4) write(R, $1)