haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement of take operator URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r199881853
########## File path: src/operator/tensor/indexing_op.h ########## @@ -805,17 +836,259 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const TShape& oshape = outputs[take_::kOut].shape_; Stream<xpu> *s = ctx.get_stream<xpu>(); + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type - Kernel<Take, xpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr<DType>(), - inputs[take_::kArr].dptr<DType>(), - inputs[take_::kIdx].dptr<IType>(), - oshape.Size()/idxshape.Size(), arrshape[0]); + if (actual_axis == 0) { + if (param.mode == take_::kClip) { + Kernel<Take<true>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } else { + Kernel<Take<false>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } + } else { + mshadow::Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + mshadow::Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + if (param.mode == take_::kClip) { + Kernel<Take<true>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } else if (param.mode == take_::kWrap) { + Kernel<Take<false>, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } + } + }); + }); +} + +struct TakeGradGeneralKernel { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param arr_grad ptr to in_grad + * \param ograd ptr to out_grad + * \param src_indptr ptr to indptr to src indices + * \param original_idx ptr to original indices of the inputs + * \param in_strides strides of inputs + * \param out_strides strides of outputs + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template<typename DType, typename IType> + MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, + const IType* src_indptr, const IType* original_idx, + mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis) { + const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; + const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; + const int in_mid_index = in_rest_index / in_strides[axis]; + const int in_tail_index = (axis == in_ndims - 1) ? + 0 : (in_rest_index % in_strides[axis]); + for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { + const int out_mid_index = original_idx[i]; + int target = in_tail_index + out_mid_index * in_strides[axis]; + target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; + arr_grad[tid] += ograd[target]; + } + } +}; + +template<bool clip = true> +void TakeOpBackwardImpl(mshadow::Stream<cpu>* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + char* temp_storage_ptr = nullptr; + int* src_indptr_ptr = nullptr; + size_t temp_storage_bytes = SortByKeyWorkspaceSize<int, int, cpu>(idxshape.Size()); + size_t original_idx_bytes = idxshape.Size() * sizeof(int); + size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor<cpu, 1, char> workspace = + ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_bytes), s); + int* sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_); + int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * original_idx_bytes); + temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; + // Reset indptr to zero + Kernel<set_zero, cpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel<range_fwd, cpu>::Launch(s, idxshape.Size(), 1, 0, 1, kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel<mshadow_op::identity_with_cast, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr<IType>()); + if (clip) { + Kernel<op_with_req<mshadow_op::clip, kWriteTo>, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, + 0, static_cast<int>(arrshape[axis] - 1)); + } else { + Kernel<op_with_req<mshadow_op::mod, kWriteTo>, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast<int>(arrshape[axis])); + } + Tensor<cpu, 1, int> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + Tensor<cpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); Review comment: Sorry that I did not notice this comment earlier, the tensor is purely for the SortByKey function call, so keeping declaration of it closer to the function call makes more sense. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services