This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e21b6a25a8 [Relax] Update BasePyModule with faster DLPack converter
for tensor conversion (#18331)
e21b6a25a8 is described below
commit e21b6a25a821cdb449bd2ca3ae3975092067d4c0
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Sep 25 15:07:31 2025 -0400
[Relax] Update BasePyModule with faster DLPack converter for tensor
conversion (#18331)
This PR enhances `BasePyModule` by integrating a faster DLPack
converter for efficient tensor conversion between TVM and PyTorch
following #18306.
---
python/tvm/relax/base_py_module.py | 58 ++++++++++++++++-------
tests/python/relax/test_base_py_module.py | 2 +-
tests/python/relax/test_base_py_module_printer.py | 52 ++------------------
3 files changed, 47 insertions(+), 65 deletions(-)
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index 7a790d28a7..41ef44fb30 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -32,6 +32,13 @@ try:
except ImportError:
to_dlpack_legacy = None
+try:
+ from tvm_ffi._optional_torch_c_dlpack import load_torch_c_dlpack_extension
+
+ _FASTER_DLPACK_EXTENSION = load_torch_c_dlpack_extension()
+except ImportError:
+ _FASTER_DLPACK_EXTENSION = None
+
class BasePyModule:
"""Base class that allows Python functions in IRModule with DLPack
conversion.
@@ -369,20 +376,29 @@ class BasePyModule:
return self._convert_single_pytorch_to_tvm(tensors)
def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
- """Convert a single PyTorch tensor to TVM Tensor with robust
fallbacks."""
+ """Convert a single PyTorch tensor to TVM Tensor with faster DLPack
converter."""
# pylint: disable=import-outside-toplevel
import torch
if isinstance(tensor, Tensor):
return tensor
if isinstance(tensor, torch.Tensor):
- # 1. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
+ # 1. Try faster C++ DLPack converter
+ if _FASTER_DLPACK_EXTENSION is not None:
+ try:
+ dlpack = torch.to_dlpack(tensor)
+ return tvm.runtime.from_dlpack(dlpack)
+ except (AttributeError, ValueError):
+ pass # Fall through to the next method
+
+ # 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
try:
dlpack = torch.to_dlpack(tensor)
return tvm.runtime.from_dlpack(dlpack)
except (AttributeError, ValueError):
pass # Fall through to the next method
- # 2. Try legacy `torch.utils.dlpack.to_dlpack`
+
+ # 3. Try legacy `torch.utils.dlpack.to_dlpack`
if to_dlpack_legacy:
try:
dlpack = to_dlpack_legacy(tensor)
@@ -392,7 +408,8 @@ class BasePyModule:
f"Warning: Legacy DLPack conversion failed
({error_legacy}), "
f"using numpy fallback."
)
- # 3. If all DLPack methods fail, use numpy fallback
+
+ # 4. If all DLPack methods fail, use numpy fallback
numpy_array = tensor.detach().cpu().numpy()
return tvm.runtime.tensor(numpy_array, device=self.device)
@@ -406,28 +423,37 @@ class BasePyModule:
) from error
def _convert_tvm_to_pytorch(
- self, tvm_arrays: Union[Any, List[Any]]
+ self, tvm_tensors: Union[Any, List[Any]]
) -> Union["torch.Tensor", List["torch.Tensor"]]:
"""Convert TVM Tensors to PyTorch tensors using DLPack."""
- if isinstance(tvm_arrays, (list, tuple)):
- return [self._convert_single_tvm_to_pytorch(arr) for arr in
tvm_arrays]
- return self._convert_single_tvm_to_pytorch(tvm_arrays)
+ if isinstance(tvm_tensors, (list, tuple)):
+ return [self._convert_single_tvm_to_pytorch(tensor) for tensor in
tvm_tensors]
+ return self._convert_single_tvm_to_pytorch(tvm_tensors)
- def _convert_single_tvm_to_pytorch(self, tvm_array: Any) -> "torch.Tensor":
- """Convert a single TVM Tensor to PyTorch tensor using DLPack."""
+ def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) ->
"torch.Tensor":
+ """Convert a single TVM Tensor to PyTorch tensor using faster DLPack
converter."""
# pylint: disable=import-outside-toplevel
import torch
- if isinstance(tvm_array, torch.Tensor):
- return tvm_array
- if not isinstance(tvm_array, Tensor):
- return torch.tensor(tvm_array)
+ if isinstance(tvm_tensor, torch.Tensor):
+ return tvm_tensor
+ if not isinstance(tvm_tensor, Tensor):
+ return torch.tensor(tvm_tensor)
+
+ # 1. Try faster C++ DLPack converter
+ if _FASTER_DLPACK_EXTENSION is not None:
+ try:
+ return torch.from_dlpack(tvm_tensor)
+ except (AttributeError, ValueError):
+ pass # Fall through to the next method
+
+ # 2. Try standard DLPack conversion
try:
- return torch.from_dlpack(tvm_array)
+ return torch.from_dlpack(tvm_tensor)
# pylint: disable=broad-exception-caught
except Exception as error:
print(f"Warning: DLPack conversion from TVM failed ({error}),
using numpy fallback")
- numpy_array = tvm_array.numpy()
+ numpy_array = tvm_tensor.numpy()
return torch.from_numpy(numpy_array)
def get_function(self, name: str) -> Optional[PackedFunc]:
diff --git a/tests/python/relax/test_base_py_module.py
b/tests/python/relax/test_base_py_module.py
index 19cc5c9eec..1f888991be 100644
--- a/tests/python/relax/test_base_py_module.py
+++ b/tests/python/relax/test_base_py_module.py
@@ -203,4 +203,4 @@ class TestBasePyModule:
if __name__ == "__main__":
- pytest.main([__file__])
+ tvm.testing.main()
diff --git a/tests/python/relax/test_base_py_module_printer.py
b/tests/python/relax/test_base_py_module_printer.py
index c9d23a7465..a64b3fed5a 100644
--- a/tests/python/relax/test_base_py_module_printer.py
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -420,54 +420,6 @@ class ErrorHandlingPyFuncModule(BasePyModule):
Output[i] = 0.0
-if __name__ == "__main__":
- # This allows the file to be run directly for debugging
- # In normal pytest usage, these classes are automatically tested by
TVMScript
- print("All test modules defined successfully!")
- print("TVMScript will automatically validate these modules during
testing.")
-
- # Demo the printer functionality
- print("\n" + "=" * 60)
- print("DEMO: BasePyModule Printer Functionality")
- print("=" * 60)
-
- # Test the printer with SimplePyFuncModule
- try:
- ir_mod = SimplePyFuncModule
- device = tvm.cpu()
- module = BasePyModule(ir_mod, device)
-
- print("\n1. Testing script() method:")
- print("-" * 40)
- script_output = module.script()
- print(script_output[:500] + "..." if len(script_output) > 500 else
script_output)
-
- print("\n2. Testing show() method:")
- print("-" * 40)
- module.show()
-
- print("\n3. Python functions found in pyfuncs:")
- print("-" * 40)
- if hasattr(ir_mod, "pyfuncs"):
- for name, func in ir_mod.pyfuncs.items():
- print(f" - {name}: {func}")
- else:
- print(" No pyfuncs attribute found")
-
- except Exception as e:
- print(f"Demo failed: {e}")
- print("This is expected for testing-only TVMScript code.")
-
- # Run all tests using tvm.testing.main()
- print("\n" + "=" * 60)
- print("Running all tests with tvm.testing.main()...")
- print("=" * 60)
-
- import tvm.testing
-
- tvm.testing.main()
-
-
# Pytest test functions to verify the classes work correctly
def test_simple_pyfunc_module_creation():
"""Test that SimplePyFuncModule can be created."""
@@ -849,3 +801,7 @@ def test_call_py_func_with_base_py_module():
# Use numpy for comparison since we have numpy arrays
np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5,
atol=1e-5)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()