This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 838e256  Optimize NMS part 2 (#14352)
838e256 is described below

commit 838e256cbfbf0856c42b268b801f7f55271bcff2
Author: Przemyslaw Tredak <ptre...@gmail.com>
AuthorDate: Thu Mar 7 18:57:08 2019 -0800

    Optimize NMS part 2 (#14352)
    
    * Optimize NMS part 2
    
    * Guarding ldg intrinsics
---
 src/operator/contrib/bounding_box-common.h | 10 +++++++
 src/operator/contrib/bounding_box-inl.cuh  | 44 ++++++++++++++++++++++++++++++
 src/operator/contrib/bounding_box-inl.h    | 27 +++++++++---------
 3 files changed, 68 insertions(+), 13 deletions(-)

diff --git a/src/operator/contrib/bounding_box-common.h 
b/src/operator/contrib/bounding_box-common.h
index 70215ab..4c9b1b8 100644
--- a/src/operator/contrib/bounding_box-common.h
+++ b/src/operator/contrib/bounding_box-common.h
@@ -112,6 +112,16 @@ struct nms_impl {
   }
 };
 
+namespace mshadow_op {
+struct less_than : public mxnet_op::tunable {
+  // a is x, b is sigma
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a, DType b) {
+    return static_cast<DType>(a < b);
+  }
+};  // struct equal_to
+}   // namespace mshadow_op
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/contrib/bounding_box-inl.cuh 
b/src/operator/contrib/bounding_box-inl.cuh
index 4b7cf34..e7f5567 100644
--- a/src/operator/contrib/bounding_box-inl.cuh
+++ b/src/operator/contrib/bounding_box-inl.cuh
@@ -280,6 +280,50 @@ void NMSApply(mshadow::Stream<gpu> *s,
   }
 }
 
+__launch_bounds__(512)
+__global__ void nms_calculate_batch_start_kernel(int32_t * batch_start,
+                                                 int32_t * valid_batch_id,
+                                                 size_t N,
+                                                 int num_batch) {
+  size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+  if (tid < N) {
+#if __CUDA_ARCH__ >= 350
+    const int32_t previous = tid > 0 ? __ldg(valid_batch_id + tid - 1) : -1;
+    const int32_t my = __ldg(valid_batch_id + tid);
+#else
+    const int32_t previous = tid > 0 ? valid_batch_id[tid - 1] : -1;
+    const int32_t my = valid_batch_id[tid];
+#endif
+    if (my > previous) {
+      for (int32_t current = previous + 1; current <= my; ++current) {
+        batch_start[current] = tid;
+      }
+    }
+    if (tid == N - 1) {
+      for (int32_t current = my + 1; current <= num_batch; ++current) {
+        batch_start[current] = tid + 1;
+      }
+    }
+  }
+}
+
+inline void NMSCalculateBatchStart(mshadow::Stream<gpu> *s,
+                                   mshadow::Tensor<gpu, 1, int32_t>* 
batch_start,
+                                   mshadow::Tensor<gpu, 1, int32_t>* 
valid_batch_id,
+                                   int num_batch) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  auto stream = mshadow::Stream<gpu>::GetStream(s);
+  constexpr int block_size = 512;
+  const int num_elements = valid_batch_id->size(0);
+  const int blocks = (num_elements + block_size - 1) / block_size;
+  nms_calculate_batch_start_kernel<<<blocks, block_size, 0, 
stream>>>(batch_start->dptr_,
+                                                                      
valid_batch_id->dptr_,
+                                                                      
num_elements,
+                                                                      
num_batch);
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/contrib/bounding_box-inl.h 
b/src/operator/contrib/bounding_box-inl.h
index 35ab19d..8610dcc 100644
--- a/src/operator/contrib/bounding_box-inl.h
+++ b/src/operator/contrib/bounding_box-inl.h
@@ -162,15 +162,6 @@ int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
   return j;
 }
 
-namespace mshadow_op {
-struct less_than : public mxnet_op::tunable {
-  // a is x, b is sigma
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType a, DType b) {
-    return static_cast<DType>(a < b);
-  }
-};  // struct equal_to
-}   // namespace mshadow_op
 
 struct corner_to_center {
   template<typename DType>
@@ -277,6 +268,19 @@ void NMSApply(mshadow::Stream<cpu> *s,
   }
 }
 
+inline void NMSCalculateBatchStart(mshadow::Stream<cpu> *s,
+                                   mshadow::Tensor<cpu, 1, int32_t>* 
batch_start,
+                                   mshadow::Tensor<cpu, 1, int32_t>* 
valid_batch_id,
+                                   int num_batch) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  for (int b = 0; b < num_batch + 1; b++) {
+    slice<0>(*batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
+        F<mshadow_op::less_than>(*valid_batch_id, ScalarExp<int32_t>(b)), 0);
+  }
+}
+
 /*!
    * \brief Assign output of nms by indexing input
    *
@@ -435,10 +439,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
 
     // calculate batch_start: accumulated sum to denote 1st sorted_index for a 
given batch_index
     valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
-    for (int b = 0; b < num_batch + 1; b++) {
-      slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
-        F<mshadow_op::less_than>(valid_batch_id, ScalarExp<int32_t>(b)), 0);
-    }
+    mxnet::op::NMSCalculateBatchStart(s, &batch_start, &valid_batch_id, 
num_batch);
 
     // pre-compute areas of candidates
     areas = 0;

Reply via email to