lihaofd commented on a change in pull request #14476: Change RNN OP to stateful URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r269112331
########## File path: src/operator/rnn-inl.h ########## @@ -436,387 +566,897 @@ class RNNOp : public Operator{ if (param_.state_outputs) { hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>(); } - DType* cx_ptr = NULL; - DType* cy_ptr = NULL; + DType * cx_ptr = NULL; + DType * cy_ptr = NULL; + if (param_.mode == rnn_enum::kLstm) + cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; + if (param_.mode == rnn_enum::kLstm && param_.state_outputs) + cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_; - if (param_.mode == rnn_enum::kLstm) { - cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>(); - if (param_.state_outputs) { - cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>(); - } - } + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size, direction, param_.mode); - Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s); + DType* work_cpu_space = NULL; + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(s, in_data, out_data); + } + // Get temp space + int temp_size = workspace_size_; + Tensor<xpu, 1, DType> temp_space = + ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>( + mshadow::Shape1(temp_size + work_cpu_space_size), s); + + work_cpu_space = temp_space.dptr_ + temp_size; + + #if USE_CUDNN_LSTM_PROJ + std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_); + CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + int out_size = + (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; + out_size = (param_.bidirectional) ? (out_size * 2) : out_size; + CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); + if (ctx.is_train) { + CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); + } + #endif + + #if USE_CUDNN_LSTM_PROJ + bool clip_state = param_.lstm_state_clip_min.has_value(); + bool clip_nan = param_.lstm_state_clip_nan; + CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, + rnn_desc_, + clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, + clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, + clip_state ? param_.lstm_state_clip_min.value() : 0.0, + clip_state ? param_.lstm_state_clip_max.value() : 0.0)); + #endif if (ctx.is_train) { - const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (init_space_ && reserve_space_size_ < r_size) { - Storage::Get()->Free(reserve_space_); - init_space_ = false; - } + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space.dptr_, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space.dptr_, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #endif + } else { + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space.dptr_, + workspace_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space.dptr_, + workspace_byte_)); + #endif + } + #endif - if (!init_space_) { - reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); - reserve_space_size_ = r_size; - init_space_ = true; + if (ctx_.dev_type == kCPU) { + if (!work_cpu_space) { + Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] + .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s); + work_cpu_space = workspace.dptr_; + } + if (ctx.is_train) { + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (init_space_ && reserve_cpu_space_size_ < r_size) { + Storage::Get()->Free(reserve_cpu_space_); + init_space_ = false; + } + if (!init_space_) { + reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); + reserve_cpu_space_size_ = r_size; + init_space_ = true; + } + + DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr); + + RNNForwardTraining<DType>(work_cpu_space, + reserve_space_ptr, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.p, + param_.mode); + } else { + RNNForwardInference<DType>(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); } - - DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr); - RNNForwardTraining<DType>(workspace.dptr_, - reserve_space_ptr, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.p, - param_.mode); - } else { - RNNForwardInference<DType>(workspace.dptr_, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); } } - virtual void Backward(const OpContext &ctx, - const std::vector<TBlob> &out_grad, - const std::vector<TBlob> &in_data, - const std::vector<TBlob> &out_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &in_grad, - const std::vector<TBlob> &aux_args) { + void Backward(const OpContext &ctx, + const std::vector<TBlob> &out_grad, + const std::vector<TBlob> &in_data, + const std::vector<TBlob> &out_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &in_grad) { using namespace mshadow; using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; - size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; - if (!param_.state_outputs) { - out_expected = 1; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - CHECK_EQ(in_grad.size(), in_expected); - CHECK_EQ(out_grad.size(), out_expected); - CHECK_EQ(req.size(), in_expected); + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + CHECK_EQ(in_grad.size(), num_inputs); + CHECK_EQ(out_grad.size(), num_outputs); + CHECK_EQ(req.size(), num_inputs); CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data"; CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state"; - mshadow::Stream<cpu> *s = ctx.get_stream<cpu>(); + Stream<xpu> *s = ctx.get_stream<xpu>(); // get input + output tensors - Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s); - Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s); - Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s); - Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, DType>(s); - Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s); - CHECK(x.CheckContiguous()); - CHECK(w.CheckContiguous()); - CHECK(hx.CheckContiguous()); - CHECK(y.CheckContiguous()); - CHECK(dx.CheckContiguous()); - CHECK(dw.CheckContiguous()); - CHECK(dhx.CheckContiguous()); - CHECK(dy.CheckContiguous()); + Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s); + Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, DType>(s); + Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s); + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(dw.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(dhx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + CHECK_EQ(dy.CheckContiguous(), true); + + if (req[rnn_enum::kParams] != kAddTo) { + dw = mshadow::expr::ScalarExp<DType>(0.0f); + } + param_.seq_length_ = x.shape_[0]; param_.batch_size_ = x.shape_[1]; param_.input_size_ = x.shape_[2]; const int direction = param_.bidirectional ? 2 : 1; const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); + DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize; DType * dhy_ptr = NULL; if (param_.state_outputs) { dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>(); } - DType * cx_ptr = NULL; - DType * dcx_ptr = NULL; - DType * dcy_ptr = NULL; + DType* dcx_ptr = NULL; + DType* dcy_ptr = NULL; + DType* cx_ptr = NULL; if (param_.mode == rnn_enum::kLstm) { CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell"; - cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>(); - dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>(); - if (param_.state_outputs) { - dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>(); - } + cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; + dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; } + if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) + dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_; // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s); - - size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (!init_space_ || reserve_space_size_ != r_size) { - LOG(FATAL) << "Check forward init error"; + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + DType* work_cpu_space = NULL; + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(s, in_data, out_data); } - DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr); - RNNBackward<DType>(workspace.dptr_, - reserve_space_ptr, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - y.dptr_, - dy.dptr_, - dhy_ptr, - dcy_ptr, - dx.dptr_, - dhx.dptr_, - dcx_ptr, - dw.dptr_, - db_ptr, - req[rnn_enum::kData], - req[rnn_enum::kParams], - req[rnn_enum::kState], - // State cell should be present for LSTMs, but is absent for other RNNs. - param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, - param_.p, - param_.mode); - } - - private: - RNNParam param_; - bool init_space_; - size_t reserve_space_size_; - Storage::Handle reserve_space_; -}; // class RNNOp + // Get temp space + int temp_size = workspace_size_; + Tensor<xpu, 1, DType> temp_space = + ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>( + mshadow::Shape1(temp_size + work_cpu_space_size), s); + work_cpu_space = temp_space.dptr_ + temp_size; + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, + rnn_desc_, + y_data_desc_, + y.dptr_, + dy_data_desc_, + dy.dptr_, + nullptr, + nullptr, + dhy_desc_, + dhy_ptr, + dcy_desc_, + dcy_ptr, + w_desc_, + w.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + dx_data_desc_, + dx.dptr_, + dhx_desc_, + dhx.dptr_, + dcx_desc_, + dcx_ptr, + nullptr, + nullptr, + temp_space.dptr_, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + y_data_desc_, + y.dptr_, + temp_space.dptr_, + workspace_byte_, + dw_desc_, + dw.dptr_, + reserve_space_.dptr, + reserve_space_byte_)); + #else + CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + y_desc_vec_.data(), + y.dptr_, + dy_desc_vec_.data(), + dy.dptr_, + dhy_desc_, + dhy_ptr, + dcy_desc_, + dcy_ptr, + w_desc_, + w.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + dx_desc_vec_.data(), + dx.dptr_, + dhx_desc_, + dhx.dptr_, + dcx_desc_, + dcx_ptr, + temp_space.dptr_, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + y_desc_vec_.data(), + y.dptr_, + temp_space.dptr_, + workspace_byte_, + dw_desc_, + dw.dptr_, + reserve_space_.dptr, + reserve_space_byte_)); + #endif + #endif + + if (ctx_.dev_type == kCPU) { + if (!work_cpu_space) { + Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] + .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s); + work_cpu_space = workspace.dptr_; + } + size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); -template<typename xpu> -Operator* CreateOp(RNNParam param, int dtype); + if (!init_space_ || reserve_cpu_space_size_ != r_size) { + LOG(FATAL) << "Check forward init error"; + } -#if DMLC_USE_CXX11 -class RNNProp : public OperatorProperty { - public: - std::vector<std::string> ListArguments() const override { - if (param_.mode == rnn_enum::kLstm) { - return {"data", "parameters", "state", "state_cell"}; - } else { - return {"data", "parameters", "state"}; + DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr); + RNNBackward<DType>(work_cpu_space, + reserve_space_ptr, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + y.dptr_, + dy.dptr_, + dhy_ptr, + dcy_ptr, + dx.dptr_, + dhx.dptr_, + dcx_ptr, + dw.dptr_, + db_ptr, + req[rnn_enum::kData], + req[rnn_enum::kParams], + req[rnn_enum::kState], + // State cell should be present for LSTMs, but is absent for other RNNs. + param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, + param_.p, + param_.mode); } } - std::vector<std::string> ListOutputs() const override { - std::vector<std::string> outputs = {"output"}; - if (!param_.state_outputs) - return outputs; - else - outputs.emplace_back("state"); - if (param_.mode == rnn_enum::kLstm) - outputs.emplace_back("state_cell"); - return outputs; - } - - int NumOutputs() const override { - int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1; - int num_outputs = param_.state_outputs ? (mode_num + 1) : 1; - return num_outputs; - } - - void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { - param_.Init(kwargs); - } - std::map<std::string, std::string> GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + private: + inline void Init(mshadow::Stream<xpu> *s, + const std::vector<TBlob> &in_data, + const std::vector<TBlob> &out_data) { using namespace mshadow; - if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]"; - } else { - CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]"; - } - const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData]; - if (dshape.ndim() == 0) return false; - CHECK_EQ(dshape.ndim(), 3U) \ - << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; - // data: [sequence len, batch, input dimension] - int batch_size = dshape[1]; - int input_size = dshape[2]; - int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bidirectional - int layer_size = (param_.projection_size.has_value()) ? - param_.projection_size.value() : param_.state_size; - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kState, - Shape3(total_layers, batch_size, layer_size)); - if (param_.mode == rnn_enum::kLstm) - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kStateCell, - Shape3(total_layers, batch_size, param_.state_size)); - - // calculate parameter vector length - int param_size = GetRnnParamSize(param_.num_layers, - input_size, - param_.state_size, - numDirections, - param_.mode, - param_.projection_size); - SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); - - out_shape->clear(); - // output: [sequence len, batch, output size] - mxnet::TShape oshape = dshape; - if (param_.projection_size.has_value()) { - oshape[2] = numDirections * param_.projection_size.value(); - } else { - oshape[2] = numDirections * param_.state_size; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - out_shape->push_back(oshape); - if (!param_.state_outputs) { - return true; - } else { - // outStateShape: [layer_num, batch, state size] - mxnet::TShape outStateShape = dshape; - outStateShape[0] = total_layers; - outStateShape[1] = batch_size; - if (param_.projection_size.has_value()) { - outStateShape[2] = param_.projection_size.value(); + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + #if CUDNN_MAJOR >= 5 + format_ = CUDNN_TENSOR_NCHW; + #endif + + if (!init_cudnn_) { + init_cudnn_ = true; + // get input + output tensors + Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s); + param_.seq_length_ = x.shape_[0]; + param_.batch_size_ = x.shape_[1]; + param_.input_size_ = x.shape_[2]; + + // Tensor Descriptors + std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_); + std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_); + std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_); + std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_); + int dimA[3]; + int strideA[3]; + for (int i = 0; i < param_.seq_length_; i++) { + CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i])); + + dimA[0] = param_.batch_size_; + dimA[1] = param_.input_size_; + dimA[2] = 1; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + + CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i], + dtype_, + 3, + dimA, + strideA)); + dimA[0] = param_.batch_size_; + dimA[1] = param_.bidirectional ? param_.state_size * 2 : param_.state_size; + dimA[2] = 1; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + + CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i], + dtype_, + 3, + dimA, + strideA)); + } + x_desc_vec_ = x_vec; + y_desc_vec_ = y_vec; + dx_desc_vec_ = dx_vec; + dy_desc_vec_ = dy_vec; + + // set the state tensors + dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); + dimA[1] = param_.batch_size_; + dimA[2] = param_.state_size; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + #if USE_CUDNN_LSTM_PROJ + int dimB[3]; + int strideB[3]; + dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); + dimB[1] = param_.batch_size_; + dimB[2] = param_.projection_size.has_value() ? + param_.projection_size.value() : param_.state_size; + strideB[0] = dimB[2] * dimB[1]; + strideB[1] = dimB[2]; + strideB[2] = 1; + #endif + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_, + dtype_, + 3, + dimA, + strideA)); + + // Create Dropout descriptors + if (param_.p > 0) { + CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_)); + dropout_size_ = dropout_byte_ / sizeof(DType); + dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU(s->dev_id)); Review comment: using ctx.requested to fix it ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services