anirudh2290 commented on a change in pull request #11326: [MXNET-381] Enhancement of take operator URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r198700885
########## File path: src/operator/tensor/indexing_op.h ########## @@ -321,10 +309,53 @@ struct Take { MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, const int M, const int K) { int j = static_cast<int>(idx[i/M]); - if (j <= 0) j = 0; - else if (j >= K) j = K - 1; + if (clip) { + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + } else { + j = j % K; + j += (j < 0) ? K : 0; + } out_data[i] = in_data[j * M + i % M]; } + + /*! + * \brief Map function for take operator + * \param i global thread id + * \param out_data ptr to output buffer + * \param in_data ptr to input buffer + * \param idx ptr to indices buffer + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template<typename DType, typename IType> + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, + const mshadow::Shape<10> in_stride, + const mshadow::Shape<10> out_stride, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis_dim, const int axis) { + // i is the global flattened index in the output + const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]); Review comment: okay makes sense. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services