This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s3 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ab7a2c3e618c7f5d046e47a319fc52d4d3be3727 Author: tqchen <[email protected]> AuthorDate: Sun May 4 13:41:42 2025 -0400 Bring custom hook to ffi layer --- python/tvm/_ffi/__init__.py | 2 +- python/tvm/_ffi/base.py | 9 ++-- python/tvm/ffi/cython/dtype.pxi | 2 +- python/tvm/ffi/cython/object.pxi | 62 ++++++++++++++++++++++ python/tvm/ffi/cython/string.pxi | 2 + python/tvm/ffi/dtype.py | 2 +- python/tvm/ir/container.py | 2 + python/tvm/meta_schedule/cost_model/mlp_model.py | 4 +- python/tvm/rpc/client.py | 2 +- python/tvm/runtime/__init__.py | 2 +- python/tvm/runtime/_ffi_node_api.py | 3 -- python/tvm/runtime/container.py | 3 ++ python/tvm/runtime/device.py | 5 +- python/tvm/runtime/module.py | 2 +- python/tvm/runtime/ndarray.py | 2 +- python/tvm/runtime/object.py | 62 +++++----------------- python/tvm/tir/expr.py | 4 +- src/node/structural_hash.cc | 30 +++++++++++ src/support/ffi_testing.cc | 1 - tests/python/ffi/test_string.py | 9 ++++ .../python/tvmscript/test_tvmscript_printer_doc.py | 2 - .../test_tvmscript_printer_structural_equal.py | 5 +- 22 files changed, 140 insertions(+), 77 deletions(-) diff --git a/python/tvm/_ffi/__init__.py b/python/tvm/_ffi/__init__.py index 4e8e59b3a8..559ca84635 100644 --- a/python/tvm/_ffi/__init__.py +++ b/python/tvm/_ffi/__init__.py @@ -28,4 +28,4 @@ from . import _pyversion from . import base from .registry import register_object, register_func from .registry import _init_api, get_global_func -from tvm.ffi import register_error +from ..ffi import register_error diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 76293cf125..18ed40fb4c 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -18,13 +18,9 @@ # pylint: disable=invalid-name, import-outside-toplevel """Base library for TVM FFI.""" import ctypes -import functools import os -import re import sys -import types -from typing import Callable, Sequence, Optional import numpy as np @@ -64,10 +60,11 @@ _LIB, _LIB_NAME = _load_lib() # Whether we are runtime only _RUNTIME_ONLY = "runtime" in _LIB_NAME -import tvm.ffi.registry if _RUNTIME_ONLY: - tvm.ffi.registry._SKIP_UNKNOWN_OBJECTS = True + from ..ffi import registry as _tvm_ffi_registry + + _tvm_ffi_registry._SKIP_UNKNOWN_OBJECTS = True # The FFI mode of TVM _FFI_MODE = os.environ.get("TVM_FFI", "auto") diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi index ec045fce65..bbf9e60053 100644 --- a/python/tvm/ffi/cython/dtype.pxi +++ b/python/tvm/ffi/cython/dtype.pxi @@ -29,7 +29,7 @@ def _create_dtype_from_tuple(cls, code, bits, lanes): cdtype.bits = bits cdtype.lanes = lanes ret = cls.__new__(cls) - (<DLDataType>ret).cdtype = cdtype + (<DataType>ret).cdtype = cdtype return ret diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi index 9abbb40195..1ac32c3bc6 100644 --- a/python/tvm/ffi/cython/object.pxi +++ b/python/tvm/ffi/cython/object.pxi @@ -32,6 +32,31 @@ def __object_repr__(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" +def __object_save_json__(obj): + """Object repr function that can be overridden by assigning to it""" + raise NotImplementedError("JSON serialization depends on downstream init") + + +def __object_load_json__(json_str): + """Object repr function that can be overridden by assigning to it""" + raise NotImplementedError("JSON serialization depends on downstream init") + + +def __object_dir__(obj): + """Object dir function that can be overridden by assigning to it""" + return [] + + +def __object_getattr__(obj, name): + """Object getattr function that can be overridden by assigning to it""" + raise AttributeError() + + +def _new_object(cls): + """Helper function for pickle""" + return cls.__new__(cls) + + class ObjectGeneric: """Base class for all classes that can be converted to object.""" @@ -71,9 +96,46 @@ cdef class Object: cdef uint64_t chandle = <uint64_t>self.chandle return chandle + def __reduce__(self): + cls = type(self) + return (_new_object, (cls,), self.__getstate__()) + + def __getstate__(self): + if not self.__chandle__() == 0: + # need to explicit convert to str in case String + # returned and triggered another infinite recursion in get state + return {"handle": str(__object_save_json__(self))} + return {"handle": None} + + def __setstate__(self, state): + # pylint: disable=assigning-non-slot, assignment-from-no-return + handle = state["handle"] + if handle is not None: + self.__init_handle_by_constructor__(__object_load_json__, handle) + else: + self.chandle = NULL + + def __getattr__(self, name): + try: + return __object_getattr__(self, name) + except AttributeError: + raise AttributeError(f"{type(self)} has no attribute {name}") + + def __dir__(self): + return __object_dir__(self) + def __repr__(self): return __object_repr__(self) + def __eq__(self, other): + return self.same_as(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __init_handle_by_load_json__(self, json_str): + raise NotImplementedError("JSON serialization depends on downstream init") + def __init_handle_by_constructor__(self, fconstructor, *args): """Initialize the handle by calling constructor function. diff --git a/python/tvm/ffi/cython/string.pxi b/python/tvm/ffi/cython/string.pxi index 00ec92b7ec..733ea90301 100644 --- a/python/tvm/ffi/cython/string.pxi +++ b/python/tvm/ffi/cython/string.pxi @@ -65,6 +65,7 @@ class String(str, PyNativeObject): val.__tvm_ffi_object__ = obj return val + _register_object_by_index(kTVMFFIStr, String) @@ -89,6 +90,7 @@ class Bytes(bytes, PyNativeObject): val.__tvm_ffi_object__ = obj return val + _register_object_by_index(kTVMFFIBytes, Bytes) # We special handle str/bytes constructor in cython to avoid extra cyclic deps diff --git a/python/tvm/ffi/dtype.py b/python/tvm/ffi/dtype.py index ca98726560..56b888316d 100644 --- a/python/tvm/ffi/dtype.py +++ b/python/tvm/ffi/dtype.py @@ -17,9 +17,9 @@ """dtype class.""" # pylint: disable=invalid-name from enum import IntEnum +import numpy as np from . import core -import numpy as np class DataTypeCode(IntEnum): diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py index 6a013bce8c..4bc6fcae21 100644 --- a/python/tvm/ir/container.py +++ b/python/tvm/ir/container.py @@ -16,3 +16,5 @@ # under the License. """Additional container data structures used across IR variants.""" from tvm.ffi import Array, Map + +__all__ = ["Array", "Map"] diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py b/python/tvm/meta_schedule/cost_model/mlp_model.py index 8bd050b689..4ee5ba838d 100644 --- a/python/tvm/meta_schedule/cost_model/mlp_model.py +++ b/python/tvm/meta_schedule/cost_model/mlp_model.py @@ -541,7 +541,9 @@ class State: "_workload.json", "_candidates.json" ), ) - except tvm._ffi.base.TVMError: # pylint: disable=protected-access + except ( + tvm._ffi.base.TVMError + ): # pylint: disable=protected-access,broad-exception-caught continue candidates, results = [], [] tuning_records = database.get_all_tuning_records() diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index cf9706c348..f9e677e49e 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=used-before-assignment +# pylint: disable=used-before-assignment,broad-exception-caught """RPC client tools""" import os import socket diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index a630ce1101..774c8dd635 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -41,7 +41,7 @@ from .params import ( load_param_dict_from_file, ) -from tvm.ffi import convert, dtype as DataType, DataTypeCode from . import disco from .support import _regex_match +from ..ffi import convert, dtype as DataType, DataTypeCode diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 9623a87351..395496d16b 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -48,6 +48,3 @@ def LoadJSON(json_str): # Exports functions registered via TVM_REGISTER_GLOBAL with the "node" prefix. # e.g. TVM_REGISTER_GLOBAL("node.AsRepr") tvm._ffi._init_api("node", __name__) - -# override the default repr function for tvm.ffi.core.Object -tvm.ffi.core.__object_repr__ = AsRepr diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 052701f0d3..3bf149d6b2 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,3 +16,6 @@ # under the License. """Runtime container structures.""" from tvm.ffi import String, Shape as ShapeTuple + + +__all__ = ["ShapeTuple", "String"] diff --git a/python/tvm/runtime/device.py b/python/tvm/runtime/device.py index b83bb8cceb..d9d6abce50 100644 --- a/python/tvm/runtime/device.py +++ b/python/tvm/runtime/device.py @@ -16,13 +16,12 @@ # under the License. """Common runtime ctypes.""" # pylint: disable=invalid-name +import json + import tvm.ffi -from tvm.ffi import dtype as DataType, DataTypeCode from . import _ffi_api -import json - RPC_SESS_MASK = 128 diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 657120f907..bb1fbb5fe3 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -102,7 +102,7 @@ class Module(tvm.ffi.Object): """Runtime Module.""" def __new__(cls): - instance = super().__new__(cls) + instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter instance.entry_name = "__tvm_main__" instance._entry = None return instance diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index f1a7d1c2f7..5581adbbc1 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -33,7 +33,7 @@ import tvm.ffi from . import _ffi_api -from tvm.ffi import ( +from ..ffi import ( device, cpu, cuda, diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index 2aa2c08632..688682d197 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -17,58 +17,22 @@ # pylint: disable=invalid-name, unused-import """Runtime Object API""" +from tvm.ffi.core import Object import tvm.ffi.core +from . import _ffi_node_api -from . import _ffi_api, _ffi_node_api - -def _new_object(cls): - """Helper function for pickle""" - return cls.__new__(cls) - - -class Object(tvm.ffi.core.Object): - """Base class for all tvm's runtime objects.""" - - __slots__ = [] - - def __dir__(self): - class_names = dir(self.__class__) - fnames = _ffi_node_api.NodeListAttrNames(self) - size = fnames(-1) - return sorted([fnames(i) for i in range(size)] + class_names) - - def __getattr__(self, name): - try: - return _ffi_node_api.NodeGetAttr(self, name) - except AttributeError: - raise AttributeError(f"{type(self)} has no attribute {name}") from None - - def __hash__(self): - return _ffi_api.ObjectPtrHash(self) - - def __eq__(self, other): - return self.same_as(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __reduce__(self): - cls = type(self) - return (_new_object, (cls,), self.__getstate__()) - - def __getstate__(self): - if not self.__chandle__() == 0: - # need to explicit convert to str in case String - # returned and triggered another infinite recursion in get state - return {"handle": str(_ffi_node_api.SaveJSON(self))} - return {"handle": None} - - def __setstate__(self, state): - # pylint: disable=assigning-non-slot, assignment-from-no-return - handle = state["handle"] - if handle is not None: - self.__init_handle_by_constructor__(_ffi_node_api.LoadJSON, handle) +def __object_dir__(obj): + class_names = dir(obj.__class__) + fnames = _ffi_node_api.NodeListAttrNames(obj) + size = fnames(-1) + return sorted([fnames(i) for i in range(size)] + class_names) tvm.ffi.core._set_class_object(Object) +# override the default repr function for tvm.ffi.core.Object +tvm.ffi.core.__object_repr__ = _ffi_node_api.AsRepr +tvm.ffi.core.__object_save_json__ = _ffi_node_api.SaveJSON +tvm.ffi.core.__object_load_json__ = _ffi_node_api.LoadJSON +tvm.ffi.core.__object_getattr__ = _ffi_node_api.NodeGetAttr +tvm.ffi.core.__object_dir__ = __object_dir__ diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 13e10ba3ac..b293343cae 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -69,7 +69,7 @@ def _dtype_is_float(value): ) # type: ignore -class ExprOp(object): +class ExprOp: """Operator overloading for Expr like expressions.""" # TODO(tkonolige): use inspect to add source information to these objects @@ -395,7 +395,7 @@ class SizeVar(Var): @tvm._ffi.register_object("tir.IterVar") -class IterVar(Object, ExprOp, Scriptable): +class IterVar(ExprOp, Object, Scriptable): """Represent iteration variable. IterVar represents axis iterations in the computation. diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index bba8a59647..94e6768203 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -328,6 +328,22 @@ struct StringObjTrait { } }; +struct BytesObjTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const ffi::BytesObj* key, SHashReducer hash_reduce) { + hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); + } + + static bool SEqualReduce(const ffi::BytesObj* lhs, const ffi::BytesObj* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + if (lhs->size != rhs->size) return false; + if (lhs->data == rhs->data) return true; + return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; + } +}; + struct RefToObjectPtr : public ObjectRef { static ObjectPtr<Object> Get(const ObjectRef& ref) { return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(ref); @@ -350,6 +366,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; }); +TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj, BytesObjTrait) + .set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return GetRef<ffi::Bytes>(static_cast<const ffi::BytesObj*>(n)).operator std::string(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<ffi::BytesObj>([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast<const ffi::BytesObj*>(node.get()); + p->stream << "b\"" << support::StrEscape(op->data, op->size) << '"'; + }); + struct ModuleNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; static constexpr const std::nullptr_t SHashReduce = nullptr; diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index cfd5c42f6e..9e727f06d4 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -71,7 +71,6 @@ TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") *ret = result; }); - TVM_REGISTER_GLOBAL("testing.test_check_eq_callback") .set_body_packed([](TVMArgs args, TVMRetValue* ret) { auto msg = args[0].cast<std::string>(); diff --git a/tests/python/ffi/test_string.py b/tests/python/ffi/test_string.py index 67c82d52f7..98eab5bcb7 100644 --- a/tests/python/ffi/test_string.py +++ b/tests/python/ffi/test_string.py @@ -1,3 +1,4 @@ +import pickle from tvm import ffi as tvm_ffi @@ -12,6 +13,10 @@ def test_string(): assert isinstance(s3, tvm_ffi.String) assert isinstance(s3, str) + s4 = pickle.loads(pickle.dumps(s)) + assert s4 == "hello" + assert isinstance(s4, tvm_ffi.String) + def test_bytes(): fecho = tvm_ffi.get_global_func("testing.echo") @@ -27,3 +32,7 @@ def test_bytes(): b4 = tvm_ffi.convert(bytearray(b"hello")) assert isinstance(b4, tvm_ffi.Bytes) assert isinstance(b4, bytes) + + b5 = pickle.loads(pickle.dumps(b)) + assert b5 == b"hello" + assert isinstance(b5, tvm_ffi.Bytes) diff --git a/tests/python/tvmscript/test_tvmscript_printer_doc.py b/tests/python/tvmscript/test_tvmscript_printer_doc.py index 6353627c58..e3d1280b32 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_doc.py +++ b/tests/python/tvmscript/test_tvmscript_printer_doc.py @@ -307,7 +307,6 @@ def test_assign_doc(lhs, rhs, annotation): def test_invalid_assign_doc(lhs, rhs, annotation): with pytest.raises(ValueError) as e: AssignDoc(lhs, rhs, annotation) - assert "AssignDoc" in str(e.value) @pytest.mark.parametrize( @@ -332,7 +331,6 @@ def test_if_doc(then_branch, else_branch): if not then_branch and not else_branch: with pytest.raises(ValueError) as e: IfDoc(predicate, then_branch, else_branch) - assert "IfDoc" in str(e.value) return else: doc = IfDoc(predicate, then_branch, else_branch) diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py index 6f67733a28..58d9402e6f 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py +++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py @@ -24,12 +24,11 @@ from tvm.script import ir as I, tir as T def _error_message(exception): - splitter = "ValueError: StructuralEqual" - return splitter + str(exception).split(splitter)[1] + return str(exception) def _expected_result(func1, func2, objpath1, objpath2): - return f"""ValueError: StructuralEqual check failed, caused by lhs at {objpath1}: + return f"""StructuralEqual check failed, caused by lhs at {objpath1}: {func1.script(path_to_underline=[objpath1], syntax_sugar=False)} and rhs at {objpath2}: {func2.script(path_to_underline=[objpath2], syntax_sugar=False)}"""
