Repository: systemml Updated Branches: refs/heads/master 5aadb4b22 -> a6bca8851
[SYSTEMML-445] Added a rewrite for batch normalization train - This PR fuses a batch normalization train pattern into a FunctionOp. The method batchNormTrain in RewriteGPUSpecificOps performs the fusing. - This rewrite is only enabled if none of the outputs are persistent writes. It replaces the existing outputs of the matched pattern with transient reads. Closes #800. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a6bca885 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a6bca885 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a6bca885 Branch: refs/heads/master Commit: a6bca88512f3f542278709713706d256fad2cc17 Parents: 5aadb4b Author: Niketan Pansare <npan...@us.ibm.com> Authored: Mon Jul 16 14:50:36 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Mon Jul 16 14:52:01 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/FunctionOp.java | 28 +- src/main/java/org/apache/sysml/hops/Hop.java | 1 + .../hops/rewrite/RewriteGPUSpecificOps.java | 573 ++++++++++++++++++- .../org/apache/sysml/lops/FunctionCallCP.java | 12 +- src/main/java/org/apache/sysml/lops/Lop.java | 13 + .../org/apache/sysml/parser/DMLTranslator.java | 6 +- .../instructions/GPUInstructionParser.java | 1 + .../instructions/gpu/DnnGPUInstruction.java | 62 +- .../apache/sysml/test/gpu/BatchNormTest.java | 35 +- .../org/apache/sysml/test/gpu/GPUTests.java | 37 +- 10 files changed, 701 insertions(+), 67 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/FunctionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index 64963d9..aedaf81 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -169,17 +169,20 @@ public class FunctionOp extends Hop long outputValues = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0); return outputVectors+outputValues; } - else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) { + else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) { // TODO: To allow for initial version to always run on the GPU return 0; } - else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") ) { + else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) { return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0); } + else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) { + return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0); + } else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) { return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) + @@ -215,7 +218,8 @@ public class FunctionOp extends Hop return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0) + 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0); } - else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { + else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || + getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) { return 0; } else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) { @@ -240,7 +244,8 @@ public class FunctionOp extends Hop @Override public boolean isGPUEnabled() { if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") || - getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) + getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || + getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) return true; else return false; @@ -283,20 +288,25 @@ public class FunctionOp extends Hop checkAndSetForcedPlatform(); if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) { + boolean isBuiltinFunction = isBuiltinFunction(); // check if there is sufficient memory to execute this function - if( getFunctionName().equalsIgnoreCase("transformencode") ) { + if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("transformencode") ) { _etype = ((_etypeForced==ExecType.SPARK || (getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP); } - else if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) { + else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) { if(!DMLScript.USE_ACCELERATOR) throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU."); _etype = ExecType.GPU; } - else if( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { + else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) { _etype = DMLScript.USE_ACCELERATOR ? ExecType.GPU : ExecType.CP; } + else if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) { + // Only GPU implementation is supported + _etype = ExecType.GPU; + } else { // Since the memory estimate is only conservative, do not throw // exception if the estimated memory is larger than the budget @@ -312,6 +322,10 @@ public class FunctionOp extends Hop return _etype; } + + private boolean isBuiltinFunction() { + return getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE); + } @Override public void refreshSizeInformation() http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 5d357c6..d8f4424 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1537,6 +1537,7 @@ public abstract class Hop implements ParseInfo HopsData2String.put(DataOpTypes.PERSISTENTWRITE, "PWrite"); HopsData2String.put(DataOpTypes.TRANSIENTWRITE, "TWrite"); HopsData2String.put(DataOpTypes.TRANSIENTREAD, "TRead"); + HopsData2String.put(DataOpTypes.FUNCTIONOUTPUT, "FunOut"); } public static OpOp2 getOpOp2ForOuterVectorOperation(String op) http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java index 1c00c6f..b946178 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -20,47 +20,64 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.FunctionOp.FunctionType; import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOpDnn; import org.apache.sysml.hops.Hop.ReOrgOp; +import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.DnnOp; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; /* * This class contains GPU-specific rewrites for following patterns: * - * 1. batchNormTest: + * 1. batchNormTest: applied when mode="test" in batch normalization nn layer. * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) * hi = bias_add(bias_multiply(norm, gamma), beta) * * 2. channelSum: * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize)) + * + * 3. batchNormTrain: applied when mode="train" in batch normalization nn layer. + * This rewrite is only enabled if none of the outputs are persistent writes as it assumes that + * FunctionOp will introduce a transient writes. This rewrite replaces the existing outputs of the matched pattern with transient reads. + * */ public class RewriteGPUSpecificOps extends HopRewriteRule { + private static int _seq = 1; + @Override public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { if( roots == null ) return roots; //one pass rewrite-descend (rewrite created pattern) - for( Hop h : roots ) - rule_GPUKernels( h, false ); + for( int i = 0; i < roots.size(); i++ ) + rule_GPUKernels(roots, roots.get(i), false ); Hop.resetVisitStatus(roots, true); //one pass descend-rewrite (for rollup) - for( Hop h : roots ) - rule_GPUKernels( h, true ); + for( int i = 0; i < roots.size(); i++ ) + rule_GPUKernels(roots, roots.get(i), true ); Hop.resetVisitStatus(roots, true); return roots; @@ -72,12 +89,12 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { return root; //one pass rewrite-descend (rewrite created pattern) - rule_GPUKernels( root, false ); + rule_GPUKernels(null, root, false ); root.resetVisitStatus(); //one pass descend-rewrite (for rollup) - rule_GPUKernels( root, true ); + rule_GPUKernels(null, root, true ); return root; } @@ -85,10 +102,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { /** * Fuse the kernel * + * @param roots root operators * @param hop high-level operator * @param descendFirst true if recursively process children first */ - private void rule_GPUKernels(Hop hop, boolean descendFirst) + private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean descendFirst) { if(hop.isVisited()) return; @@ -99,13 +117,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { //process childs recursively first (to allow roll-up) if( descendFirst ) - rule_GPUKernels(hi, descendFirst); //see below + rule_GPUKernels(roots, hi, descendFirst); //see below + if(roots != null) { + hi = batchNormTrain(roots, hop, hi, i); + } hi = batchNormTest(hop, hi, i); hi = channelSums(hop, hi, i); if( !descendFirst ) - rule_GPUKernels(hi, descendFirst); + rule_GPUKernels(roots, hi, descendFirst); } hop.setVisited(); @@ -149,6 +170,10 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget(); } + private static boolean hasFirstInput(Hop h) { + return !(h == null || h.getInput() == null || h.getInput().size() < 1); + } + private static Hop getFirstInput(Hop h) { if(h == null || h.getInput() == null || h.getInput().size() < 1) { throw new RuntimeException("No input available for " + h); @@ -156,13 +181,24 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { return h.getInput().get(0); } + private static boolean hasSecondInput(Hop h) { + return !(h == null || h.getInput() == null || h.getInput().size() < 2); + } + private static Hop getSecondInput(Hop h) { if(h == null || h.getInput() == null || h.getInput().size() < 2) { - throw new RuntimeException("No input available for " + h); + throw new RuntimeException("Expected atleast two inputs for " + h); } return h.getInput().get(1); } + private static Hop getThirdInput(Hop h) { + if(h == null || h.getInput() == null || h.getInput().size() < 3) { + throw new RuntimeException("Expected atleast three inputs for " + h); + } + return h.getInput().get(2); + } + private static boolean isUnaryMinus(Hop h) { return HopRewriteUtils.isBinary(h, OpOp2.MINUS) && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0); @@ -200,13 +236,488 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { return hi; } + private static boolean isRowMeans(Hop h) { + return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row; + } + + private static boolean isRowVars(Hop h) { + return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row; + } + + private static boolean isRowVars(Hop h, Hop childHop) { + return isRowVars(h) && getFirstInput(h) == childHop; + } + + private static boolean isColMeans(Hop h) { + return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col; + } + + private static boolean isColVars(Hop h) { + return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col; + } + + private static boolean isReshape(Hop h) { + return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE; + } + + private static boolean isReshape(Hop h, long expectedRows, long expectedCols) { + return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE && + Hop.computeSizeInformation(getSecondInput(h)) == expectedRows && + Hop.computeSizeInformation(getThirdInput(h)) == expectedCols; + } + + private static boolean isBinaryAdd(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS; + } + + private static boolean isBinaryMSAdd(Hop h, double expectedValue) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS + && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR + && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue; + } + + private static boolean isBinaryMMAdd(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS + && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX; + } + + private static boolean isBinaryMSMult(Hop h, double expectedValue) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT + && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR + && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue; + } + + private static boolean isBinarySSMinus(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS + && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR; + } + + private static boolean isBinarySSDiv(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV + && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR; + } + + private static boolean isBinarySMDiv(Hop h, double expectedValue) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV + && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX + && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) == expectedValue; + } + + private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) { + if(hops != null) { + for(Hop h : hops) { + if(h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS) + return true; + } + } + return false; + } + + private static boolean isBinaryMSMult(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT + && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR; + } + + private static boolean isBinarySMMult(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT + && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR; + } + + /** + * Checks if the "mean" hop is a moving average of mean in batch normalization layer. + * + * @param mean hop to check against + * @param X input data + * @return true if the "mean" hop is a moving average of mean in batch normalization layer. + */ + private static boolean isBatchNormTrainMean(Hop mean, Hop X) { + // subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win) + // mean = rowMeans(subgrp_means) + return isRowMeans(mean) && isReshape(getFirstInput(mean)) && isColMeans(getFirstInput(getFirstInput(mean))) + && getFirstInput(getFirstInput(getFirstInput(mean))) == X; + } + + /** + * Checks for nrow(X) pattern + * + * @param expr hop to be matched + * @param X input X + * @return true if expr is nrow(X) else false + */ + private static boolean isNrowOfX(Hop expr, Hop X) { + return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == OpOp1.NROW && getFirstInput(expr) == X; + } + + /** + * Checks for the colVars(X) * ((N-1)/N) pattern + * + * @param expr hop to be matched + * @param X input X + * @param ignoreCorrectionTerm whether to ignore the correction term ((N-1)/N). + * @return true if expr is colVars(X) * ((N-1)/N) else false + */ + private static boolean isCorrectedColVars(Hop expr, Hop X, boolean ignoreCorrectionTerm) { + // colVars(X) * ((N-1)/N) + if(isColVars(expr) && getFirstInput(expr) == X) { + // Support no correction as well in this rewrite + return true; + } + else if(X.rowsKnown()) { + return isBinaryMSMult(expr, ((double)X.getDim1()-1)/X.getDim1()) && + isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X; + } + else if(isBinaryMSMult(expr) && + isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X) { + if(ignoreCorrectionTerm) { + return true; + } + Hop tmp = getSecondInput(expr); + // ((N-1)/N) + boolean isNMinus1Pattern = isBinarySSDiv(tmp) && isBinarySSMinus(getFirstInput(tmp)) && + getFirstInput(getFirstInput(tmp)) == getSecondInput(tmp) && + OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), new HashMap<>()) == 1; + boolean ret = isNMinus1Pattern && isNrowOfX(getSecondInput(tmp), X); + if(LOG.isDebugEnabled()) { + LOG.debug("Is the corrected column variance pattern for batch_norm_train rewrite when number of rows of X unknown matched:" + ret); + } + return ret; + } + return false; + } + + /** + * Checks if the "var" hop is a moving average of variance in batch normalization layer. + * + * @param mean previously matched mean hop + * @param var the hop to check against + * @param X input data hop + * @param subgrpMeans mean for subgroup mean + * @param ignoreCorrectionTerm whether to incore the correct term (see isCorrectedColVars method in this class) + * @return true if the "var" hop is a moving average of variance in batch normalization layer. + */ + private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop X, Hop subgrpMeans, boolean ignoreCorrectionTerm) { + long numChannels = Hop.computeSizeInformation(getSecondInput(getFirstInput(mean))); + long HW = Hop.computeSizeInformation(getThirdInput(getFirstInput(mean))); + // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) + // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) + return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && isRowMeans(getFirstInput(var)) && + // matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) + isReshape(getFirstInput(getFirstInput(var)), numChannels, HW) && + isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, ignoreCorrectionTerm) && + // rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) + isBinaryMSMult(getSecondInput(var), ((((double)HW)-1)/HW)) && + isRowVars(getFirstInput(getSecondInput(var)), subgrpMeans); + } + + /** + * Checks and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean + * + * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean + * @param mu value of mu + * @return an array [ema_mean_upd, ema_mean] if expression matched, else null + */ + private static Hop [] getUpdatedMovingAverageExpressions(Hop rhsTimesOp, double mu) { + if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || + !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0))) + return null; + + // Check (1-mu)*mean + double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>()); + Hop plusOp = rhsTimesOp.getParent().get(0); + Hop lhsTimesOp = null; + if(plusOp.getInput().get(0) == rhsTimesOp) { + lhsTimesOp = plusOp.getInput().get(1); + } + else { + lhsTimesOp = plusOp.getInput().get(0); + } + + if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 && + isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) { + return new Hop[] { + plusOp.getParent().get(0), + getSecondInput(lhsTimesOp), + getSecondInput(rhsTimesOp) + }; + } + return null; + } + + /** + * Checks (if exactly one of rhsTimesOps) and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean + * + * @param rhsTimesOps array list of hop representing BinaryOp of expression (1-mu)*mean + * @param mu value of mu + * @return an array [ema_mean_upd, ema_mean] if any of the expression matched, else null + */ + private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps, double mu) { + if(rhsTimesOps == null || rhsTimesOps.size() == 0) + return null; + + Hop [] ret = null; + for(Hop h : rhsTimesOps) { + boolean matched = isUpdatedMovingAverageExpression(h, mu); + if(matched && ret != null) { + return null; // Multiple matches, cannot decide which one to fuse + } + else if(matched) { + ret = getUpdatedMovingAverageExpressions(h, mu); + } + } + + return ret; + } + + /** + * Checks and returns the mu in the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean + * + * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean + * @return value of mu if the expression matched else null + */ + private static Double getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) { + if(rhsTimesOps == null || rhsTimesOps.size() == 0) + return null; + + Double ret = null; + for(Hop h : rhsTimesOps) { + boolean matched = isUpdatedMovingAverageExpression(h); + if(matched && ret != null) { + return null; // Multiple matches, cannot decide which one to fuse + } + else if(matched) { + ret = -(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>())-1); + } + } + return ret; + } + + /** + * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean + * + * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean + * @return true if expression matched + */ + private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) { + if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || + !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0))) + return false; + + // Check (1-mu)*mean + Hop plusOp = rhsTimesOp.getParent().get(0); + Hop lhsTimesOp = null; + if(plusOp.getInput().get(0) == rhsTimesOp) { + lhsTimesOp = plusOp.getInput().get(1); + } + else { + lhsTimesOp = plusOp.getInput().get(0); + } + + if(plusOp.getParent() != null && plusOp.getParent().size() == 1 && isBinarySMMult(lhsTimesOp)) { + return true; + } + return false; + } + + // ema_mean_upd = mu*ema_mean + (1-mu)*mean + // Returns true if expression matched, else false + private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, double mu) { + if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || + !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0))) + return false; + + // Check (1-mu)*mean + double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>()); + Hop plusOp = rhsTimesOp.getParent().get(0); + Hop lhsTimesOp = null; + if(plusOp.getInput().get(0) == rhsTimesOp) { + lhsTimesOp = plusOp.getInput().get(1); + } + else { + lhsTimesOp = plusOp.getInput().get(0); + } + + if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 && + isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) { + return true; + } + return false; + } + + /** + * Checks for the expression 1/sqrt(denom) + * + * @param denom denominator of the expression to be matched + * @return true if the expression 1/sqrt(denom) matched else false + */ + private static boolean isOneBySqrt(Hop denom) { + return denom.getParent() != null && denom.getParent().get(0) instanceof UnaryOp && + ((UnaryOp)denom.getParent().get(0)).getOp() == OpOp1.SQRT && + denom.getParent().get(0).getParent() != null && denom.getParent().get(0).getParent().size() == 1 && + isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1); + } + + /** + * Checks for the batch norm (mode="train") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar + * and returns a new FunctionOp if matched + * + * @param roots root hops of the given statement block + * @param parent parent of the input + * @param hi input to be matched + * @param pos position + * @return a new FunctionOp or hi + */ + private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop hi, int pos) + { + // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) + // hi = bias_add(bias_multiply(norm, gamma), beta) + // 2x for input and output and 1x for overhead + // fitsOnGPU(hi, 3) + if( hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) ) { + Hop norm = getFirstInput(getFirstInput(hi)); + if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) + && hasSecondInput(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm))) + && isOneDivideBySqrt(getSecondInput(norm))) { + double eps = 0; + Hop var = getFirstInput(getSecondInput(getSecondInput(norm))); + if(isBinaryAdd(var) && (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) { + // eps + ema_var + if(getFirstInput(var) instanceof LiteralOp) { + eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>()); + var = getSecondInput(var); + } + else { + eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>()); + var = getFirstInput(var); + } + } + // Generate batch norm test op + Hop X = getFirstInput(getFirstInput(norm)); + Hop mean = getSecondInput(getSecondInput(getFirstInput(norm))); + + if(hasFirstInput(mean) && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), false) && + mean.getParent() != null && mean.getParent().size() >= 2 && + var.getParent() != null && var.getParent().size() == 2) { + Hop gamma = getSecondInput(getFirstInput(hi)); + Hop beta = getSecondInput(hi); + + // Always get mu from variance as it will have exactly one match of fusion pattern + Double potentialMu = getMuFromUpdatedMovingAverageExpressions(var.getParent()); + if(potentialMu == null) + return hi; + double mu = potentialMu; + + Hop [] means = getUpdatedMovingAverageExpressions(mean.getParent(), mu); + Hop [] vars = getUpdatedMovingAverageExpressions(var.getParent(), mu); + if(means == null || vars == null) + return hi; + + Hop varPlusEps = null; + boolean isFirstBinaryAddOp = isAnyBinaryAdd(var.getParent().get(0).getParent()); + boolean isSecondBinaryAddOp = isAnyBinaryAdd(var.getParent().get(1).getParent()); + if(isFirstBinaryAddOp && !isSecondBinaryAddOp) { + varPlusEps = var.getParent().get(1); + } + else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) { + varPlusEps = var.getParent().get(0); + } + if(varPlusEps != null && isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) { + + Hop cache_var = varPlusEps.getParent().get(0).getParent().get(0); + Hop ema_mean_upd = means[0]; + Hop ema_var_upd = vars[0]; + Hop ema_mean = means[1]; + Hop ema_var = vars[1]; + Hop cache_mean = means[2]; + + + ArrayList<Hop> inHops = new ArrayList<Hop>(); + inHops.add(X); + inHops.add(gamma); + inHops.add(beta); + inHops.add(ema_mean); + inHops.add(ema_var); + inHops.add(new LiteralOp(eps)); + inHops.add(new LiteralOp(mu)); + Hop [] oldHops = {hi, ema_mean_upd, ema_var_upd, cache_mean, cache_var}; + + // Since FunctionOp adds transientwrite explicitly, persistent writes are not supported + if(!isAnyPersistentWrite(oldHops)) { + LOG.debug("Applied batchNormTrain rewrite."); + ArrayList<Hop> outputs = getMultiOutputHops(roots, oldHops); + FunctionOp ret = new FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, "batch_norm2d_train", + inHops, outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs); + Collections.reverse(roots); + roots.add(ret); + Collections.reverse(roots); + return ret; + } + } + + } + } + } + + return hi; + } + + // ------------------------------------------------------------ + /** + * Checks if any of the given output hop is a persistent write. + * + * @param outputHops output hops to check + * @return true if any of the hop is a persistent write else false. + */ + private static boolean isAnyPersistentWrite(Hop [] outputHops) { + for(Hop outHop : outputHops) { + if(HopRewriteUtils.isData(outHop, DataOpTypes.PERSISTENTWRITE)) + return true; + } + return false; + } + + /** + * Returns output hop for a multi-output FunctionOp to be created by rewrite. + * + * @param roots root hops of statement block + * @param oldHops old output hops of the pattern + * @return new output hops that should be passed to FunctionOp + */ + private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, Hop [] oldHops) { + ArrayList<Hop> ret = new ArrayList<>(); + for(int i = 0; i < oldHops.length; i++) { + // Create a transient read as FunctionOp will add a transient write. + if(HopRewriteUtils.isData(oldHops[i], DataOpTypes.PERSISTENTWRITE)) + throw new RuntimeException("Persistent write is not supported as output for the given rewrite." + oldHops[i]); + // Generate a new name if the old output was not transient write. + String name = HopRewriteUtils.isData(oldHops[i], DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++); + DataOp tRead = HopRewriteUtils.createTransientRead(name, oldHops[i]); + HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead); + ret.add(tRead); + // Remove old output from roots to avoid unnecessary computation. + if(roots.contains(oldHops[i])) { + roots.remove(oldHops[i]); + } + } + return ret; + } + // ------------------------------------------------------------ + + /** + * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar + * and returns a new DnnOp if matched + * + * @param parent parent of the input + * @param hi input to be matched + * @param pos position + * @return a new DnnOp or hi + */ private static Hop batchNormTest(Hop parent, Hop hi, int pos) { // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) // hi = bias_add(bias_multiply(norm, gamma), beta) // 2x for input and output and 1x for overhead - if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) { + if(hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) { Hop norm = getFirstInput(getFirstInput(hi)); - if(isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) + if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm))) && isOneDivideBySqrt(getSecondInput(norm))) { double eps = 0; @@ -226,20 +737,28 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { // Generate batch norm test op Hop X = getFirstInput(getFirstInput(norm)); Hop mean = getSecondInput(getSecondInput(getFirstInput(norm))); - Hop gamma = getSecondInput(getFirstInput(hi)); - Hop beta = getSecondInput(hi); - ArrayList<Hop> inHops = new ArrayList<Hop>(); - inHops.add(X); - inHops.add(gamma); - inHops.add(beta); - inHops.add(mean); - inHops.add(var); - inHops.add(new LiteralOp(eps)); - if(fitsOnGPU(inHops, true)) { - LOG.debug("Applied batchNormTest rewrite."); - Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), - OpOpDnn.BATCH_NORM2D_TEST, inHops); - return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + + // This guard disallows eager fusion of train batch normalization into test batch normalization + boolean potentialForBatchNormTrain = !X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), true); + if(!potentialForBatchNormTrain) { + Hop gamma = getSecondInput(getFirstInput(hi)); + Hop beta = getSecondInput(hi); + ArrayList<Hop> inHops = new ArrayList<Hop>(); + inHops.add(X); + inHops.add(gamma); + inHops.add(beta); + inHops.add(mean); + inHops.add(var); + inHops.add(new LiteralOp(eps)); + if(fitsOnGPU(inHops, true)) { + LOG.debug("Applied batchNormTest rewrite."); + Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), + OpOpDnn.BATCH_NORM2D_TEST, inHops); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + } + } + else { + LOG.debug("Skipping batchNormTest rewrite as there is potential for batch normalization train rewrite after recompilation."); } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/FunctionCallCP.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java index ac58335..711219a 100644 --- a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java +++ b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java @@ -42,8 +42,16 @@ public class FunctionCallCP extends Lop this(inputs, fnamespace, fname, outputs, et); if(outputHops != null) { _outputLops = new ArrayList<>(); - for(Hop h : outputHops) - _outputLops.add( h.constructLops() ); + setLevel(); + for(Hop h : outputHops) { + Lop outputLop = h.constructLops(); + _outputLops.add( outputLop ); + addOutput(outputLop); + // Update the output level if necessary for correct instruction ordering + if(outputLop.getLevel() <= getLevel()) { + outputLop.updateLevel(getLevel()+1); + } + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/Lop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java index 9e81496..885b8b9 100644 --- a/src/main/java/org/apache/sysml/lops/Lop.java +++ b/src/main/java/org/apache/sysml/lops/Lop.java @@ -345,6 +345,19 @@ public abstract class Lop lps.setLevel(inputs); } + protected void updateLevel(int newLevel) { + if(newLevel < getLevel()) { + throw new RuntimeException("Decrement the levels not supported."); + } + else if(newLevel > getLevel()) { + lps.setLevel(newLevel); + for(Lop out : outputs) { + if(out.getLevel() < newLevel+1) + out.updateLevel(newLevel+1); + } + } + } + /** * Method to get the location property of LOP * http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index b9e5f9d..7cf7418 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2000,8 +2000,8 @@ public class DMLTranslator String[] outputNames = new String[targetList.size()]; outputNames[0] = ((DataIdentifier)targetList.get(0)).getName(); outputNames[1] = ((DataIdentifier)targetList.get(1)).getName(); - outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[0])); - outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[1])); + outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename())); + outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename())); currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs); break; @@ -2233,7 +2233,7 @@ public class DMLTranslator String[] outputNames = new String[targetList.size()]; for ( int i=0; i < targetList.size(); i++ ) { outputNames[i] = ((DataIdentifier)targetList.get(i)).getName(); - Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[i]); + Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename()); outputs.add(output); } http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index 1122a24..01b10a8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -61,6 +61,7 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "batch_norm2d", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "batch_norm2d_train", GPUINSTRUCTION_TYPE.Dnn); // Matrix Multiply Operators String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary); http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java index b01b8d8..a36d0fc 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java @@ -351,12 +351,28 @@ public class DnnGPUInstruction extends GPUInstruction { CPOperand out = new CPOperand(parts[7]); return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0); } + else if (opcode.equalsIgnoreCase("batch_norm2d_train")) { + InstructionUtils.checkNumFields(parts, 12); + CPOperand in1 = new CPOperand(parts[1]); // image + CPOperand in2 = new CPOperand(parts[2]); // gamma + CPOperand in3 = new CPOperand(parts[3]); // beta + CPOperand in4 = new CPOperand(parts[4]); // ema_mean + CPOperand in5 = new CPOperand(parts[5]); // ema_var + CPOperand in6 = new CPOperand(parts[6]); // eps + CPOperand in7 = new CPOperand(parts[7]); // mu + CPOperand out = new CPOperand(parts[8]); // out + CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd + CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd + CPOperand out4 = new CPOperand(parts[11]); // cache_mean + CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var + return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0); + } else { throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str); } } - public void processBiasInstruction(String instOpcode, ExecutionContext ec) { + private void processBiasInstruction(String instOpcode, ExecutionContext ec) { GPUStatistics.incrementNoOfExecutedGPUInst(); MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); @@ -372,7 +388,7 @@ public class DnnGPUInstruction extends GPUInstruction { ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } - public void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException { + private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException { GPUStatistics.incrementNoOfExecutedGPUInst(); MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); @@ -420,7 +436,41 @@ public class DnnGPUInstruction extends GPUInstruction { ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } - public void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException { + private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException { + GPUStatistics.incrementNoOfExecutedGPUInst(); + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); + MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName()); + MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName()); + + double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue(); + double exponentialAverageFactor = 1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue(); + + MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); + MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); + MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); + MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); + MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); + + LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), + image, scale, bias, runningMean, runningVar, ret, + retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + ec.releaseMatrixInputForGPUInstruction(_input4.getName()); + ec.releaseMatrixInputForGPUInstruction(_input5.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output4.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output5.getName()); + } + + private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException { GPUStatistics.incrementNoOfExecutedGPUInst(); MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); @@ -485,7 +535,7 @@ public class DnnGPUInstruction extends GPUInstruction { ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } - public void processChannelSumsInstruction(ExecutionContext ec) { + private void processChannelSumsInstruction(ExecutionContext ec) { GPUStatistics.incrementNoOfExecutedGPUInst(); MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue(); @@ -667,6 +717,10 @@ public class DnnGPUInstruction extends GPUInstruction { processBatchNorm2dTestInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) { + processBatchNorm2dTrainInstruction(ec); + return; + } GPUStatistics.incrementNoOfExecutedGPUInst(); http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java index 83adad4..b8bb9b6 100644 --- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java +++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java @@ -46,10 +46,10 @@ public class BatchNormTest extends GPUTests { testBatchNormForward("test"); } -// @Test -// public void testBatchNormForwardTrain() { -// testBatchNormForward("train"); -// } + @Test + public void testBatchNormForwardTrain() { + testBatchNormForward("train"); + } private void testBatchNormForward(String mode) { int imgSize = 32; @@ -58,18 +58,29 @@ public class BatchNormTest extends GPUTests { String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") as batch_norm2d_old;\n " + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)"; HashMap<String, Object> inputs = new HashMap<>(); - inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 100, sparsity, seed)); - inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed)); - inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed)); - inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 40, 60, sparsity, seed)); - inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 5, 15, sparsity, seed)); + inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed)); + inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed)); + inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed)); + inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 3, 7, sparsity, seed)); + inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed)); List<String> outputs = Arrays.asList("output", "ema_mean_upd", "ema_var_upd", "cache_mean", "cache_var"); List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs); List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); - if(mode.equals("test")) + if(mode.equals("test")) { assertHeavyHitterPresent("gpu_batch_norm2d_test"); - for(int i = 0; i < outputs.size(); i++) { - assertEqualObjects(outCPU.get(i), outGPU.get(i)); + for(int i = 0; i < outputs.size(); i++) { + assertEqualObjects(outCPU.get(i), outGPU.get(i)); + } + } + else { + assertHeavyHitterPresent("gpu_batch_norm2d_train"); + double [] threshold = new double[outputs.size()]; + Arrays.fill(threshold, getTHRESHOLD()); + // Handle loss of precision in CuDNN kernel + threshold[2] = 1e-3; + for(int i = 0; i < outputs.size()-1; i++) { + assertEqualObjects(outCPU.get(i), outGPU.get(i), threshold[i]); + } } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/GPUTests.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java index e006fd2..cae2e33 100644 --- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java +++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java @@ -212,7 +212,7 @@ public abstract class GPUTests extends AutomatedTestBase { return in1; } - private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock actualMB) { + private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock actualMB, double threshold) { long rows = expectedMB.getNumRows(); long cols = expectedMB.getNumColumns(); boolean matrixNotEqual = false; @@ -222,7 +222,7 @@ public abstract class GPUTests extends AutomatedTestBase { double actualDouble = actualMB.quickGetValue(i, j); if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) { double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble); - if(relativeError >= getTHRESHOLD()) { + if(relativeError >= threshold) { matrixNotEqual = true; break; } @@ -250,12 +250,13 @@ public abstract class GPUTests extends AutomatedTestBase { * * @param expected expected matrix * @param actual actual matrix + * @param threshold relative threshold */ - private void assertEqualMatrices(Matrix expected, Matrix actual) { + private void assertEqualMatrices(Matrix expected, Matrix actual, double threshold) { try { // Faster way to compare two matrices MLContext cpuMLC = new MLContext(spark); - String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + getTHRESHOLD() + ");"; + String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + threshold + ");"; Script script = ScriptFactory.dmlFromString(scriptStr).in("X", expected).in("Y", actual).out("num_mismatch"); long num_mismatch = cpuMLC.execute(script).getLong("num_mismatch"); cpuMLC.close(); @@ -271,7 +272,7 @@ public abstract class GPUTests extends AutomatedTestBase { Assert.assertEquals(rows, actualMB.getNumRows()); Assert.assertEquals(cols, actualMB.getNumColumns()); - if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, actualMB); + if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, actualMB, threshold); for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { @@ -285,12 +286,12 @@ public abstract class GPUTests extends AutomatedTestBase { "Relative error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at [%d, %d]", relativeError, getTHRESHOLD(), expectedDouble, actualDouble, i, j); if(FLOATING_POINT_PRECISION.equals("double")) - Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD()); + Assert.assertTrue(format.toString(), relativeError < threshold); else - Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD() || absoluteError < getTHRESHOLD()); + Assert.assertTrue(format.toString(), relativeError < threshold || absoluteError < threshold); format.close(); } else { - Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD()); + Assert.assertEquals(expectedDouble, actualDouble, threshold); } } } @@ -349,6 +350,7 @@ public abstract class GPUTests extends AutomatedTestBase { // and other side effects. synchronized(GPUTests.class) { MLContext gpuMLC = new MLContext(spark); + // gpuMLC.setExplain(true); gpuMLC.setExplainLevel("recompile_runtime"); gpuMLC.setConfigProperty("sysml.floating.point.precision", FLOATING_POINT_PRECISION); if(IGNORE_CLEAR_MEMORY_BUG) gpuMLC.setConfigProperty("sysml.gpu.eager.cudaFree", "true"); @@ -366,7 +368,7 @@ public abstract class GPUTests extends AutomatedTestBase { return outputs; } } - + /** * Assert that the two objects are equal. Supported types are Boolean, Integer, String, Double and Matrix * @@ -374,6 +376,17 @@ public abstract class GPUTests extends AutomatedTestBase { * @param actual */ protected void assertEqualObjects(Object expected, Object actual) { + assertEqualObjects(expected, actual, getTHRESHOLD()); + } + + /** + * Assert that the two objects are equal. Supported types are Boolean, Integer, String, Double and Matrix + * + * @param expected expected value + * @param actual actual value + * @param threshold relative error threshold + */ + protected void assertEqualObjects(Object expected, Object actual, double threshold) { Assert.assertEquals(expected.getClass(), actual.getClass()); if (expected instanceof Boolean) { @@ -384,16 +397,16 @@ public abstract class GPUTests extends AutomatedTestBase { if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) { double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble); Assert.assertTrue("Comparing floating point numbers, relative error(" + relativeError - + ") is more than threshold (" + getTHRESHOLD() + ")", relativeError < getTHRESHOLD()); + + ") is more than threshold (" + threshold + ")", relativeError < threshold); } else { - Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD()); + Assert.assertEquals(expectedDouble, actualDouble, threshold); } } else if (expected instanceof String) { Assert.assertEquals(expected.toString(), actual.toString()); } else if (expected instanceof Integer) { Assert.assertEquals(((Integer) expected).intValue(), ((Integer) actual).intValue()); } else if (expected instanceof Matrix) - assertEqualMatrices((Matrix) expected, (Matrix) actual); + assertEqualMatrices((Matrix) expected, (Matrix) actual, threshold); else { Assert.fail("Invalid types for comparison"); }