eric-haibin-lin commented on a change in pull request #14476: Change RNN OP to stateful URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r274756641
########## File path: src/operator/rnn-inl.h ########## @@ -438,385 +572,891 @@ class RNNOp : public Operator{ } DType* cx_ptr = NULL; DType* cy_ptr = NULL; - - if (param_.mode == rnn_enum::kLstm) { - cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>(); - if (param_.state_outputs) { - cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>(); - } + if (param_.mode == rnn_enum::kLstm) + cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; + if (param_.mode == rnn_enum::kLstm && param_.state_outputs) + cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_; + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(ctx, s, in_data, out_data); } - // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s); + #if USE_CUDNN_LSTM_PROJ + std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_); + CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + int out_size = + (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; + out_size = (param_.bidirectional) ? (out_size * 2) : out_size; + CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); + if (ctx.is_train) { + CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); + } + #endif + + #if USE_CUDNN_LSTM_PROJ + bool clip_state = param_.lstm_state_clip_min.has_value(); + bool clip_nan = param_.lstm_state_clip_nan; + CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, + rnn_desc_, + clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, + clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, + clip_state ? param_.lstm_state_clip_min.value() : 0.0, + clip_state ? param_.lstm_state_clip_max.value() : 0.0)); + #endif if (ctx.is_train) { - const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (init_space_ && reserve_space_size_ < r_size) { - Storage::Get()->Free(reserve_space_); - init_space_ = false; + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #endif + } else { + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space_.dptr, + workspace_byte_)); + #endif + } + #endif + + if (ctx_.dev_type == kCPU) { + // allocate temp space + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; } - - if (!init_space_) { - reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); - reserve_space_size_ = r_size; - init_space_ = true; + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc + (work_cpu_space_size * sizeof(DType), Context::CPU()); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr); + if (ctx.is_train) { + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (init_space_ && reserve_cpu_space_size_ < r_size) { + Storage::Get()->Free(reserve_cpu_space_); + init_space_ = false; + } + if (!init_space_) { + reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); + reserve_cpu_space_size_ = r_size; + init_space_ = true; + } + + DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr); + + RNNForwardTraining<DType>(work_cpu_space, + reserve_space_ptr, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.p, + param_.mode); + } else { + RNNForwardInference<DType>(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); } - - DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr); - RNNForwardTraining<DType>(workspace.dptr_, - reserve_space_ptr, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.p, - param_.mode); - } else { - RNNForwardInference<DType>(workspace.dptr_, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); } } - virtual void Backward(const OpContext &ctx, - const std::vector<TBlob> &out_grad, - const std::vector<TBlob> &in_data, - const std::vector<TBlob> &out_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &in_grad, - const std::vector<TBlob> &aux_args) { + void Backward(const OpContext &ctx, + const std::vector<TBlob> &out_grad, + const std::vector<TBlob> &in_data, + const std::vector<TBlob> &out_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &in_grad) { using namespace mshadow; using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; - size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; - if (!param_.state_outputs) { - out_expected = 1; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - CHECK_EQ(in_grad.size(), in_expected); - CHECK_EQ(out_grad.size(), out_expected); - CHECK_EQ(req.size(), in_expected); + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + CHECK_EQ(in_grad.size(), num_inputs); + CHECK_EQ(out_grad.size(), num_outputs); + CHECK_EQ(req.size(), num_inputs); CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data"; CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state"; - mshadow::Stream<cpu> *s = ctx.get_stream<cpu>(); + Stream<xpu> *s = ctx.get_stream<xpu>(); // get input + output tensors - Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s); - Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s); - Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s); - Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, DType>(s); - Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, DType>(s); - Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s); - CHECK(x.CheckContiguous()); - CHECK(w.CheckContiguous()); - CHECK(hx.CheckContiguous()); - CHECK(y.CheckContiguous()); - CHECK(dx.CheckContiguous()); - CHECK(dw.CheckContiguous()); - CHECK(dhx.CheckContiguous()); - CHECK(dy.CheckContiguous()); + Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s); + Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, DType>(s); + Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s); + Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s); + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(dw.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(dhx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + CHECK_EQ(dy.CheckContiguous(), true); + CHECK_EQ(dx.CheckContiguous(), true); + + if (req[rnn_enum::kParams] != kAddTo) { + dw = mshadow::expr::ScalarExp<DType>(0.0f); + } + param_.seq_length_ = x.shape_[0]; param_.batch_size_ = x.shape_[1]; param_.input_size_ = x.shape_[2]; const int direction = param_.bidirectional ? 2 : 1; const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); + DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize; DType * dhy_ptr = NULL; if (param_.state_outputs) { dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>(); } - DType * cx_ptr = NULL; - DType * dcx_ptr = NULL; - DType * dcy_ptr = NULL; + DType* dcx_ptr = NULL; + DType* dcy_ptr = NULL; + DType* cx_ptr = NULL; if (param_.mode == rnn_enum::kLstm) { CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell"; - cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>(); - dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>(); - if (param_.state_outputs) { - dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>(); - } + cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; + dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_; } + if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) + dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_; - // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s); - - size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (!init_space_ || reserve_space_size_ != r_size) { - LOG(FATAL) << "Check forward init error"; + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(ctx, s, in_data, out_data); } - DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr); - RNNBackward<DType>(workspace.dptr_, - reserve_space_ptr, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - y.dptr_, - dy.dptr_, - dhy_ptr, - dcy_ptr, - dx.dptr_, - dhx.dptr_, - dcx_ptr, - dw.dptr_, - db_ptr, - req[rnn_enum::kData], - req[rnn_enum::kParams], - req[rnn_enum::kState], - // State cell should be present for LSTMs, but is absent for other RNNs. - param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, - param_.p, - param_.mode); - } - - private: - RNNParam param_; - bool init_space_; - size_t reserve_space_size_; - Storage::Handle reserve_space_; -}; // class RNNOp + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, + rnn_desc_, + y_data_desc_, + y.dptr_, + dy_data_desc_, + dy.dptr_, + nullptr, + nullptr, + dhy_desc_, + dhy_ptr, + dcy_desc_, + dcy_ptr, + w_desc_, + w.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + dx_data_desc_, + dx.dptr_, + dhx_desc_, + dhx.dptr_, + dcx_desc_, + dcx_ptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + y_data_desc_, + y.dptr_, + temp_space_.dptr, + workspace_byte_, + dw_desc_, + dw.dptr_, + reserve_space_.dptr, + reserve_space_byte_)); + #else + CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + y_desc_vec_.data(), + y.dptr_, + dy_desc_vec_.data(), + dy.dptr_, + dhy_desc_, + dhy_ptr, + dcy_desc_, + dcy_ptr, + w_desc_, + w.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + dx_desc_vec_.data(), + dx.dptr_, + dhx_desc_, + dhx.dptr_, + dcx_desc_, + dcx_ptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + y_desc_vec_.data(), + y.dptr_, + temp_space_.dptr, + workspace_byte_, + dw_desc_, + dw.dptr_, + reserve_space_.dptr, + reserve_space_byte_)); + #endif + #endif + + if (ctx_.dev_type == kCPU) { + // allocate temp space + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) { + LOG(FATAL) << "Check temp init error"; + } + DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr); + size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); -template<typename xpu> -Operator* CreateOp(RNNParam param, int dtype); + if (!init_space_ || reserve_cpu_space_size_ != r_size) { + LOG(FATAL) << "Check forward init error"; + } -#if DMLC_USE_CXX11 -class RNNProp : public OperatorProperty { - public: - std::vector<std::string> ListArguments() const override { - if (param_.mode == rnn_enum::kLstm) { - return {"data", "parameters", "state", "state_cell"}; - } else { - return {"data", "parameters", "state"}; + DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr); + RNNBackward<DType>(work_cpu_space, + reserve_space_ptr, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + y.dptr_, + dy.dptr_, + dhy_ptr, + dcy_ptr, + dx.dptr_, + dhx.dptr_, + dcx_ptr, + dw.dptr_, + db_ptr, + req[rnn_enum::kData], + req[rnn_enum::kParams], + req[rnn_enum::kState], + // State cell should be present for LSTMs, but is absent for other RNNs. + param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, + param_.p, + param_.mode); } } - std::vector<std::string> ListOutputs() const override { - std::vector<std::string> outputs = {"output"}; - if (!param_.state_outputs) - return outputs; - else - outputs.emplace_back("state"); - if (param_.mode == rnn_enum::kLstm) - outputs.emplace_back("state_cell"); - return outputs; - } - - int NumOutputs() const override { - int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1; - int num_outputs = param_.state_outputs ? (mode_num + 1) : 1; - return num_outputs; - } - - void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { - param_.Init(kwargs); - } - std::map<std::string, std::string> GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + private: + inline void Init(const OpContext &ctx, + mshadow::Stream<xpu> *s, + const std::vector<TBlob> &in_data, + const std::vector<TBlob> &out_data) { using namespace mshadow; - if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]"; - } else { - CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]"; - } - const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData]; - if (dshape.ndim() == 0) return false; - CHECK_EQ(dshape.ndim(), 3U) \ - << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; - // data: [sequence len, batch, input dimension] - int batch_size = dshape[1]; - int input_size = dshape[2]; - int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bidirectional - int layer_size = (param_.projection_size.has_value()) ? - param_.projection_size.value() : param_.state_size; - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kState, - Shape3(total_layers, batch_size, layer_size)); - if (param_.mode == rnn_enum::kLstm) - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kStateCell, - Shape3(total_layers, batch_size, param_.state_size)); - - // calculate parameter vector length - int param_size = GetRnnParamSize(param_.num_layers, - input_size, - param_.state_size, - numDirections, - param_.mode, - param_.projection_size); - SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); - - out_shape->clear(); - // output: [sequence len, batch, output size] - mxnet::TShape oshape = dshape; - if (param_.projection_size.has_value()) { - oshape[2] = numDirections * param_.projection_size.value(); - } else { - oshape[2] = numDirections * param_.state_size; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - out_shape->push_back(oshape); - if (!param_.state_outputs) { - return true; - } else { - // outStateShape: [layer_num, batch, state size] - mxnet::TShape outStateShape = dshape; - outStateShape[0] = total_layers; - outStateShape[1] = batch_size; - if (param_.projection_size.has_value()) { - outStateShape[2] = param_.projection_size.value(); - } else { - outStateShape[2] = param_.state_size; + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + #if CUDNN_MAJOR >= 5 + format_ = CUDNN_TENSOR_NCHW; + #endif + + if (!init_cudnn_) { + init_cudnn_ = true; + // get input + output tensors + Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s); + Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s); + param_.seq_length_ = x.shape_[0]; + param_.batch_size_ = x.shape_[1]; + param_.input_size_ = x.shape_[2]; + + // Tensor Descriptors + std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_); + std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_); + std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_); + std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_); + int dimA[3]; + int strideA[3]; + for (int i = 0; i < param_.seq_length_; i++) { + CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i])); + + dimA[0] = param_.batch_size_; + dimA[1] = param_.input_size_; + dimA[2] = 1; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + + CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i], + dtype_, + 3, + dimA, + strideA)); + dimA[0] = param_.batch_size_; + dimA[1] = param_.bidirectional ? param_.state_size * 2 : param_.state_size; + dimA[2] = 1; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + + CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i], + dtype_, + 3, + dimA, + strideA)); } - out_shape->push_back(outStateShape); - // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { - mxnet::TShape cellStateShape = dshape; - cellStateShape[0] = total_layers; - cellStateShape[1] = batch_size; - cellStateShape[2] = param_.state_size; - out_shape->push_back(cellStateShape); + x_desc_vec_ = x_vec; + y_desc_vec_ = y_vec; + dx_desc_vec_ = dx_vec; + dy_desc_vec_ = dy_vec; + + // set the state tensors + dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); + dimA[1] = param_.batch_size_; + dimA[2] = param_.state_size; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + #if USE_CUDNN_LSTM_PROJ + int dimB[3]; + int strideB[3]; + dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); + dimB[1] = param_.batch_size_; + dimB[2] = param_.projection_size.has_value() ? + param_.projection_size.value() : param_.state_size; + strideB[0] = dimB[2] * dimB[1]; + strideB[1] = dimB[2]; + strideB[2] = 1; + #endif + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_, + dtype_, + 3, + dimA, + strideA)); + + // Create Dropout descriptors + DType* dropout_states_ = NULL; + if (param_.p > 0) { + ctx.requested[rnn_enum::kCuDNNDropoutDescSpace].get_cudnn_dropout_desc + (&dropout_desc_, s, 1.0f - param_.p, seed_); + } else { + dropout_byte_ = 0; } - return true; - } - } - bool InferType(std::vector<int> *in_type, - std::vector<int> *out_type, - std::vector<int> *aux_type) const override { - CHECK_GE(in_type->size(), 1U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; - for (size_t i = 0; i < in_type->size(); ++i) { - if ((*in_type)[i] == -1) { - (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); + CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_, + param_.p, // discard probability + dropout_states_, dropout_byte_, + seed_)); + + // RNN descriptors + #if CUDNN_MAJOR >= 6 + cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; + CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, + rnn_desc_, + param_.state_size, + param_.num_layers, + dropout_desc_, + input_mode_, + direction_, + mode_, + rnn_algo, + dtype_)); + #else + CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_, + param_.state_size, + param_.num_layers, + dropout_desc_, + input_mode_, + direction_, + mode_, + dtype_)); + #endif + #if CUDNN_MAJOR >= 7 + cudnnMathType_t math_type = CUDNN_DEFAULT_MATH; + if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) { + math_type = CUDNN_TENSOR_OP_MATH; + } + #if CUDNN_VERSION >= 7200 + if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && + (DataType<DType>::kFlag != kFloat16)) + math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; + #endif + CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); + #endif + #if USE_CUDNN_LSTM_PROJ + if (param_.projection_size.has_value()) { + CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_, + rnn_desc_, + param_.projection_size.value(), + 0)); } + #endif + // Get temp space sizes + CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + &workspace_byte_)); + CUDNN_CALL(cudnnGetRNNTrainingReserveSize(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + &reserve_space_byte_)); + workspace_size_ = workspace_byte_ / sizeof(DType); + // Allocate the reserve space + reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id)); + // Allocate the temp space + temp_space_ = Storage::Get()->Alloc(workspace_byte_, Context::GPU(s->dev_id)); + // Check that number of params are correct + size_t cudnn_param_size; + CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_, + rnn_desc_, + x_desc_vec_[0], + &cudnn_param_size, + dtype_)); + CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size); + // Set param descriptors + int dim_w[3] = {1, 1, 1}; + dim_w[0] = w.shape_[0]; + CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_, + dtype_, + format_, + 3, + dim_w)); + CUDNN_CALL(cudnnSetFilterNdDescriptor(dw_desc_, + dtype_, + format_, + 3, + dim_w)); + + // Query weight layout + // cudnnFilterDescriptor_t m_desc; + // CHECK_EQ(cudnnCreateFilterDescriptor(&m_desc), CUDNN_STATUS_SUCCESS); + // DType *p; + // int n = 2; + // int64_t last = 0; + // if (param_.mode == rnn_enum::kLstm) n = 8; + // else if (param_.mode == rnn_enum::kGru) n = 6; + + // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { + // for (int j = 0; j < n; ++j) { + // CHECK_EQ(cudnnGetRNNLinLayerMatrixParams(s->dnn_handle_, rnn_desc_, + // i, x_desc_vec_[0], w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); + // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; + // last = ((int64_t)(p - NULL))/sizeof(DType); + // cudnnDataType_t t; + // cudnnTensorFormat_t f; + // int ndim = 5; + // int dims[5] = {0, 0, 0, 0, 0}; + // CHECK_EQ(cudnnGetFilterNdDescriptor(m_desc, ndim, &t, &f, &ndim, &dims[0]), + // CUDNN_STATUS_SUCCESS); + // LOG(INFO) << "w: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); + // for (int i = 0; i < ndim; ++i) LOG(INFO) << dims[i]; + // } + // } + + // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { + // for (int j = 0; j < n; ++j) { + // CHECK_EQ(cudnnGetRNNLinLayerBiasParams(s->dnn_handle_, rnn_desc_, i, x_desc_vec_[0], + // w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); + // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; + // last = ((int64_t)(p - NULL))/sizeof(DType); + // LOG(INFO) << "b: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); + // } + // } } - out_type->clear(); - out_type->push_back(dtype); - if (!param_.state_outputs) { - return true; + #endif + } + #if MXNET_USE_CUDNN_RNN + cudnnDataType_t dtype_; + bool init_cudnn_; + cudnnRNNDescriptor_t rnn_desc_; + cudnnRNNMode_t mode_; + cudnnDirectionMode_t direction_; + cudnnRNNInputMode_t input_mode_; + cudnnDropoutDescriptor_t dropout_desc_; + Storage::Handle reserve_space_, temp_space_; + uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) + size_t workspace_byte_, reserve_space_byte_, dropout_byte_; + int workspace_size_; + std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; + #if USE_CUDNN_LSTM_PROJ + cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; + #endif + cudnnTensorDescriptor_t hx_desc_, cx_desc_; + cudnnTensorDescriptor_t hy_desc_, cy_desc_; + cudnnTensorDescriptor_t dhx_desc_, dcx_desc_; + cudnnTensorDescriptor_t dhy_desc_, dcy_desc_; + + cudnnFilterDescriptor_t w_desc_, dw_desc_; + // Allow TensorCore algo policy + bool cudnn_tensor_core_; + + #if CUDNN_MAJOR >= 5 + cudnnTensorFormat_t format_; + #endif + #endif + bool init_space_, temp_init_space_; + size_t reserve_cpu_space_size_, temp_cpu_space_size_; + Storage::Handle reserve_cpu_space_, temp_cpu_space_; +}; // class RNNOp + +static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector<int> &in_types) { + const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed); + OpStatePtr state = OpStatePtr(); + MSHADOW_REAL_TYPE_SWITCH(in_types[rnn_enum::kData], DType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create<RNNOp<gpu, DType>>(param, ctx); } else { - out_type->push_back(dtype); - // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) - out_type->push_back(dtype); - return true; + state = OpStatePtr::Create<RNNOp<cpu, DType>>(param, ctx); } - } - - OperatorProperty* Copy() const override { - auto ptr = new RNNProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "RNN"; - } + }); + return state; +} - std::vector<int> DeclareBackwardDependency( - const std::vector<int> &out_grad, - const std::vector<int> &in_data, - const std::vector<int> &out_data) const override { - std::vector<int> dep = {in_data[rnn_enum::kData], in_data[rnn_enum::kParams], - in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; +template<typename xpu> +void RNNStatefulCompute(const OpStatePtr& state, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + int dtype = inputs[rnn_enum::kData].type_flag_; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + RNNOp<xpu, DType>& op = state.get_state<RNNOp<xpu, DType>>(); + op.Forward(ctx, inputs, req, outputs); + }); +} - if (param_.state_outputs) { - dep.push_back(out_data[rnn_enum::kStateOut]); - dep.push_back(out_grad[rnn_enum::kStateOut]); +/* +index description Review comment: We can probably make some enums like `RNNOpInputs ` for these indices ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services