Repository: systemml
Updated Branches:
  refs/heads/master 26d63806e -> 942696a2f


[SYSTEMML-445] Added rewrites for generating GPU-specific operators

- This commit enables SystemML to reuse existing CPU/Spark optimizations (such 
as codegen) and also to exploit GPU-specific CUDA kernels.
- As an initial step, I have only added channel_sums and batch_norm2d_test 
pattern. The latter will execute the cudnnBatchNormalizationForwardInference 
kernel.
- This commit also adds a GPU+MLContext related bugfix for a specific case: the 
output variable generated as part of a MLContext session with GPU enabled and 
passed to another MLContext with GPU disabled.

Closes #792.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/942696a2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/942696a2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/942696a2

Branch: refs/heads/master
Commit: 942696a2fc352f2c4419508a59807b10777e6b10
Parents: 26d6380
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Thu Jun 28 09:42:36 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Thu Jun 28 09:45:24 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/api/DMLScript.java    |   2 +-
 .../apache/sysml/api/ScriptExecutorUtils.java   |  29 ++-
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  97 +++----
 src/main/java/org/apache/sysml/hops/DnnOp.java  |  39 ++-
 src/main/java/org/apache/sysml/hops/Hop.java    |   4 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   4 +
 .../hops/rewrite/RewriteGPUSpecificOps.java     | 257 +++++++++++++++++++
 .../org/apache/sysml/lops/DnnTransform.java     |  35 ++-
 .../context/ExecutionContext.java               |   3 +
 .../instructions/CPInstructionParser.java       |   1 -
 .../instructions/GPUInstructionParser.java      |   1 +
 .../instructions/cp/DnnCPInstruction.java       |  42 ---
 .../instructions/gpu/DnnGPUInstruction.java     |  54 ++++
 .../org/apache/sysml/test/gpu/AppendTest.java   |   2 +-
 .../apache/sysml/test/gpu/BatchNormTest.java    |  75 ++++++
 .../apache/sysml/test/gpu/ChannelSumsTest.java  |  60 +++++
 .../functions/tensor/ChannelSumTest.java        | 146 -----------
 .../scripts/functions/tensor/ChannelSumTest.R   |  39 ---
 .../scripts/functions/tensor/ChannelSumTest.dml |  35 ---
 19 files changed, 588 insertions(+), 337 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/api/DMLScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java 
b/src/main/java/org/apache/sysml/api/DMLScript.java
index c737e92..215d082 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -743,7 +743,7 @@ public class DMLScript
                ExecutionContext ec = null;
                try {
                        ec = ExecutionContextFactory.createContext(rtprog);
-                       ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, 
dmlconf, STATISTICS ? STATISTICS_COUNT : 0);
+                       ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, 
dmlconf, STATISTICS ? STATISTICS_COUNT : 0, null);
                }
                finally {
                        if(ec != null && ec instanceof SparkExecutionContext)

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java 
b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
index 2d913b6..0b4c7ab 100644
--- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
+++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
@@ -20,6 +20,7 @@
 package org.apache.sysml.api;
 
 import java.util.List;
+import java.util.Set;
 
 import org.apache.sysml.api.mlcontext.ScriptExecutor;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -28,9 +29,12 @@ import org.apache.sysml.hops.codegen.SpoofCompiler;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.Program;
 import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysml.utils.NativeHelper;
 import org.apache.sysml.utils.Statistics;
 
@@ -50,7 +54,7 @@ public class ScriptExecutorUtils {
                Program prog = se.getRuntimeProgram();
                ExecutionContext ec = se.getExecutionContext();
                DMLConfig config = se.getConfig();
-               executeRuntimeProgram(prog, ec, config, 
statisticsMaxHeavyHitters);
+               executeRuntimeProgram(prog, ec, config, 
statisticsMaxHeavyHitters, se.getScript().getOutputVariables());
        }
 
        /**
@@ -66,8 +70,10 @@ public class ScriptExecutorUtils {
         *            dml configuration
         * @param statisticsMaxHeavyHitters
         *            maximum number of statistics to print
+        * @param outputVariables
+        *            output variables that were registered as part of MLContext
         */
-       public static void executeRuntimeProgram(Program rtprog, 
ExecutionContext ec, DMLConfig dmlconf, int statisticsMaxHeavyHitters) {
+       public static void executeRuntimeProgram(Program rtprog, 
ExecutionContext ec, DMLConfig dmlconf, int statisticsMaxHeavyHitters, 
Set<String> outputVariables) {
                // Whether extra statistics useful for developers and others 
interested
                // in digging into performance problems are recorded and 
displayed
                DMLScript.FINEGRAINED_STATISTICS = DMLScript.STATISTICS && 
dmlconf.getBooleanValue(DMLConfig.EXTRA_FINEGRAINED_STATS);
@@ -103,6 +109,25 @@ public class ScriptExecutorUtils {
                        throw e;
                } finally { // ensure cleanup/shutdown
                        if (DMLScript.USE_ACCELERATOR && 
!ec.getGPUContexts().isEmpty()) {
+                               // 
-----------------------------------------------------------------
+                               // The below code pulls the output variables on 
the GPU to the host. This is required especially when:
+                               // The output variable was generated as part of 
a MLContext session with GPU enabled
+                               // and was passed to another MLContext with GPU 
disabled
+                               // The above scenario occurs in our gpu test 
suite (eg: BatchNormTest).
+                               if(outputVariables != null) {
+                                       for(String outVar : outputVariables) {
+                                               Data data = 
ec.getVariable(outVar);
+                                               if(data != null && data 
instanceof MatrixObject) {
+                                                       for(GPUContext gCtx : 
ec.getGPUContexts()) {
+                                                               GPUObject 
gpuObj = ((MatrixObject)data).getGPUObject(gCtx);
+                                                               if(gpuObj != 
null && gpuObj.isDirty()) {
+                                                                       
gpuObj.acquireHostRead(null);
+                                                               }
+                                                       }
+                                               }
+                                       }
+                               }
+                               // 
-----------------------------------------------------------------
                                for(GPUContext gCtx : ec.getGPUContexts()) {
                                        gCtx.clearTemporaryMemory();
                                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 1c39787..4e6cf95 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -25,7 +25,6 @@ import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.Aggregate;
 import org.apache.sysml.lops.Aggregate.OperationTypes;
 import org.apache.sysml.lops.Binary;
-import org.apache.sysml.lops.DnnTransform;
 import org.apache.sysml.lops.Group;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.lops.PartialAggregate;
@@ -112,20 +111,6 @@ public class AggUnaryOp extends MultiThreadedHop
                return false;
        }
        
-       /**
-        * Checks if channels sum rewrite is applicable
-        * 
-        * @return returns true for pattern rowSums(matrix(colSums(X), rows=.., 
cols=..)) else false
-        */
-       private boolean isChannelSumRewriteApplicable() {
-               if( OptimizerUtils.ALLOW_OPERATOR_FUSION && _op == AggOp.SUM && 
_direction == Direction.Row
-                       && getInput().get(0) instanceof ReorgOp && 
((ReorgOp)getInput().get(0)).getOp() == ReOrgOp.RESHAPE) {
-                       Hop input1 = getInput().get(0).getInput().get(0);
-                       return input1 instanceof AggUnaryOp && 
((AggUnaryOp)input1)._op == AggOp.SUM && ((AggUnaryOp)input1)._direction == 
Direction.Col;
-               }
-               return false;
-       }
-       
        @Override
        public Lop constructLops()
        {
@@ -140,58 +125,42 @@ public class AggUnaryOp extends MultiThreadedHop
                        
                        if ( et == ExecType.CP || et == ExecType.GPU ) 
                        {
-                               Lop agg1 = null;
-                               long numChannels = 
isChannelSumRewriteApplicable() ? 
Hop.computeSizeInformation(getInput().get(0).getInput().get(1)) : -1;
-                               if(numChannels > 0 && numChannels < 1000000) {
-                                       // Apply channel sums only if rewrite 
is applicable and if the dimension of C is known at compile time
-                                       // and if numChannels is less than 8 MB.
-                                       ReorgOp in = 
((ReorgOp)getInput().get(0));
-                                       agg1 = new DnnTransform(
-                                                       
in.getInput().get(0).getInput().get(0).constructLops(), 
-                                                       
in.getInput().get(1).constructLops(),
-                                                       
in.getInput().get(2).constructLops(),
-                                                       
DnnTransform.OperationTypes.CHANNEL_SUMS, getDataType(), getValueType(), et, 
-1);
-                                       
agg1.getOutputParameters().setDimensions(numChannels, 1, getRowsInBlock(), 
getColsInBlock(), -1);
-                                       setLineNumbers(agg1);
-                                       setLops(agg1);
+                               Lop agg1 = null; 
+                               if( isTernaryAggregateRewriteApplicable() ) {
+                                       agg1 = 
constructLopsTernaryAggregateRewrite(et);
                                }
-                               else { 
-                                       if( 
isTernaryAggregateRewriteApplicable() ) {
-                                               agg1 = 
constructLopsTernaryAggregateRewrite(et);
-                                       }
-                                       else if( 
isUnaryAggregateOuterCPRewriteApplicable() )
-                                       {
-                                               OperationTypes op = 
HopsAgg2Lops.get(_op);
-                                               DirectionTypes dir = 
HopsDirection2Lops.get(_direction);
-       
-                                               BinaryOp binput = 
(BinaryOp)getInput().get(0);
-                                               agg1 = new UAggOuterChain( 
binput.getInput().get(0).constructLops(), 
-                                                               
binput.getInput().get(1).constructLops(), op, dir, 
-                                                               
HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), 
ExecType.CP);
-                                               
PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), 
input.getRowsInBlock(), input.getColsInBlock(), dir);
-                                       
-                                               if (getDataType() == 
DataType.SCALAR) {
-                                                       UnaryCP unary1 = new 
UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
-                                                               getDataType(), 
getValueType());
-                                                       
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
-                                                       setLineNumbers(unary1);
-                                                       agg1 = unary1;
-                                               }
-                                       
-                                       }
-                                       else { //general case
-                                               int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-                                               agg1 = new 
PartialAggregate(input.constructLops(), 
-                                                               
HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), 
getDataType(),getValueType(), et, k);
-                                       }
-                                       
-                                       setOutputDimensions(agg1);
-                                       setLineNumbers(agg1);
-                                       setLops(agg1);
-                                       
+                               else if( 
isUnaryAggregateOuterCPRewriteApplicable() )
+                               {
+                                       OperationTypes op = 
HopsAgg2Lops.get(_op);
+                                       DirectionTypes dir = 
HopsDirection2Lops.get(_direction);
+
+                                       BinaryOp binput = 
(BinaryOp)getInput().get(0);
+                                       agg1 = new UAggOuterChain( 
binput.getInput().get(0).constructLops(), 
+                                                       
binput.getInput().get(1).constructLops(), op, dir, 
+                                                       
HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), 
ExecType.CP);
+                                       
PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), 
input.getRowsInBlock(), input.getColsInBlock(), dir);
+                               
                                        if (getDataType() == DataType.SCALAR) {
-                                               
agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), 
getColsInBlock(), getNnz());
+                                               UnaryCP unary1 = new 
UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
+                                                       getDataType(), 
getValueType());
+                                               
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
+                                               setLineNumbers(unary1);
+                                               agg1 = unary1;
                                        }
+                               
+                               }
+                               else { //general case
+                                       int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+                                       agg1 = new 
PartialAggregate(input.constructLops(), 
+                                                       HopsAgg2Lops.get(_op), 
HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k);
+                               }
+                               
+                               setOutputDimensions(agg1);
+                               setLineNumbers(agg1);
+                               setLops(agg1);
+                               
+                               if (getDataType() == DataType.SCALAR) {
+                                       
agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), 
getColsInBlock(), getNnz());
                                }
                        }
                        else if( et == ExecType.MR )

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java 
b/src/main/java/org/apache/sysml/hops/DnnOp.java
index 0a0e50a..8dbbeda 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -31,6 +31,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.DnnParameters;
+
 import java.util.ArrayList;
 
 public class DnnOp extends MultiThreadedHop
@@ -133,6 +134,18 @@ public class DnnOp extends MultiThreadedHop
                                }
                                // break;
                        }
+                       case BATCH_NORM2D_TEST:
+                       case CHANNEL_SUMS:
+                       {       
+                               if(et == ExecType.GPU) {
+                                       setLops(constructDnnLops(et, inputs));
+                                       break;
+                               }
+                               else {
+                                       throw new HopsException("Unimplemented 
DnnOp for execution type: " + et.name());
+                               }
+                               // break;
+                       }
                        default: 
                                throw new HopsException("Unsupported lops 
construction for operation type '"+op+"'.");
                }
@@ -158,6 +171,10 @@ public class DnnOp extends MultiThreadedHop
                        case BIASADD:
                        case BIASMULT:
                                return 2;
+                       case BATCH_NORM2D_TEST:
+                               return 6;
+                       case CHANNEL_SUMS:
+                               return 3;
                        default:
                                return 13;
                }
@@ -505,13 +522,20 @@ public class DnnOp extends MultiThreadedHop
                // [numRows, numCols, NNZ] 
                long[] ret = new long[3];
                
-               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT) {
+               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST) {
                        MatrixCharacteristics[] mc = 
memo.getAllInputStats(getInput());
                        ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
                        ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1;
                        ret[2] = -1;
                        return (ret[0]>=0 && ret[1]>=0) ? ret : null;
                }
+               else if(op == OpOpDnn.CHANNEL_SUMS) {
+                       long numChannels = 
Hop.computeSizeInformation(getInput().get(1));
+                       ret[0] = numChannels;
+                       ret[1] = 1;
+                       ret[2] = -1;
+                       return ret;
+               }
                
                refreshSizeInformation();
                ret[0] = _dim1; ret[1] = _dim2; ret[2] = _nnz;
@@ -708,13 +732,20 @@ public class DnnOp extends MultiThreadedHop
        @Override
        public void refreshSizeInformation()
        {
-               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT) {
+               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST) {
                        Hop input1 = getInput().get(0);
                        setDim1(input1.getDim1());
                        setDim2(input1.getDim2());
                        _nnz = -1; // cannot infer stats
                        return;
                }
+               else if(op == OpOpDnn.CHANNEL_SUMS) {
+                       long numChannels = 
Hop.computeSizeInformation(getInput().get(1));
+                       setDim1(numChannels);
+                       setDim2(1);
+                       _nnz = -1; // cannot infer stats
+                       return;
+               }
                
                // Reset the _cachedParams to avoid incorrect sizes
                _cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, 
-1, -1, -1, -1, _maxNumThreads);
@@ -807,8 +838,8 @@ public class DnnOp extends MultiThreadedHop
         * @return either -1 or value associated with the dimString
         */
        private long getDim(String dimString) {
-               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT) {
-                       throw new RuntimeException("getDim method should not be 
invoked for bias_add and bias_multiply");
+               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS) {
+                       throw new RuntimeException("getDim method should not be 
invoked for batch_norm_test, channel_sums, bias_add and bias_multiply");
                }
                try {
                        parseInput();

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 436d45f..5d357c6 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1099,7 +1099,7 @@ public abstract class Hop implements ParseInfo
        public enum OpOpDnn {
                MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
-               BIASADD, BIASMULT
+               BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS
        }
        
        public enum DataGenMethod {
@@ -1172,6 +1172,8 @@ public abstract class Hop implements ParseInfo
                HopsConv2Lops.put(OpOpDnn.BIASMULT, 
org.apache.sysml.lops.DnnTransform.OperationTypes.BIAS_MULTIPLY);
                HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_FILTER, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_FILTER);
                HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
+               HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, 
org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
+               HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
        }
 
        protected static final HashMap<Hop.Direction, 
org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 2963e9d..3d4eafd 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -24,6 +24,7 @@ import java.util.List;
 
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.CompilerConfig.ConfigType;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.Hop;
@@ -121,6 +122,9 @@ public class ProgramRewriter
                // DYNAMIC REWRITES (which do require size information)
                if( dynamicRewrites )
                {
+                       if ( DMLScript.USE_ACCELERATOR ){
+                               _dagRuleSet.add( new RewriteGPUSpecificOps() ); 
// gpu-specific rewrites
+                       }
                        if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
                                _dagRuleSet.add( new 
RewriteMatrixMultChainOptimization()         ); //dependency: cse
                                _dagRuleSet.add( new 
RewriteElementwiseMultChainOptimization()    ); //dependency: cse

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
new file mode 100644
index 0000000..987d9cd
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOpDnn;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.DnnOp;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+
+/*
+ * This class contains GPU-specific rewrites for following patterns:
+ * 
+ * 1. batchNormTest:
+ * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+ * hi = bias_add(bias_multiply(norm, gamma), beta)
+ * 
+ * 2. channelSum:
+ * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
+ */
+public class RewriteGPUSpecificOps extends HopRewriteRule {
+
+       @Override
+       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
+               if( roots == null )
+                       return roots;
+
+               //one pass rewrite-descend (rewrite created pattern)
+               for( Hop h : roots )
+                       rule_GPUKernels( h, false );
+               Hop.resetVisitStatus(roots, true);
+
+               //one pass descend-rewrite (for rollup) 
+               for( Hop h : roots )
+                       rule_GPUKernels( h, true );
+               Hop.resetVisitStatus(roots, true);
+               
+               return roots;
+       }
+
+       @Override
+       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+               if( root == null )
+                       return root;
+               
+               //one pass rewrite-descend (rewrite created pattern)
+               rule_GPUKernels( root, false );
+               
+               root.resetVisitStatus();
+               
+               //one pass descend-rewrite (for rollup) 
+               rule_GPUKernels( root, true );
+               
+               return root;
+       }
+       
+       /**
+        * Fuse the kernel
+        * 
+        * @param hop high-level operator
+        * @param descendFirst true if recursively process children first
+        */
+       private void rule_GPUKernels(Hop hop, boolean descendFirst) 
+       {
+               if(hop.isVisited())
+                       return;
+               
+               //recursively process children
+               for( int i=0; i<hop.getInput().size(); i++)
+               {
+                       Hop hi = hop.getInput().get(i);
+                       
+                       //process childs recursively first (to allow roll-up)
+                       if( descendFirst )
+                               rule_GPUKernels(hi, descendFirst); //see below
+                       
+                       hi = batchNormTest(hop, hi, i); 
+                       hi = channelSums(hop, hi, i); 
+       
+                       if( !descendFirst )
+                               rule_GPUKernels(hi, descendFirst);
+               }
+
+               hop.setVisited();
+       }
+       
+       private static boolean isBiasAdd(Hop h) {
+               return h instanceof DnnOp && ((DnnOp) h).getOp() == 
OpOpDnn.BIASADD;
+       }
+       
+       private static boolean isBiasMultiply(Hop h) {
+               return h instanceof DnnOp && ((DnnOp) h).getOp() == 
OpOpDnn.BIASMULT;
+       }
+       
+       private static boolean fitsOnGPU(Hop h, double multiplier) {
+               double memEst = multiplier*h.getMemEstimate();
+               return DMLScript.USE_ACCELERATOR && h.dimsKnown() && 
OptimizerUtils.isMemoryBasedOptLevel() &&
+                               memEst < OptimizerUtils.getLocalMemBudget() && 
memEst < GPUContextPool.initialGPUMemBudget();
+       }
+       
+       private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean 
isFirstSameSizeAsOutput) {
+               return fitsOnGPU(inputHops, isFirstSameSizeAsOutput, 0);
+       }
+       
+       private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean 
isFirstSameSizeAsOutput, long additionalBytes) {
+               double memEst = additionalBytes;
+               boolean isFirst = true;
+               for(Hop h : inputHops) {
+                       double est = h.getMemEstimate();
+                       if(est == OptimizerUtils.INVALID_SIZE) {
+                               return false;
+                       }
+                       else if(isFirst && isFirstSameSizeAsOutput) {
+                               isFirst = false;
+                               memEst += 2*est;
+                       }
+                       else {
+                               memEst += est;
+                       }
+               }
+               return DMLScript.USE_ACCELERATOR && 
OptimizerUtils.isMemoryBasedOptLevel() &&
+                               memEst < OptimizerUtils.getLocalMemBudget() && 
memEst < GPUContextPool.initialGPUMemBudget();
+       }
+       
+       private static Hop getFirstInput(Hop h) {
+               if(h == null || h.getInput() == null || h.getInput().size() < 
1) {
+                       throw new RuntimeException("No input available for " + 
h);
+               }
+               return h.getInput().get(0);
+       }
+       
+       private static Hop getSecondInput(Hop h) {
+               if(h == null || h.getInput() == null || h.getInput().size() < 
2) {
+                       throw new RuntimeException("No input available for " + 
h);
+               }
+               return h.getInput().get(1);
+       }
+       
+       private static boolean isUnaryMinus(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MINUS 
+                               && 
Hop.computeSizeInformation(h.getInput().get(0)) == 0;
+       }
+       
+       private static boolean isOneDivideBySqrt(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.DIV 
+                               && h.getInput().get(1) instanceof UnaryOp
+                               && ((UnaryOp)h.getInput().get(1)).getOp() == 
OpOp1.SQRT
+                               && 
Hop.computeSizeInformation(h.getInput().get(0)) == 1;
+       }
+       
+       private static Hop channelSums(Hop parent, Hop hi, int pos) 
+       {
+               if(hi instanceof AggUnaryOp) {
+                       AggUnaryOp hop = (AggUnaryOp) hi;
+                       // output = rowSums(matrix(colSums(x), 
rows=numChannels, cols=imgSize*imgSize))
+                       if(hop.getOp() == AggOp.SUM && hop.getDirection() == 
Direction.Row
+                               && hop.getInput().get(0) instanceof ReorgOp && 
((ReorgOp)hop.getInput().get(0)).getOp() == ReOrgOp.RESHAPE) {
+                               Hop colSumsInput = 
hop.getInput().get(0).getInput().get(0);
+                               if(colSumsInput instanceof AggUnaryOp && 
((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && 
((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) {
+                                       ArrayList<Hop> inHops = new 
ArrayList<Hop>();
+                                       
inHops.add(colSumsInput.getInput().get(0));
+                                       long numChannels = 
Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(1));
+                                       long HW = 
Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(2));
+                                       if(numChannels > 0 && HW > 0 && 
fitsOnGPU(inHops, false, numChannels*8)) {
+                                               inHops.add(new 
LiteralOp(numChannels));
+                                               inHops.add(new LiteralOp(HW));
+                                               LOG.debug("Applied channelSums 
rewrite.");
+                                               Hop newHop = new 
DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
+                                                               
OpOpDnn.CHANNEL_SUMS, inHops);
+                                               return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+                                       }
+                               }
+                       }
+               }
+               return hi;
+       }
+       
+       private static Hop batchNormTest(Hop parent, Hop hi, int pos) 
+       {               
+               // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+               // hi = bias_add(bias_multiply(norm, gamma), beta)
+               // 2x for input and output and 1x for overhead
+               if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && 
fitsOnGPU(hi, 3) ) {  
+                       Hop norm = getFirstInput(getFirstInput(hi));
+                       if(isBiasMultiply(norm) && 
isBiasAdd(getFirstInput(norm)) 
+                                       && 
isUnaryMinus(getSecondInput(getFirstInput(norm)))
+                                       && 
isOneDivideBySqrt(getSecondInput(norm))) {
+                               double eps = 0;
+                               Hop var = 
getFirstInput(getSecondInput(getSecondInput(norm)));
+                               if(var instanceof BinaryOp && ((BinaryOp) 
var).getOp() == OpOp2.PLUS &&
+                                       (getFirstInput(var) instanceof 
LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
+                                       // eps + ema_var
+                                       if(getFirstInput(var) instanceof 
LiteralOp) {
+                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
+                                               var = getSecondInput(var);
+                                       }
+                                       else {
+                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new 
HashMap<>());
+                                               var = getFirstInput(var);
+                                       }
+                               }
+                               // Generate batch norm test op
+                               Hop X = getFirstInput(getFirstInput(norm));
+                               Hop mean = 
getSecondInput(getSecondInput(getFirstInput(norm)));
+                               Hop gamma = getSecondInput(getFirstInput(hi));
+                               Hop beta = getSecondInput(hi);
+                               ArrayList<Hop> inHops = new ArrayList<Hop>();
+                               inHops.add(X);
+                               inHops.add(gamma);
+                               inHops.add(beta);
+                               inHops.add(mean);
+                               inHops.add(var);
+                               inHops.add(new LiteralOp(eps));
+                               if(fitsOnGPU(inHops, true)) {
+                                       LOG.debug("Applied batchNormTest 
rewrite.");
+                                       Hop newHop = new DnnOp(hi.getName(), 
hi.getDataType(), hi.getValueType(),
+                                                       
OpOpDnn.BATCH_NORM2D_TEST, inHops);
+                                       return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+                               }
+                       }                       
+               }
+               
+               return hi;
+       }
+
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java 
b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 02dcec1..6c61d4a 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -31,7 +31,7 @@ public class DnnTransform extends Lop
                MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
                RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
-               BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS
+               BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, 
BATCH_NORM2D_TEST
        }
        
        private OperationTypes operation;
@@ -166,6 +166,9 @@ public class DnnTransform extends Lop
                case CHANNEL_SUMS:
                        return "channel_sums";
                        
+               case BATCH_NORM2D_TEST:
+                       return "batch_norm2d_test";
+                       
                default:
                        throw new 
UnsupportedOperationException(this.printErrorLocation() + "Instruction is not 
defined for Transform operation " + operation);
                                
@@ -242,6 +245,36 @@ public class DnnTransform extends Lop
                
                return sb.toString();
        }
+       
+       public String getInstructions(String input1, String input2, String 
input3, String input4, String input5, String input6, String output) {
+               if(operation == OperationTypes.BATCH_NORM2D_TEST) {
+                       StringBuilder sb = new StringBuilder();
+                       sb.append( getExecType() );
+                       
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getOpcode() );
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(0).prepInputOperand(input1));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(1).prepInputOperand(input2));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(2).prepInputOperand(input3));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(3).prepInputOperand(input4));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(4).prepInputOperand(input5));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(5).prepInputOperand(input6));
+                       //output
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( this.prepOutputOperand(output));
+                       
+                       return sb.toString();
+               }
+               else {
+                       throw new LopsException("The operation is not supported 
with six operands:" + operation.name());
+               }
+       }
 
        public void appendOpcode(StringBuilder sb) {
                sb.append( getExecType() );

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index d87f9d9..77598bf 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -53,7 +53,9 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.MetaDataFormat;
 import org.apache.sysml.runtime.matrix.MetaData;
 import org.apache.sysml.runtime.matrix.data.FrameBlock;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.matrix.data.Pair;
 import org.apache.sysml.runtime.util.MapReduceTool;
 import org.apache.sysml.utils.GPUStatistics;
@@ -462,6 +464,7 @@ public class ExecutionContext {
                if(mo.getGPUObject(getGPUContext(0)) == null || 
!mo.getGPUObject(getGPUContext(0)).isAllocated()) {
                        throw new DMLRuntimeException("No output is allocated 
on GPU");
                }
+               setMetaData(varName, new 
MetaDataFormat(mo.getMatrixCharacteristics(), OutputInfo.BinaryBlockOutputInfo, 
InputInfo.BinaryBlockInputInfo));
                mo.getGPUObject(getGPUContext(0)).releaseOutput();
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index fcc27e9..4663671 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -248,7 +248,6 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "conv2d_backward_data"      , 
CPType.Dnn);
                String2CPInstructionType.put( "bias_add"      , CPType.Dnn);
                String2CPInstructionType.put( "bias_multiply"      , 
CPType.Dnn);
-               String2CPInstructionType.put( "channel_sums"      , CPType.Dnn);
                String2CPInstructionType.put( "batch_norm2d",           
CPType.Dnn);
                String2CPInstructionType.put( "batch_norm2d_backward",  
CPType.Dnn);
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 59c7350..1122a24 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -60,6 +60,7 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "lstm_backward",         
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d",           
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d_backward",  
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "batch_norm2d_test",      
GPUINSTRUCTION_TYPE.Dnn);
                
                // Matrix Multiply Operators
                String2GPUInstructionType.put( "ba+*",  
GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 4532240..e54b430 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -80,13 +80,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                }
        }
        
-       public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, 
CPOperand out, String opcode, String istr, int numThreads, double 
intermediateMemoryBudget) {
-               this(in, in2, in3, out, null, null, null, null, numThreads, 
intermediateMemoryBudget, opcode, istr);
-               if( !opcode.equals("channel_sums") ) {
-                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be channel_sums, but found " + opcode);
-               }
-       }
-
        private DnnCPInstruction(CPOperand in, CPOperand out, String opcode, 
String istr,
                        ArrayList<CPOperand> stride, ArrayList<CPOperand> 
padding, ArrayList<CPOperand> input_shape,
                        ArrayList<CPOperand> filter_shape, int numThreads, 
double intermediateMemoryBudget) {
@@ -238,14 +231,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        int k = Integer.parseInt(parts[4]);
                        return new DnnCPInstruction(in, in2, out, opcode, str, 
k, Double.parseDouble(parts[5]));
                }
-               else if (opcode.equalsIgnoreCase("channel_sums")) {
-                       InstructionUtils.checkNumFields(parts, 4);
-                       CPOperand in = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand in3 = new CPOperand(parts[3]);
-                       CPOperand out = new CPOperand(parts[4]);
-                       return new DnnCPInstruction(in, in2, in3, out, opcode, 
str, -1, 0);
-               }
                else if (opcode.equalsIgnoreCase("batch_norm2d")) {
                        InstructionUtils.checkNumFields(parts, 13);
                        CPOperand in1 = new CPOperand(parts[1]); // image
@@ -358,29 +343,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
        }
        
-       public void processChannelSumsInstruction(ExecutionContext ec) {
-               MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
-               int C = (int) ec.getScalarInput(_in2.getName(), 
_in2.getValueType(), _in2.isLiteral()).getLongValue();
-               int HW = (int) ec.getScalarInput(_in3.getName(), 
_in3.getValueType(), _in3.isLiteral()).getLongValue();
-               if(C*HW != input.getNumColumns()) {
-                       throw new DMLRuntimeException("Expected rows*cols" + C 
+ "*" + HW + " to be equal to number of columns of input " + 
input.getNumColumns());
-               }
-               MatrixBlock outputBlock = null;
-               if(input.isEmpty()) {
-                       outputBlock = new MatrixBlock(C, 1, true);
-               }
-               else {
-                       outputBlock = new MatrixBlock(C, 1, 
false).allocateBlock();
-                       LibMatrixDNN.channelSums(input, outputBlock, C, HW);
-               }
-               
-               // release inputs/outputs
-               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
-               ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
-       }
-       
-       
-       
        public void processBatchNorm2dInstruction(ExecutionContext ec) {
                MatrixBlock image = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
                MatrixBlock scale = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
@@ -466,10 +428,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        processReluBackwardInstruction(ec);
                        return;
                }
-               else if (instOpcode.equalsIgnoreCase("channel_sums")) {
-                       processChannelSumsInstruction(ec);
-                       return;
-               }
                else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
                        processBatchNorm2dInstruction(ec);
                        return;

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 709be6c..b01b8d8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -151,6 +151,23 @@ public class DnnGPUInstruction extends GPUInstruction {
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
 
+       public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5, CPOperand in6,
+                       CPOperand out, String opcode, String istr, double 
intermediateMemoryBudget) {
+               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
+               if( !opcode.equals("batch_norm2d_test") ) {
+                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be batch_norm2d_test, but found " + opcode);
+               }
+               _input1 = in;
+               _input2 = in2;
+               _input3 = in3;
+               _input4 = in4;
+               _input5 = in5;
+               _input6 = in6;
+               _gputype = GPUINSTRUCTION_TYPE.Dnn;
+               _output = out;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
        public static DnnGPUInstruction parseInstruction(String str) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
@@ -323,6 +340,17 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand out3 = new CPOperand(parts[9]); // dBias
                        return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0);
                }
+               else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
+                       InstructionUtils.checkNumFields(parts, 7);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand in4 = new CPOperand(parts[4]);
+                       CPOperand in5 = new CPOperand(parts[5]);
+                       CPOperand in6 = new CPOperand(parts[6]);
+                       CPOperand out = new CPOperand(parts[7]);
+                       return new DnnGPUInstruction(in, in2, in3, in4, in5, 
in6, out, opcode, str, 0);
+               }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnGPUInstruction: " + str);      
                }
@@ -392,6 +420,28 @@ public class DnnGPUInstruction extends GPUInstruction {
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
+       public void processBatchNorm2dTestInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
+               GPUStatistics.incrementNoOfExecutedGPUInst();
+               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
+               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
+               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
+               MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
+               MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, 
_input5.getName());
+               double epsilon = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getDoubleValue();
+               
+               MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
+               
LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), 
getExtendedOpcode(), 
+                               image, scale, bias, runningMean, runningVar, 
ret, epsilon);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+       }
+       
        public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
                GPUStatistics.incrementNoOfExecutedGPUInst();
                MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
@@ -613,6 +663,10 @@ public class DnnGPUInstruction extends GPUInstruction {
                        processBatchNorm2dBackwardInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
+                       processBatchNorm2dTestInstruction(ec);
+                       return;
+               }
                
                GPUStatistics.incrementNoOfExecutedGPUInst();
                                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/gpu/AppendTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/AppendTest.java 
b/src/test/java/org/apache/sysml/test/gpu/AppendTest.java
index f359d48..3f2fdd6 100644
--- a/src/test/java/org/apache/sysml/test/gpu/AppendTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/AppendTest.java
@@ -32,7 +32,7 @@ import org.junit.Test;
  */
 public class AppendTest extends GPUTests {
 
-       private final static String TEST_NAME = "BinaryOpTests";
+       private final static String TEST_NAME = "AppendTests";
        private final int seed = 42;
 
        private final int[] rowSizes = new int[] { 1, 64, 2049 };

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java 
b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
new file mode 100644
index 0000000..83adad4
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests batchnorm rewrite
+ */
+public class BatchNormTest extends GPUTests {
+
+       private final static String TEST_NAME = "BatchNormTests";
+       private final int seed = 42;
+
+       @Override
+       public void setUp() {
+               super.setUp();
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_DIR, TEST_NAME);
+               getAndLoadTestConfiguration(TEST_NAME);
+       }
+
+       @Test
+       public void testBatchNormForwardTest() {
+               testBatchNormForward("test");
+       }
+       
+//     @Test
+//     public void testBatchNormForwardTrain() {
+//             testBatchNormForward("train");
+//     }
+       
+       private void testBatchNormForward(String mode) {
+               int imgSize = 32; 
+               int numChannels = 3;
+               double sparsity = 0.9;
+               String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") 
as batch_norm2d_old;\n "
+                               + "[output, ema_mean_upd, ema_var_upd, 
cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + 
numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, 
ema_var, 0.9, 1e-3)";
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, 32, 
numChannels*imgSize*imgSize, 0, 100, sparsity, seed));
+               inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 
0, 10, sparsity, seed));
+               inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 
0, 10, sparsity, seed));
+               inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 
1, 40, 60, sparsity, seed));
+               inputs.put("ema_var", generateInputMatrix(spark, numChannels, 
1, 5, 15, sparsity, seed));
+               List<String> outputs = Arrays.asList("output", "ema_mean_upd", 
"ema_var_upd", "cache_mean", "cache_var");
+               List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
+               List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
+               if(mode.equals("test"))
+                       assertHeavyHitterPresent("gpu_batch_norm2d_test");
+               for(int i = 0; i < outputs.size(); i++) {
+                       assertEqualObjects(outCPU.get(i), outGPU.get(i));
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java 
b/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java
new file mode 100644
index 0000000..64e3061
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/ChannelSumsTest.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests channel sums rewrite
+ */
+public class ChannelSumsTest extends GPUTests {
+
+       private final static String TEST_NAME = "ChannelSumsTest";
+       private final int seed = 42;
+
+       @Override
+       public void setUp() {
+               super.setUp();
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_DIR, TEST_NAME);
+               getAndLoadTestConfiguration(TEST_NAME);
+       }
+
+       @Test
+       public void testChannelSumsTest() {
+               int imgSize = 32; 
+               int numChannels = 10;
+               double sparsity = 0.9;
+               String scriptStr = "output = rowSums(matrix(colSums(x), rows=" 
+ numChannels + ", cols=" + imgSize*imgSize + "));";
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, 32, 
numChannels*imgSize*imgSize, 0, 100, sparsity, seed));
+               List<String> outputs = Arrays.asList("output");
+               List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
+               List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
+               assertHeavyHitterPresent("gpu_channel_sums");
+               for(int i = 0; i < outputs.size(); i++) {
+                       assertEqualObjects(outCPU.get(i), outGPU.get(i));
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
deleted file mode 100644
index 61ca370..0000000
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysml.test.integration.functions.tensor;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Test;
-
-public class ChannelSumTest extends AutomatedTestBase
-{
-       
-       private final static String TEST_NAME = "ChannelSumTest";
-       private final static String TEST_DIR = "functions/tensor/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
PoolTest.class.getSimpleName() + "/";
-       private final static double epsilon=0.0000000001;
-       
-       @Override
-       public void setUp() {
-               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, 
-                               new String[] {"B"}));
-       }
-       
-       @Test
-       public void testChannelSumDense1() 
-       {
-               int numImg = 10; int imgSize = 9; int numChannels = 5; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
-       }
-       
-       @Test
-       public void testChannelSumDense2() 
-       {
-               int numImg = 2; int imgSize = 5; int numChannels = 3; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
-       }
-       
-       @Test
-       public void testChannelSumDense3() 
-       {
-               int numImg = 9; int imgSize = 4; int numChannels = 11; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
-       }
-       
-       @Test
-       public void testChannelSumDense4() 
-       {
-               int numImg = 7; int imgSize = 8; int numChannels = 12; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
-       }
-       
-       @Test
-       public void testChannelSumSparse1() 
-       {
-               int numImg = 4; int imgSize = 10; int numChannels = 5; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
-       }
-       
-       @Test
-       public void testChannelSumSparse2() 
-       {
-               int numImg = 2; int imgSize = 10; int numChannels = 8; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
-       }
-       
-       @Test
-       public void testChannelSumSparse3() 
-       {
-               int numImg = 4; int imgSize = 10; int numChannels = 11; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
-       }
-       
-       @Test
-       public void testChannelSumSparse4() 
-       {
-               int numImg = 9; int imgSize = 6; int numChannels = 8; 
-               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
-       }
-       
-       public void runChannelSumTest( ExecType et, int imgSize, int numImg, 
int numChannels, boolean sparse) 
-       {
-               RUNTIME_PLATFORM platformOld = rtplatform;
-               switch( et ){
-                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
-                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
-                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
-               }
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == RUNTIME_PLATFORM.SPARK )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-               
-               try
-               {
-                       String sparseVal = String.valueOf(sparse).toUpperCase();
-                       
-                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
-                       loadTestConfiguration(config);
-       
-                       String RI_HOME = SCRIPT_DIR + TEST_DIR;
-                       fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{"-explain", "hops", "-args", 
String.valueOf(imgSize), 
-                               String.valueOf(numImg), 
String.valueOf(numChannels),
-                               output("B"), sparseVal};
-                       
-                       fullRScriptName = RI_HOME + TEST_NAME + ".R";
-                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
imgSize + " " + numImg + 
-                               " " + numChannels + " " + expectedDir() + " " + 
sparseVal; 
-                       
-                       // run scripts
-                       runTest(true, false, null, -1);
-                       runRScript(true);
-                       
-                       //compare results
-                       HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
-                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("B");
-                       TestUtils.compareMatrices(dmlfile, bHM, epsilon, 
"B-DML", "NumPy");
-               }
-               finally {
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-               }
-       }
-       
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/scripts/functions/tensor/ChannelSumTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.R 
b/src/test/scripts/functions/tensor/ChannelSumTest.R
deleted file mode 100644
index c605074..0000000
--- a/src/test/scripts/functions/tensor/ChannelSumTest.R
+++ /dev/null
@@ -1,39 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-args <- commandArgs(TRUE)
-library("Matrix")
-library("matrixStats") 
-imgSize=as.integer(args[1])
-numImg=as.integer(args[2])
-numChannels=as.integer(args[3])
-
-# Assumption: NCHW image format
-x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, 
numChannels*imgSize*imgSize, byrow=TRUE)
-if(as.logical(args[5])) {
-       zero_mask = (x - 1.5*mean(x)) > 0 
-       x = x * zero_mask
-} else {
-       x = x - mean(x)
-}
-
-output = rowSums(matrix(colSums(x), numChannels, imgSize*imgSize, byrow=TRUE));
-
-writeMM(as(output,"CsparseMatrix"), paste(args[4], "B", sep=""))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/942696a2/src/test/scripts/functions/tensor/ChannelSumTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.dml 
b/src/test/scripts/functions/tensor/ChannelSumTest.dml
deleted file mode 100644
index 7810a12..0000000
--- a/src/test/scripts/functions/tensor/ChannelSumTest.dml
+++ /dev/null
@@ -1,35 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# 
-#-------------------------------------------------------------
-imgSize=$1
-numImg=$2
-numChannels=$3
-
-# Assumption: NCHW image format
-x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, 
cols=numChannels*imgSize*imgSize)
-if($5) {
-       zero_mask = (x - 1.5*mean(x)) > 0 
-       x = x * zero_mask
-}
-else {
-       x = x - mean(x)
-}
-output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))  
# shape (C, 1)
-write(output, $4, format="text")
\ No newline at end of file

Reply via email to