This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new df58a05 feat(typing): Parameterizable `Array[T]` and `Map[K, V]` (#37)
df58a05 is described below
commit df58a05ec400dc5e91ac86146aedaf3a8273b7bd
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 22 16:39:19 2025 -0700
feat(typing): Parameterizable `Array[T]` and `Map[K, V]` (#37)
Depends on #35.
Previously `Array` and `Map` types, despite that they inherit from
`collections.abc.Sequence` and `Mapping`, are not properly
parameterized. This PR makes it work.
---
python/tvm_ffi/container.py | 185 +++++++++++++++++++++++++++-----------------
python/tvm_ffi/testing.py | 4 +-
2 files changed, 115 insertions(+), 74 deletions(-)
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 6f29dfd..e3b2fb7 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -18,9 +18,12 @@
from __future__ import annotations
-import collections.abc
+import operator
+from collections.abc import ItemsView as ItemsViewBase
from collections.abc import Iterator, Mapping, Sequence
-from typing import Any, Callable
+from collections.abc import KeysView as KeysViewBase
+from collections.abc import ValuesView as ValuesViewBase
+from typing import Any, Callable, SupportsIndex, TypeVar, cast, overload
from . import _ffi_api, core
from .registry import register_object
@@ -28,12 +31,18 @@ from .registry import register_object
__all__ = ["Array", "Map"]
+T = TypeVar("T")
+K = TypeVar("K")
+V = TypeVar("V")
+_DefaultT = TypeVar("_DefaultT")
+
+
def getitem_helper(
obj: Any,
- elem_getter: Callable[[Any, int], Any],
+ elem_getter: Callable[[Any, int], T],
length: int,
- idx: int | slice,
-) -> Any:
+ idx: SupportsIndex | slice,
+) -> T | list[T]:
"""Implement a pythonic __getitem__ helper.
Parameters
@@ -41,47 +50,46 @@ def getitem_helper(
obj: Any
The original object
- elem_getter : Callable[[Any, int], Any]
+ elem_getter : Callable[[Any, int], T]
A simple function that takes index and return a single element.
length : int
The size of the array
- idx : int or slice
+ idx : SupportsIndex or slice
The argument passed to getitem
Returns
-------
result : object
- The result of getitem
+ The element for integer indices or a ``list`` for slices.
"""
if isinstance(idx, slice):
- start = idx.start if idx.start is not None else 0
- stop = idx.stop if idx.stop is not None else length
- step = idx.step if idx.step is not None else 1
- if start < 0:
- start += length
- if stop < 0:
- stop += length
+ start, stop, step = idx.indices(length)
return [elem_getter(obj, i) for i in range(start, stop, step)]
- if idx < -length or idx >= length:
- raise IndexError(f"Index out of range. size: {length}, got index
{idx}")
- if idx < 0:
- idx += length
- return elem_getter(obj, idx)
+ try:
+ index = operator.index(idx)
+ except TypeError as exc: # pragma: no cover - defensive, matches list
behaviour
+ raise TypeError(f"indices must be integers or slices, not
{type(idx).__name__}") from exc
+
+ if index < -length or index >= length:
+ raise IndexError(f"Index out of range. size: {length}, got index
{index}")
+ if index < 0:
+ index += length
+ return elem_getter(obj, index)
@register_object("ffi.Array")
-class Array(core.Object, collections.abc.Sequence):
+class Array(core.Object, Sequence[T]):
"""Array container that represents a sequence of values in ffi.
:py:func:`tvm_ffi.convert` will map python list/tuple to this class.
Parameters
----------
- input_list : Sequence[Any]
+ input_list : Sequence[T]
The list of values to be stored in the array.
See Also
@@ -100,18 +108,34 @@ class Array(core.Object, collections.abc.Sequence):
"""
- def __init__(self, input_list: Sequence[Any]) -> None:
+ def __init__(self, input_list: Sequence[T]) -> None:
"""Construct an Array from a Python sequence."""
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
- def __getitem__(self, idx: int | slice) -> Any:
- """Return one element or a Python list for a slice."""
- return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx)
+ @overload
+ def __getitem__(self, idx: SupportsIndex, /) -> T: ...
+
+ @overload
+ def __getitem__(self, idx: slice, /) -> Array[T]: ...
+
+ def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T]:
+ """Return one element or a new :class:`Array` for a slice."""
+ length = len(self)
+ result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx)
+ if isinstance(result, list):
+ return cast(Array[T], type(self)(result))
+ return result
def __len__(self) -> int:
"""Return the number of elements in the array."""
return _ffi_api.ArraySize(self)
+ def __iter__(self) -> Iterator[T]:
+ """Iterate over the elements in the array."""
+ length = len(self)
+ for i in range(length):
+ yield self[i]
+
def __repr__(self) -> str:
"""Return a string representation of the array."""
# exception safety handling for chandle=None
@@ -120,79 +144,87 @@ class Array(core.Object, collections.abc.Sequence):
return "[" + ", ".join([x.__repr__() for x in self]) + "]"
-class KeysView(collections.abc.KeysView):
+class KeysView(KeysViewBase[K]):
"""Helper class to return keys view."""
- def __init__(self, backend_map: Map) -> None:
+ def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map
def __len__(self) -> int:
return len(self._backend_map)
- def __iter__(self) -> Iterator[Any]:
- if self.__len__() == 0:
- return
- functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
- while True:
- k = functor(0)
- yield k
+ def __iter__(self) -> Iterator[K]:
+ size = len(self._backend_map)
+ functor: Callable[[int], Any] =
_ffi_api.MapForwardIterFunctor(self._backend_map)
+ for _ in range(size):
+ key = cast(K, functor(0))
+ yield key
if not functor(2):
break
- def __contains__(self, k: Any) -> bool:
- return self._backend_map.__contains__(k)
+ def __contains__(self, k: object) -> bool:
+ return k in self._backend_map
-class ValuesView(collections.abc.ValuesView):
+class ValuesView(ValuesViewBase[V]):
"""Helper class to return values view."""
- def __init__(self, backend_map: Map) -> None:
+ def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map
def __len__(self) -> int:
return len(self._backend_map)
- def __iter__(self) -> Iterator[Any]:
- if self.__len__() == 0:
- return
- functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
- while True:
- v = functor(1)
- yield v
+ def __iter__(self) -> Iterator[V]:
+ size = len(self._backend_map)
+ functor: Callable[[int], Any] =
_ffi_api.MapForwardIterFunctor(self._backend_map)
+ for _ in range(size):
+ value = cast(V, functor(1))
+ yield value
if not functor(2):
break
-class ItemsView(collections.abc.ItemsView):
+class ItemsView(ItemsViewBase[K, V]):
"""Helper class to return items view."""
- def __init__(self, backend_map: Map) -> None:
- self.backend_map = backend_map
+ def __init__(self, backend_map: Map[K, V]) -> None:
+ self._backend_map = backend_map
def __len__(self) -> int:
- return len(self.backend_map)
-
- def __iter__(self) -> Iterator[tuple[Any, Any]]:
- if self.__len__() == 0:
- return
- functor = _ffi_api.MapForwardIterFunctor(self.backend_map)
- while True:
- k = functor(0)
- v = functor(1)
- yield (k, v)
+ return len(self._backend_map)
+
+ def __iter__(self) -> Iterator[tuple[K, V]]:
+ size = len(self._backend_map)
+ functor: Callable[[int], Any] =
_ffi_api.MapForwardIterFunctor(self._backend_map)
+ for _ in range(size):
+ key = cast(K, functor(0))
+ value = cast(V, functor(1))
+ yield (key, value)
if not functor(2):
break
+ def __contains__(self, item: object) -> bool:
+ if not isinstance(item, tuple) or len(item) != 2:
+ return False
+ key, value = item
+ try:
+ existing_value = self._backend_map[key]
+ except KeyError:
+ return False
+ else:
+ return existing_value == value
+
@register_object("ffi.Map")
-class Map(core.Object, collections.abc.Mapping):
+class Map(core.Object, Mapping[K, V]):
"""Map container.
:py:func:`tvm_ffi.convert` will map python dict to this class.
Parameters
----------
- input_dict : Mapping[Any, Any]
+ input_dict : Mapping[K, V]
The dictionary of values to be stored in the map.
See Also
@@ -213,31 +245,31 @@ class Map(core.Object, collections.abc.Mapping):
"""
- def __init__(self, input_dict: Mapping[Any, Any]) -> None:
+ def __init__(self, input_dict: Mapping[K, V]) -> None:
"""Construct a Map from a Python mapping."""
- list_kvs = []
+ list_kvs: list[Any] = []
for k, v in input_dict.items():
list_kvs.append(k)
list_kvs.append(v)
self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)
- def __getitem__(self, k: Any) -> Any:
+ def __getitem__(self, k: K) -> V:
"""Return the value for key `k` or raise KeyError."""
- return _ffi_api.MapGetItem(self, k)
+ return cast(V, _ffi_api.MapGetItem(self, k))
- def __contains__(self, k: Any) -> bool:
+ def __contains__(self, k: object) -> bool:
"""Return True if the map contains key `k`."""
return _ffi_api.MapCount(self, k) != 0
- def keys(self) -> KeysView:
+ def keys(self) -> KeysView[K]:
"""Return a dynamic view of the map's keys."""
return KeysView(self)
- def values(self) -> ValuesView:
+ def values(self) -> ValuesView[V]:
"""Return a dynamic view of the map's values."""
return ValuesView(self)
- def items(self) -> ItemsView:
+ def items(self) -> ItemsView[K, V]:
"""Get the items from the map."""
return ItemsView(self)
@@ -245,11 +277,17 @@ class Map(core.Object, collections.abc.Mapping):
"""Return the number of items in the map."""
return _ffi_api.MapSize(self)
- def __iter__(self) -> Iterator[Any]:
+ def __iter__(self) -> Iterator[K]:
"""Iterate over the map's keys."""
return iter(self.keys())
- def get(self, key: Any, default: Any | None = None) -> Any:
+ @overload
+ def get(self, key: K) -> V | None: ...
+
+ @overload
+ def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ...
+
+ def get(self, key: K, default: V | _DefaultT | None = None) -> V |
_DefaultT | None:
"""Get an element with a default value.
Parameters
@@ -266,7 +304,10 @@ class Map(core.Object, collections.abc.Mapping):
The result value.
"""
- return self[key] if key in self else default
+ try:
+ return self[key]
+ except KeyError:
+ return default
def __repr__(self) -> str:
"""Return a string representation of the map."""
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index f74abc2..6d302bc 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -52,8 +52,8 @@ class TestIntPair(Object):
class TestObjectDerived(TestObjectBase):
"""Test object derived class."""
- v_map: Map
- v_array: Array
+ v_map: Map[Any, Any]
+ v_array: Array[Any]
def create_object(type_key: str, **kwargs: Any) -> Object: