This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s3 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bd915404443b1536b63be8fb2c9852faa5fef29f Author: tqchen <tianqi.tc...@gmail.com> AuthorDate: Mon May 5 17:34:37 2025 -0400 [FFI] Make Error to be ABI invariant --- ffi/include/tvm/ffi/c_api.h | 14 ++++++------ ffi/include/tvm/ffi/error.h | 47 +++++++++++++++++++++++++---------------- ffi/src/ffi/error.cc | 7 ------ python/tvm/ffi/cython/base.pxi | 2 +- python/tvm/ffi/cython/error.pxi | 2 +- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 61738d1082..1d495d9c5e 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -224,6 +224,12 @@ typedef struct { * \brief The traceback of the error. */ TVMFFIByteArray traceback; + /*! + * \brief Function handle to update the traceback of the error. + * \param self The self object handle. + * \param traceback The traceback to update. + */ + void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback); } TVMFFIErrorCell; /*! @@ -483,14 +489,6 @@ TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message, const TVMFFIByteArray* traceback); -/*! - * \brief Update the traceback of an Error object. - * \param obj The error handle. - * \param traceback The traceback to update. - */ -TVM_FFI_DLL void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle obj, - const TVMFFIByteArray* traceback); - /*! * \brief Check if there are any signals raised in the surrounding env. * \return 0 when success, nonzero when failure happens diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h index 4810754f17..239a0e500b 100644 --- a/ffi/include/tvm/ffi/error.h +++ b/ffi/include/tvm/ffi/error.h @@ -81,26 +81,39 @@ struct EnvErrorAlreadySet : public std::exception {}; */ class ErrorObj : public Object, public TVMFFIErrorCell { public: - /*! - * \brief Update the traceback of the error object. - * \param traceback The traceback to update. - */ - void UpdateTraceback(const TVMFFIByteArray* traceback_str) { - this->traceback_data_ = std::string(traceback_str->data, traceback_str->size); - this->traceback = TVMFFIByteArray{this->traceback_data_.data(), this->traceback_data_.length()}; - } - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError; static constexpr const char* _type_key = "object.Error"; TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ErrorObj, Object); +}; + +namespace details { +class ErrorObjFromStd : public ErrorObj { + public: + ErrorObjFromStd(std::string kind, std::string message, std::string traceback) + : kind_data_(kind), message_data_(message), traceback_data_(traceback) { + this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()}; + this->message = TVMFFIByteArray{message_data_.data(), message_data_.length()}; + this->traceback = TVMFFIByteArray{traceback_data_.data(), traceback_data_.length()}; + this->update_traceback = UpdateTraceback; + } private: - friend class Error; + /*! + * \brief Update the traceback of the error object. + * \param traceback The traceback to update. + */ + static void UpdateTraceback(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback_str) { + ErrorObjFromStd* obj = static_cast<ErrorObjFromStd*>(self); + obj->traceback_data_ = std::string(traceback_str->data, traceback_str->size); + obj->traceback = TVMFFIByteArray{obj->traceback_data_.data(), obj->traceback_data_.length()}; + } + std::string kind_data_; std::string message_data_; std::string traceback_data_; }; +} // namespace details /*! * \brief Managed reference to ErrorObj @@ -109,14 +122,7 @@ class ErrorObj : public Object, public TVMFFIErrorCell { class Error : public ObjectRef, public std::exception { public: Error(std::string kind, std::string message, std::string traceback) { - ObjectPtr<ErrorObj> n = make_object<ErrorObj>(); - n->kind_data_ = std::move(kind); - n->message_data_ = std::move(message); - n->traceback_data_ = std::move(traceback); - n->kind = TVMFFIByteArray{n->kind_data_.data(), n->kind_data_.length()}; - n->message = TVMFFIByteArray{n->message_data_.data(), n->message_data_.length()}; - n->traceback = TVMFFIByteArray{n->traceback_data_.data(), n->traceback_data_.length()}; - data_ = std::move(n); + data_ = make_object<details::ErrorObjFromStd>(kind, message, traceback); } Error(std::string kind, std::string message, const TVMFFIByteArray* traceback) @@ -137,6 +143,11 @@ class Error : public ObjectRef, public std::exception { return std::string(obj->traceback.data, obj->traceback.size); } + void UpdateTraceback(const TVMFFIByteArray* traceback_str) { + ErrorObj* obj = static_cast<ErrorObj*>(data_.get()); + obj->update_traceback(obj, traceback_str); + } + const char* what() const noexcept(true) override { thread_local std::string what_data; ErrorObj* obj = static_cast<ErrorObj*>(data_.get()); diff --git a/ffi/src/ffi/error.cc b/ffi/src/ffi/error.cc index 4dcfb67714..c8c77e510d 100644 --- a/ffi/src/ffi/error.cc +++ b/ffi/src/ffi/error.cc @@ -67,13 +67,6 @@ void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) { tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromRaised(result); } -void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle obj, const TVMFFIByteArray* traceback) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - static_cast<tvm::ffi::ErrorObj*>(reinterpret_cast<tvm::ffi::Object*>(obj)) - ->UpdateTraceback(traceback); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorUpdateTraceback); -} - TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message, const TVMFFIByteArray* traceback) { TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 42db97809d..8fe23cd23b 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -128,6 +128,7 @@ cdef extern from "tvm/ffi/c_api.h": TVMFFIByteArray kind TVMFFIByteArray message TVMFFIByteArray traceback + void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback) ctypedef int (*TVMFFISafeCallType)( void* ctx, const TVMFFIAny* args, int32_t num_args, @@ -144,7 +145,6 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle* out) nogil void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) nogil void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil - void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle error, TVMFFIByteArray* traceback) nogil TVMFFIObjectHandle TVMFFIErrorCreate(TVMFFIByteArray* kind, TVMFFIByteArray* message, TVMFFIByteArray* traceback) nogil int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil diff --git a/python/tvm/ffi/cython/error.pxi b/python/tvm/ffi/cython/error.pxi index 73aa86572d..3a19573b8f 100644 --- a/python/tvm/ffi/cython/error.pxi +++ b/python/tvm/ffi/cython/error.pxi @@ -89,7 +89,7 @@ cdef class Error(Object): The traceback to update. """ cdef ByteArrayArg traceback_arg = ByteArrayArg(c_str(traceback)) - TVMFFIErrorUpdateTraceback(self.chandle, traceback_arg.cptr()) + TVMFFIErrorGetCellPtr(self.chandle).update_traceback(self.chandle, traceback_arg.cptr()) def py_error(self): """