Repository: systemml
Updated Branches:
  refs/heads/master 512fb9e11 -> 3702df7c1


[SYSTEMML-445] Improved the performance of batchnorm backward

- Added a custom kernel for computing dgamma in batch normalization
layer.
- Also, fixed a minor bug in GPUDenseInputPointerFetcher class.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3702df7c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3702df7c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3702df7c

Branch: refs/heads/master
Commit: 3702df7c1890b8c87c42715260240c604a5c3c64
Parents: 512fb9e
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Tue Oct 9 14:58:09 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Tue Oct 9 14:58:09 2018 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |  21 +++
 src/main/cpp/kernels/SystemML.ptx               | 188 ++++++++++++++++---
 src/main/java/org/apache/sysml/hops/DnnOp.java  |   8 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |   3 +-
 .../hops/rewrite/RewriteGPUSpecificOps.java     |  22 ++-
 .../org/apache/sysml/lops/DnnTransform.java     |   7 +-
 .../instructions/GPUInstructionParser.java      |   1 +
 .../instructions/gpu/DnnGPUInstruction.java     |  51 ++++-
 .../gpu/GPUDenseInputPointerFetcher.java        |   4 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |  19 +-
 10 files changed, 285 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index a53d07a..26d7f43 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2385,3 +2385,24 @@ extern "C" __global__ void invVar_f(float *X, float *C, 
double eps, unsigned int
   invVar(X, C, eps, size);
 }
 
+template <typename T>
+__device__ void backward_dgamma_tmp(T *ema_mean, T *dout, T *X, T*ema_var, 
T*ret, int N, int C,
+                         int HW, int CHW, unsigned int NCHW) {
+  int tid = blockIdx.x * blockDim.x + threadIdx.x;
+  int ix = tid / CHW;
+  int iy = tid % CHW;
+  if (ix < N && iy < CHW) {
+    int c = iy / HW;
+    ret[tid] = dout[tid] * ((X[tid] - ema_mean[c]) * ema_var[c]);
+  }
+}
+
+extern "C" __global__ void backward_dgamma_tmp_d(double *ema_mean, double 
*dout, double *X, double* ema_var, double* ret, 
+       int N, int C, int HW, int CHW, unsigned int NCHW) {
+  backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW);
+}
+
+extern "C" __global__ void backward_dgamma_tmp_f(double *ema_mean, double 
*dout, double *X, double* ema_var, double* ret, 
+       int N, int C, int HW, int CHW, int NCHW) {
+  backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW);
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx 
b/src/main/cpp/kernels/SystemML.ptx
index ac04967..3043373 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -15084,12 +15084,146 @@ BB123_2:
        ret;
 }
 
+       // .globl       backward_dgamma_tmp_d
+.visible .entry backward_dgamma_tmp_d(
+       .param .u64 backward_dgamma_tmp_d_param_0,
+       .param .u64 backward_dgamma_tmp_d_param_1,
+       .param .u64 backward_dgamma_tmp_d_param_2,
+       .param .u64 backward_dgamma_tmp_d_param_3,
+       .param .u64 backward_dgamma_tmp_d_param_4,
+       .param .u32 backward_dgamma_tmp_d_param_5,
+       .param .u32 backward_dgamma_tmp_d_param_6,
+       .param .u32 backward_dgamma_tmp_d_param_7,
+       .param .u32 backward_dgamma_tmp_d_param_8,
+       .param .u32 backward_dgamma_tmp_d_param_9
+)
+{
+       .reg .pred      %p<4>;
+       .reg .b32       %r<11>;
+       .reg .f64       %fd<8>;
+       .reg .b64       %rd<18>;
+
+
+       ld.param.u64    %rd1, [backward_dgamma_tmp_d_param_0];
+       ld.param.u64    %rd2, [backward_dgamma_tmp_d_param_1];
+       ld.param.u64    %rd3, [backward_dgamma_tmp_d_param_2];
+       ld.param.u64    %rd4, [backward_dgamma_tmp_d_param_3];
+       ld.param.u64    %rd5, [backward_dgamma_tmp_d_param_4];
+       ld.param.u32    %r4, [backward_dgamma_tmp_d_param_5];
+       ld.param.u32    %r2, [backward_dgamma_tmp_d_param_7];
+       ld.param.u32    %r3, [backward_dgamma_tmp_d_param_8];
+       mov.u32         %r5, %ctaid.x;
+       mov.u32         %r6, %ntid.x;
+       mov.u32         %r7, %tid.x;
+       mad.lo.s32      %r1, %r6, %r5, %r7;
+       div.s32         %r8, %r1, %r3;
+       setp.lt.s32     %p1, %r8, %r4;
+       setp.gt.s32     %p2, %r3, -1;
+       and.pred        %p3, %p1, %p2;
+       @!%p3 bra       BB124_2;
+       bra.uni         BB124_1;
+
+BB124_1:
+       rem.s32         %r9, %r1, %r3;
+       cvta.to.global.u64      %rd6, %rd2;
+       mul.wide.s32    %rd7, %r1, 8;
+       add.s64         %rd8, %rd6, %rd7;
+       cvta.to.global.u64      %rd9, %rd3;
+       add.s64         %rd10, %rd9, %rd7;
+       div.s32         %r10, %r9, %r2;
+       cvta.to.global.u64      %rd11, %rd1;
+       mul.wide.s32    %rd12, %r10, 8;
+       add.s64         %rd13, %rd11, %rd12;
+       ld.global.f64   %fd1, [%rd13];
+       ld.global.f64   %fd2, [%rd10];
+       sub.f64         %fd3, %fd2, %fd1;
+       cvta.to.global.u64      %rd14, %rd4;
+       add.s64         %rd15, %rd14, %rd12;
+       ld.global.f64   %fd4, [%rd15];
+       mul.f64         %fd5, %fd3, %fd4;
+       ld.global.f64   %fd6, [%rd8];
+       mul.f64         %fd7, %fd6, %fd5;
+       cvta.to.global.u64      %rd16, %rd5;
+       add.s64         %rd17, %rd16, %rd7;
+       st.global.f64   [%rd17], %fd7;
+
+BB124_2:
+       ret;
+}
+
+       // .globl       backward_dgamma_tmp_f
+.visible .entry backward_dgamma_tmp_f(
+       .param .u64 backward_dgamma_tmp_f_param_0,
+       .param .u64 backward_dgamma_tmp_f_param_1,
+       .param .u64 backward_dgamma_tmp_f_param_2,
+       .param .u64 backward_dgamma_tmp_f_param_3,
+       .param .u64 backward_dgamma_tmp_f_param_4,
+       .param .u32 backward_dgamma_tmp_f_param_5,
+       .param .u32 backward_dgamma_tmp_f_param_6,
+       .param .u32 backward_dgamma_tmp_f_param_7,
+       .param .u32 backward_dgamma_tmp_f_param_8,
+       .param .u32 backward_dgamma_tmp_f_param_9
+)
+{
+       .reg .pred      %p<4>;
+       .reg .b32       %r<11>;
+       .reg .f64       %fd<8>;
+       .reg .b64       %rd<18>;
+
+
+       ld.param.u64    %rd1, [backward_dgamma_tmp_f_param_0];
+       ld.param.u64    %rd2, [backward_dgamma_tmp_f_param_1];
+       ld.param.u64    %rd3, [backward_dgamma_tmp_f_param_2];
+       ld.param.u64    %rd4, [backward_dgamma_tmp_f_param_3];
+       ld.param.u64    %rd5, [backward_dgamma_tmp_f_param_4];
+       ld.param.u32    %r4, [backward_dgamma_tmp_f_param_5];
+       ld.param.u32    %r2, [backward_dgamma_tmp_f_param_7];
+       ld.param.u32    %r3, [backward_dgamma_tmp_f_param_8];
+       mov.u32         %r5, %ctaid.x;
+       mov.u32         %r6, %ntid.x;
+       mov.u32         %r7, %tid.x;
+       mad.lo.s32      %r1, %r6, %r5, %r7;
+       div.s32         %r8, %r1, %r3;
+       setp.lt.s32     %p1, %r8, %r4;
+       setp.gt.s32     %p2, %r3, -1;
+       and.pred        %p3, %p1, %p2;
+       @!%p3 bra       BB125_2;
+       bra.uni         BB125_1;
+
+BB125_1:
+       rem.s32         %r9, %r1, %r3;
+       cvta.to.global.u64      %rd6, %rd2;
+       mul.wide.s32    %rd7, %r1, 8;
+       add.s64         %rd8, %rd6, %rd7;
+       cvta.to.global.u64      %rd9, %rd3;
+       add.s64         %rd10, %rd9, %rd7;
+       div.s32         %r10, %r9, %r2;
+       cvta.to.global.u64      %rd11, %rd1;
+       mul.wide.s32    %rd12, %r10, 8;
+       add.s64         %rd13, %rd11, %rd12;
+       ld.global.f64   %fd1, [%rd13];
+       ld.global.f64   %fd2, [%rd10];
+       sub.f64         %fd3, %fd2, %fd1;
+       cvta.to.global.u64      %rd14, %rd4;
+       add.s64         %rd15, %rd14, %rd12;
+       ld.global.f64   %fd4, [%rd15];
+       mul.f64         %fd5, %fd3, %fd4;
+       ld.global.f64   %fd6, [%rd8];
+       mul.f64         %fd7, %fd6, %fd5;
+       cvta.to.global.u64      %rd16, %rd5;
+       add.s64         %rd17, %rd16, %rd7;
+       st.global.f64   [%rd17], %fd7;
+
+BB125_2:
+       ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
        .param .b64 __internal_trig_reduction_slowpathd_param_0,
        .param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-       .local .align 8 .b8     __local_depot124[40];
+       .local .align 8 .b8     __local_depot126[40];
        .reg .b64       %SP;
        .reg .b64       %SPL;
        .reg .pred      %p<9>;
@@ -15098,7 +15232,7 @@ BB123_2:
        .reg .b64       %rd<102>;
 
 
-       mov.u64         %rd101, __local_depot124;
+       mov.u64         %rd101, __local_depot126;
        cvta.local.u64  %SP, %rd101;
        ld.param.f64    %fd4, [__internal_trig_reduction_slowpathd_param_0];
        ld.param.u64    %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -15112,7 +15246,7 @@ BB123_2:
        shr.u32         %r3, %r1, 20;
        bfe.u32         %r4, %r1, 20, 11;
        setp.eq.s32     %p1, %r4, 2047;
-       @%p1 bra        BB124_13;
+       @%p1 bra        BB126_13;
 
        add.s32         %r15, %r4, -1024;
        shr.u32         %r16, %r15, 6;
@@ -15125,7 +15259,7 @@ BB123_2:
        mov.u64         %rd94, 0;
        setp.ge.s32     %p2, %r5, %r6;
        mov.u64         %rd93, %rd1;
-       @%p2 bra        BB124_4;
+       @%p2 bra        BB126_4;
 
        mov.b64          %rd41, %fd4;
        shl.b64         %rd42, %rd41, 11;
@@ -15142,7 +15276,7 @@ BB123_2:
        mov.u64         %rd91, %rd1;
        mov.u32         %r39, %r5;
 
-BB124_3:
+BB126_3:
        .pragma "nounroll";
        ld.const.u64    %rd47, [%rd89];
        // inline asm
@@ -15172,15 +15306,15 @@ BB124_3:
        add.s64         %rd93, %rd93, 8;
        add.s64         %rd89, %rd89, 8;
        setp.lt.s32     %p3, %r39, %r6;
-       @%p3 bra        BB124_3;
+       @%p3 bra        BB126_3;
 
-BB124_4:
+BB126_4:
        st.local.u64    [%rd93], %rd94;
        ld.local.u64    %rd95, [%rd1+16];
        ld.local.u64    %rd96, [%rd1+24];
        and.b32         %r9, %r3, 63;
        setp.eq.s32     %p4, %r9, 0;
-       @%p4 bra        BB124_6;
+       @%p4 bra        BB126_6;
 
        mov.u32         %r27, 64;
        sub.s32         %r28, %r27, %r9;
@@ -15192,7 +15326,7 @@ BB124_4:
        shr.u64         %rd55, %rd54, %r28;
        or.b64          %rd95, %rd55, %rd53;
 
-BB124_6:
+BB126_6:
        cvta.to.local.u64       %rd56, %rd37;
        shr.u64         %rd57, %rd96, 62;
        cvt.u32.u64     %r29, %rd57;
@@ -15209,7 +15343,7 @@ BB124_6:
        selp.b32        %r34, %r32, %r33, %p5;
        st.local.u32    [%rd56], %r34;
        setp.eq.s32     %p6, %r31, 0;
-       @%p6 bra        BB124_8;
+       @%p6 bra        BB126_8;
 
        mov.u64         %rd64, 0;
        // inline asm
@@ -15229,10 +15363,10 @@ BB124_6:
        // inline asm
        xor.b32         %r40, %r40, -2147483648;
 
-BB124_8:
+BB126_8:
        clz.b64         %r41, %rd98;
        setp.eq.s32     %p7, %r41, 0;
-       @%p7 bra        BB124_10;
+       @%p7 bra        BB126_10;
 
        shl.b64         %rd67, %rd98, %r41;
        mov.u32         %r35, 64;
@@ -15240,7 +15374,7 @@ BB124_8:
        shr.u64         %rd68, %rd97, %r36;
        or.b64          %rd98, %rd68, %rd67;
 
-BB124_10:
+BB126_10:
        mov.u64         %rd72, -3958705157555305931;
        // inline asm
        {
@@ -15261,7 +15395,7 @@ BB124_10:
        }
        // inline asm
        setp.lt.s64     %p8, %rd100, 1;
-       @%p8 bra        BB124_12;
+       @%p8 bra        BB126_12;
 
        // inline asm
        {
@@ -15280,7 +15414,7 @@ BB124_10:
        // inline asm
        add.s32         %r41, %r41, 1;
 
-BB124_12:
+BB126_12:
        cvt.u64.u32     %rd79, %r40;
        shl.b64         %rd80, %rd79, 32;
        mov.u32         %r37, 1022;
@@ -15295,7 +15429,7 @@ BB124_12:
        or.b64          %rd88, %rd87, %rd80;
        mov.b64          %fd4, %rd88;
 
-BB124_13:
+BB126_13:
        st.param.f64    [func_retval0+0], %fd4;
        ret;
 }
@@ -15323,7 +15457,7 @@ BB124_13:
        }
        shr.u32         %r51, %r50, 20;
        setp.ne.s32     %p1, %r51, 0;
-       @%p1 bra        BB125_2;
+       @%p1 bra        BB127_2;
 
        mul.f64         %fd14, %fd12, 0d4350000000000000;
        {
@@ -15337,13 +15471,13 @@ BB124_13:
        shr.u32         %r16, %r50, 20;
        add.s32         %r51, %r16, -54;
 
-BB125_2:
+BB127_2:
        add.s32         %r52, %r51, -1023;
        and.b32         %r17, %r50, -2146435073;
        or.b32          %r18, %r17, 1072693248;
        mov.b64         %fd135, {%r49, %r18};
        setp.lt.u32     %p2, %r18, 1073127583;
-       @%p2 bra        BB125_4;
+       @%p2 bra        BB127_4;
 
        {
        .reg .b32 %temp; 
@@ -15357,7 +15491,7 @@ BB125_2:
        mov.b64         %fd135, {%r19, %r21};
        add.s32         %r52, %r51, -1022;
 
-BB125_4:
+BB127_4:
        add.f64         %fd15, %fd135, 0d3FF0000000000000;
        rcp.approx.ftz.f64      %fd16, %fd15;
        neg.f64         %fd17, %fd15;
@@ -15520,13 +15654,13 @@ BB125_4:
        mov.b32          %f2, %r35;
        abs.f32         %f1, %f2;
        setp.lt.f32     %p4, %f1, 0f4086232B;
-       @%p4 bra        BB125_7;
+       @%p4 bra        BB127_7;
 
        setp.lt.f64     %p5, %fd4, 0d0000000000000000;
        add.f64         %fd129, %fd4, 0d7FF0000000000000;
        selp.f64        %fd136, 0d0000000000000000, %fd129, %p5;
        setp.geu.f32    %p6, %f1, 0f40874800;
-       @%p6 bra        BB125_7;
+       @%p6 bra        BB127_7;
 
        mov.f64         %fd134, 0d4338000000000000;
        mov.f64         %fd133, 0d3FF71547652B82FE;
@@ -15548,26 +15682,26 @@ BB125_4:
        mov.b64         %fd131, {%r44, %r43};
        mul.f64         %fd136, %fd130, %fd131;
 
-BB125_7:
+BB127_7:
        {
        .reg .b32 %temp; 
        mov.b64         {%temp, %r45}, %fd136;
        }
        and.b32         %r46, %r45, 2147483647;
        setp.ne.s32     %p7, %r46, 2146435072;
-       @%p7 bra        BB125_9;
+       @%p7 bra        BB127_9;
 
        {
        .reg .b32 %temp; 
        mov.b64         {%r47, %temp}, %fd136;
        }
        setp.eq.s32     %p8, %r47, 0;
-       @%p8 bra        BB125_10;
+       @%p8 bra        BB127_10;
 
-BB125_9:
+BB127_9:
        fma.rn.f64      %fd136, %fd136, %fd5, %fd136;
 
-BB125_10:
+BB127_10:
        st.param.f64    [func_retval0+0], %fd136;
        ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java 
b/src/main/java/org/apache/sysml/hops/DnnOp.java
index c4ce466..7cf5061 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -141,6 +141,7 @@ public class DnnOp extends MultiThreadedHop
                        case UPDATE_EMA:
                        case INV_VAR:
                        case BATCH_NORM2D_BACKWARD_DX:
+                       case BATCH_NORM2D_BACKWARD_DGAMMA:
                        {       
                                // GPU-specific operators
                                setLops(constructDnnLops(ExecType.GPU, inputs));
@@ -181,6 +182,7 @@ public class DnnOp extends MultiThreadedHop
                        case CHANNEL_SUMS:
                        case UPDATE_EMA:
                                return 3;
+                       case BATCH_NORM2D_BACKWARD_DGAMMA:
                        case UPDATE_NESTEROV_X:
                                return 4;
                        default:
@@ -538,7 +540,7 @@ public class DnnOp extends MultiThreadedHop
                
                if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST ||
                        op == OpOpDnn.UPDATE_NESTEROV_X || op == 
OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
-                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == 
OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
                        // Same dimension as the first input
                        MatrixCharacteristics[] mc = 
memo.getAllInputStats(getInput());
                        ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
@@ -755,7 +757,7 @@ public class DnnOp extends MultiThreadedHop
        {
                if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || 
                        op == OpOpDnn.UPDATE_NESTEROV_X || op == 
OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
-                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == 
OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
                        // Same dimension as the first input
                        Hop input1 = getInput().get(0);
                        setDim1(input1.getDim1());
@@ -873,7 +875,7 @@ public class DnnOp extends MultiThreadedHop
                if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
                        op == OpOpDnn.UPDATE_NESTEROV_X || op == 
OpOpDnn.RESHAPE_COLMEANS ||
                        op == OpOpDnn.UPDATE_EMA_VAR || op == 
OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
-                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == 
OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
                        throw new RuntimeException("getDim method should not be 
invoked for " + op.name());
                }
                try {

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index c8356e0..82a6669 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1101,7 +1101,7 @@ public abstract class Hop implements ParseInfo
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
                BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
                UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, 
UPDATE_EMA, INV_VAR,
-               BATCH_NORM2D_BACKWARD_DX
+               BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA
        }
        
        public enum DataGenMethod {
@@ -1182,6 +1182,7 @@ public abstract class Hop implements ParseInfo
                HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, 
org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA);
                HopsConv2Lops.put(OpOpDnn.INV_VAR, 
org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR);
                HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, 
org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX);
+               HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, 
org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA);
        }
 
        protected static final HashMap<Hop.Direction, 
org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 577adc3..ab40d7b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -170,6 +170,25 @@ public class RewriteGPUSpecificOps extends 
HopRewriteRuleWithPatternMatcher {
                return hi;
        };
        
+       // Avoids unnecessary intermediates:
+       // mean = cache_mean
+       // centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+       // norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
+       // # Compute gradients during training
+       // dgamma = util::channel_sums(dout*norm, C, Hin, Win)
+       private static final HopDagPatternMatcher _batchNormDGamma;
+       static {
+               _batchNormDGamma = util_channel_sums(
+                               mult(   leaf("dout", MATRIX).fitsOnGPU(3),
+                                               
bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", 
MATRIX))), 
+                               leaf("ema_var", MATRIX))), leaf("C", SCALAR), 
leaf("HW", SCALAR));
+       }
+       private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi 
-> {
+               LOG.debug("Applied batchNormDGamma rewrite.");
+               Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, 
OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, 
+                               "ema_mean", "dout", "X", "ema_var");
+               return HopRewriteUtils.rewireAllParentChildReferences(hi, 
newHop);
+       };
                
        // Pattern 3:
        private static final HopDagPatternMatcher _batchNormTest;
@@ -282,8 +301,9 @@ public class RewriteGPUSpecificOps extends 
HopRewriteRuleWithPatternMatcher {
                if(_rewriters == null) {
                        ArrayList<HopPatternRewriter> rewriters = new 
ArrayList<>();
                        rewriters.add(new HopPatternRewriter("batchNormdX", 
_batchNormdX, _batchNormdXReplacer));
-                       rewriters.add(new 
HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, 
_batchNormUpdatedVarReplacer));
                        rewriters.add(new HopPatternRewriter("batchNormTest", 
_batchNormTest, _batchNormTestReplacer));
+                       rewriters.add(new 
HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, 
_batchNormUpdatedVarReplacer));
+                       // rewriters.add(new 
HopPatternRewriter("batchNormDGamma", _batchNormDGamma, 
_batchNormDGammaReplacer));
                        rewriters.add(new HopPatternRewriter("channelSums", 
_channelSums, _channelSumsReplacer));
                        rewriters.add(new HopPatternRewriter("updateNesterovX", 
_updateNesterovX, _updateNesterovXReplacer));
                        rewriters.add(new HopPatternRewriter("reshapeColMeans", 
_reshapeColMeans, _reshapeColMeansReplacer));

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java 
b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 2d2d5f1..3496b5b 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -33,7 +33,7 @@ public class DnnTransform extends Lop
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
                BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, 
BATCH_NORM2D_TEST, 
                UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, 
UPDATE_EMA, INV_VAR,
-               BATCH_NORM2D_BACKWARD_DX
+               BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA
        }
        
        private OperationTypes operation;
@@ -174,6 +174,9 @@ public class DnnTransform extends Lop
                case UPDATE_NESTEROV_X:
                        return "update_nesterov_x";
                        
+               case BATCH_NORM2D_BACKWARD_DGAMMA:
+                       return "batch_norm2d_bwd_dgamma";
+                       
                case BATCH_NORM2D_TEST:
                        return "batch_norm2d_test";
                
@@ -254,7 +257,7 @@ public class DnnTransform extends Lop
        
        @Override
        public String getInstructions(String input1, String input2, String 
input3, String input4, String output) {
-               if(operation == OperationTypes.UPDATE_NESTEROV_X) {
+               if(operation == OperationTypes.UPDATE_NESTEROV_X || operation 
== OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA) {
                        StringBuilder sb = new StringBuilder();
                        sb.append( getExecType() );
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 3480504..c8a0e8d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -66,6 +66,7 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "reshape_colmeans",      
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "inv_var",                       
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d_bwd_dx",   
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "batch_norm2d_bwd_dgamma",   
GPUINSTRUCTION_TYPE.Dnn);
                
                // Matrix Multiply Operators
                String2GPUInstructionType.put( "ba+*",  
GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 6094b6c..4ad4155 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -127,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction {
        public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand out, String opcode, String istr, 
                        double intermediateMemoryBudget) throws 
DMLRuntimeException {
                super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               if( !( opcode.equals("update_nesterov_x")) ) {
+               if( !( opcode.equals("update_nesterov_x") || 
opcode.equals("batch_norm2d_bwd_dgamma")) ) {
                        throw new DMLRuntimeException("Incorrect opcode: " + 
opcode);
                }
                _input1 = in1;
@@ -339,6 +339,15 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand out = new CPOperand(parts[5]);
                        return new DnnGPUInstruction(in, in2, in3, in4, out, 
opcode, str, 0);
                }
+               else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) {
+                       InstructionUtils.checkNumFields(parts, 5);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand in4 = new CPOperand(parts[4]);
+                       CPOperand out = new CPOperand(parts[5]);
+                       return new DnnGPUInstruction(in, in2, in3, in4, out, 
opcode, str, 0);
+               }
                else if (opcode.equalsIgnoreCase("lstm")) {
                        InstructionUtils.checkNumFields(parts, 8);
                        CPOperand in1 = new CPOperand(parts[1]);
@@ -586,6 +595,42 @@ public class DnnGPUInstruction extends GPUInstruction {
                }
        }
        
+       // "ema_mean", "dout", "X", "ema_var"
+               private void 
processBatchNorm2dBackwardDGammaInstruction(ExecutionContext ec) {
+                       try(GPUDenseInputPointerFetcher fetcher = new 
GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+                               fetcher.add("ema_mean", _input1).add("dout", 
_input2).add("X", _input3)
+                               .add("ema_var", _input4);
+                               MatrixObject ema_mean = 
fetcher.getInputMatrixObject("ema_mean");
+                               MatrixObject dout = 
fetcher.getInputMatrixObject("dout");
+                               long C = ema_mean.getNumRows();
+                               long N = dout.getNumRows();
+                               long CHW = dout.getNumColumns();
+                               fetcher.validateDimensions("ema_mean", C, 1);
+                               fetcher.validateDimensions("dout", N, CHW);
+                               fetcher.validateDimensions("X", N, CHW);
+                               fetcher.validateDimensions("ema_var", C, 1);
+                               if(CHW % C != 0) {
+                                       throw new 
DMLRuntimeException("Incorrect dimensions: C=" + C + ", CHW=" + CHW);
+                               }
+                               long HW = CHW / C;
+                               Pointer tmp = gCtx.allocate(instName, 
N*CHW*LibMatrixCUDA.sizeOfDataType);
+                               // jcuda.runtime.JCuda.cudaDeviceSynchronize();
+                               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("backward_dgamma_tmp", 
+                                               
ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(N*CHW)),
+                                               
fetcher.getInputPointer("ema_mean"), 
+                                               fetcher.getInputPointer("dout"),
+                                               fetcher.getInputPointer("X"),
+                                               
fetcher.getInputPointer("ema_var"),
+                                               tmp,
+                                               // N, C, HW, CHW, NCHW
+                                               toInt(N), toInt(C), toInt(HW), 
toInt(CHW), N*CHW);
+                               
+                               LibMatrixCUDA.channelSums(gCtx, instName, 
+                                               tmp, 
fetcher.getOutputPointer(C, 1), N, C, HW);
+                               gCtx.cudaFreeHelper(instName, tmp, 
gCtx.EAGER_CUDA_FREE);
+                       }
+               }
+       
        private static int toInt(long num) throws DMLRuntimeException {
                if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
                        throw new DMLRuntimeException("GPU : Exceeded supported 
size " + num);
@@ -734,6 +779,10 @@ public class DnnGPUInstruction extends GPUInstruction {
                        processNesterovUpdateInstruction(ec);
                        return;
                }
+               else if 
(instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) {
+                       processBatchNorm2dBackwardDGammaInstruction(ec);
+                       return;
+               }
                else if (instOpcode.equalsIgnoreCase("update_ema_var")) {
                        processUpdateEMAVarInstruction(ec);
                        return;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
index 1ab3420..06bd1df 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
@@ -94,10 +94,10 @@ public class GPUDenseInputPointerFetcher implements 
java.lang.AutoCloseable {
        public void validateDimensions(String var, long numRows, long numCols) {
                MatrixObject mo = getInputMatrixObject(var);
                if(numRows > 0 && mo.getNumRows() != numRows) {
-                       throw new DMLRuntimeException("Expected number of rows 
of subgrp_means to be " + numRows + ", but found " + mo.getNumRows());
+                       throw new DMLRuntimeException("Expected number of rows 
of " + var + " to be " + numRows + ", but found " + mo.getNumRows());
                }
                else if(numCols > 0 && mo.getNumColumns() != numCols) {
-                       throw new DMLRuntimeException("Expected number of 
columns of subgrp_means to be " + numCols + ", but found " + 
mo.getNumColumns());
+                       throw new DMLRuntimeException("Expected number of 
columns of " + var + " to be " + numCols + ", but found " + mo.getNumColumns());
                }
        }
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 46ab3f7..00aa578 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -362,10 +362,25 @@ public class LibMatrixCUDA {
                }
                Pointer imagePointer = getDensePointer(gCtx, input, instName);
                Pointer outputPointer = getDensePointer(gCtx, outputBlock, 
instName);
-               
+               channelSums(gCtx, instName, imagePointer, outputPointer, N, C, 
HW);
+       }
+       
+       /**
+        * Perform channel_sums operations: out = rowSums(matrix(colSums(A), 
rows=C, cols=HW))
+        * 
+        * @param gCtx a valid {@link GPUContext}
+        * @param instName the invoking instruction's name for record {@link 
Statistics}.
+        * @param imagePointer  input image pointer
+        * @param outputPointer output pointer
+        * @param N number of rows
+        * @param C number of channels
+        * @param HW height*width
+        */
+       public static void channelSums(GPUContext gCtx, String instName, 
Pointer imagePointer, Pointer outputPointer, long N, long C, long HW) {
+               int cols = toInt(C*HW);
                // We can replace this with CuDNN tensor reduce
                Pointer tmp = gCtx.allocate(instName, cols*sizeOfDataType);
-               reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, 
N, cols);
+               reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, 
toInt(N), cols);
                reduceRow(gCtx, instName, "reduce_row_sum", tmp, outputPointer, 
toInt(C), toInt(HW));
                gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
        }

Reply via email to