haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement 
of take operator
URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r198671449
 
 

 ##########
 File path: src/operator/tensor/indexing_op.h
 ##########
 @@ -805,17 +836,259 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs,
   const TShape& oshape = outputs[take_::kOut].shape_;
 
   Stream<xpu> *s = ctx.get_stream<xpu>();
+  const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 
0);
+
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output data type
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index data type
-      Kernel<Take, xpu>::Launch(s, oshape.Size(),
-                                outputs[take_::kOut].dptr<DType>(),
-                                inputs[take_::kArr].dptr<DType>(),
-                                inputs[take_::kIdx].dptr<IType>(),
-                                oshape.Size()/idxshape.Size(), arrshape[0]);
+      if (actual_axis == 0) {
+        if (param.mode == take_::kClip) {
+          Kernel<Take<true>, xpu>::Launch(s, oshape.Size(),
+                                          outputs[take_::kOut].dptr<DType>(),
+                                          inputs[take_::kArr].dptr<DType>(),
+                                          inputs[take_::kIdx].dptr<IType>(),
+                                          oshape.Size()/idxshape.Size(), 
arrshape[0]);
+        } else {
+          Kernel<Take<false>, xpu>::Launch(s, oshape.Size(),
+                                           outputs[take_::kOut].dptr<DType>(),
+                                           inputs[take_::kArr].dptr<DType>(),
+                                           inputs[take_::kIdx].dptr<IType>(),
+                                           oshape.Size()/idxshape.Size(), 
arrshape[0]);
+        }
+      } else {
+        mshadow::Shape<10> in_strides;
+        int stride = 1;
+        for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
+          in_strides[i] = stride;
+        }
+        mshadow::Shape<10> out_strides;
+        stride = 1;
+        for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
+          out_strides[i] = stride;
+        }
+        if (param.mode == take_::kClip) {
+          Kernel<Take<true>, xpu>::Launch(s, oshape.Size(),
+                                          outputs[take_::kOut].dptr<DType>(),
+                                          inputs[take_::kArr].dptr<DType>(),
+                                          inputs[take_::kIdx].dptr<IType>(),
+                                          in_strides, out_strides, 
arrshape.ndim(), oshape.ndim(),
+                                          idxshape.ndim(), 
arrshape[actual_axis], actual_axis);
+        } else if (param.mode == take_::kWrap) {
+          Kernel<Take<false>, xpu>::Launch(s, oshape.Size(),
+                                           outputs[take_::kOut].dptr<DType>(),
+                                           inputs[take_::kArr].dptr<DType>(),
+                                           inputs[take_::kIdx].dptr<IType>(),
+                                           in_strides, out_strides, 
arrshape.ndim(), oshape.ndim(),
+                                           idxshape.ndim(), 
arrshape[actual_axis], actual_axis);
+        }
+      }
+    });
+  });
+}
+
+struct TakeGradGeneralKernel {
+  /*!
+   * \brief Map function for general case of take grad
+   * \param tid           global thread id
+   * \param arr_grad      ptr to in_grad
+   * \param ograd         ptr to out_grad
+   * \param src_indptr    ptr to indptr to src indices
+   * \param original_idx  ptr to original indices of the inputs
+   * \param in_strides    strides of inputs
+   * \param out_strides   strides of outputs
+   * \param in_ndims      # of dims of input tensor
+   * \param out_ndims     # of dims of output tensor
+   * \param idx_ndims     # of dims of indices tensor
+   * \param axis_dim      dim size of the axis dimension
+   * \param axis          axis id
+   */
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd,
+                                  const IType* src_indptr, const IType* 
original_idx,
+                                  mshadow::Shape<10> in_strides, 
mshadow::Shape<10> out_strides,
+                                  const int in_ndims, const int out_ndims, 
const int idx_ndims,
+                                  const int axis) {
+    const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
+    const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
+    const int in_mid_index = in_rest_index / in_strides[axis];
+    const int in_tail_index = (axis == in_ndims - 1) ?
+                              0 : (in_rest_index % in_strides[axis]);
+    for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; 
++i) {
+      const int out_mid_index = original_idx[i];
+      int target = in_tail_index + out_mid_index * in_strides[axis];
+      target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
+      arr_grad[tid] += ograd[target];
+    }
+  }
+};
+
+template<bool clip = true>
+void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
+                        const OpContext& ctx,
+                        const TBlob& arr,
+                        const TBlob& idx,
+                        const TBlob& ograd,
+                        const int axis) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy 
implementation";
+  const TShape& arrshape = arr.shape_;
+  const TShape& idxshape = idx.shape_;
+  const TShape& oshape = ograd.shape_;
+  MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
+    // get size of temporary storage for sort
+    char* temp_storage_ptr = nullptr;
+    int* src_indptr_ptr = nullptr;
 
 Review comment:
   Since we're using cub DeviceHistogram for doing the histogramming of indices 
here we need to stick to int32, currently int32 should suffice. Or we can 
switch our own histogram kernel which supports all types, but that would be 
slower compared to cub's implementation.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to