This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 1ae001e7012910ac2ff76d2b344832b19014ed78 Author: tqchen <[email protected]> AuthorDate: Sat Aug 17 15:58:44 2024 -0400 [FFI] Any support Co-authored-by: Junru Shao <[email protected]> --- ffi/include/tvm/ffi/any.h | 217 ++++++++++++++++++++++++++++++++++++ ffi/include/tvm/ffi/object.h | 119 ++++++++++++-------- ffi/include/tvm/ffi/type_traits.h | 179 +++++++++++++++++++++++++++++ ffi/tests/example/test_any.cc | 145 ++++++++++++++++++++++++ ffi/tests/example/test_c_ffi_abi.cc | 18 +++ ffi/tests/example/test_error.cc | 18 +++ ffi/tests/example/test_object.cc | 126 ++++++++------------- 7 files changed, 698 insertions(+), 124 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h new file mode 100644 index 0000000000..1568fe98b2 --- /dev/null +++ b/ffi/include/tvm/ffi/any.h @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/any.h + * \brief Any value support. + */ +#ifndef TVM_FFI_ANY_H_ +#define TVM_FFI_ANY_H_ + +#include <tvm/ffi/c_api.h> +#include <tvm/ffi/type_traits.h> + +namespace tvm { +namespace ffi { + +class Any; + +/*! + * \brief AnyView allows us to take un-managed reference view of any value. + */ +class AnyView { + protected: + /*! \brief The underlying backing data of the any object */ + TVMFFIAny data_; + // Any can see AnyView + friend class Any; + + public: + // NOTE: the following two functions uses styl style + // since they are common functions appearing in FFI. + /*! + * \brief Reset any view to None + */ + void reset() { data_.type_index = TypeIndex::kTVMFFINone; } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(AnyView& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + // default constructors + AnyView() { data_.type_index = TypeIndex::kTVMFFINone; } + ~AnyView() = default; + // constructors from any view + AnyView(const AnyView&) = default; + AnyView& operator=(const AnyView&) = default; + AnyView(AnyView&& other) : data_(other.data_) { other.data_.type_index = TypeIndex::kTVMFFINone; } + AnyView& operator=(AnyView&& other) { + // copy-and-swap idiom + AnyView(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + // constructor from general types + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + AnyView(const T& other) { // NOLINT(*) + TypeTraits<T>::ConvertToAnyView(other, &data_); + } + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + AnyView& operator=(const T& other) { // NOLINT(*) + // copy-and-swap idiom + AnyView(other).swap(*this); // NOLINT(*) + return *this; + } + + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + std::optional<T> TryAs() const { + return TypeTraits<T>::TryConvertFromAnyView(&data_); + } + + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + operator T() const { + std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_); + if (opt.has_value()) { + return std::move(opt.value()); + } + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) + << "` to `" << TypeTraits<T>::TypeStr() << "`"; + } + // The following functions are only used for testing purposes + /*! + * \return The underlying supporting data of any view + * \note This function is used only for testing purposes. + */ + TVMFFIAny AsTVMFFIAny() const { return data_; } + /*! + * \return Create an AnyView from TVMFFIAny + * \param data the underlying ffi data. + */ + static AnyView FromTVMFFIAny(TVMFFIAny data) { + AnyView view; + view.data_ = data; + return view; + } +}; + +// layout assert to ensure we can freely cast between the two types +static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); + +namespace details { +/*! + * \brief Helper function to inplace convert any view to any. + * \param data The pointer that represents the format as any view. + * \param extra_any_bytes Indicate that the data may contain extra bytes following + * the TVMFFIAny data structure. This is reserved for future possible optimizations + * of small-string and extended any object. + */ +TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, + [[maybe_unused]] size_t extra_any_bytes = 0) { + // TODO: string conversion. + if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectInternal::IncRefObjectInAny(data); + } +} +} // namespace details + +/*! + * \brief + */ +class Any { + protected: + /*! \brief The underlying backing data of the any object */ + TVMFFIAny data_; + + public: + /*! + * \brief Reset any to None + */ + void reset() { + if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectInternal::DecRefObjectInAny(&data_); + } + data_.type_index = TVMFFITypeIndex::kTVMFFINone; + } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(Any& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + // default constructors + Any() { data_.type_index = TypeIndex::kTVMFFINone; } + ~Any() { this->reset(); } + // constructors from Any + Any(const Any& other) : data_(other.data_) { + if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectInternal::IncRefObjectInAny(&data_); + } + } + Any(Any&& other) : data_(other.data_) { other.data_.type_index = TypeIndex::kTVMFFINone; } + Any& operator=(const Any& other) { + // copy-and-swap idiom + Any(other).swap(*this); // NOLINT(*) + return *this; + } + Any& operator=(Any&& other) { + // copy-and-swap idiom + Any(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + // convert from/to AnyView + Any(const AnyView& other) : data_(other.data_) { details::InplaceConvertAnyViewToAny(&data_); } + Any& operator=(const AnyView& other) { + // copy-and-swap idiom + Any(other).swap(*this); // NOLINT(*) + return *this; + } + /*! \brief Any can be converted to AnyView in zero cost. */ + operator AnyView() { return AnyView::FromTVMFFIAny(data_); } + // constructor from general types + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + Any(T other) { // NOLINT(*) + TypeTraits<T>::MoveToManagedAny(std::move(other), &data_); + } + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + Any& operator=(T other) { // NOLINT(*) + // copy-and-swap idiom + Any(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + std::optional<T> TryAs() const { + return TypeTraits<T>::TryConvertFromAnyView(&data_); + } + + template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> + operator T() const { + std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_); + if (opt.has_value()) { + return std::move(opt.value()); + } + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) + << "` to `" << TypeTraits<T>::TypeStr() << "`"; + } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 5a6b552e1c..a1547b1900 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -71,6 +71,19 @@ TVM_FFI_DLL int32_t ObjectGetOrAllocTypeIndex(const char* type_key, int32_t stat */ TVM_FFI_DLL const TypeInfo* ObjectGetTypeInfo(int32_t type_index); #endif // TVM_FFI_ALLOW_DYN_TYPE + +/*! + * Check if the type_index is an instance of TargetObjectType. + * + * \tparam TargetType The target object type to be checked. + * + * \param object_type_index The type index to be checked, caller + * ensures that the index is already within the object index range. + * + * \return Whether the target type is true. + */ +template <typename TargetType> +TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index); } // namespace details /*! @@ -130,38 +143,7 @@ class Object { */ template <typename TargetType> bool IsInstance() const { - // Everything is a subclass of object. - if constexpr (std::is_same<TargetType, Object>::value) return true; - - if constexpr (TargetType::_type_final) { - // if the target type is a final type - // then we only need to check the equivalence. - return header_.type_index == TargetType::RuntimeTypeIndex(); - } - - // if target type is a non-leaf type - // Check if type index falls into the range of reserved slots. - int32_t target_type_index = TargetType::RuntimeTypeIndex(); - int32_t begin = target_type_index; - // The condition will be optimized by constant-folding. - if constexpr (TargetType::_type_child_slots != 0) { - int32_t end = begin + TargetType::_type_child_slots; - if (header_.type_index >= begin && header_.type_index < end) return true; - } else { - if (header_.type_index == begin) return true; - } - if (!TargetType::_type_child_slots_can_overflow) return false; - // Invariance: parent index is always smaller than the child. - if (header_.type_index < target_type_index) return false; - // Do a runtime lookup of type information -#if TVM_FFI_ALLOW_DYN_TYPE - // the function checks that the info exists - const TypeInfo* type_info = details::ObjectGetTypeInfo(header_.type_index); - return (type_info->type_depth > TargetType::_type_depth && - type_info->type_acenstors[TargetType::_type_depth] == target_type_index); -#else - return false; -#endif + return details::IsObjectInstance<TargetType>(header_.type_index); } // Information about the object @@ -551,21 +533,42 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); namespace details { -// auxiliary class to enable static type info table at depth -template <int depth> -struct TypeInfoAtDepth : public TypeInfo { - /*! \brief extra type acenstors fields */ - int32_t _type_acenstors[depth]; - - TypeInfoAtDepth(const char* type_key, int32_t static_type_index) { - this->type_key = type_key; - this->type_key_hash = 0; - this->type_index = static_type_index; - this->type_depth = depth; - this->type_acenstors = _type_acenstors; +template <typename TargetType> +TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { + static_assert(std::is_base_of_v<Object, TargetType>); + // Everything is a subclass of object. + if constexpr (std::is_same<TargetType, Object>::value) return true; + + if constexpr (TargetType::_type_final) { + // if the target type is a final type + // then we only need to check the equivalence. + return object_type_index == TargetType::RuntimeTypeIndex(); } -}; + // if target type is a non-leaf type + // Check if type index falls into the range of reserved slots. + int32_t target_type_index = TargetType::RuntimeTypeIndex(); + int32_t begin = target_type_index; + // The condition will be optimized by constant-folding. + if constexpr (TargetType::_type_child_slots != 0) { + int32_t end = begin + TargetType::_type_child_slots; + if (object_type_index >= begin && object_type_index < end) return true; + } else { + if (object_type_index == begin) return true; + } + if (!TargetType::_type_child_slots_can_overflow) return false; + // Invariance: parent index is always smaller than the child. + if (object_type_index < target_type_index) return false; + // Do a runtime lookup of type information +#if TVM_FFI_ALLOW_DYN_TYPE + // the function checks that the info exists + const TypeInfo* type_info = details::ObjectGetTypeInfo(object_type_index); + return (type_info->type_depth > TargetType::_type_depth && + type_info->type_acenstors[TargetType::_type_depth] == target_type_index); +#else + return false; +#endif +} /*! * \brief Namespace to internally manipulate object class. * \note These functions are only supposed to be used by internal @@ -582,6 +585,31 @@ struct ObjectInternal { static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromUnowned(Object* raw_ptr) { return tvm::ffi::ObjectPtr<T>(raw_ptr); } + + template <typename T> + static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { + return tvm::ffi::ObjectPtr<T>(reinterpret_cast<Object*>(obj_ptr)); + } + + // Interactions with Any system + static TVM_FFI_INLINE void DecRefObjectInAny(TVMFFIAny* src) { + reinterpret_cast<Object*>(src->v_obj)->DecRef(); + } + + static TVM_FFI_INLINE void IncRefObjectInAny(TVMFFIAny* src) { + reinterpret_cast<Object*>(src->v_obj)->IncRef(); + } + + static TVM_FFI_INLINE TVMFFIObject* GetTVMFFIObjectPtrFromObjectRef(const ObjectRef& src) { + return GetHeader(src.data_.data_); + } + + static TVM_FFI_INLINE TVMFFIObject* MoveTVMFFIObjectPtrFromObjectRef(ObjectRef* src) { + Object* obj_ptr = src->data_.data_; + src->data_.data_ = nullptr; + return GetHeader(obj_ptr); + } + // Create objectptr by moving from an existing address of object and setting its // address to nullptr template <typename T> @@ -592,7 +620,6 @@ struct ObjectInternal { return ptr; } }; - } // namespace details } // namespace ffi } // namespace tvm diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h new file mode 100644 index 0000000000..0a1d7d9941 --- /dev/null +++ b/ffi/include/tvm/ffi/type_traits.h @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/object.h + * \brief A managed object in the TVM FFI. + */ +#ifndef TVM_FFI_TYPE_TRAITS_H_ +#define TVM_FFI_TYPE_TRAITS_H_ + +#include <tvm/ffi/c_api.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/internal_utils.h> +#include <tvm/ffi/object.h> + +#include <optional> +#include <type_traits> + +namespace tvm { +namespace ffi { + +/*! + * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. + * + * We need to implement the following conversion functions + * + * - void ConvertToAnyView(const T& src, TVMFFIAny* result); + * + * Convert a value to AnyView + * + * - std::optional<T> TryConvertFromAnyView(const TVMFFIAny* src); + * + * Try convert AnyView to a value type. + */ +template <typename, typename = void> +struct TypeTraits { + static constexpr bool enabled = false; +}; + +// Integer POD values +template <typename Int> +struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> { + static constexpr bool enabled = true; + + static TVM_FFI_INLINE void ConvertToAnyView(const Int& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIInt; + result->v_int64 = static_cast<int64_t>(src); + } + + static TVM_FFI_INLINE void MoveToManagedAny(Int src, TVMFFIAny* result) { + ConvertToAnyView(src, result); + } + + static TVM_FFI_INLINE std::optional<Int> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIInt) { + return std::make_optional<Int>(src->v_int64); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "int"; } +}; + +// Float POD values +template <typename Float> +struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> { + static constexpr bool enabled = true; + + static TVM_FFI_INLINE void ConvertToAnyView(const Float& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIFloat; + result->v_float64 = static_cast<double>(src); + } + + static TVM_FFI_INLINE void MoveToManagedAny(Float src, TVMFFIAny* result) { + ConvertToAnyView(src, result); + } + + static TVM_FFI_INLINE std::optional<Float> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIFloat) { + return std::make_optional<Float>(src->v_float64); + } else if (src->type_index == TypeIndex::kTVMFFIInt) { + return std::make_optional<Float>(src->v_int64); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "float"; } +}; + +// Traits for object +template <typename TObjRef> +struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef>>> { + using ContainerType = typename TObjRef::ContainerType; + + static constexpr bool enabled = true; + + static TVM_FFI_INLINE void ConvertToAnyView(const TObjRef& src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectInternal::GetTVMFFIObjectPtrFromObjectRef(src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE void MoveToManagedAny(TObjRef src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectInternal::MoveTVMFFIObjectPtrFromObjectRef(&src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE std::optional<TObjRef> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { +#if TVM_FFI_ALLOW_DYN_TYPE + if (details::IsObjectInstance<ContainerType>(src->type_index)) { + return TObjRef(details::ObjectInternal::ObjectPtrFromUnowned<Object>(src->v_obj)); + } +#else + TVM_FFI_THROW(RuntimeError) + << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on". +#endif + } else if (src->type_index == kTVMFFINone) { + if (!TObjRef::_type_is_nullable) return std::nullopt; + return TObjRef(ObjectPtr<Object>()); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return ContainerType::_type_key; } +}; + +/*! + * \brief Get type key from type index + * \param type_index The input type index + * \return the type key + */ +inline std::string TypeIndex2TypeKey(int32_t type_index) { + switch (type_index) { + case TypeIndex::kTVMFFINone: + return "None"; + case TypeIndex::kTVMFFIInt: + return "int"; + case TypeIndex::kTVMFFIFloat: + return "double"; + case TypeIndex::kTVMFFIOpaquePtr: + return "void*"; + case TypeIndex::kTVMFFIDataType: + return "DataType"; + case TypeIndex::kTVMFFIDevice: + return "Device"; + case TypeIndex::kTVMFFIRawStr: + return "const char*"; + default: { + TVM_FFI_ICHECK_GE(type_index, TypeIndex::kTVMFFIStaticObjectBegin) + << "Uknown type_index=" << type_index; +#if TVM_FFI_ALLOW_DYN_TYPE + const TypeInfo* type_info = details::ObjectGetTypeInfo(type_index); + return type_info->type_key; +#else + return "object.Object"; +#endif + } + } +} +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_TYPE_TRAITS_H_ diff --git a/ffi/tests/example/test_any.cc b/ffi/tests/example/test_any.cc new file mode 100644 index 0000000000..55ef21405e --- /dev/null +++ b/ffi/tests/example/test_any.cc @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/any.h> +#include <tvm/ffi/memory.h> + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Any, Int) { + AnyView view0; + EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + std::optional<int64_t> opt_v0 = view0.TryAs<int64_t>(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] int64_t v0 = view0; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `int`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + AnyView view1 = 1; + EXPECT_EQ(view1.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); + EXPECT_EQ(view1.AsTVMFFIAny().v_int64, 1); + + int32_t int_v1 = view1; + EXPECT_EQ(int_v1, 1); + + int64_t v1 = 2; + view0 = v1; + EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); + EXPECT_EQ(view0.AsTVMFFIAny().v_int64, 2); +} + +TEST(Any, Float) { + AnyView view0; + EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + std::optional<double> opt_v0 = view0.TryAs<double>(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] double v0 = view0; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `float`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + AnyView view1_int = 1; + float float_v1 = view1_int; + EXPECT_EQ(float_v1, 1); + + AnyView view2 = 2.2; + EXPECT_EQ(view2.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); + EXPECT_EQ(view2.AsTVMFFIAny().v_float64, 2.2); + + float v1 = 2; + view0 = v1; + EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); + EXPECT_EQ(view0.AsTVMFFIAny().v_float64, 2); +} + +TEST(Any, Object) { + AnyView view0; + EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + // int object is not nullable + std::optional<TInt> opt_v0 = view0.TryAs<TInt>(); + EXPECT_TRUE(!opt_v0.has_value()); + + TInt v1(11); + EXPECT_EQ(v1.use_count(), 1); + // view won't increase refcount + AnyView view1 = v1; + EXPECT_EQ(v1.use_count(), 1); + // any will trigger ref count increase + Any any1 = v1; + EXPECT_EQ(v1.use_count(), 2); + // copy to another view + AnyView view2 = any1; + EXPECT_EQ(v1.use_count(), 2); + + // convert that triggers error + EXPECT_THROW( + { + try { + [[maybe_unused]] TFloat v0 = view1; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + std::cout << what; + EXPECT_NE(what.find("Cannot convert from type `test.Int` to `test.Float`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + // Try to convert to number + TNumber number0 = any1; + EXPECT_EQ(v1.use_count(), 3); + EXPECT_TRUE(number0.as<TIntObj>()); + EXPECT_EQ(number0.as<TIntObj>()->value, 11); + EXPECT_TRUE(!any1.TryAs<int>().has_value()); + + TInt int1 = view2; + EXPECT_EQ(v1.use_count(), 4); + any1.reset(); + EXPECT_EQ(v1.use_count(), 3); +} + +} // namespace diff --git a/ffi/tests/example/test_c_ffi_abi.cc b/ffi/tests/example/test_c_ffi_abi.cc index d936247653..9cbc7e67e1 100644 --- a/ffi/tests/example/test_c_ffi_abi.cc +++ b/ffi/tests/example/test_c_ffi_abi.cc @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ #include <gtest/gtest.h> #include <tvm/ffi/c_api.h> diff --git a/ffi/tests/example/test_error.cc b/ffi/tests/example/test_error.cc index 015551c953..bf43f97ad5 100644 --- a/ffi/tests/example/test_error.cc +++ b/ffi/tests/example/test_error.cc @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ #include <gtest/gtest.h> #include <tvm/ffi/error.h> diff --git a/ffi/tests/example/test_object.cc b/ffi/tests/example/test_object.cc index 4395fccf52..f931c99419 100644 --- a/ffi/tests/example/test_object.cc +++ b/ffi/tests/example/test_object.cc @@ -1,65 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ #include <gtest/gtest.h> #include <tvm/ffi/memory.h> #include <tvm/ffi/object.h> +#include "./testing_object.h" + namespace { using namespace tvm::ffi; - -class NumberObj : public Object { - public: - // declare as one slot, with float as overflow - static constexpr uint32_t _type_child_slots = 1; - static constexpr const char* _type_key = "test.Number"; - TVM_FFI_DECLARE_BASE_OBJECT_INFO(NumberObj, Object); -}; - -class Number : public ObjectRef { - public: - TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Number, ObjectRef, NumberObj); -}; - -class IntObj : public NumberObj { - public: - int64_t value; - - IntObj(int64_t value) : value(value) {} - - static constexpr const char* _type_key = "test.Int"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(IntObj, NumberObj); -}; - -class Int : public Number { - public: - explicit Int(int64_t value) { - data_ = make_object<IntObj>(value); - } - - TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Int, Number, IntObj); -}; - -class FloatObj : public NumberObj { - public: - double value; - - FloatObj(double value) : value(value) {} - - static constexpr const char* _type_key = "test.Float"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(FloatObj, NumberObj); -}; - -class Float : public Number { - public: - explicit Float(double value) { - data_ = make_object<FloatObj>(value); - } - - TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Float, Number, FloatObj); -}; +using namespace tvm::ffi::testing; TEST(Object, RefCounter) { - ObjectPtr<IntObj> a = make_object<IntObj>(11); - ObjectPtr<IntObj> b = a; + ObjectPtr<TIntObj> a = make_object<TIntObj>(11); + ObjectPtr<TIntObj> b = a; EXPECT_EQ(a->value, 11); @@ -69,7 +39,7 @@ TEST(Object, RefCounter) { EXPECT_TRUE(b == nullptr); EXPECT_EQ(b.use_count(), 0); - ObjectPtr<IntObj> c = std::move(a); + ObjectPtr<TIntObj> c = std::move(a); EXPECT_EQ(c.use_count(), 1); EXPECT_TRUE(a == nullptr); @@ -77,50 +47,50 @@ TEST(Object, RefCounter) { } TEST(Object, TypeInfo) { - const TypeInfo* info = tvm::ffi::details::ObjectGetTypeInfo(IntObj::RuntimeTypeIndex()); + const TypeInfo* info = tvm::ffi::details::ObjectGetTypeInfo(TIntObj::RuntimeTypeIndex()); EXPECT_TRUE(info != nullptr); - EXPECT_EQ(info->type_index, IntObj::RuntimeTypeIndex()); + EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex()); EXPECT_EQ(info->type_depth, 2); EXPECT_EQ(info->type_acenstors[0], Object::_type_index); - EXPECT_EQ(info->type_acenstors[1], NumberObj::_type_index); + EXPECT_EQ(info->type_acenstors[1], TNumberObj::_type_index); EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin); } TEST(Object, InstanceCheck) { - ObjectPtr<Object> a = make_object<IntObj>(11); - ObjectPtr<Object> b = make_object<FloatObj>(11); + ObjectPtr<Object> a = make_object<TIntObj>(11); + ObjectPtr<Object> b = make_object<TFloatObj>(11); EXPECT_TRUE(a->IsInstance<Object>()); - EXPECT_TRUE(a->IsInstance<NumberObj>()); - EXPECT_TRUE(a->IsInstance<IntObj>()); - EXPECT_TRUE(!a->IsInstance<FloatObj>()); + EXPECT_TRUE(a->IsInstance<TNumberObj>()); + EXPECT_TRUE(a->IsInstance<TIntObj>()); + EXPECT_TRUE(!a->IsInstance<TFloatObj>()); EXPECT_TRUE(a->IsInstance<Object>()); - EXPECT_TRUE(b->IsInstance<NumberObj>()); - EXPECT_TRUE(!b->IsInstance<IntObj>()); - EXPECT_TRUE(b->IsInstance<FloatObj>()); + EXPECT_TRUE(b->IsInstance<TNumberObj>()); + EXPECT_TRUE(!b->IsInstance<TIntObj>()); + EXPECT_TRUE(b->IsInstance<TFloatObj>()); } TEST(ObjectRef, as) { - ObjectRef a = Int(10); - ObjectRef b = Float(20); + ObjectRef a = TInt(10); + ObjectRef b = TFloat(20); // nullable object ObjectRef c(nullptr); - EXPECT_TRUE(a.as<IntObj>() != nullptr); - EXPECT_TRUE(a.as<FloatObj>() == nullptr); - EXPECT_TRUE(a.as<NumberObj>() != nullptr); + EXPECT_TRUE(a.as<TIntObj>() != nullptr); + EXPECT_TRUE(a.as<TFloatObj>() == nullptr); + EXPECT_TRUE(a.as<TNumberObj>() != nullptr); - EXPECT_TRUE(b.as<IntObj>() == nullptr); - EXPECT_TRUE(b.as<FloatObj>() != nullptr); - EXPECT_TRUE(b.as<NumberObj>() != nullptr); + EXPECT_TRUE(b.as<TIntObj>() == nullptr); + EXPECT_TRUE(b.as<TFloatObj>() != nullptr); + EXPECT_TRUE(b.as<TNumberObj>() != nullptr); - EXPECT_TRUE(c.as<IntObj>() == nullptr); - EXPECT_TRUE(c.as<FloatObj>() == nullptr); - EXPECT_TRUE(c.as<NumberObj>() == nullptr); + EXPECT_TRUE(c.as<TIntObj>() == nullptr); + EXPECT_TRUE(c.as<TFloatObj>() == nullptr); + EXPECT_TRUE(c.as<TNumberObj>() == nullptr); - EXPECT_EQ(a.as<IntObj>()->value, 10); - EXPECT_EQ(b.as<FloatObj>()->value, 20); + EXPECT_EQ(a.as<TIntObj>()->value, 10); + EXPECT_EQ(b.as<TFloatObj>()->value, 20); } } // namespace
