This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit d5b942f94e934e0d4cd61de669bafed17cbababe
Author: tqchen <[email protected]>
AuthorDate: Fri Sep 13 15:20:58 2024 -0400

    [FFI] Optional Support
---
 ffi/include/tvm/ffi/any.h                |  17 +-
 ffi/include/tvm/ffi/container/optional.h | 300 +++++++++++++++++++++++++++++++
 ffi/include/tvm/ffi/object.h             |   2 +-
 ffi/tests/example/test_any.cc            |  22 +--
 ffi/tests/example/test_optional.cc       | 107 +++++++++++
 5 files changed, 431 insertions(+), 17 deletions(-)

diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 031038d4fc..b7181b2297 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -112,12 +112,12 @@ class AnyView {
    * \return The underlying supporting data of any view
    * \note This function is used only for testing purposes.
    */
-  TVMFFIAny AsTVMFFIAny() const { return data_; }
+  TVMFFIAny CopyToTVMFFIAny() const { return data_; }
   /*!
    * \return Create an AnyView from TVMFFIAny
    * \param data the underlying ffi data.
    */
-  static AnyView FromTVMFFIAny(TVMFFIAny data) {
+  static AnyView CopyFromTVMFFIAny(TVMFFIAny data) {
     AnyView view;
     view.data_ = data;
     return view;
@@ -142,7 +142,10 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* 
data,
 }  // namespace details
 
 /*!
- * \brief
+ * \brief Managed Any that takes strong reference to a value.
+ *
+ * \note Develooper invariance: the TVMFFIAny data_
+ *       in the Any can be safely used in AnyView.
  */
 class Any {
  protected:
@@ -198,7 +201,7 @@ class Any {
     return *this;
   }
   /*! \brief Any can be converted to AnyView in zero cost. */
-  operator AnyView() { return AnyView::FromTVMFFIAny(data_); }
+  operator AnyView() { return AnyView::CopyFromTVMFFIAny(data_); }
   // constructor from general types
   template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>>
   Any(T other) {  // NOLINT(*)
@@ -227,10 +230,14 @@ class Any {
     TVM_FFI_UNREACHABLE();
   }
 
+  bool operator==(std::nullptr_t) const { return data_.type_index == 
TypeIndex::kTVMFFINone; }
+
+  bool operator!=(std::nullptr_t) const { return data_.type_index != 
TypeIndex::kTVMFFINone; }
+
   // FFI related operations
   /*!
    * Move the current data to FFI any
-   * \parma result the output to nmove to
+   * \param result the output to nmove to
    */
   void MoveToTVMFFIAny(TVMFFIAny* result) {
     *result = data_;
diff --git a/ffi/include/tvm/ffi/container/optional.h 
b/ffi/include/tvm/ffi/container/optional.h
new file mode 100644
index 0000000000..5f08db0a5b
--- /dev/null
+++ b/ffi/include/tvm/ffi/container/optional.h
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/container/optional.h
+ * \brief Runtime Optional container types.
+ */
+#ifndef TVM_FFI_CONTAINER_OPTIONAL_H_
+#define TVM_FFI_CONTAINER_OPTIONAL_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/object.h>
+
+#include <optional>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Optional that is backed by Any
+ *
+ * nullptr will be treated as NullOpt
+ *
+ * \tparam T any value will be treated as
+ */
+template <typename T>
+class Optional<T, std::enable_if_t<!std::is_base_of_v<ObjectRef, T>>> {
+ public:
+  static_assert(!std::is_same_v<std::nullptr_t, T>, "Optional<nullptr> is not 
well defined");
+  // default constructors.
+  Optional() = default;
+  Optional(const Optional<T>& other) : data_(other.data_) {}
+  Optional(Optional<T>&& other) : data_(std::move(other.data_)) {}
+  Optional<T>& operator=(const Optional<T>& other) {
+    data_ = other.data_;
+    return *this;
+  }
+  Optional<T>& operator=(Optional<T>&& other) {
+    data_ = std::move(other.data_);
+    return *this;
+  }
+  // normal value handling.
+  Optional(T other)  // NOLINT(*)
+      : data_(std::move(other)) {}
+  Optional<T>& operator=(T other) {
+    data_ = std::move(other);
+    return *this;
+  }
+  // nullptr handling.
+  // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
+  explicit Optional(std::nullptr_t) {}
+  Optional<T>& operator=(std::nullptr_t) {
+    data_ = std::nullopt;
+    return *this;
+  }
+  /*!
+   * \return A not-null container value in the optional.
+   * \note This function performs not-null checking.
+   */
+  T value() const {
+    if (!data_.has_value()) {
+      TVM_FFI_THROW(RuntimeError) << "Back optional access";
+    }
+    return *data_;
+  }
+  /*!
+   * \return A not-null container value in the optional.
+   * \note This function performs not-null checking.
+   */
+  T value_or(T default_value) const { return data_.value_or(default_value); }
+
+  /*! \return Whether the container is not nullptr.*/
+  explicit operator bool() const { return data_.has_value(); }
+
+  bool has_value() const { return data_.has_value(); }
+
+  bool operator==(const Optional<T>& other) const { return data_ == 
other.data_; }
+
+  bool operator!=(const Optional<T>& other) const { return data_ != 
other.data_; }
+
+  template <typename U>
+  bool operator==(const U& other) const {
+    return data_ == other;
+  }
+
+  template <typename U>
+  bool operator!=(const U& other) const {
+    return data_ != other;
+  }
+
+  // operator overloadings with nullptr
+  bool operator==(std::nullptr_t) const { return !data_.has_value(); }
+  bool operator!=(std::nullptr_t) const { return data_.has_value(); }
+
+  // helper function to move out value
+  T&& MoveValueNoCheck() { return std::move(*data_); }
+  // helper function to copy out value
+  T CopyValueNoCheck() const { return *data_; }
+
+ private:
+  std::optional<T> data_;
+};
+
+/*!
+ * \brief Specialization of Optional for ObjectRef.
+ *
+ * In such cases, nullptr is treated as NullOpt.
+ * This specialization reduces the storage cost of
+ * Optional for ObjectRef.
+ *
+ * \tparam T The original ObjectRef.
+ */
+template <typename T>
+class Optional<T, std::enable_if_t<std::is_base_of_v<ObjectRef, T>>> : public 
ObjectRef {
+ public:
+  using ContainerType = typename T::ContainerType;
+  static_assert(std::is_base_of<ObjectRef, T>::value, "Optional is only 
defined for ObjectRef.");
+  // default constructors.
+  Optional() = default;
+  Optional(const Optional<T>& other) : ObjectRef(other.data_) {}
+  Optional(Optional<T>&& other) : ObjectRef(std::move(other.data_)) {}
+  Optional<T>& operator=(const Optional<T>& other) {
+    data_ = other.data_;
+    return *this;
+  }
+  Optional<T>& operator=(Optional<T>&& other) {
+    data_ = std::move(other.data_);
+    return *this;
+  }
+  /*!
+   * \brief Construct from an ObjectPtr
+   *        whose type already matches the ContainerType.
+   * \param ptr
+   */
+  explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+  /*! \brief Nullopt handling */
+  Optional(std::nullopt_t) {}  // NOLINT(*)
+  // nullptr handling.
+  // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
+  explicit Optional(std::nullptr_t) {}
+  Optional<T>& operator=(std::nullptr_t) {
+    data_ = nullptr;
+    return *this;
+  }
+  // normal value handling.
+  Optional(T other)  // NOLINT(*)
+      : ObjectRef(std::move(other)) {}
+  Optional<T>& operator=(T other) {
+    ObjectRef::operator=(std::move(other));
+    return *this;
+  }
+  // delete the int constructor
+  // since Optional<Integer>(0) is ambiguious
+  // 0 can be implicitly casted to nullptr_t
+  explicit Optional(int val) = delete;
+  Optional<T>& operator=(int val) = delete;
+  // helper function to move out value
+  T&& MoveOutValueNoCheck() { return T(std::move(data_)); }
+  /*!
+   * \return A not-null container value in the optional.
+   * \note This function performs not-null checking.
+   */
+  T value() const {
+    if (data_ == nullptr) {
+      TVM_FFI_THROW(RuntimeError) << "Bad optional access";
+    }
+    return T(data_);
+  }
+  /*!
+   * \return The internal object pointer with container type of T.
+   * \note This function do not perform not-null checking.
+   */
+  const ContainerType* get() const { return 
static_cast<ContainerType*>(data_.get()); }
+  /*!
+   * \return The contained value if the Optional is not null
+   *         otherwise return the default_value.
+   */
+  T value_or(T default_value) const { return data_ != nullptr ? T(data_) : 
default_value; }
+
+  /*! \return Whether the container is not nullptr.*/
+  explicit operator bool() const { return *this != nullptr; }
+  /*! \return Whether the container is not nullptr */
+  bool has_value() const { return *this != nullptr; }
+
+  // operator overloadings
+  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
+  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
+  auto operator==(const Optional<T>& other) const {
+    // support case where sub-class returns a symbolic ref type.
+    using RetType = decltype(value() == other.value());
+    if (same_as(other)) return RetType(true);
+    if (*this != nullptr && other != nullptr) {
+      return value() == other.value();
+    } else {
+      // one of them is nullptr.
+      return RetType(false);
+    }
+  }
+  auto operator!=(const Optional<T>& other) const {
+    // support case where sub-class returns a symbolic ref type.
+    using RetType = decltype(value() != other.value());
+    if (same_as(other)) return RetType(false);
+    if (*this != nullptr && other != nullptr) {
+      return value() != other.value();
+    } else {
+      // one of them is nullptr.
+      return RetType(true);
+    }
+  }
+  auto operator==(const T& other) const {
+    using RetType = decltype(value() == other);
+    if (same_as(other)) return RetType(true);
+    if (*this != nullptr) return value() == other;
+    return RetType(false);
+  }
+  auto operator!=(const T& other) const { return !(*this == other); }
+  template <typename U>
+  auto operator==(const U& other) const {
+    using RetType = decltype(value() == other);
+    if (*this == nullptr) return RetType(false);
+    return value() == other;
+  }
+  template <typename U>
+  auto operator!=(const U& other) const {
+    using RetType = decltype(value() != other);
+    if (*this == nullptr) return RetType(true);
+    return value() != other;
+  }
+  static constexpr bool _type_is_nullable = true;
+
+  // helper function to move out value
+  T&& MoveValueNoCheck() { return T(std::move(data_)); }
+  // helper function to copy out value
+  T CopyValueNoCheck() const { return T(data_); }
+};
+
+template <typename T>
+inline constexpr bool use_default_type_traits_v<Optional<T>> = false;
+
+template <typename T>
+struct TypeTraits<Optional<T>> : public TypeTraitsBase {
+  static TVM_FFI_INLINE void CopyToAnyView(const Optional<T>& src, TVMFFIAny* 
result) {
+    if (src.has_value()) {
+      TypeTraits<T>::CopyToAnyView(src.CopyValueNoCheck(), result);
+    } else {
+      TypeTraits<std::nullptr_t>::CopyToAnyView(nullptr, result);
+    }
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(Optional<T> src, TVMFFIAny* result) {
+    if (src.has_value()) {
+      TypeTraits<T>::MoveToAny(src.MoveValueNoCheck(), result);
+    } else {
+      TypeTraits<std::nullptr_t>::CopyToAnyView(nullptr, result);
+    }
+  }
+
+  static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
+    return TypeTraits<T>::GetMismatchTypeInfo(src);
+  }
+
+  static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) return true;
+    return TypeTraits<T>::CheckAnyView(src);
+  }
+
+  static TVM_FFI_INLINE Optional<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) return Optional<T>(nullptr);
+    return TypeTraits<T>::CopyFromAnyViewAfterCheck(src);
+  }
+
+  static TVM_FFI_INLINE std::optional<Optional<T>> TryCopyFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) return Optional<T>(nullptr);
+    return TypeTraits<T>::TryCopyFromAnyView(src);
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() {
+    return "Optional<" + TypeTraits<T>::TypeStr() + ">";
+  }
+};
+
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_CONTAINER_OPTIONAL_H_
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 120a735c82..07170dce8f 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -352,7 +352,7 @@ class ObjectPtr {
 };
 
 // Forward declaration, to prevent circular includes.
-template <typename T>
+template <typename T, typename = void>
 class Optional;
 
 /*! \brief Base class of all object reference */
diff --git a/ffi/tests/example/test_any.cc b/ffi/tests/example/test_any.cc
index 02d0ad6a23..36d4783dc5 100644
--- a/ffi/tests/example/test_any.cc
+++ b/ffi/tests/example/test_any.cc
@@ -29,7 +29,7 @@ using namespace tvm::ffi::testing;
 
 TEST(Any, Int) {
   AnyView view0;
-  EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
 
   std::optional<int64_t> opt_v0 = view0.TryAs<int64_t>();
   EXPECT_TRUE(!opt_v0.has_value());
@@ -48,21 +48,21 @@ TEST(Any, Int) {
       ::tvm::ffi::Error);
 
   AnyView view1 = 1;
-  EXPECT_EQ(view1.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIInt);
-  EXPECT_EQ(view1.AsTVMFFIAny().v_int64, 1);
+  EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt);
+  EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1);
 
   int32_t int_v1 = view1;
   EXPECT_EQ(int_v1, 1);
 
   int64_t v1 = 2;
   view0 = v1;
-  EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIInt);
-  EXPECT_EQ(view0.AsTVMFFIAny().v_int64, 2);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2);
 }
 
 TEST(Any, Float) {
   AnyView view0;
-  EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
 
   std::optional<double> opt_v0 = view0.TryAs<double>();
   EXPECT_TRUE(!opt_v0.has_value());
@@ -85,18 +85,18 @@ TEST(Any, Float) {
   EXPECT_EQ(float_v1, 1);
 
   AnyView view2 = 2.2;
-  EXPECT_EQ(view2.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat);
-  EXPECT_EQ(view2.AsTVMFFIAny().v_float64, 2.2);
+  EXPECT_EQ(view2.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat);
+  EXPECT_EQ(view2.CopyToTVMFFIAny().v_float64, 2.2);
 
   float v1 = 2;
   view0 = v1;
-  EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat);
-  EXPECT_EQ(view0.AsTVMFFIAny().v_float64, 2);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().v_float64, 2);
 }
 
 TEST(Any, Object) {
   AnyView view0;
-  EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
+  EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
 
   // int object is not nullable
   std::optional<TInt> opt_v0 = view0.TryAs<TInt>();
diff --git a/ffi/tests/example/test_optional.cc 
b/ffi/tests/example/test_optional.cc
new file mode 100644
index 0000000000..10d9189c23
--- /dev/null
+++ b/ffi/tests/example/test_optional.cc
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/optional.h>
+#include <tvm/ffi/memory.h>
+
+#include "./testing_object.h"
+
+namespace {
+
+using namespace tvm::ffi;
+using namespace tvm::ffi::testing;
+
+TEST(Optional, TInt) {
+  Optional<TInt> x;
+  Optional<TInt> y = TInt(11);
+  static_assert(sizeof(Optional<TInt>) == sizeof(ObjectRef));
+
+  EXPECT_TRUE(!x.has_value());
+  EXPECT_TRUE(x == nullptr);
+  EXPECT_EQ(x.value_or(TInt(12))->value, 12);
+
+  EXPECT_TRUE(y.has_value());
+  EXPECT_TRUE(y != nullptr);
+  EXPECT_EQ(y.value_or(TInt(12))->value, 11);
+}
+
+TEST(Optional, double) {
+  Optional<double> x;
+  Optional<double> y = 11.0;
+  static_assert(sizeof(Optional<double>) > sizeof(ObjectRef));
+
+  EXPECT_TRUE(!x.has_value());
+  EXPECT_TRUE(x == nullptr);
+  EXPECT_EQ(x.value_or(12), 12);
+  EXPECT_TRUE(x != 12);
+
+  EXPECT_TRUE(y.has_value());
+  EXPECT_TRUE(y != nullptr);
+  EXPECT_EQ(y.value_or(12), 11);
+  EXPECT_TRUE(y == 11);
+  EXPECT_TRUE(y != 12);
+}
+
+TEST(Optional, AnyConvert_int) {
+  Optional<int> opt_v0 = 1;
+  EXPECT_EQ(opt_v0.value(), 1);
+  EXPECT_TRUE(opt_v0 != nullptr);
+
+  AnyView view0 = opt_v0;
+  EXPECT_EQ(view0.operator int(), 1);
+
+  Any any1;
+  Optional<int> opt_v1 = any1;
+
+  EXPECT_TRUE(opt_v1 == nullptr);
+}
+
+TEST(Optional, AnyConvert_Array) {
+  AnyView view0;
+  Array<Array<TNumber>> arr_nested = {{}, {TInt(1), TFloat(2)}};
+  view0 = arr_nested;
+
+  Optional<Array<Array<TNumber>>> opt_arr = view0;
+  EXPECT_EQ(arr_nested.use_count(), 2);
+
+  Optional<Array<Array<TNumber>>> arr1 = view0;
+  EXPECT_EQ(arr_nested.use_count(), 3);
+  EXPECT_EQ(arr1.value()[1][1].as<TFloatObj>()->value, 2);
+
+  Any any1;
+  Optional<Array<Array<TNumber>>> arr2 = any1;
+  EXPECT_TRUE(arr2 == nullptr);
+
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] Optional<Array<Array<int>>> arr2 = view0;
+        } catch (const Error& error) {
+          EXPECT_EQ(error->kind, "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("to `Optional<Array<Array<int>>>`"), 
std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+}
+
+}  // namespace

Reply via email to