ganler opened a new pull request #9582: URL: https://github.com/apache/tvm/pull/9582
### Environment & Flags - Model: ![image](https://user-images.githubusercontent.com/38074777/143319329-69de4c95-80e1-4e29-bd0a-52e2f7917ee6.png) - Target: llvm - TVM git hash: 0cb633777a2f65a06579b1cea7ef11e3dd659498 - Compler: clang version 13.0.0 - Linux: Linux ise-manjaro 5.10.70-1-MANJARO #1 SMP PREEMPT Thu Sep 30 15:29:01 UTC 2021 x86_64 GNU/Linux ### Bug Description When compiling this model, TVM tends to do some type (shape) inference related to NCHWc. In NCHWc->NCHW shape inference, `c + C` will be evaluated. ![image](https://user-images.githubusercontent.com/38074777/143319870-d90a076c-f64d-4eb7-97bf-4cde64b9554f.png) However, model shape axis values are int64 but some other values initialized by default is int32. So here's a data type mismatch. This data type mismatch will cause TVM to fail in `ExprMutator` since binary operators require matched types. for operands. ![image](https://user-images.githubusercontent.com/38074777/143320278-b42d06da-d64f-4c00-bf9c-9fd3a3c1b277.png) ### Fix Initialize those artificial constants as int64 instead of int32 (`PrimExpr(0)` will be regarded as int32.). ### Full failure log ``` > INIT:: # Edge in this DSO: 1244181; # Edge total: 1244181 One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details. Traceback (most recent call last): File "nnsmith/backend_executor.py", line 60, in <module> run_backend_same_proc(args.model, args.input, bknd) File "/home/ganler/Documents/nnsmith/nnsmith/difftest.py", line 62, in run_backend_same_proc outputs = backend.predict(model_path, inputs) File "/home/ganler/Documents/nnsmith/nnsmith/backends/tvm_graph.py", line 74, in predict self.load_model(model) File "/home/ganler/Documents/nnsmith/nnsmith/backends/tvm_graph.py", line 68, in load_model executor = relay.build_module.create_executor( File "/home/ganler/Documents/tvm/python/tvm/relay/backend/interpreter.py", line 171, in evaluate return self._make_executor() File "/home/ganler/Documents/tvm/python/tvm/relay/build_module.py", line 591, in _make_executor mod = build(self.mod, target=self.target) File "/home/ganler/Documents/tvm/python/tvm/relay/build_module.py", line 449, in build graph_json, runtime_mod, params = bld_mod.build( File "/home/ganler/Documents/tvm/python/tvm/relay/build_module.py", line 189, in build self._build(mod, target, target_host, executor, runtime, mod_name) File "/home/ganler/Documents/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__ raise get_last_ffi_error() tvm._ffi.base.TVMError: Traceback (most recent call last): 16: TVMFuncCall 15: _ZNSt17_Function_handlerIFvN3tvm7runtime7TVMArgsEPNS1_11TVMRetValue 14: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const 13: tvm::relay::backend::RelayBuildModule::Build(tvm::IRModule, tvm::runtime::Map<tvm::Integer, tvm::Target, void, void> const&, tvm::Target const&, tvm::relay::Executor const&, tvm::relay::Runtime const&, tvm::runtime::String) 12: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&) 11: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule) 10: tvm::transform::Pass::operator()(tvm::IRModule) const 9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 6: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 5: tvm::transform::Pass::operator()(tvm::IRModule) const 4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) 1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) 0: tvm::relay::TypeSolver::Solve() 28: TVMFuncCall 27: _ZNSt17_Function_handlerIFvN3tvm7runtime7TVMArgsEPNS1_11TVMRetValue 26: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const 25: tvm::relay::backend::RelayBuildModule::Build(tvm::IRModule, tvm::runtime::Map<tvm::Integer, tvm::Target, void, void> const&, tvm::Target const&, tvm::relay::Executor const&, tvm::relay::Runtime const&, tvm::runtime::String) 24: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&) 23: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule) 22: tvm::transform::Pass::operator()(tvm::IRModule) const 21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 20: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 19: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 18: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 17: tvm::transform::Pass::operator()(tvm::IRModule) const 16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 15: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 14: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) 13: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) 12: tvm::relay::TypeSolver::Solve() 11: tvm::TypedEnvFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::operator()(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) const 10: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const 9: bool tvm::relay::Conv2DWinogradRel<tvm::relay::Conv2DAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) 8: tvm::tir::BijectiveLayout::ForwardShape(tvm::runtime::Array<tvm::PrimExpr, void> const&) const 7: tvm::tir::TransformShape(tvm::runtime::Array<tvm::PrimExpr, void> const&, tvm::runtime::Array<tvm::tir::IterVar, void> const&, tvm::runtime::Array<tvm::tir::IterVar, void> const&, tvm::runtime::Array<tvm::PrimExpr, void> const&) 6: tvm::PrimExpr tvm::tir::Substitute<tvm::PrimExpr>(tvm::PrimExpr, std::unordered_map<tvm::tir::VarNode const*, tvm::PrimExpr, std::hash<tvm::tir::VarNode const*>, std::equal_to<tvm::tir::VarNode const*>, std::allocator<std::pair<tvm::tir::VarNode const* const, tvm::PrimExpr> > > const&) 5: tvm::tir::Substitute(tvm::PrimExpr, std::function<tvm::runtime::Optional<tvm::PrimExpr> (tvm::tir::Var const&)>) 4: non-virtual thunk to tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&) 3: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const 2: _ZZN3tvm3tir11ExprFunctorIFNS_8PrimExprERKS2_EE10I 1: tvm::tir::ExprMutator::VisitExpr_(tvm::tir::AddNode const*) 0: tvm::tir::Add::Add(tvm::PrimExpr, tvm::PrimExpr, tvm::Span) File "/home/ganler/Documents/tvm/src/relay/analysis/type_solver.cc", line 622 TVMError: --------------------------------------------------------------- An error occurred during the execution of TVM. For more information, please see: https://tvm.apache.org/docs/errors.html --------------------------------------------------------------- Check failed: (false) is false: [16:20:42] /home/ganler/Documents/tvm/src/tir/ir/expr.cc:226: --------------------------------------------------------------- An error occurred during the execution of TVM. For more information, please see: https://tvm.apache.org/docs/errors.html --------------------------------------------------------------- Check failed: (a.dtype() == b.dtype()) is false: TypeError: mismatched types. int64 vs. int32 ``` cc: @YuchenJin @junrushao1994 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org