larroy commented on a change in pull request #8915: NVLink communication 
pattern updated 
URL: https://github.com/apache/incubator-mxnet/pull/8915#discussion_r160936864
 
 

 ##########
 File path: src/kvstore/comm.h
 ##########
 @@ -526,101 +541,238 @@ class CommDevice : public Comm {
     }
 
     InitBuffersAndComm(src);
+    auto& stage = stage_buf_[key];
     auto& buf = merge_buf_[key];
-    std::vector<NDArray> reduce(src.size());
-    CopyFromTo(src[0], &(buf.merged), priority);
-    reduce[0] = buf.merged;
-
-    if (buf.copy_buf.empty()) {
-      // TODO(mli) this results in large device memory usage for huge ndarray,
-      // such as the largest fullc in VGG. consider to do segment reduce with
-      // NDArray.Slice or gpu direct memory access. for the latter, we need to
-      // remove some ctx check, and also it reduces 20% perf
-      buf.copy_buf.resize(src.size()-1);
-      for (size_t i = 0; i < src.size()-1; ++i) {
-        buf.copy_buf[i] = NDArray(
-          buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+
+    if (buf.merged.is_none() && stage.copy_buf.empty()) {
+      stage.copy_buf.resize(src.size() - 1);
+      for (size_t i = 0; i < src.size() - 1; ++i)
+        stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                    false, stage.merged.dtype());
+    } else if (!buf.merged.is_none()) {
+      if (buf.copy_buf.empty()) {
+        buf.copy_buf.resize(g1.size());
+        for (size_t i = 0; i < g1.size(); ++i)
+          buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), 
false,
+                                    buf.merged.dtype());
+      }
+      if (stage.copy_buf.empty()) {
+        stage.copy_buf.resize(g2.size() - 1);
+        for (size_t i = 0; i < g2.size() - 1; ++i)
+          stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                      false, stage.merged.dtype());
       }
     }
-    for (size_t i = 0; i < src.size()-1; ++i) {
-      CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
-      reduce[i+1] = buf.copy_buf[i];
+    std::vector<NDArray> reduce_s(stage.copy_buf.size() + 1);
+    for (size_t i = 0, j = 0; i < src.size(); ++i) {
+      int id = src[i].ctx().dev_id;
+      if ((!buf.merged.is_none() && id == stage.merged.ctx().dev_id) ||
+          (buf.merged.is_none() && i == 0)) {
+        CopyFromTo(src[i], &(stage.merged), priority);
+        reduce_s[0] = stage.merged;
+      } else if (id >= NVLINK_SUPPORT || buf.merged.is_none()) {
+        CopyFromTo(src[i], &(stage.copy_buf[j]), priority);
+        reduce_s[j + 1] = stage.copy_buf[j];
+        j++;
+      }
     }
+    ElementwiseSum(reduce_s, &stage.merged);
+    // Main reduce result on gpu 0 including the partial result from gpu
+    // NVLINK_SUPPORT
+    if (!buf.merged.is_none()) {
+      std::vector<NDArray> reduce(buf.copy_buf.size() + 1);
+      for (size_t i = 0, j = 0; i < src.size(); ++i) {
+        int id = src[i].ctx().dev_id;
+        if (id == buf.merged.ctx().dev_id) {
+          CopyFromTo(src[i], &(buf.merged), priority);
+          reduce[0] = buf.merged;
+        } else if (id < NVLINK_SUPPORT) {
+          CopyFromTo(src[i], &(buf.copy_buf[j]), priority);
+          reduce[j + 1] = buf.copy_buf[j];
+          j++;
+        }
+      }
 
-    ElementwiseSum(reduce, &buf.merged);
+      CopyFromTo(stage.merged, &(buf.copy_buf[buf.copy_buf.size() - 1]),
+                 priority);
+      reduce[reduce.size() - 1] = buf.copy_buf[buf.copy_buf.size() - 1];
+      ElementwiseSum(reduce, &buf.merged);
+    } else {
+      return stage.merged;
+    }
     return buf.merged;
   }
 
   const NDArray& ReduceCompressed(int key, const std::vector<NDArray>& src,
                                   int priority) {
     InitBuffersAndComm(src);
     auto& buf = merge_buf_[key];
-    std::vector<NDArray> reduce(src.size());
-    if (buf.copy_buf.empty()) {
+    auto& stage = stage_buf_[key];
+    if (buf.merged.is_none() && stage.copy_buf.empty()) {
       // one buf for each context
-      buf.copy_buf.resize(src.size());
-      buf.compressed_recv_buf.resize(src.size());
-      buf.compressed_send_buf.resize(src.size());
-      buf.residual.resize(src.size());
+      stage.copy_buf.resize(src.size());
+      stage.compressed_recv_buf.resize(src.size());
+      stage.compressed_send_buf.resize(src.size());
+      stage.residual.resize(src.size());
 
       for (size_t i = 0; i < src.size(); ++i) {
-        buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(),
-                                  false, buf.merged.dtype());
-        buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(),
-                                  false, buf.merged.dtype());
-        buf.residual[i] = 0;
+        stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                    false, stage.merged.dtype());
+        stage.residual[i] = NDArray(stage.merged.shape(), src[i].ctx(), false,
+                                    stage.merged.dtype());
+        stage.residual[i] = 0;
+        int64_t small_size =
+            gc_->GetCompressedSize(stage.merged.shape().Size());
+        stage.compressed_recv_buf[i] =
+            NDArray(TShape{small_size}, stage.merged.ctx(), false,
+                    stage.merged.dtype());
+        stage.compressed_send_buf[i] = NDArray(TShape{small_size}, 
src[i].ctx(),
+                                               false, stage.merged.dtype());
+      }
+    } else if (!buf.merged.is_none()) {
+      if (buf.copy_buf.empty() && stage.copy_buf.empty()) {
+        buf.copy_buf.resize(g1.size() + 1);
+        buf.compressed_recv_buf.resize(g1.size() + 1);
+        buf.compressed_send_buf.resize(g1.size() + 1);
+        buf.residual.resize(g1.size() + 1);
+        stage.copy_buf.resize(g2.size());
+        stage.compressed_recv_buf.resize(g2.size());
+        stage.compressed_send_buf.resize(g2.size());
+        stage.residual.resize(g2.size());
+        for (size_t i = 0, j = 0, k = 0; i < src.size(); ++i) {
+          int id = src[i].ctx().dev_id;
+          if (id < NVLINK_SUPPORT) {
+            buf.copy_buf[j] = NDArray(buf.merged.shape(), buf.merged.ctx(),
+                                      false, buf.merged.dtype());
+            buf.residual[j] = NDArray(buf.merged.shape(), src[i].ctx(), false,
+                                      buf.merged.dtype());
+            buf.residual[j] = 0;
+            int64_t small_size =
+                gc_->GetCompressedSize(buf.merged.shape().Size());
+            buf.compressed_recv_buf[j] =
+                NDArray(TShape{small_size}, buf.merged.ctx(), false,
+                        buf.merged.dtype());
+            buf.compressed_send_buf[j] = NDArray(
+                TShape{small_size}, src[i].ctx(), false, buf.merged.dtype());
+            j++;
+          } else {
+            stage.copy_buf[k] =
+                NDArray(stage.merged.shape(), stage.merged.ctx(), false,
+                        stage.merged.dtype());
+            stage.residual[k] = NDArray(stage.merged.shape(), src[i].ctx(),
+                                        false, stage.merged.dtype());
+            stage.residual[k] = 0;
+            int64_t small_size =
+                gc_->GetCompressedSize(stage.merged.shape().Size());
+            stage.compressed_recv_buf[k] =
+                NDArray(TShape{small_size}, stage.merged.ctx(), false,
+                        stage.merged.dtype());
+            stage.compressed_send_buf[k] = NDArray(
+                TShape{small_size}, src[i].ctx(), false, stage.merged.dtype());
+            k++;
+          }
+        }
+        buf.copy_buf[g1.size()] = NDArray(buf.merged.shape(), buf.merged.ctx(),
+                                          false, buf.merged.dtype());
+        buf.residual[g1.size()] = NDArray(
+            buf.merged.shape(), stage.merged.ctx(), false, buf.merged.dtype());
+        buf.residual[g1.size()] = 0;
         int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size());
-        buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, 
buf.merged.ctx(),
-                                        false, buf.merged.dtype());
-        buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(),
-                                        false, buf.merged.dtype());
+        buf.compressed_recv_buf[g1.size()] = NDArray(
+            TShape{small_size}, buf.merged.ctx(), false, buf.merged.dtype());
+        buf.compressed_send_buf[g1.size()] = NDArray(
+            TShape{small_size}, stage.merged.ctx(), false, buf.merged.dtype());
       }
     }
+    std::vector<NDArray> reduce_s(stage.copy_buf.size());
+    std::vector<NDArray> reduce(buf.copy_buf.size());
+
+    for (size_t i = 0, j = 0, k = 0; i < src.size(); ++i) {
+      int id = src[i].ctx().dev_id;
+      if (id >= NVLINK_SUPPORT || buf.merged.is_none()) {
+        // compress before copy
+        // this is done even if the data is on same context as copy_buf because
+        // we don't want the training to be biased towards data on this GPU
+        gc_->Quantize(src[i], &(stage.compressed_send_buf[j]),
+                      &(stage.residual[j]), priority);
+
+        if (stage.compressed_send_buf[j].ctx() !=
+            stage.compressed_recv_buf[j].ctx()) {
+          CopyFromTo(stage.compressed_send_buf[j],
+                     &(stage.compressed_recv_buf[j]), priority);
+        } else {
+          // avoid memory copy when they are on same context
+          stage.compressed_recv_buf[j] = stage.compressed_send_buf[j];
+        }
 
-    for (size_t i = 0; i < src.size(); ++i) {
-      // compress before copy
-      // this is done even if the data is on same context as copy_buf because
-      // we don't want the training to be biased towards data on this GPU
-      gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), 
priority);
-
-      if (buf.compressed_send_buf[i].ctx() != 
buf.compressed_recv_buf[i].ctx()) {
-        CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), 
priority);
+        gc_->Dequantize(stage.compressed_recv_buf[j], &(stage.copy_buf[j]),
+                        priority);
+        reduce_s[j] = stage.copy_buf[j];
+        j++;
       } else {
-        // avoid memory copy when they are on same context
-        buf.compressed_recv_buf[i] = buf.compressed_send_buf[i];
-      }
+        gc_->Quantize(src[i], &(buf.compressed_send_buf[k]), 
&(buf.residual[k]),
+                      priority);
+
+        if (buf.compressed_send_buf[k].ctx() !=
+            buf.compressed_recv_buf[k].ctx()) {
+          CopyFromTo(buf.compressed_send_buf[k], &(buf.compressed_recv_buf[k]),
+                     priority);
+        } else {
+          // avoid memory copy when they are on same context
+          buf.compressed_recv_buf[k] = buf.compressed_send_buf[k];
+        }
 
-      gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), 
priority);
-      reduce[i] = buf.copy_buf[i];
+        gc_->Dequantize(buf.compressed_recv_buf[k], &(buf.copy_buf[k]),
+                        priority);
+        reduce[k] = buf.copy_buf[k];
+        k++;
+      }
+    }
+    ElementwiseSum(reduce_s, &stage.merged);
+    if (buf.merged.is_none()) {
+      return stage.merged;
+    } else {
+      gc_->Quantize(stage.merged, &buf.compressed_send_buf[g1.size()],
+                    &(buf.residual[g1.size()]), priority);
+      CopyFromTo(buf.compressed_send_buf[g1.size()],
+                 &(buf.compressed_recv_buf[g1.size()]), priority);
+      gc_->Dequantize(buf.compressed_recv_buf[g1.size()],
+                      &(buf.copy_buf[g1.size()]), priority);
+      reduce[reduce.size() - 1] = buf.copy_buf[g1.size()];
+      ElementwiseSum(reduce, &buf.merged);
     }
-    ElementwiseSum(reduce, &buf.merged);
+
     return buf.merged;
   }
 
-  void Broadcast(int key, const NDArray& src,
-                 const std::vector<NDArray*> dst, int priority) override {
+  void Broadcast(int key, const NDArray& src, const std::vector<NDArray*> dst,
 
 Review comment:
   Shouldn't be by ref?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to