Cookiee235 opened a new issue, #17231:
URL: https://github.com/apache/tvm/issues/17231

   ### Actual behavior
   
   ```
   Traceback (most recent call last):
     File "/share_container/optfuzz/res/bugs/simple/res_undefined.py", line 49, 
in <module>
       compiled_after = compile_mod(relax.transform.LiftTransformParams()(mod))
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/share_container/optfuzz/res/bugs/simple/res_undefined.py", line 41, 
in compile_mod
       ex = relax.build(mod, target="llvm")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/software/tvm-lunder/python/tvm/relax/vm_build.py", line 340, in 
build
       mod = _vmcodegen(builder, mod, exec_mode)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/software/tvm-lunder/python/tvm/relax/vm_build.py", line 176, in 
_vmcodegen
       return _ffi_api.VMCodeGen(builder, mod)  # type:ignore
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 
240, in __call__
       raise_last_ffi_error()
     File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in 
raise_last_ffi_error
       raise py_err
   tvm.error.InternalError: Traceback (most recent call last):
     7: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::relax::ExecBuilder, tvm::IRModule)>::AssignTypedLambda<tvm::IRModule 
(*)(tvm::relax::ExecBuilder, tvm::IRModule)>(tvm::IRModule 
(*)(tvm::relax::ExecBuilder, tvm::IRModule), std::__cxx11::basic_string<char, 
std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs 
const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj 
const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     6: tvm::relax::relax_vm::VMCodeGen(tvm::relax::ExecBuilder, tvm::IRModule)
     5: tvm::relax::relax_vm::CodeGenVM::Run(tvm::relax::ExecBuilder, 
tvm::IRModule)
     4: tvm::relax::relax_vm::CodeGenVM::Codegen(tvm::relax::Function const&)
     3: tvm::relax::ExprFunctor<tvm::runtime::relax_vm::Instruction::Arg 
(tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     2: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::SeqExprNode 
const*)
     1: tvm::relax::ExprFunctor<tvm::runtime::relax_vm::Instruction::Arg 
(tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     0: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::VarNode const*)
     File "/software/tvm-lunder/src/relax/backend/vm/codegen_vm.cc", line 232
   InternalError: Check failed: (it != this->var_arg_map_.end()) is false: Var 
w1_t is not defined
   ```
   
   
   ### Steps to reproduce
   ```
   import tvm
   from tvm import relax
   import numpy as np
   from tvm.script import ir as I
   from tvm.script import tir as T
   from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def transpose(w1: T.Buffer((T.int64(256), T.int64(256)), "float32"), 
T_transpose: T.Buffer((T.int64(256), T.int64(256)), "float32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(256), T.int64(256)):
               with T.block("T_transpose"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(w1[v_ax1, v_ax0])
                   T.writes(T_transpose[v_ax0, v_ax1])
                   T_transpose[v_ax0, v_ax1] = w1[v_ax1, v_ax0]
   
       @R.function(private=False)
       def main(x: R.Tensor((256, 256), dtype="float32"), w1: R.Tensor((256, 
256), dtype="float32")) -> R.Tensor((256, 256), dtype="float32"):
           R.func_attr({"num_input": 1})
           cls = Module
           with R.dataflow():
               w1_t = R.call_tir(cls.transpose, (w1,), out_sinfo=R.Tensor((256, 
256), dtype="float32"))
               R.output(w1_t)
           return w1_t
   
   mod = Module
   mod.show()
   mod = tvm.relax.transform.LegalizeOps()(mod)
   
   
   input_0 = tvm.nd.array(10 * np.random.random([256, 256]).astype('float32'))
   input_1 = tvm.nd.array(10 * np.random.random([256, 256]).astype('float32'))
   
   def compile_mod(mod):
       mod = relax.transform.FuseTIR()(mod)
       mod = relax.transform.LambdaLift()(mod)
       ex = relax.build(mod, target="llvm")
       vm = relax.VirtualMachine(ex, tvm.cpu())
       return vm
   
   
   compiled_before = compile_mod(mod)
   before_outputs = compiled_before["main"](input_0, input_1)
   
   compiled_after = compile_mod(relax.transform.LiftTransformParams()(mod))
   transformed_weights = compiled_after["main_transform_params"]([input_1])
   after_outputs = compiled_after["main"](input_0, *transformed_weights)
   
   ```
   
   cc @Lunderberg @junrushao 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to