sxjscience commented on a change in pull request #14894: Accelerate ROIPooling 
layer
URL: https://github.com/apache/incubator-mxnet/pull/14894#discussion_r303173596
 
 

 ##########
 File path: src/operator/roi_pooling.cc
 ##########
 @@ -137,91 +140,18 @@ template<typename Dtype>
 inline void ROIPoolBackwardAcc(const Tensor<cpu, 4, Dtype> &in_grad,
                                const Tensor<cpu, 4, Dtype> &out_grad,
                                const Tensor<cpu, 2, Dtype> &bbox,
-                               const Tensor<cpu, 4, Dtype> &max_idx,
+                               const Tensor<cpu, 4, index_t> &max_idx,
                                const float spatial_scale_) {
   const Dtype *top_diff = out_grad.dptr_;
-  const Dtype *bottom_rois = bbox.dptr_;
   Dtype *bottom_diff = in_grad.dptr_;
-  Dtype *argmax_data = max_idx.dptr_;
-
-  const int batch_size_ = in_grad.size(0);
-  const int channels_ = in_grad.size(1);
-  const int height_ = in_grad.size(2);
-  const int width_ = in_grad.size(3);
-  const int pooled_height_ = out_grad.size(2);
-  const int pooled_width_ = out_grad.size(3);
-
-  const int num_rois = bbox.size(0);
-
-  for (int b = 0; b < batch_size_; ++b) {
-    for (int c = 0; c < channels_; ++c) {
-      for (int h = 0; h < height_; ++h) {
-        for (int w = 0; w < width_; ++w) {
-          int offset_bottom_diff = (b * channels_ + c) * height_ * width_;
-          offset_bottom_diff += h * width_ + w;
-
-          Dtype gradient = 0;
-          // Accumulate gradient over all ROIs that pooled this element
-          for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
-            const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5;
-            int roi_batch_ind = offset_bottom_rois[0];
-            assert(roi_batch_ind >= 0);
-            assert(roi_batch_ind < batch_size_);
-            if (b != roi_batch_ind) {
-              continue;
-            }
+  index_t *argmax_data = max_idx.dptr_;
 
-            int roi_start_w = std::round(offset_bottom_rois[1] * 
spatial_scale_);
-            int roi_start_h = std::round(offset_bottom_rois[2] * 
spatial_scale_);
-            int roi_end_w = std::round(offset_bottom_rois[3] * spatial_scale_);
-            int roi_end_h = std::round(offset_bottom_rois[4] * spatial_scale_);
+  const index_t count = out_grad.shape_.Size();
 
-            bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
-                           h >= roi_start_h && h <= roi_end_h);
-            if (!in_roi) {
-              continue;
-            }
-
-            // force malformed ROIs to be 1 * 1
-            int roi_height = max(roi_end_h - roi_start_h + 1, 1);
-            int roi_width = max(roi_end_w - roi_start_w + 1, 1);
-            const Dtype bin_size_h = static_cast<Dtype>(roi_height)
-                                     / static_cast<Dtype>(pooled_height_);
-            const Dtype bin_size_w = static_cast<Dtype>(roi_width)
-                                     / static_cast<Dtype>(pooled_width_);
-
-            // compute pooled regions correspond to original (h, w) point
-            int phstart = static_cast<int>(floor(static_cast<Dtype>(h - 
roi_start_h)
-                                                 / bin_size_h));
-            int pwstart = static_cast<int>(floor(static_cast<Dtype>(w - 
roi_start_w)
-                                                 / bin_size_w));
-            int phend = static_cast<int>(ceil(static_cast<Dtype>(h - 
roi_start_h + 1)
-                                              / bin_size_h));
-            int pwend = static_cast<int>(ceil(static_cast<Dtype>(w - 
roi_start_w + 1)
-                                              / bin_size_w));
-
-            // clip to boundaries of pooled region
-            phstart = min(max(phstart, 0), pooled_height_);
-            phend = min(max(phend, 0), pooled_height_);
-            pwstart = min(max(pwstart, 0), pooled_width_);
-            pwend = min(max(pwend, 0), pooled_width_);
-
-            // accumulate over gradients in pooled regions
-            int offset = (roi_n * channels_ + c) * pooled_height_ * 
pooled_width_;
-            const Dtype* offset_top_diff = top_diff + offset;
-            const Dtype* offset_argmax_data = argmax_data + offset;
-            for (int ph = phstart; ph < phend; ++ph) {
-              for (int pw = pwstart; pw < pwend; ++pw) {
-                const int pooled_index = ph * pooled_width_ + pw;
-                if (static_cast<int>(offset_argmax_data[pooled_index]) == h * 
width_ + w) {
-                  gradient += offset_top_diff[pooled_index];
-                }
-              }
-            }
-          }
-          bottom_diff[offset_bottom_diff] += gradient;
-        }
-      }
+  for (int index = 0; index < count; ++index) {
+    index_t max_idx = argmax_data[index];
+    if (max_idx >= 0) {
+      bottom_diff[max_idx] += top_diff[index];
 
 Review comment:
   It's correct, sorry for the misreading.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to