This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit dfbf0e587ebc112f2bfa5a8a2bd36a597c849ec9 Author: tqchen <[email protected]> AuthorDate: Sun Aug 18 17:14:29 2024 -0400 [FFI] Function support Co-authored-by: Junru Shao <[email protected]> --- ffi/include/tvm/ffi/any.h | 30 ++- ffi/include/tvm/ffi/c_api.h | 16 ++ ffi/include/tvm/ffi/function.h | 341 +++++++++++++++++++++++++++++++++ ffi/include/tvm/ffi/function_details.h | 262 +++++++++++++++++++++++++ ffi/include/tvm/ffi/internal_utils.h | 26 ++- ffi/include/tvm/ffi/type_traits.h | 9 +- ffi/src/ffi/traceback.h | 6 + ffi/tests/example/test_any.cc | 1 - ffi/tests/example/test_function.cc | 123 ++++++++++++ ffi/tests/example/testing_object.h | 82 ++++++++ 10 files changed, 886 insertions(+), 10 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 1568fe98b2..6f2935f8a6 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -55,6 +55,10 @@ class AnyView { void swap(AnyView& other) { // NOLINT(*) std::swap(data_, other.data_); } + /*! \return the internal type index */ + int32_t type_index() const { + return data_.type_index; + } // default constructors AnyView() { data_.type_index = TypeIndex::kTVMFFINone; } ~AnyView() = default; @@ -88,7 +92,7 @@ class AnyView { operator T() const { std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_); if (opt.has_value()) { - return std::move(opt.value()); + return std::move(*opt); } TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) << "` to `" << TypeTraits<T>::TypeStr() << "`"; @@ -110,9 +114,6 @@ class AnyView { } }; -// layout assert to ensure we can freely cast between the two types -static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); - namespace details { /*! * \brief Helper function to inplace convert any view to any. @@ -155,7 +156,10 @@ class Any { void swap(Any& other) { // NOLINT(*) std::swap(data_, other.data_); } - + /*! \return the internal type index */ + int32_t type_index() const { + return data_.type_index; + } // default constructors Any() { data_.type_index = TypeIndex::kTVMFFINone; } ~Any() { this->reset(); } @@ -205,13 +209,27 @@ class Any { operator T() const { std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_); if (opt.has_value()) { - return std::move(opt.value()); + return std::move(*opt); } TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) << "` to `" << TypeTraits<T>::TypeStr() << "`"; } + + // FFI related operations + /*! + * Move the current data to FFI any + * \parma result the output to nmove to + */ + void MoveToTVMFFIAny(TVMFFIAny* result) { + *result = data_; + data_.type_index = TypeIndex::kTVMFFINone; + } }; +// layout assert to ensure we can freely cast between the two types +static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); +static_assert(sizeof(Any) == sizeof(TVMFFIAny)); + } // namespace ffi } // namespace tvm #endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index e0597d0dda..65341feb83 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -164,6 +164,22 @@ typedef struct { const int32_t* type_acenstors; } TVMFFITypeInfo; +/*! + * \brief Type that defines C-style safe call convention + * + * Safe call explicitly catches exception on function boundary. + * + * \param func The function handle + * \param num_args Number if input arguments + * \param args The input arguments to the call. + * \param result Store output result + * + * \return The call return 0 if call is successful. + * It returns non-zero value if there is an error. + * When error happens, the exception object will be stored in result. + */ +typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const TVMFFIAny* args, TVMFFIAny* result); + #ifdef __cplusplus } // TVM_FFI_EXTERN_C #endif diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h new file mode 100644 index 0000000000..81b0d8fab2 --- /dev/null +++ b/ffi/include/tvm/ffi/function.h @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/function.h + * \brief A managed function in the TVM FFI. + */ +#ifndef TVM_FFI_FUNCTION_H_ +#define TVM_FFI_FUNCTION_H_ + +#include <tvm/ffi/c_api.h> +#include <tvm/ffi/internal_utils.h> +#include <tvm/ffi/any.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/function_details.h> + +namespace tvm { +namespace ffi { + +/*! + * \brief Object container class that backs ffi::Function + * \note Do not use this function directly, use ffi::Function + */ +class FunctionObj : public Object { + public: + typedef void (*FCall)(const FunctionObj*, int32_t, const AnyView* , Any*); + /*! \brief A C++ style call implementation */ + FCall call; + /*! \brief A C API compatible call with exception catching. */ + TVMFFISafeCallType safe_call; + + TVM_FFI_INLINE void CallPacked(int32_t num_args, const AnyView* args, Any* result) const { + this->call(this, num_args, args, result); + } + + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunc; + static constexpr const char* _type_key = "object.Function"; + + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(FunctionObj, Object); + + protected: + /*! \brief Make default constructor protected. */ + FunctionObj() {} + + // Implementing safe call style + static int32_t SafeCall(void* func, int32_t num_args, const TVMFFIAny* args, TVMFFIAny* result) { + FunctionObj* self = static_cast<FunctionObj*>(func); + try { + self->call(self, num_args, reinterpret_cast<const AnyView*>(args), reinterpret_cast<Any*>(result)); + return 0; + } catch (const tvm::ffi::Error& err) { + Any(std::move(err)).MoveToTVMFFIAny(result); + return 1; + } catch (const std::runtime_error& err) { + Any( + tvm::ffi::Error("RuntimeError", err.what(), "") + ).MoveToTVMFFIAny(result); + return 1; + } + TVM_FFI_UNREACHABLE(); + } + + friend class Function; +}; + +namespace details { +/*! + * \brief Derived object class for constructing FunctionObj backed by a TCallable + * + * This is a helper class that + */ +template <typename TCallable> +class FunctionObjImpl : public FunctionObj { + public: + using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type; + /*! \brief The type of derived object class */ + using TSelf = FunctionObjImpl<TCallable>; + /*! + * \brief Derived object class for constructing PackedFuncObj. + * \param callable The type-erased callable object. + */ + explicit FunctionObjImpl(TCallable callable) + : callable_(callable) { + this->call = Call; + this->safe_call = SafeCall; + } + + private: + // implementation of call + static void Call(const FunctionObj* func, int32_t num_args, const AnyView* args, Any* result) { + (static_cast<const TSelf*>(func))->callable_(num_args, args, result); + } + + /*! \brief Type-erased filed for storing callable object*/ + mutable TStorage callable_; +}; + +/*! + * \brief Base class to provide a common implementation to redirect call to safecall + * \tparam Derived The derived class in CRTP-idiom + */ +template <typename Derived> +struct RedirectCallToSafeCall { + static void Call(const FunctionObj* func, int32_t num_args, const AnyView* args , Any* rv) { + Derived* self = static_cast<Derived*>(const_cast<FunctionObj*>(func)); + int ret_code = self->RedirectSafeCall( + num_args, reinterpret_cast<const TVMFFIAny*>(args), + reinterpret_cast<TVMFFIAny*>(rv)); + if (ret_code != 0) { + if (std::optional<tvm::ffi::Error> err = rv->TryAs<tvm::ffi::Error>()) { + throw std::move(*err); + } else { + TVM_FFI_THROW(RuntimeError) << "Error encountered when calling a tvm::ffi::Function"; + } + } + } + + static int32_t SafeCall(void* func, int32_t num_args, const TVMFFIAny* args, TVMFFIAny* rv) { + Derived* self = reinterpret_cast<Derived*>(func); + return self->RedirectSafeCall(num_args, args, rv); + } +}; + +/*! + * \brief FunctionObj specialization that leverages C-style callback definitions. + */ +class ExternCFunctionObjImpl : + public FunctionObj, + public RedirectCallToSafeCall<ExternCFunctionObjImpl> { + public: + using RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall; + + ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) + : self_(self), safe_call_(safe_call), deleter_(deleter) { + this->call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::Call; + this->safe_call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall; + } + + ~ExternCFunctionObjImpl() { + deleter_(self_); + } + + TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* args, TVMFFIAny* rv) const { + return safe_call_(self_, num_args, args, rv); + } + + private: + void* self_; + TVMFFISafeCallType safe_call_; + void (*deleter_)(void* self); +}; + +/*! + * \brief FunctionObj specialization that wraps an external function. + */ +class ImportedFunctionObjImpl : + public FunctionObj, + public RedirectCallToSafeCall<ImportedFunctionObjImpl> { + public: + using RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall; + + explicit ImportedFunctionObjImpl(ObjectPtr<Object> data) + : data_(data) { + this->call = RedirectCallToSafeCall<ImportedFunctionObjImpl>::Call; + this->safe_call = RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall; + } + + TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* args, TVMFFIAny* rv) const { + FunctionObj* func = const_cast<FunctionObj*>(static_cast<const FunctionObj*>(data_.get())); + return func->safe_call(func, num_args, args, rv); + } + + private: + ObjectPtr<Object> data_; +}; + +// Helper class to set packed arguments +class PackedArgsSetter { + public: + PackedArgsSetter(AnyView* args) : args_(args) {} + + // NOTE: setter needs to be very carefully designed + // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) + // that is why we need T&& and std::forward here + template<typename T> + TVM_FFI_INLINE void operator()(size_t i, T&& value) const { + args_[i].operator=(std::forward<T>(value)); + } + + private: + AnyView* args_; +}; +} // namespace details + +/*! + * \brief ffi::Function is a type-erased function. + * The arguments are passed by packed format. + */ +class Function : public ObjectRef { + public: + /*! \brief Constructor from null */ + Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` + * \param packed_call The packed function signature + */ + template <typename TCallable> + static Function FromPacked(TCallable packed_call) { + static_assert( + std::is_convertible_v<TCallable, std::function<void(int32_t, const AnyView*, Any*)>>, + "tvm::ffi::Function::FromPacked requires input function signature to match packed func format" + ); + using ObjType = typename details::FunctionObjImpl<TCallable>; + Function func; + func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call)); + return func; + } + /*! + * \brief Import a possibly externally defined function to this dll + * \param other Function defined in another dynamic library. + * + * \note This function will redirect the call to safe_call in other. + * It will try to detect if the function is already from the same DLL + * and directly return the original function if so. + * + * \return The imported function. + */ + static Function ImportFromExternDLL(Function other) { + const FunctionObj* other_func = static_cast<const FunctionObj*>(other.get()); + // the other function comes from the same dll, no action needed + if (other_func->safe_call == FunctionObj::SafeCall || + other_func->safe_call == details::ImportedFunctionObjImpl::SafeCall || + other_func->safe_call == details::ExternCFunctionObjImpl::SafeCall) { + return other; + } + // the other function coems from a different library + Function func; + func.data_ = make_object<details::ImportedFunctionObjImpl>(std::move(other.data_)); + return func; + } + /*! + * \brief Create ffi::Function from a C style callbacks. + * \param self Resource handle to the function + * \param safe_call The safe_call definition in C. + * \param deleter The deleter to release the resource of self. + * \return The created function. + */ + static Function FromExternC(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) { + // the other function coems from a different library + Function func; + func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, deleter); + return func; + } + /*! + * \brief Constructing a packed function from a normal function. + * + * \param callable the internal container of packed function. + */ + template <typename TCallable> + static Function FromUnpacked(TCallable callable) { + using FuncInfo = details::FunctionInfo<TCallable>; + auto call_packed = [callable](int32_t num_args, const AnyView* args, Any* rv) -> void { + details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>(nullptr, callable, num_args, args, rv); + }; + return FromPacked(call_packed); + } + /*! + * \brief Constructing a packed function from a normal function. + * + * \param callable the internal container of packed function. + * \param name optional name attacked to the function. + */ + template <typename TCallable> + static Function FromUnpacked(TCallable callable, std::string name) { + using FuncInfo = details::FunctionInfo<TCallable>; + auto call_packed = [callable, name](int32_t num_args, const AnyView* args, Any* rv) -> void { + details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>(&name, callable, + num_args, args, rv); + }; + return FromPacked(call_packed); + } + /*! + * \brief Call function by directly passing in unpacked arguments. + * + * \param args Arguments to be passed. + * \tparam Args arguments to be passed. + * + * \code + * // Example code on how to call packed function + * void CallFFIFunction(tvm::ffi::Function f) { + * // call like normal functions by pass in arguments + * // return value is automatically converted back + * int rvalue = f(1, 2.0); + * } + * \endcode + */ + template <typename... Args> + TVM_FFI_INLINE Any operator()(Args&&... args) const { + const int kNumArgs = sizeof...(Args); + const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; + AnyView args_pack[kArraySize]; + details::for_each(details::PackedArgsSetter(args_pack), std::forward<Args>(args)...); + Any result; + static_cast<FunctionObj*>(data_.get())->CallPacked(kNumArgs, args_pack, &result); + return result; + } + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param rv The return value. + */ + TVM_FFI_INLINE void CallPacked(int32_t num_args, const AnyView* args, Any* result) const { + static_cast<FunctionObj*>(data_.get())->CallPacked(num_args, args, result); + } + /*! \return Whether the packed function is nullptr */ + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + /*! \return Whether the packed function is not nullptr */ + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + + TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Function, ObjectRef, FunctionObj); +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_OBJECT_H_ diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h new file mode 100644 index 0000000000..f22258029d --- /dev/null +++ b/ffi/include/tvm/ffi/function_details.h @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/function_details.h + * \brief Implements the funciton signature reflection + */ +#ifndef TVM_FFI_FUNCTION_DETAILS_H_ +#define TVM_FFI_FUNCTION_DETAILS_H_ + +#include <tvm/ffi/c_api.h> +#include <tvm/ffi/internal_utils.h> +#include <tvm/ffi/any.h> +#include <tvm/ffi/error.h> + +namespace tvm { +namespace ffi { +namespace details { + +template <typename Type> +struct Type2Str { + static std::string v() { + return TypeTraitsNoCR<Type>::TypeStr(); + } +}; + +template <> +struct Type2Str<Any> { + static const char* v() { + return "Any"; + } +}; + +template <> +struct Type2Str<AnyView> { + static const char* v() { + return "AnyView"; + } +}; +template <> +struct Type2Str<void> { + static const char* v() { + return "void"; + } +}; + +template <typename ArgType> +struct Arg2Str { + template <size_t i> + static TVM_FFI_INLINE void Apply(std::ostream& os) { + using Arg = std::tuple_element_t<i, ArgType>; + if constexpr (i != 0) { + os << ", "; + } + os << i << ": " << Type2Str<Arg>::v(); + } + template <size_t... I> + static TVM_FFI_INLINE void Run(std::ostream &os, std::index_sequence<I...>) { + using TExpander = int[]; + (void)TExpander{0, (Apply<I>(os), 0)...}; + } +}; + +template <typename T> +static constexpr bool ArgSupported = ( + std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, Any> || + std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, AnyView> || + TypeTraitsNoCR<T>::enabled +); + +// NOTE: return type can only support non-reference managed returns +template <typename T> +static constexpr bool RetSupported = ( + std::is_same_v<T, Any> || std::is_void_v<T> || TypeTraits<T>::enabled +); + +template <typename R, typename... Args> +struct FuncFunctorImpl { + using FType = R(Args...); + using ArgType = std::tuple<Args...>; + using RetType = R; + /*! \brief total number of arguments*/ + static constexpr size_t num_args = sizeof...(Args); + /*! \brief Whether this function can be converted to ffi::Function via FromUnpacked */ + static constexpr bool unpacked_supported = (ArgSupported<Args> && ...) && (RetSupported<R>); + + static TVM_FFI_INLINE std::string Sig() { + using IdxSeq = std::make_index_sequence<sizeof...(Args)>; + std::ostringstream ss; + ss << "("; + Arg2Str<std::tuple<Args...>>::Run(ss, IdxSeq{}); + ss << ") -> " << Type2Str<R>::v(); + return ss.str(); + } +}; + +template <typename T> +struct FunctionInfoHelper; + +template <typename T, typename R, typename... Args> +struct FunctionInfoHelper<R (T::*)(Args...)>: FuncFunctorImpl<R, Args...> {}; +template <typename T, typename R, typename... Args> +struct FunctionInfoHelper<R (T::*)(Args...) const>: FuncFunctorImpl<R, Args...> {}; + +/*! + * \brief Template class to get function signature of a function or functor. + * \tparam T The function/functor type. + * \note We need a decltype redirection because this helps lambda types. + */ +template <typename T> +struct FunctionInfo : FunctionInfoHelper<decltype(&T::operator())> {}; + +template <typename R, typename... Args> +struct FunctionInfo<R(Args...)> : FuncFunctorImpl<R, Args...> {}; +template <typename R, typename... Args> +struct FunctionInfo<R (*)(Args...)> : FuncFunctorImpl<R, Args...> {}; + +/*! \brief Using static function to output TypedPackedFunc signature */ +typedef std::string (*FGetFuncSignature)(); + +template<typename T> +TVM_FFI_INLINE std::optional<T> TryAs(AnyView arg) { + return arg.TryAs<T>(); +} +template<> +TVM_FFI_INLINE std::optional<Any> TryAs<Any>(AnyView arg) { + return Any(arg); +} +template<> +TVM_FFI_INLINE std::optional<AnyView> TryAs<AnyView>(AnyView arg) { + return arg; +} + +/*! + * \brief Auxilary argument value with context for error reporting + */ +class MovableArgValueWithContext { + public: + /*! + * \brief move constructor from another return value. + * \param args The argument list + * \param arg_index In a function call, this argument is at index arg_index (0-indexed). + * \param optional_name Name of the function being called. Can be nullptr if the function is not. + * \param f_sig Pointer to static function outputting signature of the function being called. + * named. + */ + TVM_FFI_INLINE MovableArgValueWithContext( + const AnyView* args, int32_t arg_index, + const std::string* optional_name, + FGetFuncSignature f_sig) + : args_(args), + arg_index_(arg_index), + optional_name_(optional_name), + f_sig_(f_sig) {} + + template <typename Type> + TVM_FFI_INLINE operator Type() { + using TypeWithoutCR = std::remove_const_t<std::remove_reference_t<Type>>; + std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]); + if (opt.has_value()) { + return std::move(*opt); + } + TVM_FFI_THROW(TypeError) + << "Mismatched type on argument #" << arg_index_ << " when calling: `" + << (optional_name_ == nullptr ? "" : *optional_name_) + << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" + << Type2Str<Type>::v() << "` but got `" + << TypeIndex2TypeKey(args_[arg_index_].type_index()) << "`"; + } + + private: + const AnyView* args_; + int32_t arg_index_; + const std::string* optional_name_; + FGetFuncSignature f_sig_; +}; + +template <typename R, int nleft, int index, typename F> +struct unpack_call_dispatcher { + template <typename... Args> + TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, const F& f, + int32_t num_args, const AnyView* args, Any* rv, + Args&&... unpacked_args) { + // construct a movable argument value + // which allows potential move of argument to the input of F. + unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run( + optional_name, f_sig, f, num_args, args, rv, std::forward<Args>(unpacked_args)..., + MovableArgValueWithContext(args, index, optional_name, f_sig)); + } +}; + +template <typename R, int index, typename F> +struct unpack_call_dispatcher<R, 0, index, F> { + template <typename... Args> + TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const F& f, + int32_t, const AnyView*, Any* rv, + Args&&... unpacked_args) { + using RetType = decltype(f(std::forward<Args>(unpacked_args)...)); + if constexpr (std::is_same_v<RetType, R>) { + *rv = f(std::forward<Args>(unpacked_args)...); + } else { + *rv = R(f(std::forward<Args>(unpacked_args)...)); + } + } +}; + +template <int index, typename F> +struct unpack_call_dispatcher<void, 0, index, F> { + template <typename... Args> + TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, const F& f, + int32_t num_args, const AnyView* args, Any* rv, + Args&&... unpacked_args) { + f(std::forward<Args>(unpacked_args)...); + } +}; + +template <typename R, int nargs, typename F> +TVM_FFI_INLINE void unpack_call(const std::string* optional_name, const F& f, + int32_t num_args, const AnyView* args, Any* rv) { + using FuncInfo = FunctionInfo<F>; + FGetFuncSignature f_sig = FuncInfo::Sig; + static_assert(FuncInfo::unpacked_supported, "The function signature cannot support unpacked call"); + if (nargs != num_args) { + TVM_FFI_THROW(TypeError) + << "Mismatched number of arguments when calling: `" + << (optional_name == nullptr ? "" : *optional_name) + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " + << nargs << " but got " << num_args << " arguments"; + } + unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, num_args, args, rv); +} + +template <typename FType> +struct unpack_call_by_signature {}; + +template <typename R, typename... Args> +struct unpack_call_by_signature<R(Args...)> { + template <typename F> + TVM_FFI_INLINE static void run(const F& f, int32_t num_args, const AnyView* args, Any* rv) { + unpack_call<R, sizeof...(Args)>(nullptr, f, num_args, args, rv); + } +}; + +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_FUNCTION_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/internal_utils.h b/ffi/include/tvm/ffi/internal_utils.h index 900b7fd2e2..ffe1dd9e1a 100644 --- a/ffi/include/tvm/ffi/internal_utils.h +++ b/ffi/include/tvm/ffi/internal_utils.h @@ -17,8 +17,8 @@ * under the License. */ /*! - * \file tvm/ffi/internal_utils.h - * \brief Utility functions and macros for internal use + * \file tvm/ffi/base_details.h + * \brief Internal use utilities */ #ifndef TVM_FFI_INTERNAL_UTILS_H_ #define TVM_FFI_INTERNAL_UTILS_H_ @@ -26,6 +26,7 @@ #include <tvm/ffi/c_api.h> #include <cstddef> +#include <utility> #if defined(_MSC_VER) #define TVM_FFI_INLINE __forceinline @@ -110,6 +111,27 @@ TVM_FFI_INLINE int32_t AtomicLoadRelaxed(const int32_t* ptr) { return __atomic_load_n(raw_ptr, __ATOMIC_RELAXED); #endif } + +// for each iterator +template <bool stop, std::size_t I, typename F> +struct for_each_dispatcher { + template <typename T, typename... Args> + static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) + f(I, std::forward<T>(value)); + for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...); + } +}; + +template <std::size_t I, typename F> +struct for_each_dispatcher<true, I, F> { + static void run(const F&) {} // NOLINT(*) +}; + +template <typename F, typename... Args> +void for_each(const F& f, Args&&... args) { // NOLINT(*) + for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...); +} + } // namespace details } // namespace ffi } // namespace tvm diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index 0a1d7d9941..18c7ab2a35 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -52,6 +52,13 @@ struct TypeTraits { static constexpr bool enabled = false; }; +/*! + * \brief TypeTraits that removes const and reference keywords. + * \tparam T the original type + */ +template <typename T> +using TypeTraitsNoCR = TypeTraits<std::remove_const_t<std::remove_reference_t<T>>>; + // Integer POD values template <typename Int> struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> { @@ -153,7 +160,7 @@ inline std::string TypeIndex2TypeKey(int32_t type_index) { case TypeIndex::kTVMFFIInt: return "int"; case TypeIndex::kTVMFFIFloat: - return "double"; + return "float"; case TypeIndex::kTVMFFIOpaquePtr: return "void*"; case TypeIndex::kTVMFFIDataType: diff --git a/ffi/src/ffi/traceback.h b/ffi/src/ffi/traceback.h index 2bc9e523a7..e4ba2a589d 100644 --- a/ffi/src/ffi/traceback.h +++ b/ffi/src/ffi/traceback.h @@ -57,6 +57,12 @@ inline bool ShouldExcludeFrame(const char* filename, const char* symbol) { if (strstr(filename, "include/tvm/ffi/error.h")) { return true; } + if (strstr(filename, "include/tvm/ffi/function_details.h")) { + return true; + } + if (strstr(filename, "include/tvm/ffi/function.h")) { + return true; + } if (strstr(filename, "src/ffi/traceback.cc")) { return true; } diff --git a/ffi/tests/example/test_any.cc b/ffi/tests/example/test_any.cc index 55ef21405e..791ecfe68c 100644 --- a/ffi/tests/example/test_any.cc +++ b/ffi/tests/example/test_any.cc @@ -122,7 +122,6 @@ TEST(Any, Object) { } catch (const Error& error) { EXPECT_EQ(error->kind, "TypeError"); std::string what = error.what(); - std::cout << what; EXPECT_NE(what.find("Cannot convert from type `test.Int` to `test.Float`"), std::string::npos); throw; diff --git a/ffi/tests/example/test_function.cc b/ffi/tests/example/test_function.cc new file mode 100644 index 0000000000..46c69c3739 --- /dev/null +++ b/ffi/tests/example/test_function.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/any.h> +#include <tvm/ffi/memory.h> +#include <tvm/ffi/function.h> + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Func, FromPacked) { + Function fadd1 = Function::FromPacked( + [](int32_t num_args, const AnyView* args, Any* rv) { + EXPECT_EQ(num_args, 1); + int32_t a = args[0]; + *rv = a + 1; + } + ); + int b = fadd1(1); + EXPECT_EQ(b, 2); + + Function fadd2 = Function::FromPacked( + [](int32_t num_args, const AnyView* args, Any* rv) { + EXPECT_EQ(num_args, 1); + TInt a = args[0]; + EXPECT_EQ(a.use_count(), 2); + *rv = a->value + 1; + } + ); + EXPECT_EQ(fadd2(TInt(12)).operator int(), 13); +} + +TEST(Func, FromUnpacked) { + // try decution + Function fadd1 = Function::FromUnpacked( + [](const int32_t& a) -> int { return a + 1; } + ); + int b = fadd1(1); + EXPECT_EQ(b, 2); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd1(1.1); + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + EXPECT_STREQ( + error->message.c_str(), + "Mismatched type on argument #0 when calling: `(0: int) -> int`. " + "Expected `int` but got `float`"); + throw; + } + }, + ::tvm::ffi::Error); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd1(); + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + EXPECT_STREQ( + error->message.c_str(), + "Mismatched number of arguments when calling: `(0: int) -> int`. " + "Expected 1 but got 0 arguments"); + throw; + } + }, + ::tvm::ffi::Error); + + // try decution + Function fpass_and_return = Function::FromUnpacked( + [](TInt x, int value, AnyView z) -> Function { + EXPECT_EQ(x.use_count(), 2); + EXPECT_EQ(x->value, value); + if (auto opt = z.TryAs<int>()) { + EXPECT_EQ(value, *opt); + } + return Function::FromUnpacked([value](int x) -> int { return x + value; }); + }, + "fpass_and_return"); + TInt a(11); + Function fret = fpass_and_return(std::move(a), 11, 11); + EXPECT_EQ(fret(12).operator int(), 23); + + EXPECT_THROW( + { + try { + fpass_and_return(); + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + EXPECT_STREQ(error->message.c_str(), + "Mismatched number of arguments when calling: " + "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> object.Function`. " + "Expected 3 but got 0 arguments"); + throw; + } + }, + ::tvm::ffi::Error); +} +} // namespace diff --git a/ffi/tests/example/testing_object.h b/ffi/tests/example/testing_object.h new file mode 100644 index 0000000000..00cc58b7b7 --- /dev/null +++ b/ffi/tests/example/testing_object.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_FFI_TESTING_OBJECT_H_ +#define TVM_FFI_TESTING_OBJECT_H_ +#include <tvm/ffi/memory.h> +#include <tvm/ffi/object.h> + +namespace tvm { +namespace ffi { +namespace testing { + +class TNumberObj : public Object { + public: + // declare as one slot, with float as overflow + static constexpr uint32_t _type_child_slots = 1; + static constexpr const char* _type_key = "test.Number"; + TVM_FFI_DECLARE_BASE_OBJECT_INFO(TNumberObj, Object); +}; + +class TNumber : public ObjectRef { + public: + TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(TNumber, ObjectRef, TNumberObj); +}; + +class TIntObj : public TNumberObj { + public: + int64_t value; + + TIntObj(int64_t value) : value(value) {} + + static constexpr const char* _type_key = "test.Int"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj); +}; + +class TInt : public TNumber { + public: + explicit TInt(int64_t value) { + data_ = make_object<TIntObj>(value); + } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj); +}; + +class TFloatObj : public TNumberObj { + public: + double value; + + TFloatObj(double value) : value(value) {} + + static constexpr const char* _type_key = "test.Float"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj); +}; + +class TFloat : public TNumber { + public: + explicit TFloat(double value) { + data_ = make_object<TFloatObj>(value); + } + + TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(TFloat, TNumber, TFloatObj); +}; +} // namespace testing +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_TESTING_OBJECT_H_ \ No newline at end of file
