sxjscience commented on a change in pull request #18089: URL: https://github.com/apache/incubator-mxnet/pull/18089#discussion_r430856261
########## File path: src/operator/tensor/index_add_backward.cc ########## @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index_add-inl.cc + * \brief CPU implementation of index_add operator +*/ +#include <vector> +#include "./index_add-inl.h" + +namespace mxnet { +namespace op { + +template<typename xpu, typename DType> +void IndexAddOpBackwardACalc(mshadow::Stream<xpu> *s, + DType* grad_a, const DType* ograd, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> stride, + const int tail_size, const int ind_num, + const int ind_ndim, const int32_t* ind_vec, + const int req, const int out_ndim) { + using namespace mxnet_op; + using namespace mshadow; + Kernel<IndexAddBackwardAKernel<DType>, xpu>::Launch( + s, ind_num, grad_a, ograd, stride, tail_size, ind_num, ind_ndim, ind_vec, req, out_ndim); +} + +template<typename DType> +struct IndexAddBackwardValCPUKernel { + MSHADOW_XINLINE static void Map(size_t i, DType* grad_val, + const DType* ograd, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape, + const int ograd_tail_size, const int ind_num, + const int ind_ndim, const int32_t* ind_vec, + const int out_ndim) { + index_t id = 0; + int seg = MXNET_SPECIAL_MAX_NDIM - out_ndim; + for (int dim = 0; dim < ind_ndim; ++dim) { + id += ograd_pre_stride[seg + dim] * ind_vec[dim * ind_num + i]; + } + id *= ograd_tail_size; + #pragma omp parallel for + for (int _i = 0; _i < ograd_tail_size; ++_i) { + mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id = + mxnet_op::unravel(_i, ograd_tail_shape); + mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id; + for (int _j = seg; _j < seg + out_ndim; ++_j) { + val_id[_j] = (val_shape[_j] == 1) ? 0 : ograd_tail_id[_j]; + } + val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 : i; + index_t val_dest = mxnet_op::dot(val_id, val_stride); + #pragma omp critical + { + grad_val[val_dest] += ograd[id + _i]; + } + } + } +}; + +template<typename xpu, typename DType> +void IndexAddOpBackwardValCalc(mshadow::Stream<xpu> *s, + DType* grad_val, const DType* ograd, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride, + const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape, + const int tail_size, const int ind_num, + const int ind_ndim, const int32_t* ind_vec, + const int out_ndim) { + using namespace mxnet_op; + using namespace mshadow; + Kernel<IndexAddBackwardValCPUKernel<DType>, xpu>::Launch( + s, ind_num, grad_val, ograd, ograd_tail_shape, ograd_pre_stride, + val_stride, val_shape, tail_size, ind_num, ind_ndim, ind_vec, out_ndim); +} Review comment: It's not necessary to use Kernel::Launch. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org