This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 8fcd924 feat: Introduce `tvm.registry.get_registered_type_keys()`
(#249)
8fcd924 is described below
commit 8fcd9245186df3d6570e641dfc1c84239a9f9a40
Author: Junru Shao <[email protected]>
AuthorDate: Sun Nov 9 15:03:01 2025 -0800
feat: Introduce `tvm.registry.get_registered_type_keys()` (#249)
This API is going to be useful for library stub generation, where we
need to query objects registered on C++ end.
---
python/tvm_ffi/_ffi_api.py | 1 +
python/tvm_ffi/registry.py | 17 +++++++++++++++--
src/ffi/object.cc | 21 +++++++++++++++++++++
tests/python/test_object.py | 30 +++++++++++++++++++++++++++++-
4 files changed, 66 insertions(+), 3 deletions(-)
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 5a957d6..316a6fb 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
def FunctionRemoveGlobal(_0: str, /) -> bool: ...
def GetFirstStructuralMismatch(_0: Any, _1: Any, _2: bool, _3: bool, /) ->
tuple[AccessPath, AccessPath] | None: ...
def GetGlobalFuncMetadata(_0: str, /) -> str: ...
+ def GetRegisteredTypeKeys() -> Sequence[str]: ...
def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
def Map(*args: Any) -> Any: ...
def MapCount(_0: Mapping[Any, Any], _1: Any, /) -> int: ...
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index ce1d15f..74f56cb 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import json
import sys
-from typing import Any, Callable, Literal, TypeVar, overload
+from typing import Any, Callable, Literal, Sequence, TypeVar, overload
from . import core
from .core import TypeInfo
@@ -268,7 +268,7 @@ def get_global_func_metadata(name: str) -> dict[str, Any]:
Register a Python callable as a global FFI function.
"""
- return json.loads(get_global_func("ffi.GetGlobalFuncMetadata")(name))
+ return json.loads(get_global_func("ffi.GetGlobalFuncMetadata")(name) or
"{}")
def init_ffi_api(namespace: str, target_module_name: str | None = None) ->
None:
@@ -346,9 +346,22 @@ def __init__invalid(self: Any, *args: Any, **kwargs: Any)
-> None:
raise RuntimeError("The __init__ method of this class is not implemented.")
+def get_registered_type_keys() -> Sequence[str]:
+ """Get the list of valid type keys registered to TVM-FFI.
+
+ Returns
+ -------
+ type_keys
+ List of valid type keys.
+
+ """
+ return get_global_func("ffi.GetRegisteredTypeKeys")()
+
+
__all__ = [
"get_global_func",
"get_global_func_metadata",
+ "get_registered_type_keys",
"init_ffi_api",
"list_global_func_names",
"register_global_func",
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 1671ba8..e8a232d 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -21,6 +21,7 @@
* \brief Registry to record dynamic types
*/
#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
@@ -193,6 +194,16 @@ class TypeTable {
return entry;
}
+ Array<String> GetRegisteredTypeKeys() const {
+ Array<String> ret;
+ for (const auto& entry : type_table_) {
+ if (entry) {
+ ret.push_back(entry->type_key_data);
+ }
+ }
+ return ret;
+ }
+
void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) {
Entry* entry = GetTypeEntry(type_index);
TVMFFIFieldInfo field_data = *info;
@@ -537,3 +548,13 @@ int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input,
TVMFFIAny* out) {
tvm::ffi::TypeTraits<tvm::ffi::Bytes>::MoveToAny(tvm::ffi::Bytes(input->data,
input->size), out);
TVM_FFI_SAFE_CALL_END();
}
+
+namespace {
+TVM_FFI_STATIC_INIT_BLOCK() {
+ using namespace tvm::ffi;
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def_method("ffi.GetRegisteredTypeKeys", []() ->
Array<String> {
+ return tvm::ffi::TypeTable::Global()->GetRegisteredTypeKeys();
+ });
+}
+} // namespace
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index c39b0be..e49d3ec 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import sys
-from typing import Any
+from typing import Any, Sequence
import pytest
import tvm_ffi
@@ -228,3 +228,31 @@ def test_type_info_attachment(test_cls: type, type_key:
str, parent_cls: type |
assert parent_type_info.type_cls is parent_cls, (
f"Expected parent type {parent_cls}, but got
{parent_type_info.type_cls}"
)
+
+
+def test_get_registered_type_keys() -> None:
+ keys = tvm_ffi.registry.get_registered_type_keys()
+ assert isinstance(keys, Sequence)
+ assert all(isinstance(k, str) for k in keys)
+ keys = set(keys)
+ assert "ffi.Object" in keys
+ assert "ffi.String" in keys
+ for ty in [
+ "None",
+ "int",
+ "bool",
+ "float",
+ "void*",
+ "DataType",
+ "Device",
+ "DLTensor*",
+ "const char*",
+ "TVMFFIByteArray*",
+ "ObjectRValueRef",
+ ]:
+ assert ty in keys, f"Expected to find `{ty}` in registered type keys,
but it was not found."
+ keys.remove(ty)
+ for ty in keys:
+ assert ty.startswith("ffi.") or ty.startswith("testing."), (
+ f"Expected type key `{ty}` to start with `ffi.` or `testing.`"
+ )