This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new dd4fb0a feat: Introduce flexible `TypeSchema.repr(ty_map)` (#94)
dd4fb0a is described below
commit dd4fb0ae92ab69b1cc0280e812a18a58273d4ae6
Author: Junru Shao <[email protected]>
AuthorDate: Wed Oct 8 11:49:32 2025 -0700
feat: Introduce flexible `TypeSchema.repr(ty_map)` (#94)
This PR adds a new API `TypeSchema.repr(ty_map: ...)` which allows
downstream applications to override string representation of a certain
type, e.g. override `list` to `Sequence`, `dict` to `Mapping`. This can
be useful in schema generation.
---
python/tvm_ffi/core.pyi | 514 ++++++++++++++++++++++++++++++++----
python/tvm_ffi/cython/type_info.pxi | 45 ++--
src/ffi/extra/testing.cc | 2 +
tests/cpp/test_metadata.cc | 4 +
tests/python/test_metadata.py | 30 +++
5 files changed, 523 insertions(+), 72 deletions(-)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index e7c38f7..54b44cf 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -33,7 +33,35 @@ _TRACEBACK_TO_BACKTRACE_STR: Callable[[types.TracebackType |
None], str] | None
__dlpack_version__: tuple[int, int]
class Object:
- """Base class of all TVM FFI objects."""
+ """Base class of all TVM FFI objects.
+
+ This is the root Python type for objects backed by the TVM FFI
+ runtime. Each instance references a handle to a C++ runtime
+ object. Python subclasses typically correspond to C++ runtime
+ types and are registered via ``tvm_ffi.register_object``.
+
+ Notes
+ -----
+ - Equality of two ``Object`` instances uses underlying handle
+ identity unless an overridden implementation is provided on the
+ concrete type. Use :py:meth:`same_as` to check whether two
+ references point to the same underlying object.
+ - Most users interact with subclasses (e.g. :class:`Tensor`,
+ :class:`Function`) rather than ``Object`` directly.
+
+ Examples
+ --------
+ Constructing objects is typically performed by Python wrappers that
+ call into registered constructors on the FFI side.
+
+ .. code-block:: python
+
+ # Acquire a testing object constructed through FFI
+ obj = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
+ assert isinstance(obj, tvm_ffi.Object)
+ assert obj.same_as(obj)
+
+ """
def __ctypes_handle__(self) -> Any: ...
def __chandle__(self) -> int: ...
@@ -54,28 +82,128 @@ class Object:
The arguments to the constructor
"""
- def same_as(self, other: Any) -> bool: ...
- def _move(self) -> ObjectRValueRef: ...
- def __move_handle_from__(self, other: Object) -> None: ...
+ def same_as(self, other: Any) -> bool:
+ """Return ``True`` if both references point to the same object.
+
+ This checks identity of the underlying FFI handle rather than
+ performing a structural, value-based comparison.
+
+ Parameters
+ ----------
+ other : Any
+ The object to compare against.
+
+ Returns
+ -------
+ bool
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = tvm_ffi.testing.create_object("testing.TestObjectBase")
+ y = x
+ z = tvm_ffi.testing.create_object("testing.TestObjectBase")
+ assert x.same_as(y)
+ assert not x.same_as(z)
+
+ """
+
+ def _move(self) -> ObjectRValueRef:
+ """Create an rvalue reference that transfers ownership.
+
+ The returned :class:`ObjectRValueRef` indicates move semantics
+ to the FFI layer, and is intended for performance-sensitive
+ paths that wish to avoid an additional retain/release pair.
+
+ Notes
+ -----
+ After a successful move, the original object should be treated
+ as invalid on the FFI side. Do not rely on the handle after
+ transferring.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ use_count = tvm_ffi.get_global_func("testing.object_use_count")
+ x = tvm_ffi.convert([1, 2])
+ _ = tvm_ffi.convert(lambda o: o._move())(x)
+ # After move, ``x`` no longer owns the FFI handle
+ assert x.__ctypes_handle__().value is None
+
+ """
+
+ def __move_handle_from__(self, other: Object) -> None:
+ """Steal the FFI handle from ``other``.
+
+ Internal helper used by the runtime to implement move
+ semantics. Users should prefer :py:meth:`_move`.
+ """
class ObjectConvertible:
- """Base class for all classes that can be converted to Object."""
+ """Base class for Python classes convertible to :class:`Object`.
- def asobject(self) -> Object: ...
+ Subclasses implement :py:meth:`asobject` to produce an
+ :class:`Object` instance used by the FFI runtime.
+ """
+
+ def asobject(self) -> Object:
+ """Return an :class:`Object` view of this value.
+
+ This method is used by the conversion helpers (e.g.
+ :func:`tvm_ffi.convert`) when a Python value needs to be passed
+ into FFI calls.
+
+ Returns
+ -------
+ tvm_ffi.core.Object
+
+ """
class ObjectRValueRef:
- """Represent an RValue ref to an object that can be moved."""
+ """Rvalue reference wrapper used to express move semantics.
+
+ Instances are created from :py:meth:`Object._move` and signal to
+ the FFI layer that ownership of the underlying handle can be
+ transferred.
+ """
obj: Object
- def __init__(self, obj: Object) -> None: ...
+ def __init__(self, obj: Object) -> None:
+ """Construct from an existing :class:`Object`.
+
+ Parameters
+ ----------
+ obj : Object
+ The source object from which to move the underlying handle.
+
+ """
class OpaquePyObject(Object):
- """Opaque PyObject container."""
+ """Wrapper that carries an arbitrary Python object across the FFI.
+
+ The contained object is held with correct reference counting, and
+ can be recovered on the Python side using :py:meth:`pyobject`.
- def pyobject(self) -> Any: ...
+ Notes
+ -----
+ ``OpaquePyObject`` is useful when a Python value must traverse the
+ FFI boundary without conversion into a native FFI type.
+
+ """
+
+ def pyobject(self) -> Any:
+ """Return the original Python object held by this wrapper."""
class PyNativeObject:
- """Base class of all TVM objects that also subclass python's builtin
types."""
+ """Base class for TVM objects that also inherit Python builtins.
+
+ This mixin is used by Python-native proxy types such as
+ :class:`String` and :class:`Bytes`, which subclass :class:`str` and
+ :class:`bytes` respectively while also carrying an attached FFI
+ object for zero-copy exchange with the runtime when beneficial.
+ """
__slots__: list[str]
def __init_tvm_ffi_object_by_constructor__(self, fconstructor: Any, *args:
Any) -> None: ...
@@ -87,17 +215,45 @@ def _set_type_cls(type_info: TypeInfo, type_cls: type) ->
None: ...
def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo: ...
class Error(Object):
- """Base class for FFI errors."""
+ """Base class for FFI errors.
+
+ An :class:`Error` is a lightweight wrapper around a concrete Python
+ exception raised by FFI calls. It stores the error ``kind`` (e.g.
+ ``"ValueError"``), the message, and a serialized FFI backtrace that
+ can be re-attached to produce a Python traceback.
+
+ Users normally interact with specific error subclasses that are
+ registered via :func:`tvm_ffi.error.register_error`.
+ """
+
+ def __init__(self, kind: str, message: str, backtrace: str) -> None:
+ """Construct an error wrapper.
+
+ Parameters
+ ----------
+ kind : str
+ Name of the Python exception type (e.g. ``"ValueError"``).
+ message : str
+ The error message from the FFI side.
+ backtrace : str
+ Serialized backtrace encoded by the runtime.
+
+ """
+
+ def update_backtrace(self, backtrace: str) -> None:
+ """Replace the stored backtrace string with ``backtrace``."""
- def __init__(self, kind: str, message: str, backtrace: str) -> None: ...
- def update_backtrace(self, backtrace: str) -> None: ...
- def py_error(self) -> BaseException: ...
+ def py_error(self) -> BaseException:
+ """Return a Python :class:`BaseException` instance for this error."""
@property
- def kind(self) -> str: ...
+ def kind(self) -> str:
+ """The name of the Python exception class (e.g. ``"ValueError"``)."""
@property
- def message(self) -> str: ...
+ def message(self) -> str:
+ """The error message."""
@property
- def backtrace(self) -> str: ...
+ def backtrace(self) -> str:
+ """The serialized FFI backtrace string."""
def _convert_to_ffi_error(error: BaseException) -> Error: ...
def _env_set_current_stream(
@@ -105,20 +261,40 @@ def _env_set_current_stream(
) -> int | c_void_p: ...
class DataType:
- """DataType wrapper around DLDataType."""
+ """Internal wrapper around ``DLDataType``.
+
+ This is a low-level representation used by the FFI layer. It is
+ not intended as a user-facing API. For user code, prefer
+ :class:`tvm_ffi.dtype`, which behaves like a Python ``str`` and
+ integrates with array libraries.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # Prefer the user-facing helper
+ d = tvm_ffi.dtype("int32")
+ assert d.bits == 32
+ assert str(d) == "int32"
+
+ """
def __init__(self, dtype_str: str) -> None: ...
def __reduce__(self) -> Any: ...
def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
@property
- def type_code(self) -> int: ...
+ def type_code(self) -> int:
+ """Integer DLDataTypeCode of the scalar base type."""
@property
- def bits(self) -> int: ...
+ def bits(self) -> int:
+ """Number of bits of the scalar base type."""
@property
- def lanes(self) -> int: ...
+ def lanes(self) -> int:
+ """Number of lanes (for vector types)."""
@property
- def itemsize(self) -> int: ...
+ def itemsize(self) -> int:
+ """Size of one element in bytes (``bits * lanes // 8``)."""
def __str__(self) -> str: ...
def _set_class_dtype(cls: type) -> None: ...
@@ -127,7 +303,18 @@ def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype: Any) ->
DataType: ...
def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes:
int) -> DataType: ...
class DLDeviceType(IntEnum):
- """Enum that maps to DLDeviceType."""
+ """Enumeration mirroring DLPack's ``DLDeviceType``.
+
+ Values can be compared against :py:meth:`Device.dlpack_device_type`.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ dev = tvm_ffi.device("cuda", 0)
+ assert dev.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCUDA
+
+ """
kDLCPU = 1
kDLCUDA = 2
@@ -146,21 +333,51 @@ class DLDeviceType(IntEnum):
kDLTrn = 17
class Device:
- """Device represents a device in the ffi system."""
+ """A device descriptor used by TVM FFI and DLPack.
+
+ A :class:`Device` identifies a placement (e.g. CPU, CUDA GPU) and
+ a device index within that placement. Most users construct devices
+ using :func:`tvm_ffi.device`.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ dev = tvm_ffi.device("cuda:0")
+ assert dev.type == "cuda"
+ assert dev.index == 0
+ assert str(dev) == "cuda:0"
+
+ """
+
+ def __init__(self, device_type: str | int, index: int | None = None) ->
None:
+ """Construct a device from a type and optional index.
- def __init__(self, device_type: str | int, index: int | None = None) ->
None: ...
+ Parameters
+ ----------
+ device_type : str or int
+ A device type name (e.g. ``"cpu"``, ``"cuda"``) or a
+ DLPack device type code.
+ index : int, optional
+ Zero-based device index (defaults to ``0`` when omitted).
+
+ """
def __reduce__(self) -> Any: ...
def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
- def __device_type_name__(self) -> str: ...
+ def __device_type_name__(self) -> str:
+ """Return the canonical device type name (e.g. ``"cuda"``)."""
@property
- def type(self) -> str: ...
+ def type(self) -> str:
+ """Device type name such as ``"cpu"`` or ``"cuda"``."""
@property
- def index(self) -> int: ...
- def dlpack_device_type(self) -> int: ...
+ def index(self) -> int:
+ """Zero-based device index."""
+ def dlpack_device_type(self) -> int:
+ """Return the corresponding :class:`DLDeviceType` enum value."""
def _set_class_device(cls: type) -> None: ...
@@ -169,17 +386,40 @@ _CLASS_DEVICE: type[Device]
def _shape_obj_get_py_tuple(obj: Any) -> tuple[int, ...]: ...
class Tensor(Object):
- """Tensor object that represents a managed n-dimensional array."""
+ """Managed n-dimensional array compatible with DLPack.
+
+ ``Tensor`` provides zero-copy interoperability with array libraries
+ through the DLPack protocol. Instances are typically created with
+ :func:`from_dlpack` or returned from FFI functions.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ import numpy as np
+ x = tvm_ffi.from_dlpack(np.arange(6, dtype="int32"))
+ assert x.shape == (6,)
+ assert x.dtype == tvm_ffi.dtype("int32")
+ # Round-trip through NumPy using DLPack
+ np.testing.assert_equal(np.from_dlpack(x), np.arange(6, dtype="int32"))
+
+ """
@property
- def shape(self) -> tuple[int, ...]: ...
+ def shape(self) -> tuple[int, ...]:
+ """Tensor shape as a tuple of integers."""
@property
- def dtype(self) -> Any: ... # returned as python dtype (str subclass)
+ def dtype(self) -> Any:
+ """Data type as :class:`tvm_ffi.dtype` (``str`` subclass)."""
@property
- def device(self) -> Device: ...
- def _to_dlpack(self) -> Any: ...
- def _to_dlpack_versioned(self) -> Any: ...
- def __dlpack_device__(self) -> tuple[int, int]: ...
+ def device(self) -> Device:
+ """The :class:`Device` on which the tensor is placed."""
+ def _to_dlpack(self) -> Any:
+ """Return a DLPack capsule representing this tensor (internal)."""
+ def _to_dlpack_versioned(self) -> Any:
+ """Return a versioned DLPack capsule (internal)."""
+ def __dlpack_device__(self) -> tuple[int, int]:
+ """Implement the standard ``__dlpack_device__`` protocol."""
def __dlpack__(
self,
*,
@@ -187,14 +427,59 @@ class Tensor(Object):
max_version: tuple[int, int] | None = None,
dl_device: tuple[int, int] | None = None,
copy: bool | None = None,
- ) -> Any: ...
+ ) -> Any:
+ """Implement the standard ``__dlpack__`` protocol.
+
+ Parameters
+ ----------
+ stream : Any, optional
+ Framework-specific stream/context object.
+ max_version : Tuple[int, int], optional
+ Upper bound on the supported DLPack version of the
+ consumer. When ``None``, use the built-in protocol version.
+ dl_device : Tuple[int, int], optional
+ Override the device reported by :py:meth:`__dlpack_device__`.
+ copy : bool, optional
+ If ``True``, produce a copy rather than exporting in-place.
+
+ """
_CLASS_TENSOR: type[Tensor] = Tensor
def _set_class_tensor(cls: type[Tensor]) -> None: ...
def from_dlpack(
ext_tensor: Any, *, require_alignment: int = ..., require_contiguous: bool
= ...
-) -> Tensor: ...
+) -> Tensor:
+ """Import a foreign array that implements the DLPack producer protocol.
+
+ Parameters
+ ----------
+ ext_tensor : Any
+ An object supporting ``__dlpack__`` and ``__dlpack_device__``.
+ require_alignment : int, optional
+ If greater than zero, require the underlying data pointer to be
+ aligned to this many bytes. Misaligned inputs raise
+ :class:`ValueError`.
+ require_contiguous : bool, optional
+ When ``True``, require the layout to be contiguous. Non-contiguous
+ inputs raise :class:`ValueError`.
+
+ Returns
+ -------
+ Tensor
+ A TVM FFI :class:`Tensor` that references the same memory.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ import numpy as np
+ x_np = np.arange(8, dtype="int32")
+ x = tvm_ffi.from_dlpack(x_np)
+ y_np = np.from_dlpack(x)
+ assert np.shares_memory(x_np, y_np)
+
+ """
class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose."""
@@ -208,13 +493,39 @@ class DLTensorTestWrapper:
def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ...
class Function(Object):
- """Python class that wraps a function with tvm-ffi ABI."""
+ """Callable wrapper around a TVM FFI function.
+
+ Instances are obtained by converting Python callables with
+ :func:`tvm_ffi.convert`, or by looking up globally-registered FFI
+ functions using :func:`tvm_ffi.get_global_func`.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ @tvm_ffi.register_global_func("my.add")
+ def add(a, b):
+ return a + b
+
+ f = tvm_ffi.get_global_func("my.add")
+ assert isinstance(f, tvm_ffi.Function)
+ assert f(1, 2) == 3
+
+ """
@property
- def release_gil(self) -> bool: ...
+ def release_gil(self) -> bool:
+ """Whether calls release the Python GIL while executing."""
@release_gil.setter
- def release_gil(self, value: bool) -> None: ...
- def __call__(self, *args: Any) -> Any: ...
+ def release_gil(self, value: bool) -> None:
+ """Configure GIL release behavior for this function."""
+ def __call__(self, *args: Any) -> Any:
+ """Invoke the wrapped FFI function with ``args``.
+
+ Arguments are automatically converted between Python values and
+ FFI-compatible forms. The return value is converted back to a
+ Python object.
+ """
def _register_global_func(
name: str, pyfunc: Callable[..., Any] | Function, override: bool
@@ -225,40 +536,119 @@ def _convert_to_opaque_object(pyobject: Any) ->
OpaquePyObject: ...
def _print_debug_info() -> None: ...
class String(str, PyNativeObject):
+ """UTF-8 string that interoperates with FFI while behaving like ``str``.
+
+ ``String`` is a :class:`str` subclass that can travel across the
+ FFI boundary without copying for large payloads. For most Python
+ APIs, using a plain ``str`` works seamlessly; the runtime converts
+ to and from ``String`` as needed.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ s = tvm_ffi.core.String("hello")
+ assert fecho(s) == "hello"
+ assert fecho("world") == "world"
+
+ """
+
__slots__ = ["__tvm_ffi_object__"]
__tvm_ffi_object__: Object | None
- def __new__(cls, value: str) -> String: ...
+ def __new__(cls, value: str) -> String:
+ """Create a new ``String`` from a Python ``str``."""
# pylint: disable=no-self-argument
- def __from_tvm_ffi_object__(cls, obj: Any) -> String: ...
+ def __from_tvm_ffi_object__(cls, obj: Any) -> String:
+ """Construct a ``String`` from an FFI object (internal)."""
class Bytes(bytes, PyNativeObject):
+ """Byte buffer that interoperates with FFI while behaving like ``bytes``.
+
+ Like :class:`String`, this class enables zero-copy exchange for
+ large data. Most Python code can use ``bytes`` directly; the FFI
+ layer constructs :class:`Bytes` as needed.
+ """
+
__slots__ = ["__tvm_ffi_object__"]
__tvm_ffi_object__: Object | None
- def __new__(cls, value: bytes) -> Bytes: ...
+ def __new__(cls, value: bytes) -> Bytes:
+ """Create a new ``Bytes`` from a Python ``bytes`` value."""
# pylint: disable=no-self-argument
- def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ...
+ def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes:
+ """Construct ``Bytes`` from an FFI object (internal)."""
# ---------------------------------------------------------------------------
# Type reflection metadata (from cython/type_info.pxi)
# ---------------------------------------------------------------------------
class TypeSchema:
- """Type schema for a TVM FFI type."""
+ """Type schema that describes a TVM FFI type.
+
+ The schema is expressed using a compact JSON-compatible structure
+ and can be rendered as a Python typing string with
+ :py:meth:`repr`.
+ """
origin: str
args: tuple[TypeSchema, ...] = ()
@staticmethod
- def from_json_obj(obj: dict[str, Any]) -> TypeSchema: ...
+ def from_json_obj(obj: dict[str, Any]) -> TypeSchema:
+ """Construct a :class:`TypeSchema` from a parsed JSON object."""
@staticmethod
- def from_json_str(s: str) -> TypeSchema: ...
+ def from_json_str(s: str) -> TypeSchema:
+ """Construct a :class:`TypeSchema` from a JSON string."""
+ def repr(self, ty_map: Callable[[str], str] | None = None) -> str:
+ """Render a human-readable representation of this schema.
+
+ Parameters
+ ----------
+ ty_map : Callable[[str], str], optional
+ A mapping function applied to the schema origin name before
+ rendering (e.g. map ``"list" -> "Sequence"`` and
+ ``"dict" -> "Mapping"``). If ``None``, the raw origin is used.
+
+ Returns
+ -------
+ str
+ A readable string using Python typing syntax. Formats include:
+ - Unions as ``"T1 | T2"``
+ - Optional as ``"T | None"``
+ - Callables as ``"Callable[[arg1, ...], ret]"``
+ - Containers as ``"origin[arg1, ...]"``
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # From JSON emitted by the runtime
+ s =
TypeSchema.from_json_str('{"type":"Optional","args":[{"type":"int"}]}')
+ assert s.repr() == "int | None"
+
+ # Callable where the first arg is return type, remaining are
parameters
+ s = TypeSchema("Callable", (TypeSchema("int"), TypeSchema("str")))
+ assert s.repr() == "Callable[[str], int]"
+
+ # Custom mapping to stdlib typing collections
+ def _map(t: str) -> str:
+ return {"list": "Sequence", "dict": "Mapping"}.get(t, t)
+
+ s =
TypeSchema.from_json_str('{"type":"dict","args":[{"type":"str"},{"type":"int"}]}')
+ assert s.repr(_map) == "Mapping[str, int]"
+
+ """
class TypeField:
- """Description of a single reflected field on an FFI-backed type."""
+ """Description of a single reflected field on an FFI-backed type.
+
+ Instances are used to synthesize Python properties on generated
+ proxy classes.
+ """
name: str
doc: str | None
@@ -270,10 +660,15 @@ class TypeField:
setter: Any
dataclass_field: Any | None
- def as_property(self, cls: type) -> property: ...
+ def as_property(self, cls: type) -> property:
+ """Produce a Python :class:`property` for the given class ``cls``."""
class TypeMethod:
- """Description of a single reflected method on an FFI-backed type."""
+ """Description of a single reflected method on an FFI-backed type.
+
+ Instances are used to synthesize bound callables on generated proxy
+ classes.
+ """
name: str
doc: str | None
@@ -281,10 +676,16 @@ class TypeMethod:
is_static: bool
metadata: dict[str, Any]
- def as_callable(self, cls: type) -> Callable[..., Any]: ...
+ def as_callable(self, cls: type) -> Callable[..., Any]:
+ """Produce a bound Python callable for the given class ``cls``."""
class TypeInfo:
- """Aggregated type information required to build a proxy class."""
+ """Aggregated type information required to build a proxy class.
+
+ This structure contains the reflected fields and methods for an FFI
+ type, along with hierarchy information used during Python class
+ synthesis.
+ """
type_cls: type | None
type_index: int
@@ -293,4 +694,5 @@ class TypeInfo:
methods: list[TypeMethod]
parent_type_info: TypeInfo | None
- def prototype_py(self) -> str: ...
+ def prototype_py(self) -> str:
+ """Render a Python prototype string for debugging and testing."""
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index f48299c..ae27217 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -86,7 +86,7 @@ _TYPE_SCHEMA_ORIGIN_CONVERTER = {
}
[email protected](repr=False, frozen=True)
[email protected](repr=False)
class TypeSchema:
"""Type schema for a TVM FFI type."""
origin: str
@@ -101,27 +101,17 @@ class TypeSchema:
assert len(args) == 1, "Optional must have exactly one argument"
elif origin == "list":
assert len(args) in (0, 1), "list must have 0 or 1 argument"
+ if args == ():
+ self.args = (TypeSchema("Any"),)
elif origin == "dict":
assert len(args) in (0, 2), "dict must have 0 or 2 arguments"
+ if args == ():
+ self.args = (TypeSchema("Any"), TypeSchema("Any"))
elif origin == "tuple":
pass # tuple can have arbitrary number of arguments
def __repr__(self) -> str:
- if self.origin == "Union":
- return " | ".join(repr(a) for a in self.args)
- elif self.origin == "Optional":
- return repr(self.args[0]) + " | None"
- elif self.origin == "Callable":
- if not self.args:
- return "Callable[..., Any]"
- else:
- arg_ret = self.args[0]
- arg_args = self.args[1:]
- return f"Callable[[{', '.join(repr(a) for a in arg_args)}],
{repr(arg_ret)}]"
- elif not self.args:
- return self.origin
- else:
- return f"{self.origin}[{', '.join(repr(a) for a in self.args)}]"
+ return self.repr(ty_map=None)
@staticmethod
def from_json_obj(obj: dict[str, Any]) -> "TypeSchema":
@@ -136,6 +126,29 @@ class TypeSchema:
def from_json_str(s) -> "TypeSchema":
return TypeSchema.from_json_obj(json.loads(s))
+ def repr(self, ty_map = None) -> str:
+ if ty_map is None:
+ origin = self.origin
+ else:
+ origin = ty_map(self.origin)
+ args = [i.repr(ty_map) for i in self.args]
+ if origin == "Union":
+ return " | ".join(args)
+ elif origin == "Optional":
+ return args[0] + " | None"
+ elif origin == "Callable":
+ if not args:
+ return "Callable[..., Any]"
+ else:
+ ret = args[0]
+ args = ", ".join(args[1:])
+ return f"Callable[[{args}], {ret}]"
+ elif not args:
+ return origin
+ else:
+ args = ", ".join(args)
+ return f"{origin}[{args}]"
+
@dataclasses.dataclass(eq=False)
class TypeField:
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 0e30906..1898bda 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -389,9 +389,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("testing.schema_id_arr_int", [](Array<int64_t> arr) { return arr; })
.def("testing.schema_id_arr_str", [](Array<String> arr) { return arr; })
.def("testing.schema_id_arr_obj", [](Array<ObjectRef> arr) { return arr;
})
+ .def("testing.schema_id_arr", [](const ArrayObj* arr) { return arr; })
.def("testing.schema_id_map_str_int", [](Map<String, int64_t> m) {
return m; })
.def("testing.schema_id_map_str_str", [](Map<String, String> m) { return
m; })
.def("testing.schema_id_map_str_obj", [](Map<String, ObjectRef> m) {
return m; })
+ .def("testing.schema_id_map", [](const MapObj* m) { return m; })
.def("testing.schema_id_variant_int_str", [](Variant<int64_t, String> v)
{ return v; })
.def_packed("testing.schema_packed", [](PackedArgs args, Any* ret) {})
.def("testing.schema_arr_map_opt",
diff --git a/tests/cpp/test_metadata.cc b/tests/cpp/test_metadata.cc
index b5fedb8..982893f 100644
--- a/tests/cpp/test_metadata.cc
+++ b/tests/cpp/test_metadata.cc
@@ -102,6 +102,8 @@ TEST(Schema, GlobalFuncTypeSchema) {
EXPECT_EQ(
fetch("testing.schema_id_arr_obj"),
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"ffi.Object"}]},{"type":"ffi.Array","args":[{"type":"ffi.Object"}]}]})");
+ EXPECT_EQ(fetch("testing.schema_id_arr"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Array"},{"type":"ffi.Array"}]})");
EXPECT_EQ(
fetch("testing.schema_id_map_str_int"),
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"int"}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"int"}]}]})");
@@ -111,6 +113,8 @@ TEST(Schema, GlobalFuncTypeSchema) {
EXPECT_EQ(
fetch("testing.schema_id_map_str_obj"),
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Object"}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Object"}]}]})");
+ EXPECT_EQ(fetch("testing.schema_id_map"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Map"},{"type":"ffi.Map"}]})");
EXPECT_EQ(
fetch("testing.schema_id_variant_int_str"),
R"({"type":"ffi.Function","args":[{"type":"Variant","args":[{"type":"int"},{"type":"ffi.String"}]},{"type":"Variant","args":[{"type":"int"},{"type":"ffi.String"}]}]})");
diff --git a/tests/python/test_metadata.py b/tests/python/test_metadata.py
index fe19822..22e2ce5 100644
--- a/tests/python/test_metadata.py
+++ b/tests/python/test_metadata.py
@@ -22,6 +22,13 @@ from tvm_ffi.core import TypeInfo, TypeSchema
from tvm_ffi.testing import _SchemaAllTypes
+def _replace_list_dict(ty: str) -> str:
+ return {
+ "list": "Sequence",
+ "dict": "Mapping",
+ }.get(ty, ty)
+
+
@pytest.mark.parametrize(
"func_name,expected",
[
@@ -48,9 +55,11 @@ from tvm_ffi.testing import _SchemaAllTypes
("testing.schema_id_arr_int", "Callable[[list[int]], list[int]]"),
("testing.schema_id_arr_str", "Callable[[list[str]], list[str]]"),
("testing.schema_id_arr_obj", "Callable[[list[Object]],
list[Object]]"),
+ ("testing.schema_id_arr", "Callable[[list[Any]], list[Any]]"),
("testing.schema_id_map_str_int", "Callable[[dict[str, int]],
dict[str, int]]"),
("testing.schema_id_map_str_str", "Callable[[dict[str, str]],
dict[str, str]]"),
("testing.schema_id_map_str_obj", "Callable[[dict[str, Object]],
dict[str, Object]]"),
+ ("testing.schema_id_map", "Callable[[dict[Any, Any]], dict[Any,
Any]]"),
("testing.schema_id_variant_int_str", "Callable[[int | str], int |
str]"),
("testing.schema_packed", "Callable[..., Any]"),
(
@@ -67,6 +76,13 @@ def test_schema_global_func(func_name: str, expected: str)
-> None:
metadata: dict[str, Any] = get_global_func_metadata(func_name)
actual: TypeSchema = TypeSchema.from_json_str(metadata["type_schema"])
assert str(actual) == expected, f"{func_name}: {actual}"
+ assert actual.repr(_replace_list_dict) == expected.replace(
+ "list",
+ "Sequence",
+ ).replace(
+ "dict",
+ "Mapping",
+ )
@pytest.mark.parametrize(
@@ -95,6 +111,13 @@ def test_schema_field(field_name: str, expected: str) ->
None:
if field.name == field_name:
actual: TypeSchema =
TypeSchema.from_json_str(field.metadata["type_schema"])
assert str(actual) == expected, f"{field_name}: {actual}"
+ assert actual.repr(_replace_list_dict) == expected.replace(
+ "list",
+ "Sequence",
+ ).replace(
+ "dict",
+ "Mapping",
+ )
break
else:
raise ValueError(f"Field not found: {field_name}")
@@ -119,6 +142,13 @@ def test_schema_member_method(method_name: str, expected:
str) -> None:
if method.name == method_name:
actual: TypeSchema =
TypeSchema.from_json_str(method.metadata["type_schema"])
assert str(actual) == expected, f"{method_name}: {actual}"
+ assert actual.repr(_replace_list_dict) == expected.replace(
+ "list",
+ "Sequence",
+ ).replace(
+ "dict",
+ "Mapping",
+ )
break
else:
raise ValueError(f"Method not found: {method_name}")