szha commented on a change in pull request #14476: Change RNN OP to stateful URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r274142661
########## File path: src/operator/rnn-inl.h ########## @@ -438,385 +565,891 @@ class RNNOp : public Operator{ } DType* cx_ptr = NULL; DType* cy_ptr = NULL; - - if (param_.mode == rnn_enum::kLstm) { - cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>(); - if (param_.state_outputs) { - cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>(); - } + if (param_.mode == rnn_enum::kLstm) + cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; + if (param_.mode == rnn_enum::kLstm && param_.state_outputs) + cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_; + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(ctx, s, in_data, out_data); } - - // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s); - + // Get temp space + int temp_size = workspace_size_; + Tensor<xpu, 1, DType> temp_space = + ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>( + mshadow::Shape1(temp_size), s); + + #if USE_CUDNN_LSTM_PROJ + std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_); + CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + int out_size = + (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; + out_size = (param_.bidirectional) ? (out_size * 2) : out_size; + CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); if (ctx.is_train) { - const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (init_space_ && reserve_space_size_ < r_size) { - Storage::Get()->Free(reserve_space_); - init_space_ = false; - } - - if (!init_space_) { - reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); - reserve_space_size_ = r_size; - init_space_ = true; - } + CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); + } + #endif + + #if USE_CUDNN_LSTM_PROJ + bool clip_state = param_.lstm_state_clip_min.has_value(); + bool clip_nan = param_.lstm_state_clip_nan; + CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, + rnn_desc_, + clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, + clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, + clip_state ? param_.lstm_state_clip_min.value() : 0.0, + clip_state ? param_.lstm_state_clip_max.value() : 0.0)); + #endif - DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr); - RNNForwardTraining<DType>(workspace.dptr_, - reserve_space_ptr, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.p, - param_.mode); + if (ctx.is_train) { + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space.dptr_, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space.dptr_, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #endif } else { - RNNForwardInference<DType>(workspace.dptr_, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space.dptr_, + workspace_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space.dptr_, + workspace_byte_)); + #endif + } + #endif + + if (ctx_.dev_type == kCPU) { + // allocate temp space + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] + .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s); + DType* work_cpu_space = workspace.dptr_; + if (ctx.is_train) { + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (init_space_ && reserve_cpu_space_size_ < r_size) { + Storage::Get()->Free(reserve_cpu_space_); + init_space_ = false; + } + if (!init_space_) { + reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); + reserve_cpu_space_size_ = r_size; + init_space_ = true; + } + + DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr); + + RNNForwardTraining<DType>(work_cpu_space, + reserve_space_ptr, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.p, + param_.mode); + } else { + RNNForwardInference<DType>(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); + } } } - virtual void Backward(const OpContext &ctx, - const std::vector<TBlob> &out_grad, - const std::vector<TBlob> &in_data, - const std::vector<TBlob> &out_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &in_grad, - const std::vector<TBlob> &aux_args) { + void Backward(const OpContext &ctx, + const std::vector<TBlob> &out_grad, + const std::vector<TBlob> &in_data, + const std::vector<TBlob> &out_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &in_grad) { using namespace mshadow; using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; - size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; - if (!param_.state_outputs) { - out_expected = 1; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - CHECK_EQ(in_grad.size(), in_expected); - CHECK_EQ(out_grad.size(), out_expected); - CHECK_EQ(req.size(), in_expected); + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + CHECK_EQ(in_grad.size(), num_inputs); + CHECK_EQ(out_grad.size(), num_outputs); + CHECK_EQ(req.size(), num_inputs); CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data"; CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state"; - mshadow::Stream<cpu> *s = ctx.get_stream<cpu>(); + Stream<xpu> *s = ctx.get_stream<xpu>(); // get input + output tensors - Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s); - Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s); - Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s); - Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, DType>(s); - Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s); - CHECK(x.CheckContiguous()); - CHECK(w.CheckContiguous()); - CHECK(hx.CheckContiguous()); - CHECK(y.CheckContiguous()); - CHECK(dx.CheckContiguous()); - CHECK(dw.CheckContiguous()); - CHECK(dhx.CheckContiguous()); - CHECK(dy.CheckContiguous()); + Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s); + Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, DType>(s); + Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s); + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(dw.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(dhx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + CHECK_EQ(dy.CheckContiguous(), true); + CHECK_EQ(dx.CheckContiguous(), true); + + if (req[rnn_enum::kParams] != kAddTo) { + dw = mshadow::expr::ScalarExp<DType>(0.0f); + } + param_.seq_length_ = x.shape_[0]; param_.batch_size_ = x.shape_[1]; param_.input_size_ = x.shape_[2]; const int direction = param_.bidirectional ? 2 : 1; const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); + DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize; DType * dhy_ptr = NULL; if (param_.state_outputs) { dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>(); } - DType * cx_ptr = NULL; - DType * dcx_ptr = NULL; - DType * dcy_ptr = NULL; + DType* dcx_ptr = NULL; + DType* dcy_ptr = NULL; + DType* cx_ptr = NULL; if (param_.mode == rnn_enum::kLstm) { CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell"; - cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>(); - dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>(); - if (param_.state_outputs) { - dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>(); - } + cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; + dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; } + if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) + dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_; - // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s); - - size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (!init_space_ || reserve_space_size_ != r_size) { - LOG(FATAL) << "Check forward init error"; + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(ctx, s, in_data, out_data); } - DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr); - RNNBackward<DType>(workspace.dptr_, - reserve_space_ptr, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - y.dptr_, - dy.dptr_, - dhy_ptr, - dcy_ptr, - dx.dptr_, - dhx.dptr_, - dcx_ptr, - dw.dptr_, - db_ptr, - req[rnn_enum::kData], - req[rnn_enum::kParams], - req[rnn_enum::kState], - // State cell should be present for LSTMs, but is absent for other RNNs. - param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, - param_.p, - param_.mode); - } - - private: - RNNParam param_; - bool init_space_; - size_t reserve_space_size_; - Storage::Handle reserve_space_; -}; // class RNNOp - -template<typename xpu> -Operator* CreateOp(RNNParam param, int dtype); + // Get temp space + int temp_size = workspace_size_; + Tensor<xpu, 1, DType> temp_space = + ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>( Review comment: Actually it won't work as the temp space resource would have been released. Since it doesn't change in size between forward and backward, it would be great to avoid repeated allocation for this. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services