Repository: incubator-systemml Updated Branches: refs/heads/master 10f5dd920 -> 0bf27c2a8
[SYSTEMML-1439] Improved codegen candidate exploration algorithm This patch improves the existing code generation candidate exploration algorithm by merge considerations after opening templates in order to support patterns like t(X) %*% (v - abs(y)), where the opening condition of row aggregates, here t(X)%*%, is the closing condition but we still want to fuse all amenable cellwise operations of the other input. Furthermore, this also includes the following minor enhancements: * Support for ternary hops in row aggregate templates * Extended support for binary vector operations with row lookups in row aggregate templates * Handle empty sparse/dense inputs to row aggregate operations * Simplified pruning of redundant partial subplans Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0bf27c2a Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0bf27c2a Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0bf27c2a Branch: refs/heads/master Commit: 0bf27c2a8c747f4c2d7cbbfc8af5da02e2081dfa Parents: 10f5dd9 Author: Matthias Boehm <mboe...@gmail.com> Authored: Wed Mar 29 21:22:35 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Wed Mar 29 21:39:38 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 30 ++++++----- .../sysml/hops/codegen/cplan/CNodeUnary.java | 24 ++++----- .../hops/codegen/template/CPlanMemoTable.java | 23 ++++---- .../hops/codegen/template/TemplateBase.java | 6 ++- .../hops/codegen/template/TemplateCell.java | 4 ++ .../codegen/template/TemplateOuterProduct.java | 4 ++ .../hops/codegen/template/TemplateRowAgg.java | 57 +++++++++++++++----- .../hops/codegen/template/TemplateUtils.java | 17 ++++-- .../runtime/codegen/SpoofRowAggregate.java | 6 +++ .../functions/codegen/AlgorithmL2SVM.java | 6 ++- .../functions/codegen/RowAggTmplTest.java | 27 +++++++--- .../scripts/functions/codegen/rowAggPattern9.R | 32 +++++++++++ .../functions/codegen/rowAggPattern9.dml | 29 ++++++++++ 13 files changed, 198 insertions(+), 67 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 187a9ca..071b03b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -345,8 +345,10 @@ public class SpoofCompiler //open initial operator plans, if possible for( TemplateBase tpl : TemplateUtils.TEMPLATES ) - if( tpl.open(hop) ) - memo.add(hop, tpl.getType()); + if( tpl.open(hop) ) { + MemoTableEntrySet P = new MemoTableEntrySet(tpl.getType(), false); + memo.addAll(hop, enumPlans(hop, -1, P, tpl, memo)); + } //fuse and merge operator plans for( Hop c : hop.getInput() ) { @@ -356,16 +358,7 @@ public class SpoofCompiler if( tpl.fuse(hop, c) ) { int pos = hop.getInput().indexOf(c); MemoTableEntrySet P = new MemoTableEntrySet(tpl.getType(), pos, c.getHopID(), tpl.isClosed()); - for(int k=0; k<hop.getInput().size(); k++) - if( k != pos ) { - Hop input2 = hop.getInput().get(k); - if( memo.contains(input2.getHopID()) && !memo.get(input2.getHopID()).get(0).closed - && memo.get(input2.getHopID()).get(0).type == TemplateType.CellTpl && tpl.merge(hop, input2) ) - P.crossProduct(k, -1L, input2.getHopID()); - else - P.crossProduct(k, -1L); - } - memo.addAll(hop, P); + memo.addAll(hop, enumPlans(hop, pos, P, tpl, memo)); } } } @@ -392,6 +385,19 @@ public class SpoofCompiler memo.addHop(hop); } + private static MemoTableEntrySet enumPlans(Hop hop, int pos, MemoTableEntrySet P, TemplateBase tpl, CPlanMemoTable memo) { + for(int k=0; k<hop.getInput().size(); k++) + if( k != pos ) { + Hop input2 = hop.getInput().get(k); + if( memo.contains(input2.getHopID()) && !memo.get(input2.getHopID()).get(0).closed + && memo.get(input2.getHopID()).get(0).type == TemplateType.CellTpl && tpl.merge(hop, input2) ) + P.crossProduct(k, -1L, input2.getHopID()); + else + P.crossProduct(k, -1L); + } + return P; + } + private static void rConstructCPlans(Hop hop, CPlanMemoTable memo, HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, boolean compileLiterals) throws DMLException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index 5fcb7bc..4cf25cd 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -124,26 +124,22 @@ public class CNodeUnary extends CNode //generate children sb.append(_inputs.get(0).codegen(sparse)); - //generate binary operation + //generate unary operation String var = createVarname(); String tmp = _type.getTemplate(sparse); tmp = tmp.replaceAll("%TMP%", var); String varj = _inputs.get(0).getVarname(); - if( sparse && !tmp.contains("%IN1%") ) { - tmp = tmp.replaceAll("%IN1v%", varj+"vals"); - tmp = tmp.replaceAll("%IN1i%", varj+"ix"); - } - else - tmp = tmp.replaceAll("%IN1%", varj ); - if(varj.startsWith("b") ) //i.e. b.get(index) - { - tmp = tmp.replaceAll("%POS1%", "bi"); - tmp = tmp.replaceAll("%POS2%", "bi"); - } - tmp = tmp.replaceAll("%POS1%", varj+"i"); - tmp = tmp.replaceAll("%POS2%", varj+"i"); + //replace sparse and dense inputs + tmp = tmp.replaceAll("%IN1v%", varj+"vals"); + tmp = tmp.replaceAll("%IN1i%", varj+"ix"); + tmp = tmp.replaceAll("%IN1%", varj ); + + //replace start position of main input + String spos = !varj.startsWith("b") ? varj+"i" : "0"; + tmp = tmp.replaceAll("%POS1%", spos); + tmp = tmp.replaceAll("%POS2%", spos); sb.append(tmp); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index cbafc2a..1b920cd 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Map.Entry; import java.util.stream.Collectors; -import org.apache.commons.collections.CollectionUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.Hop; @@ -104,23 +103,20 @@ public class CPlanMemoTable .distinct().collect(Collectors.toList())); } - @SuppressWarnings("unchecked") public void pruneRedundant(long hopID) { if( !contains(hopID) ) return; //prune redundant plans (i.e., equivalent) - HashSet<MemoTableEntry> set = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> list = _plans.get(hopID); - for( MemoTableEntry me : list ) - set.add(me); + setDistinct(hopID, _plans.get(hopID)); //prune dominated plans (e.g., opened plan subsumed //by fused plan if single consumer of input) - ArrayList<MemoTableEntry> rmList = new ArrayList<MemoTableEntry>(); + HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>(); + List<MemoTableEntry> list = _plans.get(hopID); Hop hop = _hopRefs.get(hopID); - for( MemoTableEntry e1 : set ) - for( MemoTableEntry e2 : set ) + for( MemoTableEntry e1 : list ) + for( MemoTableEntry e2 : list ) if( e1 != e2 && e1.subsumes(e2) ) { //check that childs don't have multiple consumers boolean rmSafe = true; @@ -131,9 +127,8 @@ public class CPlanMemoTable rmList.add(e2); } - //update current entry list - list.clear(); - list.addAll(CollectionUtils.subtract(set, rmList)); + //update current entry list, by removing rmList + remove(hop, rmList); } public void pruneSuboptimal(ArrayList<Hop> roots) { @@ -301,6 +296,10 @@ public class CPlanMemoTable { public ArrayList<MemoTableEntry> plans = new ArrayList<MemoTableEntry>(); + public MemoTableEntrySet(TemplateType type, boolean close) { + plans.add(new MemoTableEntry(type, -1, -1, -1, close)); + } + public MemoTableEntrySet(TemplateType type, int pos, long hopID, boolean close) { plans.add(new MemoTableEntry(type, (pos==0)?hopID:-1, (pos==1)?hopID:-1, (pos==2)?hopID:-1)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java index 1c5ea56..9d2466e 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java @@ -46,8 +46,12 @@ public abstract class TemplateBase protected boolean _closed = false; protected TemplateBase(TemplateType type) { + this(type, false); + } + + protected TemplateBase(TemplateType type, boolean closed) { _type = type; - _closed = false; + _closed = closed; } public TemplateType getType() { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 46e265c..87ec899 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -60,6 +60,10 @@ public class TemplateCell extends TemplateBase public TemplateCell() { super(TemplateType.CellTpl); } + + public TemplateCell(boolean closed) { + super(TemplateType.CellTpl, closed); + } @Override public boolean open(Hop hop) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java index 5918e5e..a1d2174 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java @@ -51,6 +51,10 @@ public class TemplateOuterProduct extends TemplateBase { public TemplateOuterProduct() { super(TemplateType.OuterProdTpl); } + + public TemplateOuterProduct(boolean closed) { + super(TemplateType.OuterProdTpl, closed); + } @Override public boolean open(Hop hop) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java index b384f52..f8f1508 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java @@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; @@ -60,31 +61,36 @@ public class TemplateRowAgg extends TemplateBase super(TemplateType.RowAggTpl); } + public TemplateRowAgg(boolean closed) { + super(TemplateType.RowAggTpl, closed); + } + @Override public boolean open(Hop hop) { - //any unary or binary aggregate operation with a vector output, but exclude binary aggregate - //with transposed input to avoid counter-productive fusion - return ( ((hop instanceof AggBinaryOp && hop.getInput().get(1).getDim1()>1 - && !HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) - || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM )) - && ( (hop.getDim1()==1 && hop.getDim2()!=1) || (hop.getDim1()!=1 && hop.getDim2()==1) ) ); + return (hop instanceof AggBinaryOp && hop.getDim2()==1 + && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol + && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1); } @Override public boolean fuse(Hop hop, Hop input) { return !isClosed() && ( (hop instanceof BinaryOp && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) - || HopRewriteUtils.isBinaryMatrixScalarOperation(hop))) + || HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) || (hop instanceof UnaryOp && TemplateCell.isValidOperation(hop)) - || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col) - || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol) + || (hop instanceof AggBinaryOp && hop.getDim1()>1 + && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); } @Override public boolean merge(Hop hop, Hop input) { //merge rowagg tpl with cell tpl if input is a vector return !isClosed() && - (hop instanceof BinaryOp && input.getDim2()==1 ); //matrix-scalar/vector-vector ops ) + ((hop instanceof BinaryOp && input.getDim2()==1) //matrix-scalar/vector-vector ops ) + ||(hop instanceof AggBinaryOp && input.getDim2()==1 + && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); } @Override @@ -208,6 +214,8 @@ public class TemplateRowAgg extends TemplateBase { if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) { String opname = "VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR"; + if( TemplateUtils.isColVector(cdata2) ) + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname)); } else @@ -217,15 +225,36 @@ public class TemplateRowAgg extends TemplateBase else //one input is a vector/scalar other is a scalar { String primitiveOpName = ((BinaryOp)hop).getOp().toString(); - if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) { + if( TemplateUtils.isColVector(cdata1) ) cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - } - if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) { + if( TemplateUtils.isColVector(cdata2) ) cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); - } out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } } + else if(hop instanceof TernaryOp) + { + TernaryOp top = (TernaryOp) hop; + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID()); + + //cdata1 is vector + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); + else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + + //cdata3 is vector + if( TemplateUtils.isColVector(cdata3) ) + cdata3 = new CNodeUnary(cdata3, UnaryType.LOOKUP_R); + else if( cdata3 instanceof CNodeData && hop.getInput().get(2).getDataType().isMatrix() ) + cdata3 = new CNodeUnary(cdata3, UnaryType.LOOKUP_RC); + + //construct ternary cnode, primitive operation derived from OpOp3 + out = new CNodeTernary(cdata1, cdata2, cdata3, + TernaryType.valueOf(top.getOp().toString())); + } else if( hop instanceof IndexingOp ) { CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index adab46c..3f5fed9 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -182,11 +182,20 @@ public class TemplateUtils public static TemplateBase createTemplate(TemplateType type, boolean closed) { TemplateBase tpl = null; switch( type ) { - case CellTpl: tpl = new TemplateCell(); break; - case RowAggTpl: tpl = new TemplateRowAgg(); break; - case OuterProdTpl: tpl = new TemplateOuterProduct(); break; + case CellTpl: tpl = new TemplateCell(closed); break; + case RowAggTpl: tpl = new TemplateRowAgg(closed); break; + case OuterProdTpl: tpl = new TemplateOuterProduct(closed); break; + } + return tpl; + } + + public static TemplateBase[] createCompatibleTemplates(TemplateType type, boolean closed) { + TemplateBase[] tpl = null; + switch( type ) { + case CellTpl: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRowAgg(closed)}; break; + case RowAggTpl: tpl = new TemplateBase[]{new TemplateRowAgg(closed)}; break; + case OuterProdTpl: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break; } - tpl._closed = closed; return tpl; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowAggregate.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowAggregate.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowAggregate.java index 6df9b67..500c91f 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowAggregate.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowAggregate.java @@ -129,6 +129,9 @@ public abstract class SpoofRowAggregate extends SpoofOperator private void executeDense(double[] a, double[][] b, double[] scalars, double[] c, int n, int rl, int ru) { + if( a == null ) + return; + for( int i=rl, aix=rl*n; i<ru; i++, aix+=n ) { //call generated method genexecRowDense( a, aix, b, scalars, c, n, i ); @@ -137,6 +140,9 @@ public abstract class SpoofRowAggregate extends SpoofOperator private void executeSparse(SparseBlock sblock, double[][] b, double[] scalars, double[] c, int n, int rl, int ru) { + if( sblock == null ) + return; + for( int i=rl; i<ru; i++ ) { if( !sblock.isEmpty(i) ) { double[] avals = sblock.values(i); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java index 025f065..7481022 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java @@ -44,8 +44,9 @@ public class AlgorithmL2SVM extends AutomatedTestBase private final static double eps = 1e-5; private final static int rows = 1468; - private final static int cols = 1007; - + private final static int cols1 = 1007; + private final static int cols2 = 987; + private final static double sparsity1 = 0.7; //dense private final static double sparsity2 = 0.1; //sparse @@ -133,6 +134,7 @@ public class AlgorithmL2SVM extends AutomatedTestBase OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; //generate actual datasets + int cols = (instType==ExecType.SPARK) ? cols2 : cols1; double[][] X = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 714); writeInputMatrixWithMTD("X", X, true); double[][] y = TestUtils.round(getRandomMatrix(rows, 1, 0, 1, 1.0, 136)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index c2ae38b..13e0283 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -36,14 +36,15 @@ import org.apache.sysml.test.utils.TestUtils; public class RowAggTmplTest extends AutomatedTestBase { private static final String TEST_NAME = "rowAggPattern"; - private static final String TEST_NAME1 = TEST_NAME+"1"; - private static final String TEST_NAME2 = TEST_NAME+"2"; - private static final String TEST_NAME3 = TEST_NAME+"3"; - private static final String TEST_NAME4 = TEST_NAME+"4"; - private static final String TEST_NAME5 = TEST_NAME+"5"; - private static final String TEST_NAME6 = TEST_NAME+"6"; - private static final String TEST_NAME7 = TEST_NAME+"7"; + private static final String TEST_NAME1 = TEST_NAME+"1"; //t(X)%*%(X%*%(lamda*v)) + private static final String TEST_NAME2 = TEST_NAME+"2"; //t(X)%*%(lamda*(X%*%v)) + private static final String TEST_NAME3 = TEST_NAME+"3"; //t(X)%*%(z+(2-(w*(X%*%v)))) + private static final String TEST_NAME4 = TEST_NAME+"4"; //colSums(X/rowSums(X)) + private static final String TEST_NAME5 = TEST_NAME+"5"; //t(X)%*%((P*(1-P))*(X%*%v)); + private static final String TEST_NAME6 = TEST_NAME+"6"; //t(X)%*%((P[,1]*(1-P[,1]))*(X%*%v)); + private static final String TEST_NAME7 = TEST_NAME+"7"; //t(X)%*%(X%*%v-y); sum((X%*%v-y)^2); private static final String TEST_NAME8 = TEST_NAME+"8"; //colSums((X/rowSums(X))>0.7) + private static final String TEST_NAME9 = TEST_NAME+"9"; //t(X) %*% (v - abs(y)) private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -55,7 +56,7 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=8; i++) + for(int i=1; i<=9; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @@ -99,6 +100,11 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME8, true, ExecType.CP ); } + @Test + public void testCodegenRowAggRewrite9() { + testCodegenIntegration( TEST_NAME9, true, ExecType.CP ); + } + @Test public void testCodegenRowAgg1() { testCodegenIntegration( TEST_NAME1, false, ExecType.CP ); @@ -139,6 +145,11 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME8, false, ExecType.CP ); } + @Test + public void testCodegenRowAgg9() { + testCodegenIntegration( TEST_NAME9, false, ExecType.CP ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/test/scripts/functions/codegen/rowAggPattern9.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern9.R b/src/test/scripts/functions/codegen/rowAggPattern9.R new file mode 100644 index 0000000..aea3028 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern9.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(1,15), 5, 3); +v = seq(1,5); +y = abs(seq(1,5)); + +R = t(X) %*% (v - y); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0bf27c2a/src/test/scripts/functions/codegen/rowAggPattern9.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern9.dml b/src/test/scripts/functions/codegen/rowAggPattern9.dml new file mode 100644 index 0000000..4e24c8f --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern9.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,15), rows=5, cols=3); +v = seq(1,5); +y = abs(seq(1,5)); + +R = t(X) %*% (v - y); + +write(R, $1) +