This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch checkpoint in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 6f95da54b0a836f86d1636bcf38d98f9bfae0674 Author: Josh Fromm <jwfr...@uw.edu> AuthorDate: Wed Mar 24 16:32:06 2021 -0700 Some fixes. --- CMakeLists.txt | 4 ++-- python/tvm/relay/frontend/jit/onnx.py | 4 +--- src/driver/driver_api.cc | 9 +++++++-- src/runtime/graph/graph_runtime.cc | 7 ------- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 769a353..0a5122d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ project(tvm C CXX) include(cmake/utils/Utils.cmake) include(cmake/utils/FindCUDA.cmake) include(cmake/utils/FindOpenCL.cmake) -include(cmake/utils/FindVulkan.cmake) +#include(cmake/utils/FindVulkan.cmake) include(cmake/utils/FindLLVM.cmake) include(cmake/utils/FindROCM.cmake) include(cmake/utils/FindEthosN.cmake) @@ -330,7 +330,7 @@ include(cmake/modules/CUDA.cmake) include(cmake/modules/Hexagon.cmake) include(cmake/modules/OpenCL.cmake) include(cmake/modules/OpenMP.cmake) -include(cmake/modules/Vulkan.cmake) +#include(cmake/modules/Vulkan.cmake) include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) include(cmake/modules/LLVM.cmake) diff --git a/python/tvm/relay/frontend/jit/onnx.py b/python/tvm/relay/frontend/jit/onnx.py index ae10915..e0f5b6b 100644 --- a/python/tvm/relay/frontend/jit/onnx.py +++ b/python/tvm/relay/frontend/jit/onnx.py @@ -46,9 +46,8 @@ def onnx_compile(model_string, target, target_host, opt_level, input_shapes): shape_dict = collections.OrderedDict(input_mapping) irmod, params = tvm.relay.frontend.from_onnx(model, shape_dict, opset=11) - print(irmod) # import ipdb; ipdb.set_trace() - with tvm.relay.build_config(opt_level=opt_level): + with tvm.relay.build_config(opt_level=opt_level, disabled_pass={"AlterOpLayout"}): tuning_logfile = os.getenv("AUTOTVM_TUNING_LOG") if tuning_logfile: with autotvm.apply_history_best(tuning_logfile): @@ -57,7 +56,6 @@ def onnx_compile(model_string, target, target_host, opt_level, input_shapes): else: lib = tvm.relay.build(irmod, target_host=target_host, target=target) - print(lib.graph_json) ctx = tvm.context(target, 0) m = tvm.contrib.graph_runtime.GraphModule(lib["default"](ctx)) # m.set_input(**params) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0011dab..a892a86 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -364,8 +364,13 @@ void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor>& inputs, std::vecto set_input(i, &inputs[i]); } - const tvm::PackedFunc* run = tvm::runtime::Registry::Get("tvm_run_with_benchmark"); - (*run)(mod); + // Dont include benchmarking in core run for now + //const tvm::PackedFunc* run = tvm::runtime::Registry::Get("tvm_run_with_benchmark"); + //(*run)(mod); + + // Just directly run the module + tvm::PackedFunc run = mod.GetFunction("run", false); + run(); tvm::PackedFunc get_output = mod.GetFunction("get_output", false); for (size_t i = 0; i < outputs.size(); i++) diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 5df9068..8b697c9 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -90,7 +90,6 @@ void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module modu for(int ind = 0; ind < old_t->ndim; ind++) { s << old_t->shape[ind] << " "; } - LOG(INFO) << s.str(); } } /*! @@ -128,12 +127,6 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { // check the consistency of input ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref)); // ICHECK_EQ(reinterpret_cast<size_t>(data_ref->data) % kAllocAlignment, 0) << data_ref->data; - for(int i = 0; i < old_t->ndim; i++) { - LOG(INFO) << "OLD " << old_t->shape[i]; - } - for(int i = 0; i < data_ref->ndim; i++) { - LOG(INFO) << "DATA_REF " << data_ref->shape[i]; - } ICHECK_EQ(old_t->ndim, static_cast<size_t>(data_ref->ndim)); ICHECK_EQ(old_t->ctx.device_type, data_ref->ctx.device_type); ICHECK_EQ(old_t->ctx.device_id, data_ref->ctx.device_id);