Repository: systemml
Updated Branches:
  refs/heads/master 5aadb4b22 -> a6bca8851


[SYSTEMML-445] Added a rewrite for batch normalization train

- This PR fuses a batch normalization train pattern into a FunctionOp. The 
method batchNormTrain in RewriteGPUSpecificOps performs the fusing.
- This rewrite is only enabled if none of the outputs are persistent writes. It 
replaces the existing outputs of the matched pattern with transient reads.

Closes #800.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a6bca885
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a6bca885
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a6bca885

Branch: refs/heads/master
Commit: a6bca88512f3f542278709713706d256fad2cc17
Parents: 5aadb4b
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Mon Jul 16 14:50:36 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Mon Jul 16 14:52:01 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/FunctionOp.java  |  28 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |   1 +
 .../hops/rewrite/RewriteGPUSpecificOps.java     | 573 ++++++++++++++++++-
 .../org/apache/sysml/lops/FunctionCallCP.java   |  12 +-
 src/main/java/org/apache/sysml/lops/Lop.java    |  13 +
 .../org/apache/sysml/parser/DMLTranslator.java  |   6 +-
 .../instructions/GPUInstructionParser.java      |   1 +
 .../instructions/gpu/DnnGPUInstruction.java     |  62 +-
 .../apache/sysml/test/gpu/BatchNormTest.java    |  35 +-
 .../org/apache/sysml/test/gpu/GPUTests.java     |  37 +-
 10 files changed, 701 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 64963d9..aedaf81 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -169,17 +169,20 @@ public class FunctionOp extends Hop
                                long outputValues = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0);
                                return outputVectors+outputValues; 
                        }
-                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward")  ) {
+                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward") ) {
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
                        }
-                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d") ) {
+                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
                                return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0) +
                                                
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0) +
                                                
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), 
getOutputs().get(2).getDim2(), 1.0) +
                                                
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), 
getOutputs().get(3).getDim2(), 1.0) + 
                                                
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), 
getOutputs().get(4).getDim2(), 1.0);
                        }
+                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) {
+                       return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0);
+               }
                        else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
                                return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0) +
                                                
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0) +
@@ -215,7 +218,8 @@ public class FunctionOp extends Hop
                                return 
OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 
getInput().get(0).getDim2(), 1.0) 
                                                + 
3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 
1.0); 
                        }
-                       else if 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+                       else if 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
+                                       
getFunctionName().equalsIgnoreCase("batch_norm2d_train") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
                                return 0; 
                        }
                        else if ( getFunctionName().equalsIgnoreCase("lstm") || 
 getFunctionName().equalsIgnoreCase("lstm_backward") ) {
@@ -240,7 +244,8 @@ public class FunctionOp extends Hop
        @Override
        public boolean isGPUEnabled() {
                if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward") ||  
-                       getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) 
+                       getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
+                       
getFunctionName().equalsIgnoreCase("batch_norm2d_train") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_test")) 
                        return true;
                else
                        return false;
@@ -283,20 +288,25 @@ public class FunctionOp extends Hop
                checkAndSetForcedPlatform();
                
                if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) {
+                       boolean isBuiltinFunction = isBuiltinFunction();
                        // check if there is sufficient memory to execute this 
function
-                       if( 
getFunctionName().equalsIgnoreCase("transformencode") ) {
+                       if(isBuiltinFunction && 
getFunctionName().equalsIgnoreCase("transformencode") ) {
                                _etype = ((_etypeForced==ExecType.SPARK 
                                        || (getMemEstimate() >= 
OptimizerUtils.getLocalMemBudget()
                                                && 
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
                        }
-                       else if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward")) {
+                       else if(isBuiltinFunction && 
(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
                                if(!DMLScript.USE_ACCELERATOR)
                                        throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
                                _etype = ExecType.GPU;
                        }
-                       else if( 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+                       else if(isBuiltinFunction && 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
                                _etype = DMLScript.USE_ACCELERATOR ? 
ExecType.GPU : ExecType.CP;
                        }
+                       else if(isBuiltinFunction && 
getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
+                               // Only GPU implementation is supported
+                               _etype = ExecType.GPU;
+                       }
                        else {
                                // Since the memory estimate is only 
conservative, do not throw
                                // exception if the estimated memory is larger 
than the budget
@@ -312,6 +322,10 @@ public class FunctionOp extends Hop
                
                return _etype;
        }
+       
+       private boolean isBuiltinFunction() {
+               return 
getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE);
+       }
 
        @Override
        public void refreshSizeInformation()

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 5d357c6..d8f4424 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1537,6 +1537,7 @@ public abstract class Hop implements ParseInfo
                HopsData2String.put(DataOpTypes.PERSISTENTWRITE, "PWrite");
                HopsData2String.put(DataOpTypes.TRANSIENTWRITE, "TWrite");
                HopsData2String.put(DataOpTypes.TRANSIENTREAD, "TRead");
+               HopsData2String.put(DataOpTypes.FUNCTIONOUTPUT, "FunOut");
        }
 
        public static OpOp2 getOpOp2ForOuterVectorOperation(String op) 

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 1c00c6f..b946178 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -20,47 +20,64 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.FunctionOp.FunctionType;
 import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOpDnn;
 import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.DnnOp;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 
 /*
  * This class contains GPU-specific rewrites for following patterns:
  * 
- * 1. batchNormTest:
+ * 1. batchNormTest: applied when mode="test" in batch normalization nn layer.
  * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
  * hi = bias_add(bias_multiply(norm, gamma), beta)
  * 
  * 2. channelSum:
  * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
+ * 
+ * 3. batchNormTrain: applied when mode="train" in batch normalization nn 
layer.
+ * This rewrite is only enabled if none of the outputs are persistent writes 
as it assumes that 
+ * FunctionOp will introduce a transient writes. This rewrite replaces the 
existing outputs of the matched pattern with transient reads.
+ * 
  */
 public class RewriteGPUSpecificOps extends HopRewriteRule {
 
+       private static int _seq = 1;
+       
        @Override
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
                if( roots == null )
                        return roots;
 
                //one pass rewrite-descend (rewrite created pattern)
-               for( Hop h : roots )
-                       rule_GPUKernels( h, false );
+               for( int i = 0; i < roots.size(); i++ )
+                       rule_GPUKernels(roots, roots.get(i), false );
                Hop.resetVisitStatus(roots, true);
 
                //one pass descend-rewrite (for rollup) 
-               for( Hop h : roots )
-                       rule_GPUKernels( h, true );
+               for( int i = 0; i < roots.size(); i++ )
+                       rule_GPUKernels(roots, roots.get(i), true );
                Hop.resetVisitStatus(roots, true);
                
                return roots;
@@ -72,12 +89,12 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                        return root;
                
                //one pass rewrite-descend (rewrite created pattern)
-               rule_GPUKernels( root, false );
+               rule_GPUKernels(null, root, false );
                
                root.resetVisitStatus();
                
                //one pass descend-rewrite (for rollup) 
-               rule_GPUKernels( root, true );
+               rule_GPUKernels(null, root, true );
                
                return root;
        }
@@ -85,10 +102,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
        /**
         * Fuse the kernel
         * 
+        * @param roots root operators
         * @param hop high-level operator
         * @param descendFirst true if recursively process children first
         */
-       private void rule_GPUKernels(Hop hop, boolean descendFirst) 
+       private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean 
descendFirst) 
        {
                if(hop.isVisited())
                        return;
@@ -99,13 +117,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                        
                        //process childs recursively first (to allow roll-up)
                        if( descendFirst )
-                               rule_GPUKernels(hi, descendFirst); //see below
+                               rule_GPUKernels(roots, hi, descendFirst); //see 
below
                        
+                       if(roots != null) {
+                               hi = batchNormTrain(roots, hop, hi, i);
+                       }
                        hi = batchNormTest(hop, hi, i); 
                        hi = channelSums(hop, hi, i); 
        
                        if( !descendFirst )
-                               rule_GPUKernels(hi, descendFirst);
+                               rule_GPUKernels(roots, hi, descendFirst);
                }
 
                hop.setVisited();
@@ -149,6 +170,10 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                                memEst < OptimizerUtils.getLocalMemBudget() && 
memEst < GPUContextPool.initialGPUMemBudget();
        }
        
+       private static boolean hasFirstInput(Hop h) {
+               return !(h == null || h.getInput() == null || 
h.getInput().size() < 1);
+       }
+       
        private static Hop getFirstInput(Hop h) {
                if(h == null || h.getInput() == null || h.getInput().size() < 
1) {
                        throw new RuntimeException("No input available for " + 
h);
@@ -156,13 +181,24 @@ public class RewriteGPUSpecificOps extends HopRewriteRule 
{
                return h.getInput().get(0);
        }
        
+       private static boolean hasSecondInput(Hop h) {
+               return !(h == null || h.getInput() == null || 
h.getInput().size() < 2);
+       }
+       
        private static Hop getSecondInput(Hop h) {
                if(h == null || h.getInput() == null || h.getInput().size() < 
2) {
-                       throw new RuntimeException("No input available for " + 
h);
+                       throw new RuntimeException("Expected atleast two inputs 
for " + h);
                }
                return h.getInput().get(1);
        }
        
+       private static Hop getThirdInput(Hop h) {
+               if(h == null || h.getInput() == null || h.getInput().size() < 
3) {
+                       throw new RuntimeException("Expected atleast three 
inputs for " + h);
+               }
+               return h.getInput().get(2);
+       }
+       
        private static boolean isUnaryMinus(Hop h) {
                return HopRewriteUtils.isBinary(h, OpOp2.MINUS)
                        && 
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0);
@@ -200,13 +236,488 @@ public class RewriteGPUSpecificOps extends 
HopRewriteRule {
                return hi;
        }
        
+       private static boolean isRowMeans(Hop h) {
+               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row; 
+       }
+       
+       private static boolean isRowVars(Hop h) {
+               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row; 
+       }
+       
+       private static boolean isRowVars(Hop h, Hop childHop) {
+               return isRowVars(h) && getFirstInput(h) == childHop; 
+       }
+       
+       private static boolean isColMeans(Hop h) {
+               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col; 
+       }
+       
+       private static boolean isColVars(Hop h) {
+               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col; 
+       }
+       
+       private static boolean isReshape(Hop h) {
+               return h instanceof ReorgOp && ((ReorgOp)h).getOp() == 
ReOrgOp.RESHAPE;
+       }
+       
+       private static boolean isReshape(Hop h, long expectedRows, long 
expectedCols) {
+               return h instanceof ReorgOp && ((ReorgOp)h).getOp() == 
ReOrgOp.RESHAPE &&
+                               Hop.computeSizeInformation(getSecondInput(h)) 
== expectedRows && 
+                               Hop.computeSizeInformation(getThirdInput(h)) == 
expectedCols;
+       }
+       
+       private static boolean isBinaryAdd(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.PLUS;
+       }
+       
+       private static boolean isBinaryMSAdd(Hop h, double expectedValue) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.PLUS 
+                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
+                               && 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) 
== expectedValue;
+       }
+       
+       private static boolean isBinaryMMAdd(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.PLUS 
+                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
+       }
+       
+       private static boolean isBinaryMSMult(Hop h, double expectedValue) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
+                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
+                               && 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) 
== expectedValue;
+       }
+       
+       private static boolean isBinarySSMinus(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MINUS 
+                               && getFirstInput(h).getDataType() == 
DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
+       }
+       
+       private static boolean isBinarySSDiv(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.DIV 
+                               && getFirstInput(h).getDataType() == 
DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
+       }
+       
+       private static boolean isBinarySMDiv(Hop h, double expectedValue) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.DIV 
+                               && getFirstInput(h).getDataType() == 
DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX 
+                               && 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) 
== expectedValue;
+       }
+       
+       private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) {
+               if(hops != null) {
+                       for(Hop h : hops) {
+                               if(h instanceof BinaryOp && 
((BinaryOp)h).getOp() == OpOp2.PLUS)
+                                       return true;
+                       }
+               }
+               return false;
+       }
+       
+       private static boolean isBinaryMSMult(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
+                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR;
+       }
+       
+       private static boolean isBinarySMMult(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
+                               && getSecondInput(h).getDataType() == 
DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
+       }
+       
+       /**
+        * Checks if the "mean" hop is a moving average of mean in batch 
normalization layer.
+        *  
+        * @param mean hop to check against
+        * @param X input data
+        * @return true if the "mean" hop is a moving average of mean in batch 
normalization layer.
+        */
+       private static boolean isBatchNormTrainMean(Hop mean, Hop X) {
+               // subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+               // mean = rowMeans(subgrp_means)
+               return isRowMeans(mean) && isReshape(getFirstInput(mean)) && 
isColMeans(getFirstInput(getFirstInput(mean)))
+                               && 
getFirstInput(getFirstInput(getFirstInput(mean))) == X;
+       }
+       
+       /**
+        * Checks for nrow(X) pattern
+        * 
+        * @param expr hop to be matched
+        * @param X input X
+        * @return true if expr is nrow(X) else false
+        */
+       private static boolean isNrowOfX(Hop expr, Hop X) {
+               return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == 
OpOp1.NROW && getFirstInput(expr) == X;
+       }
+       
+       /**
+        * Checks for the colVars(X) * ((N-1)/N) pattern
+        * 
+        * @param expr hop to be matched
+        * @param X input X
+        * @param ignoreCorrectionTerm whether to ignore the correction term 
((N-1)/N).
+        * @return true if expr is colVars(X) * ((N-1)/N) else false
+        */
+       private static boolean isCorrectedColVars(Hop expr, Hop X, boolean 
ignoreCorrectionTerm) {
+               // colVars(X) * ((N-1)/N)
+               if(isColVars(expr) && getFirstInput(expr) == X) {
+                       // Support no correction as well in this rewrite
+                       return true;
+               }
+               else if(X.rowsKnown()) {
+                       return isBinaryMSMult(expr, 
((double)X.getDim1()-1)/X.getDim1()) && 
+                                       isColVars(getFirstInput(expr)) && 
getFirstInput(getFirstInput(expr)) == X;
+               }
+               else if(isBinaryMSMult(expr) && 
+                               isColVars(getFirstInput(expr)) && 
getFirstInput(getFirstInput(expr)) == X) {
+                       if(ignoreCorrectionTerm) {
+                               return true;
+                       }
+                       Hop tmp = getSecondInput(expr);
+                       // ((N-1)/N)
+                       boolean isNMinus1Pattern = isBinarySSDiv(tmp) && 
isBinarySSMinus(getFirstInput(tmp)) &&
+                                       getFirstInput(getFirstInput(tmp)) == 
getSecondInput(tmp) && 
+                                       
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), 
new HashMap<>()) == 1;
+                       boolean ret = isNMinus1Pattern && 
isNrowOfX(getSecondInput(tmp), X);
+                       if(LOG.isDebugEnabled()) {
+                               LOG.debug("Is the corrected column variance 
pattern for batch_norm_train rewrite when number of rows of X unknown matched:" 
+ ret);
+                       }
+                       return ret;
+               }
+               return false;
+       }
+       
+       /**
+        * Checks if the "var" hop is a moving average of variance in batch 
normalization layer.
+        *  
+        * @param mean previously matched mean hop
+        * @param var the hop to check against
+        * @param X input data hop
+        * @param subgrpMeans mean for subgroup mean
+        * @param ignoreCorrectionTerm whether to incore the correct term  (see 
isCorrectedColVars method in this class)
+        * @return true if the "var" hop is a moving average of variance in 
batch normalization layer.
+        */
+       private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop  X, 
Hop subgrpMeans, boolean ignoreCorrectionTerm) {
+               long numChannels = 
Hop.computeSizeInformation(getSecondInput(getFirstInput(mean)));
+               long HW = 
Hop.computeSizeInformation(getThirdInput(getFirstInput(mean)));
+               // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, 
cols=Hin*Win)
+               // var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+               return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && 
isRowMeans(getFirstInput(var)) &&  
+                               // matrix(colVars(X) * ((N-1)/N), rows=C, 
cols=Hin*Win)
+                               isReshape(getFirstInput(getFirstInput(var)), 
numChannels, HW) &&
+                               
isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, 
ignoreCorrectionTerm) &&
+                               // 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+                               isBinaryMSMult(getSecondInput(var), 
((((double)HW)-1)/HW)) && 
+                               isRowVars(getFirstInput(getSecondInput(var)), 
subgrpMeans);
+       }
+       
+       /**
+        * Checks and returns the matched hops for expression ema_mean_upd = 
mu*ema_mean + (1-mu)*mean  
+        * 
+        * @param rhsTimesOps hop representing BinaryOp of expression 
(1-mu)*mean 
+        * @param mu value of mu
+        * @return an array [ema_mean_upd, ema_mean] if expression matched, 
else null
+        */
+       private static Hop [] getUpdatedMovingAverageExpressions(Hop 
rhsTimesOp, double mu) {
+               if(rhsTimesOp == null || rhsTimesOp.getParent() == null || 
rhsTimesOp.getParent().size() != 1 || 
+                               !isBinarySMMult(rhsTimesOp) || 
!isBinaryAdd(rhsTimesOp.getParent().get(0)))
+                       return null;
+               
+               // Check (1-mu)*mean
+               double expectedOneMinusMu = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new 
HashMap<>());
+               Hop plusOp = rhsTimesOp.getParent().get(0); 
+               Hop lhsTimesOp = null;
+               if(plusOp.getInput().get(0) == rhsTimesOp) {
+                       lhsTimesOp = plusOp.getInput().get(1); 
+               }
+               else {
+                       lhsTimesOp = plusOp.getInput().get(0);
+               }
+               
+               if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null 
&& plusOp.getParent().size() == 1 &&  
+                       isBinarySMMult(lhsTimesOp) && 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new 
HashMap<>()) == mu) {
+                       return new Hop[] {
+                               plusOp.getParent().get(0),
+                               getSecondInput(lhsTimesOp), 
+                               getSecondInput(rhsTimesOp)
+                       };
+               }
+               return null;
+       }
+       
+       /**
+        * Checks (if exactly one of rhsTimesOps) and returns the matched hops 
for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean  
+        * 
+        * @param rhsTimesOps array list of hop representing BinaryOp of 
expression (1-mu)*mean 
+        * @param mu value of mu
+        * @return an array [ema_mean_upd, ema_mean] if any of the expression 
matched, else null
+        */
+       private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> 
rhsTimesOps, double mu) {
+               if(rhsTimesOps == null || rhsTimesOps.size() == 0)
+                       return null;
+               
+               Hop [] ret = null;
+               for(Hop h : rhsTimesOps) {
+                       boolean matched = isUpdatedMovingAverageExpression(h, 
mu);
+                       if(matched && ret != null) {
+                               return null; // Multiple matches, cannot decide 
which one to fuse
+                       }
+                       else if(matched) {
+                               ret = getUpdatedMovingAverageExpressions(h, mu);
+                       }
+               }
+               
+               return ret;
+       }
+       
+       /**
+        * Checks and returns the mu in the expression ema_mean_upd = 
mu*ema_mean + (1-mu)*mean
+        * 
+        * @param rhsTimesOps hop representing BinaryOp of expression 
(1-mu)*mean
+        * @return value of mu if the expression matched else null 
+        */
+       private static Double 
getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) {
+               if(rhsTimesOps == null || rhsTimesOps.size() == 0)
+                       return null;
+               
+               Double ret = null; 
+               for(Hop h : rhsTimesOps) {
+                       boolean matched = isUpdatedMovingAverageExpression(h);
+                       if(matched && ret != null) {
+                               return null; // Multiple matches, cannot decide 
which one to fuse
+                       }
+                       else if(matched) {
+                               ret = 
-(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new 
HashMap<>())-1);
+                       }
+               }
+               return ret;
+       }
+       
+       /**
+        * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
+        * 
+        * @param rhsTimesOps hop representing BinaryOp of expression 
(1-mu)*mean
+        * @return true if expression matched
+        */
+       private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) 
{
+               if(rhsTimesOp == null || rhsTimesOp.getParent() == null || 
rhsTimesOp.getParent().size() != 1 || 
+                               !isBinarySMMult(rhsTimesOp) || 
!isBinaryAdd(rhsTimesOp.getParent().get(0)))
+                       return false;
+               
+               // Check (1-mu)*mean
+               Hop plusOp = rhsTimesOp.getParent().get(0); 
+               Hop lhsTimesOp = null;
+               if(plusOp.getInput().get(0) == rhsTimesOp) {
+                       lhsTimesOp = plusOp.getInput().get(1); 
+               }
+               else {
+                       lhsTimesOp = plusOp.getInput().get(0);
+               }
+               
+               if(plusOp.getParent() != null && plusOp.getParent().size() == 1 
&& isBinarySMMult(lhsTimesOp)) {
+                       return true;
+               }
+               return false;
+       }
+       
+       // ema_mean_upd = mu*ema_mean + (1-mu)*mean
+       // Returns true if expression matched, else false
+       private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, 
double mu) {
+               if(rhsTimesOp == null || rhsTimesOp.getParent() == null || 
rhsTimesOp.getParent().size() != 1 || 
+                               !isBinarySMMult(rhsTimesOp) || 
!isBinaryAdd(rhsTimesOp.getParent().get(0)))
+                       return false;
+               
+               // Check (1-mu)*mean
+               double expectedOneMinusMu = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new 
HashMap<>());
+               Hop plusOp = rhsTimesOp.getParent().get(0); 
+               Hop lhsTimesOp = null;
+               if(plusOp.getInput().get(0) == rhsTimesOp) {
+                       lhsTimesOp = plusOp.getInput().get(1); 
+               }
+               else {
+                       lhsTimesOp = plusOp.getInput().get(0);
+               }
+               
+               if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null 
&& plusOp.getParent().size() == 1 &&  
+                       isBinarySMMult(lhsTimesOp) && 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new 
HashMap<>()) == mu) {
+                       return true;
+               }
+               return false;
+       }
+       
+       /**
+        * Checks for the expression 1/sqrt(denom)
+        * 
+        * @param denom denominator of the expression to be matched
+        * @return true if the expression 1/sqrt(denom) matched else false
+        */
+       private static boolean isOneBySqrt(Hop denom) {
+               return denom.getParent() != null && denom.getParent().get(0) 
instanceof UnaryOp &&
+                               ((UnaryOp)denom.getParent().get(0)).getOp() == 
OpOp1.SQRT &&
+                               denom.getParent().get(0).getParent() != null && 
denom.getParent().get(0).getParent().size() == 1 &&
+                               
isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1);
+       }
+       
+       /**
+        * Checks for the batch norm (mode="train") pattern using the helper 
isBatchNormTrainMean and isBatchNormTrainVar
+        * and returns a new FunctionOp if matched
+        * 
+        * @param roots root hops of the given statement block
+        * @param parent parent of the input
+        * @param hi input to be matched
+        * @param pos position
+        * @return a new FunctionOp or hi
+        */
+       private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop 
hi, int pos) 
+       {               
+               // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+               // hi = bias_add(bias_multiply(norm, gamma), beta)
+               // 2x for input and output and 1x for overhead
+               // fitsOnGPU(hi, 3)
+               if( hasFirstInput(hi) && isBiasAdd(hi) && 
isBiasMultiply(getFirstInput(hi)) ) { 
+                       Hop norm = getFirstInput(getFirstInput(hi));
+                       if(hasSecondInput(norm) && isBiasMultiply(norm) && 
isBiasAdd(getFirstInput(norm)) 
+                                       && hasSecondInput(getFirstInput(norm)) 
&& isUnaryMinus(getSecondInput(getFirstInput(norm)))
+                                       && 
isOneDivideBySqrt(getSecondInput(norm))) {
+                               double eps = 0;
+                               Hop var = 
getFirstInput(getSecondInput(getSecondInput(norm)));
+                               if(isBinaryAdd(var) && (getFirstInput(var) 
instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
+                                       // eps + ema_var
+                                       if(getFirstInput(var) instanceof 
LiteralOp) {
+                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
+                                               var = getSecondInput(var);
+                                       }
+                                       else {
+                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new 
HashMap<>());
+                                               var = getFirstInput(var);
+                                       }
+                               }
+                               // Generate batch norm test op
+                               Hop X = getFirstInput(getFirstInput(norm));
+                               Hop mean = 
getSecondInput(getSecondInput(getFirstInput(norm)));
+                               
+                               if(hasFirstInput(mean) && 
isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, 
getFirstInput(mean), false) &&
+                                       mean.getParent() != null && 
mean.getParent().size() >= 2 && 
+                                       var.getParent() != null && 
var.getParent().size() == 2) {
+                                       Hop gamma = 
getSecondInput(getFirstInput(hi));
+                                       Hop beta = getSecondInput(hi);
+                                       
+                                       // Always get mu from variance as it 
will have exactly one match of fusion pattern
+                                       Double potentialMu = 
getMuFromUpdatedMovingAverageExpressions(var.getParent());
+                                       if(potentialMu == null)
+                                               return hi;
+                                       double mu = potentialMu;
+                                       
+                                       Hop [] means = 
getUpdatedMovingAverageExpressions(mean.getParent(), mu);
+                                       Hop [] vars = 
getUpdatedMovingAverageExpressions(var.getParent(), mu);
+                                       if(means == null || vars == null)
+                                               return hi;
+                                       
+                                       Hop varPlusEps = null;
+                                       boolean isFirstBinaryAddOp = 
isAnyBinaryAdd(var.getParent().get(0).getParent());
+                    boolean isSecondBinaryAddOp = 
isAnyBinaryAdd(var.getParent().get(1).getParent());
+                    if(isFirstBinaryAddOp && !isSecondBinaryAddOp) {
+                            varPlusEps = var.getParent().get(1);
+                    }
+                    else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) {
+                            varPlusEps = var.getParent().get(0);
+                    }
+                                       if(varPlusEps != null && 
isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) {
+                                               
+                                               Hop cache_var = 
varPlusEps.getParent().get(0).getParent().get(0);
+                                               Hop ema_mean_upd = means[0];
+                                               Hop ema_var_upd = vars[0];
+                                               Hop ema_mean = means[1];
+                                               Hop ema_var = vars[1];
+                                               Hop cache_mean = means[2];
+                                               
+                                               
+                                               ArrayList<Hop> inHops = new 
ArrayList<Hop>();
+                                               inHops.add(X);
+                                               inHops.add(gamma);
+                                               inHops.add(beta);
+                                               inHops.add(ema_mean);
+                                               inHops.add(ema_var);
+                                               inHops.add(new LiteralOp(eps));
+                                               inHops.add(new LiteralOp(mu));
+                                               Hop [] oldHops = {hi, 
ema_mean_upd, ema_var_upd, cache_mean, cache_var};
+                                               
+                                               // Since FunctionOp adds 
transientwrite explicitly, persistent writes are not supported
+                                               
if(!isAnyPersistentWrite(oldHops)) {
+                                                       LOG.debug("Applied 
batchNormTrain rewrite.");
+                                                       ArrayList<Hop> outputs 
= getMultiOutputHops(roots, oldHops);
+                                                       FunctionOp ret = new 
FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, 
"batch_norm2d_train", 
+                                                                       inHops, 
outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs);
+                                                       
Collections.reverse(roots);
+                                                       roots.add(ret);
+                                                       
Collections.reverse(roots);
+                                                       return ret;
+                                               }
+                                       }
+                                       
+                               }
+                       }                       
+               }
+               
+               return hi;
+       }
+       
+       // ------------------------------------------------------------
+       /**
+        * Checks if any of the given output hop is a persistent write.
+        * 
+        * @param outputHops output hops to check
+        * @return true if any of the hop is a persistent write else false.
+        */
+       private static boolean isAnyPersistentWrite(Hop [] outputHops) {
+               for(Hop outHop : outputHops) {
+                       if(HopRewriteUtils.isData(outHop, 
DataOpTypes.PERSISTENTWRITE))
+                               return true;
+               }
+               return false;
+       }
+       
+       /**
+        * Returns output hop for a multi-output FunctionOp to be created by 
rewrite.
+        * 
+        * @param roots root hops of statement block
+        * @param oldHops old output hops of the pattern
+        * @return new output hops that should be passed to FunctionOp
+        */
+       private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, 
Hop [] oldHops) {
+               ArrayList<Hop> ret = new ArrayList<>();
+               for(int i = 0; i < oldHops.length; i++) {
+                       // Create a transient read as FunctionOp will add a 
transient write.
+                       if(HopRewriteUtils.isData(oldHops[i], 
DataOpTypes.PERSISTENTWRITE))
+                               throw new RuntimeException("Persistent write is 
not supported as output for the given rewrite." + oldHops[i]);
+                       // Generate a new name if the old output was not 
transient write.
+                       String name = HopRewriteUtils.isData(oldHops[i], 
DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++);
+                       DataOp tRead = 
HopRewriteUtils.createTransientRead(name, oldHops[i]);
+                       
HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead);
+                       ret.add(tRead);
+                       // Remove old output from roots to avoid unnecessary 
computation.
+                       if(roots.contains(oldHops[i])) {
+                               roots.remove(oldHops[i]);
+                       }
+               }
+               return ret;
+       }
+       // ------------------------------------------------------------
+       
+       /**
+        * Checks for the batch norm (mode="test") pattern using the helper 
isBatchNormTrainMean and isBatchNormTrainVar
+        * and returns a new DnnOp if matched
+        * 
+        * @param parent parent of the input
+        * @param hi input to be matched
+        * @param pos position
+        * @return a new DnnOp or hi
+        */
        private static Hop batchNormTest(Hop parent, Hop hi, int pos) {
                // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
                // hi = bias_add(bias_multiply(norm, gamma), beta)
                // 2x for input and output and 1x for overhead
-               if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && 
fitsOnGPU(hi, 3) ) {
+               if(hasFirstInput(hi) && isBiasAdd(hi) && 
isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
                        Hop norm = getFirstInput(getFirstInput(hi));
-                       if(isBiasMultiply(norm) && 
isBiasAdd(getFirstInput(norm)) 
+                       if(hasSecondInput(norm) && isBiasMultiply(norm) && 
isBiasAdd(getFirstInput(norm)) 
                                        && 
isUnaryMinus(getSecondInput(getFirstInput(norm)))
                                        && 
isOneDivideBySqrt(getSecondInput(norm))) {
                                double eps = 0;
@@ -226,20 +737,28 @@ public class RewriteGPUSpecificOps extends HopRewriteRule 
{
                                // Generate batch norm test op
                                Hop X = getFirstInput(getFirstInput(norm));
                                Hop mean = 
getSecondInput(getSecondInput(getFirstInput(norm)));
-                               Hop gamma = getSecondInput(getFirstInput(hi));
-                               Hop beta = getSecondInput(hi);
-                               ArrayList<Hop> inHops = new ArrayList<Hop>();
-                               inHops.add(X);
-                               inHops.add(gamma);
-                               inHops.add(beta);
-                               inHops.add(mean);
-                               inHops.add(var);
-                               inHops.add(new LiteralOp(eps));
-                               if(fitsOnGPU(inHops, true)) {
-                                       LOG.debug("Applied batchNormTest 
rewrite.");
-                                       Hop newHop = new DnnOp(hi.getName(), 
hi.getDataType(), hi.getValueType(),
-                                                       
OpOpDnn.BATCH_NORM2D_TEST, inHops);
-                                       return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+                               
+                               // This guard disallows eager fusion of train 
batch normalization into test batch normalization
+                               boolean potentialForBatchNormTrain = 
!X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, 
var, X, getFirstInput(mean), true);
+                               if(!potentialForBatchNormTrain) {
+                                       Hop gamma = 
getSecondInput(getFirstInput(hi));
+                                       Hop beta = getSecondInput(hi);
+                                       ArrayList<Hop> inHops = new 
ArrayList<Hop>();
+                                       inHops.add(X);
+                                       inHops.add(gamma);
+                                       inHops.add(beta);
+                                       inHops.add(mean);
+                                       inHops.add(var);
+                                       inHops.add(new LiteralOp(eps));
+                                       if(fitsOnGPU(inHops, true)) {
+                                               LOG.debug("Applied 
batchNormTest rewrite.");
+                                               Hop newHop = new 
DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
+                                                               
OpOpDnn.BATCH_NORM2D_TEST, inHops);
+                                               return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+                                       }
+                               }
+                               else {
+                                       LOG.debug("Skipping batchNormTest 
rewrite as there is potential for batch normalization train rewrite after 
recompilation.");
                                }
                        }
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java 
b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
index ac58335..711219a 100644
--- a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
+++ b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
@@ -42,8 +42,16 @@ public class FunctionCallCP extends Lop
                this(inputs, fnamespace, fname, outputs, et);
                if(outputHops != null) {
                        _outputLops = new ArrayList<>();
-                       for(Hop h : outputHops)
-                               _outputLops.add( h.constructLops() );
+                       setLevel();
+                       for(Hop h : outputHops) {
+                               Lop outputLop = h.constructLops();
+                               _outputLops.add( outputLop );
+                               addOutput(outputLop);
+                               // Update the output level if necessary for 
correct instruction ordering
+                               if(outputLop.getLevel() <= getLevel()) {
+                                       outputLop.updateLevel(getLevel()+1);
+                               }
+                       }
                }
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/Lop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Lop.java 
b/src/main/java/org/apache/sysml/lops/Lop.java
index 9e81496..885b8b9 100644
--- a/src/main/java/org/apache/sysml/lops/Lop.java
+++ b/src/main/java/org/apache/sysml/lops/Lop.java
@@ -345,6 +345,19 @@ public abstract class Lop
                lps.setLevel(inputs);
        }
        
+       protected void updateLevel(int newLevel) {
+               if(newLevel < getLevel()) {
+                       throw new RuntimeException("Decrement the levels not 
supported.");
+               }
+               else if(newLevel > getLevel()) {
+                       lps.setLevel(newLevel);
+                       for(Lop out : outputs) {
+                               if(out.getLevel() < newLevel+1)
+                                       out.updateLevel(newLevel+1);
+                       }
+               }
+       }
+       
        /**
         * Method to get the location property of LOP
         * 

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index b9e5f9d..7cf7418 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2000,8 +2000,8 @@ public class DMLTranslator
                                String[] outputNames = new 
String[targetList.size()]; 
                                outputNames[0] = 
((DataIdentifier)targetList.get(0)).getName();
                                outputNames[1] = 
((DataIdentifier)targetList.get(1)).getName();
-                               outputs.add(new DataOp(outputNames[0], 
DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, 
outputNames[0]));
-                               outputs.add(new DataOp(outputNames[1], 
DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, 
outputNames[1]));
+                               outputs.add(new DataOp(outputNames[0], 
DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, 
inputs.get(0).getFilename()));
+                               outputs.add(new DataOp(outputNames[1], 
DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, 
inputs.get(0).getFilename()));
                                
                                currBuiltinOp = new FunctionOp(ftype, 
nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs);
                                break;
@@ -2233,7 +2233,7 @@ public class DMLTranslator
                        String[] outputNames = new String[targetList.size()]; 
                        for ( int i=0; i < targetList.size(); i++ ) {
                                outputNames[i] = 
((DataIdentifier)targetList.get(i)).getName();
-                               Hop output = new DataOp(outputNames[i], 
DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, 
outputNames[i]);
+                               Hop output = new DataOp(outputNames[i], 
DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, 
inputs.get(0).getFilename());
                                outputs.add(output);
                        }
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 1122a24..01b10a8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -61,6 +61,7 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "batch_norm2d",           
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d_backward",  
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d_test",      
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "batch_norm2d_train",      
GPUINSTRUCTION_TYPE.Dnn);
                
                // Matrix Multiply Operators
                String2GPUInstructionType.put( "ba+*",  
GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index b01b8d8..a36d0fc 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -351,12 +351,28 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand out = new CPOperand(parts[7]);
                        return new DnnGPUInstruction(in, in2, in3, in4, in5, 
in6, out, opcode, str, 0);
                }
+               else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
+                       InstructionUtils.checkNumFields(parts, 12);
+                       CPOperand in1 = new CPOperand(parts[1]); // image
+                       CPOperand in2 = new CPOperand(parts[2]); // gamma
+                       CPOperand in3 = new CPOperand(parts[3]); // beta
+                       CPOperand in4 = new CPOperand(parts[4]); // ema_mean
+                       CPOperand in5 = new CPOperand(parts[5]); // ema_var
+                       CPOperand in6 = new CPOperand(parts[6]); // eps
+                       CPOperand in7 = new CPOperand(parts[7]); // mu
+                       CPOperand out = new CPOperand(parts[8]);  // out
+                       CPOperand out2 = new CPOperand(parts[9]); // 
ema_mean_upd
+                       CPOperand out3 = new CPOperand(parts[10]); // 
ema_var_upd
+                       CPOperand out4 = new CPOperand(parts[11]); // cache_mean
+                       CPOperand out5 = new CPOperand(parts[12]); // 
cache_inv_var
+                       return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
+               }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnGPUInstruction: " + str);      
                }
        }
 
-       public void processBiasInstruction(String instOpcode, ExecutionContext 
ec) {
+       private void processBiasInstruction(String instOpcode, ExecutionContext 
ec) {
                GPUStatistics.incrementNoOfExecutedGPUInst();
                MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
                MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
@@ -372,7 +388,7 @@ public class DnnGPUInstruction extends GPUInstruction {
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
-       public void processBatchNorm2dInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
+       private void processBatchNorm2dInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
                GPUStatistics.incrementNoOfExecutedGPUInst();
                MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
                MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
@@ -420,7 +436,41 @@ public class DnnGPUInstruction extends GPUInstruction {
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
-       public void processBatchNorm2dTestInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
+       private void processBatchNorm2dTrainInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
+               GPUStatistics.incrementNoOfExecutedGPUInst();
+               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
+               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
+               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
+               MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
+               MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, 
_input5.getName());
+               
+               double epsilon = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getDoubleValue();
+               double exponentialAverageFactor = 
1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), 
_input7.isLiteral()).getDoubleValue();
+               
+               MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
+               MatrixObject retRunningMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
+               MatrixObject retRunningVar = 
getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
+               MatrixObject resultSaveMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
+               MatrixObject resultSaveInvVariance = 
getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
+               
+               
LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), 
getExtendedOpcode(), 
+                       image, scale, bias, runningMean, runningVar, ret, 
+                       retRunningMean, retRunningVar, epsilon, 
exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
+       }
+       
+       private void processBatchNorm2dTestInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
                GPUStatistics.incrementNoOfExecutedGPUInst();
                MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
                MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
@@ -485,7 +535,7 @@ public class DnnGPUInstruction extends GPUInstruction {
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
-       public void processChannelSumsInstruction(ExecutionContext ec) {
+       private void processChannelSumsInstruction(ExecutionContext ec) {
                GPUStatistics.incrementNoOfExecutedGPUInst();
                MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
                int C = (int) ec.getScalarInput(_input2.getName(), 
_input2.getValueType(), _input2.isLiteral()).getLongValue();
@@ -667,6 +717,10 @@ public class DnnGPUInstruction extends GPUInstruction {
                        processBatchNorm2dTestInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
+                       processBatchNorm2dTrainInstruction(ec);
+                       return;
+               }
                
                GPUStatistics.incrementNoOfExecutedGPUInst();
                                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java 
b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
index 83adad4..b8bb9b6 100644
--- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
@@ -46,10 +46,10 @@ public class BatchNormTest extends GPUTests {
                testBatchNormForward("test");
        }
        
-//     @Test
-//     public void testBatchNormForwardTrain() {
-//             testBatchNormForward("train");
-//     }
+       @Test
+       public void testBatchNormForwardTrain() {
+               testBatchNormForward("train");
+       }
        
        private void testBatchNormForward(String mode) {
                int imgSize = 32; 
@@ -58,18 +58,29 @@ public class BatchNormTest extends GPUTests {
                String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") 
as batch_norm2d_old;\n "
                                + "[output, ema_mean_upd, ema_var_upd, 
cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + 
numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, 
ema_var, 0.9, 1e-3)";
                HashMap<String, Object> inputs = new HashMap<>();
-               inputs.put("x", generateInputMatrix(spark, 32, 
numChannels*imgSize*imgSize, 0, 100, sparsity, seed));
-               inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 
0, 10, sparsity, seed));
-               inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 
0, 10, sparsity, seed));
-               inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 
1, 40, 60, sparsity, seed));
-               inputs.put("ema_var", generateInputMatrix(spark, numChannels, 
1, 5, 15, sparsity, seed));
+               inputs.put("x", generateInputMatrix(spark, 32, 
numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
+               inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 
0, 2, sparsity, seed));
+               inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 
0, 2, sparsity, seed));
+               inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 
1, 3, 7, sparsity, seed));
+               inputs.put("ema_var", generateInputMatrix(spark, numChannels, 
1, 0, 2, sparsity, seed));
                List<String> outputs = Arrays.asList("output", "ema_mean_upd", 
"ema_var_upd", "cache_mean", "cache_var");
                List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
                List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
-               if(mode.equals("test"))
+               if(mode.equals("test")) {
                        assertHeavyHitterPresent("gpu_batch_norm2d_test");
-               for(int i = 0; i < outputs.size(); i++) {
-                       assertEqualObjects(outCPU.get(i), outGPU.get(i));
+                       for(int i = 0; i < outputs.size(); i++) {
+                               assertEqualObjects(outCPU.get(i), 
outGPU.get(i));
+                       }
+               }
+               else {
+                       assertHeavyHitterPresent("gpu_batch_norm2d_train");
+                       double [] threshold = new double[outputs.size()];
+                       Arrays.fill(threshold, getTHRESHOLD());
+                       // Handle loss of precision in CuDNN kernel 
+                       threshold[2] = 1e-3;
+                       for(int i = 0; i < outputs.size()-1; i++) {
+                               assertEqualObjects(outCPU.get(i), 
outGPU.get(i), threshold[i]);
+                       }
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java 
b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index e006fd2..cae2e33 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -212,7 +212,7 @@ public abstract class GPUTests extends AutomatedTestBase {
                return in1;
        }
        
-       private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock 
actualMB) {
+       private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock 
actualMB, double threshold) {
                long rows = expectedMB.getNumRows();
                long cols = expectedMB.getNumColumns();
                boolean matrixNotEqual = false;
@@ -222,7 +222,7 @@ public abstract class GPUTests extends AutomatedTestBase {
                                double actualDouble = actualMB.quickGetValue(i, 
j);
                                if (expectedDouble != 0.0 && 
!Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
                                        double relativeError = 
Math.abs((expectedDouble - actualDouble) / expectedDouble);
-                                       if(relativeError >= getTHRESHOLD()) {
+                                       if(relativeError >= threshold) {
                                                matrixNotEqual = true;
                                                break;
                                        }
@@ -250,12 +250,13 @@ public abstract class GPUTests extends AutomatedTestBase {
         *
         * @param expected expected matrix
         * @param actual   actual matrix
+        * @param threshold relative threshold
         */
-       private void assertEqualMatrices(Matrix expected, Matrix actual) {
+       private void assertEqualMatrices(Matrix expected, Matrix actual, double 
threshold) {
                try {
                        // Faster way to compare two matrices
                        MLContext cpuMLC = new MLContext(spark);
-                       String scriptStr = "num_mismatch = sum((abs(X - Y) / X) 
> " + getTHRESHOLD() + ");";
+                       String scriptStr = "num_mismatch = sum((abs(X - Y) / X) 
> " + threshold + ");";
                        Script script = 
ScriptFactory.dmlFromString(scriptStr).in("X", expected).in("Y", 
actual).out("num_mismatch");
                        long num_mismatch = 
cpuMLC.execute(script).getLong("num_mismatch");
                        cpuMLC.close();
@@ -271,7 +272,7 @@ public abstract class GPUTests extends AutomatedTestBase {
                        Assert.assertEquals(rows, actualMB.getNumRows());
                        Assert.assertEquals(cols, actualMB.getNumColumns());
 
-                       if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, 
actualMB);
+                       if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, 
actualMB, threshold);
                        
                        for (int i = 0; i < rows; i++) {
                                for (int j = 0; j < cols; j++) {
@@ -285,12 +286,12 @@ public abstract class GPUTests extends AutomatedTestBase {
                                                                "Relative 
error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at 
[%d, %d]",
                                                                relativeError, 
getTHRESHOLD(), expectedDouble, actualDouble, i, j);
                                                
if(FLOATING_POINT_PRECISION.equals("double"))
-                                                       
Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD());
+                                                       
Assert.assertTrue(format.toString(), relativeError < threshold);
                                                else
-                                                       
Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD() || 
absoluteError < getTHRESHOLD());
+                                                       
Assert.assertTrue(format.toString(), relativeError < threshold || absoluteError 
< threshold);
                                                format.close();
                                        } else {
-                                               
Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
+                                               
Assert.assertEquals(expectedDouble, actualDouble, threshold);
                                        }
                                }
                        }
@@ -349,6 +350,7 @@ public abstract class GPUTests extends AutomatedTestBase {
                // and other side effects.
                synchronized(GPUTests.class) {
                        MLContext gpuMLC = new MLContext(spark);
+                       // gpuMLC.setExplain(true); 
gpuMLC.setExplainLevel("recompile_runtime");
                        
gpuMLC.setConfigProperty("sysml.floating.point.precision", 
FLOATING_POINT_PRECISION);
                        if(IGNORE_CLEAR_MEMORY_BUG)
                                
gpuMLC.setConfigProperty("sysml.gpu.eager.cudaFree", "true");
@@ -366,7 +368,7 @@ public abstract class GPUTests extends AutomatedTestBase {
                        return outputs;
                }
        }
-
+       
        /**
         * Assert that the two objects are equal. Supported types are Boolean, 
Integer, String, Double and Matrix
         *
@@ -374,6 +376,17 @@ public abstract class GPUTests extends AutomatedTestBase {
         * @param actual
         */
        protected void assertEqualObjects(Object expected, Object actual) {
+               assertEqualObjects(expected, actual, getTHRESHOLD());
+       }
+
+       /**
+        * Assert that the two objects are equal. Supported types are Boolean, 
Integer, String, Double and Matrix
+        *
+        * @param expected expected value
+        * @param actual actual value 
+        * @param threshold relative error threshold
+        */
+       protected void assertEqualObjects(Object expected, Object actual, 
double threshold) {
                Assert.assertEquals(expected.getClass(), actual.getClass());
 
                if (expected instanceof Boolean) {
@@ -384,16 +397,16 @@ public abstract class GPUTests extends AutomatedTestBase {
                        if (expectedDouble != 0.0 && 
!Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
                                double relativeError = Math.abs((expectedDouble 
- actualDouble) / expectedDouble);
                                Assert.assertTrue("Comparing floating point 
numbers, relative error(" + relativeError
-                                               + ") is more than threshold (" 
+ getTHRESHOLD() + ")", relativeError < getTHRESHOLD());
+                                               + ") is more than threshold (" 
+ threshold + ")", relativeError < threshold);
                        } else {
-                               Assert.assertEquals(expectedDouble, 
actualDouble, getTHRESHOLD());
+                               Assert.assertEquals(expectedDouble, 
actualDouble, threshold);
                        }
                } else if (expected instanceof String) {
                        Assert.assertEquals(expected.toString(), 
actual.toString());
                } else if (expected instanceof Integer) {
                        Assert.assertEquals(((Integer) expected).intValue(), 
((Integer) actual).intValue());
                } else if (expected instanceof Matrix)
-                       assertEqualMatrices((Matrix) expected, (Matrix) actual);
+                       assertEqualMatrices((Matrix) expected, (Matrix) actual, 
threshold);
                else {
                        Assert.fail("Invalid types for comparison");
                }

Reply via email to