Repository: systemml Updated Branches: refs/heads/master 81419ae6a -> 0f36780a8
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java index e736a1c..d620de9 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java @@ -19,8 +19,8 @@ package org.apache.sysml.runtime.instructions.gpu; import java.util.ArrayList; - import jcuda.Pointer; +import jcuda.jcudnn.JCudnn; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; @@ -32,7 +32,6 @@ import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig; import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA; import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.util.DnnUtils; @@ -57,12 +56,14 @@ public class DnnGPUInstruction extends GPUInstruction { private ArrayList<CPOperand> _stride = new ArrayList<>(); private ArrayList<CPOperand> _padding = new ArrayList<>(); private double _intermediateMemoryBudget = 0; + private GPUContext gCtx; + private String instName; public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) { + if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward") || opcode.equals("inv_var") )) { throw new DMLRuntimeException( - "Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + "Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward or inv_var, but found " + opcode); } _input1 = in1; @@ -112,8 +113,8 @@ public class DnnGPUInstruction extends GPUInstruction { public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - if( !opcode.equals("channel_sums") ) { - throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode); + if( !(opcode.equals("channel_sums") || opcode.equals("reshape_colmeans") || opcode.equals("update_ema") ) ) { + throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums or reshape_colmeans or update_ema, but found " + opcode); } _input1 = in1; _input2 = in2; @@ -126,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction { public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - if( !opcode.equals("update_nesterov_x") ) { + if( !( opcode.equals("update_nesterov_x")) ) { throw new DMLRuntimeException("Incorrect opcode: " + opcode); } _input1 = in1; @@ -182,6 +183,22 @@ public class DnnGPUInstruction extends GPUInstruction { _intermediateMemoryBudget = intermediateMemoryBudget; } + public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, + CPOperand out, String opcode, String istr, double intermediateMemoryBudget) { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + if( !(opcode.equals("update_ema_var") || opcode.equals("batch_norm2d_bwd_dx")) ) { + throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be update_ema_var or batch_norm2d_bwd_dx, but found " + opcode); + } + _input1 = in; + _input2 = in2; + _input3 = in3; + _input4 = in4; + _input5 = in5; + _gputype = GPUINSTRUCTION_TYPE.Dnn; + _output = out; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + public static DnnGPUInstruction parseInstruction(String str) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; @@ -297,14 +314,15 @@ public class DnnGPUInstruction extends GPUInstruction { return new DnnGPUInstruction(in1, null, out, opcode, str, stride, padding, input_shape, filter_shape, Double.parseDouble(parts[15])); } - else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) { + else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") + || opcode.equalsIgnoreCase("inv_var") ) { InstructionUtils.checkNumFields(parts, 4); CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand out = new CPOperand(parts[3]); return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4])); } - else if (opcode.equalsIgnoreCase("channel_sums")) { + else if (opcode.equalsIgnoreCase("channel_sums") || opcode.equals("reshape_colmeans") || opcode.equals("update_ema")) { InstructionUtils.checkNumFields(parts, 4); CPOperand in = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); @@ -333,7 +351,7 @@ public class DnnGPUInstruction extends GPUInstruction { CPOperand out2 = new CPOperand(parts[8]); return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0); } - else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) { + else if (opcode.equalsIgnoreCase("lstm_backward")) { InstructionUtils.checkNumFields(parts, 13); CPOperand in1 = new CPOperand(parts[1]); // image CPOperand in2 = new CPOperand(parts[2]); // scale @@ -350,19 +368,6 @@ public class DnnGPUInstruction extends GPUInstruction { CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0); } - else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) { - InstructionUtils.checkNumFields(parts, 9); - CPOperand in1 = new CPOperand(parts[1]); // image - CPOperand in2 = new CPOperand(parts[2]); // dout - CPOperand in3 = new CPOperand(parts[3]); // scale - CPOperand in4 = new CPOperand(parts[4]); // epsilon - CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean - CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance - CPOperand out = new CPOperand(parts[7]); // dX - CPOperand out2 = new CPOperand(parts[8]); // dScale - CPOperand out3 = new CPOperand(parts[9]); // dBias - return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); - } else if (opcode.equalsIgnoreCase("batch_norm2d_test")) { InstructionUtils.checkNumFields(parts, 7); CPOperand in = new CPOperand(parts[1]); @@ -374,21 +379,25 @@ public class DnnGPUInstruction extends GPUInstruction { CPOperand out = new CPOperand(parts[7]); return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0); } - else if (opcode.equalsIgnoreCase("batch_norm2d_train")) { - InstructionUtils.checkNumFields(parts, 12); - CPOperand in1 = new CPOperand(parts[1]); // image - CPOperand in2 = new CPOperand(parts[2]); // gamma - CPOperand in3 = new CPOperand(parts[3]); // beta - CPOperand in4 = new CPOperand(parts[4]); // ema_mean - CPOperand in5 = new CPOperand(parts[5]); // ema_var - CPOperand in6 = new CPOperand(parts[6]); // eps - CPOperand in7 = new CPOperand(parts[7]); // mu - CPOperand out = new CPOperand(parts[8]); // out - CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd - CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd - CPOperand out4 = new CPOperand(parts[11]); // cache_mean - CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var - return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0); + else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) { + InstructionUtils.checkNumFields(parts, 6); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand in5 = new CPOperand(parts[5]); + CPOperand out = new CPOperand(parts[6]); + return new DnnGPUInstruction(in, in2, in3, in4, in5, out, opcode, str, 0); + } + else if (opcode.equalsIgnoreCase("update_ema_var")) { + InstructionUtils.checkNumFields(parts, 6); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand in5 = new CPOperand(parts[5]); + CPOperand out = new CPOperand(parts[6]); + return new DnnGPUInstruction(in, in2, in3, in4, in5, out, opcode, str, 0); } else { throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str); @@ -396,211 +405,185 @@ public class DnnGPUInstruction extends GPUInstruction { } private void processBiasInstruction(String instOpcode, ExecutionContext ec) { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns()); - - if(instOpcode.equalsIgnoreCase("bias_add")) - LibMatrixCUDA.biasAdd(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out); - else if(instOpcode.equalsIgnoreCase("bias_multiply")) - LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out); - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - } - - private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); - MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName()); - MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName()); - - String phase = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getStringValue(); - double epsilon = ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue(); - - MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); - - if(phase.equalsIgnoreCase("train")) { - double exponentialAverageFactor = 1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), _input8.isLiteral()).getDoubleValue(); - MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); - MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); - MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); - MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); - LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), - image, scale, bias, runningMean, runningVar, ret, - retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance); - ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output4.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output5.getName()); - } - else if(phase.equalsIgnoreCase("test")) { - LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(), - image, scale, bias, runningMean, runningVar, ret, epsilon); - ec.setMatrixOutput(_output2.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode()); - ec.setMatrixOutput(_output3.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode()); - ec.setMatrixOutput(_output4.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode()); - ec.setMatrixOutput(_output5.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode()); - } - else { - throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase); + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("input", _input1).add("bias", _input2); + + MatrixObject input = fetcher.getInputMatrixObject("input"); + MatrixObject bias = fetcher.getInputMatrixObject("bias"); + MatrixObject out = fetcher.getOutputMatrixObject(input.getNumRows(), input.getNumColumns()); + + if(instOpcode.equalsIgnoreCase("bias_add")) + LibMatrixCUDA.biasAdd(gCtx, instName, input, bias, out); + else if(instOpcode.equalsIgnoreCase("bias_multiply")) + LibMatrixCUDA.biasMultiply(gCtx, instName, input, bias, out); } - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - ec.releaseMatrixInputForGPUInstruction(_input4.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } - private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); - MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName()); - MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName()); - - double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue(); - double exponentialAverageFactor = 1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue(); - - MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); - MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); - MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); - MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); - MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); - - LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), - image, scale, bias, runningMean, runningVar, ret, - retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - ec.releaseMatrixInputForGPUInstruction(_input4.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output4.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output5.getName()); + private void processInverseVarianceInstruction(String instOpcode, ExecutionContext ec) { + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("X", _input1).addScalar("eps", _input2); + + int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("X")); + int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X")); + + // invVar(X, C, eps, size); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("invVar", + ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols), + fetcher.getInputPointer("X"), fetcher.getOutputPointer(rows, cols), + fetcher.getDouble("eps"), rows*cols); + } } private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); - MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName()); - MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName()); - double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue(); - - MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); - LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(), - image, scale, bias, runningMean, runningVar, ret, epsilon); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - ec.releaseMatrixInputForGPUInstruction(_input4.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("image", _input1).add("scale", _input2).add("bias", _input3) + .add("runningMean", _input4).add("runningVar", _input5).addScalar("epsilon", _input6); + + double epsilon = fetcher.getDouble("epsilon"); + if(epsilon < JCudnn.CUDNN_BN_MIN_EPSILON) { + throw new DMLRuntimeException("The epsilon (" + epsilon + ") cannot be less than CUDNN_BN_MIN_EPSILON=(" + JCudnn.CUDNN_BN_MIN_EPSILON + ")"); + } + + MatrixObject image = fetcher.getInputMatrixObject("image"); + LibMatrixCuDNN.batchNormalizationForwardInference(gCtx, instName, + image, fetcher.getInputMatrixObject("scale"), fetcher.getInputMatrixObject("bias"), + fetcher.getInputMatrixObject("runningMean"), fetcher.getInputMatrixObject("runningVar"), + fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), epsilon); + } } - public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input3.getName()); - double epsilon = ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue(); - MatrixObject resultSaveMean = getMatrixInputForGPUInstruction(ec, _input5.getName()); - MatrixObject resultSaveInvVariance = getMatrixInputForGPUInstruction(ec, _input6.getName()); - - MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); - MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), scale.getNumRows(), scale.getNumColumns()); - MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), scale.getNumRows(), scale.getNumColumns()); - - LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), getExtendedOpcode(), image, - dout, scale, dX, dScale, dBias, - epsilon, resultSaveMean, resultSaveInvVariance); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - ec.releaseMatrixInputForGPUInstruction(_input6.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); + private void processBatchNorm2dBackwardDxInstruction(ExecutionContext ec) throws DMLRuntimeException { + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("X", _input1).add("dout", _input2).add("gamma", _input3) + .add("resultSaveMean", _input4).add("resultSaveInvVariance", _input5); + + // #define CUDNN_BN_MIN_EPSILON 1e-5 // Minimum epsilon allowed to be used in the Batch Normalization formula + double epsilon = 1e-4; + MatrixObject image = fetcher.getInputMatrixObject("X"); + LibMatrixCuDNN.batchNormalizationBackwardDX(gCtx, instName, image, + fetcher.getInputMatrixObject("dout"), fetcher.getInputMatrixObject("gamma"), + fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), epsilon, fetcher.getInputMatrixObject("resultSaveMean"), + fetcher.getInputMatrixObject("resultSaveInvVariance")); + } } - + + // (X > 0) * dout public void processReLUBackwardInstruction(ExecutionContext ec) { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns()); - - LibMatrixCUDA.reluBackward(ec.getGPUContext(0), getExtendedOpcode(), input, dout, out); - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("X", _input1).add("dout", _input2); + MatrixObject X = fetcher.getInputMatrixObject("X"); + LibMatrixCUDA.reluBackward(gCtx, instName, X, + fetcher.getInputMatrixObject("dout"), fetcher.getOutputMatrixObject(X.getNumRows(), X.getNumColumns())); + } } private void processChannelSumsInstruction(ExecutionContext ec) { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); - int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue(); - int HW = (int) ec.getScalarInput(_input3.getName(), _input3.getValueType(), _input3.isLiteral()).getLongValue(); - if(C*HW != input.getNumColumns()) { - throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns()); + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("X", _input1).addScalar("C", _input2).addScalar("HW", _input3); + int C = fetcher.getInteger("C"); + int HW = fetcher.getInteger("HW"); + fetcher.validateDimensions("X", -1, C*HW); + LibMatrixCUDA.channelSums(gCtx, instName, + fetcher.getInputMatrixObject("X"), + fetcher.getOutputMatrixObject(C, 1), C, HW); + } + } + + private void processEMAInstruction(ExecutionContext ec) { + // "ema_mean", "mean", "mu" + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("ema_mean", _input1).add("mean", _input2).addScalar("mu", _input3); + double mu = fetcher.getDouble("mu"); + + int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("ema_mean")); + int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("ema_mean")); + + fetcher.validateDimensions("mean", rows, cols); + + // aXplusbY(X, Y, C, a, b, size); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbY", + ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols), + fetcher.getInputPointer("ema_mean"), fetcher.getInputPointer("mean"), + fetcher.getOutputPointer(rows, cols), + mu, (1-mu), rows*cols); + } + } + + private void processReshapeColMeansInstruction(ExecutionContext ec) { + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("X", _input1).addScalar("C", _input2).addScalar("HW", _input3); + int C = fetcher.getInteger("C"); + int HW = fetcher.getInteger("HW"); + fetcher.validateDimensions("X", -1, C*HW); + int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("X")); + int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X")); + // output = matrix(colMeans(X), rows=C, cols=Hin*Win) + LibMatrixCUDA.colMeans(gCtx, instName, + fetcher.getInputPointer("X"), + fetcher.getOutputPointer(C, HW), rows, cols); } - MatrixObject outputBlock = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1); - - LibMatrixCUDA.channelSums(ec.getGPUContext(0), getExtendedOpcode(), input, outputBlock, C, HW); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } + private void processUpdateEMAVarInstruction(ExecutionContext ec) { + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + // "subgrp_means", "X", "C", "HW", "varConst1" + fetcher.add("subgrp_means", _input1).add("X", _input2).addScalar("C", _input3) + .addScalar("HW", _input4).addScalar("varConst1", _input5); + + // subgrp_vars = matrix(colVars(X) * varConst1, rows=C, cols=Hin*Win) + // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) + // ---> + // subgrp_vars = matrix(colVars(X), rows=C, cols=HW) + // var = rowMeans(subgrp_vars)*varConst1 + rowVars(subgrp_means)*((HW-1)/HW) + int C = fetcher.getInteger("C"); + int HW = fetcher.getInteger("HW"); + double varConst1 = fetcher.getDouble("varConst1"); + fetcher.validateDimensions("subgrp_means", C, HW); + fetcher.validateDimensions("X", -1, C*HW); + + Pointer subgrp_vars = gCtx.allocate(instName, C*HW*LibMatrixCUDA.sizeOfDataType); + // subgrp_vars <- colVars(X) + LibMatrixCUDA.colVars(gCtx, instName, fetcher.getInputPointer("X"), subgrp_vars, + LibMatrixCUDA.toInt(fetcher.getInputNumRows("X")), C*HW); + + // tmp1 <- rowMeans(subgrp_vars) + Pointer tmp1 = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType); + LibMatrixCUDA.rowMeans(gCtx, instName, subgrp_vars, tmp1, C, HW); + gCtx.cudaFreeHelper(instName, subgrp_vars, gCtx.EAGER_CUDA_FREE); + + // out <- rowVars(subgrp_means) + Pointer out = fetcher.getOutputPointer(C, 1); + LibMatrixCUDA.rowVars(gCtx, instName, fetcher.getInputPointer("subgrp_means"), out, C, HW); + + // var = tmp1*varConst1 + out*((HW-1)/HW) + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbC", + ExecutionConfig.getConfigForSimpleVectorOperations(C), + tmp1, out, + varConst1, (((double)HW-1)/HW), C); + gCtx.cudaFreeHelper(instName, tmp1, gCtx.EAGER_CUDA_FREE); + } + } + + + private void processNesterovUpdateInstruction(ExecutionContext ec) { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject v = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, _input3.getName()); - double mu = (int) ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue(); - int rows = LibMatrixCUDA.toInt(input.getNumRows()); - int cols = LibMatrixCUDA.toInt(input.getNumColumns()); - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), rows, cols); - - GPUContext gCtx = ec.getGPUContext(0); - String instName = getExtendedOpcode(); - LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", - ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)), - LibMatrixCUDA.getDensePointer(gCtx, input, instName), - LibMatrixCUDA.getDensePointer(gCtx, v, instName), - LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName), - mu, - LibMatrixCUDA.getDensePointer(gCtx, out, instName), - rows*cols); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("input", _input1).add("v", _input2).add("v_prev", _input3) + .addScalar("mu", _input4); + MatrixObject input = fetcher.getInputMatrixObject("input"); + int rows = LibMatrixCUDA.toInt(input.getNumRows()); + int cols = LibMatrixCUDA.toInt(input.getNumColumns()); + + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", + ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)), + fetcher.getInputPointer("input"), + fetcher.getInputPointer("v"), + fetcher.getInputPointer("v_prev"), + fetcher.getDouble("mu"), + fetcher.getOutputPointer(rows, cols), + rows*cols); + } } private static int toInt(long num) throws DMLRuntimeException { @@ -610,32 +593,18 @@ public class DnnGPUInstruction extends GPUInstruction { return (int)num; } -// private Pointer transpose(ExecutionContext ec, MatrixObject X) throws DMLRuntimeException { -// GPUContext gCtx = ec.getGPUContext(0); -// String instructionName = getExtendedOpcode(); -// long numRowsX = X.getNumRows(); long numColsX = X.getNumColumns(); -// Pointer tX = gCtx.allocate(instructionName, numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType); -// jcuda.runtime.JCuda.cudaMemcpy(tX, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType, jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice); -// // LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, numColsX); -// return tX; -// } - private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - GPUContext gCtx = ec.getGPUContext(0); - String instructionName = getExtendedOpcode(); - MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName()); int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M) - Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName); + Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instName); MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName()); MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); long numRowsW = W.getNumRows(); int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures - Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M); - Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M); - Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); + Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M); + Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M); + Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); @@ -644,20 +613,20 @@ public class DnnGPUInstruction extends GPUInstruction { MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName()); - Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); + Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D) long numColsX = X.getNumColumns(); int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength - Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); + Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), xPointer, cudnnInput, N, D, T*D, N*T*D); ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); + Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue(); - // LibMatrixCuDNN.lstm(ec, gCtx, instructionName, + // LibMatrixCuDNN.lstm(ec, gCtx, instName, // cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); // String xName, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input // String dxName, String dwName, String dbName, String dhxName, String dcxName, // output @@ -668,12 +637,12 @@ public class DnnGPUInstruction extends GPUInstruction { String dcxName = _output5.getName(); String doutName = _input7.getName(); String dcyName = _input8.getName(); - LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, + LibMatrixCuDNN.lstmBackward(ec, gCtx, instName, cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName, // input dxName, dwName, dbName, dhxName, dcxName, // output return_sequences, N, M, D, T); - gCtx.cudaFreeHelper(instructionName, cudnnWPointer, gCtx.EAGER_CUDA_FREE); - gCtx.cudaFreeHelper(instructionName, cudnnInput, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE); // release inputs/outputs ec.releaseMatrixInputForGPUInstruction(_input4.getName()); @@ -686,21 +655,17 @@ public class DnnGPUInstruction extends GPUInstruction { // weight W:(D+M+2, 4M) // previous output out0 (also represented by hx) and cell state c0 (also represented by cx): (N, M) ==> (1, M, N) // out: (N, T*M) or (N, M) ==> (T, M, N) - GPUStatistics.incrementNoOfExecutedGPUInst(); - GPUContext gCtx = ec.getGPUContext(0); - String instructionName = getExtendedOpcode(); - MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName()); int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M) - Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName); + Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instName); MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName()); MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); long numRowsW = W.getNumRows(); int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures - Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M); - Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M); - Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); + Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M); + Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M); + Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); @@ -711,21 +676,21 @@ public class DnnGPUInstruction extends GPUInstruction { // Beause the matrices are released immediately, the output for transpose need not be taken into account MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName()); - Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); + Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D) long numColsX = X.getNumColumns(); int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength - Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); + Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), xPointer, cudnnInput, N, D, T*D, N*T*D); ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); + Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); - LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); - gCtx.cudaFreeHelper(instructionName, cudnnWPointer, gCtx.EAGER_CUDA_FREE); - gCtx.cudaFreeHelper(instructionName, cudnnInput, gCtx.EAGER_CUDA_FREE); + LibMatrixCuDNN.lstm(ec, gCtx, instName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); + gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE); // release inputs/outputs ec.releaseMatrixInputForGPUInstruction(_input4.getName()); @@ -736,10 +701,17 @@ public class DnnGPUInstruction extends GPUInstruction { @Override public void processInstruction(ExecutionContext ec) { + GPUStatistics.incrementNoOfExecutedGPUInst(); + gCtx = ec.getGPUContext(0); + instName = getExtendedOpcode(); if (instOpcode.equalsIgnoreCase("bias_add") || instOpcode.equalsIgnoreCase("bias_multiply")) { processBiasInstruction(instOpcode, ec); return; } + else if (instOpcode.equalsIgnoreCase("inv_var")) { + processInverseVarianceInstruction(instOpcode, ec); + return; + } else if (instOpcode.equalsIgnoreCase("relu_backward")) { processReLUBackwardInstruction(ec); return; @@ -748,10 +720,22 @@ public class DnnGPUInstruction extends GPUInstruction { processChannelSumsInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("update_ema")) { + processEMAInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("reshape_colmeans")) { + processReshapeColMeansInstruction(ec); + return; + } else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) { processNesterovUpdateInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("update_ema_var")) { + processUpdateEMAVarInstruction(ec); + return; + } else if (instOpcode.equalsIgnoreCase("lstm")) { processLstmInstruction(ec); return; @@ -760,24 +744,14 @@ public class DnnGPUInstruction extends GPUInstruction { processLstmBackwardInstruction(ec); return; } - else if (instOpcode.equalsIgnoreCase("batch_norm2d")) { - processBatchNorm2dInstruction(ec); - return; - } - else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) { - processBatchNorm2dBackwardInstruction(ec); - return; - } else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) { processBatchNorm2dTestInstruction(ec); return; } - else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) { - processBatchNorm2dTrainInstruction(ec); + else if (instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) { + processBatchNorm2dBackwardDxInstruction(ec); return; } - - GPUStatistics.incrementNoOfExecutedGPUInst(); int pad_h = getScalarInput(ec, _padding, 0); int pad_w = getScalarInput(ec, _padding, 1); http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java new file mode 100644 index 0000000..8fcaec3 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.runtime.instructions.gpu; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA; +import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.utils.GPUStatistics; + +import jcuda.Pointer; + +public class GPUDenseInputPointerFetcher implements java.lang.AutoCloseable { + ExecutionContext _ec; GPUContext _gCtx; String _instName; + HashMap<String, CPOperand> _inputMatrices = new HashMap<>(); + HashMap<String, MatrixObject> _inputMatrixObjects = new HashMap<>(); + HashMap<String, CPOperand> _inputScalars = new HashMap<>(); + CPOperand _output; + public GPUDenseInputPointerFetcher(ExecutionContext ec, GPUContext gCtx, String instName, CPOperand output) { + _ec = ec; + _gCtx = gCtx; + _instName = instName; + _output = output; + } + public GPUDenseInputPointerFetcher add(String var, CPOperand in) { + _inputMatrices.put(var, in); + return this; + } + public GPUDenseInputPointerFetcher addScalar(String var, CPOperand in) { + _inputScalars.put(var, in); + return this; + } + public double getDouble(String var) { + CPOperand in = _inputScalars.get(var); + return _ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getDoubleValue(); + } + public long getLong(String var) { + CPOperand in = _inputScalars.get(var); + return _ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getLongValue(); + } + public int getInteger(String var) { + CPOperand in = _inputScalars.get(var); + return LibMatrixCUDA.toInt(_ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getLongValue()); + } + public Pointer getInputPointer(String var) { + return LibMatrixCUDA.getDensePointer(_gCtx, getInputMatrixObject(var), _instName); + } + public long getInputNumRows(String var) { + return getInputMatrixObject(var).getNumRows(); + } + public long getInputNumColumns(String var) { + return getInputMatrixObject(var).getNumColumns(); + } + public MatrixObject getOutputMatrixObject(long numRows, long numCols) { + boolean isFinegrainedStats = ConfigurationManager.isFinegrainedStatistics(); + long t0 = isFinegrainedStats ? System.nanoTime() : 0; + Pair<MatrixObject, Boolean> mb = _ec.getDenseMatrixOutputForGPUInstruction(_output.getName(), numRows, numCols); + if (isFinegrainedStats && mb.getValue()) GPUStatistics.maintainCPMiscTimes(_instName, + GPUInstruction.MISC_TIMER_ALLOCATE_DENSE_OUTPUT, System.nanoTime() - t0); + return mb.getKey(); + } + public Pointer getOutputPointer(long numRows, long numCols) { + return LibMatrixCUDA.getDensePointer(_gCtx, getOutputMatrixObject(numRows, numCols), _instName); + } + public MatrixObject getInputMatrixObject(String var) { + CPOperand in = _inputMatrices.get(var); + if(!_inputMatrixObjects.containsKey(var)) { + _inputMatrixObjects.put(var, _ec.getMatrixInputForGPUInstruction(in.getName(), _instName)); + } + return _inputMatrixObjects.get(var); + } + public void validateDimensions(String var, long numRows, long numCols) { + MatrixObject mo = getInputMatrixObject(var); + if(numRows > 0 && mo.getNumRows() != numRows) { + throw new DMLRuntimeException("Expected number of rows of subgrp_means to be " + numRows + ", but found " + mo.getNumRows()); + } + else if(numCols > 0 && mo.getNumColumns() != numCols) { + throw new DMLRuntimeException("Expected number of columns of subgrp_means to be " + numCols + ", but found " + mo.getNumColumns()); + } + } + @Override + public void close() { + for(CPOperand in : _inputMatrices.values()) { + _ec.releaseMatrixInputForGPUInstruction(in.getName()); + } + _ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java index 2e43b99..e01c71a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java @@ -242,7 +242,7 @@ public class GPUMemoryManager { * @return allocated pointer */ public Pointer malloc(String opcode, long size) { - if(size < 0) { + if(size <= 0) { throw new DMLRuntimeException("Cannot allocate memory of size " + byteCountToDisplaySize(size)); } if(DEBUG_MEMORY_LEAK) { http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index d02a875..f3f8434 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -208,7 +208,7 @@ public class LibMatrixCUDA { return gCtx.getCusparseHandle(); } - protected static cublasHandle getCublasHandle(GPUContext gCtx) { + public static cublasHandle getCublasHandle(GPUContext gCtx) { return gCtx.getCublasHandle(); } @@ -302,7 +302,7 @@ public class LibMatrixCUDA { } } - protected static Pointer dataTypePointerTo(double value) { + public static Pointer dataTypePointerTo(double value) { if(value == 1) { return one(); } @@ -313,7 +313,6 @@ public class LibMatrixCUDA { return _dataTypePointerTo(value); } } - /** * This method computes the backpropagation errors for previous layer of relu operation @@ -753,11 +752,11 @@ public class LibMatrixCUDA { break; } case REDUCTION_COL: { - reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen); + rowMeans(gCtx, instName, in, out, rlen, clen); break; } case REDUCTION_ROW: { - reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen); + colMeans(gCtx, instName, in, out, rlen, clen); break; } default: @@ -818,13 +817,14 @@ public class LibMatrixCUDA { break; } case OP_VARIANCE : { - // Temporary GPU array for - Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType); - Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType); + switch(reductionDirection) { case REDUCTION_ALL: { + // Temporary GPU array for + Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType); + Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType); double result = reduceAll(gCtx, instName, "reduce_sum", in, size); double mean = result / size; @@ -837,50 +837,21 @@ public class LibMatrixCUDA { double result2 = reduceAll(gCtx, instName, "reduce_sum", tmp2, size); double variance = result2 / (size - 1); ec.setScalarOutput(output, new DoubleObject(variance)); - + gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE); break; } case REDUCTION_COL: { - reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen); - // Subtract the row-wise mean from every element in the matrix - BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject()); - matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp); - - squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen); - - Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType); - reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen); - - ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1); - matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp); - - gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE); - + rowVars(gCtx, instName, in, out, rlen, clen); break; } case REDUCTION_ROW: { - reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen); - // Subtract the columns-wise mean from every element in the matrix - BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject()); - matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp); - - squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen); - - Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType); - reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen); - - ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1); - matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp); - - gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE); - + colVars(gCtx, instName, in, out, rlen, clen); break; } default: throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for variance"); } - gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); - gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE); break; } case OP_MAXINDEX : { @@ -904,6 +875,59 @@ public class LibMatrixCUDA { default : throw new DMLRuntimeException("Internal Error - Invalid GPU Unary aggregate function!"); } } + + public static void rowMeans(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) { + LibMatrixCUDA.reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen); + } + + public static void colMeans(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) { + reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen); + } + + public static void colVars(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) { + int size = rlen * clen; + Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType); + Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType); + reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen); + // Subtract the columns-wise mean from every element in the matrix + BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject()); + matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp); + + squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen); + + Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType); + reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen); + + ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1); + matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp); + + gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE); + } + + public static void rowVars(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) { + int size = rlen * clen; + Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType); + Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType); + + reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen); + // Subtract the row-wise mean from every element in the matrix + BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject()); + matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp); + + squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen); + + Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType); + reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen); + + ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1); + matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp); + + gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE); + } /** * Helper method to square a matrix in GPU memory @@ -970,7 +994,7 @@ public class LibMatrixCUDA { * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ - private static void reduceRow(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) { + public static void reduceRow(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) { if(LOG.isTraceEnabled()) { LOG.trace("GPU : reduceRow for " + kernelFunction + ", GPUContext=" + gCtx); } @@ -997,7 +1021,7 @@ public class LibMatrixCUDA { * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ - private static void reduceCol(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) { + public static void reduceCol(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) { if(LOG.isTraceEnabled()) { LOG.trace("GPU : reduceCol for " + kernelFunction + ", GPUContext=" + gCtx); } http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java index d3b5984..e7955e1 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java @@ -1108,23 +1108,29 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { runningMeanPtr, runningVarPtr, epsilon)); } + private static void validateDimensions(MatrixObject mo, long expectedRows, long expectedCols) { + if(mo.getNumRows() != expectedRows || mo.getNumColumns() != expectedCols) { + throw new DMLRuntimeException("Incorrect dimensions for the input matrix object. Expected [" + expectedRows + ", " + expectedCols+ "], but found " + + "[" + mo.getNumRows() + ", " + mo.getNumColumns() + "]."); + } + } + /** - * This method computes the backpropagation errors for image, scale and bias of batch normalization layer + * This method computes the backpropagation errors for image of batch normalization layer * @param gCtx a valid {@link GPUContext} * @param instName name of the instruction * @param image input image * @param dout input errors of shape C, H, W * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1] * @param dX (output) backpropagation errors for previous layer - * @param dScale backpropagation error for scale - * @param dBias backpropagation error for bias * @param epsilon epsilon value used in the batch normalization formula * @param resultSaveMean (input) running mean accumulated during training phase: shape [1, C, 1, 1] * @param resultSaveInvVariance (input) running variance accumulated during training phase: shape [1, C, 1, 1] * @throws DMLRuntimeException if error occurs */ - public static void batchNormalizationBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout, - MatrixObject scale, MatrixObject dX, MatrixObject dScale, MatrixObject dBias, + public static void batchNormalizationBackwardDX(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout, + MatrixObject scale, MatrixObject dX, + // MatrixObject dScale, MatrixObject dBias, double epsilon, MatrixObject resultSaveMean, MatrixObject resultSaveInvVariance) throws DMLRuntimeException { if(LOG.isTraceEnabled()) { LOG.trace("GPU : batchNormalizationBackward" + ", GPUContext=" + gCtx); @@ -1133,7 +1139,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { int N = toInt(image.getNumRows()); int C = toInt(scale.getNumRows()); long CHW = image.getNumColumns(); - + + validateDimensions(scale, C, 1); + validateDimensions(dX, N, CHW); + validateDimensions(dout, N, CHW); + validateDimensions(resultSaveMean, C, 1); + validateDimensions(resultSaveInvVariance, C, 1); + // Allocate descriptors cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW, new MatrixObject[] {image, dout}, new MatrixObject[] {dX}); @@ -1144,18 +1156,17 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { Pointer doutPtr = getDensePointerForCuDNN(gCtx, dout, instName); Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName); Pointer dXPtr = getDensePointerForCuDNN(gCtx, dX, instName); - Pointer dScalePtr = getDensePointerForCuDNN(gCtx, dScale, instName); - Pointer dBiasPtr = getDensePointerForCuDNN(gCtx, dBias, instName); - + Pointer dScalePtr = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType); // getDensePointerForCuDNN(gCtx, dScale, instName); + Pointer dBiasPtr = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType); //getDensePointerForCuDNN(gCtx, dBias, instName); Pointer resultSaveMeanPtr = getDensePointerForCuDNN(gCtx, resultSaveMean, instName); Pointer resultSaveInvVariancePtr = getDensePointerForCuDNN(gCtx, resultSaveInvVariance, instName); - - // ignoring resultSaveMean and resultSaveVariance as it requires state management - checkStatus(cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), + cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), jcuda.jcudnn.cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL, one(), zero(), one(), zero(), nCHWDescriptor, imagePtr, nCHWDescriptor, doutPtr, nCHWDescriptor, dXPtr, - scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr)); + scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr); + gCtx.cudaFreeHelper(instName, dScalePtr, gCtx.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instName, dBiasPtr, gCtx.EAGER_CUDA_FREE); } private static void validateBatchNormalizationDimensions(MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, int C) throws DMLRuntimeException { http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java index d96feac..d7e7b24 100644 --- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java +++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java @@ -55,8 +55,8 @@ public class BatchNormTest extends GPUTests { int imgSize = 32; int numChannels = 3; double sparsity = 0.9; - String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") as batch_norm2d_old;\n " - + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)"; + String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as batch_norm2d;\n " + + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)"; HashMap<String, Object> inputs = new HashMap<>(); inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed)); inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed)); @@ -68,19 +68,40 @@ public class BatchNormTest extends GPUTests { List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); if(mode.equals("test")) { assertHeavyHitterPresent("gpu_batch_norm2d_test"); - for(int i = 0; i < outputs.size(); i++) { - assertEqualObjects(outCPU.get(i), outGPU.get(i)); - } } else { - //assertHeavyHitterPresent("gpu_batch_norm2d_train"); - double [] threshold = new double[outputs.size()]; - Arrays.fill(threshold, getTHRESHOLD()); - // Handle loss of precision in CuDNN kernel - threshold[2] = 1e-3; - for(int i = 0; i < outputs.size()-1; i++) { - assertEqualObjects(outCPU.get(i), outGPU.get(i), threshold[i]); - } + assertHeavyHitterPresent("gpu_batch_norm2d_test"); + assertHeavyHitterPresent("gpu_reshape_colmeans"); + assertHeavyHitterPresent("gpu_update_ema_var"); } + assertEqualObjects(outCPU.get(0), outGPU.get(0)); + assertEqualObjects(outCPU.get(1), outGPU.get(1)); + assertEqualObjects(outCPU.get(2), outGPU.get(2)); + assertEqualObjects(outCPU.get(3), outGPU.get(3)); + assertEqualObjects(outCPU.get(4), outGPU.get(4)); + } + + @Test + public void testBatchNormBackward() { + int imgSize = 32; + int numChannels = 3; + double sparsity = 0.9; + String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as batch_norm2d;\n " + + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"train\", ema_mean, ema_var, 0.9, 1e-3);\n" + + "[dX, dgamma, dbeta] = batch_norm2d::backward(dout, cache_mean, cache_var, x, gamma, " + numChannels + ", " + imgSize + ", " + imgSize + ", 1e-3);"; + HashMap<String, Object> inputs = new HashMap<>(); + inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed)); + inputs.put("dout", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 1, 5, sparsity, seed)); + inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed)); + inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed)); + inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 3, 7, sparsity, seed)); + inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed)); + List<String> outputs = Arrays.asList("dX", "dgamma", "dbeta"); + List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs); + List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); + assertHeavyHitterPresent("gpu_batch_norm2d_bwd_dx"); + assertEqualObjects(outCPU.get(0), outGPU.get(0), 1e-6); + assertEqualObjects(outCPU.get(1), outGPU.get(1)); + assertEqualObjects(outCPU.get(2), outGPU.get(2)); } }