piiswrong closed pull request #12118: fix potential floating number overflow, 
enable float16
URL: https://github.com/apache/incubator-mxnet/pull/12118
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/contrib/bounding_box-inl.cuh 
b/src/operator/contrib/bounding_box-inl.cuh
index fb1dacc11f4..fd5e30b25b2 100644
--- a/src/operator/contrib/bounding_box-inl.cuh
+++ b/src/operator/contrib/bounding_box-inl.cuh
@@ -45,9 +45,9 @@ struct valid_score {
 
 template<typename DType>
 int FilterScores(mshadow::Tensor<gpu, 1, DType> out_scores,
-                 mshadow::Tensor<gpu, 1, DType> out_sorted_index,
+                 mshadow::Tensor<gpu, 1, int32_t> out_sorted_index,
                  mshadow::Tensor<gpu, 1, DType> scores,
-                 mshadow::Tensor<gpu, 1, DType> sorted_index,
+                 mshadow::Tensor<gpu, 1, int32_t> sorted_index,
                  float valid_thresh) {
   valid_score<DType> pred(static_cast<DType>(valid_thresh));
   DType * end_scores = thrust::copy_if(thrust::device, scores.dptr_, 
scores.dptr_ + scores.MSize(),
diff --git a/src/operator/contrib/bounding_box-inl.h 
b/src/operator/contrib/bounding_box-inl.h
index f739dbc8a52..8e963461ec0 100644
--- a/src/operator/contrib/bounding_box-inl.h
+++ b/src/operator/contrib/bounding_box-inl.h
@@ -150,9 +150,9 @@ inline uint32_t BoxNMSNumVisibleOutputs(const NodeAttrs& 
attrs) {
 
 template<typename DType>
 int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
-                 mshadow::Tensor<cpu, 1, DType> out_sorted_index,
+                 mshadow::Tensor<cpu, 1, int32_t> out_sorted_index,
                  mshadow::Tensor<cpu, 1, DType> scores,
-                 mshadow::Tensor<cpu, 1, DType> sorted_index,
+                 mshadow::Tensor<cpu, 1, int32_t> sorted_index,
                  float valid_thresh) {
   index_t j = 0;
   for (index_t i = 0; i < scores.size(0); i++) {
@@ -230,7 +230,7 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) 
{
 
 /*!
  * \brief compute areas specialized for nms to reduce computation
- * 
+ *
  * \param i the launched thread index (total thread num_batch * topk)
  * \param out 1d array for areas (size num_batch * num_elem)
  * \param in 1st coordinate of 1st box (buffer + coord_start)
@@ -243,7 +243,7 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) 
{
 struct compute_area {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
-                                  const DType *indices, const DType 
*batch_start,
+                                  const int32_t *indices, const int32_t 
*batch_start,
                                   int topk, int num_elem, int stride, int 
encode) {
     int b = i / topk;
     int k = i % topk;
@@ -302,7 +302,7 @@ MSHADOW_XINLINE DType Intersect(const DType *a, const DType 
*b, int encode) {
    */
 struct nms_impl {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *index, const DType 
*batch_start,
+  MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t 
*batch_start,
                                   const DType *input, const DType *areas,
                                   int k, int ref, int num,
                                   int stride, int offset_box, int offset_id,
@@ -326,8 +326,7 @@ struct nms_impl {
     intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, 
encode);
     int ref_area_offset = static_cast<int>(index[ref]);
     int pos_area_offset = static_cast<int>(index[pos]);
-    DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] -
-      intersect);
+    DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - 
intersect);
     if (iou > thresh) {
       index[pos] = -1;
     }
@@ -336,7 +335,7 @@ struct nms_impl {
 
 /*!
    * \brief Assign output of nms by indexing input
-   * 
+   *
    * \param i the launched thread index (total num_batch)
    * \param out output array [cls, conf, b0, b1, b2, b3]
    * \param record book keeping the selected index for backward
@@ -349,7 +348,7 @@ struct nms_impl {
 struct nms_assign {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType *out, DType *record, const 
DType *input,
-                                  const DType *index, const DType *batch_start,
+                                  const int32_t *index, const int32_t 
*batch_start,
                                   int k, int num, int stride) {
     int count = 0;
     for (int j = 0; j < k; ++j) {
@@ -404,7 +403,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
   int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2);
   int num_elem = in_shape[indim - 2];
   int width_elem = in_shape[indim - 1];
-  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
     Tensor<xpu, 3, DType> data = inputs[box_nms_enum::kData]
      .get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), 
s);
     Tensor<xpu, 3, DType> out = outputs[box_nms_enum::kOut]
@@ -415,25 +414,33 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
     // prepare workspace
     Shape<1> sort_index_shape = Shape1(num_batch * num_elem);
     Shape<3> buffer_shape = Shape3(num_batch, num_elem, width_elem);
-    index_t workspace_size = 4 * sort_index_shape.Size();
     Shape<1> batch_start_shape = Shape1(num_batch + 1);
-    workspace_size += batch_start_shape.Size();
+
+    // index
+    index_t int32_size = sort_index_shape.Size() * 3 + 
batch_start_shape.Size();
+    index_t dtype_size = sort_index_shape.Size() * 2;
     if (req[0] == kWriteInplace) {
-      workspace_size += buffer_shape.Size();
+      dtype_size += buffer_shape.Size();
     }
+    // ceil up when sizeof(DType) is larger than sizeof(DType)
+    index_t int32_offset = (int32_size * sizeof(int32_t) - 1) / sizeof(DType) 
+ 1;
+    index_t workspace_size = int32_offset + dtype_size;
     Tensor<xpu, 1, DType> workspace = ctx.requested[box_nms_enum::kTempSpace]
       .get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
-    Tensor<xpu, 1, DType> sorted_index(workspace.dptr_, sort_index_shape, s);
-    Tensor<xpu, 1, DType> scores(sorted_index.dptr_ + sorted_index.MSize(),
+    Tensor<xpu, 1, int32_t> sorted_index(
+      reinterpret_cast<int32_t*>(workspace.dptr_), sort_index_shape, s);
+    Tensor<xpu, 1, int32_t> all_sorted_index(sorted_index.dptr_ + 
sorted_index.MSize(),
       sort_index_shape, s);
-    Tensor<xpu, 1, DType> batch_id(scores.dptr_ + scores.MSize(), 
sort_index_shape,
-      s);
-    Tensor<xpu, 1, DType> areas(batch_id.dptr_ + batch_id.MSize(), 
sort_index_shape, s);
-    Tensor<xpu, 1, DType> batch_start(areas.dptr_ + areas.MSize(), 
batch_start_shape, s);
+    Tensor<xpu, 1, int32_t> batch_id(
+      all_sorted_index.dptr_ + all_sorted_index.MSize(), sort_index_shape, s);
+    Tensor<xpu, 1, int32_t> batch_start(batch_id.dptr_ + batch_id.MSize(), 
batch_start_shape, s);
+    Tensor<xpu, 1, DType> scores(workspace.dptr_ + int32_offset,
+      sort_index_shape, s);
+    Tensor<xpu, 1, DType> areas(scores.dptr_ + scores.MSize(), 
sort_index_shape, s);
     Tensor<xpu, 3, DType> buffer = data;
     if (req[0] == kWriteInplace) {
       // make copy
-      buffer = Tensor<xpu, 3, DType>(batch_start.dptr_ + batch_start.MSize(), 
buffer_shape, s);
+      buffer = Tensor<xpu, 3, DType>(areas.dptr_ + areas.MSize(), 
buffer_shape, s);
       buffer = F<mshadow_op::identity>(data);
     }
 
@@ -451,10 +458,10 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
     }
 
     // use batch_id and areas as temporary storage
-    Tensor<xpu, 1, DType> all_scores = batch_id;
-    Tensor<xpu, 1, DType> all_sorted_index = areas;
+    Tensor<xpu, 1, DType> all_scores = areas;
+    // Tensor<xpu, 1, DType> all_sorted_index = areas;
     all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), 
all_scores.shape_);
-    all_sorted_index = range<DType>(0, num_batch * num_elem);
+    all_sorted_index = range<int32_t>(0, num_batch * num_elem);
 
     // filter scores but keep original sorted_index value
     // move valid score and index to the front, return valid size
@@ -474,19 +481,19 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
     // only sort the valid scores and batch_id
     Shape<1> valid_score_shape = Shape1(num_valid);
     Tensor<xpu, 1, DType> valid_scores(scores.dptr_, valid_score_shape, s);
-    Tensor<xpu, 1, DType> valid_sorted_index(sorted_index.dptr_, 
valid_score_shape, s);
-    Tensor<xpu, 1, DType> valid_batch_id(batch_id.dptr_, valid_score_shape, s);
+    Tensor<xpu, 1, int32_t> valid_sorted_index(sorted_index.dptr_, 
valid_score_shape, s);
+    Tensor<xpu, 1, int32_t> valid_batch_id(batch_id.dptr_, valid_score_shape, 
s);
 
     // sort index by batch_id then score (stable sort)
     mxnet::op::SortByKey(valid_scores, valid_sorted_index, false);
-    valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / 
ScalarExp<DType>(num_elem));
+    valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
     mxnet::op::SortByKey(valid_batch_id, valid_sorted_index, true);
 
     // calculate batch_start: accumulated sum to denote 1st sorted_index for a 
given batch_index
-    valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / 
ScalarExp<DType>(num_elem));
+    valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
     for (int b = 0; b < num_batch + 1; b++) {
       slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
-        F<mshadow_op::less_than>(valid_batch_id, ScalarExp<DType>(b)), 0);
+        F<mshadow_op::less_than>(valid_batch_id, ScalarExp<int32_t>(b)), 0);
     }
 
     // pre-compute areas of candidates
@@ -721,11 +728,11 @@ inline bool MatchingShape(const nnvm::NodeAttrs& attrs,
 struct bipartite_matching {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType *row_marker, DType *col_marker,
-                                  const DType *scores, const DType 
*sorted_index,
+                                  const DType *scores, const int32_t 
*sorted_index,
                                   int num_batch, int num_row, int num_col,
                                   float threshold, bool is_ascend, int topk) {
     int stride = num_row * num_col;
-    const DType *index = sorted_index + i * stride;
+    const int32_t *index = sorted_index + i * stride;
     const DType *score = scores + i * stride;
     DType *rmarker = row_marker + i * num_row;
     DType *cmarker = col_marker + i * num_col;
@@ -769,7 +776,7 @@ void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs,
   int row = dshape[dshape.ndim() - 2];
   int col = dshape[dshape.ndim() - 1];
   int batch_size = dshape.Size() / row / col;
-  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
     Tensor<xpu, 1, DType> scores = inputs[0]
      .get_with_shape<xpu, 1, DType>(Shape1(dshape.Size()), s);
     Tensor<xpu, 2, DType> row_marker = outputs[0]
@@ -777,23 +784,24 @@ void BipartiteMatchingForward(const nnvm::NodeAttrs& 
attrs,
     Tensor<xpu, 2, DType> col_marker = outputs[1]
      .get_with_shape<xpu, 2, DType>(Shape2(batch_size, col), s);
     Shape<1> sort_index_shape = Shape1(dshape.Size());
-    index_t workspace_size = sort_index_shape.Size() * 3;
+    index_t workspace_size = sort_index_shape.Size();
+    workspace_size += ((sort_index_shape.Size() * sizeof(int32_t) - 1) / 
sizeof(DType)) * 2;
     Tensor<xpu, 1, DType> workspace = ctx.requested[0]
       .get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
-    Tensor<xpu, 1, DType> sorted_index(workspace.dptr_,
-      sort_index_shape, s);
-    Tensor<xpu, 1, DType> batch_id(sorted_index.dptr_ + sorted_index.MSize(),
+    Tensor<xpu, 1, DType> scores_copy(workspace.dptr_,
       sort_index_shape, s);
-    Tensor<xpu, 1, DType> scores_copy(batch_id.dptr_ + batch_id.MSize(),
+    Tensor<xpu, 1, int32_t> sorted_index(reinterpret_cast<int32_t*>(
+      scores_copy.dptr_ + scores_copy.MSize()), sort_index_shape, s);
+    Tensor<xpu, 1, int32_t> batch_id(sorted_index.dptr_ + sorted_index.MSize(),
       sort_index_shape, s);
 
     // sort according to score
     scores_copy = F<mshadow_op::identity>(scores);
-    sorted_index = range<DType>(0, dshape.Size());
+    sorted_index = range<int32_t>(0, dshape.Size());
     mxnet::op::SortByKey(scores_copy, sorted_index, param.is_ascend);
-    batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(row * 
col));
+    batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
     mxnet::op::SortByKey(batch_id, scores_copy, true);
-    batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(row * 
col));
+    batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
     mxnet::op::SortByKey(batch_id, sorted_index, true);
 
     // bipartite matching, parallelization is limited to batch_size
diff --git a/src/operator/tensor/sort_op-inl.cuh 
b/src/operator/tensor/sort_op-inl.cuh
index 5ad31053f92..1a8e2325ef4 100644
--- a/src/operator/tensor/sort_op-inl.cuh
+++ b/src/operator/tensor/sort_op-inl.cuh
@@ -24,6 +24,7 @@
  */
 #ifndef MXNET_OPERATOR_TENSOR_SORT_OP_INL_CUH_
 #define MXNET_OPERATOR_TENSOR_SORT_OP_INL_CUH_
+#include <type_traits>
 #include <thrust/device_ptr.h>
 #include <thrust/sort.h>
 #if defined(_MSC_VER) && __CUDACC_VER_MAJOR__ == 8 && __CUDACC_VER_BUILD__ != 
44
@@ -40,6 +41,29 @@
 
 namespace mxnet {
 namespace op {
+namespace cuda {
+template<typename T>
+struct less_half
+{
+  typedef T first_argument_type;
+  typedef T second_argument_type;
+  typedef bool result_type;
+  __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
+    return static_cast<mshadow::half::half_t>(lhs) < 
static_cast<mshadow::half::half_t>(rhs);
+  }
+};
+
+template<typename T>
+struct greater_half
+{
+  typedef T first_argument_type;
+  typedef T second_argument_type;
+  typedef bool result_type;
+  __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
+    return static_cast<mshadow::half::half_t>(lhs) < 
static_cast<mshadow::half::half_t>(rhs);
+  }
+};
+}
 
 template <typename KDType, typename VDType, typename xpu>
 inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
@@ -57,9 +81,12 @@ SortByKeyWorkspaceSize(const size_t num_keys) {
 }
 
 template<typename KDType, typename VDType>
-inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, 
mshadow::Tensor<gpu, 1, VDType> values,
-                      bool is_ascend, mshadow::Tensor<gpu, 1, char>* workspace,
-                      const int begin_bit, const int end_bit) {
+inline typename 
std::enable_if<!(std::is_same<KDType,mshadow::half::half_t>::value ||
+                                 
std::is_same<VDType,mshadow::half::half_t>::value), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+              mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+              mshadow::Tensor<gpu, 1, char>* workspace,
+              const int begin_bit, const int end_bit) {
   CHECK_EQ(keys.CheckContiguous(), true);
   CHECK_EQ(values.CheckContiguous(), true);
 #if CUDA_VERSION >= 7000
@@ -128,18 +155,100 @@ inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> 
keys, mshadow::Tensor<gpu,
 #endif
 }
 
-template<typename DType>
-inline void SortByKey(mshadow::Tensor<gpu, 1, mshadow::half::half_t> keys,
-  mshadow::Tensor<gpu, 1, DType> values, bool is_ascend,
-  mshadow::Tensor<gpu, 1, char>* workspace, const int begin_bit, const int 
end_bit) {
-  LOG(FATAL) << "SortByKey for half_t is not implemented!";
+template<typename KDType, typename VDType>
+inline typename 
std::enable_if<((!std::is_same<KDType,mshadow::half::half_t>::value) &&
+                                
std::is_same<VDType,mshadow::half::half_t>::value), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+              mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+              mshadow::Tensor<gpu, 1, char>* workspace,
+              const int begin_bit, const int end_bit) {
+  CHECK_EQ(keys.CheckContiguous(), true);
+  CHECK_EQ(values.CheckContiguous(), true);
+#if CUDA_VERSION >= 9000
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
+  thrust::device_ptr<KDType> key_iter = 
thrust::device_pointer_cast(keys.dptr_);
+  thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
+    reinterpret_cast<half*>(values.dptr_));
+  if (is_ascend) {
+    thrust::stable_sort_by_key(
+      thrust::cuda::par.on(stream),
+      key_iter.get(), key_iter.get() + (keys.size(0)), value_iter.get(), 
thrust::less<KDType>());
+  } else {
+    thrust::stable_sort_by_key(
+      thrust::cuda::par.on(stream),
+      key_iter.get(), key_iter.get() + (keys.size(0)), value_iter.get(), 
thrust::greater<KDType>());
+  }
+  MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
+#else
+  LOG(FATAL) << "SortByKey with fp16 values is only supported for CUDA version 
>= 9.0";
+#endif
+}
+
+template<typename KDType, typename VDType>
+inline typename 
std::enable_if<(std::is_same<KDType,mshadow::half::half_t>::value &&
+                                
(!std::is_same<VDType,mshadow::half::half_t>::value)), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+              mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+              mshadow::Tensor<gpu, 1, char>* workspace,
+              const int begin_bit, const int end_bit) {
+  CHECK_EQ(keys.CheckContiguous(), true);
+  CHECK_EQ(values.CheckContiguous(), true);
+#if CUDA_VERSION >= 9000
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
+  thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
+    reinterpret_cast<half*>(keys.dptr_));
+  thrust::device_ptr<VDType> value_iter = 
thrust::device_pointer_cast(values.dptr_);
+  if (is_ascend) {
+    thrust::stable_sort_by_key(
+      thrust::cuda::par.on(stream),
+      key_iter, key_iter + (keys.size(0)), value_iter, 
cuda::less_half<half>());
+  } else {
+    thrust::stable_sort_by_key(
+      thrust::cuda::par.on(stream),
+      key_iter, key_iter + (keys.size(0)), value_iter, 
cuda::greater_half<half>());
+  }
+  MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
+#else
+  LOG(FATAL) << "SortByKey with fp16 keys is only supported for CUDA version 
>= 9.0";
+#endif
+}
+
+// use thrust sorting when keys or values are half_t
+template<typename KDType, typename VDType>
+inline typename 
std::enable_if<(std::is_same<KDType,mshadow::half::half_t>::value &&
+                                
std::is_same<VDType,mshadow::half::half_t>::value), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+              mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+              mshadow::Tensor<gpu, 1, char>* workspace,
+              const int begin_bit, const int end_bit) {
+  CHECK_EQ(keys.CheckContiguous(), true);
+  CHECK_EQ(values.CheckContiguous(), true);
+#if CUDA_VERSION >= 9000
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
+  thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
+    reinterpret_cast<half*>(keys.dptr_));
+  thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
+    reinterpret_cast<half*>(values.dptr_));
+  if (is_ascend) {
+    thrust::stable_sort_by_key(
+      thrust::cuda::par.on(stream),
+      key_iter, key_iter + (keys.size(0)), value_iter, 
cuda::less_half<half>());
+  } else {
+    thrust::stable_sort_by_key(
+      thrust::cuda::par.on(stream),
+      key_iter, key_iter + (keys.size(0)), value_iter, 
cuda::greater_half<half>());
+  }
+  MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
+#else
+  LOG(FATAL) << "SortByKey with fp16 keys and values is only supported for 
CUDA version >= 9.0";
+#endif
 }
 
-template<typename DType>
-inline void SortByKey(mshadow::Tensor<gpu, 1, DType> keys,
-  mshadow::Tensor<gpu, 1, mshadow::half::half_t> values, bool is_ascend,
-  mshadow::Tensor<gpu, 1, char>* workspace, const int begin_bit, const int 
end_bit) {
-  LOG(FATAL) << "SortByKey for half_t is not implemented!";
+template<typename KDType, typename VDType>
+inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, 
mshadow::Tensor<gpu, 1, VDType> values,
+                      bool is_ascend, mshadow::Tensor<gpu, 1, char>* workspace,
+                      const int begin_bit, const int end_bit) {
+  SortByKeyImpl(keys, values, is_ascend, workspace, begin_bit, end_bit);
 }
 
 }  // namespace op
diff --git a/tests/python/unittest/test_contrib_operator.py 
b/tests/python/unittest/test_contrib_operator.py
index a220f08d20d..fc6c1be9c3a 100644
--- a/tests/python/unittest/test_contrib_operator.py
+++ b/tests/python/unittest/test_contrib_operator.py
@@ -28,11 +28,12 @@
 def test_box_nms_op():
     def test_box_nms_forward(data, expected, thresh=0.5, valid=0, topk=-1, 
coord=2, score=1, cid=0,
                          force=False, in_format='corner', out_format='corner'):
-        data = mx.nd.array(data)
-        out = mx.contrib.nd.box_nms(data, overlap_thresh=thresh, 
valid_thresh=valid, topk=topk,
-                                coord_start=coord, score_index=score, 
id_index=cid,
-                                force_suppress=force, in_format=in_format, 
out_format=out_format)
-        assert_almost_equal(out.asnumpy(), expected)
+        for dtype in ['float16', 'float32', 'float64']:
+            data = mx.nd.array(data, dtype=dtype)
+            out = mx.contrib.nd.box_nms(data, overlap_thresh=thresh, 
valid_thresh=valid, topk=topk,
+                                    coord_start=coord, score_index=score, 
id_index=cid,
+                                    force_suppress=force, in_format=in_format, 
out_format=out_format)
+            assert_almost_equal(out.asnumpy(), expected.astype(dtype), 
rtol=1e-3, atol=1e-3)
 
     def test_box_nms_backward(data, grad, expected, thresh=0.5, valid=0, 
topk=-1, coord=2, score=1,
                           cid=0, force=False, in_format='corner', 
out_format='corner'):
@@ -233,13 +234,13 @@ def generate_boxes(dims):
 
 def test_bipartite_matching_op():
     def assert_match(inputs, x, y, threshold, is_ascend=False):
-        inputs = mx.nd.array(inputs)
-        x = np.array(x)
-        y = np.array(y)
-        a, b = mx.nd.contrib.bipartite_matching(inputs, threshold=threshold, 
is_ascend=is_ascend)
-        print(a, b)
-        assert_array_equal(a.asnumpy().astype('int64'), x.astype('int64'))
-        assert_array_equal(b.asnumpy().astype('int64'), y.astype('int64'))
+        for dtype in ['float16', 'float32', 'float64']:
+            inputs = mx.nd.array(inputs, dtype=dtype)
+            x = np.array(x, dtype=dtype)
+            y = np.array(y, dtype=dtype)
+            a, b = mx.nd.contrib.bipartite_matching(inputs, 
threshold=threshold, is_ascend=is_ascend)
+            assert_array_equal(a.asnumpy().astype('int64'), x.astype('int64'))
+            assert_array_equal(b.asnumpy().astype('int64'), y.astype('int64'))
     assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 
1e-12, False)
     assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 
100, True)
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to