This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e21b6a25a8 [Relax] Update BasePyModule with faster DLPack converter 
for tensor conversion (#18331)
e21b6a25a8 is described below

commit e21b6a25a821cdb449bd2ca3ae3975092067d4c0
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Sep 25 15:07:31 2025 -0400

    [Relax] Update BasePyModule with faster DLPack converter for tensor 
conversion (#18331)
    
    This PR enhances `BasePyModule` by integrating a faster DLPack
    converter for efficient tensor conversion between TVM and PyTorch
    following #18306.
---
 python/tvm/relax/base_py_module.py                | 58 ++++++++++++++++-------
 tests/python/relax/test_base_py_module.py         |  2 +-
 tests/python/relax/test_base_py_module_printer.py | 52 ++------------------
 3 files changed, 47 insertions(+), 65 deletions(-)

diff --git a/python/tvm/relax/base_py_module.py 
b/python/tvm/relax/base_py_module.py
index 7a790d28a7..41ef44fb30 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -32,6 +32,13 @@ try:
 except ImportError:
     to_dlpack_legacy = None
 
+try:
+    from tvm_ffi._optional_torch_c_dlpack import load_torch_c_dlpack_extension
+
+    _FASTER_DLPACK_EXTENSION = load_torch_c_dlpack_extension()
+except ImportError:
+    _FASTER_DLPACK_EXTENSION = None
+
 
 class BasePyModule:
     """Base class that allows Python functions in IRModule with DLPack 
conversion.
@@ -369,20 +376,29 @@ class BasePyModule:
         return self._convert_single_pytorch_to_tvm(tensors)
 
     def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
-        """Convert a single PyTorch tensor to TVM Tensor with robust 
fallbacks."""
+        """Convert a single PyTorch tensor to TVM Tensor with faster DLPack 
converter."""
         # pylint: disable=import-outside-toplevel
         import torch
 
         if isinstance(tensor, Tensor):
             return tensor
         if isinstance(tensor, torch.Tensor):
-            # 1. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
+            # 1. Try faster C++ DLPack converter
+            if _FASTER_DLPACK_EXTENSION is not None:
+                try:
+                    dlpack = torch.to_dlpack(tensor)
+                    return tvm.runtime.from_dlpack(dlpack)
+                except (AttributeError, ValueError):
+                    pass  # Fall through to the next method
+
+            # 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
             try:
                 dlpack = torch.to_dlpack(tensor)
                 return tvm.runtime.from_dlpack(dlpack)
             except (AttributeError, ValueError):
                 pass  # Fall through to the next method
-            # 2. Try legacy `torch.utils.dlpack.to_dlpack`
+
+            # 3. Try legacy `torch.utils.dlpack.to_dlpack`
             if to_dlpack_legacy:
                 try:
                     dlpack = to_dlpack_legacy(tensor)
@@ -392,7 +408,8 @@ class BasePyModule:
                         f"Warning: Legacy DLPack conversion failed 
({error_legacy}), "
                         f"using numpy fallback."
                     )
-            # 3. If all DLPack methods fail, use numpy fallback
+
+            # 4. If all DLPack methods fail, use numpy fallback
             numpy_array = tensor.detach().cpu().numpy()
             return tvm.runtime.tensor(numpy_array, device=self.device)
 
@@ -406,28 +423,37 @@ class BasePyModule:
             ) from error
 
     def _convert_tvm_to_pytorch(
-        self, tvm_arrays: Union[Any, List[Any]]
+        self, tvm_tensors: Union[Any, List[Any]]
     ) -> Union["torch.Tensor", List["torch.Tensor"]]:
         """Convert TVM Tensors to PyTorch tensors using DLPack."""
-        if isinstance(tvm_arrays, (list, tuple)):
-            return [self._convert_single_tvm_to_pytorch(arr) for arr in 
tvm_arrays]
-        return self._convert_single_tvm_to_pytorch(tvm_arrays)
+        if isinstance(tvm_tensors, (list, tuple)):
+            return [self._convert_single_tvm_to_pytorch(tensor) for tensor in 
tvm_tensors]
+        return self._convert_single_tvm_to_pytorch(tvm_tensors)
 
-    def _convert_single_tvm_to_pytorch(self, tvm_array: Any) -> "torch.Tensor":
-        """Convert a single TVM Tensor to PyTorch tensor using DLPack."""
+    def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) -> 
"torch.Tensor":
+        """Convert a single TVM Tensor to PyTorch tensor using faster DLPack 
converter."""
         # pylint: disable=import-outside-toplevel
         import torch
 
-        if isinstance(tvm_array, torch.Tensor):
-            return tvm_array
-        if not isinstance(tvm_array, Tensor):
-            return torch.tensor(tvm_array)
+        if isinstance(tvm_tensor, torch.Tensor):
+            return tvm_tensor
+        if not isinstance(tvm_tensor, Tensor):
+            return torch.tensor(tvm_tensor)
+
+        # 1. Try faster C++ DLPack converter
+        if _FASTER_DLPACK_EXTENSION is not None:
+            try:
+                return torch.from_dlpack(tvm_tensor)
+            except (AttributeError, ValueError):
+                pass  # Fall through to the next method
+
+        # 2. Try standard DLPack conversion
         try:
-            return torch.from_dlpack(tvm_array)
+            return torch.from_dlpack(tvm_tensor)
         # pylint: disable=broad-exception-caught
         except Exception as error:
             print(f"Warning: DLPack conversion from TVM failed ({error}), 
using numpy fallback")
-            numpy_array = tvm_array.numpy()
+            numpy_array = tvm_tensor.numpy()
             return torch.from_numpy(numpy_array)
 
     def get_function(self, name: str) -> Optional[PackedFunc]:
diff --git a/tests/python/relax/test_base_py_module.py 
b/tests/python/relax/test_base_py_module.py
index 19cc5c9eec..1f888991be 100644
--- a/tests/python/relax/test_base_py_module.py
+++ b/tests/python/relax/test_base_py_module.py
@@ -203,4 +203,4 @@ class TestBasePyModule:
 
 
 if __name__ == "__main__":
-    pytest.main([__file__])
+    tvm.testing.main()
diff --git a/tests/python/relax/test_base_py_module_printer.py 
b/tests/python/relax/test_base_py_module_printer.py
index c9d23a7465..a64b3fed5a 100644
--- a/tests/python/relax/test_base_py_module_printer.py
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -420,54 +420,6 @@ class ErrorHandlingPyFuncModule(BasePyModule):
                 Output[i] = 0.0
 
 
-if __name__ == "__main__":
-    # This allows the file to be run directly for debugging
-    # In normal pytest usage, these classes are automatically tested by 
TVMScript
-    print("All test modules defined successfully!")
-    print("TVMScript will automatically validate these modules during 
testing.")
-
-    # Demo the printer functionality
-    print("\n" + "=" * 60)
-    print("DEMO: BasePyModule Printer Functionality")
-    print("=" * 60)
-
-    # Test the printer with SimplePyFuncModule
-    try:
-        ir_mod = SimplePyFuncModule
-        device = tvm.cpu()
-        module = BasePyModule(ir_mod, device)
-
-        print("\n1. Testing script() method:")
-        print("-" * 40)
-        script_output = module.script()
-        print(script_output[:500] + "..." if len(script_output) > 500 else 
script_output)
-
-        print("\n2. Testing show() method:")
-        print("-" * 40)
-        module.show()
-
-        print("\n3. Python functions found in pyfuncs:")
-        print("-" * 40)
-        if hasattr(ir_mod, "pyfuncs"):
-            for name, func in ir_mod.pyfuncs.items():
-                print(f"  - {name}: {func}")
-        else:
-            print("  No pyfuncs attribute found")
-
-    except Exception as e:
-        print(f"Demo failed: {e}")
-        print("This is expected for testing-only TVMScript code.")
-
-    # Run all tests using tvm.testing.main()
-    print("\n" + "=" * 60)
-    print("Running all tests with tvm.testing.main()...")
-    print("=" * 60)
-
-    import tvm.testing
-
-    tvm.testing.main()
-
-
 # Pytest test functions to verify the classes work correctly
 def test_simple_pyfunc_module_creation():
     """Test that SimplePyFuncModule can be created."""
@@ -849,3 +801,7 @@ def test_call_py_func_with_base_py_module():
 
     # Use numpy for comparison since we have numpy arrays
     np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, 
atol=1e-5)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to