haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement of take operator URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r198671449
########## File path: src/operator/tensor/indexing_op.h ########## @@ -805,17 +836,259 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const TShape& oshape = outputs[take_::kOut].shape_; Stream<xpu> *s = ctx.get_stream<xpu>(); + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type - Kernel<Take, xpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr<DType>(), - inputs[take_::kArr].dptr<DType>(), - inputs[take_::kIdx].dptr<IType>(), - oshape.Size()/idxshape.Size(), arrshape[0]); + if (actual_axis == 0) { + if (param.mode == take_::kClip) { + Kernel<Take<true>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } else { + Kernel<Take<false>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } + } else { + mshadow::Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + mshadow::Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + if (param.mode == take_::kClip) { + Kernel<Take<true>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } else if (param.mode == take_::kWrap) { + Kernel<Take<false>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } + } + }); + }); +} + +struct TakeGradGeneralKernel { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param arr_grad ptr to in_grad + * \param ograd ptr to out_grad + * \param src_indptr ptr to indptr to src indices + * \param original_idx ptr to original indices of the inputs + * \param in_strides strides of inputs + * \param out_strides strides of outputs + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template<typename DType, typename IType> + MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, + const IType* src_indptr, const IType* original_idx, + mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis) { + const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; + const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; + const int in_mid_index = in_rest_index / in_strides[axis]; + const int in_tail_index = (axis == in_ndims - 1) ? + 0 : (in_rest_index % in_strides[axis]); + for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { + const int out_mid_index = original_idx[i]; + int target = in_tail_index + out_mid_index * in_strides[axis]; + target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; + arr_grad[tid] += ograd[target]; + } + } +}; + +template<bool clip = true> +void TakeOpBackwardImpl(mshadow::Stream<cpu>* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + char* temp_storage_ptr = nullptr; + int* src_indptr_ptr = nullptr; Review comment: Since we're using cub DeviceHistogram for doing the histogramming of indices here we need to stick to int32, currently int32 should suffice. Or we can switch our own histogram kernel which supports all types, but that would be slower compared to cub's implementation. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services