Cookiee235 opened a new issue, #17310: URL: https://github.com/apache/tvm/issues/17310
### Actual behavior ``` Traceback (most recent call last): File "/share_container/optfuzz/res/res_ut/res_executions/30_test.py", line 50, in <module> ex = relax.build(mod, target='llvm') ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build mod = pipeline(mod) ^^^^^^^^^^^^^ File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__ return _ffi_transform_api.RunPass(self, mod) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__ raise_last_ffi_error() File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline mod = seq(mod) ^^^^^^^^ File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__ return _ffi_transform_api.RunPass(self, mod) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__ raise_last_ffi_error() File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err tvm._ffi.base.TVMError: Traceback (most recent call last): 38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue) 37: tvm::transform::Pass::operator()(tvm::IRModule) const 36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 35: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 33: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 32: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1 31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const 30: tvm::relax::CallTIRMutator::Run() 29: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&) 28: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 27: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef 26: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*) 25: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >) 24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&) 23: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 22: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef 21: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*) 20: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&) 19: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*) 18: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&) 17: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*) 16: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl 15: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*) 14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&) 13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef 11: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*) 10: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String) 9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String) 8: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&) 7: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 6: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef 5: non-virtual thunk to tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*) 4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*) 3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&) 2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*) 1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&) 0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&) File "/software/tvm/src/relax/ir/block_builder.cc", line 159 TVMError: Argument 0 type mismatch: expected R.Tensor((64, 64, 56, 56), dtype="float32"), given R.Tensor((1, 64, 56, 56), dtype="float32") ``` ### Steps to reproduce ``` import tvm from tvm import relax from tvm.script import ir as I from tvm.script import tir as T from tvm.script import relax as R @I.ir_module class Module: @T.prim_func(private=True) def conv2d(data: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), weight1: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): pad_temp = T.alloc_buffer((T.int64(1), T.int64(64), T.int64(58), T.int64(58))) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(64), T.int64(58), T.int64(58)): with T.block("pad_temp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)]) T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(57) and T.int64(1) <= v_i3 and v_i3 < T.int64(57), data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0)) for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(64), T.int64(3), T.int64(3)): with T.block("conv2d_nchw"): v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx]) T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], weight1[v_ff, v_rc, v_ry, v_rx]) T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0) conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * weight1[v_ff, v_rc, v_ry, v_rx] @T.prim_func def relu(data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32")): # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): with T.block("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(data[i, j, k, l]) T.writes(out[i, j, k, l]) out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0)) @R.function def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((64, 64, 56, 56), dtype="float32"): cls = Module with R.dataflow(): conv1 = R.call_tir(cls.conv2d, (data, weight1), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32")) relu1 = R.call_tir(cls.relu, (conv1,), out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32")) R.output(relu1) return relu1 mod = Module mod.show() ex = relax.build(mod, target='llvm') ``` The given Relax IR passed the IR validity checking but threw a crash when I built it. Could you help me review it? Thanks a lot! CC @Lunderberg @junrushao -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org