piiswrong closed pull request #9200: Fix the gradient of gather_nd
URL: https://github.com/apache/incubator-mxnet/pull/9200
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index a1c37a9478..9d3388b235 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -479,6 +479,11 @@ static inline __device__ void 
atomicAdd(mshadow::half::half_t *address,
   } while (assumed != old);
 }
 
+// Overload atomicAdd to work for signed int64 on all architectures
+static inline  __device__  void atomicAdd(int64_t *address, int64_t val) {
+  atomicAdd(reinterpret_cast<unsigned long long*>(address), 
static_cast<unsigned long long>(val)); // NOLINT
+}
+
 template <typename DType>
 __device__ inline DType ldg(const DType* address) {
 #if __CUDA_ARCH__ >= 350
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 15ad59f552..081e40a621 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -132,6 +132,50 @@ inline int get_num_threads<cpu>(const int N) {
     LOG(FATAL) << "ndim=" << NDim << "too large "; \
   }
 
+#define MXNET_NO_INT8_TYPE_SWITCH(type, DType, ...)        \
+  switch (type) {                                          \
+  case mshadow::kFloat32:                                  \
+    {                                                      \
+      typedef float DType;                                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat64:                                  \
+    {                                                      \
+      typedef double DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat16:                                  \
+    {                                                      \
+      typedef mshadow::half::half_t DType;                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kUint8:                                    \
+    LOG(FATAL) << "This operation does not "               \
+                  "support int8 or uint8";                 \
+    break;                                                 \
+  case mshadow::kInt8:                                     \
+    LOG(FATAL) << "This operation does not "               \
+                  "support int8 or uint8";                 \
+    break;                                                 \
+  case mshadow::kInt32:                                    \
+    {                                                      \
+      typedef int32_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt64:                                    \
+    {                                                      \
+      typedef int64_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  default:                                                 \
+    LOG(FATAL) << "Unknown type enum " << type;            \
+  }
+
 
 /*!
  * \brief assign the val to out according
diff --git a/src/operator/tensor/indexing_op.cc 
b/src/operator/tensor/indexing_op.cc
index 735da31b8b..10905b538f 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -137,6 +137,46 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const 
OpContext& ctx,
 }
 
 
+template<typename DType, typename IType>
+inline typename std::enable_if<(!std::is_same<DType, 
mshadow::half::half_t>::value), void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s) {
+#pragma omp parallel for
+  for (int i = 0; i < N; i++) {
+    int offset = 0;
+    for (int j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+    }
+    for (int j = 0; j < K; ++j) {
+#pragma omp atomic
+      out[offset + j] += data[i * K + j];
+    }
+  }
+}
+
+template<typename DType, typename IType>
+inline typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value, void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s) {
+  for (int i = 0; i < N; i++) {
+    int offset = 0;
+    for (int j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+    }
+    for (int j = 0; j < K; ++j) {
+      out[offset + j] += data[i * K + j];
+    }
+  }
+}
+
 DMLC_REGISTER_PARAMETER(EmbeddingParam);
 DMLC_REGISTER_PARAMETER(TakeParam);
 DMLC_REGISTER_PARAMETER(OneHotParam);
@@ -443,8 +483,7 @@ Examples::
 
 NNVM_REGISTER_OP(gather_nd)
 .describe(R"code(Gather elements or slices from `data` and store to a tensor 
whose
-shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse 
functions
-to each other.
+shape is defined by `indices`.
 
 Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape
 `(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, 
..., X_{N-1})`,
@@ -476,13 +515,14 @@ Examples::
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     auto p = nnvm::Node::Create();
-    p->attrs.op = nnvm::Op::Get("scatter_nd");
+    p->attrs.op = nnvm::Op::Get("_backward_gather_nd");
     p->attrs.name = n->attrs.name + "_backward";
     p->inputs.push_back(ograds[0]);
     p->inputs.push_back(n->inputs[1]);
     p->control_deps.emplace_back(n);
     auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
                          {n->inputs[1]}, nullptr, &n);
+
     std::vector<nnvm::NodeEntry> ret;
     ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
     ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
@@ -492,10 +532,8 @@ Examples::
 .add_argument("data", "NDArray-or-Symbol", "data")
 .add_argument("indices", "NDArray-or-Symbol", "indices");
 
-
 NNVM_REGISTER_OP(scatter_nd)
 .describe(R"code(Scatters data into a new tensor according to indices.
-`gather_nd` and `scatter_nd` are inverse functions to each other.
 
 Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices 
with shape
 `(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., 
X_{N-1})`,
@@ -510,6 +548,12 @@ The elements in output is defined as follows::
 
 all other entries in output are 0.
 
+.. warning::
+
+    If the indices have duplicates, the result will be non-deterministic and
+    the gradient of `scatter_nd` will not be correct!!
+
+
 Examples::
 
   data = [2, 3, 0]
@@ -548,11 +592,73 @@ Examples::
 .add_argument("indices", "NDArray-or-Symbol", "indices")
 .add_arguments(ScatterNDParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_backward_gather_nd)
+.describe(R"code(Accumulates data according to indices and get the result. 
It's the backward of
+`gather_nd`.
+
+Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices 
with shape
+`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., 
X_{N-1})`,
+where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
+
+The elements in output is defined as follows::
+
+  output[indices[0, y_0, ..., y_{K-1}],
+         ...,
+         indices[M-1, y_0, ..., y_{K-1}],
+         x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+
+all other entries in output are 0 or the original value if AddTo is triggered.
+
+Examples::
+
+  data = [2, 3, 0]
+  indices = [[1, 1, 0], [0, 1, 0]]
+  shape = (2, 2)
+  _backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as 
scatter_nd
+
+  # The difference between scatter_nd and scatter_nd_acc is the latter will 
accumulate
+  #  the values that point to the same index.
+
+  data = [2, 3, 0]
+  indices = [[1, 1, 0], [1, 1, 0]]
+  shape = (2, 2)
+  _backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]]
+
+)code")
+.set_num_outputs(1)
+.set_num_inputs(2)
+.set_attr_parser(ParamParser<ScatterNDParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "indices"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
+.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
+.set_attr<FCompute>("FCompute<cpu>", GatherNDBackward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+    auto p = nnvm::Node::Create();
+    p->attrs.op = nnvm::Op::Get("gather_nd");
+    p->attrs.name = n->attrs.name + "_backward";
+    p->inputs.push_back(ograds[0]);
+    p->inputs.push_back(n->inputs[1]);
+    p->control_deps.emplace_back(n);
+    auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
+                         {n->inputs[1]}, nullptr, &n);
+    std::vector<nnvm::NodeEntry> ret;
+    ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
+    ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
+    return ret;
+  })
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.add_argument("data", "NDArray-or-Symbol", "data")
+.add_argument("indices", "NDArray-or-Symbol", "indices")
+.add_arguments(ScatterNDParam::__FIELDS__());
+
 NNVM_REGISTER_OP(_scatter_set_nd)
 .describe(R"code(This operator has the same functionality as scatter_nd
 except that it does not reset the elements not indexed by the input
 index `NDArray` in the input data `NDArray`.
-
 .. note:: This operator is for internal use only.
 
 Examples::
diff --git a/src/operator/tensor/indexing_op.cu 
b/src/operator/tensor/indexing_op.cu
index 4021f2b3a2..762d8fd64c 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -179,6 +179,32 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const 
OpContext& ctx,
   });
 }
 
+struct backward_gather_nd_gpu {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, int N, int M, int K,
+                                  const mshadow::Shape<10> strides,
+                                  DType* out, const DType* data,
+                                  const IType* indices) {
+    int offset = 0;
+    for (int j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+    }
+    for (int j = 0; j < K; ++j) {
+      atomicAdd(out + (offset + j), data[i * K + j]);
+    }
+  }
+};
+
+template<typename DType, typename IType>
+inline void GatherNDBackwardImpl(int N, int M, int K,
+                                 const mshadow::Shape<10> strides,
+                                 DType* out,
+                                 const DType* data,
+                                 const IType* indices,
+                                 mshadow::Stream<gpu> *s) {
+  mxnet_op::Kernel<backward_gather_nd_gpu, gpu>::Launch(s, N, N, M, K, 
strides, out, data, indices);
+}
+
 NNVM_REGISTER_OP(Embedding)
 .set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);
 
@@ -209,6 +235,9 @@ NNVM_REGISTER_OP(gather_nd)
 NNVM_REGISTER_OP(scatter_nd)
 .set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);
 
+NNVM_REGISTER_OP(_backward_gather_nd)
+.set_attr<FCompute>("FCompute<gpu>", GatherNDBackward<gpu>);
+
 NNVM_REGISTER_OP(_scatter_set_nd)
 .set_attr<FCompute>("FCompute<gpu>", ScatterSetNDForward<gpu>);
 }  // namespace op
diff --git a/src/operator/tensor/indexing_op.h 
b/src/operator/tensor/indexing_op.h
index 4043e76cfd..7323f81c09 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -1131,10 +1131,10 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
   int K = oshape.ProdShape(M, oshape.ndim());
   mshadow::Shape<10> strides;
   for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = 
stride;
+  if (kWriteTo == req[0]) {
+    Fill<true>(s, outputs[0], req[0], 0);
+  }
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data type 
switch
-    if (kWriteTo == req[0]) {
-      Fill<true>(s, outputs[0], req[0], 0);
-    }
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type 
switch
       mxnet_op::Kernel<scatter_nd, xpu>::Launch(
         s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
@@ -1143,6 +1143,64 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename DType, typename IType>
+inline typename std::enable_if<(!std::is_same<DType, 
mshadow::half::half_t>::value), void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s);
+
+template<typename DType, typename IType>
+inline typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value, void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s);
+
+template<typename DType, typename IType>
+inline void GatherNDBackwardImpl(int N, int M, int K,
+                                 const mshadow::Shape<10> strides,
+                                 DType* out,
+                                 const DType* data,
+                                 const IType* indices,
+                                 mshadow::Stream<gpu> *s);
+
+template<typename xpu>
+void GatherNDBackward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  if (req[0] == kNullOp) return;
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TShape& oshape = outputs[0].shape_;
+  const TShape& ishape = inputs[1].shape_;
+  int M = ishape[0];
+  int N = ishape.Size() / M;
+  int K = oshape.ProdShape(M, oshape.ndim());
+  mshadow::Shape<10> strides;
+  for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = 
stride;
+  if (kWriteTo == req[0]) {
+    Fill<true>(s, outputs[0], req[0], 0);
+  }
+  MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data 
type switch
+    MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type 
switch
+      GatherNDBackwardImpl(N, M, K, strides,
+                           outputs[0].dptr<DType>(),
+                           inputs[0].dptr<DType>(),
+                           inputs[1].dptr<IType>(),
+                           s);
+    });
+  });
+}
+
 /*!
  * This is for internal use only.
  * DO NOT call this function unless you have to.
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index ba1b99183f..4b980b5d04 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4315,21 +4315,32 @@ def check(data, idx):
         npdata = np.zeros_like(data.asnumpy())
         npdata[npidx] = y.asnumpy()
         assert (npdata == data.grad.asnumpy()).all()
-        assert (mx.nd.scatter_nd(y, idx, shape=data.shape).asnumpy() == 
data.grad.asnumpy()).all()
-
-    data = mx.nd.arange(360, dtype='int32').reshape((3,4,5,6))
-    idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
-
-    check(data, idx)
-
-    idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')
-
-    check(data, idx)
-
-    data = mx.nd.array([2, 3, 0])
-    idx = mx.nd.array([[1, 1, 0], [0, 1, 0]])
-
-    assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], 
[2, 3]]).all()
+        assert (mx.nd._internal._backward_gather_nd(y, idx, 
shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
+    for dtype in ['int32', 'int64', 'float16', 'float32', 'float64']:
+        data = mx.nd.arange(360, dtype=dtype).reshape((3,4,5,6))
+        idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
+        check(data, idx)
+
+        idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')
+
+        check(data, idx)
+
+        data = mx.nd.array([2, 3, 0], dtype=dtype)
+        idx = mx.nd.array([[1, 1, 0], [0, 1, 0]], dtype='int32')
+        assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 
0], [2, 3]]).all()
+
+        data = mx.nd.array([2, 3, 0], dtype=dtype)
+        idx = mx.nd.array([[1, 1, 0], [1, 1, 0]], dtype='int32')
+        assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(2, 
2)).asnumpy() == [[0, 0], [0, 5]]).all()
+        data_npy = np.random.randint(0, 10, (100,))
+        data = mx.nd.array(data_npy, dtype=dtype)
+        idx = mx.nd.zeros(shape=(1, 100), dtype='int32')
+        assert (mx.nd._internal._backward_gather_nd(data, idx, 
shape=(1,)).asscalar() == data_npy.sum())
+        if dtype == 'int64':
+            data = mx.nd.array([2123162361283621, -31231236374787,
+                                -112372937128970, -1378278798172378], 
dtype=dtype)
+            idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32')
+            assert (mx.nd._internal._backward_gather_nd(data, idx, 
shape=(1,)).asscalar() == data.asnumpy().sum())
 
 def compare_forw_backw_unary_op(
         name, forward_mxnet_call, forward_numpy_call,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to