This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 00fd84b40d73097a06c684fc195b3e058a3ec603
Author: tqchen <[email protected]>
AuthorDate: Sun Sep 15 16:06:14 2024 -0400

    [FFI] Global function registry
---
 ffi/include/tvm/ffi/any.h      |  20 --------
 ffi/include/tvm/ffi/c_api.h    |  36 +++++++++++++-
 ffi/include/tvm/ffi/cast.h     |  80 +++++++++++++++++++++++++++++++
 ffi/include/tvm/ffi/function.h |  98 +++++++++++++++++++++++++++++++++++--
 ffi/include/tvm/ffi/object.h   |  12 ++++-
 ffi/src/ffi/function.cc        | 106 ++++++++++++++++++++++++++++++++++++++++-
 ffi/src/ffi/object.cc          |  16 ++++---
 ffi/src/ffi/traceback.cc       |   4 --
 ffi/tests/cpp/test_function.cc |  16 +++++++
 ffi/tests/cpp/test_string.cc   |   1 +
 10 files changed, 350 insertions(+), 39 deletions(-)

diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 225bbfd21b..f7073d20e4 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -341,26 +341,6 @@ struct AnyEqual {
     return false;
   }
 };
-
-// Downcast an object
-// NOTE: the implementation is put in here to avoid cyclic dependency
-// with the
-template <typename SubRef, typename BaseRef,
-          typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>>
-TVM_FFI_INLINE SubRef Downcast(BaseRef ref) {
-  if (ref.defined()) {
-    if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
-      TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " 
to "
-                               << SubRef::ContainerType::_type_key << " 
failed.";
-    }
-  } else {
-    if (!SubRef::_type_is_nullable) {
-      TVM_FFI_THROW(TypeError) << "Downcast from nullptr to not nullable 
reference of "
-                               << SubRef::ContainerType::_type_key;
-    }
-  }
-  return details::ObjectUnsafe::DowncastRefNoCheck<SubRef>(std::move(ref));
-}
 }  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_FFI_ANY_H_
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index e6d3449ac3..c580410581 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -188,6 +188,41 @@ typedef struct {
 typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const 
TVMFFIAny* args,
                                   TVMFFIAny* result);
 
+/*!
+ * \brief Create a FFIFunc by passing in callbacks from C callback.
+ *
+ * The registered function then can be pulled by the backend by the name.
+ *
+ * \param self The resource handle of the C callback.
+ * \param safe_call The C callback implementation
+ * \param deleter deleter to recycle
+ * \param out The output of the function.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFIFuncCreate(void* self, TVMFFISafeCallType safe_call,
+                                 void (*deleter)(void* self), 
TVMFFIObjectHandle* out);
+
+/*!
+ * \brief Register the function to runtime's global table.
+ *
+ * The registered function then can be pulled by the backend by the name.
+ *
+ * \param name The name of the function.
+ * \param f The function to be registered.
+ * \param override Whether allow override already registered function.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, 
int override);
+
+/*!
+ * \brief Get a global function.
+ *
+ * \param name The name of the function.
+ * \param out the result function pointer, NULL if it does not exist.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out);
+
 /*!
  * \brief Free an object handle by decreasing reference
  * \param obj The object handle.
@@ -211,7 +246,6 @@ TVM_FFI_DLL void TVMFFIMoveFromLastError(TVMFFIAny* result);
  *
  * \param error_view The error in format of any view.
  *        It can be an object, or simply a raw c_str.
- * \note
  */
 TVM_FFI_DLL void TVMFFISetLastError(const TVMFFIAny* error_view);
 
diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h
new file mode 100644
index 0000000000..3cdb620d22
--- /dev/null
+++ b/ffi/include/tvm/ffi/cast.h
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/cast.h
+ * \brief Value casting support
+ */
+#ifndef TVM_FFI_CAST_H_
+#define TVM_FFI_CAST_H_
+
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/object.h>
+
+namespace tvm {
+namespace ffi {
+/*!
+ * \brief Get a reference type from a raw object ptr type
+ *
+ *  It is always important to get a reference type
+ *  if we want to return a value as reference or keep
+ *  the object alive beyond the scope of the function.
+ *
+ * \param ptr The object pointer
+ * \tparam RefType The reference type
+ * \tparam ObjectType The object type
+ * \return The corresponding RefType
+ */
+template <typename RefType, typename ObjectType>
+inline RefType GetRef(const ObjectType* ptr) {
+  static_assert(std::is_base_of_v<typename RefType::ContainerType, ObjectType>,
+                "Can only cast to the ref of same container type");
+  if (!RefType::_type_is_nullable) {
+    TVM_FFI_ICHECK_NOTNULL(ptr);
+  }
+  return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
+      const_cast<Object*>(static_cast<const Object*>(ptr))));
+}
+
+/*!
+ * \brief Downcast a base reference type to a more specific type.
+ *
+ * \param ref The input reference
+ * \return The corresponding SubRef.
+ * \tparam SubRef The target specific reference type.
+ * \tparam BaseRef the current reference type.
+ */
+template <typename SubRef, typename BaseRef,
+          typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>>
+inline SubRef Downcast(BaseRef ref) {
+  if (ref.defined()) {
+    if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
+      TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " 
to "
+                               << SubRef::ContainerType::_type_key << " 
failed.";
+    }
+  } else {
+    if (!SubRef::_type_is_nullable) {
+      TVM_FFI_THROW(TypeError) << "Downcast from nullptr to not nullable 
reference of "
+                               << SubRef::ContainerType::_type_key;
+    }
+  }
+  return details::ObjectUnsafe::DowncastRefNoCheck<SubRef>(std::move(ref));
+}
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_CAST_H_
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index cb4cb2b937..680758159a 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -246,11 +246,11 @@ class Function : public ObjectRef {
  public:
   /*! \brief Constructor from null */
   Function(std::nullptr_t) : ObjectRef(nullptr) {}  // NOLINT(*)
-                                                    /*!
-                                                     * \brief Constructing a 
packed function from a callable type
-                                                     *        whose signature 
is consistent with `PackedFunc`
-                                                     * \param packed_call The 
packed function signature
-                                                     */
+  /*!
+   * \brief Constructing a packed function from a callable type
+   *        whose signature is consistent with `PackedFunc`
+   * \param packed_call The packed function signature
+   */
   template <typename TCallable>
   static Function FromPacked(TCallable packed_call) {
     static_assert(
@@ -299,6 +299,31 @@ class Function : public ObjectRef {
     func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, 
deleter);
     return func;
   }
+  /*!
+   * \brief Get global function by name
+   * \param name The function name
+   * \return The global function.
+   */
+  static Function GetGlobal(const char* name) {
+    TVMFFIObjectHandle handle;
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFIFuncGetGlobal(name, &handle));
+    if (handle != nullptr) {
+      return Function(
+          
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<Object*>(handle)));
+    } else {
+      return Function();
+    }
+  }
+  /*!
+   * \brief Set global function by name
+   * \param name The name of the function
+   * \param func The function
+   * \param override Whether to override when there is duplication.
+   */
+  static void SetGlobal(const char* name, Function func, bool override = 
false) {
+    TVM_FFI_CHECK_SAFE_CALL(
+        TVMFFIFuncSetGlobal(name, 
details::ObjectUnsafe::GetHeader(func.get()), override));
+  }
   /*!
    * \brief Constructing a packed function from a normal function.
    *
@@ -367,8 +392,71 @@ class Function : public ObjectRef {
   bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
 
   TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Function, ObjectRef, FunctionObj);
+
+  class Registry;
 };
 
+/*! \brief Registry for global function */
+class Function::Registry {
+ public:
+  /*! \brief constructor */
+  explicit Registry(const char* name) : name_(name) {}
+  /*!
+   * \brief set the body of the function to the given function.
+   *        Note that this will ignore default arg values and always require 
all arguments to be
+   * provided.
+   *
+   * \code
+   *
+   * int multiply(int x, int y) {
+   *   return x * y;
+   * }
+   *
+   * TVM_REGISTER_GLOBAL("multiply")
+   * .set_body_typed(multiply); // will have type int(int, int)
+   *
+   * // will have type int(int, int)
+   * TVM_REGISTER_GLOBAL("sub")
+   * .set_body_typed([](int a, int b) -> int { return a - b; });
+   *
+   * \endcode
+   *
+   * \param f The function to forward to.
+   * \tparam FLambda The signature of the function.
+   */
+  template <typename FLambda>
+  Registry& set_body_typed(FLambda f) {
+    return Register(Function::FromUnpacked(f));
+  }
+
+ protected:
+  /*!
+   * \brief set the body of the function to be f
+   * \param f The body of the function.
+   */
+  Registry& Register(Function f) {
+    Function::SetGlobal(name_, f);
+    return *this;
+  }
+
+  /*! \brief name of the function */
+  const char* name_;
+};
+
+#define TVM_FFI_FUNC_REG_VAR_DEF \
+  static TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::Function::Registry& __mk_##TVMFFI
+
+/*!
+ * \brief Register a function globally.
+ * \code
+ *   TVM_FFI_REGISTER_GLOBAL("MyAdd")
+ *   .set_body_typed([](int a, int b) {
+ *      return a + b;
+ *   });
+ * \endcode
+ */
+#define TVM_FFI_REGISTER_GLOBAL(OpName) \
+  TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = 
::tvm::ffi::Function::Registry(OpName)
 }  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_FFI_FUNCTION_H_
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index dbb940ab45..6a202ba1f5 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -545,6 +545,13 @@ struct ObjectUnsafe {
     return const_cast<TVMFFIObject*>(&(src->header_));
   }
 
+  template <typename T>
+  static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromOwned(Object* raw_ptr) {
+    tvm::ffi::ObjectPtr<T> ptr;
+    ptr.data_ = raw_ptr;
+    return ptr;
+  }
+
   // Create ObjectPtr from unknowned ptr
   template <typename T>
   static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromUnowned(Object* raw_ptr) {
@@ -556,7 +563,10 @@ struct ObjectUnsafe {
     return tvm::ffi::ObjectPtr<T>(reinterpret_cast<Object*>(obj_ptr));
   }
 
-  // Interactions with Any system
+  static TVM_FFI_INLINE void DecRefObjectHandle(TVMFFIObjectHandle handle) {
+    reinterpret_cast<Object*>(handle)->DecRef();
+  }
+
   static TVM_FFI_INLINE void DecRefObjectInAny(TVMFFIAny* src) {
     reinterpret_cast<Object*>(src->v_obj)->DecRef();
   }
diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc
index 5da4e0178b..f306af09f6 100644
--- a/ffi/src/ffi/function.cc
+++ b/ffi/src/ffi/function.cc
@@ -22,9 +22,14 @@
  */
 #include <tvm/ffi/any.h>
 #include <tvm/ffi/c_api.h>
+#include <tvm/ffi/cast.h>
+#include <tvm/ffi/container/array.h>
 #include <tvm/ffi/error.h>
+#include <tvm/ffi/function.h>
 #include <tvm/ffi/string.h>
 
+#include <unordered_map>
+
 namespace tvm {
 namespace ffi {
 
@@ -49,10 +54,99 @@ class SafeCallContext {
   Any last_error_;
 };
 
+/*!
+ * \brief Global function table.
+ *
+
+ * \note We do not use mutex to guard updating of GlobalFunctionTable
+ *
+ * The assumption is that updating of GlobalFunctionTable will be done
+ * in the main thread during initialization or loading, or
+ * explicitly locked from the caller.
+ *
+ * Then the followup code will leverage the information
+ */
+class GlobalFunctionTable {
+ public:
+  void Update(const String& name, Function func, bool can_override) {
+    if (table_.count(name)) {
+      if (!can_override) {
+        TVM_FFI_THROW(RuntimeError) << "Global Function `" << name << "` is 
already registered";
+      }
+    }
+    table_[name] = new Function(func);
+  }
+
+  bool Remove(const String& name) {
+    auto it = table_.find(name);
+    if (it == table_.end()) return false;
+    table_.erase(it);
+    return true;
+  }
+
+  const Function* Get(const String& name) {
+    auto it = table_.find(name);
+    if (it == table_.end()) return nullptr;
+    return it->second;
+  }
+
+  Array<String> ListNames() const {
+    Array<String> names;
+    names.reserve(table_.size());
+    for (const auto& kv : table_) {
+      names.push_back(kv.first);
+    }
+    return names;
+  }
+
+  static GlobalFunctionTable* Global() {
+    // We deliberately create a new instance via raw new
+    // This is because GlobalFunctionTable can contain callbacks into
+    // the host language (Python) and the resource can become invalid
+    // indeterministic order of destruction and forking.
+    // The resources will only be recycled during program exit.
+    static GlobalFunctionTable* inst = new GlobalFunctionTable();
+    return inst;
+  }
+
+ private:
+  // deliberately track function pointer without recycling
+  // to avoid
+  std::unordered_map<String, Function*> table_;
+};
 }  // namespace ffi
 }  // namespace tvm
 
-extern "C" {
+int TVMFFIFuncCreate(void* self, TVMFFISafeCallType safe_call, void 
(*deleter)(void* self),
+                     TVMFFIObjectHandle* out) {
+  using namespace tvm::ffi;
+  TVM_FFI_SAFE_CALL_BEGIN();
+  Function func = Function::FromExternC(self, safe_call, deleter);
+  *out = details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectRef(&func);
+  TVM_FFI_SAFE_CALL_END();
+}
+
+int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int override) {
+  using namespace tvm::ffi;
+  TVM_FFI_SAFE_CALL_BEGIN();
+  GlobalFunctionTable::Global()->Update(name, 
GetRef<Function>(static_cast<FunctionObj*>(f)),
+                                        override != 0);
+  TVM_FFI_SAFE_CALL_END();
+}
+
+int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out) {
+  using namespace tvm::ffi;
+  TVM_FFI_SAFE_CALL_BEGIN();
+  const Function* fp = GlobalFunctionTable::Global()->Get(name);
+  if (fp != nullptr) {
+    Function func(*fp);
+    *out = details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectRef(&func);
+  } else {
+    *out = nullptr;
+  }
+  TVM_FFI_SAFE_CALL_END();
+}
+
 void TVMFFISetLastError(const TVMFFIAny* error_view) {
   tvm::ffi::SafeCallContext::ThreadLocal()->SetLastError(error_view);
 }
@@ -60,4 +154,12 @@ void TVMFFISetLastError(const TVMFFIAny* error_view) {
 void TVMFFIMoveFromLastError(TVMFFIAny* result) {
   tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromLastError(result);
 }
-}
+
+TVM_FFI_REGISTER_GLOBAL("tvm_ffi.GlobalFunctionRemove")
+    .set_body_typed([](const tvm::ffi::String& name) -> bool {
+      return tvm::ffi::GlobalFunctionTable::Global()->Remove(name);
+    });
+
+TVM_FFI_REGISTER_GLOBAL("tvm_ffi.GlobalFunctionListNames").set_body_typed([]() 
{
+  return tvm::ffi::GlobalFunctionTable::Global()->ListNames();
+});
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index d0782d1b50..8941753df8 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -34,12 +34,13 @@ namespace tvm {
 namespace ffi {
 
 /*!
- * \brief Type context that manages the type hierarchy information.
+ * \brief Global registry that manages
  *
- * \note We do not use mutex to guard updating of TypeContext
+ * \note We do not use mutex to guard updating of TypeTable
  *
- * The assumption is that updating of TypeContext will be done
- * in the main thread during initialization or loading.
+ * The assumption is that updating of TypeTable will be done
+ * in the main thread during initialization or loading, or
+ * explicitly locked from the caller.
  *
  * Then the followup code will leverage the information
  */
@@ -223,7 +224,11 @@ class TypeTable {
 }  // namespace ffi
 }  // namespace tvm
 
-extern "C" {
+int TVMFFIObjectFree(TVMFFIObjectHandle handle) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle);
+  TVM_FFI_SAFE_CALL_END();
+}
 
 int32_t TVMFFIGetOrAllocTypeIndex(const char* type_key, int32_t 
static_type_index,
                                   int32_t type_depth, int32_t num_child_slots,
@@ -240,4 +245,3 @@ const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) 
{
   return tvm::ffi::TypeTable::Global()->GetTypeInfo(type_index);
   TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo);
 }
-}  // extern "C"
diff --git a/ffi/src/ffi/traceback.cc b/ffi/src/ffi/traceback.cc
index 70de055719..6de3efe7ca 100644
--- a/ffi/src/ffi/traceback.cc
+++ b/ffi/src/ffi/traceback.cc
@@ -149,15 +149,12 @@ __attribute__((constructor)) void 
install_signal_handler(void) {
 }  // namespace ffi
 }  // namespace tvm
 
-extern "C" {
 const char* TVMFFITraceback(const char*, const char*, int) {
   static thread_local std::string traceback_str;
   traceback_str = ::tvm::ffi::Traceback();
   return traceback_str.c_str();
 }
-}  // extern "C"
 #else
-extern "C" {
 // fallback implementation simply print out the last trace
 const char* TVMFFITraceback(const char* filename, const char* func, int 
lineno) {
   static thread_local std::string traceback_str;
@@ -167,6 +164,5 @@ const char* TVMFFITraceback(const char* filename, const 
char* func, int lineno)
   traceback_str = traceback_stream.str();
   return traceback_str.c_str();
 }
-}  // extern "C"
 #endif  // TVM_FFI_USE_LIBBACKTRACE
 #endif  // _MSC_VER
diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc
index 99e7819b23..367e613665 100644
--- a/ffi/tests/cpp/test_function.cc
+++ b/ffi/tests/cpp/test_function.cc
@@ -18,6 +18,7 @@
  */
 #include <gtest/gtest.h>
 #include <tvm/ffi/any.h>
+#include <tvm/ffi/container/array.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/memory.h>
 
@@ -112,4 +113,19 @@ TEST(Func, FromUnpacked) {
       },
       ::tvm::ffi::Error);
 }
+
+TEST(Func, Global) {
+  Function::SetGlobal("testing.add1",
+                      Function::FromUnpacked([](const int32_t& a) -> int { 
return a + 1; }));
+  Function fadd1 = Function::GetGlobal("testing.add1");
+  int b = fadd1(1);
+  EXPECT_EQ(b, 2);
+  Function fnot_exist = Function::GetGlobal("testing.not_existing_func");
+  EXPECT_TRUE(fnot_exist == nullptr);
+
+  Array<String> names = 
Function::GetGlobal("tvm_ffi.GlobalFunctionListNames")();
+
+  EXPECT_TRUE(std::find(names.begin(), names.end(), "testing.add1") != 
names.end());
+}
+
 }  // namespace
diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc
index 23310824a6..04348292b7 100644
--- a/ffi/tests/cpp/test_string.cc
+++ b/ffi/tests/cpp/test_string.cc
@@ -18,6 +18,7 @@
  */
 #include <gtest/gtest.h>
 #include <tvm/ffi/any.h>
+#include <tvm/ffi/cast.h>
 #include <tvm/ffi/string.h>
 
 namespace {

Reply via email to