Cookiee235 commented on issue #17211:
URL: https://github.com/apache/tvm/issues/17211#issuecomment-2256589902

   A similar bug occurs as shown below. 
   Based on what I saw. The well-formed checker commonly corrects the return 
type and shape. However, when the type of relax function return var is 
`R.Tuple()`, the well-formed checker seems not to work. 
   
   ### Actual behavior
   ```
   Traceback (most recent call last):
     File "/share_container/optfuzz/res/bugs/res_type.py", line 82, in <module>
       mod_outputs = vm['main'](input_0, input_1)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in 
__call__
       raise_last_ffi_error()
     File "/software/tvm/python/tvm/_ffi/base.py", line 481, in 
raise_last_ffi_error
       raise py_err
   ValueError: Traceback (most recent call last):
     8: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::_LookupFunction(tvm::runtime::String
 const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     7: 
tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef
 const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     6: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::GetClosureInternal(tvm::runtime::String
 const&, bool)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     5: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, 
std::vector<tvm::runtime::TVMRetValue, 
std::allocator<tvm::runtime::TVMRetValue> > const&)
     4: tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()
     3: 
tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*,
 tvm::runtime::relax_vm::Instruction)
     2: 
tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef
 const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     1: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16Pack
     0: tvm::runtime::relax_vm::CheckTensorInfo(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     File "/software/tvm/src/runtime/relax_vm/builtin.cc", line 247
   ValueError: Check failed: (DataType(ptr->dl_tensor.dtype) == dtype) is 
false: ErrorContext(fn=main, loc=return, annotation=R.Tuple(R.Tensor((16, 16), 
dtype="int32"), R.Tensor((32, 32), dtype="float32")))  expect Tensor with dtype 
float32 but get int32
   ```
   
   ### Steps to reproduce
   ```
   import tvm
   from tvm import relax
   import numpy as np
   
   from tvm.script import ir as I
   from tvm.script import tir as T
   from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def ones(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
               with T.block("T_full"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads()
                   T.writes(T_full[v_ax0, v_ax1])
                   T_full[v_ax0, v_ax1] = 1
   
       @T.prim_func(private=True)
       def zeros(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
               with T.block("T_full"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads()
                   T.writes(T_full[v_ax0, v_ax1])
                   T_full[v_ax0, v_ax1] = 0
       @T.prim_func(private=True)
       def zeros1(T_full: T.Buffer((T.int64(32), T.int64(32)), "int32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
               with T.block("T_full"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads()
                   T.writes(T_full[v_ax0, v_ax1])
                   T_full[v_ax0, v_ax1] = 0
   
       @R.function(private=True)
       def func() -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 
16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
           cls = Module
           A = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((16, 16), 
dtype="int32"))
           B = R.call_tir(cls.ones, R.tuple(), out_sinfo=R.Tensor((16, 16), 
dtype="int32"))
           C = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((32, 32), 
dtype="int32"))
           return (A, B, C)
   
       @R.function
       def main_2() -> R.Tuple(R.Tensor, R.Tensor):
           cls = Module
           args: R.Tuple(R.Tensor, R.Tensor, R.Tensor) = cls.func()
           gv1: R.Tensor = args[0]
           gv2: R.Tensor = args[2]
           return (gv1, gv2)
       @R.function
       def main(v3_0: R.Tensor((1, 22, 1), dtype="float16"), v6_0: R.Tensor((1, 
37), dtype="float16")) -> R.Tuple(R.Tensor((16, 16), dtype="int32"), 
R.Tensor((32, 32), dtype="float32")):  # if return value is a tuple, well_form 
checker cannot correct it!
           R.func_attr({"num_input": 1})
           cls = Module
           with R.dataflow():
               res: R.Tuple(R.Tensor, R.Tensor) = cls.main_2()
               R.output(res)
           return res
   
   
   mod = Module
   mod.show()
   mod = tvm.relax.transform.LegalizeOps()(mod)
   
   mod = relax.transform.FuseTIR()(mod)
   mod = relax.transform.LambdaLift()(mod)
   ex = relax.build(mod, target='llvm')
   vm = relax.VirtualMachine(ex, tvm.cpu())
   
   input_0 = tvm.nd.array(10 * np.random.random([1, 22, 1]).astype('float16'))
   input_1 = tvm.nd.array(10 * np.random.random([1, 37]).astype('float16'))
   mod_outputs = vm['main'](input_0, input_1)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to