szha commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r274142661
 
 

 ##########
 File path: src/operator/rnn-inl.h
 ##########
 @@ -438,385 +565,891 @@ class RNNOp : public Operator{
     }
     DType* cx_ptr = NULL;
     DType* cy_ptr = NULL;
-
-    if (param_.mode == rnn_enum::kLstm) {
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
-      }
+    if (param_.mode == rnn_enum::kLstm)
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+    if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+      cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
+
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
+
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(ctx, s, in_data, out_data);
     }
-
-    // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
-                                                      param_.state_size, 
direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
-
+    // Get temp space
+    int temp_size = workspace_size_;
+    Tensor<xpu, 1, DType> temp_space =
+      ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+                              mshadow::Shape1(temp_size), s);
+
+    #if USE_CUDNN_LSTM_PROJ
+    std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+                                         dtype_,
+                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         param_.input_size_,
+                                         seqLengthArray.data(),
+                                         nullptr));
+    int out_size =
+      (param_.projection_size.has_value()) ? param_.projection_size.value() : 
param_.state_size;
+    out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+                                         dtype_,
+                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         out_size,
+                                         seqLengthArray.data(),
+                                         nullptr));
     if (ctx.is_train) {
-      const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, 
direction,
-                                                   param_.seq_length_, 
param_.batch_size_,
-                                                   param_.state_size, 
param_.mode);
-      if (init_space_ && reserve_space_size_ < r_size) {
-        Storage::Get()->Free(reserve_space_);
-        init_space_ = false;
-      }
-
-      if (!init_space_) {
-        reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), 
Context::CPU());
-        reserve_space_size_ = r_size;
-        init_space_ = true;
-      }
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+                                           dtype_,
+                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           param_.input_size_,
+                                           seqLengthArray.data(),
+                                           nullptr));
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+                                           dtype_,
+                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           out_size,
+                                           seqLengthArray.data(),
+                                           nullptr));
+    }
+    #endif
+
+    #if USE_CUDNN_LSTM_PROJ
+    bool clip_state = param_.lstm_state_clip_min.has_value();
+    bool clip_nan = param_.lstm_state_clip_nan;
+    CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+                               rnn_desc_,
+                               clip_state ? CUDNN_RNN_CLIP_MINMAX : 
CUDNN_RNN_CLIP_NONE,
+                               clip_nan ? CUDNN_NOT_PROPAGATE_NAN : 
CUDNN_PROPAGATE_NAN,
+                               clip_state ? param_.lstm_state_clip_min.value() 
: 0.0,
+                               clip_state ? param_.lstm_state_clip_max.value() 
: 0.0));
+    #endif
 
-      DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
-      RNNForwardTraining<DType>(workspace.dptr_,
-                                reserve_space_ptr,
-                                param_.state_outputs,
-                                param_.num_layers,
-                                direction,
-                                param_.seq_length_,
-                                param_.batch_size_,
-                                param_.input_size_,
-                                param_.state_size,
-                                x.dptr_,
-                                hx.dptr_,
-                                cx_ptr,
-                                w.dptr_,
-                                b_ptr,
-                                y.dptr_,
-                                hy_ptr,
-                                cy_ptr,
-                                param_.p,
-                                param_.mode);
+    if (ctx.is_train) {
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+                                           rnn_desc_,
+                                           x_data_desc_,
+                                           x.dptr_,
+                                           hx_desc_,
+                                           hx.dptr_,
+                                           cx_desc_,
+                                           cx_ptr,
+                                           w_desc_,
+                                           w.dptr_,
+                                           y_data_desc_,
+                                           y.dptr_,
+                                           hy_desc_,
+                                           hy_ptr,
+                                           cy_desc_,
+                                           cy_ptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           temp_space.dptr_,
+                                           workspace_byte_,
+                                           reserve_space_.dptr,
+                                           reserve_space_byte_));
+      #else
+      CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
+                                         rnn_desc_,
+                                         param_.seq_length_,
+                                         x_desc_vec_.data(),
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         cx_desc_,
+                                         cx_ptr,
+                                         w_desc_,
+                                         w.dptr_,
+                                         y_desc_vec_.data(),
+                                         y.dptr_,
+                                         hy_desc_,
+                                         hy_ptr,
+                                         cy_desc_,
+                                         cy_ptr,
+                                         temp_space.dptr_,
+                                         workspace_byte_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+      #endif
     } else {
-      RNNForwardInference<DType>(workspace.dptr_,
-                                 param_.state_outputs,
-                                 param_.num_layers,
-                                 direction,
-                                 param_.seq_length_,
-                                 param_.batch_size_,
-                                 param_.input_size_,
-                                 param_.state_size,
-                                 x.dptr_,
-                                 hx.dptr_,
-                                 cx_ptr,
-                                 w.dptr_,
-                                 b_ptr,
-                                 y.dptr_,
-                                 hy_ptr,
-                                 cy_ptr,
-                                 param_.mode);
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+                                            rnn_desc_,
+                                            x_data_desc_,
+                                            x.dptr_,
+                                            hx_desc_,
+                                            hx.dptr_,
+                                            cx_desc_,
+                                            cx_ptr,
+                                            w_desc_,
+                                            w.dptr_,
+                                            y_data_desc_,
+                                            y.dptr_,
+                                            hy_desc_,
+                                            hy_ptr,
+                                            cy_desc_,
+                                            cy_ptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            temp_space.dptr_,
+                                            workspace_byte_));
+      #else
+      CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
+                                          rnn_desc_,
+                                          param_.seq_length_,
+                                          x_desc_vec_.data(),
+                                          x.dptr_,
+                                          hx_desc_,
+                                          hx.dptr_,
+                                          cx_desc_,
+                                          cx_ptr,
+                                          w_desc_,
+                                          w.dptr_,
+                                          y_desc_vec_.data(),
+                                          y.dptr_,
+                                          hy_desc_,
+                                          hy_ptr,
+                                          cy_desc_,
+                                          cy_ptr,
+                                          temp_space.dptr_,
+                                          workspace_byte_));
+      #endif
+    }
+    #endif
+
+    if (ctx_.dev_type == kCPU) {
+      // allocate temp space
+      const size_t work_cpu_space_size =
+          GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                              param_.state_size, direction, param_.mode);
+      Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+        .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+      DType* work_cpu_space = workspace.dptr_;
+      if (ctx.is_train) {
+        const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, 
direction,
+                                                     param_.seq_length_, 
param_.batch_size_,
+                                                     param_.state_size, 
param_.mode);
+        if (init_space_ && reserve_cpu_space_size_ < r_size) {
+          Storage::Get()->Free(reserve_cpu_space_);
+          init_space_ = false;
+        }
+        if (!init_space_) {
+          reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), 
Context::CPU());
+          reserve_cpu_space_size_ = r_size;
+          init_space_ = true;
+        }
+
+        DType* reserve_space_ptr = 
static_cast<DType*>(reserve_cpu_space_.dptr);
+
+        RNNForwardTraining<DType>(work_cpu_space,
+                                  reserve_space_ptr,
+                                  param_.state_outputs,
+                                  param_.num_layers,
+                                  direction,
+                                  param_.seq_length_,
+                                  param_.batch_size_,
+                                  param_.input_size_,
+                                  param_.state_size,
+                                  x.dptr_,
+                                  hx.dptr_,
+                                  cx_ptr,
+                                  w.dptr_,
+                                  b_ptr,
+                                  y.dptr_,
+                                  hy_ptr,
+                                  cy_ptr,
+                                  param_.p,
+                                  param_.mode);
+      } else {
+        RNNForwardInference<DType>(work_cpu_space,
+                                   param_.state_outputs,
+                                   param_.num_layers,
+                                   direction,
+                                   param_.seq_length_,
+                                   param_.batch_size_,
+                                   param_.input_size_,
+                                   param_.state_size,
+                                   x.dptr_,
+                                   hx.dptr_,
+                                   cx_ptr,
+                                   w.dptr_,
+                                   b_ptr,
+                                   y.dptr_,
+                                   hy_ptr,
+                                   cy_ptr,
+                                   param_.mode);
+      }
     }
   }
 
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
+  void Backward(const OpContext &ctx,
+                const std::vector<TBlob> &out_grad,
+                const std::vector<TBlob> &in_data,
+                const std::vector<TBlob> &out_data,
+                const std::vector<OpReqType> &req,
+                const std::vector<TBlob> &in_grad) {
     using namespace mshadow;
     using namespace mshadow::expr;
     CHECK(param_.p >= 0.0f && param_.p < 1.0f)
         << "unsupported dropout value, should be 0 <= dropout < 1";
 
-    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
-    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
-    if (!param_.state_outputs) {
-      out_expected = 1;
+    size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    //  kOut
+    size_t num_outputs = 1;
+    if (param_.state_outputs) {
+      // kOut, kStateOut, kStateCellOut
+      num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     }
-    CHECK_EQ(in_data.size(), in_expected);
-    CHECK_EQ(out_data.size(), out_expected);
-    CHECK_EQ(in_grad.size(), in_expected);
-    CHECK_EQ(out_grad.size(), out_expected);
-    CHECK_EQ(req.size(), in_expected);
+
+    CHECK_EQ(in_data.size(), num_inputs);
+    CHECK_EQ(out_data.size(), num_outputs);
+    CHECK_EQ(in_grad.size(), num_inputs);
+    CHECK_EQ(out_grad.size(), num_outputs);
+    CHECK_EQ(req.size(), num_inputs);
     CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for 
data";
     CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for 
state";
-    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    Stream<xpu> *s = ctx.get_stream<xpu>();
     // get input + output tensors
-    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
-    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
-    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
-    Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, 
DType>(s);
-    Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, 
DType>(s);
-    Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
-    CHECK(x.CheckContiguous());
-    CHECK(w.CheckContiguous());
-    CHECK(hx.CheckContiguous());
-    CHECK(y.CheckContiguous());
-    CHECK(dx.CheckContiguous());
-    CHECK(dw.CheckContiguous());
-    CHECK(dhx.CheckContiguous());
-    CHECK(dy.CheckContiguous());
+    Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
+    Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+    Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, 
DType>(s);
+    Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, 
DType>(s);
+    Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
+
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(dw.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(dhx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
+    CHECK_EQ(dy.CheckContiguous(), true);
+    CHECK_EQ(dx.CheckContiguous(), true);
+
+    if (req[rnn_enum::kParams] != kAddTo) {
+      dw = mshadow::expr::ScalarExp<DType>(0.0f);
+    }
+
     param_.seq_length_ = x.shape_[0];
     param_.batch_size_ = x.shape_[1];
     param_.input_size_ = x.shape_[2];
 
     const int direction = param_.bidirectional ? 2 : 1;
     const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, 
direction, param_.mode);
+
     DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
 
     DType * dhy_ptr = NULL;
     if (param_.state_outputs) {
       dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
     }
 
-    DType * cx_ptr = NULL;
-    DType * dcx_ptr = NULL;
-    DType * dcy_ptr = NULL;
+    DType* dcx_ptr = NULL;
+    DType* dcy_ptr = NULL;
+    DType* cx_ptr = NULL;
 
     if (param_.mode == rnn_enum::kLstm) {
       CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported 
for state cell";
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
-      }
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+      dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
     }
+    if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
+        dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, 
DType>(s)).dptr_;
 
-    // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
-                                                      param_.state_size, 
direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
-
-    size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
-                                           param_.seq_length_, 
param_.batch_size_,
-                                           param_.state_size, param_.mode);
-    if (!init_space_ || reserve_space_size_ != r_size) {
-      LOG(FATAL) << "Check forward init error";
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(ctx, s, in_data, out_data);
     }
 
-    DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
-    RNNBackward<DType>(workspace.dptr_,
-                       reserve_space_ptr,
-                       param_.num_layers,
-                       direction,
-                       param_.seq_length_,
-                       param_.batch_size_,
-                       param_.input_size_,
-                       param_.state_size,
-                       x.dptr_,
-                       hx.dptr_,
-                       cx_ptr,
-                       w.dptr_,
-                       y.dptr_,
-                       dy.dptr_,
-                       dhy_ptr,
-                       dcy_ptr,
-                       dx.dptr_,
-                       dhx.dptr_,
-                       dcx_ptr,
-                       dw.dptr_,
-                       db_ptr,
-                       req[rnn_enum::kData],
-                       req[rnn_enum::kParams],
-                       req[rnn_enum::kState],
-                       // State cell should be present for LSTMs, but is 
absent for other RNNs.
-                       param_.mode == rnn_enum::kLstm ? 
req[rnn_enum::kStateCell] : kNullOp,
-                       param_.p,
-                       param_.mode);
-  }
-
- private:
-  RNNParam param_;
-  bool init_space_;
-  size_t reserve_space_size_;
-  Storage::Handle reserve_space_;
-};  // class RNNOp
-
-template<typename xpu>
-Operator* CreateOp(RNNParam param, int dtype);
+    // Get temp space
+    int temp_size = workspace_size_;
+    Tensor<xpu, 1, DType> temp_space =
+      ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
 
 Review comment:
   Actually it won't work as the temp space resource would have been released. 
Since it doesn't change in size between forward and backward, it would be great 
to avoid repeated allocation for this.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to