This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 293b58099b [Docs] Fix RPC tutorial to use set_input + invoke_stateful 
API (#18855)
293b58099b is described below

commit 293b58099bd6f676d8eac5efe4f8315a46290438
Author: landern <[email protected]>
AuthorDate: Sun Mar 1 03:39:41 2026 +0800

    [Docs] Fix RPC tutorial to use set_input + invoke_stateful API (#18855)
    
    ## Problem
    Issue #18824: The Cross Compilation and RPC tutorial fails when running
    over RPC with error:
    Mismatched type on argument #0 when calling: vm.builtin.reshape(0:
    ffi.Tensor, 1: ffi.Shape) -> ffi.Tensor
    
    
    This happens because `tvm.runtime.tensor()` creates tensors that become
    `DLTensor*` when transmitted via RPC, but VM builtins like
    `vm.builtin.reshape` expect `ffi.Tensor`.
    
    ## Solution
    Update the tutorial to use `set_input()` + `invoke_stateful()` API
    instead of direct function call `vm["main"](...)`. This API is designed
    for RPC and handles the tensor type conversion internally.
    
    ## Changes
    - Updated Step 5 to use `vm.set_input()`, `vm.invoke_stateful()`, and
    `vm.get_outputs()`
    - Updated Step 6 to use `time_evaluator("invoke_stateful", ...)` for RPC
    compatibility
    - Added explanatory comments about why this approach is needed
    
    Fixes #18824
    
    Co-authored-by: igotyuandme320 <[email protected]>
---
 docs/how_to/tutorials/cross_compilation_and_rpc.py | 24 ++++++++++++++++++----
 1 file changed, 20 insertions(+), 4 deletions(-)

diff --git a/docs/how_to/tutorials/cross_compilation_and_rpc.py 
b/docs/how_to/tutorials/cross_compilation_and_rpc.py
index a6e9bc8e71..0e6346619a 100644
--- a/docs/how_to/tutorials/cross_compilation_and_rpc.py
+++ b/docs/how_to/tutorials/cross_compilation_and_rpc.py
@@ -462,13 +462,20 @@ def run_pytorch_model_via_rpc():
     # Step 5: Run Inference on Remote Device
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
     # Execute the model on the remote ARM device and retrieve results
+    #
+    # Note: When running VM over RPC, we use set_input() and invoke_stateful()
+    # instead of direct function call (vm["main"](...)). This is because RPC
+    # transmits tensors as DLTensor*, while VM builtins expect ffi.Tensor.
+    # The set_input API handles this conversion internally.
 
     # Prepare input data
     input_data = np.random.randn(1, 1, 28, 28).astype("float32")
     remote_input = tvm.runtime.tensor(input_data, dev)
 
-    # Run inference on remote device
-    output = vm["main"](remote_input, *remote_params)
+    # Run inference using set_input + invoke_stateful for RPC compatibility
+    vm.set_input("main", remote_input, *remote_params)
+    vm.invoke_stateful("main")
+    output = vm.get_outputs("main")
 
     # Extract result (handle both tuple and single tensor outputs)
     if isinstance(output, tvm.ir.Array) and len(output) > 0:
@@ -482,13 +489,22 @@ def run_pytorch_model_via_rpc():
     print(f"  Output shape: {result_np.shape}")
     print(f"  Predicted class: {np.argmax(result_np)}")
 
+    ######################################################################
+    # Alternative: Direct Function Call (Local Only)
+    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+    # Note: The direct call syntax vm["main"](input, *params) works for
+    # local execution but may fail over RPC due to type mismatch between
+    # DLTensor* (RPC) and ffi.Tensor (VM builtins). For RPC, always use
+    # the set_input + invoke_stateful pattern shown above.
+
     ######################################################################
     # Step 6: Performance Evaluation (Optional)
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
     # Measure inference time on the remote device, excluding network overhead
+    # Note: For RPC, use invoke_stateful with time_evaluator
 
-    time_f = vm.time_evaluator("main", dev, number=10, repeat=3)
-    prof_res = time_f(remote_input, *remote_params)
+    time_f = vm.time_evaluator("invoke_stateful", dev, number=10, repeat=3)
+    prof_res = time_f("main")
     print(f"Inference time on remote device: {prof_res.mean * 1000:.2f} ms")
 
     ######################################################################

Reply via email to