This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f1b459  [FEAT][REFLECTION] Add tvm::ffi::reflection::overload_cast to 
pick overloaded function (#582)
6f1b459 is described below

commit 6f1b459532bd54c942cda3d932ea4f3f629c6df0
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun May 10 17:36:25 2026 -0400

    [FEAT][REFLECTION] Add tvm::ffi::reflection::overload_cast to pick 
overloaded function (#582)
    
    ## Summary
    
    Introduce `tvm::ffi::reflection::overload_cast<Args...>` — a `constexpr`
    helper for picking a specific overload of an overloaded callable by
    spelling out a parameter-type prefix that uniquely identifies it.
    Trailing parameter types (if any) are deduced from the picked overload's
    signature. The result is a typed function pointer (member or free)
    usable wherever a typed fn ptr is required, including as a non-type
    template argument.
    
    If the prefix matches multiple overloads, the call is ambiguous and the
    caller spells more parameters until exactly one overload matches.
    Const-qualified members are picked via the
    `tvm::ffi::reflection::const_` tag.
---
 include/tvm/ffi/reflection/registry.h | 117 ++++++++++++++++++++++++++++++++++
 tests/cpp/test_reflection.cc          | 108 +++++++++++++++++++++++++++++++
 2 files changed, 225 insertions(+)

diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index 3e715fe..543321c 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -1034,6 +1034,123 @@ inline void EnsureTypeAttrColumn(std::string_view name) 
{
                                                  reinterpret_cast<const 
TVMFFIAny*>(&any_view)));
 }
 
+/// \cond Doxygen_Suppress
+namespace details {
+
+/*!
+ * \brief Implementation struct for overload_cast.
+ *
+ * Provides operator() overloads for each callable kind (free function,
+ * non-const member, const member), in two flavors: full match where Args...
+ * is the entire parameter list, and prefix match where Args... is a leading
+ * prefix and the trailing parameters Rest... are deduced from the picked
+ * overload's signature.
+ */
+template <typename... Args>
+struct OverloadCastImpl {
+  // The first triplet handles the case where Args... is the complete
+  // parameter list of the picked overload. The second triplet handles
+  // the prefix-match case where the picked overload has additional
+  // trailing parameters Rest... beyond Args...; partial ordering picks
+  // the first triplet when both apply, which lets the caller
+  // disambiguate against shared-prefix overload sets by spelling the
+  // full parameter list.
+
+  template <typename Ret>
+  constexpr auto operator()(Ret (*fn)(Args...)) const noexcept {
+    return fn;
+  }
+  template <typename Ret, typename Cls>
+  constexpr auto operator()(Ret (Cls::*pmf)(Args...), std::false_type = {}) 
const noexcept {
+    return pmf;
+  }
+  template <typename Ret, typename Cls>
+  constexpr auto operator()(Ret (Cls::*pmf)(Args...) const, std::true_type) 
const noexcept {
+    return pmf;
+  }
+
+  template <typename Ret, typename... Rest>
+  constexpr auto operator()(Ret (*fn)(Args..., Rest...)) const noexcept {
+    return fn;
+  }
+  template <typename Ret, typename Cls, typename... Rest>
+  constexpr auto operator()(Ret (Cls::*pmf)(Args..., Rest...),
+                            std::false_type = {}) const noexcept {
+    return pmf;
+  }
+  template <typename Ret, typename Cls, typename... Rest>
+  constexpr auto operator()(Ret (Cls::*pmf)(Args..., Rest...) const,
+                            std::true_type) const noexcept {
+    return pmf;
+  }
+};
+
+}  // namespace details
+/// \endcond
+
+/*!
+ * \brief Cast an overloaded callable to a specific overload, picked by
+ *        spelling out a parameter-type prefix that uniquely identifies it.
+ *
+ * `Args...` is matched against the leading parameters of each candidate
+ * overload; the trailing parameter types (if any) are deduced from the
+ * picked overload's signature. The returned value is a constexpr function
+ * pointer (member or free) and can be used wherever a typed function
+ * pointer is required, including as a non-type template argument.
+ *
+ * If the prefix matches multiple overloads (e.g. two overloads share the
+ * same leading parameters), the call is ambiguous and the caller must
+ * spell more parameters until exactly one overload matches.
+ *
+ * \note When picking a const-qualified member function, `refl::const_` must
+ *       be passed as the second argument even when it is the only overload
+ *       of its name. Without the tag the call does not compile.
+ *
+ * \note This helper can be more permissive than some `overload_cast` variants
+ *       in existing packages that require the full parameter list to be
+ *       spelled out: here a parameter-type prefix is accepted and the
+ *       trailing types are deduced from the picked overload.
+ *
+ * \code{.cpp}
+ *   class Pet {
+ *    public:
+ *     void Set(int);
+ *     void Set(const std::string&);
+ *     int  Feed(const Cat*, int amount);
+ *     int  Feed(const Dog*, int amount);
+ *     int  Get(int);
+ *     int  Get(int) const;
+ *   };
+ *
+ *   namespace refl = tvm::ffi::reflection;
+ *
+ *   // Spell only the disambiguating first arg; the trailing `int amount`
+ *   // is deduced from the picked overload's signature.
+ *   auto p_feed_cat = refl::overload_cast<const Cat*>(&Pet::Feed);
+ *   //   decltype(p_feed_cat) == int (Pet::*)(const Cat*, int)
+ *
+ *   // Spell the full parameter list when overloads share a prefix.
+ *   auto p_set_int = refl::overload_cast<int>(&Pet::Set);
+ *
+ *   // Const-qualified member — opt in via the const_ tag:
+ *   auto p_get_const = refl::overload_cast<int>(&Pet::Get, refl::const_);
+ *
+ *   // Use directly as a non-type template argument:
+ *   template <auto F> struct UseAsTemplateArg { ... };
+ *   using U = UseAsTemplateArg<refl::overload_cast<const Cat*>(&Pet::Feed)>;
+ * \endcode
+ */
+template <typename... Args>
+inline constexpr details::OverloadCastImpl<Args...> overload_cast = {};
+
+/// \cond Doxygen_Suppress
+// `const_`'s trailing underscore triggers RST hyperlink-reference syntax in
+// the exhale-generated per-variable page; suppress doc emission for it.
+// The symbol is still referenced (and rendered as inline literal) from the
+// overload_cast docstring above.
+inline constexpr auto const_ = std::true_type{};
+/// \endcond
+
 }  // namespace reflection
 }  // namespace ffi
 }  // namespace tvm
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index f9d567f..0593eca 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -625,4 +625,112 @@ TEST(Reflection, AutoInitSimpleTooManyArgs) {
   EXPECT_THROW(auto_init(int64_t{1}, int64_t{2}, int64_t{3}), std::exception);
 }
 
+// ---------------------------------------------------------------------------
+// overload_cast — pick an overload by prefix-matching its parameter types.
+// ---------------------------------------------------------------------------
+
+namespace overload_cast_test {
+
+struct Cat {};
+struct Dog {};
+
+// Pet: each Feed overload has a unique first arg plus a trailing context
+// (`int amount`) that the caller doesn't have to spell out. Get is const
+// vs non-const overloaded with identical params (selected via const_ tag).
+struct Pet {
+  int Feed(const Cat*, int amount) { return 100 + amount; }
+  int Feed(const Dog*, int amount) { return 200 + amount; }
+  int Get(int x) { return 4000 + x; }
+  int Get(int x) const { return 5000 + x; }
+};
+
+// Mix: overloads share a leading prefix — spelling more parameters
+// disambiguates against the longer variant.
+struct Mix {
+  int Run(int, int) { return 1000; }
+  int Run(int, double) { return 2000; }
+  int Run(int, int, int) { return 3000; }
+};
+
+int FreeFeed(const Cat*, int x) { return 6000 + x; }
+int FreeFeed(const Dog*, int x) { return 7000 + x; }
+
+template <auto Method>
+struct CallVia {
+  template <typename Self, typename... Args>
+  static auto Run(Self&& self, Args&&... args) {
+    return (std::forward<Self>(self).*Method)(std::forward<Args>(args)...);
+  }
+};
+
+}  // namespace overload_cast_test
+
+TEST(OverloadCast, PrefixMatch) {
+  using namespace overload_cast_test;
+  namespace refl = tvm::ffi::reflection;
+  Pet p;
+  Cat cat;
+  Dog dog;
+
+  // (a) Member with unique first arg per overload: spelling only the
+  //     disambiguating prefix picks the overload and deduces the
+  //     trailing `int amount` from the picked signature.
+  auto p_cat = refl::overload_cast<const Cat*>(&Pet::Feed);
+  static_assert(std::is_same_v<decltype(p_cat), int (Pet::*)(const Cat*, int)>,
+                "prefix match must deduce trailing arg types");
+  EXPECT_EQ((p.*p_cat)(&cat, 7), 107);
+
+  auto p_dog = refl::overload_cast<const Dog*>(&Pet::Feed);
+  EXPECT_EQ((p.*p_dog)(&dog, 12), 212);
+
+  // (b) Free function with the same shape — trailing arg deduced.
+  auto p_free_cat = refl::overload_cast<const Cat*>(&FreeFeed);
+  EXPECT_EQ(p_free_cat(&cat, 7), 6007);
+  auto p_free_dog = refl::overload_cast<const Dog*>(&FreeFeed);
+  EXPECT_EQ(p_free_dog(&dog, 7), 7007);
+}
+
+TEST(OverloadCast, AmbiguousPrefixRequiresMoreSpelling) {
+  using namespace overload_cast_test;
+  namespace refl = tvm::ffi::reflection;
+  Mix m;
+
+  // Mix::Run has three overloads:
+  //   Run(int, int), Run(int, double), Run(int, int, int)
+  // Spelling only <int> would be ambiguous (all three start with int).
+  // Spelling enough parameters to identify exactly one overload picks it.
+  EXPECT_EQ((m.*refl::overload_cast<int, int>(&Mix::Run))(0, 0), 1000);
+  EXPECT_EQ((m.*refl::overload_cast<int, double>(&Mix::Run))(0, 0.0), 2000);
+  EXPECT_EQ((m.*refl::overload_cast<int, int, int>(&Mix::Run))(0, 0, 0), 3000);
+}
+
+TEST(OverloadCast, ConstQualifiedMember) {
+  using namespace overload_cast_test;
+  namespace refl = tvm::ffi::reflection;
+  Pet p;
+  const Pet& cp = p;
+
+  // Non-const overload — no tag.
+  EXPECT_EQ((p.*refl::overload_cast<int>(&Pet::Get))(7), 4007);
+
+  // Const overload — const_ tag required (even when the const overload
+  // is the only one with that signature, address-of-overload alone
+  // cannot select it from the operator() overload set).
+  EXPECT_EQ((cp.*refl::overload_cast<int>(&Pet::Get, refl::const_))(7), 5007);
+}
+
+TEST(OverloadCast, NonTypeTemplateArgument) {
+  using namespace overload_cast_test;
+  namespace refl = tvm::ffi::reflection;
+  Pet p;
+  Mix m;
+  Cat cat;
+
+  // Prefix match composed as a non-type template argument.
+  EXPECT_EQ((CallVia<refl::overload_cast<const Cat*>(&Pet::Feed)>::Run(p, 
&cat, 7)), 107);
+
+  // Disambiguated 3-arg overload as a non-type template argument.
+  EXPECT_EQ((CallVia<refl::overload_cast<int, int, int>(&Mix::Run)>::Run(m, 0, 
0, 0)), 3000);
+}
+
 }  // namespace

Reply via email to