piiswrong closed pull request #11436: handle the case that inputs and outputs 
of a graph share NDArrays
URL: https://github.com/apache/incubator-mxnet/pull/11436
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5a3d44c04ce..2181c5cab87 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -591,6 +591,7 @@ void CachedOp::StaticRunOps(
     const Context& default_ctx,
     const nnvm::Graph& g,
     const OpStatePtr& state_ptr,
+    const std::vector<NDArray *> &state_arrays,
     size_t start_nid,
     size_t end_nid) {
   static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
@@ -624,7 +625,7 @@ void CachedOp::StaticRunOps(
       ndinputs.clear();
       ndinputs.reserve(node.inputs.size());
       for (const auto& j : node.inputs) {
-        ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
+        ndinputs.emplace_back(state_arrays[idx.entry_id(j)]);
         CHECK(!ndinputs.back()->is_none());
       }
       ndoutputs.clear();
@@ -633,7 +634,7 @@ void CachedOp::StaticRunOps(
       req.reserve(num_outputs);
       for (size_t j = 0; j < num_outputs; ++j) {
         size_t eid = idx.entry_id(i, j);
-        ndoutputs.emplace_back(state.arrays[eid]);
+        ndoutputs.emplace_back(state_arrays[eid]);
         req.push_back(state.array_reqs[eid]);
         CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
       }
@@ -688,25 +689,29 @@ OpStatePtr CachedOp::StaticForward(
     StaticAllocMemory(state_ptr, recording, false);
   }
 
+  // We are going to add input and output arrays to the array list.
+  // The input and output arrays should only be valid for this run,
+  // so we shouldn't modify the state's array list.
+  auto arrays = state.arrays;
   if (config_.static_shape) {
     for (auto i : config_.param_indices) {
       auto nid = idx.input_nodes()[i];
-      if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
+      if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
         match = false;
         auto ptr = &state.buff[idx.entry_id(nid, 0)];
-        CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
-        *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
+        CHECK_EQ(arrays[idx.entry_id(nid, 0)], ptr);
+        *arrays[idx.entry_id(nid, 0)] = *inputs[i];
         state.dynamic_entries[idx.entry_id(nid, 0)] = false;
       }
     }
     for (auto i : config_.data_indices) {
       auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      state.arrays[eid] = inputs[i];
+      arrays[eid] = inputs[i];
     }
   } else {
     for (size_t i = 0; i < num_inputs(); ++i) {
       auto nid = idx.input_nodes()[i];
-      state.arrays[idx.entry_id(nid, 0)] = inputs[i];
+      arrays[idx.entry_id(nid, 0)] = inputs[i];
     }
   }
 
@@ -720,13 +725,16 @@ OpStatePtr CachedOp::StaticForward(
 
   for (size_t i = 0; i < outputs.size(); ++i) {
     auto eid = idx.entry_id(idx.outputs()[i]);
-    state.arrays[eid] = outputs[i];
+    // An input and an output may share the same array.
+    if (!arrays[eid]->is_none())
+      *outputs[i] = arrays[eid]->Detach();
+    arrays[eid] = outputs[i];
     if (!outputs[i]->is_none()) continue;
     *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
                           shapes[eid], default_ctx, true, dtypes[eid]);
   }
 
-  StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
+  StaticRunOps(default_ctx, g, state_ptr, arrays, 0, idx.num_nodes());
 
   return recording ? state_ptr : OpStatePtr();
 }
@@ -891,7 +899,11 @@ void CachedOp::DynamicBackward(
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
     if (reqs[i] == kNullOp) continue;
-    arrays[idx.entry_id(idx.outputs()[j++])] = outputs[i];
+    auto eid = idx.entry_id(idx.outputs()[j++]);
+    // An input and an output may share the same array.
+    if (!arrays[eid]->is_none())
+      *outputs[i] = arrays[eid]->Detach();
+    arrays[eid] = outputs[i];
   }
 
   // Allocate NDArrays
@@ -952,6 +964,15 @@ void CachedOp::StaticBackward(
     StaticAllocMemory(state_ptr, true, true);
   }
 
+  // We are going to add input and output arrays to the array list.
+  // The input and output arrays should only be valid for this run,
+  // so we shouldn't modify the state's array list.
+  auto arrays = state.arrays;
+  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
+    auto eid = state.info.bwd_input_eid[i];
+    if (state.dynamic_entries[eid]) arrays[eid] = inputs[i];
+  }
+
   if (config_.static_shape) {
     for (auto i : config_.param_indices) {
       const auto iter = fwd_input_to_grad_output_.find(i);
@@ -959,11 +980,14 @@ void CachedOp::StaticBackward(
       auto entry = grad_graph_.outputs[iter->second];
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
-      if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
+      if (!arrays[eid]->IsSame(*outputs[iter->second]) ||
           !(state.array_reqs[eid] == reqs[iter->second])) {
         match = false;
         state.array_reqs[eid] = reqs[iter->second];
-        *state.arrays[eid] = *outputs[iter->second];
+        // An input and an output may share the same array.
+        if (!arrays[eid]->is_none())
+          *outputs[iter->second] = arrays[eid]->Detach();
+        *arrays[eid] = *outputs[iter->second];
         state.dynamic_entries[eid] = false;
       }
     }
@@ -974,7 +998,10 @@ void CachedOp::StaticBackward(
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
       state.array_reqs[eid] = reqs[iter->second];
-      state.arrays[eid] = outputs[iter->second];
+      // An input and an output may share the same array.
+      if (!arrays[eid]->is_none())
+        *outputs[iter->second] = arrays[eid]->Detach();
+      arrays[eid] = outputs[iter->second];
     }
   } else {
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
@@ -982,7 +1009,10 @@ void CachedOp::StaticBackward(
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
       state.array_reqs[eid] = reqs[i];
-      state.arrays[eid] = outputs[i];
+      // An input and an output may share the same array.
+      if (!arrays[eid]->is_none())
+        *outputs[i] = arrays[eid]->Detach();
+      arrays[eid] = outputs[i];
     }
   }
 
@@ -990,12 +1020,7 @@ void CachedOp::StaticBackward(
     StaticInitExec(state_ptr, true, true);
   }
 
-  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
-    auto eid = state.info.bwd_input_eid[i];
-    if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
-  }
-
-  StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
+  StaticRunOps(default_ctx, g, state_ptr, arrays, num_forward_nodes, 
idx.num_nodes());
 }
 
 void CachedOp::Backward(
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 6b94c67a94e..370ef02b5f2 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -154,6 +154,7 @@ class CachedOp {
       const Context& default_ctx,
       const nnvm::Graph& g,
       const OpStatePtr& state_ptr,
+      const std::vector<NDArray *> &state_arrays,
       size_t start_nid,
       size_t end_nid);
   OpStatePtr StaticForward(
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index cd3cc685bdd..43497687887 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1361,6 +1361,56 @@ def test_hybrid_static_memory_recording():
     net(x)
 
 
+def test_share_inputs_outputs():
+    class TestIOBackward(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestIOBackward, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, in1, in2):
+            return in1 + in2
+
+    class TestIOForward(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestIOForward, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, in1):
+            return in1
+
+    d1 = mx.nd.arange(10)
+    d2 = mx.nd.arange(10)
+
+    params=[{'inline_limit':0},
+            {'inline_limit':0, 'static_alloc':True},
+            {'inline_limit':0, 'static_alloc':True, 'static_shape':True}]
+    # Test the case that inputs and outputs of a forward graph share NDArrays.
+    for param in params:
+        t = TestIOForward()
+        t.hybridize(**param)
+        for i in range(5):
+            d1.attach_grad()
+            out_grad = mx.nd.random.uniform(shape=(10))
+            res = t(d1)
+            assert_almost_equal(res.asnumpy(), d1.asnumpy())
+
+    param = deepcopy(params[2])
+    param['param_indices'] = (1)
+    param['data_indices'] = (0)
+    params.append(param)
+    # Test the case that inputs and outputs of a backward graph share NDArrays.
+    for param in params:
+        t = TestIOBackward()
+        t.hybridize(**param)
+        for i in range(5):
+            d1.attach_grad()
+            d2.attach_grad()
+            out_grad = mx.nd.random.uniform(shape=(10))
+            with mx.autograd.record():
+                res = t(d1, d2)
+            res.backward(out_grad=out_grad)
+            assert_almost_equal(out_grad.asnumpy(), d1.grad.asnumpy())
+            assert_almost_equal(out_grad.asnumpy(), d2.grad.asnumpy())
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to