liquanfeng commented on issue #15101:
URL: https://github.com/apache/tvm/issues/15101#issuecomment-1779026435

   A very useful job! And I try it on 
tests/python/relax/test_codegen_cudnn.py::test_conv2d_offload according to 
tests/python/relax/test_vm_multi_device.py::test_multi_device as shown below
   ```python
   import numpy as np
   
   import tvm
   import tvm.testing
   import tvm.topi.testing
   from tvm import relax
   from tvm.relax.backend.contrib.cudnn import partition_for_cudnn
   from tvm.script import relax as R, ir as I
   
   from tvm.script.ir_builder import IRBuilder
   from tvm.script.ir_builder import relax as relax_builder
   
   data_shape, weight_shape, dtype = (
       (16, 32, 32, 16),
       (32, 3, 3, 16),
       "float32",
   )
   
   input_np = np.random.randn(*data_shape).astype(dtype)
   weight_np = np.random.randn(*weight_shape).astype(dtype)
   
   oc = weight_shape[0]
   bias_np = np.random.randn(1, 1, 1, oc).astype(dtype)
   args = (input_np, weight_np, bias_np)
   
   with IRBuilder() as builder:
       with relax_builder.function():
           R.func_name("main")
           data = R.arg("data", R.Tensor(data_shape, dtype))
           weight = R.arg("weight", R.Tensor(weight_shape, dtype))
           bias = R.arg("bias", R.Tensor((1, 1, 1, weight_shape[0]), dtype))
   
           with R.dataflow() as frame:
               output = R.emit(
                   R.nn.conv2d(
                       data,
                       weight,
                       out_dtype=dtype,
                       padding=(1, 1),
                       data_layout="NHWC",
                       kernel_layout="OHWI",
                   )
               )
               output = R.emit(output + bias)
   
               output = R.emit(relax.op.to_vdevice(output, I.vdevice("llvm")))
               output = R.emit(R.multiply(output, R.const(2, "float32")))
               R.output(output)
   
           R.func_ret_value(frame.output_vars[0])
   
   func = builder.get()
   mod = tvm.IRModule(
       {"main": func},
       global_infos={
           "vdevice": [
               I.vdevice("cuda", 0),
               I.vdevice("llvm"),
           ]
       },
   )
   
   mod = partition_for_cudnn(mod)
   mod = relax.transform.RunCodegen()(mod)
   
   devs = [tvm.device("cuda", 0), tvm.device("llvm")]
   mod = relax.transform.RealizeVDevice()(mod)
   mod = relax.transform.LegalizeOps()(mod)
   mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
   
   with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": 
False}):
       ex = relax.build(mod)
   vm = relax.VirtualMachine(ex, devs)
   f = vm["main"]
   inputs = [tvm.nd.array(inp, tvm.device("cuda", 0)) for inp in input_np]
   
   print(f(*inputs).numpy())
   ```
   raise following error
   ```
   Traceback (most recent call last):
     File "/workspace/yongwww/tvm/tests/byoc.py", line 77, in <module>
       ex = relax.build(mod)
     File "/workspace/yongwww/tvm/python/tvm/relax/vm_build.py", line 334, in 
build
       new_mod = lowering_passes(mod)
     File "/workspace/yongwww/tvm/python/tvm/ir/transform.py", line 238, in 
__call__
       return _ffi_transform_api.RunPass(self, mod)
     File "/workspace/yongwww/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 
239, in __call__
       raise_last_ffi_error()
     File "/workspace/yongwww/tvm/python/tvm/_ffi/base.py", line 476, in 
raise_last_ffi_error
       raise py_err
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     24: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::transform::Pass, 
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
 tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, 
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, 
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>, tvm::runtime::TVMRetValue)
     23: tvm::transform::Pass::operator()(tvm::IRModule) const
     22: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     21: tvm::transform::SequentialNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     20: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     19: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     18: _ZN3tvm7runtime13PackedFuncObj
     17: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, 
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule,
 
tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule,
 tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, 
tvm::runtime::TVMRetValue) const
     16: tvm::relax::CallTIRMutator::Run()
     15: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     14: 
_ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
     13: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
     12: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, 
tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
     11: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     10: 
_ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
     9: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
     8: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock 
const&)
     7: 
tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
     6: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
     5: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode 
const*)
     4: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode 
const*, tvm::relax::CallNode const*)
     3: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     2: 
_ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
     1: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
     0: tvm::relax::GetDeviceIndex(tvm::IRModule const&, tvm::VDevice const&)
     File "/workspace/yongwww/tvm/src/relax/transform/utils.h", line 384
   TVMError: The vdevice is not in the ir_module.
   ```
   Is there any problem with byoc?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to