This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0455a11  add handling for grad req type other than kNullOp for indices 
(#11983)
0455a11 is described below

commit 0455a112c9aaa681b3e3e194975eb46d8ac4fcc3
Author: Hao Jin <haoj...@users.noreply.github.com>
AuthorDate: Wed Aug 15 15:31:14 2018 -0700

    add handling for grad req type other than kNullOp for indices (#11983)
---
 src/operator/tensor/indexing_op.h      |  9 +++++++--
 tests/python/unittest/test_operator.py | 27 +++++++++++++++++++++++++++
 2 files changed, 34 insertions(+), 2 deletions(-)

diff --git a/src/operator/tensor/indexing_op.h 
b/src/operator/tensor/indexing_op.h
index edaf939..1daf0a2 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -1034,8 +1034,8 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 2U);
-  CHECK_EQ(req[take_::kIdx], kNullOp)
-    << "take layer doesn't support gradient into index";
+  CHECK_NE(req[take_::kIdx], kAddTo)
+    << "take layer doesn't support gradient of req type kAddTo to index";
 
   const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
 
@@ -1052,6 +1052,11 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
       const TShape& arrshape = outputs[0].shape_;
       const TShape& oshape = inputs[0].shape_;
 
+      if (req[take_::kIdx] != kNullOp) {
+        mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
+          s, idxshape.Size(), outputs[take_::kIdx].dptr<IType>());
+      }
+
       const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() 
: 0);
 
       int idxndim = idxshape.ndim();
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 66e850f..f1aec12 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3812,6 +3812,31 @@ def test_take():
         exe.backward([mx.nd.array(grad_out)])
         assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in)
 
+    def check_autograd_req():
+        row_len = 2
+        col_len = 8
+        shape = (row_len, col_len)
+        sc = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype="float32")
+        sc.attach_grad()
+        i = mx.nd.array([0], dtype="int64")
+        j = mx.nd.array([0], dtype="int64")
+        with mx.autograd.record(train_mode=True):
+            xs = []
+            for _ in range(row_len):
+                x_i = []
+                for _ in range(col_len):
+                    x_ij = sc.take(i).squeeze(axis=0).take(j).squeeze(axis=0)
+                    x_i.append(x_ij)
+                    j = j + 1
+                i = i + 1
+                j = j - col_len  # reset j
+                xs.append(mx.nd.stack(*x_i))
+            x = mx.nd.stack(*xs)
+            x = x.sum()
+
+        x.backward()
+        assert_almost_equal(np.ones(sc.grad.shape), sc.grad.asnumpy())
+
     for mode in ['clip', 'wrap']:
         for data_ndim in range(1, 5):
             for idx_ndim in range(1, 4):
@@ -3824,6 +3849,8 @@ def test_take():
                         idx_shape += (np.random.randint(low=1, high=5), )
                     check_output_n_grad(data_shape, idx_shape, axis, mode)
 
+    check_autograd_req()
+
 
 @with_seed()
 def test_grid_generator():

Reply via email to