Repository: incubator-systemml Updated Branches: refs/heads/master c1c7d3341 -> c335cd403
[SYSTEMML-1340][SYSTEMML-1341] Implemented conv2d_bias_add and relu_maxpooling instruction for GPU Closes #425. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/c335cd40 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/c335cd40 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/c335cd40 Branch: refs/heads/master Commit: c335cd403e6961460ba21234d4dbcb79ae22925f Parents: c1c7d33 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Fri Mar 10 21:16:26 2017 -0800 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Fri Mar 10 21:16:26 2017 -0800 ---------------------------------------------------------------------- .../org/apache/sysml/hops/ConvolutionOp.java | 6 +- .../instructions/GPUInstructionParser.java | 2 + .../gpu/ConvolutionGPUInstruction.java | 64 ++++++- .../runtime/matrix/data/LibMatrixCUDA.java | 167 +++++++++++++------ 4 files changed, 178 insertions(+), 61 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c335cd40/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java index c32f227..943ff96 100644 --- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -179,11 +179,11 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop ArrayList<Hop> inputs1 = inputs; int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); OperationTypes lopOp = HopsConv2Lops.get(op); - if(op == ConvOp.MAX_POOLING && (et == ExecType.CP || et == ExecType.SPARK) && isInputReLU(inputs.get(0))) { + if(op == ConvOp.MAX_POOLING && isInputReLU(inputs.get(0))) { in = inputs.get(0).getInput().get(0).constructLops(); lopOp = OperationTypes.RELU_MAX_POOLING; } - else if(op == ConvOp.BIAS_ADD && (et == ExecType.CP || et == ExecType.SPARK) && isInputConv2d(inputs.get(0))) { + else if(op == ConvOp.BIAS_ADD && isInputConv2d(inputs.get(0))) { lopOp = OperationTypes.DIRECT_CONV2D_BIAS_ADD; // the first lop is image @@ -320,7 +320,7 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop if( _etypeForced != null ) { - _etype = _etypeForced; + _etype = findGPUExecTypeByMemEstimate(_etypeForced); } else { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c335cd40/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index 23b5328..366015f 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -41,6 +41,8 @@ public class GPUInstructionParser extends InstructionParser // Neural Network Operators String2GPUInstructionType.put( "relu_backward", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "conv2d", GPUINSTRUCTION_TYPE.Convolution); + String2GPUInstructionType.put( "relu_maxpooling", GPUINSTRUCTION_TYPE.Convolution); + String2GPUInstructionType.put( "conv2d_bias_add", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "conv2d_backward_filter", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "conv2d_backward_data", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "maxpooling", GPUINSTRUCTION_TYPE.Convolution); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c335cd40/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java index cb8c729..daf3c58 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java @@ -35,6 +35,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction { private CPOperand _input1; private CPOperand _input2; + private CPOperand _input3; private CPOperand _output; private ArrayList<CPOperand> _input_shape; private ArrayList<CPOperand> _filter_shape; @@ -52,6 +53,15 @@ public class ConvolutionGPUInstruction extends GPUInstruction _output = out; } + public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, + ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape) + { + this(in1, in2, out, opcode, istr, stride, padding, input_shape, filter_shape); + _input3 = in3; + } + public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, @@ -104,7 +114,34 @@ public class ConvolutionGPUInstruction extends GPUInstruction return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, stride, padding, input_shape, filter_shape); } - else if (opcode.equalsIgnoreCase("maxpooling")) { + else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { + InstructionUtils.checkNumFields(parts, 16); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[16]); + + ArrayList<CPOperand> stride = new ArrayList<CPOperand>(); + ArrayList<CPOperand> padding = new ArrayList<CPOperand>(); + ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>(); + ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>(); + stride.add(new CPOperand(parts[4])); + stride.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + padding.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + input_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + filter_shape.add(new CPOperand(parts[15])); + + return new ConvolutionGPUInstruction(in1, in2, in3, out, opcode, str, stride, + padding, input_shape, filter_shape); + } + else if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) { InstructionUtils.checkNumFields(parts, 14); CPOperand in1 = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[14]); @@ -216,6 +253,21 @@ public class ConvolutionGPUInstruction extends GPUInstruction LibMatrixCUDA.conv2d(getExtendedOpcode(), image, filter, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); } + else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input3.getName()); + + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) + throw new DMLRuntimeException("Incorrect dimensions for image in conv2d"); + if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) + throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d"); + + ec.setMetaData(_output.getName(), N, K * P * Q); + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName()); + LibMatrixCUDA.conv2dBiasAdd(getExtendedOpcode(), image, bias, filter, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); + } else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) { MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); @@ -248,7 +300,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction LibMatrixCUDA.conv2dBackwardData(getExtendedOpcode(), filter, dout, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); } - else if (instOpcode.equalsIgnoreCase("maxpooling")) { + else if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); if(image.getNumRows() != N || image.getNumColumns() != C*H*W) @@ -257,8 +309,12 @@ public class ConvolutionGPUInstruction extends GPUInstruction ec.setMetaData(_output.getName(), N, C * P * Q); MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName()); - LibMatrixCUDA.maxpooling(getExtendedOpcode(), image, out, N, C, H, W, + if(instOpcode.equalsIgnoreCase("maxpooling")) + LibMatrixCUDA.maxpooling(getExtendedOpcode(), image, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); + else + LibMatrixCUDA.reluMaxpooling(getExtendedOpcode(), image, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); } else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) { MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); @@ -281,7 +337,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction // release inputs/outputs ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - if (!instOpcode.equalsIgnoreCase("maxpooling")) + if (!( instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) ) ec.releaseMatrixInputForGPUInstruction(_input2.getName()); ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c335cd40/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index bcbf3f3..4ff9849 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -206,6 +206,13 @@ public class LibMatrixCUDA { private static int CONVOLUTION_PREFERENCE = cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; + public static void conv2dBiasAdd(String instName, MatrixObject image, MatrixObject bias, MatrixObject filter, MatrixObject outputBlock, int N, int C, int H, int W, + int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) + throws DMLRuntimeException { + conv2d(instName, image, filter, outputBlock, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); + biasAdd(instName, outputBlock, bias, outputBlock); + } + public static void conv2d(String instName, MatrixObject image, MatrixObject filter, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) throws DMLRuntimeException { @@ -623,7 +630,7 @@ public class LibMatrixCUDA { if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3); } } - + /** * performs maxpooling on GPU by exploiting cudnnPoolingForward(...) * @param instName the invoking instruction's name for record {@link Statistics}. @@ -645,12 +652,57 @@ public class LibMatrixCUDA { * @throws DMLRuntimeException if DMLRuntimeException occurs */ public static void maxpooling(String instName, MatrixObject image, - MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, - int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, - int Q) throws DMLRuntimeException { + MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, + int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, + int Q) throws DMLRuntimeException { if(isInSparseFormat(image)) { ((JCudaObject)image.getGPUObject()).sparseToDense(instName); } + Pointer x = ((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; + performMaxpooling(instName, x, outputBlock, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); + } + + /** + * performs relu followed by maxpooling on GPU by exploiting cudnnPoolingForward(...) + * @param instName the invoking instruction's name for record {@link Statistics}. + * @param image image as matrix object + * @param outputBlock output matrix + * @param N batch size + * @param C number of channels + * @param H height of image + * @param W width of image + * @param K number of filters + * @param R height of filter + * @param S width of filter + * @param pad_h vertical padding + * @param pad_w horizontal padding + * @param stride_h horizontal stride + * @param stride_w vertical stride + * @param P (H - R + 1 + 2*pad_h)/stride_h + * @param Q (W - S + 1 + 2*pad_w)/stride_w + * @throws DMLRuntimeException if DMLRuntimeException occurs + */ + public static void reluMaxpooling(String instName, MatrixObject image, + MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, + int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, + int Q) throws DMLRuntimeException { + if(isInSparseFormat(image)) { + ((JCudaObject)image.getGPUObject()).sparseToDense(instName); + } + Pointer x = ((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; + MatrixObject temp = new MatrixObject(image); + temp.getGPUObject().acquireDeviceModifyDense(); + Pointer y = ((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; + performReLU(instName, x, y, N, C, H, W); + performMaxpooling(instName, y, outputBlock, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q); + ((JCudaObject)temp.getGPUObject()).clearData(); // deallocate the temporary data + } + + private static void performMaxpooling(String instName, Pointer x, + MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, + int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, + int Q) throws DMLRuntimeException { + Pointer alpha = null; Pointer beta = null; cudnnTensorDescriptor xDesc = null; @@ -666,7 +718,6 @@ public class LibMatrixCUDA { poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w); // Allocate data - Pointer x = ((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; Pointer y = ((JCudaObject)outputBlock.getGPUObject()).jcudaDenseMatrixPtr; alpha = pointerTo(1.0); @@ -807,62 +858,29 @@ public class LibMatrixCUDA { if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t4); } } - - - /** - * Performs the relu operation on the GPU. - * @param ec currently active {@link ExecutionContext} - * @param instName the invoking instruction's name for record {@link Statistics}. - * @param in input matrix - * @param outputName name of the output matrix - * @throws DMLRuntimeException if an error occurs - */ - public static void relu(ExecutionContext ec, String instName, MatrixObject in, String outputName) throws DMLRuntimeException { - if(isInSparseFormat(in)) { - // TODO: FIXME: Implement sparse relu kernel - ((JCudaObject)in.getGPUObject()).sparseToDense(instName); - } - + + private static void performCuDNNReLU(String instName, Pointer srcData, Pointer dstData, long N, long C, long H, long W) { cudnnTensorDescriptor srcTensorDesc = null; cudnnTensorDescriptor dstTensorDesc = null; Pointer alpha = null; Pointer beta = null; - + long t0=0; try { alpha = pointerTo(1.0f); beta = pointerTo(0.0f); - long N = in.getNumRows(); - long H = in.getNumColumns(); - long W = 1; - Pointer srcData = ((JCudaObject)in.getGPUObject()).jcudaDenseMatrixPtr; - - MatrixObject output = ec.getMatrixObject(outputName); - getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); // Allocated the dense output matrix - Pointer dstData = ((JCudaObject)output.getGPUObject()).jcudaDenseMatrixPtr; - - long t0=0; - if(N*H*W >= numDoublesIn2GB) { - // Invokes relu(double* A, double* ret, int rlen, int clen) - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - kernels.launchKernel("relu", - ExecutionConfig.getConfigForSimpleMatrixOperations((int)N, (int) (H*W)), - srcData, dstData, (int)N, (int) H*W); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RELU_KERNEL, System.nanoTime() - t0); - } - else { - // Allocate descriptors - srcTensorDesc = allocateTensorDescriptor((int)N, 1, (int)H, (int)W); - dstTensorDesc = allocateTensorDescriptor((int)N, 1, (int)H, (int)W); - cudnnActivationDescriptor activationDescriptor = new cudnnActivationDescriptor(); - cudnnCreateActivationDescriptor(activationDescriptor); - double dummy = -1; - cudnnSetActivationDescriptor(activationDescriptor, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, dummy); - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - cudnnActivationForward(cudnnHandle, activationDescriptor, - alpha, srcTensorDesc, srcData, - beta, dstTensorDesc, dstData); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ACTIVATION_FORWARD_LIB, System.nanoTime() - t0); - } + + // Allocate descriptors + srcTensorDesc = allocateTensorDescriptor((int)N, (int)C, (int)H, (int)W); + dstTensorDesc = allocateTensorDescriptor((int)N, (int)C, (int)H, (int)W); + cudnnActivationDescriptor activationDescriptor = new cudnnActivationDescriptor(); + cudnnCreateActivationDescriptor(activationDescriptor); + double dummy = -1; + cudnnSetActivationDescriptor(activationDescriptor, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, dummy); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + cudnnActivationForward(cudnnHandle, activationDescriptor, + alpha, srcTensorDesc, srcData, + beta, dstTensorDesc, dstData); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ACTIVATION_FORWARD_LIB, System.nanoTime() - t0); } finally { long t1=0; @@ -880,6 +898,47 @@ public class LibMatrixCUDA { if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t1); } } + + private static void performReLU(String instName, Pointer srcData, Pointer dstData, long N, long C, long H, long W) throws DMLRuntimeException { + long t0=0; + if(N*H*W >= numDoublesIn2GB) { + // Invokes relu(double* A, double* ret, int rlen, int clen) + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + kernels.launchKernel("relu", + ExecutionConfig.getConfigForSimpleMatrixOperations((int)N, (int) (H*W)), + srcData, dstData, (int)N, (int) H*W); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RELU_KERNEL, System.nanoTime() - t0); + } + else { + performCuDNNReLU(instName, srcData, dstData, N, 1, H, W); + } + } + + + /** + * Performs the relu operation on the GPU. + * @param ec currently active {@link ExecutionContext} + * @param instName the invoking instruction's name for record {@link Statistics}. + * @param in input matrix + * @param outputName name of the output matrix + * @throws DMLRuntimeException if an error occurs + */ + public static void relu(ExecutionContext ec, String instName, MatrixObject in, String outputName) throws DMLRuntimeException { + if(isInSparseFormat(in)) { + // TODO: FIXME: Implement sparse relu kernel + ((JCudaObject)in.getGPUObject()).sparseToDense(instName); + } + + long N = in.getNumRows(); + long H = in.getNumColumns(); + long W = 1; + Pointer srcData = ((JCudaObject)in.getGPUObject()).jcudaDenseMatrixPtr; + + MatrixObject output = ec.getMatrixObject(outputName); + getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); // Allocated the dense output matrix + Pointer dstData = ((JCudaObject)output.getGPUObject()).jcudaDenseMatrixPtr; + performReLU(instName, srcData, dstData, N, 1, H, W); + }