This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 28fe3cc feat: Metadata and JSON-based type schemas for fields,
methods and global funcs (#36)
28fe3cc is described below
commit 28fe3cc7fc23b725ddefbb3ca1e1899a7dfd530c
Author: Junru Shao <[email protected]>
AuthorDate: Fri Oct 3 16:07:06 2025 -0700
feat: Metadata and JSON-based type schemas for fields, methods and global
funcs (#36)
This PR introduces to TVM-FFI's registry a metadata mechanism, which
allows to expose metadata attached to global functions, object fields
and object member functions. As part of metadata, this PR generates a
JSON-based type schema exposed in both C++ and Python APIs.
This enhancement provides richer, machine-readable type information,
which can be invaluable for tooling, static analysis, and dynamic
dispatch mechanisms built on top of the TVM FFI.
**Key Changes:**
[C++] JSON-based schema generation `tvm::ffi::details::TypeSchema<T>`,
which supports:
- Fundamental types, e.g., `int`, `float`, `bool`, `String`, `Bytes`,
`DLDataType`, `DLDevice`
- Complex container types, e.g., `Array<T>`, `Map<K, V>`, `Optional<T>`,
`Tuple<Args...>`, `Variant<Args...>`
- Typed and type-erased functions, e.g. `Callable`, `Callable[[Args...],
Ret]`.
[C++] Trait class `tvm::ffi::reflection::Metadata` is introduced,
allowing arbitrary key-value pairs to be attached to fields and methods
during registration. Example:
```C++
TVM_FFI_STATIC_INIT_BLOCK() {
GlobalDef()
.def("my_add", my_add_func,
Metadata{{"description", "Adds two integers"}, {"version", 1}});
}
```
[Python] `tvm_ffi.core.TypeField` and `tvm_ffi.core.TypeMethod` now
include a `metadata: dict[str, Any]` dictionary, making the attached
JSON metadata accessible from Python.
[Python] A new `tvm_ffi.core.TypeSchema` class is added, capable of
parsing the JSON schema strings into a structured Python object,
providing a more convenient way to inspect type information.
```python
schema_str: str =
"{\"type\":\"ffi.Function\",\"args\":[{\"type\":\"int\"},{\"type\":\"int\"}]}"
schema_obj: TypeSchema = TypeSchema.from_json_str(schema_str)
print(schema_obj)
# Callable[[int], int]
```
[Python] A new Python function `tvm_ffi.get_global_func_metadata(name:
str)` is exposed to retrieve the full metadata (including `type_schema`)
for any registered global function.
```python
# (Assuming 'testing.schema_id_int' is registered in C++ as in
src/ffi/extra/testing.cc)
metadata: dict[str, Any] = get_global_func_metadata("testing.schema_id_int")
# metadata will contain:
# {
# "type_schema":
"{\"type\":\"ffi.Function\",\"args\":[{\"type\":\"int\"},{\"type\":\"int\"}]}",
# "bool_attr": True,
# "int_attr": 1,
# "str_attr": "hello"
# }
```
Internal Utilities
- A new `tvm::ffi::EscapeString` utility is added for proper string
escaping.
- `FieldInfoBuilder` and `MethodInfoBuilder` are added to temporarily
hold this metadata before it's serialized into the `TVMFFIFieldInfo` and
`TVMFFIMethodInfo` C structs.
---
include/tvm/ffi/base_details.h | 16 +++
include/tvm/ffi/container/array.h | 7 ++
include/tvm/ffi/container/map.h | 8 ++
include/tvm/ffi/container/tensor.h | 3 +
include/tvm/ffi/container/tuple.h | 11 +-
include/tvm/ffi/container/variant.h | 8 ++
include/tvm/ffi/dtype.h | 3 +
include/tvm/ffi/function.h | 6 ++
include/tvm/ffi/function_details.h | 43 +++++++-
include/tvm/ffi/reflection/registry.h | 143 +++++++++++++++++++++----
include/tvm/ffi/rvalue_ref.h | 8 ++
include/tvm/ffi/string.h | 72 ++++++++++++-
include/tvm/ffi/type_traits.h | 34 ++++++
python/tvm_ffi/__init__.py | 2 +
python/tvm_ffi/core.pyi | 13 +++
python/tvm_ffi/cython/object.pxi | 8 +-
python/tvm_ffi/cython/type_info.pxi | 82 ++++++++++++++
python/tvm_ffi/registry.py | 20 ++++
python/tvm_ffi/testing.py | 4 +
src/ffi/extra/json_writer.cc | 49 +--------
src/ffi/extra/testing.cc | 174 ++++++++++++++++++++++++++++++
src/ffi/function.cc | 9 +-
tests/cpp/test_metadata.cc | 195 ++++++++++++++++++++++++++++++++++
tests/python/test_metadata.py | 157 +++++++++++++++++++++++++++
24 files changed, 999 insertions(+), 76 deletions(-)
diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h
index c20f0e5..20e60eb 100644
--- a/include/tvm/ffi/base_details.h
+++ b/include/tvm/ffi/base_details.h
@@ -290,6 +290,22 @@ TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const
TVMFFIAny* data) {
return StableHashBytes(reinterpret_cast<const void*>(data),
sizeof(data->v_uint64));
}
+/*!
+ * \brief Helper to generate a JSON-based type schema for a given type.
+ * \tparam T The type to generate the schema for. Assuming `T` is not
+ * const-qualified or reference-qualified.
+ */
+template <typename T>
+struct TypeSchemaImpl;
+/*!
+ * \brief Helper to generate a JSON-based type schema for a given type.
+ * \tparam T The type to generate the schema for.
+ * \note This type removes const and reference qualifiers from `T` before
+ * passing it to `TypeSchemaImpl`.
+ */
+template <typename T>
+using TypeSchema =
TypeSchemaImpl<std::remove_const_t<std::remove_reference_t<T>>>;
+
} // namespace details
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/container/array.h
b/include/tvm/ffi/container/array.h
index db025c0..3e8184a 100644
--- a/include/tvm/ffi/container/array.h
+++ b/include/tvm/ffi/container/array.h
@@ -1135,6 +1135,13 @@ struct TypeTraits<Array<T>> : public
ObjectRefTypeTraitsBase<Array<T>> {
}
TVM_FFI_INLINE static std::string TypeStr() { return "Array<" +
details::Type2Str<T>::v() + ">"; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << "{\"type\":\"" << StaticTypeKey::kTVMFFIArray << "\",\"args\":[";
+ oss << details::TypeSchema<T>::v();
+ oss << "]}";
+ return oss.str();
+ }
};
namespace details {
diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index 08f9a47..f366b12 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -1755,6 +1755,14 @@ struct TypeTraits<Map<K, V>> : public
ObjectRefTypeTraitsBase<Map<K, V>> {
TVM_FFI_INLINE static std::string TypeStr() {
return "Map<" + details::Type2Str<K>::v() + ", " +
details::Type2Str<V>::v() + ">";
}
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << "{\"type\":\"" << StaticTypeKey::kTVMFFIMap << "\",\"args\":[";
+ oss << details::TypeSchema<K>::v() << ",";
+ oss << details::TypeSchema<V>::v();
+ oss << "]}";
+ return oss.str();
+ }
};
namespace details {
diff --git a/include/tvm/ffi/container/tensor.h
b/include/tvm/ffi/container/tensor.h
index f451aaf..c1d7e4d 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -615,6 +615,9 @@ struct TypeTraits<TensorView> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIDLTensorPtr; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIDLTensorPtr) +
"\"}";
+ }
};
} // namespace ffi
diff --git a/include/tvm/ffi/container/tuple.h
b/include/tvm/ffi/container/tuple.h
index 7534240..4bbeb0a 100644
--- a/include/tvm/ffi/container/tuple.h
+++ b/include/tvm/ffi/container/tuple.h
@@ -267,8 +267,7 @@ struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types.
return true;
}
- TVM_FFI_INLINE static std::optional<Tuple<Types...>>
TryCastFromAnyView(const TVMFFIAny* src //
- ) {
+ TVM_FFI_INLINE static std::optional<Tuple<Types...>>
TryCastFromAnyView(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt;
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
if (n->size() != sizeof...(Types)) return std::nullopt;
@@ -305,6 +304,14 @@ struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types.
TVM_FFI_INLINE static std::string TypeStr() {
return details::ContainerTypeStr<Types...>("Tuple");
}
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << "{\"type\":\"Tuple\",\"args\":[";
+ const char* sep = "";
+ ((oss << sep << details::TypeSchema<Types>::v(), sep = ","), ...);
+ oss << "]}";
+ return oss.str();
+ }
};
namespace details {
diff --git a/include/tvm/ffi/container/variant.h
b/include/tvm/ffi/container/variant.h
index cae5a67..016ec19 100644
--- a/include/tvm/ffi/container/variant.h
+++ b/include/tvm/ffi/container/variant.h
@@ -280,6 +280,14 @@ struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
details::ContainerTypeStr<V...>("Variant"); }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << "{\"type\":\"Variant\",\"args\":[";
+ const char* sep = "";
+ ((oss << sep << details::TypeSchema<V>::v(), sep = ","), ...);
+ oss << "]}";
+ return oss.str();
+ }
};
template <typename... V>
diff --git a/include/tvm/ffi/dtype.h b/include/tvm/ffi/dtype.h
index c0f9f68..84727bb 100644
--- a/include/tvm/ffi/dtype.h
+++ b/include/tvm/ffi/dtype.h
@@ -178,6 +178,9 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
ffi::StaticTypeKey::kTVMFFIDataType; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) +
"\"}";
+ }
};
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index ba3ec0d..357fca4 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -740,6 +740,11 @@ class TypedFunction<R(Args...)> {
bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }
+ /*!
+ * \brief Get the type schema of `TypedFunction<R(Args...)>` in json format.
+ * \return The type schema of the function in json format.
+ */
+ static std::string TypeSchema() { return details::FuncFunctorImpl<R,
Args...>::TypeSchema(); }
private:
/*! \brief The internal packed function */
@@ -780,6 +785,7 @@ struct TypeTraits<TypedFunction<FType>> : public
TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
details::FunctionInfo<FType>::Sig(); }
+ TVM_FFI_INLINE static std::string TypeSchema() { return
TypedFunction<FType>::TypeSchema(); }
};
/*!
diff --git a/include/tvm/ffi/function_details.h
b/include/tvm/ffi/function_details.h
index 20ca44c..43c0fa4 100644
--- a/include/tvm/ffi/function_details.h
+++ b/include/tvm/ffi/function_details.h
@@ -34,6 +34,8 @@
namespace tvm {
namespace ffi {
+// forward declaration
+class Function;
namespace details {
template <typename ArgType>
@@ -76,7 +78,6 @@ struct FuncFunctorImpl {
/*! \brief Whether this function can be converted to ffi::Function via
FromTyped */
static constexpr bool unpacked_supported = (ArgSupported<Args> && ...) &&
(RetSupported<R>);
#endif
-
TVM_FFI_INLINE static std::string Sig() {
using IdxSeq = std::make_index_sequence<sizeof...(Args)>;
std::ostringstream ss;
@@ -85,6 +86,14 @@ struct FuncFunctorImpl {
ss << ") -> " << Type2Str<R>::v();
return ss.str();
}
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << "{\"type\":\"" << StaticTypeKey::kTVMFFIFunction << "\",\"args\":[";
+ oss << details::TypeSchema<R>::v();
+ ((oss << "," << details::TypeSchema<Args>::v()), ...);
+ oss << "]}";
+ return oss.str();
+ }
};
template <typename T>
@@ -102,11 +111,15 @@ struct FunctionInfoHelper<R (T::*)(Args...) const> :
FuncFunctorImpl<R, Args...>
*/
template <typename T>
struct FunctionInfo : FunctionInfoHelper<decltype(&T::operator())> {};
-
template <typename R, typename... Args>
struct FunctionInfo<R(Args...)> : FuncFunctorImpl<R, Args...> {};
template <typename R, typename... Args>
struct FunctionInfo<R (*)(Args...)> : FuncFunctorImpl<R, Args...> {};
+// Support pointer-to-member functions used in reflection (e.g. &Class::method)
+template <typename Class, typename R, typename... Args>
+struct FunctionInfo<R (Class::*)(Args...)> : FuncFunctorImpl<R, Args...> {};
+template <typename Class, typename R, typename... Args>
+struct FunctionInfo<R (Class::*)(Args...) const> : FuncFunctorImpl<R, Args...>
{};
/*! \brief Using static function to output typed function signature */
typedef std::string (*FGetFuncSignature)();
@@ -204,6 +217,32 @@ TVM_FFI_INLINE static Error MoveFromSafeCallRaised() {
TVM_FFI_INLINE static void SetSafeCallRaised(const Error& error) {
TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error));
}
+
+template <typename T>
+struct TypeSchemaImpl {
+ static std::string v() {
+ using U = std::remove_const_t<std::remove_reference_t<T>>;
+ return TypeTraits<U>::TypeSchema();
+ }
+};
+
+template <>
+struct TypeSchemaImpl<void> {
+ static std::string v() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFINone) + "\"}";
+ }
+};
+
+template <>
+struct TypeSchemaImpl<Any> {
+ static std::string v() { return "{\"type\":\"" +
std::string(StaticTypeKey::kTVMFFIAny) + "\"}"; }
+};
+
+template <>
+struct TypeSchemaImpl<AnyView> {
+ static std::string v() { return "{\"type\":\"" +
std::string(StaticTypeKey::kTVMFFIAny) + "\"}"; }
+};
+
} // namespace details
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 30f3a11..87a2f05 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -25,27 +25,117 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/string.h>
#include <tvm/ffi/type_traits.h>
+#include <iterator>
+#include <optional>
+#include <sstream>
#include <string>
#include <utility>
+#include <vector>
namespace tvm {
namespace ffi {
/*! \brief Reflection namespace */
namespace reflection {
+/*!
+ * \brief Types of temporary metadata hold in FieldInfoBuilder and
MethodInfoBuilder,
+ * before they are filled into final C metadata
+ */
+using _MetadataType = std::vector<std::pair<String, Any>>;
+/*!
+ * \brief Builder for TVMFFIFieldInfo
+ * \sa TVMFFIFieldInfo
+ */
+struct FieldInfoBuilder : public TVMFFIFieldInfo {
+ /*! \brief Temporary metadata info to be filled into
TVMFFIFieldInfo::metadata */
+ _MetadataType metadata_;
+};
+/*!
+ * \brief Builder for TVMFFIMethodInfo
+ * \sa TVMFFIMethodInfo
+ */
+struct MethodInfoBuilder : public TVMFFIMethodInfo {
+ /*! \brief Temporary metadata info to be filled into
TVMFFIMethodInfo::metadata */
+ _MetadataType metadata_;
+};
/*!
- * \brief Trait that can be used to set field info
+ * \brief Trait that can be used to set information attached to a field or a
method.
* \sa DefaultValue, AttachFieldFlag
*/
-struct FieldInfoTrait {};
+struct InfoTrait {};
+
+/*! \brief User-supplied metadata attached to a field or a method */
+class Metadata : public InfoTrait {
+ public:
+ /*!
+ * \brief Constructor
+ * \param dict The initial dictionary
+ */
+ explicit Metadata(std::initializer_list<std::pair<String, Any>> dict) :
dict_(dict) {}
+ /*!
+ * \brief Move metadata into `FieldInfoBuilder`
+ * \param info The field info builder.
+ */
+ inline void Apply(FieldInfoBuilder* info) const {
this->Apply(&info->metadata_); }
+ /*!
+ * \brief Move metadata into `MethodInfoBuilder`
+ * \param info The method info builder.
+ */
+ inline void Apply(MethodInfoBuilder* info) const {
this->Apply(&info->metadata_); }
+ private:
+ friend class GlobalDef;
+ template <typename T>
+ friend class ObjectDef;
+ /*!
+ * \brief Move metadata into a vector of key-value pairs.
+ * \param out The output vector.
+ */
+ inline void Apply(_MetadataType* out) const {
+ std::copy(std::make_move_iterator(dict_.begin()),
std::make_move_iterator(dict_.end()),
+ std::back_inserter(*out));
+ }
+ /*! \brief Convert the metadata to JSON string */
+ static std::string ToJSON(const _MetadataType& metadata) {
+ using ::tvm::ffi::details::StringObj;
+ std::ostringstream os;
+ os << "{";
+ bool first = true;
+ for (const auto& [key, value] : metadata) {
+ if (!first) {
+ os << ",";
+ }
+ os << "\"" << key << "\":";
+ if (std::optional<int> v = value.as<int>()) {
+ os << *v;
+ } else if (std::optional<bool> v = value.as<bool>()) {
+ os << (*v ? "true" : "false");
+ } else if (std::optional<String> v = value.as<String>()) {
+ String escaped = EscapeString(*v);
+ os << escaped.c_str();
+ } else {
+ TVM_FFI_LOG_AND_THROW(TypeError) << "Metadata can be only int, bool or
string, but on key `"
+ << key << "`, the type is " <<
value.GetTypeKey();
+ }
+ first = false;
+ }
+ os << "}";
+ return os.str();
+ }
+
+ std::vector<std::pair<String, Any>> dict_;
+};
/*!
* \brief Trait that can be used to set field default value
*/
-class DefaultValue : public FieldInfoTrait {
+class DefaultValue : public InfoTrait {
public:
/*!
* \brief Constructor
@@ -69,7 +159,7 @@ class DefaultValue : public FieldInfoTrait {
/*!
* \brief Trait that can be used to attach field flag
*/
-class AttachFieldFlag : public FieldInfoTrait {
+class AttachFieldFlag : public InfoTrait {
public:
/*!
* \brief Attach a field flag to the field
@@ -154,8 +244,8 @@ class ReflectionDefBase {
}
template <typename T>
- TVM_FFI_INLINE static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const
T& value) {
- if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
+ TVM_FFI_INLINE static void ApplyFieldInfoTrait(FieldInfoBuilder* info, const
T& value) {
+ if constexpr (std::is_base_of_v<InfoTrait, std::decay_t<T>>) {
value.Apply(info);
}
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
@@ -164,7 +254,10 @@ class ReflectionDefBase {
}
template <typename T>
- TVM_FFI_INLINE static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info,
const T& value) {
+ TVM_FFI_INLINE static void ApplyMethodInfoTrait(MethodInfoBuilder* info,
const T& value) {
+ if constexpr (std::is_base_of_v<InfoTrait, std::decay_t<T>>) {
+ value.Apply(info);
+ }
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value,
std::char_traits<char>::length(value)};
}
@@ -244,14 +337,15 @@ class GlobalDef : public ReflectionDefBase {
*
* \param name The name of the function.
* \param func The function to be registered.
- * \param extra The extra arguments that can be docstring or subclass of
FieldInfoTrait.
+ * \param extra The extra arguments that can be docstring or subclass of
InfoTrait.
*
* \return The reflection definition.
*/
template <typename Func, typename... Extra>
GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
+ using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func),
std::string(name)),
- std::forward<Extra>(extra)...);
+ FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
}
@@ -263,13 +357,14 @@ class GlobalDef : public ReflectionDefBase {
*
* \param name The name of the function.
* \param func The function to be registered.
- * \param extra The extra arguments that can be docstring or subclass of
FieldInfoTrait.
+ * \param extra The extra arguments that can be docstring or subclass of
InfoTrait.
*
* \return The reflection definition.
*/
template <typename Func, typename... Extra>
GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
- RegisterFunc(name, ffi::Function::FromPacked(func),
std::forward<Extra>(extra)...);
+ RegisterFunc(name, ffi::Function::FromPacked(func),
details::TypeSchemaImpl<Function>::v(),
+ std::forward<Extra>(extra)...);
return *this;
}
@@ -289,23 +384,24 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
+ using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, GetMethod(std::string(name), std::forward<Func>(func)),
- std::forward<Extra>(extra)...);
+ FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
}
private:
template <typename... Extra>
- void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) {
- TVMFFIMethodInfo info;
+ void RegisterFunc(const char* name, ffi::Function func, String type_schema,
Extra&&... extra) {
+ MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
- info.metadata = TVMFFIByteArray{nullptr, 0};
info.flags = 0;
- // obtain the method function
info.method = AnyView(func).CopyToTVMFFIAny();
- // apply method info traits
+ info.metadata_.emplace_back("type_schema", type_schema);
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
+ std::string metadata_str = Metadata::ToJSON(info.metadata_);
+ info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0));
}
};
@@ -430,7 +526,7 @@ class ObjectDef : public ReflectionDefBase {
void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
ExtraArgs&&... extra_args) {
static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a
base class of Class");
- TVMFFIFieldInfo info;
+ FieldInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
// store byte offset and setter, getter
@@ -447,20 +543,22 @@ class ObjectDef : public ReflectionDefBase {
// initialize default value to nullptr
info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
- info.metadata = TVMFFIByteArray{nullptr, 0};
+ info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
// apply field info traits
((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
// call register
+ std::string metadata_str = Metadata::ToJSON(info.metadata_);
+ info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
}
// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func,
Extra&&... extra) {
- TVMFFIMethodInfo info;
+ using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+ MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
- info.metadata = TVMFFIByteArray{nullptr, 0};
info.flags = 0;
if (is_static) {
info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
@@ -468,8 +566,11 @@ class ObjectDef : public ReflectionDefBase {
// obtain the method function
Function method = GetMethod(std::string(type_key_) + "." + name,
std::forward<Func>(func));
info.method = AnyView(method).CopyToTVMFFIAny();
+ info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema());
// apply method info traits
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
+ std::string metadata_str = Metadata::ToJSON(info.metadata_);
+ info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
}
diff --git a/include/tvm/ffi/rvalue_ref.h b/include/tvm/ffi/rvalue_ref.h
index ebbec58..5e6ca01 100644
--- a/include/tvm/ffi/rvalue_ref.h
+++ b/include/tvm/ffi/rvalue_ref.h
@@ -148,6 +148,14 @@ struct TypeTraits<RValueRef<TObjRef>> : public
TypeTraitsBase {
TVM_FFI_INLINE static std::string TypeStr() {
return "RValueRef<" + TypeTraits<TObjRef>::TypeStr() + ">";
}
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << "{\"type\":\"" << StaticTypeKey::kTVMFFIObjectRValueRef <<
"\",\"args\":[";
+ oss << TypeTraits<TObjRef>::TypeSchema();
+ oss << "]}";
+ return oss.str();
+ }
};
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/string.h b/include/tvm/ffi/string.h
index a1529d7..a99247b 100644
--- a/include/tvm/ffi/string.h
+++ b/include/tvm/ffi/string.h
@@ -32,10 +32,9 @@
#include <cstddef>
#include <cstring>
-#include <initializer_list>
+#include <sstream>
#include <string>
#include <string_view>
-#include <type_traits>
#include <utility>
// Note: We place string in tvm/ffi instead of tvm/ffi/container
@@ -44,6 +43,18 @@
// The following dependency relation holds
// any -> string -> object
+/// \cond Doxygen_Suppress
+#ifdef _MSC_VER
+#define TVM_FFI_SNPRINTF _snprintf_s
+#pragma warning(push)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4127)
+#pragma warning(disable : 4702)
+#else
+#define TVM_FFI_SNPRINTF snprintf
+#endif
+/// \endcond
+
namespace tvm {
namespace ffi {
namespace details {
@@ -664,6 +675,52 @@ class String {
friend String operator+(const char* lhs, const String& rhs);
};
+/*!
+ * \brief Return an escaped version of the string
+ * \param value The input string
+ * \return The escaped string, quoted with double quotes
+ */
+inline String EscapeString(const String& value) {
+ std::ostringstream oss;
+ oss << '"';
+ const char* data = value.data();
+ const size_t size = value.size();
+ for (size_t i = 0; i < size; ++i) {
+ switch (data[i]) {
+/// \cond Doxygen_Suppress
+#define TVM_FFI_ESCAPE_CHAR(pattern, val) \
+ case pattern: \
+ oss << val; \
+ break
+ TVM_FFI_ESCAPE_CHAR('\"', "\\\"");
+ TVM_FFI_ESCAPE_CHAR('\\', "\\\\");
+ TVM_FFI_ESCAPE_CHAR('/', "\\/");
+ TVM_FFI_ESCAPE_CHAR('\b', "\\b");
+ TVM_FFI_ESCAPE_CHAR('\f', "\\f");
+ TVM_FFI_ESCAPE_CHAR('\n', "\\n");
+ TVM_FFI_ESCAPE_CHAR('\r', "\\r");
+ TVM_FFI_ESCAPE_CHAR('\t', "\\t");
+#undef TVM_FFI_ESCAPE_CHAR
+ /// \endcond
+ default: {
+ uint8_t u8_val = static_cast<uint8_t>(data[i]);
+ // this is a control character, print as \uXXXX
+ if (u8_val < 0x20 || u8_val == 0x7f) {
+ char buffer[8];
+ int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "\\u%04x",
+ static_cast<int32_t>(data[i]) & 0xff);
+ oss.write(buffer, size);
+ } else {
+ oss << data[i];
+ }
+ break;
+ }
+ }
+ }
+ oss << '"';
+ return String(oss.str());
+}
+
/*! \brief Convert TVMFFIByteArray to std::string_view */
TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) {
return std::string_view(str.data, str.size);
@@ -712,6 +769,9 @@ struct TypeTraits<Bytes> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return "bytes"; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIBytes) + "\"}";
+ }
};
template <>
@@ -755,6 +815,9 @@ struct TypeTraits<String> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return "str"; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIStr) + "\"}";
+ }
};
// const char*, requirement: not nullable, do not retain ownership
@@ -799,6 +862,7 @@ struct TypeTraits<const char*> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return "const char*"; }
+ TVM_FFI_INLINE static std::string TypeSchema() { return "{\"type\":\"const
char*\"}"; }
};
// TVMFFIByteArray, requirement: not nullable, do not retain ownership
@@ -827,6 +891,9 @@ struct TypeTraits<TVMFFIByteArray*> : public TypeTraitsBase
{
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIByteArrayPtr; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIByteArrayPtr) +
"\"}";
+ }
};
template <>
@@ -847,6 +914,7 @@ struct TypeTraits<std::string>
}
TVM_FFI_INLINE static std::string TypeStr() { return "std::string"; }
+ TVM_FFI_INLINE static std::string TypeSchema() { return
"{\"type\":\"std::string\"}"; }
TVM_FFI_INLINE static std::string ConvertFallbackValue(const char* src) {
return std::string(src);
diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h
index 0fce70b..742881a 100644
--- a/include/tvm/ffi/type_traits.h
+++ b/include/tvm/ffi/type_traits.h
@@ -167,6 +167,9 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFINone; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFINone) + "\"}";
+ }
};
/**
@@ -226,6 +229,9 @@ struct TypeTraits<StrictBool> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIBool; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIBool) + "\"}";
+ }
};
// Bool type, allow implicit casting from int
@@ -262,6 +268,9 @@ struct TypeTraits<bool> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIBool; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIBool) + "\"}";
+ }
};
// Integer POD values
@@ -299,6 +308,9 @@ struct TypeTraits<Int,
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIInt; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIInt) + "\"}";
+ }
};
/// \cond Doxygen_Suppress
@@ -351,6 +363,9 @@ struct TypeTraits<IntEnum,
std::enable_if_t<is_integeral_enum_v<IntEnum>>> : pub
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIInt; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIInt) + "\"}";
+ }
};
// Float POD values
@@ -392,6 +407,9 @@ struct TypeTraits<Float,
std::enable_if_t<std::is_floating_point_v<Float>>>
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIFloat; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIFloat) + "\"}";
+ }
};
// void*
@@ -431,6 +449,9 @@ struct TypeTraits<void*> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIOpaquePtr; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIOpaquePtr) +
"\"}";
+ }
};
// Device
@@ -471,6 +492,9 @@ struct TypeTraits<DLDevice> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIDevice; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(StaticTypeKey::kTVMFFIDevice) + "\"}";
+ }
};
// DLTensor*, requirement: not nullable, do not retain ownership
@@ -514,6 +538,7 @@ struct TypeTraits<DLTensor*> : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return "DLTensor*"; }
+ TVM_FFI_INLINE static std::string TypeSchema() { return
"{\"type\":\"DLTensor*\"}"; }
};
// Traits for ObjectRef, None to ObjectRef will always fail.
@@ -599,6 +624,9 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
}
TVM_FFI_INLINE static std::string TypeStr() { return
ContainerType::_type_key; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(ContainerType::_type_key) + "\"}";
+ }
};
template <typename TObjRef>
@@ -725,6 +753,9 @@ struct TypeTraits<TObject*,
std::enable_if_t<std::is_base_of_v<Object, TObject>>
}
TVM_FFI_INLINE static std::string TypeStr() { return TObject::_type_key; }
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"" + std::string(TObject::_type_key) + "\"}";
+ }
};
template <typename T>
@@ -786,6 +817,9 @@ struct TypeTraits<Optional<T>> : public TypeTraitsBase {
TVM_FFI_INLINE static std::string TypeStr() {
return "Optional<" + TypeTraits<T>::TypeStr() + ">";
}
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return "{\"type\":\"Optional\",\"args\":[" + details::TypeSchema<T>::v() +
"]}";
+ }
};
} // namespace ffi
} // namespace tvm
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index caf994f..720968d 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -30,6 +30,7 @@ from .registry import (
register_object,
register_global_func,
get_global_func,
+ get_global_func_metadata,
remove_global_func,
init_ffi_api,
)
@@ -71,6 +72,7 @@ __all__ = [
"dtype",
"from_dlpack",
"get_global_func",
+ "get_global_func_metadata",
"init_ffi_api",
"load_module",
"register_error",
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 0c90ebf..e7c38f7 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -246,6 +246,17 @@ class Bytes(bytes, PyNativeObject):
# Type reflection metadata (from cython/type_info.pxi)
# ---------------------------------------------------------------------------
+class TypeSchema:
+ """Type schema for a TVM FFI type."""
+
+ origin: str
+ args: tuple[TypeSchema, ...] = ()
+
+ @staticmethod
+ def from_json_obj(obj: dict[str, Any]) -> TypeSchema: ...
+ @staticmethod
+ def from_json_str(s: str) -> TypeSchema: ...
+
class TypeField:
"""Description of a single reflected field on an FFI-backed type."""
@@ -254,6 +265,7 @@ class TypeField:
size: int
offset: int
frozen: bool
+ metadata: dict[str, Any]
getter: Any
setter: Any
dataclass_field: Any | None
@@ -267,6 +279,7 @@ class TypeMethod:
doc: str | None
func: Any
is_static: bool
+ metadata: dict[str, Any]
def as_callable(self, cls: type) -> Callable[..., Any]: ...
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 0bb1e03..b52b328 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import warnings
+import json
from typing import Any
@@ -340,8 +340,10 @@ cdef _type_info_create_from_type_key(object type_cls, str
type_key):
cdef int32_t type_index
cdef list ancestors = []
cdef int ancestor
+ cdef dict metadata_obj
cdef object fields = []
cdef object methods = []
+ cdef str type_schema_json
cdef FieldGetter getter
cdef FieldSetter setter
cdef ByteArrayArg type_key_arg = ByteArrayArg(c_str(type_key))
@@ -359,6 +361,7 @@ cdef _type_info_create_from_type_key(object type_cls, str
type_key):
setter = FieldSetter.__new__(FieldSetter)
(<FieldSetter>setter).setter = field.setter
(<FieldSetter>setter).offset = field.offset
+ metadata_obj = json.loads(bytearray_to_str(&field.metadata)) if
field.metadata.size != 0 else {}
fields.append(
TypeField(
name=bytearray_to_str(&field.name),
@@ -366,6 +369,7 @@ cdef _type_info_create_from_type_key(object type_cls, str
type_key):
size=field.size,
offset=field.offset,
frozen=(field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0,
+ metadata=metadata_obj,
getter=getter,
setter=setter,
)
@@ -373,12 +377,14 @@ cdef _type_info_create_from_type_key(object type_cls, str
type_key):
for i in range(info.num_methods):
method = &(info.methods[i])
+ metadata_obj = json.loads(bytearray_to_str(&method.metadata)) if
method.metadata.size != 0 else {}
methods.append(
TypeMethod(
name=bytearray_to_str(&method.name),
doc=bytearray_to_str(&method.doc) if method.doc.size != 0 else
None,
func=_get_method_from_method_info(method),
is_static=(method.flags &
kTVMFFIFieldFlagBitMaskIsStaticMethod) != 0,
+ metadata=metadata_obj,
)
)
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 0ef53e3..337e3db 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import dataclasses
+import json
from typing import Optional, Any
from io import StringIO
@@ -58,6 +59,82 @@ cdef class FieldSetter:
raise_existing_error()
raise move_from_last_error().py_error()
+_TYPE_SCHEMA_ORIGIN_CONVERTER = {
+ # A few Python-native types
+ "Variant": "Union",
+ "Optional": "Optional",
+ "Tuple": "tuple",
+ "ffi.Function": "Callable",
+ "ffi.Array": "list",
+ "ffi.Map": "dict",
+ "ffi.OpaquePyObject": "Any",
+ "ffi.Object": "Object",
+ "ffi.Tensor": "Tensor",
+ "DLTensor*": "Tensor",
+ # ctype types
+ "void*": "ctypes.c_void_p",
+ # bytes
+ "TVMFFIByteArray*": "bytes",
+ "ffi.SmallBytes": "bytes",
+ "ffi.Bytes": "bytes",
+ # strings
+ "std::string": "str",
+ "const char*": "str",
+ "ffi.SmallStr": "str",
+ "ffi.String": "str",
+}
+
+
[email protected](repr=False, frozen=True)
+class TypeSchema:
+ """Type schema for a TVM FFI type."""
+ origin: str
+ args: tuple[TypeSchema, ...] = ()
+
+ def __post_init__(self):
+ origin = self.origin
+ args = self.args
+ if origin == "Union":
+ assert len(args) >= 2, "Union must have at least two arguments"
+ elif origin == "Optional":
+ assert len(args) == 1, "Optional must have exactly one argument"
+ elif origin == "list":
+ assert len(args) == 1, "list must have exactly one argument"
+ elif origin == "dict":
+ assert len(args) == 2, "dict must have exactly two arguments"
+ elif origin == "tuple":
+ pass # tuple can have arbitrary number of arguments
+
+ def __repr__(self) -> str:
+ if self.origin == "Union":
+ return " | ".join(repr(a) for a in self.args)
+ elif self.origin == "Optional":
+ return repr(self.args[0]) + " | None"
+ elif self.origin == "Callable":
+ if not self.args:
+ return "Callable[..., Any]"
+ else:
+ arg_ret = self.args[0]
+ arg_args = self.args[1:]
+ return f"Callable[[{', '.join(repr(a) for a in arg_args)}],
{repr(arg_ret)}]"
+ elif not self.args:
+ return self.origin
+ else:
+ return f"{self.origin}[{', '.join(repr(a) for a in self.args)}]"
+
+ @staticmethod
+ def from_json_obj(obj: dict[str, Any]) -> "TypeSchema":
+ assert isinstance(obj, dict) and "type" in obj, obj
+ origin = obj["type"]
+ origin = _TYPE_SCHEMA_ORIGIN_CONVERTER.get(origin, origin)
+ args = obj.get("args", ())
+ args = tuple(TypeSchema.from_json_obj(a) for a in args)
+ return TypeSchema(origin, args)
+
+ @staticmethod
+ def from_json_str(s) -> "TypeSchema":
+ return TypeSchema.from_json_obj(json.loads(s))
+
@dataclasses.dataclass(eq=False)
class TypeField:
@@ -68,6 +145,7 @@ class TypeField:
size: int
offset: int
frozen: bool
+ metadata: dict[str, Any]
getter: FieldGetter
setter: FieldSetter
dataclass_field: Any = None
@@ -103,8 +181,12 @@ class TypeMethod:
name: str
doc: Optional[str]
func: object
+ metadata: dict[str, Any]
is_static: bool
+ def __post_init__(self):
+ assert callable(self.func)
+
def as_callable(self, object cls):
"""Create a Python method attribute for this method on ``cls``."""
cdef str name = self.name
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 45c1da0..6a9dad6 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -18,6 +18,7 @@
from __future__ import annotations
+import json
import sys
from typing import Any, Callable, Literal, overload
@@ -58,6 +59,7 @@ def register_object(type_key: str | type | None = None) ->
Callable[[type], type
raise ValueError(f"Cannot find object type index for
{object_name}")
info = core._register_object_by_index(type_index, cls)
_add_class_attrs(type_cls=cls, type_info=info)
+ setattr(cls, "__tvm_ffi_type_info__", info)
return cls
if isinstance(type_key, str):
@@ -199,6 +201,23 @@ def remove_global_func(name: str) -> None:
get_global_func("ffi.FunctionRemoveGlobal")(name)
+def get_global_func_metadata(name: str) -> dict[str, Any]:
+ """Get the type schema string of a global function by name.
+
+ Parameters
+ ----------
+ name : str
+ The name of the global function
+
+ Returns
+ -------
+ metadata : dict
+ The metadata of the function
+
+ """
+ return json.loads(get_global_func("ffi.GetGlobalFuncMetadata")(name))
+
+
def init_ffi_api(namespace: str, target_module_name: str | None = None) ->
None:
"""Initialize register ffi api functions into a given module.
@@ -265,6 +284,7 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
__all__ = [
"get_global_func",
+ "get_global_func_metadata",
"init_ffi_api",
"list_global_func_names",
"register_global_func",
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index dbf8636..097e298 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -122,3 +122,7 @@ class _TestCxxInitSubset:
required_field: int
optional_field: int = field(init=False)
note: str = field(default_factory=lambda: "py-default", init=False)
+
+
+@register_object("testing.SchemaAllTypes")
+class _SchemaAllTypes: ...
diff --git a/src/ffi/extra/json_writer.cc b/src/ffi/extra/json_writer.cc
index 1a4636d..e480a8f 100644
--- a/src/ffi/extra/json_writer.cc
+++ b/src/ffi/extra/json_writer.cc
@@ -35,16 +35,6 @@
#include <limits>
#include <string>
-#ifdef _MSC_VER
-#define TVM_FFI_SNPRINTF _snprintf_s
-#pragma warning(push)
-#pragma warning(disable : 4244)
-#pragma warning(disable : 4127)
-#pragma warning(disable : 4702)
-#else
-#define TVM_FFI_SNPRINTF snprintf
-#endif
-
namespace tvm {
namespace ffi {
namespace json {
@@ -188,41 +178,8 @@ class JSONWriter {
}
void WriteString(const String& value) {
- *out_iter_++ = '"';
- const char* data = value.data();
- const size_t size = value.size();
- for (size_t i = 0; i < size; ++i) {
- switch (data[i]) {
-// handle escape characters per JSON spec(RFC 8259)
-#define HANDLE_ESCAPE_CHAR(pattern, val) \
- case pattern: \
- WriteLiteral(val, std::char_traits<char>::length(val)); \
- break
- HANDLE_ESCAPE_CHAR('\"', "\\\"");
- HANDLE_ESCAPE_CHAR('\\', "\\\\");
- HANDLE_ESCAPE_CHAR('/', "\\/");
- HANDLE_ESCAPE_CHAR('\b', "\\b");
- HANDLE_ESCAPE_CHAR('\f', "\\f");
- HANDLE_ESCAPE_CHAR('\n', "\\n");
- HANDLE_ESCAPE_CHAR('\r', "\\r");
- HANDLE_ESCAPE_CHAR('\t', "\\t");
-#undef HANDLE_ESCAPE_CHAR
- default: {
- uint8_t u8_val = static_cast<uint8_t>(data[i]);
- // this is a control character, print as \uXXXX
- if (u8_val < 0x20 || u8_val == 0x7f) {
- char buffer[8];
- int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "\\u%04x",
- static_cast<int32_t>(data[i]) & 0xff);
- WriteLiteral(buffer, size);
- } else {
- *out_iter_++ = data[i];
- }
- break;
- }
- }
- }
- *out_iter_++ = '"';
+ String escaped = EscapeString(value);
+ std::copy(escaped.data(), escaped.data() + escaped.size(), out_iter_);
}
void WriteArray(const json::Array& value) {
@@ -303,5 +260,3 @@ TVM_FFI_STATIC_INIT_BLOCK() {
} // namespace json
} // namespace ffi
} // namespace tvm
-
-#undef TVM_FFI_SNPRINTF
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 43ab8d2..555a14b 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -17,10 +17,16 @@
* under the License.
*/
// This file is used for testing the FFI API.
+#include <dlpack/dlpack.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/dtype.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <chrono>
@@ -238,3 +244,171 @@ TVM_FFI_STATIC_INIT_BLOCK() {
} // namespace ffi
} // namespace tvm
+
+//
-----------------------------------------------------------------------------
+// Additional comprehensive schema coverage
+//
-----------------------------------------------------------------------------
+namespace tvm {
+namespace ffi {
+
+// A class with a wide variety of field types and method signatures
+class SchemaAllTypesObj : public Object {
+ public:
+ // POD and builtin types
+ bool v_bool;
+ int64_t v_int;
+ double v_float;
+ DLDevice v_device;
+ DLDataType v_dtype;
+
+ // Atomic object types
+ String v_string;
+ Bytes v_bytes;
+
+ // Containers and combinations
+ Optional<int64_t> v_opt_int;
+ Optional<String> v_opt_str;
+ Array<int64_t> v_arr_int;
+ Array<String> v_arr_str;
+ Map<String, int64_t> v_map_str_int;
+ Map<String, Array<int64_t>> v_map_str_arr_int;
+ Variant<String, Array<int64_t>, Map<String, int64_t>> v_variant;
+ Optional<Array<Variant<int64_t, String>>> v_opt_arr_variant;
+
+ // Constructor used by refl::init in make_with
+ SchemaAllTypesObj(int64_t vi, double vf, String s) // NOLINT(*): explicit
not necessary here
+ : v_bool(true),
+ v_int(vi),
+ v_float(vf),
+ v_device(TVMFFIDLDeviceFromIntPair(kDLCPU, 0)),
+ v_dtype(StringToDLDataType("float32")),
+ v_string(std::move(s)),
+ v_variant(String("v")) {}
+
+ // Some methods to exercise RegisterMethod
+ int64_t AddInt(int64_t x) const { return v_int + x; }
+ Array<int64_t> AppendInt(Array<int64_t> xs, int64_t y) const {
+ xs.push_back(y);
+ return xs;
+ }
+ Optional<String> MaybeConcat(Optional<String> a, Optional<String> b) const {
+ if (a.has_value() && b.has_value()) return String(a.value() + b.value());
+ if (a.has_value()) return a;
+ if (b.has_value()) return b;
+ return Optional<String>(std::nullopt);
+ }
+ Map<String, Array<int64_t>> MergeMap(Map<String, Array<int64_t>> lhs,
+ Map<String, Array<int64_t>> rhs) const {
+ for (const auto& kv : rhs) {
+ if (!lhs.count(kv.first)) {
+ lhs.Set(kv.first, kv.second);
+ } else {
+ Array<int64_t> arr = lhs[kv.first];
+ for (const auto& v : kv.second) arr.push_back(v);
+ lhs.Set(kv.first, arr);
+ }
+ }
+ return lhs;
+ }
+
+ static constexpr bool _type_mutable = true;
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.SchemaAllTypes", SchemaAllTypesObj,
Object);
+};
+
+class SchemaAllTypes : public ObjectRef {
+ public:
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SchemaAllTypes, ObjectRef,
SchemaAllTypesObj);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ // Register fields of various types (RegisterField usage)
+ refl::ObjectDef<SchemaAllTypesObj>()
+ .def_rw("v_bool", &SchemaAllTypesObj::v_bool, "bool field",
+ refl::Metadata{{"bool_attr", true}, //
+ {"int_attr", 1}, //
+ {"str_attr", "hello"}})
+ .def_rw("v_int", &SchemaAllTypesObj::v_int, refl::DefaultValue(0), "int
field")
+ .def_rw("v_float", &SchemaAllTypesObj::v_float, refl::DefaultValue(0.0),
"float field")
+ .def_rw("v_device", &SchemaAllTypesObj::v_device, "device field")
+ .def_rw("v_dtype", &SchemaAllTypesObj::v_dtype, "dtype field")
+ .def_rw("v_string", &SchemaAllTypesObj::v_string,
refl::DefaultValue("s"), "string field")
+ .def_rw("v_bytes", &SchemaAllTypesObj::v_bytes, "bytes field")
+ .def_rw("v_opt_int", &SchemaAllTypesObj::v_opt_int, "optional int")
+ .def_rw("v_opt_str", &SchemaAllTypesObj::v_opt_str, "optional str")
+ .def_rw("v_arr_int", &SchemaAllTypesObj::v_arr_int, "array<int>")
+ .def_rw("v_arr_str", &SchemaAllTypesObj::v_arr_str, "array<str>")
+ .def_rw("v_map_str_int", &SchemaAllTypesObj::v_map_str_int,
"map<str,int>")
+ .def_rw("v_map_str_arr_int", &SchemaAllTypesObj::v_map_str_arr_int,
"map<str,array<int>>")
+ .def_rw("v_variant", &SchemaAllTypesObj::v_variant,
"variant<str,array<int>,map<str,int>>")
+ .def_rw("v_opt_arr_variant", &SchemaAllTypesObj::v_opt_arr_variant,
+ "optional<array<variant<int,str>>>")
+ // Register methods (RegisterMethod usage)
+ .def("add_int", &SchemaAllTypesObj::AddInt, "add int method",
+ refl::Metadata{{"bool_attr", true}, //
+ {"int_attr", 1}, //
+ {"str_attr", "hello"}})
+ .def("append_int", &SchemaAllTypesObj::AppendInt, "append int to array")
+ .def("maybe_concat", &SchemaAllTypesObj::MaybeConcat, "optional concat")
+ .def("merge_map", &SchemaAllTypesObj::MergeMap, "merge maps")
+ // Register static creator (also a static method)
+ .def_static(
+ "make_with",
+ [](int64_t vi, double vf, String s) {
+ return SchemaAllTypes(make_object<SchemaAllTypesObj>(vi, vf,
std::move(s)));
+ },
+ "Constructor with subset of fields");
+
+ // Global typed functions to exercise RegisterFunc with various schemas
+ refl::GlobalDef()
+ .def(
+ "testing.schema_id_int", [](int64_t x) { return x; },
+ refl::Metadata{{"bool_attr", true}, //
+ {"int_attr", 1}, //
+ {"str_attr", "hello"}})
+ .def("testing.schema_id_float", [](double x) { return x; })
+ .def("testing.schema_id_bool", [](bool x) { return x; })
+ .def("testing.schema_id_device", [](DLDevice d) { return d; })
+ .def("testing.schema_id_dtype", [](DLDataType dt) { return dt; })
+ .def("testing.schema_id_string", [](String s) { return s; })
+ .def("testing.schema_id_bytes", [](Bytes b) { return b; })
+ .def("testing.schema_id_func", [](Function f) -> Function { return f; })
+ .def("testing.schema_id_any", [](Any a) { return a; })
+ .def("testing.schema_id_object", [](ObjectRef o) { return o; })
+ .def("testing.schema_id_dltensor", [](DLTensor* t) { return t; })
+ .def("testing.schema_id_tensor", [](Tensor t) { return t; })
+ .def("testing.schema_tensor_view_input", [](TensorView t) -> void {})
+ .def("testing.schema_id_opt_int", [](Optional<int64_t> o) { return o; })
+ .def("testing.schema_id_opt_str", [](Optional<String> o) { return o; })
+ .def("testing.schema_id_opt_obj", [](Optional<ObjectRef> o) { return o;
})
+ .def("testing.schema_id_arr_int", [](Array<int64_t> arr) { return arr; })
+ .def("testing.schema_id_arr_str", [](Array<String> arr) { return arr; })
+ .def("testing.schema_id_arr_obj", [](Array<ObjectRef> arr) { return arr;
})
+ .def("testing.schema_id_map_str_int", [](Map<String, int64_t> m) {
return m; })
+ .def("testing.schema_id_map_str_str", [](Map<String, String> m) { return
m; })
+ .def("testing.schema_id_map_str_obj", [](Map<String, ObjectRef> m) {
return m; })
+ .def("testing.schema_id_variant_int_str", [](Variant<int64_t, String> v)
{ return v; })
+ .def_packed("testing.schema_packed", [](PackedArgs args, Any* ret) {})
+ .def("testing.schema_arr_map_opt",
+ [](Array<Optional<int64_t>> arr, Map<String, Array<int64_t>> mp,
+ Optional<String> os) -> Map<String, Array<int64_t>> {
+ // no-op combine
+ if (os.has_value()) {
+ Array<int64_t> extra;
+ for (size_t i = 0; i < arr.size(); ++i) {
+ if (arr[i].has_value()) extra.push_back(arr[i].value());
+ }
+ mp.Set(os.value(), extra);
+ }
+ return mp;
+ })
+ .def(
+ "testing.schema_variant_mix",
+ [](Variant<int64_t, String, Array<int64_t>> v) { return v; },
"variant passthrough")
+ .def("testing.schema_no_args", []() { return 1; })
+ .def("testing.schema_no_return", [](int64_t x) {})
+ .def("testing.schema_no_args_no_return", []() {});
+}
+
+} // namespace ffi
+} // namespace tvm
diff --git a/src/ffi/function.cc b/src/ffi/function.cc
index 954608f..fe55677 100644
--- a/src/ffi/function.cc
+++ b/src/ffi/function.cc
@@ -225,5 +225,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return tvm::ffi::Function::FromTyped(return_functor);
})
.def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return
val; })
- .def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return
val; });
+ .def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return
val; })
+ .def("ffi.GetGlobalFuncMetadata", [](tvm::ffi::String name) ->
tvm::ffi::String {
+ const auto* f = tvm::ffi::GlobalFunctionTable::Global()->Get(name);
+ if (f == nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "Global Function is not found: " <<
name;
+ }
+ return f->metadata_data;
+ });
}
diff --git a/tests/cpp/test_metadata.cc b/tests/cpp/test_metadata.cc
new file mode 100644
index 0000000..e0fa054
--- /dev/null
+++ b/tests/cpp/test_metadata.cc
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/extra/json.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/string.h>
+
+namespace {
+
+using namespace tvm::ffi;
+using namespace tvm::ffi::reflection;
+
+static std::string ParseMetadataToSchema(const String& metadata) {
+ return json::Parse(metadata)
+ .cast<Map<String, Any>>()["type_schema"] //
+ .cast<String>(); //
+}
+
+static std::string ParseMetadataToSchema(const TVMFFIByteArray& metadata) {
+ return json::Parse(String(metadata))
+ .cast<Map<String, Any>>()["type_schema"] //
+ .cast<String>(); //
+}
+
+TEST(Schema, GlobalFuncTypeSchema) {
+ // Helper to fetch global function type schema via the exposed utility
+ Function get_metadata =
Function::GetGlobalRequired("ffi.GetGlobalFuncMetadata");
+ auto fetch = [&](const char* name) -> std::string {
+ String metadata = get_metadata(String(name)).cast<String>();
+ return ParseMetadataToSchema(metadata);
+ };
+ // Simple IDs
+ EXPECT_EQ(fetch("testing.schema_id_int"),
+
R"({"type":"ffi.Function","args":[{"type":"int"},{"type":"int"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_float"),
+
R"({"type":"ffi.Function","args":[{"type":"float"},{"type":"float"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_bool"),
+
R"({"type":"ffi.Function","args":[{"type":"bool"},{"type":"bool"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_device"),
+
R"({"type":"ffi.Function","args":[{"type":"Device"},{"type":"Device"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_dtype"),
+
R"({"type":"ffi.Function","args":[{"type":"DataType"},{"type":"DataType"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_string"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.String"},{"type":"ffi.String"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_bytes"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Bytes"},{"type":"ffi.Bytes"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_func"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Function"},{"type":"ffi.Function"}]})");
+
+ EXPECT_EQ(fetch("testing.schema_id_any"),
+
R"({"type":"ffi.Function","args":[{"type":"Any"},{"type":"Any"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_object"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Object"},{"type":"ffi.Object"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_dltensor"),
+
R"({"type":"ffi.Function","args":[{"type":"DLTensor*"},{"type":"DLTensor*"}]})");
+ EXPECT_EQ(fetch("testing.schema_id_tensor"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Tensor"},{"type":"ffi.Tensor"}]})");
+ EXPECT_EQ(fetch("testing.schema_tensor_view_input"),
+
R"({"type":"ffi.Function","args":[{"type":"None"},{"type":"DLTensor*"}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_opt_int"),
+
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"int"}]},{"type":"Optional","args":[{"type":"int"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_opt_str"),
+
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_opt_obj"),
+
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"ffi.Object"}]},{"type":"Optional","args":[{"type":"ffi.Object"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_arr_int"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"ffi.Array","args":[{"type":"int"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_arr_str"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"ffi.String"}]},{"type":"ffi.Array","args":[{"type":"ffi.String"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_arr_obj"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"ffi.Object"}]},{"type":"ffi.Array","args":[{"type":"ffi.Object"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_map_str_int"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"int"}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"int"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_map_str_str"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.String"}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.String"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_map_str_obj"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Object"}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Object"}]}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_variant_int_str"),
+
R"({"type":"ffi.Function","args":[{"type":"Variant","args":[{"type":"int"},{"type":"ffi.String"}]},{"type":"Variant","args":[{"type":"int"},{"type":"ffi.String"}]}]})");
+
+ // Packed function registered via def_packed: schema is plain ffi.Function
+ EXPECT_EQ(fetch("testing.schema_packed"), R"({"type":"ffi.Function"})");
+
+ // Mixed containers and optionals
+ EXPECT_EQ(
+ fetch("testing.schema_arr_map_opt"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Array","args":[{"type":"Optional","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"Optional","args":[{"type":"ffi.String"}]}]})");
+
+ EXPECT_EQ(
+ fetch("testing.schema_variant_mix"),
+
R"({"type":"ffi.Function","args":[{"type":"Variant","args":[{"type":"int"},{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"Variant","args":[{"type":"int"},{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]}]})");
+
+ // No-arg and no-return combinations
+ EXPECT_EQ(fetch("testing.schema_no_args"),
R"({"type":"ffi.Function","args":[{"type":"int"}]})");
+ EXPECT_EQ(fetch("testing.schema_no_return"),
+
R"({"type":"ffi.Function","args":[{"type":"None"},{"type":"int"}]})");
+ EXPECT_EQ(fetch("testing.schema_no_args_no_return"),
+ R"({"type":"ffi.Function","args":[{"type":"None"}]})");
+}
+
+TEST(Schema, FieldTypeSchemas) {
+ // Validate type schema JSON on fields of testing.SchemaAllTypes
+ const char* kTypeKey = "testing.SchemaAllTypes";
+ // Helper to fetch a field's type schema by name
+ auto field_schema = [&](const char* field_name) -> std::string {
+ const TVMFFIFieldInfo* info = GetFieldInfo(kTypeKey, field_name);
+ return ParseMetadataToSchema(info->metadata);
+ };
+
+ EXPECT_EQ(field_schema("v_bool"), R"({"type":"bool"})");
+ EXPECT_EQ(field_schema("v_int"), R"({"type":"int"})");
+ EXPECT_EQ(field_schema("v_float"), R"({"type":"float"})");
+ EXPECT_EQ(field_schema("v_device"), R"({"type":"Device"})");
+ EXPECT_EQ(field_schema("v_dtype"), R"({"type":"DataType"})");
+ EXPECT_EQ(field_schema("v_string"), R"({"type":"ffi.String"})");
+ EXPECT_EQ(field_schema("v_bytes"), R"({"type":"ffi.Bytes"})");
+ EXPECT_EQ(field_schema("v_opt_int"),
R"({"type":"Optional","args":[{"type":"int"}]})");
+ EXPECT_EQ(field_schema("v_opt_str"),
R"({"type":"Optional","args":[{"type":"ffi.String"}]})");
+ EXPECT_EQ(field_schema("v_arr_int"),
R"({"type":"ffi.Array","args":[{"type":"int"}]})");
+ EXPECT_EQ(field_schema("v_arr_str"),
R"({"type":"ffi.Array","args":[{"type":"ffi.String"}]})");
+ EXPECT_EQ(field_schema("v_map_str_int"),
+
R"({"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"int"}]})");
+ EXPECT_EQ(
+ field_schema("v_map_str_arr_int"),
+
R"({"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]})");
+ EXPECT_EQ(
+ field_schema("v_variant"),
+
R"({"type":"Variant","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"int"}]}]})");
+ EXPECT_EQ(
+ field_schema("v_opt_arr_variant"),
+
R"({"type":"Optional","args":[{"type":"ffi.Array","args":[{"type":"Variant","args":[{"type":"int"},{"type":"ffi.String"}]}]}]})");
+}
+
+TEST(Schema, MethodTypeSchemas) {
+ const char* kTypeKey = "testing.SchemaAllTypes";
+ auto method_schema = [&](const char* method_name) -> std::string {
+ const TVMFFIMethodInfo* info = GetMethodInfo(kTypeKey, method_name);
+ return ParseMetadataToSchema(info->metadata);
+ };
+
+ // Instance methods
+ EXPECT_EQ(method_schema("add_int"),
+
R"({"type":"ffi.Function","args":[{"type":"int"},{"type":"int"}]})");
+ EXPECT_EQ(
+ method_schema("append_int"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"int"}]})");
+ EXPECT_EQ(
+ method_schema("maybe_concat"),
+
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]}]})");
+ EXPECT_EQ(
+ method_schema("merge_map"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]}]})");
+
+ // Static method make_with: return type is the object type itself.
+ // Build expected JSON as ffi.Function with return type = type_key and args
= (int, float, str)
+ EXPECT_EQ(
+ method_schema("make_with"),
+
R"({"type":"ffi.Function","args":[{"type":"testing.SchemaAllTypes"},{"type":"int"},{"type":"float"},{"type":"ffi.String"}]})");
+}
+
+} // namespace
diff --git a/tests/python/test_metadata.py b/tests/python/test_metadata.py
new file mode 100644
index 0000000..527ffba
--- /dev/null
+++ b/tests/python/test_metadata.py
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any
+
+import pytest
+from tvm_ffi import get_global_func_metadata
+from tvm_ffi.core import TypeInfo, TypeSchema
+from tvm_ffi.testing import _SchemaAllTypes
+
+
[email protected](
+ "func_name,expected",
+ [
+ ("testing.schema_id_int", "Callable[[int], int]"),
+ ("testing.schema_id_float", "Callable[[float], float]"),
+ ("testing.schema_id_bool", "Callable[[bool], bool]"),
+ ("testing.schema_id_device", "Callable[[Device], Device]"),
+ ("testing.schema_id_dtype", "Callable[[DataType], DataType]"),
+ ("testing.schema_id_string", "Callable[[str], str]"),
+ ("testing.schema_id_bytes", "Callable[[bytes], bytes]"),
+ ("testing.schema_id_func", "Callable[[Callable[..., Any]],
Callable[..., Any]]"),
+ ("testing.schema_id_any", "Callable[[Any], Any]"),
+ ("testing.schema_id_object", "Callable[[Object], Object]"),
+ ("testing.schema_id_dltensor", "Callable[[Tensor], Tensor]"),
+ ("testing.schema_id_tensor", "Callable[[Tensor], Tensor]"),
+ ("testing.schema_tensor_view_input", "Callable[[Tensor], None]"),
+ ("testing.schema_id_opt_int", "Callable[[int | None], int | None]"),
+ ("testing.schema_id_opt_str", "Callable[[str | None], str | None]"),
+ ("testing.schema_id_opt_obj", "Callable[[Object | None], Object |
None]"),
+ ("testing.schema_id_arr_int", "Callable[[list[int]], list[int]]"),
+ ("testing.schema_id_arr_str", "Callable[[list[str]], list[str]]"),
+ ("testing.schema_id_arr_obj", "Callable[[list[Object]],
list[Object]]"),
+ ("testing.schema_id_map_str_int", "Callable[[dict[str, int]],
dict[str, int]]"),
+ ("testing.schema_id_map_str_str", "Callable[[dict[str, str]],
dict[str, str]]"),
+ ("testing.schema_id_map_str_obj", "Callable[[dict[str, Object]],
dict[str, Object]]"),
+ ("testing.schema_id_variant_int_str", "Callable[[int | str], int |
str]"),
+ ("testing.schema_packed", "Callable[..., Any]"),
+ (
+ "testing.schema_arr_map_opt",
+ "Callable[[list[int | None], dict[str, list[int]], str | None],
dict[str, list[int]]]",
+ ),
+ ("testing.schema_variant_mix", "Callable[[int | str | list[int]], int
| str | list[int]]"),
+ ("testing.schema_no_args", "Callable[[], int]"),
+ ("testing.schema_no_return", "Callable[[int], None]"),
+ ("testing.schema_no_args_no_return", "Callable[[], None]"),
+ ],
+)
+def test_schema_global_func(func_name: str, expected: str) -> None:
+ metadata: dict[str, Any] = get_global_func_metadata(func_name)
+ actual: TypeSchema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(actual) == expected, f"{func_name}: {actual}"
+
+
[email protected](
+ "field_name,expected",
+ [
+ ("v_bool", "bool"),
+ ("v_int", "int"),
+ ("v_float", "float"),
+ ("v_device", "Device"),
+ ("v_dtype", "DataType"),
+ ("v_string", "str"),
+ ("v_bytes", "bytes"),
+ ("v_opt_int", "int | None"),
+ ("v_opt_str", "str | None"),
+ ("v_arr_int", "list[int]"),
+ ("v_arr_str", "list[str]"),
+ ("v_map_str_int", "dict[str, int]"),
+ ("v_map_str_arr_int", "dict[str, list[int]]"),
+ ("v_variant", "str | list[int] | dict[str, int]"),
+ ("v_opt_arr_variant", "list[int | str] | None"),
+ ],
+)
+def test_schema_field(field_name: str, expected: str) -> None:
+ type_info: TypeInfo = getattr(_SchemaAllTypes, "__tvm_ffi_type_info__")
+ for field in type_info.fields:
+ if field.name == field_name:
+ actual: TypeSchema =
TypeSchema.from_json_str(field.metadata["type_schema"])
+ assert str(actual) == expected, f"{field_name}: {actual}"
+ break
+ else:
+ raise ValueError(f"Field not found: {field_name}")
+
+
[email protected](
+ "method_name,expected",
+ [
+ ("add_int", "Callable[[int], int]"),
+ ("append_int", "Callable[[list[int], int], list[int]]"),
+ ("maybe_concat", "Callable[[str | None, str | None], str | None]"),
+ (
+ "merge_map",
+ "Callable[[dict[str, list[int]], dict[str, list[int]]], dict[str,
list[int]]]",
+ ),
+ ("make_with", "Callable[[int, float, str], testing.SchemaAllTypes]"),
+ ],
+)
+def test_schema_member_method(method_name: str, expected: str) -> None:
+ type_info: TypeInfo = getattr(_SchemaAllTypes, "__tvm_ffi_type_info__")
+ for method in type_info.methods:
+ if method.name == method_name:
+ actual: TypeSchema =
TypeSchema.from_json_str(method.metadata["type_schema"])
+ assert str(actual) == expected, f"{method_name}: {actual}"
+ break
+ else:
+ raise ValueError(f"Method not found: {method_name}")
+
+
+def test_metadata_global_func() -> None:
+ metadata: dict[str, Any] =
get_global_func_metadata("testing.schema_id_int")
+ assert len(metadata) == 4
+ assert "type_schema" in metadata
+ assert metadata["bool_attr"] is True
+ assert metadata["int_attr"] == 1
+ assert metadata["str_attr"] == "hello"
+
+
+def test_metadata_field() -> None:
+ type_info: TypeInfo = getattr(_SchemaAllTypes, "__tvm_ffi_type_info__")
+ for field in type_info.fields:
+ if field.name == "v_bool":
+ assert len(field.metadata) == 4
+ assert "type_schema" in field.metadata
+ assert field.metadata["bool_attr"] is True
+ assert field.metadata["int_attr"] == 1
+ assert field.metadata["str_attr"] == "hello"
+ break
+ else:
+ raise ValueError("Field not found: v_bool")
+
+
+def test_metadata_member_method() -> None:
+ type_info: TypeInfo = getattr(_SchemaAllTypes, "__tvm_ffi_type_info__")
+ for method in type_info.methods:
+ if method.name == "add_int":
+ assert len(method.metadata) == 4
+ assert "type_schema" in method.metadata
+ assert method.metadata["bool_attr"] is True
+ assert method.metadata["int_attr"] == 1
+ assert method.metadata["str_attr"] == "hello"
+ break
+ else:
+ raise ValueError("Method not found: add_int")