[SYSTEMML-540] Add support for sparse filter dense image conv2d - Also, supported skipping convolution operations for empty image and filters. - Disabled sparse native convolution operations.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8e3c6f8b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8e3c6f8b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8e3c6f8b Branch: refs/heads/master Commit: 8e3c6f8b8af68830b43c47c06cd6ffc2dc79d1f0 Parents: e0006a2 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Fri Jan 19 11:20:50 2018 -0800 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Fri Jan 19 11:20:50 2018 -0800 ---------------------------------------------------------------------- src/main/cpp/kernels/SystemML.cu | 194 + src/main/cpp/kernels/SystemML.ptx | 4324 ++++++++++-------- .../instructions/gpu/GPUInstruction.java | 3 + .../instructions/gpu/context/GPUObject.java | 51 + .../runtime/matrix/data/LibMatrixCUDA.java | 19 + .../runtime/matrix/data/LibMatrixCuDNN.java | 173 +- .../runtime/matrix/data/LibMatrixCuMatMult.java | 2 +- .../runtime/matrix/data/LibMatrixDNNConv2d.java | 13 +- 8 files changed, 2928 insertions(+), 1851 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/8e3c6f8b/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index 8eac454..29ae820 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -43,6 +43,200 @@ extern "C" __global__ void float2double_f(float *A, double *ret, int N) { } /** + * This method performs an im2col operation on sparse input image + * + * @params inVal input val pointer + * @params inRowPtr input row pointer + * @params colInd input col index pointer + * @param ret output matrix allocated on the GPU + * @param NCHW value of N*C*H*W + * @param CHW value of C*H*W + * @param HW value of H*W + * @param W image height + * @param R filter height + * @param S filter width + * @param P height of conv2d output + * @param Q width of conv2d output + * @param PQ value of P*Q + * @param RS value of R*S + * @param NPQ value of N*P*Q + * @param stride_h stride height + * @param stride_w stride width + * @param pad_h padding height + * @param pad_w padding width + */ +template <typename T> +__device__ void sparse_dense_im2col(T *inVal, int *inRowPtr, int *colInd, T *ret, + int nnz, int N, int CHW, int HW, int W, + int R, int S, int P, int Q, int PQ, int RS, int NPQ, + int stride_h, int stride_w, int pad_h, int pad_w) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nnz) { + T value = inVal[tid]; + int n = 0; + while (inRowPtr[n+1] <= tid) { + n++; + } + int chw = colInd[tid]; + int c = chw / HW; + int hw = chw % HW; + int h = hw / W; + int w = hw % W; + + // Constraints: for(int r = 0; r < R; r++) { if(0 <= p && p < P && (h - r + pad_h) % stride_h == 0) { ... } } + // Constraint 1: p >= 0 and p = (h - r + pad_h) / stride_h + // Therefore, r <= h + pad_h + // Constraint 2: p < P and p = (h - r + pad_h) / stride_h + // Therefore, h + pad_h - P*stride_h < r + // Math.max(0, h + pad_h - P*stride_h + 1) <= r <= Math.min(R-1, h + pad_h) + int rMin = max(0, h + pad_h - P*stride_h + 1); + int rMax = min(R-1, h + pad_h); + int sMin = max(0, w + pad_w - Q*stride_w + 1); + int sMax = min(S-1, w + pad_w); + // Constraint 3: (h - r + pad_h) % stride_h == 0 + while((h - rMin + pad_h) % stride_h != 0 && rMin <= rMax) rMin++; + while((w - sMin + pad_w) % stride_w != 0 && sMin <= sMax) sMin++; + + for(int r = rMin; r <= rMax; r += stride_h) { + // Only append value if h == h, where h = (r - pad_h) + p*stride_h and 0 <= p < P + // Therefore, p = (h - r + pad_h) / stride_h. Use the same logic for q. + int p = (h - r + pad_h) / stride_h; + int npQ = n*PQ + p*Q; + int outRowIndex = c*RS + r*S; + for(int s = sMin; s <= sMax; s += stride_w) { + int q = (w - s + pad_w) / stride_w; + // chw -> [crs, npq] + ret[(outRowIndex + s)*NPQ + npQ + q] = value; + } + } + } +} + +extern "C" __global__ void sparse_dense_im2col_d(double *inVal, int *inRowPtr, int *colInd, double *ret, + int nnz, int N, int CHW, int HW, int W, + int R, int S, int P, int Q, int PQ, int RS, int NPQ, + int stride_h, int stride_w, int pad_h, int pad_w) { + sparse_dense_im2col(inVal, inRowPtr, colInd, ret, nnz, N, CHW, HW, W, R, S, P, Q, PQ, RS, NPQ, stride_h, stride_w, pad_h, pad_w); +} + +extern "C" __global__ void sparse_dense_im2col_f(float *inVal, int *inRowPtr, int *colInd, float *ret, + int nnz, int N, int CHW, int HW, int W, + int R, int S, int P, int Q, int PQ, int RS, int NPQ, + int stride_h, int stride_w, int pad_h, int pad_w) { + sparse_dense_im2col(inVal, inRowPtr, colInd, ret, nnz, N, CHW, HW, W, R, S, P, Q, PQ, RS, NPQ, stride_h, stride_w, pad_h, pad_w); +} + +/** + * This method performs an im2col operation on dense input image + * + * @param input input matrix allocated on the GPU + * @param ret output matrix allocated on the GPU + * @param NCHW value of N*C*H*W + * @param CHW value of C*H*W + * @param HW value of H*W + * @param W image height + * @param R filter height + * @param S filter width + * @param P height of conv2d output + * @param Q width of conv2d output + * @param PQ value of P*Q + * @param RS value of R*S + * @param NPQ value of N*P*Q + * @param stride_h stride height + * @param stride_w stride width + * @param pad_h padding height + * @param pad_w padding width + */ +template <typename T> +__device__ void dense_dense_im2col(T *input, T *ret, + int NCHW, int CHW, int HW, int W, + int R, int S, int P, int Q, int PQ, int RS, int NPQ, + int stride_h, int stride_w, int pad_h, int pad_w) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < NCHW) { + T value = input[tid]; + int n = tid / CHW; + int chw = tid % CHW; + int c = chw / HW; + int hw = chw % HW; + int h = hw / W; + int w = hw % W; + + // Constraints: for(int r = 0; r < R; r++) { if(0 <= p && p < P && (h - r + pad_h) % stride_h == 0) { ... } } + // Constraint 1: p >= 0 and p = (h - r + pad_h) / stride_h + // Therefore, r <= h + pad_h + // Constraint 2: p < P and p = (h - r + pad_h) / stride_h + // Therefore, h + pad_h - P*stride_h < r + // Math.max(0, h + pad_h - P*stride_h + 1) <= r <= Math.min(R-1, h + pad_h) + int rMin = max(0, h + pad_h - P*stride_h + 1); + int rMax = min(R-1, h + pad_h); + int sMin = max(0, w + pad_w - Q*stride_w + 1); + int sMax = min(S-1, w + pad_w); + // Constraint 3: (h - r + pad_h) % stride_h == 0 + while((h - rMin + pad_h) % stride_h != 0 && rMin <= rMax) rMin++; + while((w - sMin + pad_w) % stride_w != 0 && sMin <= sMax) sMin++; + + for(int r = rMin; r <= rMax; r += stride_h) { + // Only append value if h == h, where h = (r - pad_h) + p*stride_h and 0 <= p < P + // Therefore, p = (h - r + pad_h) / stride_h. Use the same logic for q. + int p = (h - r + pad_h) / stride_h; + int npQ = n*PQ + p*Q; + int outRowIndex = c*RS + r*S; + for(int s = sMin; s <= sMax; s += stride_w) { + int q = (w - s + pad_w) / stride_w; + // chw -> [crs, npq] + ret[(outRowIndex + s)*NPQ + npQ + q] = value; + } + } + } +} + +extern "C" __global__ void dense_dense_im2col_d(double *input, double *ret, + int NCHW, int CHW, int HW, int W, + int R, int S, int P, int Q, int PQ, int RS, int NPQ, + int stride_h, int stride_w, int pad_h, int pad_w) { + dense_dense_im2col(input, ret, NCHW, CHW, HW, W, R, S, P, Q, PQ, RS, NPQ, stride_h, stride_w, pad_h, pad_w); +} + +extern "C" __global__ void dense_dense_im2col_f(float *input, float *ret, + int NCHW, int CHW, int HW, int W, + int R, int S, int P, int Q, int PQ, int RS, int NPQ, + int stride_h, int stride_w, int pad_h, int pad_w) { + dense_dense_im2col(input, ret, NCHW, CHW, HW, W, R, S, P, Q, PQ, RS, NPQ, stride_h, stride_w, pad_h, pad_w); +} + +/** + * This method performs a reorg operation of matrix with dimensions [K, NPQ] + * and returns a matrix with dimensions [N, KPQ] + * + * @param knpqPtr input matrix allocated on the GPU + * @param ret output matrix allocated on the GPU + * @param NKPQ length of input and output matrix + * @param NPQ the number of columns of input matrix + * @param KPQ the number of columns of output matrix + * @param PQ value of P*Q + */ +template <typename T> +__device__ void reorg_knpq(T *knpqPtr, T *ret, int NKPQ, int NPQ, int KPQ, int PQ) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < NKPQ) { + int k = tid / NPQ; + int npq = tid % NPQ; + int n = npq / PQ; + int pq = npq % PQ; + ret[n*KPQ + k*PQ + pq] = knpqPtr[tid]; + } +} + +extern "C" __global__ void reorg_knpq_d(double *knpqPtr, double *ret, int NKPQ, int NPQ, int KPQ, int PQ) { + reorg_knpq(knpqPtr, ret, NKPQ, NPQ, KPQ, PQ); +} + +extern "C" __global__ void reorg_knpq_f(float *knpqPtr, float *ret, int NKPQ, int NPQ, int KPQ, int PQ) { + reorg_knpq(knpqPtr, ret, NKPQ, NPQ, KPQ, PQ); +} + +/** * Performs a slice operation where the input matrix is sparse and the output * matrix is dense. * This function avoids unnecessary sparse to dense conversion of the input