ptrendx commented on a change in pull request #8294: NCCL integration
URL: https://github.com/apache/incubator-mxnet/pull/8294#discussion_r148906667
 
 

 ##########
 File path: src/kvstore/comm.h
 ##########
 @@ -635,6 +656,302 @@ class CommDevice : public Comm {
   bool inited_;
 };
 
+#if MXNET_USE_NCCL
+class CommNCCL : public Comm {
+ public:
+  CommNCCL() {
+    inited_ = false;
+    pinned_ctx_ = Context::CPUPinned(0);
+  }
+
+  virtual ~CommNCCL() {
+    for (auto e : nccl_data_) {
+      cudaStreamDestroy(e.second.stream);
+      ncclCommDestroy(e.second.comm);
+    }
+  }
+
+  void Init(int key, const NDArrayStorageType stype, const TShape& shape,
+            int dtype = mshadow::kFloat32, Context pinned_ctx = 
Context::CPUPinned(0)) override {
+    if (stype == kDefaultStorage) {
+      sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype));
+    } else {
+      LOG(FATAL) << "NCCL KVStore does not support sparse storage type";
+    }
+  }
+
+  const NDArray& Reduce(int key, const std::vector<NDArray>& src,
+                        int priority) override {
+    // avoid extra copy for single device, but it may bring problems for
+    // abnormal usage of kvstore
+    if (src.size() == 1) {
+      return src[0];
+    }
+
+    if (!inited_) {
+      std::vector<Context> devs;
+      for (const auto& a : src) {
+        devs.push_back(a.ctx());
+      }
+      InitNCCL(devs);
+      InitMergeBuffer(devs);
+    }
+
+    std::vector<int> dev_ids;
+    for (auto e : src) {
+      dev_ids.push_back(e.ctx().dev_id);
+    }
+    std::sort(dev_ids.begin(), dev_ids.end());
+    CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of 
devices";
+
+    auto& buf = merge_buf_[key];
+    int root = buf.merged.ctx().dev_id;
+    size_t root_id = -1;
+    for (size_t i = 0; i < src.size(); ++i) {
+      if (src[i].ctx().dev_id == root) {
+        root_id = i;
+        break;
+      }
+    }
+
+    auto& reduce = buf.merged;
+
+    std::vector<Engine::VarHandle> const_vars;
+    for (size_t i = 0; i < src.size(); ++i) {
+      const_vars.push_back(src[i].var());
+    }
+    Engine::Get()->PushSync([src, reduce, root_id, this](RunContext rctx) {
+          {
+            std::lock_guard<std::mutex> 
l(Storage::Get()->GetMutex(Context::kGPU));
+            int root = nccl_data_[src[root_id].ctx().dev_id].rank;
+            ncclGroupStart();
+            for (size_t i = 0; i < src.size(); ++i) {
+              NCCLEntry cur = nccl_data_[src[i].ctx().dev_id];
+              if (i == root_id) {
+              MSHADOW_TYPE_SWITCH(src[i].dtype(), DType,
+              ncclReduce(src[i].data().dptr<DType>(),
+                                reduce.data().dptr<DType>(),
+                                src[i].shape().Size(),
+                                GetNCCLType(src[i].dtype()),
+                                ncclSum,
+                                root,
+                                cur.comm,
+                                cur.stream););
+              } else {
+              MSHADOW_TYPE_SWITCH(src[i].dtype(), DType,
+              ncclReduce(src[i].data().dptr<DType>(),
+                                NULL,
+                                src[i].shape().Size(),
+                                GetNCCLType(src[i].dtype()),
+                                ncclSum,
+                                root,
+                                cur.comm,
+                                cur.stream););
+              }
+            }
+            ncclGroupEnd();
+          }
+        },
+        Context::CPU(),
+        const_vars,
+        {reduce.var()},
+        FnProperty::kCPUPrioritized,
+        priority,
+        PROFILER_MESSAGE("KVStoreReduce"));
+
+    return buf.merged;
+  }
+
+  void CommSync(const std::vector<const NDArray*>& dst,
+                int priority) override {
+    std::vector<Engine::VarHandle> const_vars;
+    std::vector<Engine::VarHandle> mutate_vars;
+    for (size_t i = 0; i < dst.size(); ++i) {
+        mutate_vars.push_back(dst[i]->var());
+    }
+    Engine::Get()->PushSync([this](RunContext rctx) {
+          for (auto cur : nccl_data_) {
+            CUDA_CALL(cudaSetDevice(cur.second.dev_id));
+            CUDA_CALL(cudaStreamSynchronize(cur.second.stream));
+          }
+        },
+        Context::CPU(),
+        const_vars,
+        mutate_vars,
+        FnProperty::kCPUPrioritized,
+        priority,
+        PROFILER_MESSAGE("KVStoreStreamSync"));
+  }
+
+  void CommSync(const std::vector<NDArray>& dst,
+                int priority) override {
+    std::vector<Engine::VarHandle> const_vars;
+    std::vector<Engine::VarHandle> mutate_vars;
+    for (size_t i = 0; i < dst.size(); ++i) {
+        mutate_vars.push_back(dst[i].var());
+    }
+    Engine::Get()->PushSync([this](RunContext rctx) {
+          for (auto cur : nccl_data_) {
+            CUDA_CALL(cudaSetDevice(cur.second.dev_id));
+            CUDA_CALL(cudaStreamSynchronize(cur.second.stream));
+          }
+        },
+        Context::CPU(),
+        const_vars,
+        mutate_vars,
+        FnProperty::kCPUPrioritized,
+        priority,
+        PROFILER_MESSAGE("KVStoreStreamSync"));
+  }
+
+  void BroadcastRowSparse(int key, const NDArray& src,
+                          const std::vector<std::pair<NDArray*, NDArray>>& dst,
+                          const bool use_copy,
+                          const int priority) override {
+    LOG(FATAL) << "NCCL kvstore does not support sparse storage type";
+  }
+
+  void Broadcast(int key, const NDArray& src,
+                 const std::vector<NDArray> dst, int priority) override {
+    if (!inited_) {
+      // copy to a random device first
+      int dev_id = key % dst.size();
+      CopyFromTo(src, dst[dev_id], priority);
+      for (size_t i = 0; i < dst.size(); ++i) {
+        if (i != static_cast<size_t>(dev_id)) {
+          CopyFromTo(dst[dev_id], dst[i], priority);
+        }
+      }
+    } else {
+      auto& buf = merge_buf_[key];
+      int root = src.ctx().dev_id;
+      assert(root == buf.ctx().dev_id);
+      size_t root_id = -1;
+      for (size_t i = 0; i < dst.size(); ++i) {
+        if (dst[i].ctx().dev_id == root) {
+          root_id = i;
+          break;
+        }
+      }
+      std::vector<int> dev_ids;
+      for (size_t i = 0; i < dst.size(); ++i) {
+        auto& bcast = (i == root_id) ? src : dst[i];
+        dev_ids.push_back(bcast.ctx().dev_id);
+      }
+      std::sort(dev_ids.begin(), dev_ids.end());
+      CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set 
of devices";
+      CopyFromTo(src, dst[root_id], priority);
+      if (dst.size() == 1) return;
+      std::vector<Engine::VarHandle> mutable_vars;
+      for (size_t i = 0; i < dst.size(); ++i) {
+        if ( i != root_id)
 
 Review comment:
   Ok

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to