eric-haibin-lin commented on a change in pull request #13346: Aggregate SGD URL: https://github.com/apache/incubator-mxnet/pull/13346#discussion_r247729674
########## File path: src/operator/optimizer_op.cc ########## @@ -313,6 +315,209 @@ inline bool SGDStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } +NNVM_REGISTER_OP(multi_sgd_update) +.describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed); + return static_cast<uint32_t>(param.num_weights * 2); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed); + return static_cast<uint32_t>(param.num_weights); + }) +.set_attr_parser(ParamParser<MultiSGDParam>) +.set_attr<nnvm::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 2>) +.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>) +.set_attr<nnvm::FListInputNames>("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights; + std::vector<std::string> ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + } + return ret; + }) +.set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, type_identity, 2>) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_sgd_mom_update) +.describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage +type is the same as momentum's storage type, +only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: + + for row in gradient.indices: + v[row] = momentum[row] * v[row] - learning_rate * gradient[row] + weight[row] += v[row] + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed); + return static_cast<uint32_t>(param.num_weights * 3); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed); + return static_cast<uint32_t>(param.num_weights); + }) +.set_attr_parser(ParamParser<MultiSGDMomParam>) +.set_attr<nnvm::FInferShape>("FInferShape", MultiSGDShape<MultiSGDMomParam, 3>) +.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>) +.set_attr<nnvm::FListInputNames>("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights; + std::vector<std::string> ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("mom_") + std::to_string(i)); + } + return ret; + }) +.set_attr<nnvm::FMutateInputs>("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector<uint32_t> ret; + const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr<FCompute>("FCompute<cpu>", MultiSGDMomUpdate<cpu, type_identity, 3>) +.add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients and momentum") +.add_arguments(MultiSGDMomParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_mp_sgd_update) +.describe(R"code(Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed); + return static_cast<uint32_t>(param.num_weights * 3); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed); + return static_cast<uint32_t>(param.num_weights); + }) +.set_attr_parser(ParamParser<MultiSGDParam>) +.set_attr<nnvm::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 3>) +.set_attr<nnvm::FInferType>("FInferType", MP_MultiSGD_InferType<MultiSGDParam, 3, 1>) +.set_attr<nnvm::FListInputNames>("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights; + std::vector<std::string> ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("weight32_") + std::to_string(i)); + } + return ret; + }) +.set_attr<nnvm::FMutateInputs>("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector<uint32_t> ret; + const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, single_precision, 3>) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_mp_sgd_mom_update) +.describe(R"code(Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage Review comment: Could you remove the docs on sparse data type support? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services