This is an automated email from the ASF dual-hosted git repository.

apeforest pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b5f376  [MXNET-1413] Adding Large Tensor support for sort operators 
(#15170)
8b5f376 is described below

commit 8b5f376d17b644385706016caf8b1e58e95d96df
Author: Rohit Kumar Srivastava <srivastava....@osu.edu>
AuthorDate: Fri Jun 21 13:21:20 2019 -0700

    [MXNET-1413] Adding Large Tensor support for sort operators (#15170)
---
 src/operator/tensor/init_op.h         |   2 +-
 src/operator/tensor/ordering_op-inl.h | 190 ++++++++++++++++++++--------------
 tests/nightly/test_large_array.py     |  26 +++++
 tests/python/unittest/test_ndarray.py |  13 ++-
 4 files changed, 149 insertions(+), 82 deletions(-)

diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index fd49153..c7a1054 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -487,7 +487,7 @@ void EyeFill(const nnvm::NodeAttrs& attrs,
 
 struct range_fwd {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(index_t i, int repeat, DType start, DType 
step,
+  MSHADOW_XINLINE static void Map(index_t i, index_t repeat, DType start, 
DType step,
                                   int req, DType* out) {
     KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step);
   }
diff --git a/src/operator/tensor/ordering_op-inl.h 
b/src/operator/tensor/ordering_op-inl.h
index 1dda901..98bca3a 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -81,12 +81,18 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
       .describe("Whether to choose k largest or k smallest elements."
                 " Top K largest elements will be chosen if set to false.");
     DMLC_DECLARE_FIELD(dtype)
+    // TODO(srivrohi): remove support for real data type in mxnet-2.0
     .add_enum("uint8", mshadow::kUint8)
     .add_enum("int32", mshadow::kInt32)
+    .add_enum("int64", mshadow::kInt64)
     .add_enum("float16", mshadow::kFloat16)
     .add_enum("float32", mshadow::kFloat32)
     .add_enum("float64", mshadow::kFloat64)
-    .set_default(mshadow::kFloat32)
+#if MXNET_USE_INT64_TENSOR_SIZE == 1
+    .set_default(mshadow::kInt64)
+#else
+    .set_default(mshadow::kInt32)
+#endif
     .describe("DType of the output indices when ret_typ is \"indices\" or 
\"both\". "
               "An error will be raised if the selected data type cannot 
precisely represent the "
               "indices.");
@@ -116,21 +122,33 @@ struct ArgSortParam : public 
dmlc::Parameter<ArgSortParam> {
     DMLC_DECLARE_FIELD(is_ascend).set_default(true)
       .describe("Whether to sort in ascending or descending order.");
     DMLC_DECLARE_FIELD(dtype)
+    // TODO(srivrohi): remove support for real data type in mxnet-2.0
     .add_enum("uint8", mshadow::kUint8)
     .add_enum("int32", mshadow::kInt32)
+    .add_enum("int64", mshadow::kInt64)
     .add_enum("float16", mshadow::kFloat16)
     .add_enum("float32", mshadow::kFloat32)
     .add_enum("float64", mshadow::kFloat64)
-    .set_default(mshadow::kFloat32)
+#if USE_INT64_TENSOR_SIZE == 1
+    .set_default(mshadow::kInt64)
+#else
+    .set_default(mshadow::kInt32)
+#endif
     .describe("DType of the output indices. It is only valid when ret_typ is 
\"indices\" or"
               " \"both\". An error will be raised if the selected data type 
cannot precisely "
               "represent the indices.");
   }
 };
 
-inline void ParseTopKParam(const mxnet::TShape& src_shape, const TopKParam& 
param,
-                           mxnet::TShape *target_shape, int *batch_size, int 
*element_num,
-                           int *axis, int *k, bool *do_transpose, bool 
*is_ascend) {
+inline void ParseTopKParam(const TShape& src_shape,
+                           const TopKParam& param,
+                           TShape *target_shape,
+                           size_t *batch_size,
+                           index_t *element_num,
+                           int *axis,
+                           index_t *k,
+                           bool *do_transpose,
+                           bool *is_ascend) {
   *do_transpose = false;
   *k = param.k;
   *is_ascend = param.is_ascend;
@@ -179,14 +197,14 @@ using namespace mshadow;
 
 struct fill_ind_to_one {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, const int* indices, DType* out) {
+  MSHADOW_XINLINE static void Map(int i, const index_t* indices, DType* out) {
     out[indices[i]] = static_cast<DType>(1);
   }
 };
 
 struct fill_ind {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, const int* indices, const DType* val,
+  MSHADOW_XINLINE static void Map(int i, const index_t* indices, const DType* 
val,
                                   int req, DType* out) {
     KERNEL_ASSIGN(out[indices[i]], req, val[i]);
   }
@@ -194,39 +212,43 @@ struct fill_ind {
 
 template<typename DType>
 MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
-                                   const Tensor<cpu, 1, int>& ind,
+                                   const Tensor<cpu, 1, index_t>& ind,
                                    const Tensor<cpu, 1, char>& work,
-                                   int K, int N, bool is_ascend,
+                                   index_t K, index_t N, bool is_ascend,
                                    Stream<cpu> *s) {
   // Use full sort when K is relatively large.
   const bool full_sort(K*8 > N);
   // Batch size.
-  const int M(work.size(0)/(sizeof(DType)*N));
+  const index_t M(work.size(0)/(sizeof(DType)*N));
   const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < M; ++i) {
+  for (index_t i = 0; i < M; ++i) {
     // Tensor `work` stores the flattened source data, while `dat` stores the 
sorted result.
     DType *vals = reinterpret_cast<DType*>(work.dptr_);
     DType *sorted_vals = dat.dptr_+i*N;
-    int *indices = ind.dptr_+i*N;
+    index_t *indices = ind.dptr_+i*N;
     if (is_ascend) {
       if (full_sort) {
         std::sort(indices, indices+N,
-                  [&](const int& i1, const int& i2){ return vals[i1] < 
vals[i2]; });
+                  [&](const index_t& i1, const index_t& i2){
+          return vals[i1] < vals[i2]; });
       } else {
         std::partial_sort(indices, indices+K, indices+N,
-                          [&](const int& i1, const int& i2){ return vals[i1] < 
vals[i2]; });
+                          [&](const index_t& i1, const index_t& i2){
+          return vals[i1] < vals[i2]; });
       }
     } else {
       if (full_sort) {
         std::sort(indices, indices+N,
-                  [&](const int& i1, const int& i2){ return vals[i1] > 
vals[i2]; });
+                  [&](const index_t& i1, const index_t& i2){
+          return vals[i1] > vals[i2]; });
       } else {
         std::partial_sort(indices, indices+K, indices+N,
-                          [&](const int& i1, const int& i2){ return vals[i1] > 
vals[i2]; });
+                          [&](const index_t& i1, const index_t& i2){
+          return vals[i1] > vals[i2]; });
       }
     }
-    for (int j = 0; j < K; ++j) {
+    for (index_t j = 0; j < K; ++j) {
       sorted_vals[j] = vals[indices[j]];
     }
   }
@@ -235,18 +257,19 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, 
DType>& dat,
 #ifdef __CUDACC__
 
 template<typename DType>
-MSHADOW_XINLINE bool TopKCompare(DType val1, int ind1, DType val2, int ind2, 
bool is_ascend) {
+MSHADOW_XINLINE bool TopKCompare(DType val1, index_t ind1, DType val2, index_t 
ind2,
+                                 bool is_ascend) {
   // Negative indices denote undefined values which are considered arbitrary 
small resp. large.
   return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || 
(!is_ascend && val1 > val2)));
 }
 
 template<typename DType>
-MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int 
*ind2,
+MSHADOW_XINLINE void MergeTopK(index_t K, DType *val1, index_t *ind1, DType 
*val2, index_t *ind2,
                                bool is_ascend) {
   // In-place merge of two sorted top-K lists into val1/ind1. First determine 
the intervals
   // [0,..,i1], [0,..i2] of the two lists that will be part of the merged list.
-  int i1(K-1), i2(K-1);
-  for (int i = 0; i < K; ++i) {
+  index_t i1(K-1), i2(K-1);
+  for (index_t i = 0; i < K; ++i) {
     if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) {
       --i2;
     } else {
@@ -254,7 +277,7 @@ MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int 
*ind1, DType *val2, int *
     }
   }
   // Now merge the lists from back to front.
-  for (int i = K; i--;) {
+  for (index_t i = K; i--;) {
     if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], 
ind1[i1], is_ascend)) {
       val1[i] = val1[i1];
       ind1[i] = ind1[i1];
@@ -268,28 +291,29 @@ MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int 
*ind1, DType *val2, int *
 }
 
 template<typename DType>
-__global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool 
is_ascend) {
+__global__ void PartialSortSmallK(index_t K, index_t N, DType *val, index_t 
*ind, bool is_ascend) {
   // Buffer for pairwise reduction.
-  extern __shared__ int buff[];
+  extern __shared__ index_t buff[];
   // Start of buffer sections associated with this thread.
-  const int offset(threadIdx.x*K);
-  int *ind_buff = &buff[offset];
+  const index_t offset(threadIdx.x*K);
+  index_t *ind_buff = &buff[offset];
   DType *val_buff = reinterpret_cast<DType*>(&buff[blockDim.x*K])+offset;
   // Initialize top-K values for this thread.
-  for (int i = 0; i < K; ++i) {
+  for (index_t i = 0; i < K; ++i) {
     ind_buff[i] = -1;
   }
   // Range of values this thread cares about. Each thread block processes
   // a different batch item (i.e. a different set of ind/val where we
   // have to select the top-K elements). All threads within the same
   // block work on the same batch item.
-  const int first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N);
+  const index_t first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N);
   // Select top-K from this range and store it sorted in the buffer.
   // We assume a small K, so linear insertion is o.k.
-  for (int i = first; i < last; i += blockDim.x) {
+  for (index_t i = first; i < last; i += blockDim.x) {
     DType cur_val(val[i]);
-    int cur_ind(ind[i]);
-    for (int j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], 
ind_buff[j], is_ascend); ) {
+    index_t cur_ind(ind[i]);
+    for (index_t j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j],
+                                           ind_buff[j], is_ascend); ) {
       if (j+1 < K) {
         val_buff[j+1] = val_buff[j];
         ind_buff[j+1] = ind_buff[j];
@@ -300,7 +324,7 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, 
int *ind, bool is_as
   }
   // Recursive merge of sorted lists for this thread block. Note that 
blockDim.x is not
   // necessary a power of two, therefore the additional checks for last_s.
-  for (unsigned int s = (blockDim.x+1)/2, last_s = blockDim.x;
+  for (index_t s = (blockDim.x+1)/2, last_s = blockDim.x;
        last_s > 1; last_s = s, s = (s+1)/2) {
     __syncthreads();
     if (threadIdx.x < s && threadIdx.x+s < last_s) {
@@ -309,7 +333,7 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, 
int *ind, bool is_as
   }
   // Final updates on master thread.
   if (threadIdx.x == 0) {
-    for (int i = 0; i < K; ++i) {
+    for (index_t i = 0; i < K; ++i) {
       ind[blockIdx.x*N+i] = ind_buff[i];
       val[blockIdx.x*N+i] = val_buff[i];
     }
@@ -318,20 +342,21 @@ __global__ void PartialSortSmallK(int K, int N, DType 
*val, int *ind, bool is_as
 
 template<typename DType>
 MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
-                                   const Tensor<gpu, 1, int>& ind,
+                                   const Tensor<gpu, 1, index_t>& ind,
                                    const Tensor<gpu, 1, char>& work,
-                                   int K, int N, bool is_ascend,
+                                   index_t K, index_t N, bool is_ascend,
                                    Stream<gpu> *s) {
   // Use full sort for all but very small K for which we
   // can do a partial sort entirely within shared memory.
   const bool full_sort(K > 5);
   // Batch size.
-  const int M(dat.size(0)/N);
+  const index_t M(dat.size(0)/N);
   if (full_sort) {
     // Divide workspace into two parts. The first one is needed to store batch 
ids.
-    size_t alignment = std::max(sizeof(DType), sizeof(int));
-    size_t id_size = PadBytes(sizeof(int) * ind.size(0), alignment);
-    Tensor<gpu, 1, int> batch_id(reinterpret_cast<int*>(work.dptr_), 
Shape1(ind.size(0)), s);
+    size_t alignment = std::max(sizeof(DType), sizeof(index_t));
+    size_t id_size = PadBytes(sizeof(index_t) * ind.size(0), alignment);
+    Tensor<gpu, 1, index_t> batch_id(reinterpret_cast<index_t*>(work.dptr_),
+                                     Shape1(ind.size(0)), s);
     Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, 
Shape1(work.size(0)-id_size), s);
     mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
     if (M > 1) {
@@ -380,20 +405,22 @@ void TopKImpl(const RunContext &ctx,
   Tensor<xpu, 1, char> workspace;
   Tensor<xpu, 1, char> temp_workspace;
   Tensor<xpu, 1, DType> sorted_dat;
-  Tensor<xpu, 1, int> indices, sel_indices;
-  int batch_size, element_num;  // number of batches + the size of each batch
+  Tensor<xpu, 1, index_t> indices, sel_indices;
+  size_t batch_size = 0;
+  index_t element_num = 0;  // number of batches + the size of each batch
   int axis = 0;
   bool do_transpose = false;
   bool is_ascend = false;
-  int k = 0;
+  index_t k = 0;
   size_t alignment = std::max(sizeof(DType), sizeof(int));
   mxnet::TShape target_shape;
   ParseTopKParam(src.shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, 
&do_transpose, &is_ascend);
-  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
-    << "'IDType' does not have a sufficient precision to represent the indices 
of the input array. "
-    << "The total element_num is " << element_num << ", but the selected 
IDType can only represent "
-    << mxnet::common::MaxIntegerValue<IDType>() << " elements";
+  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
+    << "'index_t' does not have a sufficient precision to represent "
+    << "the indices of the input array. The total element_num is "
+    << element_num << ", but the selected index_t can only represent "
+    << mxnet::common::MaxIntegerValue<index_t>() << " elements";
   Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
   size_t temp_size = 0;
   // Temp space needed by the gpu-based full sorts.
@@ -404,11 +431,11 @@ void TopKImpl(const RunContext &ctx,
   temp_size = std::max(temp_size,
     mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
   // Additional temp space for gpu full sorts for batch ids.
-  temp_size += PadBytes(sizeof(int) * src.Size(), alignment);
+  temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
   // Temp space for cpu sorts.
   temp_size = std::max(temp_size, static_cast<size_t>(sizeof(DType) * 
src.Size()));
   size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), 
alignment)
-                                    + PadBytes(sizeof(int) * src.Size(), 
alignment);
+                                    + PadBytes(sizeof(index_t) * src.Size(), 
alignment);
   if (param.ret_typ == topk_enum::kReturnMask) {
     workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment);
   }
@@ -417,14 +444,14 @@ void TopKImpl(const RunContext &ctx,
   sorted_dat = Tensor<xpu, 1, 
DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape1(src.Size()), s);  // contain 
sorted dat
   workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
-  indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
+  indices = Tensor<xpu, 1, 
index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
                                 Shape1(src.Size()), s);  // indices in the 
original matrix
-  workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment);
+  workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
 
   if (param.ret_typ == topk_enum::kReturnMask) {
-    sel_indices = Tensor<xpu, 1, 
int>(reinterpret_cast<int*>(workspace_curr_ptr),
+    sel_indices = Tensor<xpu, 1, 
index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
                                       Shape1(batch_size * k), s);
-    workspace_curr_ptr += PadBytes(sizeof(int) * batch_size * k, alignment);
+    workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k, 
alignment);
     CHECK_EQ(sel_indices.CheckContiguous(), true);
   }
 
@@ -454,7 +481,7 @@ void TopKImpl(const RunContext &ctx,
     workspace_curr_ptr += temp_size;
   }
 
-  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 
1,
+  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 
index_t{0}, index_t{1},
     kWriteTo, indices.dptr_);
   CHECK_EQ(indices.CheckContiguous(), true);
 
@@ -551,7 +578,7 @@ void TopK(const nnvm::NodeAttrs& attrs,
     });
   } else {
     MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-      TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], 
outputs, param);
+      TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, 
inputs[0], outputs, param);
     });
   }
 }
@@ -569,7 +596,8 @@ void Sort(const nnvm::NodeAttrs& attrs,
   topk_param.k = 0;
   topk_param.ret_typ = topk_enum::kReturnValue;
   MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], 
outputs, topk_param);
+    TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, 
inputs[0],
+                                  outputs, topk_param);
   });
 }
 
@@ -605,30 +633,32 @@ void TopKBackwardImpl(const OpContext &ctx,
   using namespace mshadow::expr;
   Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
   CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == 
topk_enum::kReturnBoth);
-  int batch_size, element_num;  // number of batches + the size of each batch
+  size_t batch_size = 0;
+  index_t element_num = 0;  // number of batches + the size of each batch
   int axis = 0;
   bool do_transpose = false;
   bool is_ascend = false;
-  int k = 0;
+  index_t k = 0;
   mxnet::TShape target_shape;
   ParseTopKParam(outputs[0].shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, 
&do_transpose, &is_ascend);
   CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
-    << "'IDType' does not have a sufficient precision to represent the indices 
of the input array. "
-    << "The total element_num is " << element_num << ", but the selected 
IDType can only represent "
+    << "'IDType' does not have a sufficient precision to represent "
+    << "the indices of the input array. The total element_num is " << 
element_num
+    << ", but the selected index_t can only represent "
     << mxnet::common::MaxIntegerValue<IDType>() << " elements";
-  Tensor<xpu, 1, int> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, int>(Shape1(batch_size * k + 
batch_size), s);
-  Tensor<xpu, 1, int> sel_indices =
-    Tensor<xpu, 1, int>(workspace.dptr_, Shape1(batch_size * k), s);
-  Tensor<xpu, 1, int> batch_shift =
-    Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k, Shape1(batch_size), 
s);
+  Tensor<xpu, 1, index_t> workspace =
+    ctx.requested[0].get_space_typed<xpu, 1, index_t>(Shape1(batch_size * k + 
batch_size), s);
+  Tensor<xpu, 1, index_t> sel_indices =
+    Tensor<xpu, 1, index_t>(workspace.dptr_, Shape1(batch_size * k), s);
+  Tensor<xpu, 1, index_t> batch_shift =
+    Tensor<xpu, 1, index_t>(workspace.dptr_ + batch_size * k, 
Shape1(batch_size), s);
 
   Tensor<xpu, 2, DType> out_grad =
     inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 
1), s);
   Tensor<xpu, 2, DType> in_grad =
     outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 
1), s);
-  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0, element_num, 
kWriteTo,
+  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, index_t{0}, 
element_num, kWriteTo,
                                            batch_shift.dptr_);
   if (do_transpose) {
     Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
@@ -639,13 +669,13 @@ void TopKBackwardImpl(const OpContext &ctx,
                                          mxnet::TShape(Shape3(src_shape[0], 
src_shape[2], k))),
                             Shape3(0, 2, 1)),
                           Shape1(batch_size * k));
-    sel_indices += tcast<int>(indices);
+    sel_indices += tcast<index_t>(indices);
     sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], 
src_shape[2], src_shape[1]),
                                     Shape3(0, 2, 1));
   } else {
     Tensor<xpu, 2, IDType> indices =
       inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
-    sel_indices = reshape(tcast<int>(indices) +
+    sel_indices = reshape(tcast<index_t>(indices) +
                           broadcast_to(inplace_reshape(batch_shift, 
Shape2(batch_size, 1)),
                                        mxnet::TShape(Shape2(batch_size, k))),
                           Shape1(batch_size * k));
@@ -680,7 +710,7 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
     });
   } else if (param.ret_typ == topk_enum::kReturnValue) {
     MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-      TopKBackwardImpl<xpu, DType, int>(ctx, inputs, req, outputs, param);
+      TopKBackwardImpl<xpu, DType, index_t>(ctx, inputs, req, outputs, param);
     });
   } else {
     LOG(FATAL) << "Not Implemented";
@@ -715,14 +745,11 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs,
   size_t out_size = out_attrs->size();
   CHECK_EQ(in_size, 1);
   CHECK(out_size == 1 || out_size == 2);
+  //  out_attr[0] -> stores value
+  //  out_attr[1] -> stores indices
   if (out_size > 1) {
-    if (param.ret_typ == topk_enum::kReturnValue) {
-      CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
+    CHECK(type_assign(&(*out_attrs)[1], param.dtype))
         << "Failed to set the type of ret_indices.";
-    } else {
-      CHECK(type_assign(&(*out_attrs)[1], param.dtype))
-        << "Failed to set the type of ret_indices.";
-    }
   }
   if (param.ret_typ == topk_enum::kReturnIndices) {
     CHECK(type_assign(&(*out_attrs)[0], param.dtype))
@@ -752,11 +779,12 @@ inline bool TopKShapeImpl(const TopKParam& param,
     CHECK_EQ(out_attrs->size(), 2U);
   }
   mxnet::TShape& in_shape = (*in_attrs)[0];
-  int batch_size, element_num;  // number of batches + the size of each batch
+  size_t batch_size = 0;
+  index_t element_num = 0;  // number of batches + the size of each batch
   int axis = 0;
   bool do_transpose = false;
   bool is_ascend = false;
-  int k = 0;
+  index_t k = 0;
   mxnet::TShape target_shape;
   ParseTopKParam(in_shape, param,
     &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, 
&is_ascend);
@@ -785,8 +813,12 @@ inline bool SortType(const nnvm::NodeAttrs& attrs,
   size_t out_size = out_attrs->size();
   CHECK_EQ(in_size, 1);
   CHECK_EQ(out_size, 2);
+#if MXNET_USE_INT64_TENSOR_SIZE == 1
+  CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
+#else
   CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
-          << "Failed to set the type of ret_indices to int32.";
+#endif
+      << "Failed to set the type of ret_indices";
   CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of 
input, in_attrs[0]="
                                                  << (*in_attrs)[0];
   CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of 
output, out_attrs[0]="
@@ -816,7 +848,7 @@ inline bool ArgSortType(const nnvm::NodeAttrs& attrs,
                         std::vector<int> *out_attrs) {
   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
   CHECK(type_assign(&(*out_attrs)[0], param.dtype))
-          << "Failed to set the type of ret_indices to int32.";
+      << "Failed to set the type of ret_indices.";
   return true;
 }
 
diff --git a/tests/nightly/test_large_array.py 
b/tests/nightly/test_large_array.py
index cbba608..0df481a 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -326,6 +326,32 @@ def test_softmax():
     assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)
 
 
+def test_argsort():
+    b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
+    s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64)
+    mx.nd.waitall()
+    assert (s[0].asnumpy() == (LARGE_X - 1)).all()
+
+
+def test_sort():
+    b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
+    s = nd.sort(b, axis=0, is_ascend=False)
+    assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all()
+    s = nd.sort(b, is_ascend=False)
+    assert np.sum(s[0].asnumpy() == 0).all()
+
+
+def test_topk():
+    b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
+    k = nd.topk(b, k=10, axis=0, dtype=np.int64)
+    assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y
+    ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", 
is_ascend=False)
+    assert np.all(ind == val)
+    b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
+    l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value")
+    assert l.sum() == np.sum(np.arange(0, SMALL_Y))
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index e531590..d84b4f0 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -29,6 +29,7 @@ from mxnet.test_utils import default_context
 from mxnet.test_utils import np_reduce
 from mxnet.test_utils import same
 from mxnet.test_utils import random_sample, rand_shape_nd
+from mxnet import runtime
 from numpy.testing import assert_allclose
 import mxnet.autograd
 
@@ -747,6 +748,7 @@ def test_linspace():
 def test_order():
     ctx = default_context()
     dat_size = 5
+    is_large_tensor_enabled = 
runtime.Features().is_enabled('INT64_TENSOR_SIZE')
     def gt_topk(dat, axis, ret_typ, k, is_ascend):
         if ret_typ == "indices":
             if is_ascend:
@@ -819,7 +821,11 @@ def test_order():
 
         # test for ret_typ=indices
         nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, 
is_ascend=True).asnumpy()
-        assert nd_ret_topk.dtype == np.float32  # Test the default dtype
+        # Test the default dtype
+        if is_large_tensor_enabled:
+            assert nd_ret_topk.dtype == np.int64
+        else:
+            assert nd_ret_topk.dtype == np.int32
         gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
         assert_almost_equal(nd_ret_topk, gt)
         nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, 
is_ascend=False, dtype=np.float64).asnumpy()
@@ -860,7 +866,10 @@ def test_order():
         nd_ret_topk_val = nd_ret_topk_val.asnumpy()
         nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
         assert nd_ret_topk_val.dtype == dtype
-        assert nd_ret_topk_ind.dtype == np.float32
+        if is_large_tensor_enabled:
+            assert nd_ret_topk_ind.dtype == np.int64
+        else:
+            assert nd_ret_topk_ind.dtype == np.int32
         gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
         gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
         assert_almost_equal(nd_ret_topk_val, gt_val)

Reply via email to