rahul003 commented on a change in pull request #8294: NCCL integration URL: https://github.com/apache/incubator-mxnet/pull/8294#discussion_r148692803
########## File path: src/kvstore/comm.h ########## @@ -635,6 +656,302 @@ class CommDevice : public Comm { bool inited_; }; +#if MXNET_USE_NCCL +class CommNCCL : public Comm { + public: + CommNCCL() { + inited_ = false; + pinned_ctx_ = Context::CPUPinned(0); + } + + virtual ~CommNCCL() { + for (auto e : nccl_data_) { + cudaStreamDestroy(e.second.stream); + ncclCommDestroy(e.second.comm); + } + } + + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int dtype = mshadow::kFloat32, Context pinned_ctx = Context::CPUPinned(0)) override { + if (stype == kDefaultStorage) { + sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + } else { + LOG(FATAL) << "NCCL KVStore does not support sparse storage type"; + } + } + + const NDArray& Reduce(int key, const std::vector<NDArray>& src, + int priority) override { + // avoid extra copy for single device, but it may bring problems for + // abnormal usage of kvstore + if (src.size() == 1) { + return src[0]; + } + + if (!inited_) { + std::vector<Context> devs; + for (const auto& a : src) { + devs.push_back(a.ctx()); + } + InitNCCL(devs); + InitMergeBuffer(devs); + } + + std::vector<int> dev_ids; + for (auto e : src) { + dev_ids.push_back(e.ctx().dev_id); + } + std::sort(dev_ids.begin(), dev_ids.end()); + CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices"; Review comment: Do you want to check here that the set of devices don't change during the training? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services