This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch checkpoint in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 388e1df3aeab41fa18bec37a80d76c24f5210983 Author: Chris Sullivan <csulli...@octoml.ai> AuthorDate: Thu Aug 20 22:18:58 2020 -0700 [AMD:ONNXRT:TVM] Include input shapes during compilation. --- include/tvm/driver/jit_interface.h | 6 ++++-- python/tvm/relay/frontend/jit/onnx.py | 7 +++---- src/driver/driver_api.cc | 19 ++++++++++++++----- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/include/tvm/driver/jit_interface.h b/include/tvm/driver/jit_interface.h index e0906f1..e9203ee 100644 --- a/include/tvm/driver/jit_interface.h +++ b/include/tvm/driver/jit_interface.h @@ -2,7 +2,9 @@ #ifdef __cplusplus extern "C" { - EXPORT_DLL tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level); - EXPORT_DLL void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor> inputs, std::vector<DLTensor> outputs, tvm::runtime::TVMRetValue* ret); + EXPORT_DLL tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level, const std::vector<std::vector<int64_t>>& input_shapes); + EXPORT_DLL void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor>& inputs, std::vector<DLTensor>& outputs, tvm::runtime::TVMRetValue* ret); + + } // TVM_EXTERN_C #endif diff --git a/python/tvm/relay/frontend/jit/onnx.py b/python/tvm/relay/frontend/jit/onnx.py index 9545395..3672bbe 100644 --- a/python/tvm/relay/frontend/jit/onnx.py +++ b/python/tvm/relay/frontend/jit/onnx.py @@ -19,13 +19,12 @@ import tvm import tvm.relay @tvm.register_func("tvm_onnx_import_and_compile") -def onnx_compile(model_string, target, target_host, opt_level): +def onnx_compile(model_string, target, target_host, opt_level, input_shapes): model = onnx.load_model_from_string(bytes(model_string)) - # input shape from data - input_shape = {model.graph.input[0].name: (6,)} + input_shapes = {name : shape for (name, shape) in zip([i.name for i in model.graph.input], input_shapes)} - irmod, params = tvm.relay.frontend.from_onnx(model, input_shape, opset=11) + irmod, params = tvm.relay.frontend.from_onnx(model, input_shapes, opset=11) with tvm.relay.build_config(opt_level=opt_level): graph, lib, params = tvm.relay.build(irmod, target_host=target_host, target=target, params=params) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index b876c38..d55c0ae 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -338,16 +338,25 @@ runtime::Module build(const IRModule& funcs, const Target& target, const Target& } // namespace tvm - -tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level) +tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level, const std::vector<std::vector<int64_t>>& input_shapes) { + tvm::Array<tvm::Array<tvm::Integer>> shapes; + for (size_t i = 0; i < input_shapes.size(); i++) + { + tvm::Array<tvm::Integer> shape; + for (auto& dim : input_shapes[i]) + { + shape.push_back(tvm::Integer(dim)); + } + shapes.push_back(shape); + } + const tvm::PackedFunc* compile = tvm::runtime::Registry::Get("tvm_onnx_import_and_compile"); - tvm::runtime::Module mod = (*compile)(TVMByteArray{onnx_txt.data(), onnx_txt.size()}, target, target_host, opt_level); + tvm::runtime::Module mod = (*compile)(TVMByteArray{onnx_txt.data(), onnx_txt.size()}, target, target_host, opt_level, shapes); return mod; - } -void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor> inputs, std::vector<DLTensor> outputs, tvm::runtime::TVMRetValue* ret) +void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor>& inputs, std::vector<DLTensor>& outputs, tvm::runtime::TVMRetValue* ret) { tvm::PackedFunc set_input = mod.GetFunction("set_input_zero_copy", false); for (size_t i = 0; i < inputs.size(); i++)