Repository: systemml Updated Branches: refs/heads/master 26d63806e -> 942696a2f
[SYSTEMML-445] Added rewrites for generating GPU-specific operators - This commit enables SystemML to reuse existing CPU/Spark optimizations (such as codegen) and also to exploit GPU-specific CUDA kernels. - As an initial step, I have only added channel_sums and batch_norm2d_test pattern. The latter will execute the cudnnBatchNormalizationForwardInference kernel. - This commit also adds a GPU+MLContext related bugfix for a specific case: the output variable generated as part of a MLContext session with GPU enabled and passed to another MLContext with GPU disabled. Closes #792. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/942696a2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/942696a2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/942696a2 Branch: refs/heads/master Commit: 942696a2fc352f2c4419508a59807b10777e6b10 Parents: 26d6380 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Thu Jun 28 09:42:36 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Thu Jun 28 09:45:24 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/api/DMLScript.java | 2 +- .../apache/sysml/api/ScriptExecutorUtils.java | 29 ++- .../java/org/apache/sysml/hops/AggUnaryOp.java | 97 +++---- src/main/java/org/apache/sysml/hops/DnnOp.java | 39 ++- src/main/java/org/apache/sysml/hops/Hop.java | 4 +- .../sysml/hops/rewrite/ProgramRewriter.java | 4 + .../hops/rewrite/RewriteGPUSpecificOps.java | 257 +++++++++++++++++++ .../org/apache/sysml/lops/DnnTransform.java | 35 ++- .../context/ExecutionContext.java | 3 + .../instructions/CPInstructionParser.java | 1 - .../instructions/GPUInstructionParser.java | 1 + .../instructions/cp/DnnCPInstruction.java | 42 --- .../instructions/gpu/DnnGPUInstruction.java | 54 ++++ .../org/apache/sysml/test/gpu/AppendTest.java | 2 +- .../apache/sysml/test/gpu/BatchNormTest.java | 75 ++++++ .../apache/sysml/test/gpu/ChannelSumsTest.java | 60 +++++ .../functions/tensor/ChannelSumTest.java | 146 ----------- .../scripts/functions/tensor/ChannelSumTest.R | 39 --- .../scripts/functions/tensor/ChannelSumTest.dml | 35 --- 19 files changed, 588 insertions(+), 337 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/api/DMLScript.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java index c737e92..215d082 100644 --- a/src/main/java/org/apache/sysml/api/DMLScript.java +++ b/src/main/java/org/apache/sysml/api/DMLScript.java @@ -743,7 +743,7 @@ public class DMLScript ExecutionContext ec = null; try { ec = ExecutionContextFactory.createContext(rtprog); - ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, dmlconf, STATISTICS ? STATISTICS_COUNT : 0); + ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, dmlconf, STATISTICS ? STATISTICS_COUNT : 0, null); } finally { if(ec != null && ec instanceof SparkExecutionContext) http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java index 2d913b6..0b4c7ab 100644 --- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java +++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java @@ -20,6 +20,7 @@ package org.apache.sysml.api; import java.util.List; +import java.util.Set; import org.apache.sysml.api.mlcontext.ScriptExecutor; import org.apache.sysml.conf.ConfigurationManager; @@ -28,9 +29,12 @@ import org.apache.sysml.hops.codegen.SpoofCompiler; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; +import org.apache.sysml.runtime.instructions.gpu.context.GPUObject; import org.apache.sysml.utils.NativeHelper; import org.apache.sysml.utils.Statistics; @@ -50,7 +54,7 @@ public class ScriptExecutorUtils { Program prog = se.getRuntimeProgram(); ExecutionContext ec = se.getExecutionContext(); DMLConfig config = se.getConfig(); - executeRuntimeProgram(prog, ec, config, statisticsMaxHeavyHitters); + executeRuntimeProgram(prog, ec, config, statisticsMaxHeavyHitters, se.getScript().getOutputVariables()); } /** @@ -66,8 +70,10 @@ public class ScriptExecutorUtils { * dml configuration * @param statisticsMaxHeavyHitters * maximum number of statistics to print + * @param outputVariables + * output variables that were registered as part of MLContext */ - public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DMLConfig dmlconf, int statisticsMaxHeavyHitters) { + public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DMLConfig dmlconf, int statisticsMaxHeavyHitters, Set<String> outputVariables) { // Whether extra statistics useful for developers and others interested // in digging into performance problems are recorded and displayed DMLScript.FINEGRAINED_STATISTICS = DMLScript.STATISTICS && dmlconf.getBooleanValue(DMLConfig.EXTRA_FINEGRAINED_STATS); @@ -103,6 +109,25 @@ public class ScriptExecutorUtils { throw e; } finally { // ensure cleanup/shutdown if (DMLScript.USE_ACCELERATOR && !ec.getGPUContexts().isEmpty()) { + // ----------------------------------------------------------------- + // The below code pulls the output variables on the GPU to the host. This is required especially when: + // The output variable was generated as part of a MLContext session with GPU enabled + // and was passed to another MLContext with GPU disabled + // The above scenario occurs in our gpu test suite (eg: BatchNormTest). + if(outputVariables != null) { + for(String outVar : outputVariables) { + Data data = ec.getVariable(outVar); + if(data != null && data instanceof MatrixObject) { + for(GPUContext gCtx : ec.getGPUContexts()) { + GPUObject gpuObj = ((MatrixObject)data).getGPUObject(gCtx); + if(gpuObj != null && gpuObj.isDirty()) { + gpuObj.acquireHostRead(null); + } + } + } + } + } + // ----------------------------------------------------------------- for(GPUContext gCtx : ec.getGPUContexts()) { gCtx.clearTemporaryMemory(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index 1c39787..4e6cf95 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -25,7 +25,6 @@ import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.lops.Aggregate; import org.apache.sysml.lops.Aggregate.OperationTypes; import org.apache.sysml.lops.Binary; -import org.apache.sysml.lops.DnnTransform; import org.apache.sysml.lops.Group; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.PartialAggregate; @@ -112,20 +111,6 @@ public class AggUnaryOp extends MultiThreadedHop return false; } - /** - * Checks if channels sum rewrite is applicable - * - * @return returns true for pattern rowSums(matrix(colSums(X), rows=.., cols=..)) else false - */ - private boolean isChannelSumRewriteApplicable() { - if( OptimizerUtils.ALLOW_OPERATOR_FUSION && _op == AggOp.SUM && _direction == Direction.Row - && getInput().get(0) instanceof ReorgOp && ((ReorgOp)getInput().get(0)).getOp() == ReOrgOp.RESHAPE) { - Hop input1 = getInput().get(0).getInput().get(0); - return input1 instanceof AggUnaryOp && ((AggUnaryOp)input1)._op == AggOp.SUM && ((AggUnaryOp)input1)._direction == Direction.Col; - } - return false; - } - @Override public Lop constructLops() { @@ -140,58 +125,42 @@ public class AggUnaryOp extends MultiThreadedHop if ( et == ExecType.CP || et == ExecType.GPU ) { - Lop agg1 = null; - long numChannels = isChannelSumRewriteApplicable() ? Hop.computeSizeInformation(getInput().get(0).getInput().get(1)) : -1; - if(numChannels > 0 && numChannels < 1000000) { - // Apply channel sums only if rewrite is applicable and if the dimension of C is known at compile time - // and if numChannels is less than 8 MB. - ReorgOp in = ((ReorgOp)getInput().get(0)); - agg1 = new DnnTransform( - in.getInput().get(0).getInput().get(0).constructLops(), - in.getInput().get(1).constructLops(), - in.getInput().get(2).constructLops(), - DnnTransform.OperationTypes.CHANNEL_SUMS, getDataType(), getValueType(), et, -1); - agg1.getOutputParameters().setDimensions(numChannels, 1, getRowsInBlock(), getColsInBlock(), -1); - setLineNumbers(agg1); - setLops(agg1); + Lop agg1 = null; + if( isTernaryAggregateRewriteApplicable() ) { + agg1 = constructLopsTernaryAggregateRewrite(et); } - else { - if( isTernaryAggregateRewriteApplicable() ) { - agg1 = constructLopsTernaryAggregateRewrite(et); - } - else if( isUnaryAggregateOuterCPRewriteApplicable() ) - { - OperationTypes op = HopsAgg2Lops.get(_op); - DirectionTypes dir = HopsDirection2Lops.get(_direction); - - BinaryOp binput = (BinaryOp)getInput().get(0); - agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), - binput.getInput().get(1).constructLops(), op, dir, - HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP); - PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); - - if (getDataType() == DataType.SCALAR) { - UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), - getDataType(), getValueType()); - unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); - setLineNumbers(unary1); - agg1 = unary1; - } - - } - else { //general case - int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); - agg1 = new PartialAggregate(input.constructLops(), - HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k); - } - - setOutputDimensions(agg1); - setLineNumbers(agg1); - setLops(agg1); - + else if( isUnaryAggregateOuterCPRewriteApplicable() ) + { + OperationTypes op = HopsAgg2Lops.get(_op); + DirectionTypes dir = HopsDirection2Lops.get(_direction); + + BinaryOp binput = (BinaryOp)getInput().get(0); + agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), + binput.getInput().get(1).constructLops(), op, dir, + HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP); + PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); + if (getDataType() == DataType.SCALAR) { - agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz()); + UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), + getDataType(), getValueType()); + unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); + setLineNumbers(unary1); + agg1 = unary1; } + + } + else { //general case + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); + agg1 = new PartialAggregate(input.constructLops(), + HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k); + } + + setOutputDimensions(agg1); + setLineNumbers(agg1); + setLops(agg1); + + if (getDataType() == DataType.SCALAR) { + agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz()); } } else if( et == ExecType.MR ) http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/DnnOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java index 0a0e50a..8dbbeda 100644 --- a/src/main/java/org/apache/sysml/hops/DnnOp.java +++ b/src/main/java/org/apache/sysml/hops/DnnOp.java @@ -31,6 +31,7 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.DnnParameters; + import java.util.ArrayList; public class DnnOp extends MultiThreadedHop @@ -133,6 +134,18 @@ public class DnnOp extends MultiThreadedHop } // break; } + case BATCH_NORM2D_TEST: + case CHANNEL_SUMS: + { + if(et == ExecType.GPU) { + setLops(constructDnnLops(et, inputs)); + break; + } + else { + throw new HopsException("Unimplemented DnnOp for execution type: " + et.name()); + } + // break; + } default: throw new HopsException("Unsupported lops construction for operation type '"+op+"'."); } @@ -158,6 +171,10 @@ public class DnnOp extends MultiThreadedHop case BIASADD: case BIASMULT: return 2; + case BATCH_NORM2D_TEST: + return 6; + case CHANNEL_SUMS: + return 3; default: return 13; } @@ -505,13 +522,20 @@ public class DnnOp extends MultiThreadedHop // [numRows, numCols, NNZ] long[] ret = new long[3]; - if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT) { + if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) { MatrixCharacteristics[] mc = memo.getAllInputStats(getInput()); ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1; ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1; ret[2] = -1; return (ret[0]>=0 && ret[1]>=0) ? ret : null; } + else if(op == OpOpDnn.CHANNEL_SUMS) { + long numChannels = Hop.computeSizeInformation(getInput().get(1)); + ret[0] = numChannels; + ret[1] = 1; + ret[2] = -1; + return ret; + } refreshSizeInformation(); ret[0] = _dim1; ret[1] = _dim2; ret[2] = _nnz; @@ -708,13 +732,20 @@ public class DnnOp extends MultiThreadedHop @Override public void refreshSizeInformation() { - if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT) { + if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) { Hop input1 = getInput().get(0); setDim1(input1.getDim1()); setDim2(input1.getDim2()); _nnz = -1; // cannot infer stats return; } + else if(op == OpOpDnn.CHANNEL_SUMS) { + long numChannels = Hop.computeSizeInformation(getInput().get(1)); + setDim1(numChannels); + setDim2(1); + _nnz = -1; // cannot infer stats + return; + } // Reset the _cachedParams to avoid incorrect sizes _cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads); @@ -807,8 +838,8 @@ public class DnnOp extends MultiThreadedHop * @return either -1 or value associated with the dimString */ private long getDim(String dimString) { - if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT) { - throw new RuntimeException("getDim method should not be invoked for bias_add and bias_multiply"); + if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS) { + throw new RuntimeException("getDim method should not be invoked for batch_norm_test, channel_sums, bias_add and bias_multiply"); } try { parseInput(); http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 436d45f..5d357c6 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1099,7 +1099,7 @@ public abstract class Hop implements ParseInfo public enum OpOpDnn { MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, - BIASADD, BIASMULT + BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS } public enum DataGenMethod { @@ -1172,6 +1172,8 @@ public abstract class Hop implements ParseInfo HopsConv2Lops.put(OpOpDnn.BIASMULT, org.apache.sysml.lops.DnnTransform.OperationTypes.BIAS_MULTIPLY); HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_FILTER, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_FILTER); HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA); + HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST); + HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS); } protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops; http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 2963e9d..3d4eafd 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -24,6 +24,7 @@ import java.util.List; import org.apache.log4j.Level; import org.apache.log4j.Logger; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.CompilerConfig.ConfigType; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; @@ -121,6 +122,9 @@ public class ProgramRewriter // DYNAMIC REWRITES (which do require size information) if( dynamicRewrites ) { + if ( DMLScript.USE_ACCELERATOR ){ + _dagRuleSet.add( new RewriteGPUSpecificOps() ); // gpu-specific rewrites + } if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) { _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java new file mode 100644 index 0000000..987d9cd --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.OpOpDnn; +import org.apache.sysml.hops.Hop.ReOrgOp; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.DnnOp; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; + +/* + * This class contains GPU-specific rewrites for following patterns: + * + * 1. batchNormTest: + * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) + * hi = bias_add(bias_multiply(norm, gamma), beta) + * + * 2. channelSum: + * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize)) + */ +public class RewriteGPUSpecificOps extends HopRewriteRule { + + @Override + public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { + if( roots == null ) + return roots; + + //one pass rewrite-descend (rewrite created pattern) + for( Hop h : roots ) + rule_GPUKernels( h, false ); + Hop.resetVisitStatus(roots, true); + + //one pass descend-rewrite (for rollup) + for( Hop h : roots ) + rule_GPUKernels( h, true ); + Hop.resetVisitStatus(roots, true); + + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if( root == null ) + return root; + + //one pass rewrite-descend (rewrite created pattern) + rule_GPUKernels( root, false ); + + root.resetVisitStatus(); + + //one pass descend-rewrite (for rollup) + rule_GPUKernels( root, true ); + + return root; + } + + /** + * Fuse the kernel + * + * @param hop high-level operator + * @param descendFirst true if recursively process children first + */ + private void rule_GPUKernels(Hop hop, boolean descendFirst) + { + if(hop.isVisited()) + return; + + //recursively process children + for( int i=0; i<hop.getInput().size(); i++) + { + Hop hi = hop.getInput().get(i); + + //process childs recursively first (to allow roll-up) + if( descendFirst ) + rule_GPUKernels(hi, descendFirst); //see below + + hi = batchNormTest(hop, hi, i); + hi = channelSums(hop, hi, i); + + if( !descendFirst ) + rule_GPUKernels(hi, descendFirst); + } + + hop.setVisited(); + } + + private static boolean isBiasAdd(Hop h) { + return h instanceof DnnOp && ((DnnOp) h).getOp() == OpOpDnn.BIASADD; + } + + private static boolean isBiasMultiply(Hop h) { + return h instanceof DnnOp && ((DnnOp) h).getOp() == OpOpDnn.BIASMULT; + } + + private static boolean fitsOnGPU(Hop h, double multiplier) { + double memEst = multiplier*h.getMemEstimate(); + return DMLScript.USE_ACCELERATOR && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() && + memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget(); + } + + private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput) { + return fitsOnGPU(inputHops, isFirstSameSizeAsOutput, 0); + } + + private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput, long additionalBytes) { + double memEst = additionalBytes; + boolean isFirst = true; + for(Hop h : inputHops) { + double est = h.getMemEstimate(); + if(est == OptimizerUtils.INVALID_SIZE) { + return false; + } + else if(isFirst && isFirstSameSizeAsOutput) { + isFirst = false; + memEst += 2*est; + } + else { + memEst += est; + } + } + return DMLScript.USE_ACCELERATOR && OptimizerUtils.isMemoryBasedOptLevel() && + memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget(); + } + + private static Hop getFirstInput(Hop h) { + if(h == null || h.getInput() == null || h.getInput().size() < 1) { + throw new RuntimeException("No input available for " + h); + } + return h.getInput().get(0); + } + + private static Hop getSecondInput(Hop h) { + if(h == null || h.getInput() == null || h.getInput().size() < 2) { + throw new RuntimeException("No input available for " + h); + } + return h.getInput().get(1); + } + + private static boolean isUnaryMinus(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS + && Hop.computeSizeInformation(h.getInput().get(0)) == 0; + } + + private static boolean isOneDivideBySqrt(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV + && h.getInput().get(1) instanceof UnaryOp + && ((UnaryOp)h.getInput().get(1)).getOp() == OpOp1.SQRT + && Hop.computeSizeInformation(h.getInput().get(0)) == 1; + } + + private static Hop channelSums(Hop parent, Hop hi, int pos) + { + if(hi instanceof AggUnaryOp) { + AggUnaryOp hop = (AggUnaryOp) hi; + // output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize)) + if(hop.getOp() == AggOp.SUM && hop.getDirection() == Direction.Row + && hop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)hop.getInput().get(0)).getOp() == ReOrgOp.RESHAPE) { + Hop colSumsInput = hop.getInput().get(0).getInput().get(0); + if(colSumsInput instanceof AggUnaryOp && ((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && ((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) { + ArrayList<Hop> inHops = new ArrayList<Hop>(); + inHops.add(colSumsInput.getInput().get(0)); + long numChannels = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(1)); + long HW = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(2)); + if(numChannels > 0 && HW > 0 && fitsOnGPU(inHops, false, numChannels*8)) { + inHops.add(new LiteralOp(numChannels)); + inHops.add(new LiteralOp(HW)); + LOG.debug("Applied channelSums rewrite."); + Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), + OpOpDnn.CHANNEL_SUMS, inHops); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + } + } + } + } + return hi; + } + + private static Hop batchNormTest(Hop parent, Hop hi, int pos) + { + // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) + // hi = bias_add(bias_multiply(norm, gamma), beta) + // 2x for input and output and 1x for overhead + if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) { + Hop norm = getFirstInput(getFirstInput(hi)); + if(isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) + && isUnaryMinus(getSecondInput(getFirstInput(norm))) + && isOneDivideBySqrt(getSecondInput(norm))) { + double eps = 0; + Hop var = getFirstInput(getSecondInput(getSecondInput(norm))); + if(var instanceof BinaryOp && ((BinaryOp) var).getOp() == OpOp2.PLUS && + (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) { + // eps + ema_var + if(getFirstInput(var) instanceof LiteralOp) { + eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>()); + var = getSecondInput(var); + } + else { + eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>()); + var = getFirstInput(var); + } + } + // Generate batch norm test op + Hop X = getFirstInput(getFirstInput(norm)); + Hop mean = getSecondInput(getSecondInput(getFirstInput(norm))); + Hop gamma = getSecondInput(getFirstInput(hi)); + Hop beta = getSecondInput(hi); + ArrayList<Hop> inHops = new ArrayList<Hop>(); + inHops.add(X); + inHops.add(gamma); + inHops.add(beta); + inHops.add(mean); + inHops.add(var); + inHops.add(new LiteralOp(eps)); + if(fitsOnGPU(inHops, true)) { + LOG.debug("Applied batchNormTest rewrite."); + Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), + OpOpDnn.BATCH_NORM2D_TEST, inHops); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + } + } + } + + return hi; + } + +} http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/lops/DnnTransform.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java index 02dcec1..6c61d4a 100644 --- a/src/main/java/org/apache/sysml/lops/DnnTransform.java +++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java @@ -31,7 +31,7 @@ public class DnnTransform extends Lop MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD, RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, - BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS + BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST } private OperationTypes operation; @@ -166,6 +166,9 @@ public class DnnTransform extends Lop case CHANNEL_SUMS: return "channel_sums"; + case BATCH_NORM2D_TEST: + return "batch_norm2d_test"; + default: throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Transform operation " + operation); @@ -242,6 +245,36 @@ public class DnnTransform extends Lop return sb.toString(); } + + public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String output) { + if(operation == OperationTypes.BATCH_NORM2D_TEST) { + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( getOpcode() ); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(0).prepInputOperand(input1)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(1).prepInputOperand(input2)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(2).prepInputOperand(input3)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(3).prepInputOperand(input4)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(4).prepInputOperand(input5)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(5).prepInputOperand(input6)); + //output + sb.append( OPERAND_DELIMITOR ); + sb.append( this.prepOutputOperand(output)); + + return sb.toString(); + } + else { + throw new LopsException("The operation is not supported with six operands:" + operation.name()); + } + } public void appendOpcode(StringBuilder sb) { sb.append( getExecType() ); http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index d87f9d9..77598bf 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -53,7 +53,9 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.MetaDataFormat; import org.apache.sysml.runtime.matrix.MetaData; import org.apache.sysml.runtime.matrix.data.FrameBlock; +import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.runtime.matrix.data.Pair; import org.apache.sysml.runtime.util.MapReduceTool; import org.apache.sysml.utils.GPUStatistics; @@ -462,6 +464,7 @@ public class ExecutionContext { if(mo.getGPUObject(getGPUContext(0)) == null || !mo.getGPUObject(getGPUContext(0)).isAllocated()) { throw new DMLRuntimeException("No output is allocated on GPU"); } + setMetaData(varName, new MetaDataFormat(mo.getMatrixCharacteristics(), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); mo.getGPUObject(getGPUContext(0)).releaseOutput(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index fcc27e9..4663671 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -248,7 +248,6 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "conv2d_backward_data" , CPType.Dnn); String2CPInstructionType.put( "bias_add" , CPType.Dnn); String2CPInstructionType.put( "bias_multiply" , CPType.Dnn); - String2CPInstructionType.put( "channel_sums" , CPType.Dnn); String2CPInstructionType.put( "batch_norm2d", CPType.Dnn); String2CPInstructionType.put( "batch_norm2d_backward", CPType.Dnn); http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index 59c7350..1122a24 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -60,6 +60,7 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "lstm_backward", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn); // Matrix Multiply Operators String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary); http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java index 4532240..e54b430 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java @@ -80,13 +80,6 @@ public class DnnCPInstruction extends UnaryCPInstruction { } } - public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) { - this(in, in2, in3, out, null, null, null, null, numThreads, intermediateMemoryBudget, opcode, istr); - if( !opcode.equals("channel_sums") ) { - throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode); - } - } - private DnnCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) { @@ -238,14 +231,6 @@ public class DnnCPInstruction extends UnaryCPInstruction { int k = Integer.parseInt(parts[4]); return new DnnCPInstruction(in, in2, out, opcode, str, k, Double.parseDouble(parts[5])); } - else if (opcode.equalsIgnoreCase("channel_sums")) { - InstructionUtils.checkNumFields(parts, 4); - CPOperand in = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand in3 = new CPOperand(parts[3]); - CPOperand out = new CPOperand(parts[4]); - return new DnnCPInstruction(in, in2, in3, out, opcode, str, -1, 0); - } else if (opcode.equalsIgnoreCase("batch_norm2d")) { InstructionUtils.checkNumFields(parts, 13); CPOperand in1 = new CPOperand(parts[1]); // image @@ -358,29 +343,6 @@ public class DnnCPInstruction extends UnaryCPInstruction { ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); } - public void processChannelSumsInstruction(ExecutionContext ec) { - MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); - int C = (int) ec.getScalarInput(_in2.getName(), _in2.getValueType(), _in2.isLiteral()).getLongValue(); - int HW = (int) ec.getScalarInput(_in3.getName(), _in3.getValueType(), _in3.isLiteral()).getLongValue(); - if(C*HW != input.getNumColumns()) { - throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns()); - } - MatrixBlock outputBlock = null; - if(input.isEmpty()) { - outputBlock = new MatrixBlock(C, 1, true); - } - else { - outputBlock = new MatrixBlock(C, 1, false).allocateBlock(); - LibMatrixDNN.channelSums(input, outputBlock, C, HW); - } - - // release inputs/outputs - ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); - ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); - } - - - public void processBatchNorm2dInstruction(ExecutionContext ec) { MatrixBlock image = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); MatrixBlock scale = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); @@ -466,10 +428,6 @@ public class DnnCPInstruction extends UnaryCPInstruction { processReluBackwardInstruction(ec); return; } - else if (instOpcode.equalsIgnoreCase("channel_sums")) { - processChannelSumsInstruction(ec); - return; - } else if (instOpcode.equalsIgnoreCase("batch_norm2d")) { processBatchNorm2dInstruction(ec); return; http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java index 709be6c..b01b8d8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java @@ -151,6 +151,23 @@ public class DnnGPUInstruction extends GPUInstruction { _intermediateMemoryBudget = intermediateMemoryBudget; } + public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, + CPOperand out, String opcode, String istr, double intermediateMemoryBudget) { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + if( !opcode.equals("batch_norm2d_test") ) { + throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be batch_norm2d_test, but found " + opcode); + } + _input1 = in; + _input2 = in2; + _input3 = in3; + _input4 = in4; + _input5 = in5; + _input6 = in6; + _gputype = GPUINSTRUCTION_TYPE.Dnn; + _output = out; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + public static DnnGPUInstruction parseInstruction(String str) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; @@ -323,6 +340,17 @@ public class DnnGPUInstruction extends GPUInstruction { CPOperand out3 = new CPOperand(parts[9]); // dBias return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); } + else if (opcode.equalsIgnoreCase("batch_norm2d_test")) { + InstructionUtils.checkNumFields(parts, 7); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand in5 = new CPOperand(parts[5]); + CPOperand in6 = new CPOperand(parts[6]); + CPOperand out = new CPOperand(parts[7]); + return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0); + } else { throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str); } @@ -392,6 +420,28 @@ public class DnnGPUInstruction extends GPUInstruction { ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } + public void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException { + GPUStatistics.incrementNoOfExecutedGPUInst(); + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); + MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName()); + MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName()); + double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue(); + + MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); + LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(), + image, scale, bias, runningMean, runningVar, ret, epsilon); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + ec.releaseMatrixInputForGPUInstruction(_input4.getName()); + ec.releaseMatrixInputForGPUInstruction(_input5.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { GPUStatistics.incrementNoOfExecutedGPUInst(); MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); @@ -613,6 +663,10 @@ public class DnnGPUInstruction extends GPUInstruction { processBatchNorm2dBackwardInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) { + processBatchNorm2dTestInstruction(ec); + return; + } GPUStatistics.incrementNoOfExecutedGPUInst(); http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/gpu/AppendTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/AppendTest.java b/src/test/java/org/apache/sysml/test/gpu/AppendTest.java index f359d48..3f2fdd6 100644 --- a/src/test/java/org/apache/sysml/test/gpu/AppendTest.java +++ b/src/test/java/org/apache/sysml/test/gpu/AppendTest.java @@ -32,7 +32,7 @@ import org.junit.Test; */ public class AppendTest extends GPUTests { - private final static String TEST_NAME = "BinaryOpTests"; + private final static String TEST_NAME = "AppendTests"; private final int seed = 42; private final int[] rowSizes = new int[] { 1, 64, 2049 }; http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java new file mode 100644 index 0000000..83adad4 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.gpu; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Test; + +/** + * Tests batchnorm rewrite + */ +public class BatchNormTest extends GPUTests { + + private final static String TEST_NAME = "BatchNormTests"; + private final int seed = 42; + + @Override + public void setUp() { + super.setUp(); + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testBatchNormForwardTest() { + testBatchNormForward("test"); + } + +// @Test +// public void testBatchNormForwardTrain() { +// testBatchNormForward("train"); +// } + + private void testBatchNormForward(String mode) { + int imgSize = 32; + int numChannels = 3; + double sparsity = 0.9; + String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") as batch_norm2d_old;\n " + + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)"; + HashMap<String, Object> inputs = new HashMap<>(); + inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 100, sparsity, seed)); + inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed)); + inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed)); + inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 40, 60, sparsity, seed)); + inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 5, 15, sparsity, seed)); + List<String> outputs = Arrays.asList("output", "ema_mean_upd", "ema_var_upd", "cache_mean", "cache_var"); + List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs); + List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); + if(mode.equals("test")) + assertHeavyHitterPresent("gpu_batch_norm2d_test"); + for(int i = 0; i < outputs.size(); i++) { + assertEqualObjects(outCPU.get(i), outGPU.get(i)); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java b/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java new file mode 100644 index 0000000..64e3061 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.gpu; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Test; + +/** + * Tests channel sums rewrite + */ +public class ChannelSumsTest extends GPUTests { + + private final static String TEST_NAME = "ChannelSumsTest"; + private final int seed = 42; + + @Override + public void setUp() { + super.setUp(); + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testChannelSumsTest() { + int imgSize = 32; + int numChannels = 10; + double sparsity = 0.9; + String scriptStr = "output = rowSums(matrix(colSums(x), rows=" + numChannels + ", cols=" + imgSize*imgSize + "));"; + HashMap<String, Object> inputs = new HashMap<>(); + inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 100, sparsity, seed)); + List<String> outputs = Arrays.asList("output"); + List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs); + List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); + assertHeavyHitterPresent("gpu_channel_sums"); + for(int i = 0; i < outputs.size(); i++) { + assertEqualObjects(outCPU.get(i), outGPU.get(i)); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java deleted file mode 100644 index 61ca370..0000000 --- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.sysml.test.integration.functions.tensor; - -import java.util.HashMap; - -import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; -import org.apache.sysml.lops.LopProperties.ExecType; -import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; -import org.apache.sysml.test.integration.AutomatedTestBase; -import org.apache.sysml.test.integration.TestConfiguration; -import org.apache.sysml.test.utils.TestUtils; -import org.junit.Test; - -public class ChannelSumTest extends AutomatedTestBase -{ - - private final static String TEST_NAME = "ChannelSumTest"; - private final static String TEST_DIR = "functions/tensor/"; - private final static String TEST_CLASS_DIR = TEST_DIR + PoolTest.class.getSimpleName() + "/"; - private final static double epsilon=0.0000000001; - - @Override - public void setUp() { - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, - new String[] {"B"})); - } - - @Test - public void testChannelSumDense1() - { - int numImg = 10; int imgSize = 9; int numChannels = 5; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false); - } - - @Test - public void testChannelSumDense2() - { - int numImg = 2; int imgSize = 5; int numChannels = 3; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false); - } - - @Test - public void testChannelSumDense3() - { - int numImg = 9; int imgSize = 4; int numChannels = 11; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false); - } - - @Test - public void testChannelSumDense4() - { - int numImg = 7; int imgSize = 8; int numChannels = 12; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false); - } - - @Test - public void testChannelSumSparse1() - { - int numImg = 4; int imgSize = 10; int numChannels = 5; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true); - } - - @Test - public void testChannelSumSparse2() - { - int numImg = 2; int imgSize = 10; int numChannels = 8; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true); - } - - @Test - public void testChannelSumSparse3() - { - int numImg = 4; int imgSize = 10; int numChannels = 11; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true); - } - - @Test - public void testChannelSumSparse4() - { - int numImg = 9; int imgSize = 6; int numChannels = 8; - runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true); - } - - public void runChannelSumTest( ExecType et, int imgSize, int numImg, int numChannels, boolean sparse) - { - RUNTIME_PLATFORM platformOld = rtplatform; - switch( et ){ - case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; - case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; - default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; - } - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if( rtplatform == RUNTIME_PLATFORM.SPARK ) - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - - try - { - String sparseVal = String.valueOf(sparse).toUpperCase(); - - TestConfiguration config = getTestConfiguration(TEST_NAME); - loadTestConfiguration(config); - - String RI_HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = RI_HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-explain", "hops", "-args", String.valueOf(imgSize), - String.valueOf(numImg), String.valueOf(numChannels), - output("B"), sparseVal}; - - fullRScriptName = RI_HOME + TEST_NAME + ".R"; - rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + - " " + numChannels + " " + expectedDir() + " " + sparseVal; - - // run scripts - runTest(true, false, null, -1); - runRScript(true); - - //compare results - HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B"); - HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("B"); - TestUtils.compareMatrices(dmlfile, bHM, epsilon, "B-DML", "NumPy"); - } - finally { - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - } - } - -} http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/scripts/functions/tensor/ChannelSumTest.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.R b/src/test/scripts/functions/tensor/ChannelSumTest.R deleted file mode 100644 index c605074..0000000 --- a/src/test/scripts/functions/tensor/ChannelSumTest.R +++ /dev/null @@ -1,39 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- -args <- commandArgs(TRUE) -library("Matrix") -library("matrixStats") -imgSize=as.integer(args[1]) -numImg=as.integer(args[2]) -numChannels=as.integer(args[3]) - -# Assumption: NCHW image format -x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, numChannels*imgSize*imgSize, byrow=TRUE) -if(as.logical(args[5])) { - zero_mask = (x - 1.5*mean(x)) > 0 - x = x * zero_mask -} else { - x = x - mean(x) -} - -output = rowSums(matrix(colSums(x), numChannels, imgSize*imgSize, byrow=TRUE)); - -writeMM(as(output,"CsparseMatrix"), paste(args[4], "B", sep="")) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/scripts/functions/tensor/ChannelSumTest.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.dml b/src/test/scripts/functions/tensor/ChannelSumTest.dml deleted file mode 100644 index 7810a12..0000000 --- a/src/test/scripts/functions/tensor/ChannelSumTest.dml +++ /dev/null @@ -1,35 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- -imgSize=$1 -numImg=$2 -numChannels=$3 - -# Assumption: NCHW image format -x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, cols=numChannels*imgSize*imgSize) -if($5) { - zero_mask = (x - 1.5*mean(x)) > 0 - x = x * zero_mask -} -else { - x = x - mean(x) -} -output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize)) # shape (C, 1) -write(output, $4, format="text") \ No newline at end of file