This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 00fd84b40d73097a06c684fc195b3e058a3ec603 Author: tqchen <[email protected]> AuthorDate: Sun Sep 15 16:06:14 2024 -0400 [FFI] Global function registry --- ffi/include/tvm/ffi/any.h | 20 -------- ffi/include/tvm/ffi/c_api.h | 36 +++++++++++++- ffi/include/tvm/ffi/cast.h | 80 +++++++++++++++++++++++++++++++ ffi/include/tvm/ffi/function.h | 98 +++++++++++++++++++++++++++++++++++-- ffi/include/tvm/ffi/object.h | 12 ++++- ffi/src/ffi/function.cc | 106 ++++++++++++++++++++++++++++++++++++++++- ffi/src/ffi/object.cc | 16 ++++--- ffi/src/ffi/traceback.cc | 4 -- ffi/tests/cpp/test_function.cc | 16 +++++++ ffi/tests/cpp/test_string.cc | 1 + 10 files changed, 350 insertions(+), 39 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 225bbfd21b..f7073d20e4 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -341,26 +341,6 @@ struct AnyEqual { return false; } }; - -// Downcast an object -// NOTE: the implementation is put in here to avoid cyclic dependency -// with the -template <typename SubRef, typename BaseRef, - typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>> -TVM_FFI_INLINE SubRef Downcast(BaseRef ref) { - if (ref.defined()) { - if (!ref->template IsInstance<typename SubRef::ContainerType>()) { - TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; - } - } else { - if (!SubRef::_type_is_nullable) { - TVM_FFI_THROW(TypeError) << "Downcast from nullptr to not nullable reference of " - << SubRef::ContainerType::_type_key; - } - } - return details::ObjectUnsafe::DowncastRefNoCheck<SubRef>(std::move(ref)); -} } // namespace ffi } // namespace tvm #endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index e6d3449ac3..c580410581 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -188,6 +188,41 @@ typedef struct { typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const TVMFFIAny* args, TVMFFIAny* result); +/*! + * \brief Create a FFIFunc by passing in callbacks from C callback. + * + * The registered function then can be pulled by the backend by the name. + * + * \param self The resource handle of the C callback. + * \param safe_call The C callback implementation + * \param deleter deleter to recycle + * \param out The output of the function. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFuncCreate(void* self, TVMFFISafeCallType safe_call, + void (*deleter)(void* self), TVMFFIObjectHandle* out); + +/*! + * \brief Register the function to runtime's global table. + * + * The registered function then can be pulled by the backend by the name. + * + * \param name The name of the function. + * \param f The function to be registered. + * \param override Whether allow override already registered function. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int override); + +/*! + * \brief Get a global function. + * + * \param name The name of the function. + * \param out the result function pointer, NULL if it does not exist. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out); + /*! * \brief Free an object handle by decreasing reference * \param obj The object handle. @@ -211,7 +246,6 @@ TVM_FFI_DLL void TVMFFIMoveFromLastError(TVMFFIAny* result); * * \param error_view The error in format of any view. * It can be an object, or simply a raw c_str. - * \note */ TVM_FFI_DLL void TVMFFISetLastError(const TVMFFIAny* error_view); diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h new file mode 100644 index 0000000000..3cdb620d22 --- /dev/null +++ b/ffi/include/tvm/ffi/cast.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/cast.h + * \brief Value casting support + */ +#ifndef TVM_FFI_CAST_H_ +#define TVM_FFI_CAST_H_ + +#include <tvm/ffi/error.h> +#include <tvm/ffi/object.h> + +namespace tvm { +namespace ffi { +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the object alive beyond the scope of the function. + * + * \param ptr The object pointer + * \tparam RefType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template <typename RefType, typename ObjectType> +inline RefType GetRef(const ObjectType* ptr) { + static_assert(std::is_base_of_v<typename RefType::ContainerType, ObjectType>, + "Can only cast to the ref of same container type"); + if (!RefType::_type_is_nullable) { + TVM_FFI_ICHECK_NOTNULL(ptr); + } + return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>( + const_cast<Object*>(static_cast<const Object*>(ptr)))); +} + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template <typename SubRef, typename BaseRef, + typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>> +inline SubRef Downcast(BaseRef ref) { + if (ref.defined()) { + if (!ref->template IsInstance<typename SubRef::ContainerType>()) { + TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } + } else { + if (!SubRef::_type_is_nullable) { + TVM_FFI_THROW(TypeError) << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; + } + } + return details::ObjectUnsafe::DowncastRefNoCheck<SubRef>(std::move(ref)); +} +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CAST_H_ diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index cb4cb2b937..680758159a 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -246,11 +246,11 @@ class Function : public ObjectRef { public: /*! \brief Constructor from null */ Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `PackedFunc` - * \param packed_call The packed function signature - */ + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` + * \param packed_call The packed function signature + */ template <typename TCallable> static Function FromPacked(TCallable packed_call) { static_assert( @@ -299,6 +299,31 @@ class Function : public ObjectRef { func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, deleter); return func; } + /*! + * \brief Get global function by name + * \param name The function name + * \return The global function. + */ + static Function GetGlobal(const char* name) { + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIFuncGetGlobal(name, &handle)); + if (handle != nullptr) { + return Function( + details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<Object*>(handle))); + } else { + return Function(); + } + } + /*! + * \brief Set global function by name + * \param name The name of the function + * \param func The function + * \param override Whether to override when there is duplication. + */ + static void SetGlobal(const char* name, Function func, bool override = false) { + TVM_FFI_CHECK_SAFE_CALL( + TVMFFIFuncSetGlobal(name, details::ObjectUnsafe::GetHeader(func.get()), override)); + } /*! * \brief Constructing a packed function from a normal function. * @@ -367,8 +392,71 @@ class Function : public ObjectRef { bool operator!=(std::nullptr_t) const { return data_ != nullptr; } TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Function, ObjectRef, FunctionObj); + + class Registry; }; +/*! \brief Registry for global function */ +class Function::Registry { + public: + /*! \brief constructor */ + explicit Registry(const char* name) : name_(name) {} + /*! + * \brief set the body of the function to the given function. + * Note that this will ignore default arg values and always require all arguments to be + * provided. + * + * \code + * + * int multiply(int x, int y) { + * return x * y; + * } + * + * TVM_REGISTER_GLOBAL("multiply") + * .set_body_typed(multiply); // will have type int(int, int) + * + * // will have type int(int, int) + * TVM_REGISTER_GLOBAL("sub") + * .set_body_typed([](int a, int b) -> int { return a - b; }); + * + * \endcode + * + * \param f The function to forward to. + * \tparam FLambda The signature of the function. + */ + template <typename FLambda> + Registry& set_body_typed(FLambda f) { + return Register(Function::FromUnpacked(f)); + } + + protected: + /*! + * \brief set the body of the function to be f + * \param f The body of the function. + */ + Registry& Register(Function f) { + Function::SetGlobal(name_, f); + return *this; + } + + /*! \brief name of the function */ + const char* name_; +}; + +#define TVM_FFI_FUNC_REG_VAR_DEF \ + static TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::Function::Registry& __mk_##TVMFFI + +/*! + * \brief Register a function globally. + * \code + * TVM_FFI_REGISTER_GLOBAL("MyAdd") + * .set_body_typed([](int a, int b) { + * return a + b; + * }); + * \endcode + */ +#define TVM_FFI_REGISTER_GLOBAL(OpName) \ + TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::ffi::Function::Registry(OpName) } // namespace ffi } // namespace tvm #endif // TVM_FFI_FUNCTION_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index dbb940ab45..6a202ba1f5 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -545,6 +545,13 @@ struct ObjectUnsafe { return const_cast<TVMFFIObject*>(&(src->header_)); } + template <typename T> + static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromOwned(Object* raw_ptr) { + tvm::ffi::ObjectPtr<T> ptr; + ptr.data_ = raw_ptr; + return ptr; + } + // Create ObjectPtr from unknowned ptr template <typename T> static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromUnowned(Object* raw_ptr) { @@ -556,7 +563,10 @@ struct ObjectUnsafe { return tvm::ffi::ObjectPtr<T>(reinterpret_cast<Object*>(obj_ptr)); } - // Interactions with Any system + static TVM_FFI_INLINE void DecRefObjectHandle(TVMFFIObjectHandle handle) { + reinterpret_cast<Object*>(handle)->DecRef(); + } + static TVM_FFI_INLINE void DecRefObjectInAny(TVMFFIAny* src) { reinterpret_cast<Object*>(src->v_obj)->DecRef(); } diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc index 5da4e0178b..f306af09f6 100644 --- a/ffi/src/ffi/function.cc +++ b/ffi/src/ffi/function.cc @@ -22,9 +22,14 @@ */ #include <tvm/ffi/any.h> #include <tvm/ffi/c_api.h> +#include <tvm/ffi/cast.h> +#include <tvm/ffi/container/array.h> #include <tvm/ffi/error.h> +#include <tvm/ffi/function.h> #include <tvm/ffi/string.h> +#include <unordered_map> + namespace tvm { namespace ffi { @@ -49,10 +54,99 @@ class SafeCallContext { Any last_error_; }; +/*! + * \brief Global function table. + * + + * \note We do not use mutex to guard updating of GlobalFunctionTable + * + * The assumption is that updating of GlobalFunctionTable will be done + * in the main thread during initialization or loading, or + * explicitly locked from the caller. + * + * Then the followup code will leverage the information + */ +class GlobalFunctionTable { + public: + void Update(const String& name, Function func, bool can_override) { + if (table_.count(name)) { + if (!can_override) { + TVM_FFI_THROW(RuntimeError) << "Global Function `" << name << "` is already registered"; + } + } + table_[name] = new Function(func); + } + + bool Remove(const String& name) { + auto it = table_.find(name); + if (it == table_.end()) return false; + table_.erase(it); + return true; + } + + const Function* Get(const String& name) { + auto it = table_.find(name); + if (it == table_.end()) return nullptr; + return it->second; + } + + Array<String> ListNames() const { + Array<String> names; + names.reserve(table_.size()); + for (const auto& kv : table_) { + names.push_back(kv.first); + } + return names; + } + + static GlobalFunctionTable* Global() { + // We deliberately create a new instance via raw new + // This is because GlobalFunctionTable can contain callbacks into + // the host language (Python) and the resource can become invalid + // indeterministic order of destruction and forking. + // The resources will only be recycled during program exit. + static GlobalFunctionTable* inst = new GlobalFunctionTable(); + return inst; + } + + private: + // deliberately track function pointer without recycling + // to avoid + std::unordered_map<String, Function*> table_; +}; } // namespace ffi } // namespace tvm -extern "C" { +int TVMFFIFuncCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self), + TVMFFIObjectHandle* out) { + using namespace tvm::ffi; + TVM_FFI_SAFE_CALL_BEGIN(); + Function func = Function::FromExternC(self, safe_call, deleter); + *out = details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectRef(&func); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int override) { + using namespace tvm::ffi; + TVM_FFI_SAFE_CALL_BEGIN(); + GlobalFunctionTable::Global()->Update(name, GetRef<Function>(static_cast<FunctionObj*>(f)), + override != 0); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out) { + using namespace tvm::ffi; + TVM_FFI_SAFE_CALL_BEGIN(); + const Function* fp = GlobalFunctionTable::Global()->Get(name); + if (fp != nullptr) { + Function func(*fp); + *out = details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectRef(&func); + } else { + *out = nullptr; + } + TVM_FFI_SAFE_CALL_END(); +} + void TVMFFISetLastError(const TVMFFIAny* error_view) { tvm::ffi::SafeCallContext::ThreadLocal()->SetLastError(error_view); } @@ -60,4 +154,12 @@ void TVMFFISetLastError(const TVMFFIAny* error_view) { void TVMFFIMoveFromLastError(TVMFFIAny* result) { tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromLastError(result); } -} + +TVM_FFI_REGISTER_GLOBAL("tvm_ffi.GlobalFunctionRemove") + .set_body_typed([](const tvm::ffi::String& name) -> bool { + return tvm::ffi::GlobalFunctionTable::Global()->Remove(name); + }); + +TVM_FFI_REGISTER_GLOBAL("tvm_ffi.GlobalFunctionListNames").set_body_typed([]() { + return tvm::ffi::GlobalFunctionTable::Global()->ListNames(); +}); diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index d0782d1b50..8941753df8 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -34,12 +34,13 @@ namespace tvm { namespace ffi { /*! - * \brief Type context that manages the type hierarchy information. + * \brief Global registry that manages * - * \note We do not use mutex to guard updating of TypeContext + * \note We do not use mutex to guard updating of TypeTable * - * The assumption is that updating of TypeContext will be done - * in the main thread during initialization or loading. + * The assumption is that updating of TypeTable will be done + * in the main thread during initialization or loading, or + * explicitly locked from the caller. * * Then the followup code will leverage the information */ @@ -223,7 +224,11 @@ class TypeTable { } // namespace ffi } // namespace tvm -extern "C" { +int TVMFFIObjectFree(TVMFFIObjectHandle handle) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle); + TVM_FFI_SAFE_CALL_END(); +} int32_t TVMFFIGetOrAllocTypeIndex(const char* type_key, int32_t static_type_index, int32_t type_depth, int32_t num_child_slots, @@ -240,4 +245,3 @@ const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) { return tvm::ffi::TypeTable::Global()->GetTypeInfo(type_index); TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo); } -} // extern "C" diff --git a/ffi/src/ffi/traceback.cc b/ffi/src/ffi/traceback.cc index 70de055719..6de3efe7ca 100644 --- a/ffi/src/ffi/traceback.cc +++ b/ffi/src/ffi/traceback.cc @@ -149,15 +149,12 @@ __attribute__((constructor)) void install_signal_handler(void) { } // namespace ffi } // namespace tvm -extern "C" { const char* TVMFFITraceback(const char*, const char*, int) { static thread_local std::string traceback_str; traceback_str = ::tvm::ffi::Traceback(); return traceback_str.c_str(); } -} // extern "C" #else -extern "C" { // fallback implementation simply print out the last trace const char* TVMFFITraceback(const char* filename, const char* func, int lineno) { static thread_local std::string traceback_str; @@ -167,6 +164,5 @@ const char* TVMFFITraceback(const char* filename, const char* func, int lineno) traceback_str = traceback_stream.str(); return traceback_str.c_str(); } -} // extern "C" #endif // TVM_FFI_USE_LIBBACKTRACE #endif // _MSC_VER diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc index 99e7819b23..367e613665 100644 --- a/ffi/tests/cpp/test_function.cc +++ b/ffi/tests/cpp/test_function.cc @@ -18,6 +18,7 @@ */ #include <gtest/gtest.h> #include <tvm/ffi/any.h> +#include <tvm/ffi/container/array.h> #include <tvm/ffi/function.h> #include <tvm/ffi/memory.h> @@ -112,4 +113,19 @@ TEST(Func, FromUnpacked) { }, ::tvm::ffi::Error); } + +TEST(Func, Global) { + Function::SetGlobal("testing.add1", + Function::FromUnpacked([](const int32_t& a) -> int { return a + 1; })); + Function fadd1 = Function::GetGlobal("testing.add1"); + int b = fadd1(1); + EXPECT_EQ(b, 2); + Function fnot_exist = Function::GetGlobal("testing.not_existing_func"); + EXPECT_TRUE(fnot_exist == nullptr); + + Array<String> names = Function::GetGlobal("tvm_ffi.GlobalFunctionListNames")(); + + EXPECT_TRUE(std::find(names.begin(), names.end(), "testing.add1") != names.end()); +} + } // namespace diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc index 23310824a6..04348292b7 100644 --- a/ffi/tests/cpp/test_string.cc +++ b/ffi/tests/cpp/test_string.cc @@ -18,6 +18,7 @@ */ #include <gtest/gtest.h> #include <tvm/ffi/any.h> +#include <tvm/ffi/cast.h> #include <tvm/ffi/string.h> namespace {
