access2rohit commented on a change in pull request #18168:
URL: https://github.com/apache/incubator-mxnet/pull/18168#discussion_r418388922



##########
File path: src/operator/tensor/broadcast_reduce_op.h
##########
@@ -1049,29 +1049,55 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& 
attrs,
   ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, 
outputs);
 }
 
+namespace {  // unnamed namespace to keep scope of the struct within the file
+struct shape_and_stride {
+  index_t in_stride[MXNET_SPECIAL_MAX_NDIM];
+  index_t out_stride[MXNET_SPECIAL_MAX_NDIM];
+  index_t input_shape[MXNET_SPECIAL_MAX_NDIM];
+  index_t output_shape[MXNET_SPECIAL_MAX_NDIM];
+};
+
+inline void PrepareAUXData(shape_and_stride *aux_data,
+                    mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
+                    mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
+                    int ndim) {
+  index_t iter = ndim - 1;
+  aux_data->out_stride[iter] = 1;
+  aux_data->in_stride[iter] = 1;
+  aux_data->input_shape[iter] = in_shape[iter];
+  aux_data->output_shape[iter] = out_shape[iter];
+  iter--;
+  for (; iter >= 0; --iter) {
+    aux_data->out_stride[iter] = aux_data->out_stride[iter+1] * 
out_shape[iter+1];
+    aux_data->in_stride[iter] = aux_data->in_stride[iter+1] * in_shape[iter+1];
+    aux_data->input_shape[iter] = in_shape[iter];
+    aux_data->output_shape[iter] = out_shape[iter];
+  }
+}
+}  // unnamed namespace
+
 template<typename OP>
 struct broadcast_kernel {
   template<typename IType, typename OType>
   MSHADOW_XINLINE static void Map(index_t i,
                                   IType *input,
                                   OType *output,
-                                  mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
in_shape,
-                                  mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
out_shape,
+                                  const shape_and_stride& aux_data,
                                   const OpReqType req,
-                                  const uint32_t ndim) {
-    size_t in_stride = 1;
-    size_t out_stride = 1;
+                                  const int ndim) {
     index_t idx = i;
     index_t in_idx = i;
-    for (int iter = ndim - 1; iter >= 0; --iter) {
-      size_t dim_idx = idx % out_shape[iter];
-      in_idx -= dim_idx * out_stride;
-      if (in_shape[iter] != 1) {
-        in_idx += dim_idx * in_stride;
+#pragma unroll 4
+    for (index_t iter = ndim - 1; iter >= 0; --iter) {
+      index_t out_dim_shape = aux_data.output_shape[iter];
+      index_t out_dim_stride = aux_data.out_stride[iter];
+      index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape;

Review comment:
       Compiler doesn't do that:
   nvcc is not intelligent enough to do this
   gcc doesn't need to since this doesn't slowdown CPU performance in any 
measurable way.
   
   I will add a comment to explain that this is modulo operation. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to