This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s3
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bd915404443b1536b63be8fb2c9852faa5fef29f
Author: tqchen <tianqi.tc...@gmail.com>
AuthorDate: Mon May 5 17:34:37 2025 -0400

    [FFI] Make Error to be ABI invariant
---
 ffi/include/tvm/ffi/c_api.h     | 14 ++++++------
 ffi/include/tvm/ffi/error.h     | 47 +++++++++++++++++++++++++----------------
 ffi/src/ffi/error.cc            |  7 ------
 python/tvm/ffi/cython/base.pxi  |  2 +-
 python/tvm/ffi/cython/error.pxi |  2 +-
 5 files changed, 37 insertions(+), 35 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 61738d1082..1d495d9c5e 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -224,6 +224,12 @@ typedef struct {
    * \brief The traceback of the error.
    */
   TVMFFIByteArray traceback;
+  /*!
+   * \brief Function handle to update the traceback of the error.
+   * \param self The self object handle.
+   * \param traceback The traceback to update.
+   */
+  void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* 
traceback);
 } TVMFFIErrorCell;
 
 /*!
@@ -483,14 +489,6 @@ TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const 
TVMFFIByteArray* kind,
                                                  const TVMFFIByteArray* 
message,
                                                  const TVMFFIByteArray* 
traceback);
 
-/*!
- * \brief Update the traceback of an Error object.
- * \param obj The error handle.
- * \param traceback The traceback to update.
- */
-TVM_FFI_DLL void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle obj,
-                                            const TVMFFIByteArray* traceback);
-
 /*!
  * \brief Check if there are any signals raised in the surrounding env.
  * \return 0 when success, nonzero when failure happens
diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h
index 4810754f17..239a0e500b 100644
--- a/ffi/include/tvm/ffi/error.h
+++ b/ffi/include/tvm/ffi/error.h
@@ -81,26 +81,39 @@ struct EnvErrorAlreadySet : public std::exception {};
  */
 class ErrorObj : public Object, public TVMFFIErrorCell {
  public:
-  /*!
-   * \brief Update the traceback of the error object.
-   * \param traceback The traceback to update.
-   */
-  void UpdateTraceback(const TVMFFIByteArray* traceback_str) {
-    this->traceback_data_ = std::string(traceback_str->data, 
traceback_str->size);
-    this->traceback = TVMFFIByteArray{this->traceback_data_.data(), 
this->traceback_data_.length()};
-  }
-
   static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError;
   static constexpr const char* _type_key = "object.Error";
 
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ErrorObj, Object);
+};
+
+namespace details {
+class ErrorObjFromStd : public ErrorObj {
+ public:
+  ErrorObjFromStd(std::string kind, std::string message, std::string traceback)
+      : kind_data_(kind), message_data_(message), traceback_data_(traceback) {
+    this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()};
+    this->message = TVMFFIByteArray{message_data_.data(), 
message_data_.length()};
+    this->traceback = TVMFFIByteArray{traceback_data_.data(), 
traceback_data_.length()};
+    this->update_traceback = UpdateTraceback;
+  }
 
  private:
-  friend class Error;
+  /*!
+   * \brief Update the traceback of the error object.
+   * \param traceback The traceback to update.
+   */
+  static void UpdateTraceback(TVMFFIObjectHandle self, const TVMFFIByteArray* 
traceback_str) {
+    ErrorObjFromStd* obj = static_cast<ErrorObjFromStd*>(self);
+    obj->traceback_data_ = std::string(traceback_str->data, 
traceback_str->size);
+    obj->traceback = TVMFFIByteArray{obj->traceback_data_.data(), 
obj->traceback_data_.length()};
+  }
+
   std::string kind_data_;
   std::string message_data_;
   std::string traceback_data_;
 };
+}  // namespace details
 
 /*!
  * \brief Managed reference to ErrorObj
@@ -109,14 +122,7 @@ class ErrorObj : public Object, public TVMFFIErrorCell {
 class Error : public ObjectRef, public std::exception {
  public:
   Error(std::string kind, std::string message, std::string traceback) {
-    ObjectPtr<ErrorObj> n = make_object<ErrorObj>();
-    n->kind_data_ = std::move(kind);
-    n->message_data_ = std::move(message);
-    n->traceback_data_ = std::move(traceback);
-    n->kind = TVMFFIByteArray{n->kind_data_.data(), n->kind_data_.length()};
-    n->message = TVMFFIByteArray{n->message_data_.data(), 
n->message_data_.length()};
-    n->traceback = TVMFFIByteArray{n->traceback_data_.data(), 
n->traceback_data_.length()};
-    data_ = std::move(n);
+    data_ = make_object<details::ErrorObjFromStd>(kind, message, traceback);
   }
 
   Error(std::string kind, std::string message, const TVMFFIByteArray* 
traceback)
@@ -137,6 +143,11 @@ class Error : public ObjectRef, public std::exception {
     return std::string(obj->traceback.data, obj->traceback.size);
   }
 
+  void UpdateTraceback(const TVMFFIByteArray* traceback_str) {
+    ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
+    obj->update_traceback(obj, traceback_str);
+  }
+
   const char* what() const noexcept(true) override {
     thread_local std::string what_data;
     ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
diff --git a/ffi/src/ffi/error.cc b/ffi/src/ffi/error.cc
index 4dcfb67714..c8c77e510d 100644
--- a/ffi/src/ffi/error.cc
+++ b/ffi/src/ffi/error.cc
@@ -67,13 +67,6 @@ void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) {
   tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromRaised(result);
 }
 
-void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle obj, const TVMFFIByteArray* 
traceback) {
-  TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
-  static_cast<tvm::ffi::ErrorObj*>(reinterpret_cast<tvm::ffi::Object*>(obj))
-      ->UpdateTraceback(traceback);
-  TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorUpdateTraceback);
-}
-
 TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const 
TVMFFIByteArray* message,
                                      const TVMFFIByteArray* traceback) {
   TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi
index 42db97809d..8fe23cd23b 100644
--- a/python/tvm/ffi/cython/base.pxi
+++ b/python/tvm/ffi/cython/base.pxi
@@ -128,6 +128,7 @@ cdef extern from "tvm/ffi/c_api.h":
         TVMFFIByteArray kind
         TVMFFIByteArray message
         TVMFFIByteArray traceback
+        void (*update_traceback)(TVMFFIObjectHandle self, const 
TVMFFIByteArray* traceback)
 
     ctypedef int (*TVMFFISafeCallType)(
         void* ctx, const TVMFFIAny* args, int32_t num_args,
@@ -144,7 +145,6 @@ cdef extern from "tvm/ffi/c_api.h":
     int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle* 
out) nogil
     void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) nogil
     void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil
-    void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle error, TVMFFIByteArray* 
traceback) nogil
     TVMFFIObjectHandle TVMFFIErrorCreate(TVMFFIByteArray* kind, 
TVMFFIByteArray* message,
                                          TVMFFIByteArray* traceback) nogil
     int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil
diff --git a/python/tvm/ffi/cython/error.pxi b/python/tvm/ffi/cython/error.pxi
index 73aa86572d..3a19573b8f 100644
--- a/python/tvm/ffi/cython/error.pxi
+++ b/python/tvm/ffi/cython/error.pxi
@@ -89,7 +89,7 @@ cdef class Error(Object):
             The traceback to update.
         """
         cdef ByteArrayArg traceback_arg = ByteArrayArg(c_str(traceback))
-        TVMFFIErrorUpdateTraceback(self.chandle, traceback_arg.cptr())
+        TVMFFIErrorGetCellPtr(self.chandle).update_traceback(self.chandle, 
traceback_arg.cptr())
 
     def py_error(self):
         """

Reply via email to