This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new df58a05  feat(typing): Parameterizable `Array[T]` and `Map[K, V]` (#37)
df58a05 is described below

commit df58a05ec400dc5e91ac86146aedaf3a8273b7bd
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 22 16:39:19 2025 -0700

    feat(typing): Parameterizable `Array[T]` and `Map[K, V]` (#37)
    
    Depends on #35.
    
    Previously `Array` and `Map` types, despite that they inherit from
    `collections.abc.Sequence` and `Mapping`, are not properly
    parameterized. This PR makes it work.
---
 python/tvm_ffi/container.py | 185 +++++++++++++++++++++++++++-----------------
 python/tvm_ffi/testing.py   |   4 +-
 2 files changed, 115 insertions(+), 74 deletions(-)

diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 6f29dfd..e3b2fb7 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -18,9 +18,12 @@
 
 from __future__ import annotations
 
-import collections.abc
+import operator
+from collections.abc import ItemsView as ItemsViewBase
 from collections.abc import Iterator, Mapping, Sequence
-from typing import Any, Callable
+from collections.abc import KeysView as KeysViewBase
+from collections.abc import ValuesView as ValuesViewBase
+from typing import Any, Callable, SupportsIndex, TypeVar, cast, overload
 
 from . import _ffi_api, core
 from .registry import register_object
@@ -28,12 +31,18 @@ from .registry import register_object
 __all__ = ["Array", "Map"]
 
 
+T = TypeVar("T")
+K = TypeVar("K")
+V = TypeVar("V")
+_DefaultT = TypeVar("_DefaultT")
+
+
 def getitem_helper(
     obj: Any,
-    elem_getter: Callable[[Any, int], Any],
+    elem_getter: Callable[[Any, int], T],
     length: int,
-    idx: int | slice,
-) -> Any:
+    idx: SupportsIndex | slice,
+) -> T | list[T]:
     """Implement a pythonic __getitem__ helper.
 
     Parameters
@@ -41,47 +50,46 @@ def getitem_helper(
     obj: Any
         The original object
 
-    elem_getter : Callable[[Any, int], Any]
+    elem_getter : Callable[[Any, int], T]
         A simple function that takes index and return a single element.
 
     length : int
         The size of the array
 
-    idx : int or slice
+    idx : SupportsIndex or slice
         The argument passed to getitem
 
     Returns
     -------
     result : object
-        The result of getitem
+        The element for integer indices or a ``list`` for slices.
 
     """
     if isinstance(idx, slice):
-        start = idx.start if idx.start is not None else 0
-        stop = idx.stop if idx.stop is not None else length
-        step = idx.step if idx.step is not None else 1
-        if start < 0:
-            start += length
-        if stop < 0:
-            stop += length
+        start, stop, step = idx.indices(length)
         return [elem_getter(obj, i) for i in range(start, stop, step)]
 
-    if idx < -length or idx >= length:
-        raise IndexError(f"Index out of range. size: {length}, got index 
{idx}")
-    if idx < 0:
-        idx += length
-    return elem_getter(obj, idx)
+    try:
+        index = operator.index(idx)
+    except TypeError as exc:  # pragma: no cover - defensive, matches list 
behaviour
+        raise TypeError(f"indices must be integers or slices, not 
{type(idx).__name__}") from exc
+
+    if index < -length or index >= length:
+        raise IndexError(f"Index out of range. size: {length}, got index 
{index}")
+    if index < 0:
+        index += length
+    return elem_getter(obj, index)
 
 
 @register_object("ffi.Array")
-class Array(core.Object, collections.abc.Sequence):
+class Array(core.Object, Sequence[T]):
     """Array container that represents a sequence of values in ffi.
 
     :py:func:`tvm_ffi.convert` will map python list/tuple to this class.
 
     Parameters
     ----------
-    input_list : Sequence[Any]
+    input_list : Sequence[T]
         The list of values to be stored in the array.
 
     See Also
@@ -100,18 +108,34 @@ class Array(core.Object, collections.abc.Sequence):
 
     """
 
-    def __init__(self, input_list: Sequence[Any]) -> None:
+    def __init__(self, input_list: Sequence[T]) -> None:
         """Construct an Array from a Python sequence."""
         self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
 
-    def __getitem__(self, idx: int | slice) -> Any:
-        """Return one element or a Python list for a slice."""
-        return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx)
+    @overload
+    def __getitem__(self, idx: SupportsIndex, /) -> T: ...
+
+    @overload
+    def __getitem__(self, idx: slice, /) -> Array[T]: ...
+
+    def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T]:
+        """Return one element or a new :class:`Array` for a slice."""
+        length = len(self)
+        result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx)
+        if isinstance(result, list):
+            return cast(Array[T], type(self)(result))
+        return result
 
     def __len__(self) -> int:
         """Return the number of elements in the array."""
         return _ffi_api.ArraySize(self)
 
+    def __iter__(self) -> Iterator[T]:
+        """Iterate over the elements in the array."""
+        length = len(self)
+        for i in range(length):
+            yield self[i]
+
     def __repr__(self) -> str:
         """Return a string representation of the array."""
         # exception safety handling for chandle=None
@@ -120,79 +144,87 @@ class Array(core.Object, collections.abc.Sequence):
         return "[" + ", ".join([x.__repr__() for x in self]) + "]"
 
 
-class KeysView(collections.abc.KeysView):
+class KeysView(KeysViewBase[K]):
     """Helper class to return keys view."""
 
-    def __init__(self, backend_map: Map) -> None:
+    def __init__(self, backend_map: Map[K, V]) -> None:
         self._backend_map = backend_map
 
     def __len__(self) -> int:
         return len(self._backend_map)
 
-    def __iter__(self) -> Iterator[Any]:
-        if self.__len__() == 0:
-            return
-        functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
-        while True:
-            k = functor(0)
-            yield k
+    def __iter__(self) -> Iterator[K]:
+        size = len(self._backend_map)
+        functor: Callable[[int], Any] = 
_ffi_api.MapForwardIterFunctor(self._backend_map)
+        for _ in range(size):
+            key = cast(K, functor(0))
+            yield key
             if not functor(2):
                 break
 
-    def __contains__(self, k: Any) -> bool:
-        return self._backend_map.__contains__(k)
+    def __contains__(self, k: object) -> bool:
+        return k in self._backend_map
 
 
-class ValuesView(collections.abc.ValuesView):
+class ValuesView(ValuesViewBase[V]):
     """Helper class to return values view."""
 
-    def __init__(self, backend_map: Map) -> None:
+    def __init__(self, backend_map: Map[K, V]) -> None:
         self._backend_map = backend_map
 
     def __len__(self) -> int:
         return len(self._backend_map)
 
-    def __iter__(self) -> Iterator[Any]:
-        if self.__len__() == 0:
-            return
-        functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
-        while True:
-            v = functor(1)
-            yield v
+    def __iter__(self) -> Iterator[V]:
+        size = len(self._backend_map)
+        functor: Callable[[int], Any] = 
_ffi_api.MapForwardIterFunctor(self._backend_map)
+        for _ in range(size):
+            value = cast(V, functor(1))
+            yield value
             if not functor(2):
                 break
 
 
-class ItemsView(collections.abc.ItemsView):
+class ItemsView(ItemsViewBase[K, V]):
     """Helper class to return items view."""
 
-    def __init__(self, backend_map: Map) -> None:
-        self.backend_map = backend_map
+    def __init__(self, backend_map: Map[K, V]) -> None:
+        self._backend_map = backend_map
 
     def __len__(self) -> int:
-        return len(self.backend_map)
-
-    def __iter__(self) -> Iterator[tuple[Any, Any]]:
-        if self.__len__() == 0:
-            return
-        functor = _ffi_api.MapForwardIterFunctor(self.backend_map)
-        while True:
-            k = functor(0)
-            v = functor(1)
-            yield (k, v)
+        return len(self._backend_map)
+
+    def __iter__(self) -> Iterator[tuple[K, V]]:
+        size = len(self._backend_map)
+        functor: Callable[[int], Any] = 
_ffi_api.MapForwardIterFunctor(self._backend_map)
+        for _ in range(size):
+            key = cast(K, functor(0))
+            value = cast(V, functor(1))
+            yield (key, value)
             if not functor(2):
                 break
 
+    def __contains__(self, item: object) -> bool:
+        if not isinstance(item, tuple) or len(item) != 2:
+            return False
+        key, value = item
+        try:
+            existing_value = self._backend_map[key]
+        except KeyError:
+            return False
+        else:
+            return existing_value == value
+
 
 @register_object("ffi.Map")
-class Map(core.Object, collections.abc.Mapping):
+class Map(core.Object, Mapping[K, V]):
     """Map container.
 
     :py:func:`tvm_ffi.convert` will map python dict to this class.
 
     Parameters
     ----------
-    input_dict : Mapping[Any, Any]
+    input_dict : Mapping[K, V]
         The dictionary of values to be stored in the map.
 
     See Also
@@ -213,31 +245,31 @@ class Map(core.Object, collections.abc.Mapping):
 
     """
 
-    def __init__(self, input_dict: Mapping[Any, Any]) -> None:
+    def __init__(self, input_dict: Mapping[K, V]) -> None:
         """Construct a Map from a Python mapping."""
-        list_kvs = []
+        list_kvs: list[Any] = []
         for k, v in input_dict.items():
             list_kvs.append(k)
             list_kvs.append(v)
         self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)
 
-    def __getitem__(self, k: Any) -> Any:
+    def __getitem__(self, k: K) -> V:
         """Return the value for key `k` or raise KeyError."""
-        return _ffi_api.MapGetItem(self, k)
+        return cast(V, _ffi_api.MapGetItem(self, k))
 
-    def __contains__(self, k: Any) -> bool:
+    def __contains__(self, k: object) -> bool:
         """Return True if the map contains key `k`."""
         return _ffi_api.MapCount(self, k) != 0
 
-    def keys(self) -> KeysView:
+    def keys(self) -> KeysView[K]:
         """Return a dynamic view of the map's keys."""
         return KeysView(self)
 
-    def values(self) -> ValuesView:
+    def values(self) -> ValuesView[V]:
         """Return a dynamic view of the map's values."""
         return ValuesView(self)
 
-    def items(self) -> ItemsView:
+    def items(self) -> ItemsView[K, V]:
         """Get the items from the map."""
         return ItemsView(self)
 
@@ -245,11 +277,17 @@ class Map(core.Object, collections.abc.Mapping):
         """Return the number of items in the map."""
         return _ffi_api.MapSize(self)
 
-    def __iter__(self) -> Iterator[Any]:
+    def __iter__(self) -> Iterator[K]:
         """Iterate over the map's keys."""
         return iter(self.keys())
 
-    def get(self, key: Any, default: Any | None = None) -> Any:
+    @overload
+    def get(self, key: K) -> V | None: ...
+
+    @overload
+    def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ...
+
+    def get(self, key: K, default: V | _DefaultT | None = None) -> V | 
_DefaultT | None:
         """Get an element with a default value.
 
         Parameters
@@ -266,7 +304,10 @@ class Map(core.Object, collections.abc.Mapping):
             The result value.
 
         """
-        return self[key] if key in self else default
+        try:
+            return self[key]
+        except KeyError:
+            return default
 
     def __repr__(self) -> str:
         """Return a string representation of the map."""
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index f74abc2..6d302bc 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -52,8 +52,8 @@ class TestIntPair(Object):
 class TestObjectDerived(TestObjectBase):
     """Test object derived class."""
 
-    v_map: Map
-    v_array: Array
+    v_map: Map[Any, Any]
+    v_array: Array[Any]
 
 
 def create_object(type_key: str, **kwargs: Any) -> Object:

Reply via email to