piiswrong closed pull request #11436: handle the case that inputs and outputs of a graph share NDArrays URL: https://github.com/apache/incubator-mxnet/pull/11436
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 5a3d44c04ce..2181c5cab87 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -591,6 +591,7 @@ void CachedOp::StaticRunOps( const Context& default_ctx, const nnvm::Graph& g, const OpStatePtr& state_ptr, + const std::vector<NDArray *> &state_arrays, size_t start_nid, size_t end_nid) { static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState"); @@ -624,7 +625,7 @@ void CachedOp::StaticRunOps( ndinputs.clear(); ndinputs.reserve(node.inputs.size()); for (const auto& j : node.inputs) { - ndinputs.emplace_back(state.arrays[idx.entry_id(j)]); + ndinputs.emplace_back(state_arrays[idx.entry_id(j)]); CHECK(!ndinputs.back()->is_none()); } ndoutputs.clear(); @@ -633,7 +634,7 @@ void CachedOp::StaticRunOps( req.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { size_t eid = idx.entry_id(i, j); - ndoutputs.emplace_back(state.arrays[eid]); + ndoutputs.emplace_back(state_arrays[eid]); req.push_back(state.array_reqs[eid]); CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none()); } @@ -688,25 +689,29 @@ OpStatePtr CachedOp::StaticForward( StaticAllocMemory(state_ptr, recording, false); } + // We are going to add input and output arrays to the array list. + // The input and output arrays should only be valid for this run, + // so we shouldn't modify the state's array list. + auto arrays = state.arrays; if (config_.static_shape) { for (auto i : config_.param_indices) { auto nid = idx.input_nodes()[i]; - if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) { + if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) { match = false; auto ptr = &state.buff[idx.entry_id(nid, 0)]; - CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr); - *state.arrays[idx.entry_id(nid, 0)] = *inputs[i]; + CHECK_EQ(arrays[idx.entry_id(nid, 0)], ptr); + *arrays[idx.entry_id(nid, 0)] = *inputs[i]; state.dynamic_entries[idx.entry_id(nid, 0)] = false; } } for (auto i : config_.data_indices) { auto eid = idx.entry_id(idx.input_nodes()[i], 0); - state.arrays[eid] = inputs[i]; + arrays[eid] = inputs[i]; } } else { for (size_t i = 0; i < num_inputs(); ++i) { auto nid = idx.input_nodes()[i]; - state.arrays[idx.entry_id(nid, 0)] = inputs[i]; + arrays[idx.entry_id(nid, 0)] = inputs[i]; } } @@ -720,13 +725,16 @@ OpStatePtr CachedOp::StaticForward( for (size_t i = 0; i < outputs.size(); ++i) { auto eid = idx.entry_id(idx.outputs()[i]); - state.arrays[eid] = outputs[i]; + // An input and an output may share the same array. + if (!arrays[eid]->is_none()) + *outputs[i] = arrays[eid]->Detach(); + arrays[eid] = outputs[i]; if (!outputs[i]->is_none()) continue; *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]), shapes[eid], default_ctx, true, dtypes[eid]); } - StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes()); + StaticRunOps(default_ctx, g, state_ptr, arrays, 0, idx.num_nodes()); return recording ? state_ptr : OpStatePtr(); } @@ -891,7 +899,11 @@ void CachedOp::DynamicBackward( } for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) { if (reqs[i] == kNullOp) continue; - arrays[idx.entry_id(idx.outputs()[j++])] = outputs[i]; + auto eid = idx.entry_id(idx.outputs()[j++]); + // An input and an output may share the same array. + if (!arrays[eid]->is_none()) + *outputs[i] = arrays[eid]->Detach(); + arrays[eid] = outputs[i]; } // Allocate NDArrays @@ -952,6 +964,15 @@ void CachedOp::StaticBackward( StaticAllocMemory(state_ptr, true, true); } + // We are going to add input and output arrays to the array list. + // The input and output arrays should only be valid for this run, + // so we shouldn't modify the state's array list. + auto arrays = state.arrays; + for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) { + auto eid = state.info.bwd_input_eid[i]; + if (state.dynamic_entries[eid]) arrays[eid] = inputs[i]; + } + if (config_.static_shape) { for (auto i : config_.param_indices) { const auto iter = fwd_input_to_grad_output_.find(i); @@ -959,11 +980,14 @@ void CachedOp::StaticBackward( auto entry = grad_graph_.outputs[iter->second]; if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); - if (!state.arrays[eid]->IsSame(*outputs[iter->second]) || + if (!arrays[eid]->IsSame(*outputs[iter->second]) || !(state.array_reqs[eid] == reqs[iter->second])) { match = false; state.array_reqs[eid] = reqs[iter->second]; - *state.arrays[eid] = *outputs[iter->second]; + // An input and an output may share the same array. + if (!arrays[eid]->is_none()) + *outputs[iter->second] = arrays[eid]->Detach(); + *arrays[eid] = *outputs[iter->second]; state.dynamic_entries[eid] = false; } } @@ -974,7 +998,10 @@ void CachedOp::StaticBackward( if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); state.array_reqs[eid] = reqs[iter->second]; - state.arrays[eid] = outputs[iter->second]; + // An input and an output may share the same array. + if (!arrays[eid]->is_none()) + *outputs[iter->second] = arrays[eid]->Detach(); + arrays[eid] = outputs[iter->second]; } } else { for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { @@ -982,7 +1009,10 @@ void CachedOp::StaticBackward( if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); state.array_reqs[eid] = reqs[i]; - state.arrays[eid] = outputs[i]; + // An input and an output may share the same array. + if (!arrays[eid]->is_none()) + *outputs[i] = arrays[eid]->Detach(); + arrays[eid] = outputs[i]; } } @@ -990,12 +1020,7 @@ void CachedOp::StaticBackward( StaticInitExec(state_ptr, true, true); } - for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) { - auto eid = state.info.bwd_input_eid[i]; - if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i]; - } - - StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes()); + StaticRunOps(default_ctx, g, state_ptr, arrays, num_forward_nodes, idx.num_nodes()); } void CachedOp::Backward( diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 6b94c67a94e..370ef02b5f2 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -154,6 +154,7 @@ class CachedOp { const Context& default_ctx, const nnvm::Graph& g, const OpStatePtr& state_ptr, + const std::vector<NDArray *> &state_arrays, size_t start_nid, size_t end_nid); OpStatePtr StaticForward( diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index cd3cc685bdd..43497687887 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1361,6 +1361,56 @@ def test_hybrid_static_memory_recording(): net(x) +def test_share_inputs_outputs(): + class TestIOBackward(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestIOBackward, self).__init__(prefix=prefix, params=params) + + def hybrid_forward(self, F, in1, in2): + return in1 + in2 + + class TestIOForward(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestIOForward, self).__init__(prefix=prefix, params=params) + + def hybrid_forward(self, F, in1): + return in1 + + d1 = mx.nd.arange(10) + d2 = mx.nd.arange(10) + + params=[{'inline_limit':0}, + {'inline_limit':0, 'static_alloc':True}, + {'inline_limit':0, 'static_alloc':True, 'static_shape':True}] + # Test the case that inputs and outputs of a forward graph share NDArrays. + for param in params: + t = TestIOForward() + t.hybridize(**param) + for i in range(5): + d1.attach_grad() + out_grad = mx.nd.random.uniform(shape=(10)) + res = t(d1) + assert_almost_equal(res.asnumpy(), d1.asnumpy()) + + param = deepcopy(params[2]) + param['param_indices'] = (1) + param['data_indices'] = (0) + params.append(param) + # Test the case that inputs and outputs of a backward graph share NDArrays. + for param in params: + t = TestIOBackward() + t.hybridize(**param) + for i in range(5): + d1.attach_grad() + d2.attach_grad() + out_grad = mx.nd.random.uniform(shape=(10)) + with mx.autograd.record(): + res = t(d1, d2) + res.backward(out_grad=out_grad) + assert_almost_equal(out_grad.asnumpy(), d1.grad.asnumpy()) + assert_almost_equal(out_grad.asnumpy(), d2.grad.asnumpy()) + + if __name__ == '__main__': import nose nose.runmodule() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services