This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new a7ebc65  Allow `tvm_ffi.device(..., id)` where id is numpy or torch 
scalar (#347)
a7ebc65 is described below

commit a7ebc65f14eecd1592d407f1d5c952c65603a9aa
Author: wrongtest <[email protected]>
AuthorDate: Thu Dec 18 02:17:15 2025 +0800

    Allow `tvm_ffi.device(..., id)` where id is numpy or torch scalar (#347)
    
    Similar to torch.device
    
    ```python
    torch.device("cuda", 0)
    torch.device("cuda", numpy.int32(1))
    torch.device("cuda", torch.tensor(1, dtype=torch.int32))
    ```
    
    ```python
    tvm_ffi.device("cuda", 0)
    tvm_ffi.device("cuda", numpy.int32(1))
    tvm_ffi.device("cuda", torch.tensor(1, dtype=torch.int32))
    ```
    
    ---------
    
    Co-authored-by: baoxinqi <[email protected]>
---
 python/tvm_ffi/cython/device.pxi | 11 ++++++++---
 tests/python/test_device.py      |  5 +++++
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi
index 9539827..2eb36fc 100644
--- a/python/tvm_ffi/cython/device.pxi
+++ b/python/tvm_ffi/cython/device.pxi
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from enum import IntEnum
+from numbers import Integral
 from typing import Any, Optional
 
 _CLASS_DEVICE = None
@@ -129,7 +130,7 @@ cdef class Device:
         "trn": DLDeviceType.kDLTrn,
     }
 
-    def __init__(self, device_type: str | int, index: Optional[int] = None) -> 
None:
+    def __init__(self, device_type: str | int, index: Optional[Integral] = 
None) -> None:
         device_type_or_name = device_type
         index = index if index is not None else 0
         if isinstance(device_type_or_name, str):
@@ -148,8 +149,12 @@ cdef class Device:
                     raise ValueError(f"Invalid device index: {parts[1]}")
         else:
             device_type = device_type_or_name
-        if not isinstance(index, int):
-            raise TypeError(f"Invalid device index: {index}")
+
+        if not isinstance(index, Integral):
+            if hasattr(index, "item") and callable(index.item):
+                index = index.item()
+            if not isinstance(index, Integral):
+                raise TypeError(f"Invalid device index: {index}")
         self.cdevice = TVMFFIDLDeviceFromIntPair(device_type, index)
 
     def __reduce__(self) -> Any:
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 71dc0e4..33b48e8 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import ctypes
 import pickle
 
+import numpy
 import pytest
 import tvm_ffi
 from tvm_ffi import DLDeviceType
@@ -69,6 +70,10 @@ def test_device_dlpack_device_type(
         (DLDeviceType.kDLCUDA, 0, DLDeviceType.kDLCUDA, 0),
         ("cuda", 3, DLDeviceType.kDLCUDA, 3),
         (DLDeviceType.kDLMetal, 2, DLDeviceType.kDLMetal, 2),
+        # id from numpy
+        ("cpu", numpy.int32(1), DLDeviceType.kDLCPU, 1),
+        # id from torch (py dependency not ready in environment)
+        # ("cpu", torch.tensor(1, dtype=torch.int32), DLDeviceType.kDLCPU, 1),
     ],
 )
 def test_device_with_dev_id(

Reply via email to