This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new c0add28  [FIX] Fix device type override (#40)
c0add28 is described below

commit c0add281b0973aacc88f3185a84e0768c67306f4
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Sep 22 11:23:53 2025 -0400

    [FIX] Fix device type override (#40)
    
    This pr fixes the behavior of device so device type override works
    correctly, added an unittest.
---
 pyproject.toml              |  2 +-
 python/tvm_ffi/_tensor.py   |  6 +++---
 tests/python/test_device.py | 12 ++++++++++++
 tests/python/test_tensor.py | 18 ++++++++++++++++++
 4 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 1e97de7..d2fd897 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
 
 [project]
 name = "apache-tvm-ffi"
-version = "0.1.0b5"
+version = "0.1.0b6"
 description = "tvm ffi"
 
 authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index 8d06bd2..0cc09f1 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -21,9 +21,8 @@
 from numbers import Integral
 from typing import Any, Optional, Union
 
-from . import _ffi_api, registry
+from . import _ffi_api, core, registry
 from .core import (
-    _CLASS_DEVICE,
     Device,
     DLDeviceType,
     PyNativeObject,
@@ -86,7 +85,8 @@ def device(device_type: Union[str, int, DLDeviceType], index: 
Optional[int] = No
       assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0)
 
     """
-    return _CLASS_DEVICE(device_type, index)
+    # must refer to core._CLASS_DEVICE so we pick up override here
+    return core._CLASS_DEVICE(device_type, index)
 
 
 __all__ = ["DLDeviceType", "Device", "Tensor", "device", "from_dlpack"]
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 30c964a..9441c9f 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -95,3 +95,15 @@ def test_device_pickle() -> None:
     device_pickled = pickle.loads(pickle.dumps(device))
     assert device_pickled.dlpack_device_type() == device.dlpack_device_type()
     assert device_pickled.index == device.index
+
+
+def test_device_class_override() -> None:
+    class MyDevice(tvm_ffi.Device):
+        pass
+
+    old_device = tvm_ffi.core._CLASS_DEVICE
+    tvm_ffi.core._set_class_device(MyDevice)
+
+    device = tvm_ffi.device("cuda", 0)
+    assert isinstance(device, MyDevice)
+    tvm_ffi.core._set_class_device(old_device)
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 6d1da26..4c2e9a8 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -66,3 +66,21 @@ def test_tensor_auto_dlpack() -> None:
     assert y.shape == x.shape
     assert y.device == x.device
     np.testing.assert_equal(y.numpy(), x.numpy())
+
+
+def test_tensor_class_override() -> None:
+    class MyTensor(tvm_ffi.Tensor):
+        pass
+
+    old_tensor = tvm_ffi.core._CLASS_TENSOR
+    tvm_ffi.core._set_class_tensor(MyTensor)
+
+    data = np.zeros((10, 8, 4, 2), dtype="int16")
+    if not hasattr(data, "__dlpack__"):
+        return
+    x = tvm_ffi.from_dlpack(data)
+
+    fecho = tvm_ffi.get_global_func("testing.echo")
+    y = fecho(x)
+    assert isinstance(y, MyTensor)
+    tvm_ffi.core._set_class_tensor(old_tensor)

Reply via email to