ptrendx commented on a change in pull request #13896: Cudnn dropout
URL: https://github.com/apache/incubator-mxnet/pull/13896#discussion_r253651093
 
 

 ##########
 File path: src/operator/nn/dropout-inl.h
 ##########
 @@ -227,52 +203,181 @@ class DropoutOp {
     }
   };
 
-  void Init(const DropoutParam &param) {
+  explicit DropoutOp(const DropoutParam &param, Context ctx) {
     this->pkeep_ = 1.0f - param.p;
     this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode);
     this->axes_ = param.axes;
+    this->dropout_passthrough_ = true;
+#if MXNET_USE_CUDNN_DROPOUT
+    this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value();
+    this->ctx_ = ctx;
+    if (ctx.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) {
+      dtype_ = mshadow::DataType<DType>::kCudnnFlag;
+      CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_));
+      CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc_));
+      CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_desc_));
+      CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_desc_));
+      CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
+    }
+#endif  // MXNET_USE_CUDNN_DROPOUT
+  }
+
+  ~DropoutOp() {
+#if MXNET_USE_CUDNN_DROPOUT
+    if (this->ctx_.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) {
+      CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_));
+      CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_));
+      CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_));
+      CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_));
+      CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
+    }
+#endif  // MXNET_USE_CUDNN_DROPOUT
   }
 
+#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__)
+  inline bool CuDNNAvailable() {
+    return this->pkeep_ > 0 && !this->cudnn_off_;
+  }
+
+  inline void CuDNNForward(const OpContext &ctx,
+                           const TBlob &in,
+                           const TBlob &mask,
+                           const TBlob &out) {
+      Stream<xpu> *s = ctx.get_stream<xpu>();
+
+      // set dropout state.
+      ctx.requested[0].get_cudnn_dropout_desc(&dropout_desc_, s, 1.0f - 
this->pkeep_, seed_);
+
+      // describe input/output tensor
+      int dim[4], stride[4];
+      dim[0] = 1;
+      dim[1] = 1;
+      dim[2] = 1;
+      dim[3] = out.Size();
+      stride[0] = out.Size();
+      stride[1] = out.Size();
+      stride[2] = out.Size();
+      stride[3] = 1;
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(x_desc_,
+                                            dtype_,
+                                            4,
+                                            dim,
+                                            stride));
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(y_desc_,
+                                            dtype_,
+                                            4,
+                                            dim,
+                                            stride));
+
+      // perform dropout with cudnn
+      CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(x_desc_, 
&dropout_reserve_byte_));
+      // cudnn uses bits to record the positions that are dropped, so reserve 
bytes is always
+      // 1/8 of input size.
+      CHECK_GE(mask.Size() * sizeof(DType), dropout_reserve_byte_) <<
+        "The size of the mask space is smaller than the required cudnn 
reserved space.";
+      CUDNN_CALL(cudnnDropoutForward(s->dnn_handle_,
+                                     dropout_desc_,
+                                     x_desc_,
+                                     in.dptr<DType>(),
+                                     y_desc_,
+                                     out.dptr<DType>(),
+                                     mask.dptr<DType>(),
+                                     dropout_reserve_byte_));
+  }
+
+  inline void CuDNNBackward(const OpContext &ctx,
+                            const TBlob &out_grad,
+                            const TBlob &mask,
+                            const TBlob &in_grad) {
+      Stream<xpu> *s = ctx.get_stream<xpu>();
+
+      // describe input/output tensor
+      int dim[4], stride[4];
+      dim[0] = 1;
+      dim[1] = 1;
+      dim[2] = 1;
+      dim[3] = in_grad.Size();
+      stride[0] = in_grad.Size();
+      stride[1] = in_grad.Size();
+      stride[2] = in_grad.Size();
+      stride[3] = 1;
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_desc_,
+                                            dtype_,
+                                            4,
+                                            dim,
+                                            stride));
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_desc_,
+                                            dtype_,
+                                            4,
+                                            dim,
+                                            stride));
+
+      // perform dropout with cudnn
+      CUDNN_CALL(cudnnDropoutBackward(s->dnn_handle_,
+                                      dropout_desc_,
+                                      dy_desc_,
+                                      out_grad.dptr<DType>(),
+                                      dx_desc_,
+                                      in_grad.dptr<DType>(),
+                                      mask.dptr<DType>(),
+                                      dropout_reserve_byte_));
+  }
+#endif  // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__)
+
   void Forward(const OpContext &ctx,
                const std::vector<TBlob> &in_data,
                const std::vector<OpReqType> &req,
                const std::vector<TBlob> &out_data) {
+    this->dropout_passthrough_ = true;
     if (req[dropout::kOut] != kNullOp) {
       CHECK_EQ(in_data.size(), 1U);
       if (ctx.is_train) {
         CHECK_EQ(out_data.size(), 2U);
       }
       Stream<xpu> *s = ctx.get_stream<xpu>();
+      const TBlob &in = in_data[dropout::kData];
       const TBlob &out = out_data[dropout::kOut];
-      if (ctx.is_train || this->mode_ == dropout::kAlways) {
-        RandGenerator<xpu, DType> *pgen = 
ctx.requested[0].get_parallel_random<xpu, DType>();
-        CHECK_NOTNULL(pgen);
-        if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, 
in_data, out_data)) {
-          const TBlob &mask = out_data[dropout::kMask];
+      const TBlob &mask = out_data[dropout::kMask];
+      if (this->pkeep_ < 1 && (ctx.is_train || this->mode_ == 
dropout::kAlways)) {
+        this->dropout_passthrough_ = false;
+        if (this->axes_.ndim() == 0) {
+#if MXNET_USE_MKL_DROPOUT
+          if (MKLAvailable()) {
+            MKLForward(ctx, in_data, out_data);
+            return;
+          }
+#endif  // MXNET_USE_MKL_DROPOUT
+#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__)
+          if (CuDNNAvailable()) {
+            CuDNNForward(ctx, in, mask, out);
+            return;
+          }
+#endif  // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__)
+          RandGenerator<xpu, DType> *pgen = 
ctx.requested[0].get_parallel_random<xpu, DType>();
+          CHECK_NOTNULL(pgen);
           CHECK(req[dropout::kOut] != kAddTo);
 
 Review comment:
   This might become problematic since kAddTo should be handled automatically 
by the graph pass in the framework, so it's not something user can really 
control.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to