Repository: systemml
Updated Branches:
  refs/heads/master 81419ae6a -> 0f36780a8


http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index e736a1c..d620de9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -19,8 +19,8 @@
 package org.apache.sysml.runtime.instructions.gpu;
 
 import java.util.ArrayList;
-
 import jcuda.Pointer;
+import jcuda.jcudnn.JCudnn;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
@@ -32,7 +32,6 @@ import 
org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.DnnUtils;
@@ -57,12 +56,14 @@ public class DnnGPUInstruction extends GPUInstruction {
        private ArrayList<CPOperand> _stride = new ArrayList<>();
        private ArrayList<CPOperand> _padding = new ArrayList<>();
        private double _intermediateMemoryBudget = 0;
+       private GPUContext gCtx;
+       private String instName;
        
        public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, 
String opcode, String istr, double intermediateMemoryBudget) {
                super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               if (!(opcode.equals("bias_add") || 
opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) {
+               if (!(opcode.equals("bias_add") || 
opcode.equals("bias_multiply") || opcode.equals("relu_backward") || 
opcode.equals("inv_var") )) {
                        throw new DMLRuntimeException(
-                                       "Incorrect usage. Expected the opcode 
to be bias_add or bias_multiply or relu_backward, but found "
+                                       "Incorrect usage. Expected the opcode 
to be bias_add or bias_multiply or relu_backward or inv_var, but found "
                                                        + opcode);
                }
                _input1 = in1;
@@ -112,8 +113,8 @@ public class DnnGPUInstruction extends GPUInstruction {
        public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand out, String opcode, String istr, 
                        double intermediateMemoryBudget) throws 
DMLRuntimeException {
                super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               if( !opcode.equals("channel_sums") ) {
-                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be channel_sums, but found " + opcode);
+               if( !(opcode.equals("channel_sums") || 
opcode.equals("reshape_colmeans") || opcode.equals("update_ema") ) ) {
+                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be channel_sums or reshape_colmeans or update_ema, but 
found " + opcode);
                }
                _input1 = in1;
                _input2 = in2;
@@ -126,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction {
        public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand out, String opcode, String istr, 
                        double intermediateMemoryBudget) throws 
DMLRuntimeException {
                super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               if( !opcode.equals("update_nesterov_x") ) {
+               if( !( opcode.equals("update_nesterov_x")) ) {
                        throw new DMLRuntimeException("Incorrect opcode: " + 
opcode);
                }
                _input1 = in1;
@@ -182,6 +183,22 @@ public class DnnGPUInstruction extends GPUInstruction {
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
        
+       public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5, 
+                       CPOperand out, String opcode, String istr, double 
intermediateMemoryBudget) {
+               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
+               if( !(opcode.equals("update_ema_var") || 
opcode.equals("batch_norm2d_bwd_dx")) ) {
+                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be update_ema_var or batch_norm2d_bwd_dx, but found " + 
opcode);
+               }
+               _input1 = in;
+               _input2 = in2;
+               _input3 = in3;
+               _input4 = in4;
+               _input5 = in5;
+               _gputype = GPUINSTRUCTION_TYPE.Dnn;
+               _output = out;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
        public static DnnGPUInstruction parseInstruction(String str) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
@@ -297,14 +314,15 @@ public class DnnGPUInstruction extends GPUInstruction {
                        return new DnnGPUInstruction(in1, null, out, opcode, 
str, stride,
                                        padding, input_shape, filter_shape, 
Double.parseDouble(parts[15]));
                }
-               else if( opcode.equalsIgnoreCase("bias_add") || 
opcode.equalsIgnoreCase("relu_backward") || 
opcode.equalsIgnoreCase("bias_multiply")  ) {
+               else if( opcode.equalsIgnoreCase("bias_add") || 
opcode.equalsIgnoreCase("relu_backward") || 
opcode.equalsIgnoreCase("bias_multiply") 
+                               || opcode.equalsIgnoreCase("inv_var") ) {
                        InstructionUtils.checkNumFields(parts, 4);
                        CPOperand in1 = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
                        CPOperand out = new CPOperand(parts[3]);
                        return new DnnGPUInstruction(in1, in2, out, opcode, 
str, Double.parseDouble(parts[4]));
                }
-               else if (opcode.equalsIgnoreCase("channel_sums")) {
+               else if (opcode.equalsIgnoreCase("channel_sums") || 
opcode.equals("reshape_colmeans") || opcode.equals("update_ema")) {
                        InstructionUtils.checkNumFields(parts, 4);
                        CPOperand in = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
@@ -333,7 +351,7 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand out2 = new CPOperand(parts[8]);
                        return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, out, out2, opcode, str, 0);
                }
-               else if (opcode.equalsIgnoreCase("batch_norm2d") || 
opcode.equalsIgnoreCase("lstm_backward")) {
+               else if (opcode.equalsIgnoreCase("lstm_backward")) {
                        InstructionUtils.checkNumFields(parts, 13);
                        CPOperand in1 = new CPOperand(parts[1]); // image
                        CPOperand in2 = new CPOperand(parts[2]); // scale
@@ -350,19 +368,6 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand out5 = new CPOperand(parts[13]); // 
resultSaveInvVariance
                        return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
                }
-               else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
-                       InstructionUtils.checkNumFields(parts, 9);
-                       CPOperand in1 = new CPOperand(parts[1]); // image
-                       CPOperand in2 = new CPOperand(parts[2]); // dout
-                       CPOperand in3 = new CPOperand(parts[3]); // scale
-                       CPOperand in4 = new CPOperand(parts[4]); // epsilon
-                       CPOperand in5 = new CPOperand(parts[5]); // 
resultSaveMean
-                       CPOperand in6 = new CPOperand(parts[6]); // 
resultSaveInvVariance
-                       CPOperand out = new CPOperand(parts[7]);  // dX
-                       CPOperand out2 = new CPOperand(parts[8]); // dScale
-                       CPOperand out3 = new CPOperand(parts[9]); // dBias
-                       return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0);
-               }
                else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
                        InstructionUtils.checkNumFields(parts, 7);
                        CPOperand in = new CPOperand(parts[1]);
@@ -374,21 +379,25 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand out = new CPOperand(parts[7]);
                        return new DnnGPUInstruction(in, in2, in3, in4, in5, 
in6, out, opcode, str, 0);
                }
-               else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
-                       InstructionUtils.checkNumFields(parts, 12);
-                       CPOperand in1 = new CPOperand(parts[1]); // image
-                       CPOperand in2 = new CPOperand(parts[2]); // gamma
-                       CPOperand in3 = new CPOperand(parts[3]); // beta
-                       CPOperand in4 = new CPOperand(parts[4]); // ema_mean
-                       CPOperand in5 = new CPOperand(parts[5]); // ema_var
-                       CPOperand in6 = new CPOperand(parts[6]); // eps
-                       CPOperand in7 = new CPOperand(parts[7]); // mu
-                       CPOperand out = new CPOperand(parts[8]);  // out
-                       CPOperand out2 = new CPOperand(parts[9]); // 
ema_mean_upd
-                       CPOperand out3 = new CPOperand(parts[10]); // 
ema_var_upd
-                       CPOperand out4 = new CPOperand(parts[11]); // cache_mean
-                       CPOperand out5 = new CPOperand(parts[12]); // 
cache_inv_var
-                       return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
+               else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) {
+                       InstructionUtils.checkNumFields(parts, 6);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand in4 = new CPOperand(parts[4]);
+                       CPOperand in5 = new CPOperand(parts[5]);
+                       CPOperand out = new CPOperand(parts[6]);
+                       return new DnnGPUInstruction(in, in2, in3, in4, in5, 
out, opcode, str, 0);
+               }
+               else if (opcode.equalsIgnoreCase("update_ema_var")) {
+                       InstructionUtils.checkNumFields(parts, 6);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand in4 = new CPOperand(parts[4]);
+                       CPOperand in5 = new CPOperand(parts[5]);
+                       CPOperand out = new CPOperand(parts[6]);
+                       return new DnnGPUInstruction(in, in2, in3, in4, in5, 
out, opcode, str, 0);
                }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnGPUInstruction: " + str);      
@@ -396,211 +405,185 @@ public class DnnGPUInstruction extends GPUInstruction {
        }
 
        private void processBiasInstruction(String instOpcode, ExecutionContext 
ec) {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), input.getNumRows(), input.getNumColumns());
-               
-               if(instOpcode.equalsIgnoreCase("bias_add"))
-                       LibMatrixCUDA.biasAdd(ec.getGPUContext(0), 
getExtendedOpcode(), input, bias, out);
-               else if(instOpcode.equalsIgnoreCase("bias_multiply"))
-                       LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), 
getExtendedOpcode(), input, bias, out);
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-       }
-       
-       private void processBatchNorm2dInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
-               MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, 
_input5.getName());
-               
-               String phase = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getStringValue();
-               double epsilon = ec.getScalarInput(_input7.getName(), 
_input7.getValueType(), _input7.isLiteral()).getDoubleValue();
-               
-               MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
-               
-               if(phase.equalsIgnoreCase("train")) {
-                       double exponentialAverageFactor = 
1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), 
_input8.isLiteral()).getDoubleValue();
-                       MatrixObject retRunningMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
-                       MatrixObject retRunningVar = 
getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
-                       MatrixObject resultSaveMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
-                       MatrixObject resultSaveInvVariance = 
getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
-                       
LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), 
getExtendedOpcode(), 
-                               image, scale, bias, runningMean, runningVar, 
ret, 
-                               retRunningMean, retRunningVar, epsilon, 
exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
-                       
ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-                       
ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
-                       
ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
-                       
ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
-               }
-               else if(phase.equalsIgnoreCase("test")) {
-                       
LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), 
getExtendedOpcode(), 
-                                       image, scale, bias, runningMean, 
runningVar, ret, epsilon);
-                       ec.setMatrixOutput(_output2.getName(), new 
MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), 
true), getExtendedOpcode());
-                       ec.setMatrixOutput(_output3.getName(), new 
MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), 
true), getExtendedOpcode());
-                       ec.setMatrixOutput(_output4.getName(), new 
MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), 
true), getExtendedOpcode());
-                       ec.setMatrixOutput(_output5.getName(), new 
MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), 
true), getExtendedOpcode());
-               }
-               else {
-                       throw new DMLRuntimeException("Incorrect mode: Expected 
either train or test, but found " + phase);
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("input", _input1).add("bias", _input2);
+                       
+                       MatrixObject input = 
fetcher.getInputMatrixObject("input");
+                       MatrixObject bias = 
fetcher.getInputMatrixObject("bias");
+                       MatrixObject out = 
fetcher.getOutputMatrixObject(input.getNumRows(), input.getNumColumns());
+                       
+                       if(instOpcode.equalsIgnoreCase("bias_add"))
+                               LibMatrixCUDA.biasAdd(gCtx, instName, input, 
bias, out);
+                       else if(instOpcode.equalsIgnoreCase("bias_multiply"))
+                               LibMatrixCUDA.biasMultiply(gCtx, instName, 
input, bias, out);
                }
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
-       private void processBatchNorm2dTrainInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
-               MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, 
_input5.getName());
-               
-               double epsilon = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getDoubleValue();
-               double exponentialAverageFactor = 
1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), 
_input7.isLiteral()).getDoubleValue();
-               
-               MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
-               MatrixObject retRunningMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
-               MatrixObject retRunningVar = 
getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
-               MatrixObject resultSaveMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
-               MatrixObject resultSaveInvVariance = 
getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
-               
-               
LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), 
getExtendedOpcode(), 
-                       image, scale, bias, runningMean, runningVar, ret, 
-                       retRunningMean, retRunningVar, epsilon, 
exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
+       private void processInverseVarianceInstruction(String instOpcode, 
ExecutionContext ec) {
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("X", _input1).addScalar("eps", _input2);
+                       
+                       int rows = 
LibMatrixCUDA.toInt(fetcher.getInputNumRows("X"));
+                       int cols = 
LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X"));
+                       
+                       // invVar(X, C, eps, size);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("invVar", 
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols),
+                                       fetcher.getInputPointer("X"), 
fetcher.getOutputPointer(rows, cols),
+                                       fetcher.getDouble("eps"), rows*cols);
+               }
        }
        
        private void processBatchNorm2dTestInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
-               MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, 
_input5.getName());
-               double epsilon = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getDoubleValue();
-               
-               MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
-               
LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), 
getExtendedOpcode(), 
-                               image, scale, bias, runningMean, runningVar, 
ret, epsilon);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("image", _input1).add("scale", 
_input2).add("bias", _input3)
+                       .add("runningMean", _input4).add("runningVar", 
_input5).addScalar("epsilon", _input6);
+                       
+                       double epsilon = fetcher.getDouble("epsilon");
+                       if(epsilon < JCudnn.CUDNN_BN_MIN_EPSILON) {
+                               throw new DMLRuntimeException("The epsilon (" + 
epsilon + ") cannot be less than CUDNN_BN_MIN_EPSILON=(" + 
JCudnn.CUDNN_BN_MIN_EPSILON + ")");
+                       }
+                       
+                       MatrixObject image = 
fetcher.getInputMatrixObject("image");
+                       LibMatrixCuDNN.batchNormalizationForwardInference(gCtx, 
instName, 
+                               image, fetcher.getInputMatrixObject("scale"), 
fetcher.getInputMatrixObject("bias"), 
+                               fetcher.getInputMatrixObject("runningMean"), 
fetcher.getInputMatrixObject("runningVar"), 
+                               
fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), 
epsilon);
+               }
        }
        
-       public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               double epsilon = ec.getScalarInput(_input4.getName(), 
_input4.getValueType(), _input4.isLiteral()).getDoubleValue();
-               MatrixObject resultSaveMean = 
getMatrixInputForGPUInstruction(ec, _input5.getName());
-               MatrixObject resultSaveInvVariance = 
getMatrixInputForGPUInstruction(ec, _input6.getName());
-               
-               MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
-               MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, 
_output2.getName(), scale.getNumRows(), scale.getNumColumns());
-               MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, 
_output3.getName(), scale.getNumRows(), scale.getNumColumns());
-               
-               LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), 
getExtendedOpcode(), image, 
-                               dout, scale, dX, dScale, dBias,
-                               epsilon, resultSaveMean, resultSaveInvVariance);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input6.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+       private void processBatchNorm2dBackwardDxInstruction(ExecutionContext 
ec) throws DMLRuntimeException {
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("X", _input1).add("dout", 
_input2).add("gamma", _input3)
+                       .add("resultSaveMean", 
_input4).add("resultSaveInvVariance", _input5);
+                       
+                       // #define CUDNN_BN_MIN_EPSILON 1e-5 // Minimum epsilon 
allowed to be used in the Batch Normalization formula
+                       double epsilon = 1e-4; 
+                       MatrixObject image = fetcher.getInputMatrixObject("X");
+                       LibMatrixCuDNN.batchNormalizationBackwardDX(gCtx, 
instName, image, 
+                                       fetcher.getInputMatrixObject("dout"), 
fetcher.getInputMatrixObject("gamma"), 
+                                       
fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), 
epsilon, fetcher.getInputMatrixObject("resultSaveMean"), 
+                                       
fetcher.getInputMatrixObject("resultSaveInvVariance")); 
+               }
        }
-
+       
+       
        // (X > 0) * dout
        public void processReLUBackwardInstruction(ExecutionContext ec) {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               
-               MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), input.getNumRows(), input.getNumColumns());
-               
-               LibMatrixCUDA.reluBackward(ec.getGPUContext(0), 
getExtendedOpcode(), input, dout, out);
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("X", _input1).add("dout", _input2);
+                       MatrixObject X = fetcher.getInputMatrixObject("X");
+                       LibMatrixCUDA.reluBackward(gCtx, instName, X, 
+                                       fetcher.getInputMatrixObject("dout"), 
fetcher.getOutputMatrixObject(X.getNumRows(), X.getNumColumns()));
+               }
        }
        
        private void processChannelSumsInstruction(ExecutionContext ec) {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               int C = (int) ec.getScalarInput(_input2.getName(), 
_input2.getValueType(), _input2.isLiteral()).getLongValue();
-               int HW = (int) ec.getScalarInput(_input3.getName(), 
_input3.getValueType(), _input3.isLiteral()).getLongValue();
-               if(C*HW != input.getNumColumns()) {
-                       throw new DMLRuntimeException("Expected rows*cols" + C 
+ "*" + HW + " to be equal to number of columns of input " + 
input.getNumColumns());
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("X", _input1).addScalar("C", 
_input2).addScalar("HW", _input3);
+                       int C = fetcher.getInteger("C");
+                       int HW = fetcher.getInteger("HW");
+                       fetcher.validateDimensions("X", -1, C*HW);
+                       LibMatrixCUDA.channelSums(gCtx, instName, 
+                                       fetcher.getInputMatrixObject("X"), 
+                                       fetcher.getOutputMatrixObject(C, 1), C, 
HW);
+               }
+       }
+       
+       private void processEMAInstruction(ExecutionContext ec) {
+               // "ema_mean", "mean", "mu"
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("ema_mean", _input1).add("mean", 
_input2).addScalar("mu", _input3);
+                       double mu = fetcher.getDouble("mu");
+                       
+                       int rows = 
LibMatrixCUDA.toInt(fetcher.getInputNumRows("ema_mean"));
+                       int cols = 
LibMatrixCUDA.toInt(fetcher.getInputNumColumns("ema_mean"));
+                       
+                       fetcher.validateDimensions("mean", rows, cols);
+                       
+                       // aXplusbY(X, Y, C, a, b, size);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbY", 
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols),
+                                       fetcher.getInputPointer("ema_mean"), 
fetcher.getInputPointer("mean"), 
+                                       fetcher.getOutputPointer(rows, cols),
+                                       mu, (1-mu), rows*cols);
+               }
+       }
+       
+       private void processReshapeColMeansInstruction(ExecutionContext ec) {
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("X", _input1).addScalar("C", 
_input2).addScalar("HW", _input3);
+                       int C = fetcher.getInteger("C");
+                       int HW = fetcher.getInteger("HW");
+                       fetcher.validateDimensions("X", -1, C*HW);
+                       int rows = 
LibMatrixCUDA.toInt(fetcher.getInputNumRows("X"));
+                       int cols = 
LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X"));
+                       // output = matrix(colMeans(X), rows=C, cols=Hin*Win)
+                       LibMatrixCUDA.colMeans(gCtx, instName,  
+                                       fetcher.getInputPointer("X"), 
+                                       fetcher.getOutputPointer(C, HW), rows, 
cols);
                }
-               MatrixObject outputBlock = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1);
-               
-               LibMatrixCUDA.channelSums(ec.getGPUContext(0), 
getExtendedOpcode(), input, outputBlock, C, HW);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
+       private void processUpdateEMAVarInstruction(ExecutionContext ec) {
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       // "subgrp_means", "X", "C", "HW", "varConst1"
+                       fetcher.add("subgrp_means", _input1).add("X", 
_input2).addScalar("C", _input3)
+                               .addScalar("HW", 
_input4).addScalar("varConst1", _input5);
+                       
+                       // subgrp_vars = matrix(colVars(X) * varConst1, rows=C, 
cols=Hin*Win)
+                       // var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+                       // --->
+                       // subgrp_vars = matrix(colVars(X), rows=C, cols=HW)
+                       // var = rowMeans(subgrp_vars)*varConst1 + 
rowVars(subgrp_means)*((HW-1)/HW)  
+                       int C = fetcher.getInteger("C");
+                       int HW = fetcher.getInteger("HW");
+                       double varConst1 = fetcher.getDouble("varConst1");
+                       fetcher.validateDimensions("subgrp_means", C, HW);
+                       fetcher.validateDimensions("X", -1, C*HW);
+                       
+                       Pointer subgrp_vars = gCtx.allocate(instName, 
C*HW*LibMatrixCUDA.sizeOfDataType);
+                       // subgrp_vars <- colVars(X)
+                       LibMatrixCUDA.colVars(gCtx, instName, 
fetcher.getInputPointer("X"), subgrp_vars, 
+                                       
LibMatrixCUDA.toInt(fetcher.getInputNumRows("X")), C*HW);
+                       
+                       // tmp1 <- rowMeans(subgrp_vars)
+                       Pointer tmp1 = gCtx.allocate(instName, 
C*LibMatrixCUDA.sizeOfDataType);
+                       LibMatrixCUDA.rowMeans(gCtx, instName, subgrp_vars, 
tmp1, C, HW);
+                       gCtx.cudaFreeHelper(instName, subgrp_vars, 
gCtx.EAGER_CUDA_FREE);
+                       
+                       // out <- rowVars(subgrp_means)
+                       Pointer out = fetcher.getOutputPointer(C, 1);
+                       LibMatrixCUDA.rowVars(gCtx, instName, 
fetcher.getInputPointer("subgrp_means"), out, C, HW);
+                       
+                       // var = tmp1*varConst1 + out*((HW-1)/HW)
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbC", 
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(C),
+                                       tmp1, out,
+                                       varConst1, (((double)HW-1)/HW), C);
+                       gCtx.cudaFreeHelper(instName, tmp1, 
gCtx.EAGER_CUDA_FREE);
+               }
+       }
+       
+       
+       
        private void processNesterovUpdateInstruction(ExecutionContext ec) {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject v = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               double mu = (int) ec.getScalarInput(_input4.getName(), 
_input4.getValueType(), _input4.isLiteral()).getDoubleValue();
-               int rows = LibMatrixCUDA.toInt(input.getNumRows());
-               int cols = LibMatrixCUDA.toInt(input.getNumColumns());
-               MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), rows, cols);
-               
-               GPUContext gCtx = ec.getGPUContext(0);
-               String instName = getExtendedOpcode();
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", 
-                               
ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
-                               LibMatrixCUDA.getDensePointer(gCtx, input, 
instName), 
-                               LibMatrixCUDA.getDensePointer(gCtx, v, 
instName),
-                               LibMatrixCUDA.getDensePointer(gCtx, v_prev, 
instName),
-                               mu, 
-                               LibMatrixCUDA.getDensePointer(gCtx, out, 
instName),
-                               rows*cols);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+               try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                       fetcher.add("input", _input1).add("v", 
_input2).add("v_prev", _input3)
+                       .addScalar("mu", _input4);
+                       MatrixObject input = 
fetcher.getInputMatrixObject("input");
+                       int rows = LibMatrixCUDA.toInt(input.getNumRows());
+                       int cols = LibMatrixCUDA.toInt(input.getNumColumns());
+                       
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", 
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
+                                       fetcher.getInputPointer("input"), 
+                                       fetcher.getInputPointer("v"),
+                                       fetcher.getInputPointer("v_prev"),
+                                       fetcher.getDouble("mu"), 
+                                       fetcher.getOutputPointer(rows, cols),
+                                       rows*cols);
+               }
        }
        
        private static int toInt(long num) throws DMLRuntimeException {
@@ -610,32 +593,18 @@ public class DnnGPUInstruction extends GPUInstruction {
                return (int)num;
        }
        
-//     private Pointer transpose(ExecutionContext ec, MatrixObject X) throws 
DMLRuntimeException {
-//             GPUContext gCtx = ec.getGPUContext(0);
-//             String instructionName = getExtendedOpcode();
-//             long numRowsX = X.getNumRows(); long numColsX = 
X.getNumColumns();
-//             Pointer tX = gCtx.allocate(instructionName, 
numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType);
-//             jcuda.runtime.JCuda.cudaMemcpy(tX, 
LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), 
numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType,  
jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice);
-//             // LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, 
LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, 
numColsX);
-//             return tX;
-//     }
-       
        private void processLstmBackwardInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               GPUContext gCtx = ec.getGPUContext(0);
-               String instructionName = getExtendedOpcode();
-               
                MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
                int M = toInt(out0.getNumColumns()); // hiddenSize .. since 
out0: (N, M)
-               Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instructionName);
+               Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instName);
                
                MatrixObject W = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
                MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
                long numRowsW = W.getNumRows();
                int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... 
numFeatures 
-               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
-               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
-               Pointer cudnnWPointer = gCtx.allocate(instructionName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+               Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
                
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
                                
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
                                sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
@@ -644,20 +613,20 @@ public class DnnGPUInstruction extends GPUInstruction {
                
                
                MatrixObject X = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instructionName); 
+               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instName); 
                int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
                long numColsX = X.getNumColumns();
                int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-               Pointer cudnnInput = gCtx.allocate(instructionName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
+               Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
                
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
                                
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
                                xPointer, cudnnInput, N, D, T*D, N*T*D);
                ec.releaseMatrixInputForGPUInstruction(_input1.getName());
                
-               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
+               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
                boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
                
-               // LibMatrixCuDNN.lstm(ec, gCtx, instructionName, 
+               // LibMatrixCuDNN.lstm(ec, gCtx, instName, 
                                // cudnnInput, cudnnWPointer, out0Pointer, 
c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
                                // String xName, Pointer hx, Pointer cx, 
Pointer wPointer, String doutName, String dcyName,  // input
                                // String dxName, String dwName, String dbName, 
String dhxName, String dcxName,         // output
@@ -668,12 +637,12 @@ public class DnnGPUInstruction extends GPUInstruction {
                String dcxName = _output5.getName();
                String doutName = _input7.getName();
                String dcyName = _input8.getName();
-               LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, 
+               LibMatrixCuDNN.lstmBackward(ec, gCtx, instName, 
                                cudnnInput, out0Pointer, c0Pointer, 
cudnnWPointer, doutName, dcyName,  // input
                                dxName, dwName, dbName, dhxName, dcxName, // 
output 
                                return_sequences, N, M, D, T);
-               gCtx.cudaFreeHelper(instructionName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
-               gCtx.cudaFreeHelper(instructionName, cudnnInput, 
gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
                
                // release inputs/outputs
                ec.releaseMatrixInputForGPUInstruction(_input4.getName());
@@ -686,21 +655,17 @@ public class DnnGPUInstruction extends GPUInstruction {
                // weight W:(D+M+2, 4M) 
                // previous output out0 (also represented by hx) and cell state 
c0 (also represented by cx): (N, M) ==> (1, M, N)
                // out: (N, T*M) or (N, M) ==> (T, M, N)
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               GPUContext gCtx = ec.getGPUContext(0);
-               String instructionName = getExtendedOpcode();
-               
                MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
                int M = toInt(out0.getNumColumns()); // hiddenSize .. since 
out0: (N, M)
-               Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instructionName);
+               Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instName);
                
                MatrixObject W = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
                MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
                long numRowsW = W.getNumRows();
                int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... 
numFeatures 
-               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
-               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
-               Pointer cudnnWPointer = gCtx.allocate(instructionName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+               Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
                
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
                                
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
                                sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
@@ -711,21 +676,21 @@ public class DnnGPUInstruction extends GPUInstruction {
                
                // Beause the matrices are released immediately, the output for 
transpose need not be taken into account
                MatrixObject X = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instructionName); 
+               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instName); 
                int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
                long numColsX = X.getNumColumns();
                int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-               Pointer cudnnInput = gCtx.allocate(instructionName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
+               Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
                
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
                                
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
                                xPointer, cudnnInput, N, D, T*D, N*T*D);
                ec.releaseMatrixInputForGPUInstruction(_input1.getName());
                
-               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); 
+               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); 
                
-               LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, 
cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), 
_output2.getName(), N, M, D, T);
-               gCtx.cudaFreeHelper(instructionName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
-               gCtx.cudaFreeHelper(instructionName, cudnnInput, 
gCtx.EAGER_CUDA_FREE);
+               LibMatrixCuDNN.lstm(ec, gCtx, instName, cudnnInput, 
cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), 
_output2.getName(), N, M, D, T);
+               gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
                
                // release inputs/outputs
                ec.releaseMatrixInputForGPUInstruction(_input4.getName());
@@ -736,10 +701,17 @@ public class DnnGPUInstruction extends GPUInstruction {
        
        @Override
        public void processInstruction(ExecutionContext ec) {
+               GPUStatistics.incrementNoOfExecutedGPUInst();
+               gCtx = ec.getGPUContext(0);
+               instName = getExtendedOpcode();
                if (instOpcode.equalsIgnoreCase("bias_add") || 
instOpcode.equalsIgnoreCase("bias_multiply")) {
                        processBiasInstruction(instOpcode, ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("inv_var")) {
+                       processInverseVarianceInstruction(instOpcode, ec);
+                       return;
+               }
                else if (instOpcode.equalsIgnoreCase("relu_backward")) {
                        processReLUBackwardInstruction(ec);
                        return;
@@ -748,10 +720,22 @@ public class DnnGPUInstruction extends GPUInstruction {
                        processChannelSumsInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("update_ema")) {
+                       processEMAInstruction(ec);
+                       return;
+               }
+               else if (instOpcode.equalsIgnoreCase("reshape_colmeans")) {
+                       processReshapeColMeansInstruction(ec);
+                       return;
+               }
                else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
                        processNesterovUpdateInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("update_ema_var")) {
+                       processUpdateEMAVarInstruction(ec);
+                       return;
+               }
                else if (instOpcode.equalsIgnoreCase("lstm")) {
                        processLstmInstruction(ec);
                        return;
@@ -760,24 +744,14 @@ public class DnnGPUInstruction extends GPUInstruction {
                        processLstmBackwardInstruction(ec);
                        return;
                }
-               else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
-                       processBatchNorm2dInstruction(ec);
-                       return;
-               }
-               else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
-                       processBatchNorm2dBackwardInstruction(ec);
-                       return;
-               }
                else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
                        processBatchNorm2dTestInstruction(ec);
                        return;
                }
-               else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
-                       processBatchNorm2dTrainInstruction(ec);
+               else if (instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) {
+                       processBatchNorm2dBackwardDxInstruction(ec);
                        return;
                }
-               
-               GPUStatistics.incrementNoOfExecutedGPUInst();
                                        
                int pad_h = getScalarInput(ec, _padding, 0);
                int pad_w = getScalarInput(ec, _padding, 1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
new file mode 100644
index 0000000..8fcaec3
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.instructions.gpu;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.data.Pair;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+
+public class GPUDenseInputPointerFetcher implements java.lang.AutoCloseable {
+       ExecutionContext _ec; GPUContext _gCtx; String _instName;
+       HashMap<String, CPOperand> _inputMatrices = new HashMap<>();
+       HashMap<String, MatrixObject> _inputMatrixObjects = new HashMap<>();
+       HashMap<String, CPOperand> _inputScalars = new HashMap<>();
+       CPOperand _output;
+       public GPUDenseInputPointerFetcher(ExecutionContext ec, GPUContext 
gCtx, String instName, CPOperand output) {
+               _ec = ec;
+               _gCtx = gCtx;
+               _instName = instName;
+               _output = output;
+       }
+       public GPUDenseInputPointerFetcher add(String var, CPOperand in) {
+               _inputMatrices.put(var, in);
+               return this;
+       }
+       public GPUDenseInputPointerFetcher addScalar(String var, CPOperand in) {
+               _inputScalars.put(var, in);
+               return this;
+       }
+       public double getDouble(String var) {
+               CPOperand in = _inputScalars.get(var);
+               return _ec.getScalarInput(in.getName(), in.getValueType(), 
in.isLiteral()).getDoubleValue();
+       }
+       public long getLong(String var) {
+               CPOperand in = _inputScalars.get(var);
+               return _ec.getScalarInput(in.getName(), in.getValueType(), 
in.isLiteral()).getLongValue();
+       }
+       public int getInteger(String var) {
+               CPOperand in = _inputScalars.get(var);
+               return LibMatrixCUDA.toInt(_ec.getScalarInput(in.getName(), 
in.getValueType(), in.isLiteral()).getLongValue());
+       }
+       public Pointer getInputPointer(String var) {
+               return LibMatrixCUDA.getDensePointer(_gCtx, 
getInputMatrixObject(var), _instName);
+       }
+       public long getInputNumRows(String var) {
+               return getInputMatrixObject(var).getNumRows();
+       }
+       public long getInputNumColumns(String var) {
+               return getInputMatrixObject(var).getNumColumns();
+       }
+       public MatrixObject getOutputMatrixObject(long numRows, long numCols) {
+               boolean isFinegrainedStats = 
ConfigurationManager.isFinegrainedStatistics();
+               long t0 = isFinegrainedStats ? System.nanoTime() : 0;
+               Pair<MatrixObject, Boolean> mb = 
_ec.getDenseMatrixOutputForGPUInstruction(_output.getName(), numRows, numCols);
+               if (isFinegrainedStats && mb.getValue()) 
GPUStatistics.maintainCPMiscTimes(_instName,
+                               
GPUInstruction.MISC_TIMER_ALLOCATE_DENSE_OUTPUT, System.nanoTime() - t0);
+               return  mb.getKey();
+       }
+       public Pointer getOutputPointer(long numRows, long numCols) {
+               return LibMatrixCUDA.getDensePointer(_gCtx, 
getOutputMatrixObject(numRows, numCols), _instName);
+       }
+       public MatrixObject getInputMatrixObject(String var) {
+               CPOperand in = _inputMatrices.get(var);
+               if(!_inputMatrixObjects.containsKey(var)) {
+                       _inputMatrixObjects.put(var, 
_ec.getMatrixInputForGPUInstruction(in.getName(), _instName));
+               }
+               return _inputMatrixObjects.get(var);
+       }
+       public void validateDimensions(String var, long numRows, long numCols) {
+               MatrixObject mo = getInputMatrixObject(var);
+               if(numRows > 0 && mo.getNumRows() != numRows) {
+                       throw new DMLRuntimeException("Expected number of rows 
of subgrp_means to be " + numRows + ", but found " + mo.getNumRows());
+               }
+               else if(numCols > 0 && mo.getNumColumns() != numCols) {
+                       throw new DMLRuntimeException("Expected number of 
columns of subgrp_means to be " + numCols + ", but found " + 
mo.getNumColumns());
+               }
+       }
+       @Override
+       public void close() {
+               for(CPOperand in : _inputMatrices.values()) {
+                       _ec.releaseMatrixInputForGPUInstruction(in.getName());
+               }
+               _ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+       }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 2e43b99..e01c71a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -242,7 +242,7 @@ public class GPUMemoryManager {
         * @return allocated pointer
         */
        public Pointer malloc(String opcode, long size) {
-               if(size < 0) {
+               if(size <= 0) {
                        throw new DMLRuntimeException("Cannot allocate memory 
of size " + byteCountToDisplaySize(size));
                }
                if(DEBUG_MEMORY_LEAK) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index d02a875..f3f8434 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -208,7 +208,7 @@ public class LibMatrixCUDA {
                return gCtx.getCusparseHandle();
        }
 
-       protected static cublasHandle getCublasHandle(GPUContext gCtx) {
+       public static cublasHandle getCublasHandle(GPUContext gCtx) {
                return gCtx.getCublasHandle();
        }
 
@@ -302,7 +302,7 @@ public class LibMatrixCUDA {
                }
        }
        
-       protected static Pointer dataTypePointerTo(double value) {
+       public static Pointer dataTypePointerTo(double value) {
                if(value == 1) {
                        return one();
                }
@@ -313,7 +313,6 @@ public class LibMatrixCUDA {
                        return _dataTypePointerTo(value);
                }
        }
-       
 
        /**
         * This method computes the backpropagation errors for previous layer 
of relu operation
@@ -753,11 +752,11 @@ public class LibMatrixCUDA {
                                break;
                        }
                        case REDUCTION_COL: {
-                               reduceRow(gCtx, instName, "reduce_row_mean", 
in, out, rlen, clen);
+                               rowMeans(gCtx, instName, in, out, rlen, clen);
                                break;
                        }
                        case REDUCTION_ROW: {
-                               reduceCol(gCtx, instName, "reduce_col_mean", 
in, out, rlen, clen);
+                               colMeans(gCtx, instName, in, out, rlen, clen);
                                break;
                        }
                        default:
@@ -818,13 +817,14 @@ public class LibMatrixCUDA {
                        break;
                }
                case OP_VARIANCE : {
-                       // Temporary GPU array for
-                       Pointer tmp = gCtx.allocate(instName, size * 
sizeOfDataType);
-                       Pointer tmp2 = gCtx.allocate(instName, size * 
sizeOfDataType);
+                       
 
                        switch(reductionDirection) {
 
                        case REDUCTION_ALL: {
+                               // Temporary GPU array for
+                               Pointer tmp = gCtx.allocate(instName, size * 
sizeOfDataType);
+                               Pointer tmp2 = gCtx.allocate(instName, size * 
sizeOfDataType);
                                double result = reduceAll(gCtx, instName, 
"reduce_sum", in, size);
                                double mean = result / size;
 
@@ -837,50 +837,21 @@ public class LibMatrixCUDA {
                                double result2 = reduceAll(gCtx, instName, 
"reduce_sum", tmp2, size);
                                double variance = result2 / (size - 1);
                                ec.setScalarOutput(output, new 
DoubleObject(variance));
-
+                               gCtx.cudaFreeHelper(instName, tmp, 
gCtx.EAGER_CUDA_FREE);
+                               gCtx.cudaFreeHelper(instName, tmp2, 
gCtx.EAGER_CUDA_FREE);
                                break;
                        }
                        case REDUCTION_COL: {
-                               reduceRow(gCtx, instName, "reduce_row_mean", 
in, out, rlen, clen);
-                               // Subtract the row-wise mean from every 
element in the matrix
-                               BinaryOperator minusOp = new 
BinaryOperator(Minus.getMinusFnObject());
-                               matrixMatrixOp(gCtx, instName, in, out, rlen, 
clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
-
-                               squareMatrix(gCtx, instName, tmp, tmp2, rlen, 
clen);
-
-                               Pointer tmpRow = gCtx.allocate(instName, rlen * 
sizeOfDataType);
-                               reduceRow(gCtx, instName, "reduce_row_sum", 
tmp2, tmpRow, rlen, clen);
-
-                               ScalarOperator divideOp = new 
RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
-                               matrixScalarOp(gCtx, instName, tmpRow, clen - 
1, rlen, 1, out, divideOp);
-
-                               gCtx.cudaFreeHelper(instName, tmpRow, 
gCtx.EAGER_CUDA_FREE);
-
+                               rowVars(gCtx, instName, in, out, rlen, clen);
                                break;
                        }
                        case REDUCTION_ROW: {
-                               reduceCol(gCtx, instName, "reduce_col_mean", 
in, out, rlen, clen);
-                               // Subtract the columns-wise mean from every 
element in the matrix
-                               BinaryOperator minusOp = new 
BinaryOperator(Minus.getMinusFnObject());
-                               matrixMatrixOp(gCtx, instName, in, out, rlen, 
clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
-
-                               squareMatrix(gCtx, instName, tmp, tmp2, rlen, 
clen);
-
-                               Pointer tmpCol = gCtx.allocate(instName, clen * 
sizeOfDataType);
-                               reduceCol(gCtx, instName, "reduce_col_sum", 
tmp2, tmpCol, rlen, clen);
-
-                               ScalarOperator divideOp = new 
RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
-                               matrixScalarOp(gCtx, instName, tmpCol, rlen - 
1, 1, clen, out, divideOp);
-
-                               gCtx.cudaFreeHelper(instName, tmpCol, 
gCtx.EAGER_CUDA_FREE);
-
+                               colVars(gCtx, instName, in, out, rlen, clen);
                                break;
                        }
                        default:
                                throw new DMLRuntimeException("Internal Error - 
Unsupported reduction direction for variance");
                        }
-                       gCtx.cudaFreeHelper(instName, tmp, 
gCtx.EAGER_CUDA_FREE);
-                       gCtx.cudaFreeHelper(instName, tmp2, 
gCtx.EAGER_CUDA_FREE);
                        break;
                }
                case OP_MAXINDEX : {
@@ -904,6 +875,59 @@ public class LibMatrixCUDA {
                default : throw new DMLRuntimeException("Internal Error - 
Invalid GPU Unary aggregate function!");
                }
        }
+       
+       public static void rowMeans(GPUContext gCtx, String instName, Pointer 
in, Pointer out, int rlen, int clen) {
+               LibMatrixCUDA.reduceRow(gCtx, instName, "reduce_row_mean", in, 
out, rlen, clen);
+       }
+       
+       public static void colMeans(GPUContext gCtx, String instName, Pointer 
in, Pointer out, int rlen, int clen) {
+               reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, 
clen);
+       }
+       
+       public static void colVars(GPUContext gCtx, String instName, Pointer 
in, Pointer out, int rlen, int clen) {
+               int size = rlen * clen;
+               Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+               Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+               reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, 
clen);
+               // Subtract the columns-wise mean from every element in the 
matrix
+               BinaryOperator minusOp = new 
BinaryOperator(Minus.getMinusFnObject());
+               matrixMatrixOp(gCtx, instName, in, out, rlen, clen, 
VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
+
+               squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
+
+               Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType);
+               reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, 
clen);
+
+               ScalarOperator divideOp = new 
RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
+               matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, 
divideOp);
+
+               gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
+       }
+       
+       public static void rowVars(GPUContext gCtx, String instName, Pointer 
in, Pointer out, int rlen, int clen) {
+               int size = rlen * clen;
+               Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+               Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+               
+               reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, 
clen);
+               // Subtract the row-wise mean from every element in the matrix
+               BinaryOperator minusOp = new 
BinaryOperator(Minus.getMinusFnObject());
+               matrixMatrixOp(gCtx, instName, in, out, rlen, clen, 
VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
+
+               squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
+
+               Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType);
+               reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, 
clen);
+
+               ScalarOperator divideOp = new 
RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
+               matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, 
divideOp);
+
+               gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
+       }
 
        /**
         * Helper method to square a matrix in GPU memory
@@ -970,7 +994,7 @@ public class LibMatrixCUDA {
         * @param rows                                          number of rows 
in input matrix
         * @param cols                                          number of 
columns in input matrix
         */
-       private static void reduceRow(GPUContext gCtx, String instName, String 
kernelFunction, Pointer in, Pointer out, int rows, int cols) {
+       public static void reduceRow(GPUContext gCtx, String instName, String 
kernelFunction, Pointer in, Pointer out, int rows, int cols) {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : reduceRow for " + kernelFunction + ", 
GPUContext=" + gCtx);
                }
@@ -997,7 +1021,7 @@ public class LibMatrixCUDA {
         * @param rows                                          number of rows 
in input matrix
         * @param cols                                          number of 
columns in input matrix
         */
-       private static void reduceCol(GPUContext gCtx, String instName, String 
kernelFunction, Pointer in, Pointer out, int rows, int cols) {
+       public static void reduceCol(GPUContext gCtx, String instName, String 
kernelFunction, Pointer in, Pointer out, int rows, int cols) {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : reduceCol for " + kernelFunction + ", 
GPUContext=" + gCtx);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index d3b5984..e7955e1 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -1108,23 +1108,29 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                                runningMeanPtr, runningVarPtr, epsilon));
        }
        
+       private static void validateDimensions(MatrixObject mo, long 
expectedRows, long expectedCols) {
+               if(mo.getNumRows() != expectedRows || mo.getNumColumns() != 
expectedCols) {
+                       throw new DMLRuntimeException("Incorrect dimensions for 
the input matrix object. Expected [" + expectedRows + ", " +  expectedCols+ "], 
but found "
+                                       + "[" + mo.getNumRows() + ", " + 
mo.getNumColumns() + "].");
+               }
+       }
+       
        /**
-        * This method computes the backpropagation errors for image, scale and 
bias of batch normalization layer
+        * This method computes the backpropagation errors for image of batch 
normalization layer
         * @param gCtx   a valid {@link GPUContext}
         * @param instName name of the instruction
         * @param image input image
         * @param dout input errors of shape C, H, W
         * @param scale scale (as per CuDNN) and gamma as per original paper: 
shape [1, C, 1, 1]
         * @param dX (output) backpropagation errors for previous layer
-        * @param dScale backpropagation error for scale
-        * @param dBias backpropagation error for bias
         * @param epsilon epsilon value used in the batch normalization formula
         * @param resultSaveMean (input) running mean accumulated during 
training phase: shape [1, C, 1, 1]
         * @param resultSaveInvVariance (input) running variance accumulated 
during training phase: shape [1, C, 1, 1]
         * @throws DMLRuntimeException if error occurs
         */
-       public static void batchNormalizationBackward(GPUContext gCtx, String 
instName, MatrixObject image, MatrixObject dout,
-                       MatrixObject scale, MatrixObject dX, MatrixObject 
dScale, MatrixObject dBias,
+       public static void batchNormalizationBackwardDX(GPUContext gCtx, String 
instName, MatrixObject image, MatrixObject dout,
+                       MatrixObject scale, MatrixObject dX, 
+                       // MatrixObject dScale, MatrixObject dBias,
                        double epsilon, MatrixObject resultSaveMean, 
MatrixObject resultSaveInvVariance) throws DMLRuntimeException {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : batchNormalizationBackward" + ", 
GPUContext=" + gCtx);
@@ -1133,7 +1139,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                int N = toInt(image.getNumRows());
                int C = toInt(scale.getNumRows());
                long CHW = image.getNumColumns();
-
+               
+               validateDimensions(scale, C, 1);
+               validateDimensions(dX, N, CHW);
+               validateDimensions(dout, N, CHW);
+               validateDimensions(resultSaveMean, C, 1);
+               validateDimensions(resultSaveInvVariance, C, 1);
+               
                // Allocate descriptors
                cudnnTensorDescriptor nCHWDescriptor = 
allocateNCHWDescriptors(gCtx, N, C, CHW,
                                new MatrixObject[] {image, dout},  new 
MatrixObject[] {dX});
@@ -1144,18 +1156,17 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                Pointer doutPtr = getDensePointerForCuDNN(gCtx, dout, instName);
                Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, 
instName);
                Pointer dXPtr = getDensePointerForCuDNN(gCtx, dX, instName);
-               Pointer dScalePtr = getDensePointerForCuDNN(gCtx, dScale, 
instName);
-               Pointer dBiasPtr = getDensePointerForCuDNN(gCtx, dBias, 
instName);
-               
+               Pointer dScalePtr = gCtx.allocate(instName, 
C*LibMatrixCUDA.sizeOfDataType); // getDensePointerForCuDNN(gCtx, dScale, 
instName);
+               Pointer dBiasPtr = gCtx.allocate(instName, 
C*LibMatrixCUDA.sizeOfDataType); //getDensePointerForCuDNN(gCtx, dBias, 
instName);           
                Pointer resultSaveMeanPtr = getDensePointerForCuDNN(gCtx, 
resultSaveMean, instName);
                Pointer resultSaveInvVariancePtr = 
getDensePointerForCuDNN(gCtx, resultSaveInvVariance, instName);
 
-
-               // ignoring resultSaveMean and resultSaveVariance as it 
requires state management
-               
checkStatus(cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), 
+               cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), 
                                
jcuda.jcudnn.cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL,  one(), zero(), one(), 
zero(),
                                nCHWDescriptor,  imagePtr, nCHWDescriptor, 
doutPtr, nCHWDescriptor, dXPtr,
-                               scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, 
epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr));
+                               scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, 
epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr);
+               gCtx.cudaFreeHelper(instName, dScalePtr, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, dBiasPtr, gCtx.EAGER_CUDA_FREE);
        }
        
        private static void validateBatchNormalizationDimensions(MatrixObject 
scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, 
int C) throws DMLRuntimeException {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java 
b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
index d96feac..d7e7b24 100644
--- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
@@ -55,8 +55,8 @@ public class BatchNormTest extends GPUTests {
                int imgSize = 32; 
                int numChannels = 3;
                double sparsity = 0.9;
-               String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") 
as batch_norm2d_old;\n "
-                               + "[output, ema_mean_upd, ema_var_upd, 
cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + 
numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, 
ema_var, 0.9, 1e-3)";
+               String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as 
batch_norm2d;\n "
+                               + "[output, ema_mean_upd, ema_var_upd, 
cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels 
+ ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 
0.9, 1e-3)";
                HashMap<String, Object> inputs = new HashMap<>();
                inputs.put("x", generateInputMatrix(spark, 32, 
numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
                inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 
0, 2, sparsity, seed));
@@ -68,19 +68,40 @@ public class BatchNormTest extends GPUTests {
                List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
                if(mode.equals("test")) {
                        assertHeavyHitterPresent("gpu_batch_norm2d_test");
-                       for(int i = 0; i < outputs.size(); i++) {
-                               assertEqualObjects(outCPU.get(i), 
outGPU.get(i));
-                       }
                }
                else {
-                       //assertHeavyHitterPresent("gpu_batch_norm2d_train");
-                       double [] threshold = new double[outputs.size()];
-                       Arrays.fill(threshold, getTHRESHOLD());
-                       // Handle loss of precision in CuDNN kernel 
-                       threshold[2] = 1e-3;
-                       for(int i = 0; i < outputs.size()-1; i++) {
-                               assertEqualObjects(outCPU.get(i), 
outGPU.get(i), threshold[i]);
-                       }
+                       assertHeavyHitterPresent("gpu_batch_norm2d_test");
+                       assertHeavyHitterPresent("gpu_reshape_colmeans");
+                       assertHeavyHitterPresent("gpu_update_ema_var");
                }
+               assertEqualObjects(outCPU.get(0), outGPU.get(0));
+               assertEqualObjects(outCPU.get(1), outGPU.get(1));
+               assertEqualObjects(outCPU.get(2), outGPU.get(2));
+               assertEqualObjects(outCPU.get(3), outGPU.get(3));
+               assertEqualObjects(outCPU.get(4), outGPU.get(4));
+       }
+       
+       @Test
+       public void testBatchNormBackward() {
+               int imgSize = 32; 
+               int numChannels = 3;
+               double sparsity = 0.9;
+               String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as 
batch_norm2d;\n "
+                               + "[output, ema_mean_upd, ema_var_upd, 
cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels 
+ ", " + imgSize + ", " + imgSize + ", \"train\", ema_mean, ema_var, 0.9, 
1e-3);\n"
+                               + "[dX, dgamma, dbeta] = 
batch_norm2d::backward(dout, cache_mean, cache_var, x, gamma, " + numChannels + 
", " + imgSize + ", " + imgSize + ", 1e-3);";
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, 32, 
numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
+               inputs.put("dout", generateInputMatrix(spark, 32, 
numChannels*imgSize*imgSize, 1, 5, sparsity, seed));
+               inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 
0, 2, sparsity, seed));
+               inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 
0, 2, sparsity, seed));
+               inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 
1, 3, 7, sparsity, seed));
+               inputs.put("ema_var", generateInputMatrix(spark, numChannels, 
1, 0, 2, sparsity, seed));
+               List<String> outputs = Arrays.asList("dX", "dgamma", "dbeta");
+               List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
+               List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
+               assertHeavyHitterPresent("gpu_batch_norm2d_bwd_dx");
+               assertEqualObjects(outCPU.get(0), outGPU.get(0), 1e-6);
+               assertEqualObjects(outCPU.get(1), outGPU.get(1));
+               assertEqualObjects(outCPU.get(2), outGPU.get(2));
        }
 }

Reply via email to