ctcyang commented on a change in pull request #11591: [MXNET-331] Single 
machine All Reduce Topology-aware Communication (Updated)
URL: https://github.com/apache/incubator-mxnet/pull/11591#discussion_r202492787
 
 

 ##########
 File path: src/kvstore/comm_tree.h
 ##########
 @@ -0,0 +1,500 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Copyright (c) 2018 by Contributors
+ */
+#ifndef MXNET_KVSTORE_COMM_TREE_H_
+#define MXNET_KVSTORE_COMM_TREE_H_
+#include <dmlc/omp.h>
+#include <string>
+#include <algorithm>
+#include <utility>
+#include <limits>
+#include <vector>
+#include <tuple>
+#include <thread>
+#include <map>
+#include "mxnet/ndarray.h"
+#include "gradient_compression.h"
+#include "../ndarray/ndarray_function.h"
+#include "../operator/tensor/sparse_retain-inl.h"
+#include "./kvstore_utils.h"
+#include "./gpu_topology.h"
+namespace mxnet {
+namespace kvstore {
+/**
+ * \brief an implementation of Comm that performs reduction on device
+ * directly using tree.
+ *
+ * It is faster if the total device-to-device bandwidths is larger than
+ * device-to-cpu, which is often true for 4 or 8 GPUs. But it uses more device
+ * memory.
+ */
+class CommDeviceTree : public CommDevice {
+ public:
+  CommDeviceTree() {
+    inited_ = false;
+    gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_GPUARRAY_BOUND", 10000000);
+    backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 0);
+    link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 
0.7);
+  }
+
+  virtual ~CommDeviceTree() { }
+
+  void Init(int key, const NDArrayStorageType stype, const TShape& shape,
+            int dtype = mshadow::kFloat32) override {
+    tree_sorted_key_attrs_.emplace_back(key, shape, dtype);
+    sorted_key_attrs_.emplace_back(key, shape, dtype);
+  }
+
+  void InitBuffersAndComm(const std::vector<NDArray>& src) {
+    if (!inited_) {
+      for (const auto& a : src) {
+        devs_.push_back(a.ctx());
+      }
+      QueryTopology();
+      // Note: delayed allocation set to true, because we do not want to 
allocate
+      // both in TreeBufferEntry and BufferEntry, so we use a size_t to keep
+      // track of each key's shape within BufferEntry
+      // -this information is required for inherited Reduce- and
+      //  BroadcastRowSparse
+      InitMergeBuffer(devs_);
+      InitMergeBufferTree();
+      if (dmlc::GetEnv("MXNET_ENABLE_GPU_P2P", 1)) {
+        EnableP2P();
+      }
+    }
+  }
+
+  // src is sliced shape
+  // copy_buf not sliced
+  // merged not sliced
+  const NDArray& ReduceInner(int key, const std::vector<NDArray>& src, int 
root,
+                             int merged_row, int priority) {
+    std::vector<std::vector<NDArray>> reduce(devs_.size());
+
+    TreeBufferEntry& random_buf = tree_merge_buf_[0][key];
+    const NDArrayStorageType stype = random_buf.merged[0].storage_type();
+    std::vector<size_t>& topology = topology_[root];
+    NDArray buf_slice;
+
+    if (stype == kDefaultStorage) {
+      // Copy everything into buf.merged for each gpu
 
 Review comment:
   I tested the throughput difference on a similar intra-GPU `CopyFromTo`, by 
commenting out Line 306 in `comm_tree.h`. This ruins the correctness of the 
output, but it gives an idea of how much savings can be gotten. Back then, I 
was considering using a combined PushPull API, so that the `dst` array is known 
at the time of push. This saves one `CopyFromTo` during Broadcast from the 
temporary buffer `buf.merged` to `dst` array.
   
   Testing on VGG-16 on an older commit, I got the following. Geomean 
difference across these batch sizes suggests getting rid of one CopyFromTo 
makes it 2.2% faster. If you two think these 2 optimizations--(i) Eliminate 
copy from buf.merged in reduce, (ii) Eliminate copy from buf.merged to `dst` in 
broadcast by using PushPull interface--I can add them as a separate PR after 
this one is accepted, because the PushPull interface depends on the Intel's C 
API addition (#10696).
   
   ```
   
   v6: Push, Pull interface - One more intra-GPU CopyFromTo than v7
   v7: PushPull combined AllReduce interface
   BS: Batch size per GPU (8 GPUs total)
   
   Throughput (samples/s)
   
   fp32
   BS | v6   | v7
   ----------------
   4  | 711  | 745
   8  | 999  | 1035
   16 | 1449 | 1478
   32 | 1638 | 1672
   64 | 1695 | 1739
   
   fp16
   BS | v6   | v7
   ----------------
   8  | 1552 | 1599
   16 | 2127 | 2163
   32 | 2916 | 2910
   64 | 2720 | 2775
   128| 2518 | 2532
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to