This is an automated email from the ASF dual-hosted git repository. areusch pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 3f788b4 [FIX,VM] Fix get_outputs on the vm with a single output (#7902) 3f788b4 is described below commit 3f788b41c1a93bf929a475f182a1fd0fc9f9f142 Author: Tristan Konolige <tristan.konol...@gmail.com> AuthorDate: Wed May 5 09:57:37 2021 -0700 [FIX,VM] Fix get_outputs on the vm with a single output (#7902) * [FIX,VM] Fix get_outputs on the vm with a single output The VM uses an ADT for multiple outputs and an NDArray for a single output. The single output case was not being handled. * check if the user specified the correct index --- src/runtime/vm/vm.cc | 18 +++++++++++++++--- tests/python/relay/test_vm.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index a0edb3b..17a66e4 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -142,11 +142,23 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, }); } else if (name == "get_output") { return TypedPackedFunc<NDArray(int64_t)>([this](int64_t index) { - return Downcast<NDArray>(Downcast<ADT>(this->return_register_)[index]); + if (this->return_register_.as<ADTObj>()) { + return Downcast<NDArray>(Downcast<ADT>(this->return_register_)[index]); + } else { + CHECK_EQ(index, 0) << "VM output contains only one item, but you are trying to get the " + << index << "th."; + return Downcast<NDArray>(this->return_register_); + } }); } else if (name == "get_num_outputs") { - return TypedPackedFunc<int64_t(void)>( - [this]() -> int64_t { return Downcast<ADT>(this->return_register_).size(); }); + return TypedPackedFunc<int64_t(void)>([this]() -> int64_t { + // single output is an NDArray not an ADT + if (this->return_register_.as<ADTObj>()) { + return Downcast<ADT>(this->return_register_).size(); + } else { + return 1; + } + }); } else if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size() % 3, 0); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 7e79049..8f51869 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -852,5 +852,42 @@ def test_vm_rpc(): server.terminate() +def test_get_output_single(): + target = tvm.target.Target("llvm") + + # Build a IRModule. + x = relay.var("x", shape=(10,)) + f = relay.Function([x], x + x) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + vm_exec = vm.compile(mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + inp = np.ones(10, dtype="float32") + vm_factory.invoke_stateful("main", inp) + outputs = vm_factory.get_outputs() + assert len(outputs) == 1 + np.testing.assert_allclose(outputs[0].asnumpy(), inp + inp) + + +def test_get_output_multiple(): + target = tvm.target.Target("llvm") + + # Build a IRModule. + x = relay.var("x", shape=(10,)) + f = relay.Function([x], relay.Tuple([x + x, x])) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + vm_exec = vm.compile(mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + inp = np.ones(10, dtype="float32") + vm_factory.invoke_stateful("main", inp) + outputs = vm_factory.get_outputs() + assert len(outputs) == 2 + np.testing.assert_allclose(outputs[0].asnumpy(), inp + inp) + np.testing.assert_allclose(outputs[1].asnumpy(), inp) + + if __name__ == "__main__": pytest.main([__file__])