This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a54af64872 [Relax][Backend] Implement R.call_py_func operator for 
calling Python functions from compiled TVM (#18326)
a54af64872 is described below

commit a54af64872c68913309541f6f30e75da3921ef77
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Sep 21 23:36:18 2025 -0400

    [Relax][Backend] Implement R.call_py_func operator for calling Python 
functions from compiled TVM (#18326)
    
    This PR implements the `R.call_py_func` operator that allows compiled
    TVM Relax modules to call Python functions at runtime. This enables
    integration between TVM's compiled code and Python through a
    robust VM backend implementation.
    
    #### Simple Usage with BasePyModule
    ```python
    @I.ir_module
    class MyModule(BasePyModule):
        @I.pyfunc
        def torch_relu(self, x):
            return torch.relu(x)
    
        @R.function
        def forward(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), 
"float32"):
            return R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), 
"float32"))
    ```
    
    #### Direct VM Backend Usage (Manual)
    ```python
    # Manually register Python function with VM backend
    register_func = tvm.get_global_func("vm.builtin.register_py_func")
    register_func("my_func", my_python_function)
    
    # Use in Relax function (compiled to VM backend)
    @R.function
    def test(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
        return R.call_py_func("my_func", (x,), out_sinfo=R.Tensor((5,), 
"float32"))
    
    # Manual cleanup (required for direct VM backend usage)
    clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry")
    clear_func()
    ```
---
 python/tvm/relax/base_py_module.py                | 38 +++++++++
 src/relax/backend/vm/codegen_vm.cc                |  1 -
 src/relax/backend/vm/lower_runtime_builtin.cc     | 20 +++++
 src/runtime/vm/builtin.cc                         | 74 +++++++++++++++++
 tests/python/relax/test_base_py_module_printer.py | 96 ++++++++++-------------
 tests/python/relax/test_relax_operators.py        | 76 ++++++++++++++++++
 6 files changed, 248 insertions(+), 57 deletions(-)

diff --git a/python/tvm/relax/base_py_module.py 
b/python/tvm/relax/base_py_module.py
index 52f813dc6b..7a790d28a7 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -45,6 +45,14 @@ class BasePyModule:
     Only IRModules that inherit from this class are allowed to contain Python 
functions.
     """
 
+    def __del__(self):
+        """Clean up registered Python functions on module destruction."""
+        try:
+            clear_func = 
tvm.get_global_func("vm.builtin.clear_py_func_registry")
+            clear_func()
+        except (ValueError, AttributeError):
+            pass
+
     def __init__(
         self,
         ir_mod: IRModule,
@@ -100,6 +108,7 @@ class BasePyModule:
         self._compile_functions()
         self._wrap_tir_functions()
         self._wrap_relax_functions()
+        self._register_python_functions()
 
     def _collect_function_names(self):
         """Collect names of TIR and Relax functions from IRModule."""
@@ -177,6 +186,35 @@ class BasePyModule:
 
             setattr(self, func_name, _create_relax_wrapper(func_name))
 
+    def _register_python_functions(self):
+        """Register Python functions with the VM runtime for call_py_func 
support."""
+        if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs:
+            return
+
+        try:
+            register_py_func = 
tvm.get_global_func("vm.builtin.register_py_func")
+        except ValueError:
+            return
+
+        for func_name, py_func in self.ir_mod.pyfuncs.items():
+
+            def create_py_func_wrapper(name, original_func):
+                def wrapper(*args, **kwargs):
+                    converted_args = [self._convert_tvm_to_pytorch(arg) for 
arg in args]
+                    converted_kwargs = {
+                        k: self._convert_tvm_to_pytorch(v) for k, v in 
kwargs.items()
+                    }
+
+                    result = original_func(self, *converted_args, 
**converted_kwargs)
+
+                    return self._convert_pytorch_to_tvm(result)
+
+                wrapper.__name__ = name
+                return wrapper
+
+            wrapped_func = create_py_func_wrapper(func_name, py_func)
+            register_py_func(func_name, wrapped_func)
+
     def call_tir(self, tir_func, args, out_sinfo):
         """Call a TIR function with PyTorch tensors."""
         # Try to get function name from different sources
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index 96dac05cb6..e2d9b5b068 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -368,7 +368,6 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const 
Expr&)> {
 
     builder_->EmitCall(func, args, dst_reg);
   }
-
   void EmitNormalCall(const Call& call_node, RegName dst_reg) {
     Instruction::Arg func = VisitExpr(call_node->op);
     std::vector<Instruction::Arg> args = VisitArray(call_node->args);
diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc 
b/src/relax/backend/vm/lower_runtime_builtin.cc
index d52155c615..71b8413e98 100644
--- a/src/relax/backend/vm/lower_runtime_builtin.cc
+++ b/src/relax/backend/vm/lower_runtime_builtin.cc
@@ -24,6 +24,7 @@
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/attrs/op.h>
 #include <tvm/relax/backend.h>
+#include <tvm/relax/expr.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/op_attr_types.h>
 #include <tvm/relax/type.h>
@@ -52,6 +53,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
       return ShapeOf(call);
     } else if (call->op == tensor_to_shape_op_) {
       return TensorToShape(call);
+    } else if (call->op == call_py_func_op_) {
+      return CallPyFunc(call);
     } else if (call->op == to_vdevice_op_) {
       return ToDevice(call);
     } else if (call->op == make_closure_op_) {
@@ -139,6 +142,21 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
     return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), 
{GetStructInfo(call_node)});
   }
 
+  Expr CallPyFunc(const Call& call_node) {
+    ICHECK(call_node->args.size() == 2);
+    ICHECK(call_node->struct_info_.defined());
+
+    // Create tuple with function name and arguments tuple
+    ffi::Array<Expr> tuple_fields;
+    tuple_fields.push_back(call_node->args[0]);  // function name
+    tuple_fields.push_back(call_node->args[1]);  // arguments tuple
+    auto combined_tuple = Tuple(tuple_fields);
+
+    // Direct call to vm.builtin.call_py_func
+    return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, 
call_node->sinfo_args,
+                call_node->span);
+  }
+
   Expr ToDevice(const Call& call_node) {
     // TODO(yongwww): replace ToVDeviceAttrs with related Expr
     ICHECK(call_node->args.size() == 1);
@@ -198,6 +216,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
   const Op& reshape_op_ = Op::Get("relax.reshape");
   const Op& shape_of_op_ = Op::Get("relax.shape_of");
   const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
+  const Op& call_py_func_op_ = Op::Get("relax.call_py_func");
   const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
   const Op& make_closure_op_ = Op::Get("relax.make_closure");
   const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
@@ -216,6 +235,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
   const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
   const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
   const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"};
+  const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"};
   const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
   const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
   const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 362a7e4c89..41c011678e 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -34,6 +34,8 @@
 #include <tvm/runtime/vm/bytecode.h>
 #include <tvm/runtime/vm/vm.h>
 
+#include <unordered_map>
+
 namespace tvm {
 namespace runtime {
 namespace vm {
@@ -430,6 +432,78 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       });
 }
 
+//-------------------------------------
+//  Python function call support
+//-------------------------------------
+
+// Global registry for Python functions
+static std::unordered_map<std::string, ffi::Function> py_func_registry;
+
+/*!
+ * \brief Clear the Python function registry on shutdown
+ */
+void ClearPyFuncRegistry() { py_func_registry.clear(); }
+
+/*!
+ * \brief Register a Python function for call_py_func
+ * \param name The function name
+ * \param func The Python function wrapped as ffi::Function
+ */
+void RegisterPyFunc(const std::string& name, ffi::Function func) { 
py_func_registry[name] = func; }
+
+/*!
+ * \brief Get a registered Python function
+ * \param name The function name
+ * \return The Python function
+ */
+ffi::Function GetPyFunc(const std::string& name) {
+  auto it = py_func_registry.find(name);
+  if (it == py_func_registry.end()) {
+    LOG(FATAL) << "Python function '" << name << "' not found in registry";
+  }
+  return it->second;
+}
+
+/*!
+ * \brief Call a Python function from VM
+ * \param args The packed function arguments (tuple containing function name 
and arguments)
+ * \param rv The return value
+ */
+void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) {
+  // args[0] should be a tuple containing (func_name, args_tuple)
+  if (args.size() != 1) {
+    LOG(FATAL) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)";
+  }
+
+  auto tuple_arg = args[0].cast<ffi::Array<ffi::Any>>();
+  if (tuple_arg.size() != 2) {
+    LOG(FATAL) << "vm.builtin.call_py_func tuple should contain (func_name, 
args)";
+  }
+
+  // Get function name
+  std::string func_name = tuple_arg[0].cast<ffi::String>();
+
+  // Get arguments tuple
+  auto func_args = tuple_arg[1].cast<ffi::Array<ffi::Any>>();
+
+  // Look up Python function in registry
+  ffi::Function py_func = GetPyFunc(func_name);
+
+  // Call the Python function with the arguments
+  std::vector<ffi::AnyView> py_args_vec(func_args.begin(), func_args.end());
+  ffi::PackedArgs py_args(py_args_vec.data(), py_args_vec.size());
+  py_func.CallPacked(py_args, rv);
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef()
+      .def_packed("vm.builtin.call_py_func", CallPyFunc)
+      .def("vm.builtin.register_py_func", RegisterPyFunc)
+      .def("vm.builtin.get_py_func", GetPyFunc)
+      .def("vm.builtin.clear_py_func_registry", ClearPyFuncRegistry);
+}
+
 //-------------------------------------
 //  Builtin runtime operators.
 //-------------------------------------
diff --git a/tests/python/relax/test_base_py_module_printer.py 
b/tests/python/relax/test_base_py_module_printer.py
index 6e87174fda..c9d23a7465 100644
--- a/tests/python/relax/test_base_py_module_printer.py
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -760,43 +760,54 @@ def test_python_functions_in_irmodule():
         pytest.fail("pyfuncs attribute not found in IRModule")
 
 
-def test_call_py_func_validation():
-    """Test call_py_func validation and error handling."""
+def test_call_py_func_with_base_py_module():
+    """Test R.call_py_func with BasePyModule."""
     import torch
+    import numpy as np
+    from tvm.relax.op import call_py_func
+    from tvm.relax.expr import StringImm
+    from tvm.relax import Var, TensorStructInfo
 
-    @I.ir_module
-    class ValidationTestModule(BasePyModule):
-        """Test module for validation."""
+    # Test 1: Operator creation and basic properties
+    x = Var("x", TensorStructInfo((5,), "float32"))
+    y = Var("y", TensorStructInfo((5,), "float32"))
 
-        @I.pyfunc
-        def valid_func(self, x):
-            """Valid Python function."""
-            return x * 2
+    call_expr = call_py_func(StringImm("test_func"), (x, y), 
out_sinfo=R.Tensor((5,), "float32"))
 
+    assert call_expr.op.name == "relax.call_py_func"
+    assert call_expr.args[0].value == "test_func"
+    assert len(call_expr.args) == 2
+
+    # Test 2: Compilation validation
+    try:
+        call_py_func(
+            "invalid",
+            (Var("x", TensorStructInfo((5,), "float32")),),
+            out_sinfo=R.Tensor((5,), "float32"),
+        )
+        assert False, "Should raise type error"
+    except Exception as e:
+        assert "Mismatched type" in str(e) or "Expected" in str(e)
+
+    # Test 3: Validation and error handling
+    @I.ir_module
+    class ValidationTestModule(BasePyModule):
         @R.function
         def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), 
"float32"):
-            # This should cause a validation error
             result = R.call_py_func("non_existent_func", (x,), 
out_sinfo=R.Tensor((5,), "float32"))
             return result
 
     device = tvm.cpu()
     module = ValidationTestModule(device)
 
-    # Test that calling non-existent function raises error
     x = torch.randn(5, dtype=torch.float32)
 
     with pytest.raises(ValueError, match="Python function 'non_existent_func' 
not found"):
         module.call_py_func("non_existent_func", [x])
 
-
-def test_call_py_func_in_relax_function():
-    """Test using call_py_func within Relax functions."""
-    import torch
-
+    # Test 4: Using call_py_func within Relax functions
     @I.ir_module
     class RelaxCallPyFuncModule(BasePyModule):
-        """Test module with call_py_func in Relax functions."""
-
         @I.pyfunc
         def torch_relu(self, x):
             """PyTorch ReLU implementation."""
@@ -809,9 +820,7 @@ def test_call_py_func_in_relax_function():
 
         @R.function
         def mixed_computation(x: R.Tensor((10,), "float32")) -> 
R.Tensor((10,), "float32"):
-            # Use Python function for ReLU
             relu_result = R.call_py_func("torch_relu", (x,), 
out_sinfo=R.Tensor((10,), "float32"))
-            # Use Python function for softmax
             final_result = R.call_py_func(
                 "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), 
"float32")
             )
@@ -820,7 +829,6 @@ def test_call_py_func_in_relax_function():
     device = tvm.cpu()
     module = RelaxCallPyFuncModule(device)
 
-    # Test the mixed computation
     x = torch.randn(10, dtype=torch.float32)
 
     expected = torch.softmax(torch.relu(x), dim=0)
@@ -828,40 +836,16 @@ def test_call_py_func_in_relax_function():
     relu_result = module.call_py_func("torch_relu", [x])
     final_result = module.call_py_func("torch_softmax", [relu_result])
 
-    assert torch.allclose(final_result, expected, atol=1e-5)
-
-
-def test_call_py_func_operator_creation():
-    """Test R.call_py_func operator creation and basic properties."""
-    from tvm.relax.op import call_py_func
-    from tvm.relax.expr import StringImm
-    from tvm.relax import Var, TensorStructInfo
-
-    # Create variables
-    x = Var("x", TensorStructInfo((5,), "float32"))
-    y = Var("y", TensorStructInfo((5,), "float32"))
-
-    # Create call_py_func call
-    call_expr = call_py_func(StringImm("test_func"), (x, y), 
out_sinfo=R.Tensor((5,), "float32"))
-
-    # Verify operator properties
-    assert call_expr.op.name == "relax.call_py_func"
-    assert call_expr.args[0].value == "test_func"
-    assert len(call_expr.args) == 2
-
+    # Convert to numpy for comparison
+    if isinstance(final_result, tvm.runtime.Tensor):
+        final_result_np = final_result.numpy()
+    else:
+        final_result_np = final_result
 
-def test_call_py_func_compilation_validation():
-    """Test call_py_func compilation validation."""
-    from tvm.relax.op import call_py_func
-    from tvm.relax import Var, TensorStructInfo
+    if isinstance(expected, torch.Tensor):
+        expected_np = expected.numpy()
+    else:
+        expected_np = expected
 
-    # Test operator parameter validation
-    try:
-        call_py_func(
-            "invalid",
-            (Var("x", TensorStructInfo((5,), "float32")),),
-            out_sinfo=R.Tensor((5,), "float32"),
-        )
-        assert False, "Should raise type error"
-    except Exception as e:
-        assert "Mismatched type" in str(e) or "Expected" in str(e)
+    # Use numpy for comparison since we have numpy arrays
+    np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, 
atol=1e-5)
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index 8558f6e911..897082dd79 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -409,6 +409,82 @@ def test_op_call_inplace_packed(exec_mode):
     assert (result[1].numpy() == sum).all()
 
 
+def test_op_call_py_func(exec_mode):
+    """Test R.call_py_func operator functionality."""
+    import torch
+
+    def torch_relu(x):
+        if isinstance(x, tvm.runtime.Tensor):
+            x_torch = torch.from_numpy(x.numpy())
+        elif hasattr(x, "asnumpy"):
+            x_torch = torch.from_numpy(x.asnumpy())
+        else:
+            x_np = np.array(x)
+            if isinstance(x_np, tvm.runtime.Tensor):
+                x_torch = torch.from_numpy(x_np.numpy())
+            elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor):
+                x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np]))
+                if x_torch.ndim > 1:
+                    x_torch = x_torch.flatten()
+            else:
+                x_torch = torch.from_numpy(x_np)
+        result = torch.relu(x_torch)
+        return tvm.runtime.tensor(result.numpy())
+
+    def torch_sigmoid(x):
+        if isinstance(x, tvm.runtime.Tensor):
+            x_torch = torch.from_numpy(x.numpy())
+        elif hasattr(x, "asnumpy"):
+            x_torch = torch.from_numpy(x.asnumpy())
+        else:
+            x_np = np.array(x)
+            if isinstance(x_np, tvm.runtime.Tensor):
+                x_torch = torch.from_numpy(x_np.numpy())
+            elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor):
+                x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np]))
+                if x_torch.ndim > 1:
+                    x_torch = x_torch.flatten()
+            else:
+                x_torch = torch.from_numpy(x_np)
+        result = torch.sigmoid(x_torch)
+        return tvm.runtime.tensor(result.numpy())
+
+    register_func = tvm.get_global_func("vm.builtin.register_py_func")
+    register_func("torch_relu", torch_relu)
+    register_func("torch_sigmoid", torch_sigmoid)
+
+    @tvm.script.ir_module
+    class CallPyFuncTest:
+        @R.function
+        def simple_call(x: R.Tensor((3,), "float32")):
+            result = R.call_py_func(R.str("torch_relu"), (x,), 
out_sinfo=R.Tensor((3,), "float32"))
+            return result
+
+        @R.function
+        def multiple_calls(x: R.Tensor((2,), "float32")):
+            y = R.call_py_func(R.str("torch_relu"), (x,), 
out_sinfo=R.Tensor((2,), "float32"))
+            z = R.call_py_func(R.str("torch_sigmoid"), (y,), 
out_sinfo=R.Tensor((2,), "float32"))
+            return z
+
+    np.random.seed(0)
+    x_data = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
+    x_tvm = tvm.runtime.tensor(x_data)
+
+    result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode)
+    expected = np.maximum(x_data, 0.0)
+    assert (result.numpy() == expected).all()
+
+    y_data = np.array([-0.5, 0.5], dtype=np.float32)
+    y_tvm = tvm.runtime.tensor(y_data)
+
+    result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm, 
exec_mode=exec_mode)
+    expected2 = 1.0 / (1.0 + np.exp(-np.maximum(y_data, 0.0)))
+    assert (result2.numpy() == expected2).all()
+
+    clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry")
+    clear_func()
+
+
 def test_op_to_device(exec_mode):
     @tvm.script.ir_module
     class CallToDevice:

Reply via email to