This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refl-clean in repository https://gitbox.apache.org/repos/asf/tvm.git
commit a65c23a65ec2289feb284d28c41b5f2081482dba Author: tqchen <[email protected]> AuthorDate: Tue Aug 5 13:20:32 2025 -0400 [REFACTOR] Phase out getattr based attribute handling This PR phases out getattar based attribute handling as they are slower and introduces extra code path. This does mean that if an Object is not explicitly registered in python side, we will no longer be able to access the field by name. Likely this is also desirable as we would like to enable faster use that updates the python end and do not rely on these behavior. --- python/tvm/ffi/cython/object.pxi | 24 ------------ python/tvm/ir/attrs.py | 6 +++ python/tvm/runtime/_ffi_node_api.py | 8 ---- python/tvm/runtime/object.py | 9 ----- python/tvm/te/tensor.py | 20 ---------- python/tvm/testing/__init__.py | 1 + python/tvm/tir/transform/transform.py | 42 ++++++++++++++++++++ src/node/reflection.cc | 70 ---------------------------------- src/tir/transforms/hoist_expression.cc | 4 +- tests/python/ir/test_ir_attrs.py | 4 +- 10 files changed, 54 insertions(+), 134 deletions(-) diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi index 4efedf35d8..474ab0980c 100644 --- a/python/tvm/ffi/cython/object.pxi +++ b/python/tvm/ffi/cython/object.pxi @@ -42,16 +42,6 @@ def __object_load_json__(json_str): raise NotImplementedError("JSON serialization depends on downstream init") -def __object_dir__(obj): - """Object dir function that can be overridden by assigning to it""" - return [] - - -def __object_getattr__(obj, name): - """Object getattr function that can be overridden by assigning to it""" - raise AttributeError() - - def _new_object(cls): """Helper function for pickle""" return cls.__new__(cls) @@ -121,20 +111,6 @@ cdef class Object: else: self.chandle = NULL - def __getattr__(self, name): - if self.chandle == NULL: - raise AttributeError(f"{type(self)} has no attribute {name}") - try: - return __object_getattr__(self, name) - except AttributeError: - raise AttributeError(f"{type(self)} has no attribute {name}") - - def __dir__(self): - # exception safety handling for chandle=None - if self.chandle == NULL: - return [] - return __object_dir__(self) - def __repr__(self): # exception safety handling for chandle=None if self.chandle == NULL: diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index e7de1a9f90..4d296af063 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -101,6 +101,12 @@ class DictAttrs(Attrs): def __contains__(self, k): return self._dict().__contains__(k) + def __getattr__(self, name): + try: + return self._dict().__getitem__(name) + except KeyError: + raise AttributeError(f"DictAttrs has no attribute {name}") + def items(self): """Get items from the map.""" return self._dict().items() diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index aef9ded9cc..4a0edd449c 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -28,14 +28,6 @@ def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" -def NodeListAttrNames(obj): - return lambda x: 0 - - -def NodeGetAttr(obj, name): - raise AttributeError() - - def SaveJSON(obj): raise RuntimeError("Do not support object serialization in runtime only mode") diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index 688682d197..3ba26a8695 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -22,17 +22,8 @@ import tvm.ffi.core from . import _ffi_node_api -def __object_dir__(obj): - class_names = dir(obj.__class__) - fnames = _ffi_node_api.NodeListAttrNames(obj) - size = fnames(-1) - return sorted([fnames(i) for i in range(size)] + class_names) - - tvm.ffi.core._set_class_object(Object) # override the default repr function for tvm.ffi.core.Object tvm.ffi.core.__object_repr__ = _ffi_node_api.AsRepr tvm.ffi.core.__object_save_json__ = _ffi_node_api.SaveJSON tvm.ffi.core.__object_load_json__ = _ffi_node_api.LoadJSON -tvm.ffi.core.__object_getattr__ = _ffi_node_api.NodeGetAttr -tvm.ffi.core.__object_dir__ = __object_dir__ diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 489ec38ba5..73b995a45e 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -84,26 +84,6 @@ class Tensor(DataProducer, _expr.ExprOp): """Dimension of the tensor.""" return len(self.shape) - @property - def axis(self): - """Axis of the tensor.""" - return self.__getattr__("axis") - - @property - def op(self): - """The corressponding :py:class:`Operation`.""" - return self.__getattr__("op") - - @property - def value_index(self): - """The output value index the tensor corresponds to.""" - return self.__getattr__("value_index") - - @property - def shape(self): - """The output shape of the tensor.""" - return self.__getattr__("shape") - @property def name(self): op = self.op diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index ea798242b4..5f64db0f88 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -43,3 +43,4 @@ from .popen_pool import ( ) from .runner import local_run, rpc_run from .utils import * +from . import attrs diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 81ce63b797..9cb27dc49e 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -23,6 +23,8 @@ from typing import Callable, Optional from . import _ffi_api from . import function_pass as _fpass +from ...ffi import register_object +from ...ir import Attrs def Apply(ftransform): @@ -48,6 +50,11 @@ def Apply(ftransform): return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore +@register_object("tir.transform.LoopPartitionConfig") +class LoopPartitionConfig(Attrs): + """Config for loop partition pass""" + + def LoopPartition(): """Inject virtual thread loops. @@ -87,6 +94,11 @@ def InjectVirtualThread(): return _ffi_api.InjectVirtualThread() # type: ignore +@register_object("tir.transform.InjectDoubleBufferConfig") +class InjectDoubleBufferConfig(Attrs): + """Config for inject double buffer pass""" + + def InjectDoubleBuffer(): """Inject double buffer statements. @@ -149,6 +161,11 @@ def PointerValueTypeRewrite(): return _ffi_api.PointerValueTypeRewrite() # type: ignore +@register_object("tir.transform.UnrollLoopConfig") +class UnrollLoopConfig(Attrs): + """Config for unroll loop pass""" + + def UnrollLoop(): """Unroll the constant loop marked by unroll. @@ -162,6 +179,11 @@ def UnrollLoop(): return _ffi_api.UnrollLoop() # type: ignore +@register_object("tir.transform.ReduceBranchingThroughOvercomputeConfig") +class ReduceBranchingThroughOvercomputeConfig(Attrs): + """Config for reduce branching through overcompute pass""" + + def ReduceBranchingThroughOvercompute(): """Reduce branching by introducing overcompute @@ -173,6 +195,11 @@ def ReduceBranchingThroughOvercompute(): return _ffi_api.ReduceBranchingThroughOvercompute() # type: ignore +@register_object("tir.transform.RemoveNoOpConfig") +class RemoveNoOpConfig(Attrs): + """Config for remove no op pass""" + + def RemoveNoOp(): """Remove No Op from the Stmt. @@ -277,6 +304,11 @@ def RewriteUnsafeSelect(): return _ffi_api.RewriteUnsafeSelect() # type: ignore +@register_object("tir.transform.SimplifyConfig") +class SimplifyConfig(Attrs): + """Config for simplify pass""" + + def Simplify(): """Run arithmetic simplifications on the statements and expressions. @@ -607,6 +639,11 @@ def VerifyVTCMLimit(limit=None): return _ffi_api.VerifyVTCMLimit(limit) # type: ignore +@register_object("tir.transform.HoistIfThenElseConfig") +class HoistIfThenElseConfig(Attrs): + """Config for hoist if then else pass""" + + # pylint: disable=no-else-return,inconsistent-return-statements def HoistIfThenElse(variant: Optional[str] = None): """Hoist loop-invariant IfThenElse nodes to outside the eligible loops. @@ -686,6 +723,11 @@ class HoistedLetBindings(enum.Flag): """ Enable all hoisting of let bindings """ +@register_object("tir.transform.HoistExpressionConfig") +class HoistExpressionConfig(Attrs): + """Config for hoist expression pass""" + + def HoistExpression(): """Generalized verison of HoistIfThenElse. diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 6db751a80f..47ea23eddf 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -33,74 +33,6 @@ using ffi::Any; using ffi::Function; using ffi::PackedArgs; -// Expose to FFI APIs. -void NodeGetAttr(ffi::PackedArgs args, ffi::Any* ret) { - Object* self = const_cast<Object*>(args[0].cast<const Object*>()); - String field_name = args[1].cast<String>(); - - bool success; - if (field_name == "type_key") { - *ret = self->GetTypeKey(); - success = true; - } else if (!self->IsInstance<DictAttrsNode>()) { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); - success = false; - // use new reflection mechanism - if (type_info->metadata != nullptr) { - ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - if (field_name.compare(field_info->name) == 0) { - ffi::reflection::FieldGetter field_getter(field_info); - *ret = field_getter(self); - success = true; - } - }); - } - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self); - auto it = dnode->dict.find(field_name); - if (it != dnode->dict.end()) { - success = true; - *ret = (*it).second; - } else { - success = false; - } - } - if (!success) { - TVM_FFI_THROW(AttributeError) << self->GetTypeKey() << " object has no attribute `" - << field_name << "`"; - } -} - -void NodeListAttrNames(ffi::PackedArgs args, ffi::Any* ret) { - Object* self = const_cast<Object*>(args[0].cast<const Object*>()); - - std::vector<String> names; - if (!self->IsInstance<DictAttrsNode>()) { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); - if (type_info->metadata != nullptr) { - // use new reflection mechanism - ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - names.push_back(std::string(field_info->name.data, field_info->name.size)); - }); - } - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self); - for (const auto& kv : dnode->dict) { - names.push_back(kv.first); - } - } - - *ret = ffi::Function::FromPacked([names](ffi::PackedArgs args, ffi::Any* rv) { - int64_t i = args[0].cast<int64_t>(); - if (i == -1) { - *rv = static_cast<int64_t>(names.size()); - } else { - *rv = names[i]; - } - }); -} // API function to make node. // args format: @@ -124,8 +56,6 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_packed("node.NodeGetAttr", NodeGetAttr) - .def_packed("node.NodeListAttrNames", NodeListAttrNames) .def_packed("node.MakeNode", MakeNode); }); diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index d89114c68a..1548ea1da6 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -82,7 +82,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter<HoistExpressionCo return static_cast<int>(flag) & hoisted_let_bindings; } - static constexpr const char* _type_key = "tir.transforms.HoistExpressionConfig"; + static constexpr const char* _type_key = "tir.transform.HoistExpressionConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(HoistExpressionConfigNode, Object); }; @@ -112,7 +112,7 @@ struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapter<HoistIfThenElseCo "Hoist if cond with block scope variables", refl::DefaultValue(false)); } - static constexpr const char* _type_key = "tir.transforms.HoistIfThenElseConfig"; + static constexpr const char* _type_key = "tir.transform.HoistIfThenElseConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(HoistIfThenElseConfigNode, Object); }; diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py index 48c38c1556..905069059f 100644 --- a/tests/python/ir/test_ir_attrs.py +++ b/tests/python/ir/test_ir_attrs.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm + +# needed for attrs +import tvm.testing import pytest -import tvm.ir._ffi_api def test_make_attrs():
