Repository: systemml Updated Branches: refs/heads/master c170374e7 -> ca5581fcc
[SYSTEMML-1443] Codegen constraint handling for distributed row ops For distributed rowwise fused operators, the cost-based codegen plan selector has to explicitly handle conditional blocksize constraints of ncol(X) <= blocksize to guarantee that entire rows are available. These constraints are conditional on the selected spark execution type, which in turns depends on the total size of operator inputs and output (and thus fusion decisions). The cost-based plan selector now applies a best-effort pre-filtering of invalid partial row plans. Additionally, any remaining invalid plans are pruned during cplan cleanup which guarantees valid runtime plans for all selection policies. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ca5581fc Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ca5581fc Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ca5581fc Branch: refs/heads/master Commit: ca5581fccd70a6ae974e29a9d11e6d4aafe971e4 Parents: c170374 Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Aug 10 00:54:46 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu Aug 10 00:54:46 2017 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/OptimizerUtils.java | 7 ++ .../sysml/hops/codegen/SpoofCompiler.java | 87 ++++++++------ .../opt/PlanSelectionFuseCostBasedV2.java | 116 ++++++++++++------- .../hops/codegen/template/CPlanMemoTable.java | 17 ++- .../instructions/spark/SpoofSPInstruction.java | 7 +- .../functions/codegen/AlgorithmGLM.java | 60 ++++++++++ .../functions/codegen/AlgorithmLinregCG.java | 16 ++- 7 files changed, 227 insertions(+), 83 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index 7f07cfc..a0a36d5 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -19,6 +19,7 @@ package org.apache.sysml.hops; +import java.util.Arrays; import java.util.HashMap; import org.apache.commons.logging.Log; @@ -769,6 +770,12 @@ public class OptimizerUtils return bsize; } + public static double getTotalMemEstimate(Hop[] in, Hop out) { + return Arrays.stream(in) + .mapToDouble(h -> h.getOutputMemEstimate()).sum() + + out.getOutputMemEstimate(); + } + /** * Indicates if the given indexing range is block aligned, i.e., it does not require * global aggregation of blocks. http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index d5c9618..49a1686 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -35,6 +35,7 @@ import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeCell; import org.apache.sysml.hops.codegen.cplan.CNodeData; @@ -354,13 +355,26 @@ public class SpoofCompiler //context-sensitive literal replacement (only integers during recompile) boolean compileLiterals = (PLAN_CACHE_POLICY==PlanCachePolicy.CONSTANT) || !recompile; - //construct codegen plans - HashMap<Long, Pair<Hop[],CNodeTpl>> cplans = constructCPlans(roots, compileLiterals); + //candidate exploration of valid partial fusion plans + CPlanMemoTable memo = new CPlanMemoTable(); + for( Hop hop : roots ) + rExploreCPlans(hop, memo, compileLiterals); + + //candidate selection of optimal fusion plan + memo.pruneSuboptimal(roots); + + //construct actual cplan representations + //note: we do not use the hop visit status due to jumps over fused operators which would + //corrupt subsequent resets, leaving partial hops dags in visited status + HashMap<Long, Pair<Hop[],CNodeTpl>> cplans = new LinkedHashMap<>(); + HashSet<Long> visited = new HashSet<Long>(); + for( Hop hop : roots ) + rConstructCPlans(hop, memo, cplans, compileLiterals, visited); //cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping, //remove empty templates with single cnodedata input, remove spurious lookups, //perform common subexpression elimination) - cplans = cleanupCPlans(cplans); + cplans = cleanupCPlans(memo, cplans); //explain before modification if( LOG.isTraceEnabled() && !cplans.isEmpty() ) { //existing cplans @@ -476,27 +490,6 @@ public class SpoofCompiler //////////////////// // Codegen plan construction - - private static HashMap<Long, Pair<Hop[],CNodeTpl>> constructCPlans(ArrayList<Hop> roots, boolean compileLiterals) throws DMLException - { - //explore cplan candidates - CPlanMemoTable memo = new CPlanMemoTable(); - for( Hop hop : roots ) - rExploreCPlans(hop, memo, compileLiterals); - - //select optimal cplan candidates - memo.pruneSuboptimal(roots); - - //construct actual cplan representations - //note: we do not use the hop visit status due to jumps over fused operators which would - //corrupt subsequent resets, leaving partial hops dags in visited status - LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>(); - HashSet<Long> visited = new HashSet<Long>(); - for( Hop hop : roots ) - rConstructCPlans(hop, memo, ret, compileLiterals, visited); - - return ret; - } private static void rExploreCPlans(Hop hop, CPlanMemoTable memo, boolean compileLiterals) throws DMLException @@ -664,9 +657,10 @@ public class SpoofCompiler * during incremental construction. This is important as it avoids unnecessary * redundant computation. * + * @param memo memoization table * @param cplans set of cplans */ - private static HashMap<Long, Pair<Hop[],CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) + private static HashMap<Long, Pair<Hop[],CNodeTpl>> cleanupCPlans(CPlanMemoTable memo, HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) { HashMap<Long, Pair<Hop[],CNodeTpl>> cplans2 = new HashMap<Long, Pair<Hop[],CNodeTpl>>(); CPlanCSERewriter cse = new CPlanCSERewriter(); @@ -711,24 +705,51 @@ public class SpoofCompiler else rFindAndRemoveLookup(tpl.getOutput(), in1); - //remove invalid row templates (e.g., due to partial unknowns) - if( tpl instanceof CNodeRow && (in1.getNumCols() == 1 - || (((CNodeRow)tpl).getRowType()==RowType.NO_AGG - && tpl.getOutput().getDataType().isScalar())) ) - cplans2.remove(e.getKey()); + //remove invalid row templates (e.g., unsatisfied blocksize constraint) + if( tpl instanceof CNodeRow ) { + //check for invalid row cplan over column vector + if(in1.getNumCols() == 1 || (((CNodeRow)tpl).getRowType()==RowType.NO_AGG + && tpl.getOutput().getDataType().isScalar()) ) { + cplans2.remove(e.getKey()); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed invalid row cplan w/o agg on column vector."); + } + else if( OptimizerUtils.isSparkExecutionMode() ) { + boolean isSpark = DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK + || OptimizerUtils.getTotalMemEstimate(inHops, memo.getHopRefs().get(e.getKey())) + > OptimizerUtils.getLocalMemBudget(); + boolean invalidNcol = false; + for( Hop in : inHops ) + invalidNcol |= (in.getDataType().isMatrix() + && in.getDim2() > in.getColsInBlock()); + if( isSpark && invalidNcol ) { + cplans2.remove(e.getKey()); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed invalid row cplan w/ ncol>ncolpb."); + } + } + } //remove cplan w/ single op and w/o agg if( (tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG && TemplateUtils.hasSingleOperation(tpl) ) || (tpl instanceof CNodeRow && (((CNodeRow)tpl).getRowType()==RowType.NO_AGG - || ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1) + || ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1 + || ((CNodeRow)tpl).getRowType()==RowType.ROW_AGG ) && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl) ) + { cplans2.remove(e.getKey()); - + if( LOG.isTraceEnabled() ) + LOG.trace("Removed cplan with single operation."); + } + //remove cplan if empty - if( tpl.getOutput() instanceof CNodeData ) + if( tpl.getOutput() instanceof CNodeData ) { cplans2.remove(e.getKey()); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed empty cplan."); + } //rename inputs (for codegen and plan caching) tpl.renameInputs(); http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 717a059..e66c9c3 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -36,6 +36,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; @@ -44,6 +45,7 @@ import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; @@ -122,37 +124,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection private void selectPlans(CPlanMemoTable memo, PlanPartition part) { - //prune row aggregates with pure cellwise operations - for( Long hopID : part.getRoots() ) { - MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); - if( me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) - && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) { - List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); - memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist)); - if( LOG.isTraceEnabled() ) { - LOG.trace("Removed row memo table entries w/o aggregation: " - + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); - } - } - } - - //prune suboptimal outer product plans that are dominated by outer product plans w/ same number of - //references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), - //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern. - for( Long hopID : part.getPartition() ) { - if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) { - List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER); - MemoTableEntry me1 = entries.get(0); - MemoTableEntry me2 = entries.get(1); - MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2); - if( rmEntry != null ) { - memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry)); - memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex())); - if( LOG.isTraceEnabled() ) - LOG.trace("Removed dominated outer product memo table entry: " + rmEntry); - } - } - } + //prune special case patterns and invalid plans (e.g., blocksize) + pruneInvalidAndSpecialCasePlans(memo, part); //if no materialization points, use basic fuse-all w/ partition awareness if( part.getMatPointsExt() == null || part.getMatPointsExt().length==0 ) { @@ -163,8 +136,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection else { //obtain hop compute costs per cell once HashMap<Long, Double> computeCosts = new HashMap<Long, Double>(); - for( Long hopID : part.getRoots() ) - rGetComputeCosts(memo.getHopRefs().get(hopID), part.getPartition(), computeCosts); + for( Long hopID : part.getPartition() ) + getComputeCosts(memo.getHopRefs().get(hopID), computeCosts); //prepare pruning helpers and prune memo table w/ determined mat points StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo), @@ -595,7 +568,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection boolean ret = true; MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); for(int i=0; i<3; i++) - if( me.isPlanRef(i) ) + if( me!=null && me.isPlanRef(i) ) ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited); ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp); @@ -603,6 +576,69 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection return ret; } + private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) + { + //prune invalid row entries w/ violated blocksize constraint + if( OptimizerUtils.isSparkExecutionMode() ) { + for( Long hopID : part.getPartition() ) { + if( !memo.contains(hopID, TemplateType.ROW) ) + continue; + Hop hop = memo.getHopRefs().get(hopID); + boolean isSpark = DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK + || OptimizerUtils.getTotalMemEstimate(hop.getInput().toArray(new Hop[0]), hop) + > OptimizerUtils.getLocalMemBudget(); + boolean validNcol = true; + for( Hop in : hop.getInput() ) + validNcol &= in.getDataType().isScalar() + || (in.getDim2() <= in.getColsInBlock()) + || (hop instanceof AggBinaryOp && in.getDim1() <= in.getRowsInBlock() + && HopRewriteUtils.isTransposeOperation(in)); + if( isSpark && !validNcol ) { + List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); + memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist)); + if( !memo.contains(hopID) ) + memo.removeAllRefTo(hopID); + if( LOG.isTraceEnabled() ) { + LOG.trace("Removed row memo table entries w/ violated blocksize constraint ("+hopID+"): " + + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); + } + } + } + } + + //prune row aggregates with pure cellwise operations + for( Long hopID : part.getPartition() ) { + MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); + if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) + && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) { + List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); + memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist)); + if( LOG.isTraceEnabled() ) { + LOG.trace("Removed row memo table entries w/o aggregation: " + + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); + } + } + } + + //prune suboptimal outer product plans that are dominated by outer product plans w/ same number of + //references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), + //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern. + for( Long hopID : part.getPartition() ) { + if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) { + List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER); + MemoTableEntry me1 = entries.get(0); + MemoTableEntry me2 = entries.get(1); + MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2); + if( rmEntry != null ) { + memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry)); + memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex())); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed dominated outer product memo table entry: " + rmEntry); + } + } + } + } + private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan) { @@ -751,7 +787,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //open template if necessary, including memoization //under awareness of current plan choice MemoTableEntry best = null; - boolean opened = false; + boolean opened = (currentType == null); if( memo.contains(current.getHopID()) ) { //note: this is the inner loop of plan enumeration and hence, we do not //use streams, lambda expressions, etc to avoid unnecessary overhead @@ -836,16 +872,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection return costs; } - private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) + private static void getComputeCosts(Hop current, HashMap<Long, Double> computeCosts) { - if( computeCosts.containsKey(current.getHopID()) - || !partition.contains(current.getHopID()) ) - return; - - //recursively process children - for( Hop c : current.getInput() ) - rGetComputeCosts(c, partition, computeCosts); - //get costs for given hop double costs = 1; if( current instanceof UnaryOp ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 4078060..4adec25 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -77,7 +77,8 @@ public class CPlanMemoTable } public boolean contains(long hopID) { - return _plans.containsKey(hopID); + return _plans.containsKey(hopID) + && !_plans.get(hopID).isEmpty(); } public boolean contains(long hopID, TemplateType type) { @@ -151,6 +152,17 @@ public class CPlanMemoTable .removeIf(p -> blackList.contains(p)); } + public void removeAllRefTo(long hopID) { + //recursive removal of references + for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() ) { + if( !e.getValue().isEmpty() ) { + e.getValue().removeIf(p -> p.hasPlanRefTo(hopID)); + if( e.getValue().isEmpty() ) + removeAllRefTo(e.getKey()); + } + } + } + public void setDistinct(long hopID, List<MemoTableEntry> plans) { _plans.put(hopID, plans.stream() .distinct().collect(Collectors.toList())); @@ -354,6 +366,9 @@ public class CPlanMemoTable public boolean hasPlanRef() { return isPlanRef(0) || isPlanRef(1) || isPlanRef(2); } + public boolean hasPlanRefTo(long hopID) { + return (input1==hopID || input2==hopID || input3==hopID); + } public int countPlanRefs() { return ((input1 >= 0) ? 1 : 0) + ((input2 >= 0) ? 1 : 0) http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index eae5560..90e2184 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -203,8 +203,13 @@ public class SpoofSPInstruction extends SPInstruction } } else if( _class.getSuperclass() == SpoofRowwise.class ) { //row aggregate operator + if( mcIn.getCols() > mcIn.getColsPerBlock() ) { + throw new DMLRuntimeException("Invalid spark rowwise operator w/ ncol=" + + mcIn.getCols()+", ncolpb="+mcIn.getColsPerBlock()+"."); + } SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class); - RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars, (int)mcIn.getCols()); + RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), + _classBytes, bcMatrices, scalars, (int)mcIn.getCols()); out = in.mapPartitionsToPair(fmmc, op.getRowType()==RowType.ROW_AGG || op.getRowType() == RowType.NO_AGG); http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java index 803ec93..a48c84c 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java @@ -123,6 +123,66 @@ public class AlgorithmGLM extends AutomatedTestBase runGLMTest(GLMType.BINOMIAL_PROBIT, false, true, ExecType.CP); } + @Test + public void testGLMPoissonDenseRewritesSP() { + runGLMTest(GLMType.POISSON_LOG, true, false, ExecType.SPARK); + } + + @Test + public void testGLMPoissonSparseRewritesSP() { + runGLMTest(GLMType.POISSON_LOG, true, true, ExecType.SPARK); + } + + @Test + public void testGLMPoissonDenseSP() { + runGLMTest(GLMType.POISSON_LOG, false, false, ExecType.SPARK); + } + + @Test + public void testGLMPoissonSparseSP() { + runGLMTest(GLMType.POISSON_LOG, false, true, ExecType.SPARK); + } + + @Test + public void testGLMGammaDenseRewritesSP() { + runGLMTest(GLMType.GAMMA_LOG, true, false, ExecType.SPARK); + } + + @Test + public void testGLMGammaSparseRewritesSP() { + runGLMTest(GLMType.GAMMA_LOG, true, true, ExecType.SPARK); + } + + @Test + public void testGLMGammaDenseSP() { + runGLMTest(GLMType.GAMMA_LOG, false, false, ExecType.SPARK); + } + + @Test + public void testGLMGammaSparseSP() { + runGLMTest(GLMType.GAMMA_LOG, false, true, ExecType.SPARK); + } + + @Test + public void testGLMBinomialDenseRewritesSP() { + runGLMTest(GLMType.BINOMIAL_PROBIT, true, false, ExecType.SPARK); + } + + @Test + public void testGLMBinomialSparseRewritesSP() { + runGLMTest(GLMType.BINOMIAL_PROBIT, true, true, ExecType.SPARK); + } + + @Test + public void testGLMBinomialDenseSP() { + runGLMTest(GLMType.BINOMIAL_PROBIT, false, false, ExecType.SPARK); + } + + @Test + public void testGLMBinomialSparseSP() { + runGLMTest(GLMType.BINOMIAL_PROBIT, false, true, ExecType.SPARK); + } + private void runGLMTest( GLMType type, boolean rewrites, boolean sparse, ExecType instType) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java index 729699f..80e4b9f 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java @@ -79,17 +79,25 @@ public class AlgorithmLinregCG extends AutomatedTestBase runLinregCGTest(TEST_NAME1, false, true, ExecType.CP); } - /* + @Test + public void testLinregCGDenseRewritesSP() { + runLinregCGTest(TEST_NAME1, true, false, ExecType.SPARK); + } + + @Test + public void testLinregCGSparseRewritesSP() { + runLinregCGTest(TEST_NAME1, true, true, ExecType.SPARK); + } + @Test public void testLinregCGDenseSP() { - runGDFOTest(TEST_NAME1, false, ExecType.SPARK); + runLinregCGTest(TEST_NAME1, false, false, ExecType.SPARK); } @Test public void testLinregCGSparseSP() { - runGDFOTest(TEST_NAME1, true, ExecType.SPARK); + runLinregCGTest(TEST_NAME1, false, true, ExecType.SPARK); } - */ private void runLinregCGTest( String testname, boolean rewrites, boolean sparse, ExecType instType) {