Repository: systemml Updated Branches: refs/heads/master 512fb9e11 -> 3702df7c1
[SYSTEMML-445] Improved the performance of batchnorm backward - Added a custom kernel for computing dgamma in batch normalization layer. - Also, fixed a minor bug in GPUDenseInputPointerFetcher class. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3702df7c Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3702df7c Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3702df7c Branch: refs/heads/master Commit: 3702df7c1890b8c87c42715260240c604a5c3c64 Parents: 512fb9e Author: Niketan Pansare <npan...@us.ibm.com> Authored: Tue Oct 9 14:58:09 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Tue Oct 9 14:58:09 2018 -0700 ---------------------------------------------------------------------- src/main/cpp/kernels/SystemML.cu | 21 +++ src/main/cpp/kernels/SystemML.ptx | 188 ++++++++++++++++--- src/main/java/org/apache/sysml/hops/DnnOp.java | 8 +- src/main/java/org/apache/sysml/hops/Hop.java | 3 +- .../hops/rewrite/RewriteGPUSpecificOps.java | 22 ++- .../org/apache/sysml/lops/DnnTransform.java | 7 +- .../instructions/GPUInstructionParser.java | 1 + .../instructions/gpu/DnnGPUInstruction.java | 51 ++++- .../gpu/GPUDenseInputPointerFetcher.java | 4 +- .../runtime/matrix/data/LibMatrixCUDA.java | 19 +- 10 files changed, 285 insertions(+), 39 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index a53d07a..26d7f43 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -2385,3 +2385,24 @@ extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned int invVar(X, C, eps, size); } +template <typename T> +__device__ void backward_dgamma_tmp(T *ema_mean, T *dout, T *X, T*ema_var, T*ret, int N, int C, + int HW, int CHW, unsigned int NCHW) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / CHW; + int iy = tid % CHW; + if (ix < N && iy < CHW) { + int c = iy / HW; + ret[tid] = dout[tid] * ((X[tid] - ema_mean[c]) * ema_var[c]); + } +} + +extern "C" __global__ void backward_dgamma_tmp_d(double *ema_mean, double *dout, double *X, double* ema_var, double* ret, + int N, int C, int HW, int CHW, unsigned int NCHW) { + backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW); +} + +extern "C" __global__ void backward_dgamma_tmp_f(double *ema_mean, double *dout, double *X, double* ema_var, double* ret, + int N, int C, int HW, int CHW, int NCHW) { + backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW); +} http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.ptx ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx index ac04967..3043373 100644 --- a/src/main/cpp/kernels/SystemML.ptx +++ b/src/main/cpp/kernels/SystemML.ptx @@ -15084,12 +15084,146 @@ BB123_2: ret; } + // .globl backward_dgamma_tmp_d +.visible .entry backward_dgamma_tmp_d( + .param .u64 backward_dgamma_tmp_d_param_0, + .param .u64 backward_dgamma_tmp_d_param_1, + .param .u64 backward_dgamma_tmp_d_param_2, + .param .u64 backward_dgamma_tmp_d_param_3, + .param .u64 backward_dgamma_tmp_d_param_4, + .param .u32 backward_dgamma_tmp_d_param_5, + .param .u32 backward_dgamma_tmp_d_param_6, + .param .u32 backward_dgamma_tmp_d_param_7, + .param .u32 backward_dgamma_tmp_d_param_8, + .param .u32 backward_dgamma_tmp_d_param_9 +) +{ + .reg .pred %p<4>; + .reg .b32 %r<11>; + .reg .f64 %fd<8>; + .reg .b64 %rd<18>; + + + ld.param.u64 %rd1, [backward_dgamma_tmp_d_param_0]; + ld.param.u64 %rd2, [backward_dgamma_tmp_d_param_1]; + ld.param.u64 %rd3, [backward_dgamma_tmp_d_param_2]; + ld.param.u64 %rd4, [backward_dgamma_tmp_d_param_3]; + ld.param.u64 %rd5, [backward_dgamma_tmp_d_param_4]; + ld.param.u32 %r4, [backward_dgamma_tmp_d_param_5]; + ld.param.u32 %r2, [backward_dgamma_tmp_d_param_7]; + ld.param.u32 %r3, [backward_dgamma_tmp_d_param_8]; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %ntid.x; + mov.u32 %r7, %tid.x; + mad.lo.s32 %r1, %r6, %r5, %r7; + div.s32 %r8, %r1, %r3; + setp.lt.s32 %p1, %r8, %r4; + setp.gt.s32 %p2, %r3, -1; + and.pred %p3, %p1, %p2; + @!%p3 bra BB124_2; + bra.uni BB124_1; + +BB124_1: + rem.s32 %r9, %r1, %r3; + cvta.to.global.u64 %rd6, %rd2; + mul.wide.s32 %rd7, %r1, 8; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + div.s32 %r10, %r9, %r2; + cvta.to.global.u64 %rd11, %rd1; + mul.wide.s32 %rd12, %r10, 8; + add.s64 %rd13, %rd11, %rd12; + ld.global.f64 %fd1, [%rd13]; + ld.global.f64 %fd2, [%rd10]; + sub.f64 %fd3, %fd2, %fd1; + cvta.to.global.u64 %rd14, %rd4; + add.s64 %rd15, %rd14, %rd12; + ld.global.f64 %fd4, [%rd15]; + mul.f64 %fd5, %fd3, %fd4; + ld.global.f64 %fd6, [%rd8]; + mul.f64 %fd7, %fd6, %fd5; + cvta.to.global.u64 %rd16, %rd5; + add.s64 %rd17, %rd16, %rd7; + st.global.f64 [%rd17], %fd7; + +BB124_2: + ret; +} + + // .globl backward_dgamma_tmp_f +.visible .entry backward_dgamma_tmp_f( + .param .u64 backward_dgamma_tmp_f_param_0, + .param .u64 backward_dgamma_tmp_f_param_1, + .param .u64 backward_dgamma_tmp_f_param_2, + .param .u64 backward_dgamma_tmp_f_param_3, + .param .u64 backward_dgamma_tmp_f_param_4, + .param .u32 backward_dgamma_tmp_f_param_5, + .param .u32 backward_dgamma_tmp_f_param_6, + .param .u32 backward_dgamma_tmp_f_param_7, + .param .u32 backward_dgamma_tmp_f_param_8, + .param .u32 backward_dgamma_tmp_f_param_9 +) +{ + .reg .pred %p<4>; + .reg .b32 %r<11>; + .reg .f64 %fd<8>; + .reg .b64 %rd<18>; + + + ld.param.u64 %rd1, [backward_dgamma_tmp_f_param_0]; + ld.param.u64 %rd2, [backward_dgamma_tmp_f_param_1]; + ld.param.u64 %rd3, [backward_dgamma_tmp_f_param_2]; + ld.param.u64 %rd4, [backward_dgamma_tmp_f_param_3]; + ld.param.u64 %rd5, [backward_dgamma_tmp_f_param_4]; + ld.param.u32 %r4, [backward_dgamma_tmp_f_param_5]; + ld.param.u32 %r2, [backward_dgamma_tmp_f_param_7]; + ld.param.u32 %r3, [backward_dgamma_tmp_f_param_8]; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %ntid.x; + mov.u32 %r7, %tid.x; + mad.lo.s32 %r1, %r6, %r5, %r7; + div.s32 %r8, %r1, %r3; + setp.lt.s32 %p1, %r8, %r4; + setp.gt.s32 %p2, %r3, -1; + and.pred %p3, %p1, %p2; + @!%p3 bra BB125_2; + bra.uni BB125_1; + +BB125_1: + rem.s32 %r9, %r1, %r3; + cvta.to.global.u64 %rd6, %rd2; + mul.wide.s32 %rd7, %r1, 8; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + div.s32 %r10, %r9, %r2; + cvta.to.global.u64 %rd11, %rd1; + mul.wide.s32 %rd12, %r10, 8; + add.s64 %rd13, %rd11, %rd12; + ld.global.f64 %fd1, [%rd13]; + ld.global.f64 %fd2, [%rd10]; + sub.f64 %fd3, %fd2, %fd1; + cvta.to.global.u64 %rd14, %rd4; + add.s64 %rd15, %rd14, %rd12; + ld.global.f64 %fd4, [%rd15]; + mul.f64 %fd5, %fd3, %fd4; + ld.global.f64 %fd6, [%rd8]; + mul.f64 %fd7, %fd6, %fd5; + cvta.to.global.u64 %rd16, %rd5; + add.s64 %rd17, %rd16, %rd7; + st.global.f64 [%rd17], %fd7; + +BB125_2: + ret; +} + .func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd( .param .b64 __internal_trig_reduction_slowpathd_param_0, .param .b64 __internal_trig_reduction_slowpathd_param_1 ) { - .local .align 8 .b8 __local_depot124[40]; + .local .align 8 .b8 __local_depot126[40]; .reg .b64 %SP; .reg .b64 %SPL; .reg .pred %p<9>; @@ -15098,7 +15232,7 @@ BB123_2: .reg .b64 %rd<102>; - mov.u64 %rd101, __local_depot124; + mov.u64 %rd101, __local_depot126; cvta.local.u64 %SP, %rd101; ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0]; ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1]; @@ -15112,7 +15246,7 @@ BB123_2: shr.u32 %r3, %r1, 20; bfe.u32 %r4, %r1, 20, 11; setp.eq.s32 %p1, %r4, 2047; - @%p1 bra BB124_13; + @%p1 bra BB126_13; add.s32 %r15, %r4, -1024; shr.u32 %r16, %r15, 6; @@ -15125,7 +15259,7 @@ BB123_2: mov.u64 %rd94, 0; setp.ge.s32 %p2, %r5, %r6; mov.u64 %rd93, %rd1; - @%p2 bra BB124_4; + @%p2 bra BB126_4; mov.b64 %rd41, %fd4; shl.b64 %rd42, %rd41, 11; @@ -15142,7 +15276,7 @@ BB123_2: mov.u64 %rd91, %rd1; mov.u32 %r39, %r5; -BB124_3: +BB126_3: .pragma "nounroll"; ld.const.u64 %rd47, [%rd89]; // inline asm @@ -15172,15 +15306,15 @@ BB124_3: add.s64 %rd93, %rd93, 8; add.s64 %rd89, %rd89, 8; setp.lt.s32 %p3, %r39, %r6; - @%p3 bra BB124_3; + @%p3 bra BB126_3; -BB124_4: +BB126_4: st.local.u64 [%rd93], %rd94; ld.local.u64 %rd95, [%rd1+16]; ld.local.u64 %rd96, [%rd1+24]; and.b32 %r9, %r3, 63; setp.eq.s32 %p4, %r9, 0; - @%p4 bra BB124_6; + @%p4 bra BB126_6; mov.u32 %r27, 64; sub.s32 %r28, %r27, %r9; @@ -15192,7 +15326,7 @@ BB124_4: shr.u64 %rd55, %rd54, %r28; or.b64 %rd95, %rd55, %rd53; -BB124_6: +BB126_6: cvta.to.local.u64 %rd56, %rd37; shr.u64 %rd57, %rd96, 62; cvt.u32.u64 %r29, %rd57; @@ -15209,7 +15343,7 @@ BB124_6: selp.b32 %r34, %r32, %r33, %p5; st.local.u32 [%rd56], %r34; setp.eq.s32 %p6, %r31, 0; - @%p6 bra BB124_8; + @%p6 bra BB126_8; mov.u64 %rd64, 0; // inline asm @@ -15229,10 +15363,10 @@ BB124_6: // inline asm xor.b32 %r40, %r40, -2147483648; -BB124_8: +BB126_8: clz.b64 %r41, %rd98; setp.eq.s32 %p7, %r41, 0; - @%p7 bra BB124_10; + @%p7 bra BB126_10; shl.b64 %rd67, %rd98, %r41; mov.u32 %r35, 64; @@ -15240,7 +15374,7 @@ BB124_8: shr.u64 %rd68, %rd97, %r36; or.b64 %rd98, %rd68, %rd67; -BB124_10: +BB126_10: mov.u64 %rd72, -3958705157555305931; // inline asm { @@ -15261,7 +15395,7 @@ BB124_10: } // inline asm setp.lt.s64 %p8, %rd100, 1; - @%p8 bra BB124_12; + @%p8 bra BB126_12; // inline asm { @@ -15280,7 +15414,7 @@ BB124_10: // inline asm add.s32 %r41, %r41, 1; -BB124_12: +BB126_12: cvt.u64.u32 %rd79, %r40; shl.b64 %rd80, %rd79, 32; mov.u32 %r37, 1022; @@ -15295,7 +15429,7 @@ BB124_12: or.b64 %rd88, %rd87, %rd80; mov.b64 %fd4, %rd88; -BB124_13: +BB126_13: st.param.f64 [func_retval0+0], %fd4; ret; } @@ -15323,7 +15457,7 @@ BB124_13: } shr.u32 %r51, %r50, 20; setp.ne.s32 %p1, %r51, 0; - @%p1 bra BB125_2; + @%p1 bra BB127_2; mul.f64 %fd14, %fd12, 0d4350000000000000; { @@ -15337,13 +15471,13 @@ BB124_13: shr.u32 %r16, %r50, 20; add.s32 %r51, %r16, -54; -BB125_2: +BB127_2: add.s32 %r52, %r51, -1023; and.b32 %r17, %r50, -2146435073; or.b32 %r18, %r17, 1072693248; mov.b64 %fd135, {%r49, %r18}; setp.lt.u32 %p2, %r18, 1073127583; - @%p2 bra BB125_4; + @%p2 bra BB127_4; { .reg .b32 %temp; @@ -15357,7 +15491,7 @@ BB125_2: mov.b64 %fd135, {%r19, %r21}; add.s32 %r52, %r51, -1022; -BB125_4: +BB127_4: add.f64 %fd15, %fd135, 0d3FF0000000000000; rcp.approx.ftz.f64 %fd16, %fd15; neg.f64 %fd17, %fd15; @@ -15520,13 +15654,13 @@ BB125_4: mov.b32 %f2, %r35; abs.f32 %f1, %f2; setp.lt.f32 %p4, %f1, 0f4086232B; - @%p4 bra BB125_7; + @%p4 bra BB127_7; setp.lt.f64 %p5, %fd4, 0d0000000000000000; add.f64 %fd129, %fd4, 0d7FF0000000000000; selp.f64 %fd136, 0d0000000000000000, %fd129, %p5; setp.geu.f32 %p6, %f1, 0f40874800; - @%p6 bra BB125_7; + @%p6 bra BB127_7; mov.f64 %fd134, 0d4338000000000000; mov.f64 %fd133, 0d3FF71547652B82FE; @@ -15548,26 +15682,26 @@ BB125_4: mov.b64 %fd131, {%r44, %r43}; mul.f64 %fd136, %fd130, %fd131; -BB125_7: +BB127_7: { .reg .b32 %temp; mov.b64 {%temp, %r45}, %fd136; } and.b32 %r46, %r45, 2147483647; setp.ne.s32 %p7, %r46, 2146435072; - @%p7 bra BB125_9; + @%p7 bra BB127_9; { .reg .b32 %temp; mov.b64 {%r47, %temp}, %fd136; } setp.eq.s32 %p8, %r47, 0; - @%p8 bra BB125_10; + @%p8 bra BB127_10; -BB125_9: +BB127_9: fma.rn.f64 %fd136, %fd136, %fd5, %fd136; -BB125_10: +BB127_10: st.param.f64 [func_retval0+0], %fd136; ret; } http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/DnnOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java index c4ce466..7cf5061 100644 --- a/src/main/java/org/apache/sysml/hops/DnnOp.java +++ b/src/main/java/org/apache/sysml/hops/DnnOp.java @@ -141,6 +141,7 @@ public class DnnOp extends MultiThreadedHop case UPDATE_EMA: case INV_VAR: case BATCH_NORM2D_BACKWARD_DX: + case BATCH_NORM2D_BACKWARD_DGAMMA: { // GPU-specific operators setLops(constructDnnLops(ExecType.GPU, inputs)); @@ -181,6 +182,7 @@ public class DnnOp extends MultiThreadedHop case CHANNEL_SUMS: case UPDATE_EMA: return 3; + case BATCH_NORM2D_BACKWARD_DGAMMA: case UPDATE_NESTEROV_X: return 4; default: @@ -538,7 +540,7 @@ public class DnnOp extends MultiThreadedHop if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR || - op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) { + op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) { // Same dimension as the first input MatrixCharacteristics[] mc = memo.getAllInputStats(getInput()); ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1; @@ -755,7 +757,7 @@ public class DnnOp extends MultiThreadedHop { if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR || - op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) { + op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) { // Same dimension as the first input Hop input1 = getInput().get(0); setDim1(input1.getDim1()); @@ -873,7 +875,7 @@ public class DnnOp extends MultiThreadedHop if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.RESHAPE_COLMEANS || op == OpOpDnn.UPDATE_EMA_VAR || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR || - op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) { + op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) { throw new RuntimeException("getDim method should not be invoked for " + op.name()); } try { http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index c8356e0..82a6669 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1101,7 +1101,7 @@ public abstract class Hop implements ParseInfo CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS, UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR, - BATCH_NORM2D_BACKWARD_DX + BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA } public enum DataGenMethod { @@ -1182,6 +1182,7 @@ public abstract class Hop implements ParseInfo HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA); HopsConv2Lops.put(OpOpDnn.INV_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR); HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX); + HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA); } protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops; http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java index 577adc3..ab40d7b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -170,6 +170,25 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher { return hi; }; + // Avoids unnecessary intermediates: + // mean = cache_mean + // centered = bias_add(X, -mean) # shape (N, C*Hin*Win) + // norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) + // # Compute gradients during training + // dgamma = util::channel_sums(dout*norm, C, Hin, Win) + private static final HopDagPatternMatcher _batchNormDGamma; + static { + _batchNormDGamma = util_channel_sums( + mult( leaf("dout", MATRIX).fitsOnGPU(3), + bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))), + leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR)); + } + private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> { + LOG.debug("Applied batchNormDGamma rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, + "ema_mean", "dout", "X", "ema_var"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + }; // Pattern 3: private static final HopDagPatternMatcher _batchNormTest; @@ -282,8 +301,9 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher { if(_rewriters == null) { ArrayList<HopPatternRewriter> rewriters = new ArrayList<>(); rewriters.add(new HopPatternRewriter("batchNormdX", _batchNormdX, _batchNormdXReplacer)); - rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer)); rewriters.add(new HopPatternRewriter("batchNormTest", _batchNormTest, _batchNormTestReplacer)); + rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer)); + // rewriters.add(new HopPatternRewriter("batchNormDGamma", _batchNormDGamma, _batchNormDGammaReplacer)); rewriters.add(new HopPatternRewriter("channelSums", _channelSums, _channelSumsReplacer)); rewriters.add(new HopPatternRewriter("updateNesterovX", _updateNesterovX, _updateNesterovXReplacer)); rewriters.add(new HopPatternRewriter("reshapeColMeans", _reshapeColMeans, _reshapeColMeansReplacer)); http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/lops/DnnTransform.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java index 2d2d5f1..3496b5b 100644 --- a/src/main/java/org/apache/sysml/lops/DnnTransform.java +++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java @@ -33,7 +33,7 @@ public class DnnTransform extends Lop CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST, UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR, - BATCH_NORM2D_BACKWARD_DX + BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA } private OperationTypes operation; @@ -174,6 +174,9 @@ public class DnnTransform extends Lop case UPDATE_NESTEROV_X: return "update_nesterov_x"; + case BATCH_NORM2D_BACKWARD_DGAMMA: + return "batch_norm2d_bwd_dgamma"; + case BATCH_NORM2D_TEST: return "batch_norm2d_test"; @@ -254,7 +257,7 @@ public class DnnTransform extends Lop @Override public String getInstructions(String input1, String input2, String input3, String input4, String output) { - if(operation == OperationTypes.UPDATE_NESTEROV_X) { + if(operation == OperationTypes.UPDATE_NESTEROV_X || operation == OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA) { StringBuilder sb = new StringBuilder(); sb.append( getExecType() ); http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index 3480504..c8a0e8d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -66,6 +66,7 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "reshape_colmeans", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "inv_var", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d_bwd_dx", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "batch_norm2d_bwd_dgamma", GPUINSTRUCTION_TYPE.Dnn); // Matrix Multiply Operators String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary); http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java index 6094b6c..4ad4155 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java @@ -127,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction { public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - if( !( opcode.equals("update_nesterov_x")) ) { + if( !( opcode.equals("update_nesterov_x") || opcode.equals("batch_norm2d_bwd_dgamma")) ) { throw new DMLRuntimeException("Incorrect opcode: " + opcode); } _input1 = in1; @@ -339,6 +339,15 @@ public class DnnGPUInstruction extends GPUInstruction { CPOperand out = new CPOperand(parts[5]); return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0); } + else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) { + InstructionUtils.checkNumFields(parts, 5); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand out = new CPOperand(parts[5]); + return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0); + } else if (opcode.equalsIgnoreCase("lstm")) { InstructionUtils.checkNumFields(parts, 8); CPOperand in1 = new CPOperand(parts[1]); @@ -586,6 +595,42 @@ public class DnnGPUInstruction extends GPUInstruction { } } + // "ema_mean", "dout", "X", "ema_var" + private void processBatchNorm2dBackwardDGammaInstruction(ExecutionContext ec) { + try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) { + fetcher.add("ema_mean", _input1).add("dout", _input2).add("X", _input3) + .add("ema_var", _input4); + MatrixObject ema_mean = fetcher.getInputMatrixObject("ema_mean"); + MatrixObject dout = fetcher.getInputMatrixObject("dout"); + long C = ema_mean.getNumRows(); + long N = dout.getNumRows(); + long CHW = dout.getNumColumns(); + fetcher.validateDimensions("ema_mean", C, 1); + fetcher.validateDimensions("dout", N, CHW); + fetcher.validateDimensions("X", N, CHW); + fetcher.validateDimensions("ema_var", C, 1); + if(CHW % C != 0) { + throw new DMLRuntimeException("Incorrect dimensions: C=" + C + ", CHW=" + CHW); + } + long HW = CHW / C; + Pointer tmp = gCtx.allocate(instName, N*CHW*LibMatrixCUDA.sizeOfDataType); + // jcuda.runtime.JCuda.cudaDeviceSynchronize(); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("backward_dgamma_tmp", + ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(N*CHW)), + fetcher.getInputPointer("ema_mean"), + fetcher.getInputPointer("dout"), + fetcher.getInputPointer("X"), + fetcher.getInputPointer("ema_var"), + tmp, + // N, C, HW, CHW, NCHW + toInt(N), toInt(C), toInt(HW), toInt(CHW), N*CHW); + + LibMatrixCUDA.channelSums(gCtx, instName, + tmp, fetcher.getOutputPointer(C, 1), N, C, HW); + gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); + } + } + private static int toInt(long num) throws DMLRuntimeException { if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) { throw new DMLRuntimeException("GPU : Exceeded supported size " + num); @@ -734,6 +779,10 @@ public class DnnGPUInstruction extends GPUInstruction { processNesterovUpdateInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) { + processBatchNorm2dBackwardDGammaInstruction(ec); + return; + } else if (instOpcode.equalsIgnoreCase("update_ema_var")) { processUpdateEMAVarInstruction(ec); return; http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java index 1ab3420..06bd1df 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java @@ -94,10 +94,10 @@ public class GPUDenseInputPointerFetcher implements java.lang.AutoCloseable { public void validateDimensions(String var, long numRows, long numCols) { MatrixObject mo = getInputMatrixObject(var); if(numRows > 0 && mo.getNumRows() != numRows) { - throw new DMLRuntimeException("Expected number of rows of subgrp_means to be " + numRows + ", but found " + mo.getNumRows()); + throw new DMLRuntimeException("Expected number of rows of " + var + " to be " + numRows + ", but found " + mo.getNumRows()); } else if(numCols > 0 && mo.getNumColumns() != numCols) { - throw new DMLRuntimeException("Expected number of columns of subgrp_means to be " + numCols + ", but found " + mo.getNumColumns()); + throw new DMLRuntimeException("Expected number of columns of " + var + " to be " + numCols + ", but found " + mo.getNumColumns()); } } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index 46ab3f7..00aa578 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -362,10 +362,25 @@ public class LibMatrixCUDA { } Pointer imagePointer = getDensePointer(gCtx, input, instName); Pointer outputPointer = getDensePointer(gCtx, outputBlock, instName); - + channelSums(gCtx, instName, imagePointer, outputPointer, N, C, HW); + } + + /** + * Perform channel_sums operations: out = rowSums(matrix(colSums(A), rows=C, cols=HW)) + * + * @param gCtx a valid {@link GPUContext} + * @param instName the invoking instruction's name for record {@link Statistics}. + * @param imagePointer input image pointer + * @param outputPointer output pointer + * @param N number of rows + * @param C number of channels + * @param HW height*width + */ + public static void channelSums(GPUContext gCtx, String instName, Pointer imagePointer, Pointer outputPointer, long N, long C, long HW) { + int cols = toInt(C*HW); // We can replace this with CuDNN tensor reduce Pointer tmp = gCtx.allocate(instName, cols*sizeOfDataType); - reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, N, cols); + reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, toInt(N), cols); reduceRow(gCtx, instName, "reduce_row_sum", tmp, outputPointer, toInt(C), toInt(HW)); gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); }