This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 695ac175b919240816c99783d2dff567839a5536
Author: tqchen <[email protected]>
AuthorDate: Sun Dec 29 07:38:59 2024 +0800

    [FFI] bool, device, tensor, raw str
---
 ffi/include/tvm/ffi/c_api.h            |   2 +-
 ffi/include/tvm/ffi/function_details.h |   7 +-
 ffi/include/tvm/ffi/reflection.h       |  24 ++--
 ffi/include/tvm/ffi/string.h           |  26 +++++
 ffi/include/tvm/ffi/type_traits.h      | 196 +++++++++++++++++++++++++++++++--
 ffi/src/ffi/object.cc                  |   2 +-
 ffi/tests/cpp/test_any.cc              | 129 ++++++++++++++++++++++
 ffi/tests/cpp/test_function.cc         |   4 +
 ffi/tests/cpp/test_reflection.cc       |   6 +-
 9 files changed, 365 insertions(+), 31 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 18c44e40da..5a626d03c0 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -358,7 +358,7 @@ TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, 
const TVMFFIFieldInf
  * \brief Register type method information for rutnime reflection.
  * \param type_index The type index
  * \param info The method info to be registered.
- * \return 0 when success, nonzero when failure happens 
+ * \return 0 when success, nonzero when failure happens
  */
 TVM_FFI_DLL int TVMFFIRegisterTypeMethod(int32_t type_index, const 
TVMFFIMethodInfo* info);
 
diff --git a/ffi/include/tvm/ffi/function_details.h 
b/ffi/include/tvm/ffi/function_details.h
index 0f6701e509..90d842d000 100644
--- a/ffi/include/tvm/ffi/function_details.h
+++ b/ffi/include/tvm/ffi/function_details.h
@@ -161,7 +161,7 @@ class MovableArgValueWithContext {
   template <typename Type>
   TVM_FFI_INLINE operator Type() {
     using TypeWithoutCR = std::remove_const_t<std::remove_reference_t<Type>>;
-  std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]);
+    std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]);
     if (opt.has_value()) {
       return std::move(*opt);
     }
@@ -210,9 +210,8 @@ struct unpack_call_dispatcher<R, 0, index, F> {
 template <int index, typename F>
 struct unpack_call_dispatcher<void, 0, index, F> {
   template <typename... Args>
-  TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature ,
-                                 const F& , int32_t , const AnyView* , Any* ,
-                                 Args&&... unpacked_args) {
+  TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const 
F&, int32_t,
+                                 const AnyView*, Any*, Args&&... 
unpacked_args) {
     f(std::forward<Args>(unpacked_args)...);
   }
 };
diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h
index 51573d1bc2..3e6a1ecc37 100644
--- a/ffi/include/tvm/ffi/reflection.h
+++ b/ffi/include/tvm/ffi/reflection.h
@@ -52,7 +52,7 @@ struct Type2FieldStaticTypeIndex<T, 
std::enable_if_t<TypeTraits<T>::enabled>> {
  * \returns The byteoffset
  */
 template <typename Class, typename T>
-inline int64_t GetFieldByteOffset(T Class::*field_ptr) {
+inline int64_t GetFieldByteOffset(T Class::* field_ptr) {
   return 
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
 }
 
@@ -61,13 +61,13 @@ class ReflectionDef {
   explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {}
 
   template <typename Class, typename T>
-  ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) {
+  ReflectionDef& def_readonly(const char* name, T Class::* field_ptr) {
     RegisterField(name, field_ptr, true);
     return *this;
   }
 
   template <typename Class, typename T>
-  ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) {
+  ReflectionDef& def_readwrite(const char* name, T Class::* field_ptr) {
     RegisterField(name, field_ptr, false);
     return *this;
   }
@@ -76,7 +76,7 @@ class ReflectionDef {
 
  private:
   template <typename Class, typename T>
-  void RegisterField(const char* name, T Class::*field_ptr, bool readonly) {
+  void RegisterField(const char* name, T Class::* field_ptr, bool readonly) {
     TVMFFIFieldInfo info;
     info.name = name;
     info.field_static_type_index = Type2FieldStaticTypeIndex<T>::value;
@@ -126,23 +126,19 @@ inline const TVMFFIFieldInfo* 
GetReflectionFieldInfo(const char* type_key, const
  */
 class ReflectionFieldGetter {
  public:
-  explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {
-  }
+  explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {}
 
-  Any operator()(const Object* obj_ptr) const { 
+  Any operator()(const Object* obj_ptr) const {
     Any result;
     const void* addr = reinterpret_cast<const char*>(obj_ptr) + 
field_info_->byte_offset;
-    TVM_FFI_CHECK_SAFE_CALL(field_info_->getter(const_cast<void*>(addr), 
reinterpret_cast<TVMFFIAny*>(&result)));
+    TVM_FFI_CHECK_SAFE_CALL(
+        field_info_->getter(const_cast<void*>(addr), 
reinterpret_cast<TVMFFIAny*>(&result)));
     return result;
   }
 
-  Any operator()(const ObjectPtr<Object>& obj_ptr) const { 
-    return operator()(obj_ptr.get());
-  }
+  Any operator()(const ObjectPtr<Object>& obj_ptr) const { return 
operator()(obj_ptr.get()); }
 
-  Any operator()(const ObjectRef& obj) const { 
-    return operator()(obj.get());
-  }  
+  Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); }
 
  private:
   const TVMFFIFieldInfo* field_info_;
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index 7cbed2f7e8..debd361e7d 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -28,6 +28,7 @@
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/memory.h>
 #include <tvm/ffi/object.h>
+#include <tvm/ffi/type_traits.h>
 
 #include <cstddef>
 #include <cstring>
@@ -305,6 +306,30 @@ class String : public ObjectRef {
   friend struct AnyEqual;
 };
 
+template <>
+inline constexpr bool use_default_type_traits_v<String> = false;
+
+// specialize to enable implicit conversion from const char*
+template <>
+struct TypeTraits<String> : public ObjectRefTypeTraitsBase<String> {
+  static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIRawStr) return true;
+    return ObjectRefTypeTraitsBase<String>::CheckAnyView(src);
+  }
+
+  static TVM_FFI_INLINE String CopyFromAnyViewAfterCheck(const TVMFFIAny* src) 
{
+    if (src->type_index == TypeIndex::kTVMFFIRawStr) {
+      return String(src->v_c_str);
+    }
+    return ObjectRefTypeTraitsBase<String>::CopyFromAnyViewAfterCheck(src);
+  }
+
+  static TVM_FFI_INLINE std::optional<String> TryCopyFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIRawStr) return 
String(src->v_c_str);
+    return ObjectRefTypeTraitsBase<String>::TryCopyFromAnyView(src);
+  }
+};
+
 inline String operator+(const String& lhs, const String& rhs) {
   size_t lhs_size = lhs.size();
   size_t rhs_size = rhs.size();
@@ -421,6 +446,7 @@ inline int String::memncmp(const char* lhs, const char* 
rhs, size_t lhs_count, s
     return 0;
   }
 }
+
 }  // namespace ffi
 }  // namespace tvm
 
diff --git a/ffi/include/tvm/ffi/type_traits.h 
b/ffi/include/tvm/ffi/type_traits.h
index cf88eb7319..69b3738be0 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -44,6 +44,8 @@ inline std::string TypeIndex2TypeKey(int32_t type_index) {
   switch (type_index) {
     case TypeIndex::kTVMFFINone:
       return "None";
+    case TypeIndex::kTVMFFIBool:
+      return "bool";
     case TypeIndex::kTVMFFIInt:
       return "int";
     case TypeIndex::kTVMFFIFloat:
@@ -154,14 +156,14 @@ struct TypeTraits<Int, 
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
   static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) { 
CopyToAnyView(src, result); }
 
   static TVM_FFI_INLINE std::optional<Int> TryCopyFromAnyView(const TVMFFIAny* 
src) {
-    if (src->type_index == TypeIndex::kTVMFFIInt) {
+    if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == 
TypeIndex::kTVMFFIBool) {
       return std::make_optional<Int>(src->v_int64);
     }
     return std::nullopt;
   }
 
   static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
-    return src->type_index == TypeIndex::kTVMFFIInt;
+    return src->type_index == TypeIndex::kTVMFFIInt || src->type_index == 
TypeIndex::kTVMFFIBool;
   }
 
   static TVM_FFI_INLINE int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
@@ -171,6 +173,36 @@ struct TypeTraits<Int, 
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
   static TVM_FFI_INLINE std::string TypeStr() { return "int"; }
 };
 
+// Bool type, allow implicit casting from int
+template <>
+struct TypeTraits<bool> : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool;
+
+  static TVM_FFI_INLINE void CopyToAnyView(const bool& src, TVMFFIAny* result) 
{
+    result->type_index = TypeIndex::kTVMFFIBool;
+    result->v_int64 = static_cast<int64_t>(src);
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(bool src, TVMFFIAny* result) { 
CopyToAnyView(src, result); }
+
+  static TVM_FFI_INLINE std::optional<bool> TryCopyFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == 
TypeIndex::kTVMFFIBool) {
+      return std::make_optional<bool>(static_cast<bool>(src->v_int64));
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIInt || src->type_index == 
TypeIndex::kTVMFFIBool;
+  }
+
+  static TVM_FFI_INLINE bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+    return static_cast<bool>(src->v_int64);
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return "bool"; }
+};
+
 // Float POD values
 template <typename Float>
 struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>>
@@ -187,14 +219,16 @@ struct TypeTraits<Float, 
std::enable_if_t<std::is_floating_point_v<Float>>>
   static TVM_FFI_INLINE std::optional<Float> TryCopyFromAnyView(const 
TVMFFIAny* src) {
     if (src->type_index == TypeIndex::kTVMFFIFloat) {
       return std::make_optional<Float>(src->v_float64);
-    } else if (src->type_index == TypeIndex::kTVMFFIInt) {
+    } else if (src->type_index == TypeIndex::kTVMFFIInt ||
+               src->type_index == TypeIndex::kTVMFFIBool) {
       return std::make_optional<Float>(src->v_int64);
     }
     return std::nullopt;
   }
 
   static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
-    return src->type_index == TypeIndex::kTVMFFIFloat || src->type_index == 
TypeIndex::kTVMFFIInt;
+    return src->type_index == TypeIndex::kTVMFFIFloat || src->type_index == 
TypeIndex::kTVMFFIInt ||
+           src->type_index == TypeIndex::kTVMFFIBool;
   }
 
   static TVM_FFI_INLINE Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
@@ -244,11 +278,154 @@ struct TypeTraits<void*> : public TypeTraitsBase {
   static TVM_FFI_INLINE std::string TypeStr() { return "void*"; }
 };
 
+// DataType
+template <>
+struct TypeTraits<DLDataType> : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = 
TypeIndex::kTVMFFIDataType;
+
+  static TVM_FFI_INLINE void CopyToAnyView(const DLDataType& src, TVMFFIAny* 
result) {
+    result->type_index = TypeIndex::kTVMFFIDataType;
+    result->v_dtype = src;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(DLDataType src, TVMFFIAny* result) {
+    result->type_index = TypeIndex::kTVMFFIDataType;
+    result->v_dtype = src;
+  }
+
+  static TVM_FFI_INLINE std::optional<DLDataType> TryCopyFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIDataType) {
+      return src->v_dtype;
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIDataType;
+  }
+
+  static TVM_FFI_INLINE DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    return src->v_dtype;
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return "DataType"; }
+};
+
+// Device
+template <>
+struct TypeTraits<DLDevice> : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDevice;
+
+  static TVM_FFI_INLINE void CopyToAnyView(const DLDevice& src, TVMFFIAny* 
result) {
+    result->type_index = TypeIndex::kTVMFFIDevice;
+    result->v_device = src;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(DLDevice src, TVMFFIAny* result) {
+    result->type_index = TypeIndex::kTVMFFIDevice;
+    result->v_device = src;
+  }
+
+  static TVM_FFI_INLINE std::optional<DLDevice> TryCopyFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIDevice) {
+      return src->v_device;
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIDevice;
+  }
+
+  static TVM_FFI_INLINE DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    return src->v_device;
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return "Device"; }
+};
+
+// DLTensor*, requirement: not nullable, do not retain ownership
+template <>
+struct TypeTraits<DLTensor*> : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = 
TypeIndex::kTVMFFIDLTensorPtr;
+
+  static TVM_FFI_INLINE void CopyToAnyView(DLTensor* src, TVMFFIAny* result) {
+    TVM_FFI_ICHECK_NOTNULL(src);
+    result->type_index = TypeIndex::kTVMFFIDLTensorPtr;
+    result->v_ptr = src;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(DLTensor* src, TVMFFIAny* result) {
+    TVM_FFI_THROW(RuntimeError)
+        << "DLTensor* cannot be held in Any as it does not retain ownership, 
use NDArray instead";
+  }
+
+  static TVM_FFI_INLINE std::optional<DLTensor*> TryCopyFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) {
+      return static_cast<DLTensor*>(src->v_ptr);
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIDLTensorPtr;
+  }
+
+  static TVM_FFI_INLINE DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    return static_cast<DLTensor*>(src->v_ptr);
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return "DLTensor*"; }
+};
+
+// const char*, requirement: not nullable, do not retain ownership
+template <>
+struct TypeTraits<const char*> : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr;
+
+  static TVM_FFI_INLINE void CopyToAnyView(const char* src, TVMFFIAny* result) 
{
+    TVM_FFI_ICHECK_NOTNULL(src);
+    result->type_index = TypeIndex::kTVMFFIRawStr;
+    result->v_c_str = src;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(const char* src, TVMFFIAny* result) {
+    TVM_FFI_THROW(RuntimeError)
+        << "const char* cannot be held in Any as it does not retain ownership, 
use NDArray instead";
+  }
+
+  static TVM_FFI_INLINE std::optional<const char*> TryCopyFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) {
+      return static_cast<const char*>(src->v_c_str);
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIRawStr;
+  }
+
+  static TVM_FFI_INLINE const char* CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    return static_cast<const char*>(src->v_ptr);
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return "const char*"; }
+};
+
+template <int N>
+struct TypeTraits<char[N]> : public TypeTraitsBase {
+  // NOTE: only enable implicit conversion into AnyView
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr;
+
+  static TVM_FFI_INLINE void CopyToAnyView(const char src[N], TVMFFIAny* 
result) {
+    result->type_index = TypeIndex::kTVMFFIRawStr;
+    result->v_c_str = src;
+  }
+};
+
 // Traits for ObjectRef
 template <typename TObjRef>
-struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, 
TObjRef> &&
-                                            
use_default_type_traits_v<TObjRef>>>
-    : public TypeTraitsBase {
+struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
   static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject;
   using ContainerType = typename TObjRef::ContainerType;
 
@@ -292,6 +469,11 @@ struct TypeTraits<TObjRef, 
std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef
   static TVM_FFI_INLINE std::string TypeStr() { return 
ContainerType::_type_key; }
 };
 
+template <typename TObjRef>
+struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, 
TObjRef> &&
+                                            
use_default_type_traits_v<TObjRef>>>
+    : public ObjectRefTypeTraitsBase<TObjRef> {};
+
 // Traits for ObjectPtr
 template <typename T>
 struct TypeTraits<ObjectPtr<T>> : public TypeTraitsBase {
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index a0fe306e78..1cf37ac448 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -172,7 +172,7 @@ class TypeTable {
     return it->second;
   }
 
-  Entry* GetTypeEntry(int32_t type_index) { 
+  Entry* GetTypeEntry(int32_t type_index) {
     Entry* entry = nullptr;
     if (type_index >= 0 && static_cast<size_t>(type_index) < 
type_table_.size()) {
       entry = type_table_[type_index].get();
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index 2cb3e8b95b..1fd312ecd2 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -60,6 +60,37 @@ TEST(Any, Int) {
   EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2);
 }
 
+TEST(Any, bool) {
+  AnyView view0;
+  std::optional<bool> opt_v0 = view0.TryAs<bool>();
+  EXPECT_TRUE(!opt_v0.has_value());
+
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] bool v0 = view0;
+        } catch (const Error& error) {
+          EXPECT_EQ(error->kind, "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot convert from type `None` to `bool`"), 
std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+  AnyView view1 = true;
+  EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool);
+  EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1);
+
+  int32_t int_v1 = view1;
+  EXPECT_EQ(int_v1, 1);
+
+  bool v1 = false;
+  view0 = v1;
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 0);
+}
+
 TEST(Any, Float) {
   AnyView view0;
   EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
@@ -94,6 +125,104 @@ TEST(Any, Float) {
   EXPECT_EQ(view0.CopyToTVMFFIAny().v_float64, 2);
 }
 
+TEST(Any, DataType) {
+  AnyView view0;
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
+
+  std::optional<DLDataType> opt_v0 = view0.TryAs<DLDataType>();
+  EXPECT_TRUE(!opt_v0.has_value());
+
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] DLDataType v0 = view0;
+        } catch (const Error& error) {
+          EXPECT_EQ(error->kind, "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot convert from type `None` to 
`DataType`"), std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+  DLDataType dtype{kDLFloat, 32, 1};
+
+  AnyView view1_dtype = dtype;
+  DLDataType dtype_v1 = view1_dtype;
+  EXPECT_EQ(dtype_v1.code, kDLFloat);
+  EXPECT_EQ(dtype_v1.bits, 32);
+  EXPECT_EQ(dtype_v1.lanes, 1);
+
+  Any view2 = DLDataType{kDLInt, 16, 2};
+  TVMFFIAny ffi_v2;
+  view2.MoveToTVMFFIAny(&ffi_v2);
+  EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDataType);
+  EXPECT_EQ(ffi_v2.v_dtype.code, kDLInt);
+  EXPECT_EQ(ffi_v2.v_dtype.bits, 16);
+  EXPECT_EQ(ffi_v2.v_dtype.lanes, 2);
+}
+
+TEST(Any, Device) {
+  AnyView view0;
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
+
+  std::optional<DLDevice> opt_v0 = view0.TryAs<DLDevice>();
+  EXPECT_TRUE(!opt_v0.has_value());
+
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] DLDevice v0 = view0;
+        } catch (const Error& error) {
+          EXPECT_EQ(error->kind, "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot convert from type `None` to `Device`"), 
std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+  DLDevice device{kDLCUDA, 1};
+
+  AnyView view1_device = device;
+  DLDevice dtype_v1 = view1_device;
+  EXPECT_EQ(dtype_v1.device_type, kDLCUDA);
+  EXPECT_EQ(dtype_v1.device_id, 1);
+
+  Any view2 = DLDevice{kDLCPU, 0};
+  TVMFFIAny ffi_v2;
+  view2.MoveToTVMFFIAny(&ffi_v2);
+  EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDevice);
+  EXPECT_EQ(ffi_v2.v_device.device_type, kDLCPU);
+  EXPECT_EQ(ffi_v2.v_device.device_id, 0);
+}
+
+TEST(Any, DLTensor) {
+  AnyView view0;
+
+  std::optional<DLTensor*> opt_v0 = view0.TryAs<DLTensor*>();
+  EXPECT_TRUE(!opt_v0.has_value());
+
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] DLTensor* v0 = view0;
+        } catch (const Error& error) {
+          EXPECT_EQ(error->kind, "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot convert from type `None` to 
`DLTensor*`"), std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+  DLTensor dltensor;
+
+  AnyView view1_dl = &dltensor;
+  DLTensor* dl_v1 = view1_dl;
+  EXPECT_EQ(dl_v1, &dltensor);
+}
+
 TEST(Any, Object) {
   AnyView view0;
   EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc
index 367e613665..ff098d70c9 100644
--- a/ffi/tests/cpp/test_function.cc
+++ b/ffi/tests/cpp/test_function.cc
@@ -112,6 +112,10 @@ TEST(Func, FromUnpacked) {
         }
       },
       ::tvm::ffi::Error);
+
+  Function fconcact =
+      Function::FromUnpacked([](const String& a, const String& b) -> String { 
return a + b; });
+  EXPECT_EQ(fconcact("abc", "def").operator String(), "abcdef");
 }
 
 TEST(Func, Global) {
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index c7901a4ca3..9de0015009 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -20,6 +20,7 @@
 #include <gtest/gtest.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/reflection.h>
+
 #include "./testing_object.h"
 
 namespace {
@@ -38,12 +39,9 @@ TEST(Reflection, GetFieldByteOffset) {
   EXPECT_EQ(details::GetFieldByteOffset(&A::y), 12);
 }
 
-
 TEST(Reflection, FieldGetter) {
   ObjectRef a = TInt(10);
-  details::ReflectionFieldGetter getter(
-    details::GetReflectionFieldInfo("test.Int", "value")
-  );
+  details::ReflectionFieldGetter 
getter(details::GetReflectionFieldInfo("test.Int", "value"));
   EXPECT_EQ(getter(a).operator int(), 10);
 }
 }  // namespace

Reply via email to