This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit dfbf0e587ebc112f2bfa5a8a2bd36a597c849ec9
Author: tqchen <[email protected]>
AuthorDate: Sun Aug 18 17:14:29 2024 -0400

    [FFI] Function support
    
    Co-authored-by: Junru Shao <[email protected]>
---
 ffi/include/tvm/ffi/any.h              |  30 ++-
 ffi/include/tvm/ffi/c_api.h            |  16 ++
 ffi/include/tvm/ffi/function.h         | 341 +++++++++++++++++++++++++++++++++
 ffi/include/tvm/ffi/function_details.h | 262 +++++++++++++++++++++++++
 ffi/include/tvm/ffi/internal_utils.h   |  26 ++-
 ffi/include/tvm/ffi/type_traits.h      |   9 +-
 ffi/src/ffi/traceback.h                |   6 +
 ffi/tests/example/test_any.cc          |   1 -
 ffi/tests/example/test_function.cc     | 123 ++++++++++++
 ffi/tests/example/testing_object.h     |  82 ++++++++
 10 files changed, 886 insertions(+), 10 deletions(-)

diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 1568fe98b2..6f2935f8a6 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -55,6 +55,10 @@ class AnyView {
   void swap(AnyView& other) {  // NOLINT(*)
     std::swap(data_, other.data_);
   }
+  /*! \return the internal type index */
+  int32_t type_index() const {
+    return data_.type_index;
+  }
   // default constructors
   AnyView() { data_.type_index = TypeIndex::kTVMFFINone; }
   ~AnyView() = default;
@@ -88,7 +92,7 @@ class AnyView {
   operator T() const {
     std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
     if (opt.has_value()) {
-      return std::move(opt.value());
+      return std::move(*opt);
     }
     TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << 
TypeIndex2TypeKey(data_.type_index)
                              << "` to `" << TypeTraits<T>::TypeStr() << "`";
@@ -110,9 +114,6 @@ class AnyView {
   }
 };
 
-// layout assert to ensure we can freely cast between the two types
-static_assert(sizeof(AnyView) == sizeof(TVMFFIAny));
-
 namespace details {
 /*!
  * \brief Helper function to inplace convert any view to any.
@@ -155,7 +156,10 @@ class Any {
   void swap(Any& other) {  // NOLINT(*)
     std::swap(data_, other.data_);
   }
-
+  /*! \return the internal type index */
+  int32_t type_index() const {
+    return data_.type_index;
+  }
   // default constructors
   Any() { data_.type_index = TypeIndex::kTVMFFINone; }
   ~Any() { this->reset(); }
@@ -205,13 +209,27 @@ class Any {
   operator T() const {
     std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
     if (opt.has_value()) {
-      return std::move(opt.value());
+      return std::move(*opt);
     }
     TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << 
TypeIndex2TypeKey(data_.type_index)
                              << "` to `" << TypeTraits<T>::TypeStr() << "`";
   }
+
+  // FFI related operations
+  /*!
+   * Move the current data to FFI any
+   * \parma result the output to nmove to
+   */
+  void MoveToTVMFFIAny(TVMFFIAny* result) {
+    *result = data_;
+    data_.type_index = TypeIndex::kTVMFFINone;
+  }
 };
 
+// layout assert to ensure we can freely cast between the two types
+static_assert(sizeof(AnyView) == sizeof(TVMFFIAny));
+static_assert(sizeof(Any) == sizeof(TVMFFIAny));
+
 }  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_FFI_ANY_H_
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index e0597d0dda..65341feb83 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -164,6 +164,22 @@ typedef struct {
   const int32_t* type_acenstors;
 } TVMFFITypeInfo;
 
+/*!
+ * \brief Type that defines C-style safe call convention
+ *
+ * Safe call explicitly catches exception on function boundary.
+ *
+ * \param func The function handle
+ * \param num_args Number if input arguments
+ * \param args The input arguments to the call.
+ * \param result Store output result
+ *
+ * \return The call return 0 if call is successful.
+ *  It returns non-zero value if there is an error.
+ *  When error happens, the exception object will be stored in result.
+ */
+typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const 
TVMFFIAny* args, TVMFFIAny* result);
+
 #ifdef __cplusplus
 }  // TVM_FFI_EXTERN_C
 #endif
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
new file mode 100644
index 0000000000..81b0d8fab2
--- /dev/null
+++ b/ffi/include/tvm/ffi/function.h
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/function.h
+ * \brief A managed function in the TVM FFI.
+ */
+#ifndef TVM_FFI_FUNCTION_H_
+#define TVM_FFI_FUNCTION_H_
+
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/internal_utils.h>
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/function_details.h>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Object container class that backs ffi::Function
+ * \note Do not use this function directly, use ffi::Function
+ */
+class FunctionObj : public Object {
+ public:
+  typedef void (*FCall)(const FunctionObj*, int32_t, const AnyView* , Any*);
+  /*! \brief A C++ style call implementation */
+  FCall call;
+  /*! \brief A C API compatible call with exception catching. */
+  TVMFFISafeCallType safe_call;
+
+  TVM_FFI_INLINE void CallPacked(int32_t num_args, const AnyView* args, Any* 
result) const {
+    this->call(this, num_args, args, result);
+  }
+
+  static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunc;
+  static constexpr const char* _type_key = "object.Function";
+
+  TVM_FFI_DECLARE_STATIC_OBJECT_INFO(FunctionObj, Object);
+
+ protected:
+  /*! \brief Make default constructor protected. */
+  FunctionObj() {}
+
+  // Implementing safe call style
+  static int32_t SafeCall(void* func, int32_t num_args, const TVMFFIAny* args, 
TVMFFIAny* result) {
+    FunctionObj* self = static_cast<FunctionObj*>(func);
+    try {
+      self->call(self, num_args, reinterpret_cast<const AnyView*>(args), 
reinterpret_cast<Any*>(result));
+      return 0;
+    } catch (const tvm::ffi::Error& err) {
+      Any(std::move(err)).MoveToTVMFFIAny(result);
+      return 1;
+    } catch (const std::runtime_error& err) {
+     Any(
+      tvm::ffi::Error("RuntimeError", err.what(), "")
+     ).MoveToTVMFFIAny(result);
+      return 1;
+    }
+    TVM_FFI_UNREACHABLE();
+  }
+
+  friend class Function;
+};
+
+namespace details {
+/*!
+ * \brief Derived object class for constructing FunctionObj backed by a 
TCallable
+ *
+ * This is a helper class that
+ */
+template <typename TCallable>
+class FunctionObjImpl : public FunctionObj {
+ public:
+  using TStorage = typename std::remove_cv<typename 
std::remove_reference<TCallable>::type>::type;
+  /*! \brief The type of derived object class */
+  using TSelf = FunctionObjImpl<TCallable>;
+  /*!
+   * \brief Derived object class for constructing PackedFuncObj.
+   * \param callable The type-erased callable object.
+   */
+  explicit FunctionObjImpl(TCallable callable)
+      : callable_(callable) {
+    this->call = Call;
+    this->safe_call = SafeCall;
+  }
+
+ private:
+  // implementation of call
+  static void Call(const FunctionObj* func, int32_t num_args, const AnyView* 
args, Any* result) {
+    (static_cast<const TSelf*>(func))->callable_(num_args, args, result);
+  }
+
+  /*! \brief Type-erased filed for storing callable object*/
+  mutable TStorage callable_;
+};
+
+/*!
+ * \brief Base class to provide a common implementation to redirect call to 
safecall
+ * \tparam Derived The derived class in CRTP-idiom
+ */
+template <typename Derived>
+struct RedirectCallToSafeCall {
+  static void Call(const FunctionObj* func, int32_t num_args, const AnyView* 
args , Any* rv) {
+    Derived* self = static_cast<Derived*>(const_cast<FunctionObj*>(func));
+    int ret_code = self->RedirectSafeCall(
+      num_args, reinterpret_cast<const TVMFFIAny*>(args),
+      reinterpret_cast<TVMFFIAny*>(rv));
+    if (ret_code != 0) {
+      if (std::optional<tvm::ffi::Error> err = rv->TryAs<tvm::ffi::Error>()) {
+        throw std::move(*err);
+      } else {
+        TVM_FFI_THROW(RuntimeError) << "Error encountered when calling a 
tvm::ffi::Function";
+      }
+    }
+  }
+
+  static int32_t SafeCall(void* func, int32_t num_args, const TVMFFIAny* args, 
TVMFFIAny* rv) {
+    Derived* self = reinterpret_cast<Derived*>(func);
+    return self->RedirectSafeCall(num_args, args, rv);
+  }
+};
+
+/*!
+ * \brief FunctionObj specialization that leverages C-style callback 
definitions.
+ */
+class ExternCFunctionObjImpl :
+  public FunctionObj,
+  public RedirectCallToSafeCall<ExternCFunctionObjImpl> {
+ public:
+  using RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall;
+
+  ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void 
(*deleter)(void* self))
+    : self_(self), safe_call_(safe_call), deleter_(deleter) {
+    this->call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::Call;
+    this->safe_call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall;
+  }
+
+  ~ExternCFunctionObjImpl() {
+    deleter_(self_);
+  }
+
+  TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* 
args, TVMFFIAny* rv) const {
+    return safe_call_(self_, num_args, args, rv);
+  }
+
+ private:
+  void* self_;
+  TVMFFISafeCallType safe_call_;
+  void (*deleter_)(void* self);
+};
+
+/*!
+ * \brief FunctionObj specialization that wraps an external function.
+ */
+class ImportedFunctionObjImpl :
+  public FunctionObj,
+  public RedirectCallToSafeCall<ImportedFunctionObjImpl> {
+ public:
+  using RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall;
+
+  explicit ImportedFunctionObjImpl(ObjectPtr<Object> data)
+    : data_(data) {
+    this->call = RedirectCallToSafeCall<ImportedFunctionObjImpl>::Call;
+    this->safe_call = 
RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall;
+  }
+
+  TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* 
args, TVMFFIAny* rv) const {
+    FunctionObj* func = const_cast<FunctionObj*>(static_cast<const 
FunctionObj*>(data_.get()));
+    return func->safe_call(func, num_args, args, rv);
+  }
+
+ private:
+  ObjectPtr<Object> data_;
+};
+
+// Helper class to set packed arguments
+class PackedArgsSetter {
+ public:
+  PackedArgsSetter(AnyView* args) : args_(args) {}
+
+  // NOTE: setter needs to be very carefully designed
+  // such that we do not have temp variable conversion(eg. convert from lvalue 
to rvalue)
+  // that is why we need T&& and std::forward here
+  template<typename T>
+  TVM_FFI_INLINE void operator()(size_t i, T&& value) const {
+    args_[i].operator=(std::forward<T>(value));
+  }
+
+ private:
+  AnyView* args_;
+};
+}  // namespace details
+
+/*!
+ * \brief ffi::Function  is a type-erased function.
+ *  The arguments are passed by packed format.
+ */
+class Function : public ObjectRef {
+ public:
+  /*! \brief Constructor from null */
+  Function(std::nullptr_t) : ObjectRef(nullptr) {}  // NOLINT(*)
+   /*!
+   * \brief Constructing a packed function from a callable type
+   *        whose signature is consistent with `PackedFunc`
+   * \param packed_call The packed function signature
+   */
+  template <typename TCallable>
+  static Function FromPacked(TCallable packed_call) {
+    static_assert(
+      std::is_convertible_v<TCallable, std::function<void(int32_t, const 
AnyView*, Any*)>>,
+      "tvm::ffi::Function::FromPacked requires input function signature to 
match packed func format"
+    );
+    using ObjType = typename details::FunctionObjImpl<TCallable>;
+    Function func;
+    func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call));
+    return func;
+  }
+  /*!
+   * \brief Import a possibly externally defined function to this dll
+   * \param other Function defined in another dynamic library.
+   *
+   * \note This function will redirect the call to safe_call in other.
+   *  It will try to detect if the function is already from the same DLL
+   *  and directly return the original function if so.
+   *
+   * \return The imported function.
+   */
+  static Function ImportFromExternDLL(Function other) {
+    const FunctionObj* other_func = static_cast<const 
FunctionObj*>(other.get());
+    // the other function comes from the same dll, no action needed
+    if (other_func->safe_call == FunctionObj::SafeCall ||
+        other_func->safe_call == details::ImportedFunctionObjImpl::SafeCall ||
+        other_func->safe_call == details::ExternCFunctionObjImpl::SafeCall) {
+      return other;
+    }
+    // the other function coems from a different library
+    Function func;
+    func.data_ = 
make_object<details::ImportedFunctionObjImpl>(std::move(other.data_));
+    return func;
+  }
+  /*!
+   * \brief Create ffi::Function from a C style callbacks.
+   * \param self Resource handle to the function
+   * \param safe_call The safe_call definition in C.
+   * \param deleter The deleter to release the resource of self.
+   * \return The created function.
+   */
+  static Function FromExternC(void* self, TVMFFISafeCallType safe_call, void 
(*deleter)(void* self)) {
+    // the other function coems from a different library
+    Function func;
+    func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, 
deleter);
+    return func;
+  }
+  /*!
+   * \brief Constructing a packed function from a normal function.
+   *
+   * \param callable the internal container of packed function.
+   */
+  template <typename TCallable>
+  static Function FromUnpacked(TCallable callable) {
+    using FuncInfo = details::FunctionInfo<TCallable>;
+    auto call_packed = [callable](int32_t num_args, const AnyView* args, Any* 
rv) -> void {
+      details::unpack_call<typename FuncInfo::RetType, 
FuncInfo::num_args>(nullptr, callable, num_args, args, rv);
+    };
+    return FromPacked(call_packed);
+  }
+  /*!
+   * \brief Constructing a packed function from a normal function.
+   *
+   * \param callable the internal container of packed function.
+   * \param name optional name attacked to the function.
+   */
+  template <typename TCallable>
+  static Function FromUnpacked(TCallable callable, std::string name) {
+    using FuncInfo = details::FunctionInfo<TCallable>;
+    auto call_packed = [callable, name](int32_t num_args, const AnyView* args, 
Any* rv) -> void {
+      details::unpack_call<typename FuncInfo::RetType, 
FuncInfo::num_args>(&name, callable,
+                                                                           
num_args, args, rv);
+    };
+    return FromPacked(call_packed);
+  }
+  /*!
+   * \brief Call function by directly passing in unpacked arguments.
+   *
+   * \param args Arguments to be passed.
+   * \tparam Args arguments to be passed.
+   *
+   * \code
+   *   // Example code on how to call packed function
+   *   void CallFFIFunction(tvm::ffi::Function f) {
+   *     // call like normal functions by pass in arguments
+   *     // return value is automatically converted back
+   *     int rvalue = f(1, 2.0);
+   *   }
+   * \endcode
+   */
+  template <typename... Args>
+  TVM_FFI_INLINE Any operator()(Args&&... args) const {
+    const int kNumArgs = sizeof...(Args);
+    const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
+    AnyView args_pack[kArraySize];
+    details::for_each(details::PackedArgsSetter(args_pack), 
std::forward<Args>(args)...);
+    Any result;
+    static_cast<FunctionObj*>(data_.get())->CallPacked(kNumArgs, args_pack, 
&result);
+    return result;
+  }
+  /*!
+   * \brief Call the function in packed format.
+   * \param args The arguments
+   * \param rv The return value.
+   */
+  TVM_FFI_INLINE void CallPacked(int32_t num_args, const AnyView* args, Any* 
result) const {
+   static_cast<FunctionObj*>(data_.get())->CallPacked(num_args, args, result);
+  }
+  /*! \return Whether the packed function is nullptr */
+  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
+  /*! \return Whether the packed function is not nullptr */
+  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
+
+  TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Function, ObjectRef, FunctionObj);
+};
+
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_OBJECT_H_
diff --git a/ffi/include/tvm/ffi/function_details.h 
b/ffi/include/tvm/ffi/function_details.h
new file mode 100644
index 0000000000..f22258029d
--- /dev/null
+++ b/ffi/include/tvm/ffi/function_details.h
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/function_details.h
+ * \brief Implements the funciton signature reflection
+ */
+#ifndef TVM_FFI_FUNCTION_DETAILS_H_
+#define TVM_FFI_FUNCTION_DETAILS_H_
+
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/internal_utils.h>
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/error.h>
+
+namespace tvm {
+namespace ffi {
+namespace details {
+
+template <typename Type>
+struct Type2Str {
+  static std::string v() {
+    return TypeTraitsNoCR<Type>::TypeStr();
+  }
+};
+
+template <>
+struct Type2Str<Any> {
+  static const char* v() {
+    return "Any";
+  }
+};
+
+template <>
+struct Type2Str<AnyView> {
+  static const char* v() {
+    return "AnyView";
+  }
+};
+template <>
+struct Type2Str<void> {
+  static const char* v() {
+    return "void";
+  }
+};
+
+template <typename ArgType>
+struct Arg2Str {
+  template <size_t i>
+  static TVM_FFI_INLINE void Apply(std::ostream& os) {
+    using Arg = std::tuple_element_t<i, ArgType>;
+    if constexpr (i != 0) {
+      os << ", ";
+    }
+    os << i << ": " << Type2Str<Arg>::v();
+  }
+  template <size_t... I>
+  static TVM_FFI_INLINE void Run(std::ostream &os, std::index_sequence<I...>) {
+    using TExpander = int[];
+    (void)TExpander{0, (Apply<I>(os), 0)...};
+  }
+};
+
+template <typename T>
+static constexpr bool ArgSupported = (
+  std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, Any> ||
+  std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, AnyView> ||
+  TypeTraitsNoCR<T>::enabled
+);
+
+// NOTE: return type can only support non-reference managed returns
+template <typename T>
+static constexpr bool RetSupported = (
+    std::is_same_v<T, Any> || std::is_void_v<T> || TypeTraits<T>::enabled
+);
+
+template <typename R, typename... Args>
+struct FuncFunctorImpl {
+  using FType = R(Args...);
+  using ArgType = std::tuple<Args...>;
+  using RetType = R;
+  /*! \brief total number of arguments*/
+  static constexpr size_t num_args = sizeof...(Args);
+  /*! \brief Whether this function can be converted to ffi::Function via 
FromUnpacked */
+  static constexpr bool unpacked_supported = (ArgSupported<Args> && ...) && 
(RetSupported<R>);
+
+  static TVM_FFI_INLINE std::string Sig() {
+    using IdxSeq = std::make_index_sequence<sizeof...(Args)>;
+    std::ostringstream ss;
+    ss << "(";
+    Arg2Str<std::tuple<Args...>>::Run(ss, IdxSeq{});
+    ss << ") -> " << Type2Str<R>::v();
+    return ss.str();
+  }
+};
+
+template <typename T>
+struct FunctionInfoHelper;
+
+template <typename T, typename R, typename... Args>
+struct FunctionInfoHelper<R (T::*)(Args...)>: FuncFunctorImpl<R, Args...> {};
+template <typename T, typename R, typename... Args>
+struct FunctionInfoHelper<R (T::*)(Args...) const>: FuncFunctorImpl<R, 
Args...> {};
+
+/*!
+ * \brief Template class to get function signature of a function or functor.
+ * \tparam T The function/functor type.
+ * \note We need a decltype redirection because this helps lambda types.
+ */
+template <typename T>
+struct FunctionInfo : FunctionInfoHelper<decltype(&T::operator())> {};
+
+template <typename R, typename... Args>
+struct FunctionInfo<R(Args...)> : FuncFunctorImpl<R, Args...> {};
+template <typename R, typename... Args>
+struct FunctionInfo<R (*)(Args...)> : FuncFunctorImpl<R, Args...> {};
+
+/*! \brief Using static function to output TypedPackedFunc signature */
+typedef std::string (*FGetFuncSignature)();
+
+template<typename T>
+TVM_FFI_INLINE std::optional<T> TryAs(AnyView arg) {
+  return arg.TryAs<T>();
+}
+template<>
+TVM_FFI_INLINE std::optional<Any> TryAs<Any>(AnyView arg) {
+  return Any(arg);
+}
+template<>
+TVM_FFI_INLINE std::optional<AnyView> TryAs<AnyView>(AnyView arg) {
+  return arg;
+}
+
+/*!
+ * \brief Auxilary argument value with context for error reporting
+ */
+class MovableArgValueWithContext {
+ public:
+  /*!
+   * \brief move constructor from another return value.
+   * \param args The argument list
+   * \param arg_index In a function call, this argument is at index arg_index 
(0-indexed).
+   * \param optional_name Name of the function being called. Can be nullptr if 
the function is not.
+   * \param f_sig Pointer to static function outputting signature of the 
function being called.
+   * named.
+   */
+  TVM_FFI_INLINE MovableArgValueWithContext(
+    const AnyView* args, int32_t arg_index,
+    const std::string* optional_name,
+    FGetFuncSignature f_sig)
+      : args_(args),
+        arg_index_(arg_index),
+        optional_name_(optional_name),
+        f_sig_(f_sig) {}
+
+  template <typename Type>
+  TVM_FFI_INLINE operator Type() {
+    using TypeWithoutCR = std::remove_const_t<std::remove_reference_t<Type>>;
+    std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]);
+    if (opt.has_value()) {
+      return std::move(*opt);
+    }
+    TVM_FFI_THROW(TypeError)
+      << "Mismatched type on argument #" << arg_index_ << " when calling: `"
+      << (optional_name_ == nullptr ? "" : *optional_name_)
+      << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `"
+      << Type2Str<Type>::v() << "` but got `"
+      << TypeIndex2TypeKey(args_[arg_index_].type_index()) << "`";
+  }
+
+ private:
+  const AnyView* args_;
+  int32_t arg_index_;
+  const std::string* optional_name_;
+  FGetFuncSignature f_sig_;
+};
+
+template <typename R, int nleft, int index, typename F>
+struct unpack_call_dispatcher {
+  template <typename... Args>
+  TVM_FFI_INLINE static void run(const std::string* optional_name, 
FGetFuncSignature f_sig, const F& f,
+                                 int32_t num_args, const AnyView* args, Any* 
rv,
+                                 Args&&... unpacked_args) {
+    // construct a movable argument value
+    // which allows potential move of argument to the input of F.
+    unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
+        optional_name, f_sig, f, num_args, args, rv, 
std::forward<Args>(unpacked_args)...,
+        MovableArgValueWithContext(args, index, optional_name, f_sig));
+  }
+};
+
+template <typename R, int index, typename F>
+struct unpack_call_dispatcher<R, 0, index, F> {
+  template <typename... Args>
+  TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const 
F& f,
+                                 int32_t, const AnyView*, Any* rv,
+                                 Args&&... unpacked_args) {
+    using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
+    if constexpr (std::is_same_v<RetType, R>) {
+      *rv = f(std::forward<Args>(unpacked_args)...);
+    } else {
+      *rv = R(f(std::forward<Args>(unpacked_args)...));
+    }
+  }
+};
+
+template <int index, typename F>
+struct unpack_call_dispatcher<void, 0, index, F> {
+  template <typename... Args>
+  TVM_FFI_INLINE static void run(const std::string* optional_name, 
FGetFuncSignature f_sig, const F& f,
+                                    int32_t num_args, const AnyView* args, 
Any* rv,
+                                    Args&&... unpacked_args) {
+    f(std::forward<Args>(unpacked_args)...);
+  }
+};
+
+template <typename R, int nargs, typename F>
+TVM_FFI_INLINE void unpack_call(const std::string* optional_name, const F& f,
+                                int32_t num_args, const AnyView* args, Any* 
rv) {
+  using FuncInfo = FunctionInfo<F>;
+  FGetFuncSignature f_sig = FuncInfo::Sig;
+  static_assert(FuncInfo::unpacked_supported, "The function signature cannot 
support unpacked call");
+  if (nargs != num_args) {
+    TVM_FFI_THROW(TypeError)
+      << "Mismatched number of arguments when calling: `"
+      << (optional_name == nullptr ? "" : *optional_name)
+      << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected "
+      << nargs << " but got " << num_args << " arguments";
+  }
+  unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, 
num_args, args, rv);
+}
+
+template <typename FType>
+struct unpack_call_by_signature {};
+
+template <typename R, typename... Args>
+struct unpack_call_by_signature<R(Args...)> {
+  template <typename F>
+  TVM_FFI_INLINE static void run(const F& f, int32_t num_args, const AnyView* 
args, Any* rv) {
+    unpack_call<R, sizeof...(Args)>(nullptr, f, num_args, args, rv);
+  }
+};
+
+}  // namespace details
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_FUNCTION_DETAILS_H_
diff --git a/ffi/include/tvm/ffi/internal_utils.h 
b/ffi/include/tvm/ffi/internal_utils.h
index 900b7fd2e2..ffe1dd9e1a 100644
--- a/ffi/include/tvm/ffi/internal_utils.h
+++ b/ffi/include/tvm/ffi/internal_utils.h
@@ -17,8 +17,8 @@
  * under the License.
  */
 /*!
- * \file tvm/ffi/internal_utils.h
- * \brief Utility functions and macros for internal use
+ * \file tvm/ffi/base_details.h
+ * \brief Internal use utilities
  */
 #ifndef TVM_FFI_INTERNAL_UTILS_H_
 #define TVM_FFI_INTERNAL_UTILS_H_
@@ -26,6 +26,7 @@
 #include <tvm/ffi/c_api.h>
 
 #include <cstddef>
+#include <utility>
 
 #if defined(_MSC_VER)
 #define TVM_FFI_INLINE __forceinline
@@ -110,6 +111,27 @@ TVM_FFI_INLINE int32_t AtomicLoadRelaxed(const int32_t* 
ptr) {
   return __atomic_load_n(raw_ptr, __ATOMIC_RELAXED);
 #endif
 }
+
+// for each iterator
+template <bool stop, std::size_t I, typename F>
+struct for_each_dispatcher {
+  template <typename T, typename... Args>
+  static void run(const F& f, T&& value, Args&&... args) {  // NOLINT(*)
+    f(I, std::forward<T>(value));
+    for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, 
std::forward<Args>(args)...);
+  }
+};
+
+template <std::size_t I, typename F>
+struct for_each_dispatcher<true, I, F> {
+  static void run(const F&) {}  // NOLINT(*)
+};
+
+template <typename F, typename... Args>
+void for_each(const F& f, Args&&... args) {  // NOLINT(*)
+  for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, 
std::forward<Args>(args)...);
+}
+
 }  // namespace details
 }  // namespace ffi
 }  // namespace tvm
diff --git a/ffi/include/tvm/ffi/type_traits.h 
b/ffi/include/tvm/ffi/type_traits.h
index 0a1d7d9941..18c7ab2a35 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -52,6 +52,13 @@ struct TypeTraits {
   static constexpr bool enabled = false;
 };
 
+/*!
+ * \brief TypeTraits that removes const and reference keywords.
+ * \tparam T the original type
+ */
+template <typename T>
+using TypeTraitsNoCR = 
TypeTraits<std::remove_const_t<std::remove_reference_t<T>>>;
+
 // Integer POD values
 template <typename Int>
 struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> {
@@ -153,7 +160,7 @@ inline std::string TypeIndex2TypeKey(int32_t type_index) {
     case TypeIndex::kTVMFFIInt:
       return "int";
     case TypeIndex::kTVMFFIFloat:
-      return "double";
+      return "float";
     case TypeIndex::kTVMFFIOpaquePtr:
       return "void*";
     case TypeIndex::kTVMFFIDataType:
diff --git a/ffi/src/ffi/traceback.h b/ffi/src/ffi/traceback.h
index 2bc9e523a7..e4ba2a589d 100644
--- a/ffi/src/ffi/traceback.h
+++ b/ffi/src/ffi/traceback.h
@@ -57,6 +57,12 @@ inline bool ShouldExcludeFrame(const char* filename, const 
char* symbol) {
     if (strstr(filename, "include/tvm/ffi/error.h")) {
       return true;
     }
+    if (strstr(filename, "include/tvm/ffi/function_details.h")) {
+      return true;
+    }
+    if (strstr(filename, "include/tvm/ffi/function.h")) {
+      return true;
+    }
     if (strstr(filename, "src/ffi/traceback.cc")) {
       return true;
     }
diff --git a/ffi/tests/example/test_any.cc b/ffi/tests/example/test_any.cc
index 55ef21405e..791ecfe68c 100644
--- a/ffi/tests/example/test_any.cc
+++ b/ffi/tests/example/test_any.cc
@@ -122,7 +122,6 @@ TEST(Any, Object) {
         } catch (const Error& error) {
           EXPECT_EQ(error->kind, "TypeError");
           std::string what = error.what();
-          std::cout << what;
           EXPECT_NE(what.find("Cannot convert from type `test.Int` to 
`test.Float`"),
                     std::string::npos);
           throw;
diff --git a/ffi/tests/example/test_function.cc 
b/ffi/tests/example/test_function.cc
new file mode 100644
index 0000000000..46c69c3739
--- /dev/null
+++ b/ffi/tests/example/test_function.cc
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/memory.h>
+#include <tvm/ffi/function.h>
+
+#include "./testing_object.h"
+
+namespace {
+
+using namespace tvm::ffi;
+using namespace tvm::ffi::testing;
+
+TEST(Func, FromPacked) {
+  Function fadd1 = Function::FromPacked(
+    [](int32_t num_args, const AnyView* args, Any* rv) {
+      EXPECT_EQ(num_args, 1);
+      int32_t a = args[0];
+      *rv = a + 1;
+    }
+  );
+  int b = fadd1(1);
+  EXPECT_EQ(b, 2);
+
+  Function fadd2 = Function::FromPacked(
+    [](int32_t num_args, const AnyView* args, Any* rv) {
+      EXPECT_EQ(num_args, 1);
+      TInt a = args[0];
+      EXPECT_EQ(a.use_count(), 2);
+      *rv = a->value + 1;
+    }
+  );
+  EXPECT_EQ(fadd2(TInt(12)).operator int(), 13);
+}
+
+TEST(Func, FromUnpacked) {
+  // try decution
+  Function fadd1 = Function::FromUnpacked(
+    [](const int32_t& a) -> int { return a +  1; }
+  );
+  int b = fadd1(1);
+  EXPECT_EQ(b, 2);
+
+   // convert that triggers error
+  EXPECT_THROW(
+      {
+        try {
+         fadd1(1.1);
+        } catch (const Error& error) {
+          EXPECT_EQ(error->kind, "TypeError");
+          EXPECT_STREQ(
+            error->message.c_str(),
+            "Mismatched type on argument #0 when calling: `(0: int) -> int`. "
+            "Expected `int` but got `float`");
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+   // convert that triggers error
+   EXPECT_THROW(
+      {
+        try {
+         fadd1();
+        } catch (const Error& error) {
+          EXPECT_EQ(error->kind, "TypeError");
+          EXPECT_STREQ(
+            error->message.c_str(),
+            "Mismatched number of arguments when calling: `(0: int) -> int`. "
+            "Expected 1 but got 0 arguments");
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+  // try decution
+   Function fpass_and_return = Function::FromUnpacked(
+       [](TInt x, int value, AnyView z) -> Function {
+         EXPECT_EQ(x.use_count(), 2);
+         EXPECT_EQ(x->value, value);
+         if (auto opt = z.TryAs<int>()) {
+           EXPECT_EQ(value, *opt);
+         }
+         return Function::FromUnpacked([value](int x) -> int { return x + 
value; });
+       },
+       "fpass_and_return");
+   TInt a(11);
+   Function fret = fpass_and_return(std::move(a), 11, 11);
+   EXPECT_EQ(fret(12).operator int(), 23);
+
+   EXPECT_THROW(
+       {
+         try {
+           fpass_and_return();
+         } catch (const Error& error) {
+           EXPECT_EQ(error->kind, "TypeError");
+           EXPECT_STREQ(error->message.c_str(),
+                        "Mismatched number of arguments when calling: "
+                        "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> 
object.Function`. "
+                        "Expected 3 but got 0 arguments");
+           throw;
+         }
+       },
+       ::tvm::ffi::Error);
+}
+}  // namespace
diff --git a/ffi/tests/example/testing_object.h 
b/ffi/tests/example/testing_object.h
new file mode 100644
index 0000000000..00cc58b7b7
--- /dev/null
+++ b/ffi/tests/example/testing_object.h
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_FFI_TESTING_OBJECT_H_
+#define TVM_FFI_TESTING_OBJECT_H_
+#include <tvm/ffi/memory.h>
+#include <tvm/ffi/object.h>
+
+namespace tvm {
+namespace ffi {
+namespace testing {
+
+class TNumberObj : public Object {
+ public:
+  // declare as one slot, with float as overflow
+  static constexpr uint32_t _type_child_slots = 1;
+  static constexpr const char* _type_key = "test.Number";
+  TVM_FFI_DECLARE_BASE_OBJECT_INFO(TNumberObj, Object);
+};
+
+class TNumber : public ObjectRef {
+ public:
+  TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(TNumber, ObjectRef, TNumberObj);
+};
+
+class TIntObj : public TNumberObj {
+ public:
+  int64_t value;
+
+  TIntObj(int64_t value) : value(value) {}
+
+  static constexpr const char* _type_key = "test.Int";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj);
+};
+
+class TInt : public TNumber {
+ public:
+  explicit TInt(int64_t value) {
+    data_ = make_object<TIntObj>(value);
+  }
+
+  TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj);
+};
+
+class TFloatObj : public TNumberObj {
+ public:
+  double value;
+
+  TFloatObj(double value) : value(value) {}
+
+  static constexpr const char* _type_key = "test.Float";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj);
+};
+
+class TFloat : public TNumber {
+ public:
+  explicit TFloat(double value) {
+    data_ = make_object<TFloatObj>(value);
+  }
+
+  TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(TFloat, TNumber, TFloatObj);
+};
+}  // namespace testing
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_TESTING_OBJECT_H_
\ No newline at end of file


Reply via email to