eric-haibin-lin commented on a change in pull request #8732: rsp push and rsp 
pull for comm device, used in kvstore('device')
URL: https://github.com/apache/incubator-mxnet/pull/8732#discussion_r153328289
 
 

 ##########
 File path: src/kvstore/comm.h
 ##########
 @@ -619,7 +655,62 @@ class CommDevice : public Comm {
                           const std::vector<std::pair<NDArray*, NDArray>>& dst,
                           const bool use_copy,
                           const int priority) override {
-    LOG(FATAL) << "Not implemented yet";
+    using namespace mshadow;
+    CHECK_EQ(src.storage_type(), kRowSparseStorage)
+      << "BroadcastRowSparse expects row-sparse src NDArray";
+
+    bool is_same_rowid = true;
+    for (size_t i = 1; i < dst.size(); ++i) {
+      if (dst[i].second.var() != dst[0].second.var()) {
+        is_same_rowid = false;
+      }
+    }
+
+    for (size_t i = 0; i < dst.size(); ++i) {
+      if (is_same_rowid && i != 0) {
+        CopyFromTo(*dst[0].first, dst[i].first, priority);
+        continue;
+      }
+
+      NDArray* out = dst[i].first;
+      NDArray row_id = dst[i].second;
+      if (use_copy) {
+        CopyFromTo(src, out, priority);
+      } else {
+        CHECK_EQ(out->storage_type(), kRowSparseStorage)
+                 << "BroadcastRowSparse expects row_sparse dst NDArray";
+        const bool is_diff_ctx = out->ctx() != src.ctx();
+        NDArray src_gpu = is_diff_ctx? NDArray(kRowSparseStorage, src.shape(),
+            out->ctx(), true, src.dtype(), src.aux_types()) : src;
+        if (is_diff_ctx) {
+          CopyFromTo(src, &src_gpu, priority);
+        }
+        NDArray row_id_gpu = NDArray(row_id.shape(), out->ctx(), false, 
kInt64);
+        const TBlob& indices = row_id_gpu.data();
+        CopyFromTo(row_id, &row_id_gpu, priority);
+
+        Engine::Get()->PushAsync([=](RunContext rctx, 
Engine::CallbackOnComplete on_complete) {
+            NDArray temp = *out;
+            switch (temp.ctx().dev_mask()) {
+              case cpu::kDevMask: {
+                
mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
+                  src_gpu, indices, kWriteTo, &temp);
+                break;
+              }
+#if MXNET_USE_CUDA
+              case gpu::kDevMask: {
+                
mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
+                  src_gpu, indices, kWriteTo, &temp);
+                break;
+              }
 
 Review comment:
   is `Stream->Wait()` missing? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to