This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit dabd53e4c4800acb2df124a51aaf9f2303beeebe Author: tqchen <[email protected]> AuthorDate: Sat May 3 09:05:08 2025 -0400 [FFI] Use stl style get for Variant/Tuple --- ffi/include/tvm/ffi/container/tuple.h | 5 ++++- ffi/include/tvm/ffi/container/variant.h | 4 ++-- ffi/tests/cpp/test_tuple.cc | 32 ++++++++++++++++---------------- ffi/tests/cpp/test_variant.cc | 16 ++++++++-------- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 260237a08c..63c36467fe 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -91,9 +91,10 @@ class Tuple : public ObjectRef { * * \tparam I The index of the element to get * \return The I-th element of the tuple + * \note We use stl style since get usually is like a getter. */ template <size_t I> - auto Get() const { + auto get() const { static_assert(I < sizeof...(Types), "Tuple index out of bounds"); using ReturnType = std::tuple_element_t<I, std::tuple<Types...>>; const Any* ptr = GetArrayObj()->begin() + I; @@ -109,6 +110,8 @@ class Tuple : public ObjectRef { * * \note This function will perform copy on write if underlying * container is not uniquely owned. + * We use CamelCase since Set can cause copy on write + * and is more complicated than simple field setter. */ template <size_t I, typename U> void Set(U&& item) { diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index caf11557e6..1455a5b34a 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -93,12 +93,12 @@ class Variant { } template <typename T, typename = enable_if_variant_contains_t<T>> - TVM_FFI_INLINE T Get() const& { + TVM_FFI_INLINE T get() const& { return data_.cast<T>(); } template <typename T, typename = enable_if_variant_contains_t<T>> - TVM_FFI_INLINE T Get() && { + TVM_FFI_INLINE T get() && { return std::move(data_).cast<T>(); } diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc index 9d24e60056..1fe9ca74e8 100644 --- a/ffi/tests/cpp/test_tuple.cc +++ b/ffi/tests/cpp/test_tuple.cc @@ -28,16 +28,16 @@ using namespace tvm::ffi::testing; TEST(Tuple, Basic) { Tuple<int, float> tuple0(1, 2.0f); - EXPECT_EQ(tuple0.Get<0>(), 1); - EXPECT_EQ(tuple0.Get<1>(), 2.0f); + EXPECT_EQ(tuple0.get<0>(), 1); + EXPECT_EQ(tuple0.get<1>(), 2.0f); Tuple<int, float> tuple1 = tuple0; EXPECT_EQ(tuple0.use_count(), 2); // test copy on write tuple1.Set<0>(3); - EXPECT_EQ(tuple0.Get<0>(), 1); - EXPECT_EQ(tuple1.Get<0>(), 3); + EXPECT_EQ(tuple0.get<0>(), 1); + EXPECT_EQ(tuple1.get<0>(), 3); EXPECT_EQ(tuple0.use_count(), 1); EXPECT_EQ(tuple1.use_count(), 1); @@ -45,7 +45,7 @@ TEST(Tuple, Basic) { // copy on write not triggered because // tuple1 is unique. tuple1.Set<1>(4); - EXPECT_EQ(tuple1.Get<1>(), 4.0f); + EXPECT_EQ(tuple1.get<1>(), 4.0f); EXPECT_EQ(tuple1.use_count(), 1); // default state @@ -53,15 +53,15 @@ TEST(Tuple, Basic) { EXPECT_EQ(tuple2.use_count(), 1); tuple2.Set<0>(1); tuple2.Set<1>(2.0f); - EXPECT_EQ(tuple2.Get<0>(), 1); - EXPECT_EQ(tuple2.Get<1>(), 2.0f); + EXPECT_EQ(tuple2.get<0>(), 1); + EXPECT_EQ(tuple2.get<1>(), 2.0f); // tuple of object and primitive Tuple<TInt, int> tuple3(1, 2); - EXPECT_EQ(tuple3.Get<0>()->value, 1); - EXPECT_EQ(tuple3.Get<1>(), 2); + EXPECT_EQ(tuple3.get<0>()->value, 1); + EXPECT_EQ(tuple3.get<1>(), 2); tuple3.Set<0>(4); - EXPECT_EQ(tuple3.Get<0>()->value, 4); + EXPECT_EQ(tuple3.get<0>()->value, 4); } TEST(Tuple, AnyConvert) { @@ -78,16 +78,16 @@ TEST(Tuple, AnyConvert) { Any any0 = view0; // trigger a copy due to implict conversion - auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>(); + auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>() ; EXPECT_TRUE(!tuple0.same_as(tuple2)); - EXPECT_EQ(tuple2.Get<0>()->value, 1); - EXPECT_EQ(tuple2.Get<1>()->value, 2); + EXPECT_EQ(tuple2.get<0>()->value, 1); + EXPECT_EQ(tuple2.get<1>()->value, 2); } TEST(Tuple, FromUnpacked) { // try decution Function fadd1 = Function::FromUnpacked([](const Tuple<int, TPrimExpr>& a) -> int { - return a.Get<0>() + static_cast<int>(a.Get<1>()->value); + return a.get<0>() + static_cast<int>(a.get<1>()->value); }); int b = fadd1(Tuple<int, float>(1, 2)).cast<int>(); EXPECT_EQ(b, 3); @@ -130,8 +130,8 @@ TEST(Tuple, FromUnpacked) { TEST(Tuple, Upcast) { Tuple<int, float> t0(1, 2.0f); Tuple<Any, Any> t1 = t0; - EXPECT_EQ(t1.Get<0>().cast<int>(), 1); - EXPECT_EQ(t1.Get<1>().cast<float>(), 2.0f); + EXPECT_EQ(t1.get<0>().cast<int>(), 1); + EXPECT_EQ(t1.get<1>().cast<float>(), 2.0f); static_assert(details::type_contains_v<Tuple<Any, Any>, Tuple<int, float>>); static_assert(details::type_contains_v<Tuple<Any, float>, Tuple<int, float>>); static_assert(details::type_contains_v<Tuple<TNumber, float>, Tuple<TInt, float>>); diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc index 65b8a1c9e6..db29bdac50 100644 --- a/ffi/tests/cpp/test_variant.cc +++ b/ffi/tests/cpp/test_variant.cc @@ -31,13 +31,13 @@ using namespace tvm::ffi::testing; TEST(Variant, Basic) { Variant<int, float> v1 = 1; - EXPECT_EQ(v1.Get<int>(), 1); + EXPECT_EQ(v1.get<int>(), 1); EXPECT_EQ(v1.as<float>().value(), 1.0f); Variant<int, float> v2 = 2.0f; - EXPECT_EQ(v2.Get<float>(), 2.0f); + EXPECT_EQ(v2.get<float>(), 2.0f); v2 = v1; - EXPECT_EQ(v2.Get<int>(), 1); + EXPECT_EQ(v2.get<int>(), 1); } TEST(Variant, AnyConvert) { @@ -48,13 +48,13 @@ TEST(Variant, AnyConvert) { // implicit convert to variant Any any0 = 1; auto v1 = any0.cast<Variant<TPrimExpr, Array<TPrimExpr>>>(); - EXPECT_EQ(v1.Get<TPrimExpr>()->value, 1); + EXPECT_EQ(v1.get<TPrimExpr>()->value, 1); // move from any to variant Variant<TInt, int> v2 = TInt(1); Any any1 = std::move(v2); auto v3 = std::move(any1).cast<Variant<TInt, int>>(); - auto v4 = std::move(v3).Get<TInt>(); + auto v4 = std::move(v3).get<TInt>(); EXPECT_EQ(v4->value, 1); EXPECT_EQ(v4.use_count(), 1); } @@ -78,7 +78,7 @@ TEST(Variant, FromUnpacked) { if (auto opt_int = a.as<int>()) { return opt_int.value() + 1; } else { - return a.Get<TInt>()->value + 1; + return a.get<TInt>()->value + 1; } }); int b = fadd1(1).cast<int>(); @@ -104,7 +104,7 @@ TEST(Variant, FromUnpacked) { if (auto opt_int = a[0].as<int>()) { return opt_int.value() + 1; } else { - return a[0].Get<TInt>()->value + 1; + return a[0].get<TInt>()->value + 1; } }); int c = fadd2(Array<Any>({1, 2})).cast<int>(); @@ -131,7 +131,7 @@ TEST(Variant, Upcast) { Array<int> a0 = {1, 2, 3}; static_assert(details::type_contains_v<Array<Variant<int, float>>, Array<int>>); Array<Variant<int, float>> a1 = a0; - EXPECT_EQ(a1[0].Get<int>(), 1); + EXPECT_EQ(a1[0].get<int>(), 1); } } // namespace
