This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s3 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit a81a3f313b3b7a894ae61d27f84ad51e90322502 Author: tqchen <[email protected]> AuthorDate: Sun May 4 21:35:37 2025 -0400 [CYTHON] Phase out legacy FFI --- python/setup.py | 21 -- python/tvm/_ffi/_cy3/__init__.py | 17 -- python/tvm/_ffi/_cython/base.pxi | 236 -------------------- python/tvm/_ffi/_cython/core.pyx | 21 -- python/tvm/_ffi/_cython/ndarray.pxi | 180 --------------- python/tvm/_ffi/_cython/object.pxi | 165 -------------- python/tvm/_ffi/_cython/packed_func.pxi | 376 -------------------------------- 7 files changed, 1016 deletions(-) diff --git a/python/setup.py b/python/setup.py index 02cac5f897..3900e5b3d0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -148,7 +148,6 @@ def config_cython(): subdir = "_cy3" ret = [] - cython_source = "tvm/_ffi/_cython" extra_compile_args = ["-std=c++17", "-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>"] if os.name == "nt": library_dirs = ["tvm", "../build/Release", "../build"] @@ -164,26 +163,6 @@ def config_cython(): library_dirs = None libraries = None - for fn in os.listdir(cython_source): - if not fn.endswith(".pyx"): - continue - ret.append( - Extension( - "tvm._ffi.%s.%s" % (subdir, fn[:-4]), - ["tvm/_ffi/_cython/%s" % fn], - include_dirs=[ - "../include/", - "../3rdparty/dmlc-core/include", - "../ffi/include/", - "../ffi/3rdparty/dlpack/include", - ], - extra_compile_args=extra_compile_args, - library_dirs=library_dirs, - libraries=libraries, - language="c++", - ) - ) - # the latest ffi source for fn in os.listdir("tvm/ffi/cython"): if not fn.endswith(".pyx"): diff --git a/python/tvm/_ffi/_cy3/__init__.py b/python/tvm/_ffi/_cy3/__init__.py deleted file mode 100644 index 159f3254db..0000000000 --- a/python/tvm/_ffi/_cy3/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""cython3 namespace""" diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi deleted file mode 100644 index 036a9c9559..0000000000 --- a/python/tvm/_ffi/_cython/base.pxi +++ /dev/null @@ -1,236 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from ..base import raise_last_ffi_error -from libcpp cimport bool as bool_t -from libcpp.vector cimport vector -from cpython.version cimport PY_MAJOR_VERSION -from cpython cimport pycapsule -from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t -import ctypes - -cdef enum TVMArgTypeCode: - kInt = 0 - kUInt = 1 - kFloat = 2 - kTVMOpaqueHandle = 3 - kTVMNullptr = 4 - kTVMDataType = 5 - kDLDevice = 6 - kTVMDLTensorHandle = 7 - kTVMObjectHandle = 8 - kTVMModuleHandle = 9 - kTVMPackedFuncHandle = 10 - kTVMStr = 11 - kTVMBytes = 12 - kTVMNDArrayHandle = 13 - kTVMObjectRefArg = 14 - kTVMArgBool = 15 - kTVMExtBegin = 16 - -cdef extern from "tvm/runtime/c_runtime_api.h": - ctypedef struct DLDataType: - uint8_t code - uint8_t bits - uint16_t lanes - - ctypedef struct DLDevice: - int device_type - int device_id - - ctypedef struct DLTensor: - void* data - DLDevice device - int ndim - DLDataType dtype - int64_t* shape - int64_t* strides - uint64_t byte_offset - - ctypedef struct DLManagedTensor: - DLTensor dl_tensor - void* manager_ctx - void (*deleter)(DLManagedTensor* self) - - ctypedef struct TVMValue: - int64_t v_int64 - bool_t v_bool - double v_float64 - void* v_handle - const char* v_str - DLDataType v_type - DLDevice v_device - -ctypedef int64_t tvm_index_t -ctypedef DLTensor* DLTensorHandle -ctypedef void* TVMStreamHandle -ctypedef void* TVMRetValueHandle -ctypedef void* TVMPackedFuncHandle -ctypedef void* ObjectHandle - -ctypedef struct TVMObject: - uint32_t type_index_ - int32_t ref_counter_ - void (*deleter_)(TVMObject* self) - - -ctypedef int (*TVMPackedCFunc)( - TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, - void* resource_handle) - -ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle) - -# NOTE: All of TVM's C API function can be called without gil. -# for API functions that can be run long(e.g. FuncCall) -# we need to explicitly release the GIL as follows. -# -# cdef myfunc(): -# cdef int c_api_ret_code -# with nogil: -# c_api_ret_code = TVMAPIFunc(...) -# CHECK_CALL(c_apt_ret_code) -# -# Explicitly releasing the GIL enables other python threads -# to continue running while we are in TVMAPIFunc. -# Not releasing GIL explicitly is OK(and perhaps desirable) -# for short-running functions, as frequent unlocking also takes time, -# the python interpreter will release GIL in a set period. -# -# We mark the possibly long running function as nogil below. -cdef extern from "tvm/runtime/c_runtime_api.h": - void TVMAPISetLastError(const char* msg) - void TVMAPISetLastPythonError(void* py_object) except + - const char *TVMGetLastError() - int TVMFuncGetGlobal(const char* name, - TVMPackedFuncHandle* out) - int TVMFuncCall(TVMPackedFuncHandle func, - TVMValue* arg_values, - int* type_codes, - int num_args, - TVMValue* ret_val, - int* ret_type_code) nogil - int TVMFuncFree(TVMPackedFuncHandle func) - int TVMCFuncSetReturn(TVMRetValueHandle ret, - TVMValue* value, - int* type_code, - int num_ret) - int TVMFuncCreateFromCFunc(TVMPackedCFunc func, - void* resource_handle, - TVMPackedCFuncFinalizer fin, - TVMPackedFuncHandle *out) - int TVMCbArgToReturn(TVMValue* value, int* code) - int TVMArrayAlloc(tvm_index_t* shape, - tvm_index_t ndim, - DLDataType dtype, - DLDevice dev, - DLTensorHandle* out) nogil - int TVMArrayFree(DLTensorHandle handle) nogil - int TVMArrayCopyFromTo(DLTensorHandle src, - DLTensorHandle to, - TVMStreamHandle stream) nogil - int TVMArrayFromDLPack(DLManagedTensor* arr_from, - DLTensorHandle* out) nogil - int TVMArrayToDLPack(DLTensorHandle arr_from, - DLManagedTensor** out) nogil - void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) - int TVMObjectFree(ObjectHandle obj) - int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) - - -# the new FFI C API -cdef extern from "tvm/ffi/c_api.h": - ctypedef void* TVMFFIObjectHandle - - ctypedef struct TVMFFIByteArray: - const char* data - int64_t size - - int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil - TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil - - -cdef inline py_str(const char* x): - return x.decode("utf-8") - - -cdef inline c_str(pystr): - """Create ctypes char * from a python string - Parameters - ---------- - string : string type - python string - - Returns - ------- - str : c_char_p - A char pointer that can be passed to C API - """ - return pystr.encode("utf-8") - - -cdef inline int CHECK_CALL(int ret) except -2: - """Check the return code of the C API function call""" - # -2 brings exception - if ret == -2: - return -2 - if ret != 0: - raise_last_ffi_error() - return 0 - - -cdef inline object ctypes_handle(void* chandle): - """Cast C handle to ctypes handle.""" - return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p) - - -cdef inline void* c_handle(object handle): - """Cast C types handle to c handle.""" - cdef unsigned long long v_ptr - v_ptr = handle.value - return <void*>(v_ptr) - - -# python env API -cdef extern from "Python.h": - int PyErr_CheckSignals() - void* PyGILState_Ensure() - void PyGILState_Release(void*) - void Py_IncRef(void*) - void Py_DecRef(void*) - -cdef extern from "tvm/runtime/c_backend_api.h": - int TVMBackendRegisterEnvCAPI(const char* name, void* ptr) - -cdef _init_env_api(): - # Initialize env api for signal handling - # so backend can call tvm::runtime::EnvCheckSignals to check - # signal when executing a long running function. - # - # Also registers the gil state release and ensure as PyErr_CheckSignals - # function is called with gil released and we need to regrab the gil - CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyErr_CheckSignals"), <void*>PyErr_CheckSignals)) - CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure)) - CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), <void*>PyGILState_Release)) - CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), <void*>PyGILState_Release)) - CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_IncRef"), <void*>Py_IncRef)) - CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_DecRef"), <void*>Py_DecRef)) - - -_init_env_api() diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx deleted file mode 100644 index 730f8fc133..0000000000 --- a/python/tvm/_ffi/_cython/core.pyx +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -include "./base.pxi" -include "./object.pxi" -include "./packed_func.pxi" -include "./ndarray.pxi" diff --git a/python/tvm/_ffi/_cython/ndarray.pxi b/python/tvm/_ffi/_cython/ndarray.pxi deleted file mode 100644 index f220e866d3..0000000000 --- a/python/tvm/_ffi/_cython/ndarray.pxi +++ /dev/null @@ -1,180 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from ..runtime_ctypes import TVMArrayHandle -from cpython cimport PyCapsule_Destructor - -cdef const char* _c_str_dltensor = "dltensor" -cdef const char* _c_str_used_dltensor = "used_dltensor" - - -cdef void _c_dlpack_deleter(object pycaps): - cdef DLManagedTensor* dltensor - if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor): - dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor) - dltensor.deleter(dltensor) - - -def _from_dlpack(object dltensor): - cdef DLManagedTensor* ptr - cdef DLTensorHandle chandle - cdef int c_api_ret_code - if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): - ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) - with nogil: - c_api_ret_code = TVMArrayFromDLPack(ptr, &chandle) - CHECK_CALL(c_api_ret_code) - # set name and destructor to be empty - pycapsule.PyCapsule_SetDestructor(dltensor, NULL) - pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor) - return c_make_array(chandle, False, False) - raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once") - - -cdef class NDArrayBase: - cdef DLTensor* chandle - cdef int c_is_view - - cdef inline _set_handle(self, handle): - cdef unsigned long long ptr - if handle is None: - self.chandle = NULL - else: - ptr = ctypes.cast(handle, ctypes.c_void_p).value - self.chandle = <DLTensor*>(ptr) - - property _tvm_handle: - def __get__(self): - return <unsigned long long>self.chandle - - property handle: - def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes.cast( - <unsigned long long>self.chandle, TVMArrayHandle) - - def __set__(self, value): - self._set_handle(value) - - property is_view: - def __get__(self): - return self.c_is_view != 0 - - @property - def shape(self): - """Shape of this array""" - return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim)) - - def __init__(self, handle, is_view): - self._set_handle(handle) - self.c_is_view = is_view - - def __dealloc__(self): - cdef int c_api_ret_code - if self.c_is_view == 0: - with nogil: - c_api_ret_code = TVMArrayFree(self.chandle) - CHECK_CALL(c_api_ret_code) - - def _copyto(self, target_nd): - """Internal function that implements copy to target ndarray.""" - cdef int c_api_ret_code - with nogil: - c_api_ret_code = TVMArrayCopyFromTo(self.chandle, (<NDArrayBase>target_nd).chandle, NULL) - CHECK_CALL(c_api_ret_code) - return target_nd - - def to_dlpack(self): - """Produce an array from a DLPack Tensor without copying memory - - Returns - ------- - dlpack : DLPack tensor view of the array data - """ - cdef DLManagedTensor* dltensor - cdef int c_api_ret_code - if self.c_is_view != 0: - raise ValueError("to_dlpack do not work with memory views") - with nogil: - c_api_ret_code = TVMArrayToDLPack(self.chandle, &dltensor) - CHECK_CALL(c_api_ret_code) - return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, <PyCapsule_Destructor>_c_dlpack_deleter) - - -# Import limited object-related function from C++ side to improve the speed -# NOTE: can only use POD-C compatible object in FFI. -cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime": - cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle) - - -cdef c_make_array(void* chandle, is_view, is_container): - global _TVM_ND_CLS - - if is_container: - tindex = ( - <TVMObject*>TVMArrayHandleToObjectHandle(<DLTensorHandle>chandle)).type_index_ - if tindex < len(_TVM_ND_CLS): - cls = _TVM_ND_CLS[tindex] - if cls is not None: - ret = cls.__new__(cls) - else: - ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) - else: - ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) - (<NDArrayBase>ret).chandle = <DLTensor*>chandle - (<NDArrayBase>ret).c_is_view = <int>is_view - return ret - else: - ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) - (<NDArrayBase>ret).chandle = <DLTensor*>chandle - (<NDArrayBase>ret).c_is_view = <int>is_view - return ret - - -cdef _TVM_COMPATS = () - -cdef _TVM_EXT_RET = {} - -def _reg_extension(cls, fcreate): - global _TVM_COMPATS - _TVM_COMPATS += (cls,) - if fcreate: - _TVM_EXT_RET[cls._tvm_tcode] = fcreate - -cdef list _TVM_ND_CLS = [] - -cdef _register_ndarray(int index, object cls): - """register object class""" - global _TVM_ND_CLS - while len(_TVM_ND_CLS) <= index: - _TVM_ND_CLS.append(None) - - _TVM_ND_CLS[index] = cls - - -def _make_array(handle, is_view, is_container): - cdef unsigned long long ptr - ptr = ctypes.cast(handle, ctypes.c_void_p).value - return c_make_array(<void*>ptr, is_view, is_container) - -cdef object _CLASS_NDARRAY = None - -def _set_class_ndarray(cls): - global _CLASS_NDARRAY - _CLASS_NDARRAY = cls diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi deleted file mode 100644 index e45a13365c..0000000000 --- a/python/tvm/_ffi/_cython/object.pxi +++ /dev/null @@ -1,165 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Maps object type index to its constructor""" -cdef list OBJECT_TYPE = [] -"""Maps object type to its type index""" -cdef dict OBJECT_INDEX = {} - -def _register_object(int index, object cls): - """register object class""" - if issubclass(cls, NDArrayBase): - _register_ndarray(index, cls) - return - - global OBJECT_TYPE - while len(OBJECT_TYPE) <= index: - OBJECT_TYPE.append(None) - OBJECT_TYPE[index] = cls - OBJECT_INDEX[cls] = index - -def _get_object_type_index(object cls): - """get the type index of object class""" - return OBJECT_INDEX.get(cls) - -cdef inline object make_ret_object(void* chandle): - global OBJECT_TYPE - global _CLASS_OBJECT - cdef unsigned tindex - cdef object cls - cdef object handle - object_type = OBJECT_TYPE - handle = ctypes_handle(chandle) - tindex = TVMFFIObjectGetTypeIndex(chandle) - - if tindex < len(OBJECT_TYPE): - cls = OBJECT_TYPE[tindex] - if cls is not None: - if issubclass(cls, PyNativeObject): - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - (<ObjectBase>obj).chandle = chandle - return cls.__from_tvm_object__(cls, obj) - obj = cls.__new__(cls) - else: - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - else: - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - - (<ObjectBase>obj).chandle = chandle - - # Handle return values that must be converted from the TVM object - # to a python native object. This should be used in cases where - # subclassing the python native object is forbidden. For example, - # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does - # not allow any subclasses. - # if hasattr(obj, '__into_pynative_object__'): - # return obj.__into_pynative_object__) - - return obj - # return obj.__into_pynative_object__() - - -class PyNativeObject: - """Base class of all TVM objects that also subclass python's builtin types.""" - __slots__ = [] - - def __init_tvm_object_by_constructor__(self, fconstructor, *args): - """Initialize the internal tvm_object by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return object is directly set into the object - """ - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - obj.__init_handle_by_constructor__(fconstructor, *args) - self.__tvm_object__ = obj - - -cdef class ObjectBase: - cdef void* chandle - - cdef inline _set_handle(self, handle): - cdef unsigned long long ptr - if handle is None: - self.chandle = NULL - else: - ptr = handle.value - self.chandle = <void*>(ptr) - - property handle: - def __get__(self): - return ctypes_handle(self.chandle) - - def __set__(self, value): - self._set_handle(value) - - def __dealloc__(self): - CHECK_CALL(TVMObjectFree(self.chandle)) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # avoid error raised during construction. - self.chandle = NULL - cdef void* chandle - ConstructorCall( - (<PackedFuncBase>fconstructor).chandle, args, &chandle) - self.chandle = chandle - - def same_as(self, other): - """Check object identity. - - Parameters - ---------- - other : object - The other object to compare against. - - Returns - ------- - result : bool - The comparison result. - """ - if not isinstance(other, ObjectBase): - return False - return self.chandle == (<ObjectBase>other).chandle - -def StringGetPyString(obj): - cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((<ObjectBase>obj).chandle) - return py_str(bytes.data) diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi deleted file mode 100644 index 19af447a11..0000000000 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ /dev/null @@ -1,376 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import ctypes -import traceback -from cpython cimport Py_INCREF, Py_DECREF, PyGILState_Ensure, PyGILState_Release -from numbers import Number, Integral -from ..base import string_types, py2cerror -from ..runtime_ctypes import DataType, Device, TVMByteArray, ObjectRValueRef - - -cdef void tvm_callback_finalize(void* fhandle) with gil: - local_pyfunc = <object>(fhandle) - Py_DECREF(local_pyfunc) - -cdef int tvm_callback(TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, - void* fhandle) with gil: - cdef list pyargs - cdef TVMValue value - cdef int tcode - local_pyfunc = <object>(fhandle) - pyargs = [] - for i in range(num_args): - value = args[i] - tcode = type_codes[i] - if (tcode == kTVMObjectHandle or - tcode == kTVMPackedFuncHandle or - tcode == kTVMModuleHandle or - tcode == kTVMNDArrayHandle or - tcode == kTVMObjectRefArg or - tcode >= kTVMExtBegin): - CHECK_CALL(TVMCbArgToReturn(&value, &tcode)) - - if tcode != kTVMDLTensorHandle: - pyargs.append(make_ret(value, tcode)) - else: - pyargs.append(c_make_array(value.v_handle, True, False)) - try: - rv = local_pyfunc(*pyargs) - except Exception as err: - msg = traceback.format_exc() - msg = py2cerror(msg) - TVMAPISetLastPythonError(<void*>err) - - return -1 - if rv is not None: - if isinstance(rv, tuple): - raise ValueError("PackedFunction can only support one return value") - temp_args = [] - make_arg(rv, &value, &tcode, temp_args) - CHECK_CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1)) - return 0 - - -cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global): - obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC) - (<PackedFuncBase>obj).chandle = chandle - (<PackedFuncBase>obj).is_global = is_global - return obj - - -def convert_to_tvm_func(object pyfunc): - """Convert a python function to TVM function - - Parameters - ---------- - pyfunc : python function - The python function to be converted. - - Returns - ------- - tvmfunc: tvm.Function - The converted tvm function. - """ - cdef TVMPackedFuncHandle chandle - Py_INCREF(pyfunc) - CHECK_CALL(TVMFuncCreateFromCFunc(tvm_callback, - <void*>(pyfunc), - tvm_callback_finalize, - &chandle)) - return make_packed_func(chandle, False) - - -cdef inline int make_arg(object arg, - TVMValue* value, - int* tcode, - list temp_args) except -1: - """Pack arguments into c args tvm call accept""" - cdef unsigned long long ptr - if isinstance(arg, ObjectBase): - value[0].v_handle = (<ObjectBase>arg).chandle - tcode[0] = kTVMObjectHandle - elif isinstance(arg, NDArrayBase): - value[0].v_handle = (<NDArrayBase>arg).chandle - tcode[0] = (kTVMNDArrayHandle if - not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle) - elif isinstance(arg, PyNativeObject): - value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle - tcode[0] = kTVMObjectHandle - elif isinstance(arg, _TVM_COMPATS): - ptr = arg._tvm_handle - value[0].v_handle = (<void*>ptr) - tcode[0] = arg.__class__._tvm_tcode - elif isinstance(arg, bool): - # A python `bool` is a subclass of `int`, so this check - # must occur before `Integral`. - value[0].v_int64 = arg - tcode[0] = kTVMArgBool - elif isinstance(arg, Integral): - value[0].v_int64 = arg - tcode[0] = kInt - elif isinstance(arg, float): - value[0].v_float64 = arg - tcode[0] = kFloat - elif isinstance(arg, str): - tstr = c_str(arg) - value[0].v_str = tstr - tcode[0] = kTVMStr - temp_args.append(tstr) - elif arg is None: - value[0].v_handle = NULL - tcode[0] = kTVMNullptr - elif isinstance(arg, Number): - value[0].v_float64 = arg - tcode[0] = kFloat - elif isinstance(arg, DataType): - tstr = c_str(str(arg)) - value[0].v_str = tstr - tcode[0] = kTVMStr - temp_args.append(tstr) - elif isinstance(arg, Device): - value[0].v_device = (<DLDevice*>( - <unsigned long long>ctypes.addressof(arg)))[0] - tcode[0] = kDLDevice - elif isinstance(arg, (bytes, bytearray)): - # from_buffer only taeks in bytearray. - if isinstance(arg, bytes): - byte_arr = bytearray(arg) - temp_args.append(byte_arr) - arg = byte_arr - - arr = TVMByteArray() - arr.data = ctypes.cast( - (ctypes.c_byte * len(arg)).from_buffer(arg), - ctypes.POINTER(ctypes.c_byte)) - arr.size = len(arg) - value[0].v_handle = <void*>( - <unsigned long long>ctypes.addressof(arr)) - tcode[0] = kTVMBytes - temp_args.append(arr) - elif isinstance(arg, string_types): - tstr = c_str(arg) - value[0].v_str = tstr - tcode[0] = kTVMStr - temp_args.append(tstr) - elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): - arg = _FUNC_CONVERT_TO_OBJECT(arg) - value[0].v_handle = (<ObjectBase>arg).chandle - tcode[0] = kTVMObjectHandle - temp_args.append(arg) - elif isinstance(arg, _CLASS_MODULE): - value[0].v_handle = c_handle(arg.handle) - tcode[0] = kTVMModuleHandle - elif isinstance(arg, PackedFuncBase): - value[0].v_handle = (<PackedFuncBase>arg).chandle - tcode[0] = kTVMPackedFuncHandle - elif isinstance(arg, ctypes.c_void_p): - value[0].v_handle = c_handle(arg) - tcode[0] = kTVMOpaqueHandle - elif isinstance(arg, ObjectRValueRef): - value[0].v_handle = &((<ObjectBase>(arg.obj)).chandle) - tcode[0] = kTVMObjectRefArg - elif callable(arg): - arg = convert_to_tvm_func(arg) - value[0].v_handle = (<PackedFuncBase>arg).chandle - tcode[0] = kTVMPackedFuncHandle - temp_args.append(arg) - else: - raise TypeError("Don't know how to handle type %s" % type(arg)) - return 0 - - -cdef inline bytearray make_ret_bytes(void* chandle): - handle = ctypes_handle(chandle) - arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] - size = arr.size - res = bytearray(size) - rptr = (ctypes.c_byte * size).from_buffer(res) - if not ctypes.memmove(rptr, arr.data, size): - raise RuntimeError('memmove failed') - return res - - -cdef inline object make_ret(TVMValue value, int tcode): - """convert result to return value.""" - if tcode == kTVMObjectHandle: - return make_ret_object(value.v_handle) - elif tcode == kTVMNullptr: - return None - elif tcode == kTVMArgBool: - return bool(value.v_int64) - elif tcode == kInt: - return value.v_int64 - elif tcode == kFloat: - return value.v_float64 - elif tcode == kTVMNDArrayHandle: - return c_make_array(value.v_handle, False, True) - elif tcode == kTVMStr: - return py_str(value.v_str) - elif tcode == kTVMBytes: - return make_ret_bytes(value.v_handle) - elif tcode == kTVMOpaqueHandle: - return ctypes_handle(value.v_handle) - elif tcode == kDLDevice: - return Device(value.v_device.device_type, value.v_device.device_id) - elif tcode == kTVMModuleHandle: - return _CLASS_MODULE(ctypes_handle(value.v_handle)) - elif tcode == kTVMPackedFuncHandle: - return make_packed_func(value.v_handle, False) - elif tcode in _TVM_EXT_RET: - return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) - - raise ValueError("Unhandled type code %d" % tcode) - - -cdef inline int FuncCall3(void* chandle, - tuple args, - int nargs, - TVMValue* ret_val, - int* ret_tcode) except -1: - cdef TVMValue[3] values - cdef int[3] tcodes - nargs = len(args) - temp_args = [] - for i in range(nargs): - make_arg(args[i], &values[i], &tcodes[i], temp_args) - - with nogil: - c_api_ret_code = TVMFuncCall(chandle, &values[0], &tcodes[0], - nargs, ret_val, ret_tcode) - - CHECK_CALL(c_api_ret_code) - return 0 - -cdef inline int FuncCall(void* chandle, - tuple args, - TVMValue* ret_val, - int* ret_tcode) except -1: - cdef int nargs - cdef int c_api_ret_code - nargs = len(args) - if nargs <= 3: - FuncCall3(chandle, args, nargs, ret_val, ret_tcode) - return 0 - - cdef vector[TVMValue] values - cdef vector[int] tcodes - values.resize(max(nargs, 1)) - tcodes.resize(max(nargs, 1)) - temp_args = [] - for i in range(nargs): - make_arg(args[i], &values[i], &tcodes[i], temp_args) - - with nogil: - c_api_ret_code = TVMFuncCall(chandle, &values[0], &tcodes[0], - nargs, ret_val, ret_tcode) - CHECK_CALL(c_api_ret_code) - return 0 - - -cdef inline int ConstructorCall(void* constructor_handle, - tuple args, - void** handle) except -1: - """Call contructor of a handle function""" - cdef TVMValue ret_val - cdef int ret_tcode - FuncCall(constructor_handle, args, &ret_val, &ret_tcode) - handle[0] = ret_val.v_handle - return 0 - - -cdef class PackedFuncBase: - cdef TVMPackedFuncHandle chandle - cdef int is_global - - cdef inline _set_handle(self, handle): - if handle is None: - self.chandle = NULL - else: - self.chandle = c_handle(handle) - - property is_global: - def __get__(self): - return self.c_is_global != 0 - - def __set__(self, value): - self.c_is_global = value - - property handle: - def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p) - def __set__(self, value): - self._set_handle(value) - - def __init__(self, handle, is_global): - self._set_handle(handle) - self.c_is_global = is_global - - def __dealloc__(self): - if self.is_global == 0: - CHECK_CALL(TVMFuncFree(self.chandle)) - - def __call__(self, *args): - cdef TVMValue ret_val - cdef int ret_tcode - ret_tcode = kTVMNullptr - FuncCall(self.chandle, args, &ret_val, &ret_tcode) - return make_ret(ret_val, ret_tcode) - - -def _get_global_func(name, allow_missing): - cdef TVMPackedFuncHandle chandle - CHECK_CALL(TVMFuncGetGlobal(c_str(name), &chandle)) - if chandle != NULL: - return make_packed_func(chandle, True) - - if allow_missing: - return None - - raise ValueError("Cannot find global function %s" % name) - - -_CLASS_PACKED_FUNC = None -_CLASS_MODULE = None -_CLASS_OBJECT = None -_CLASS_OBJECT_GENERIC = None -_FUNC_CONVERT_TO_OBJECT = None - -def _set_class_module(module_class): - """Initialize the module.""" - global _CLASS_MODULE - _CLASS_MODULE = module_class - -def _set_class_packed_func(func_class): - global _CLASS_PACKED_FUNC - _CLASS_PACKED_FUNC = func_class - -def _set_class_object(obj_class): - global _CLASS_OBJECT - _CLASS_OBJECT = obj_class - -def _set_class_object_generic(object_generic_class, func_convert_to_object): - global _CLASS_OBJECT_GENERIC - global _FUNC_CONVERT_TO_OBJECT - _CLASS_OBJECT_GENERIC = object_generic_class - _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
