junrushao commented on code in PR #438:
URL: https://github.com/apache/tvm-ffi/pull/438#discussion_r2780822687
##########
python/tvm_ffi/registry.py:
##########
@@ -335,25 +335,97 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
if not hasattr(type_cls, name): # skip already defined attributes
setattr(type_cls, name, field.as_property(type_cls))
has_c_init = False
+ has_shallow_copy = False
for method in type_info.methods:
name = method.name
if name == "__ffi_init__":
name = "__c_ffi_init__"
has_c_init = True
- if not hasattr(type_cls, name):
+ if name == "__ffi_shallow_copy__":
+ has_shallow_copy = True
+ # Always override: shallow copy is type-specific and must not be
inherited
+ setattr(type_cls, name, method.as_callable(type_cls))
+ elif not hasattr(type_cls, name):
setattr(type_cls, name, method.as_callable(type_cls))
if "__init__" not in type_cls.__dict__:
if has_c_init:
setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__"))
elif not issubclass(type_cls, core.PyNativeObject):
setattr(type_cls, "__init__", __init__invalid)
+ _setup_copy_methods(type_cls, has_shallow_copy)
return type_cls
+def _setup_copy_methods(type_cls: type, has_shallow_copy: bool) -> None:
+ """Set up __copy__, __deepcopy__, __replace__ based on copy support."""
+ if has_shallow_copy:
+ if "__copy__" not in type_cls.__dict__:
+ setattr(type_cls, "__copy__", _copy_supported)
+ if "__deepcopy__" not in type_cls.__dict__:
+ setattr(type_cls, "__deepcopy__", _deepcopy_supported)
+ if "__replace__" not in type_cls.__dict__:
+ setattr(type_cls, "__replace__", _replace_supported)
+ else:
+ if "__copy__" not in type_cls.__dict__:
+ setattr(type_cls, "__copy__", _copy_unsupported)
+ if "__deepcopy__" not in type_cls.__dict__:
+ setattr(type_cls, "__deepcopy__", _deepcopy_unsupported)
+ if "__replace__" not in type_cls.__dict__:
+ setattr(type_cls, "__replace__", _replace_unsupported)
+
+
def __init__invalid(self: Any, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("The __init__ method of this class is not implemented.")
+def _copy_supported(self: Any) -> Any:
+ return self.__ffi_shallow_copy__()
+
+
+def _deepcopy_supported(self: Any, memo: Any = None) -> Any:
+ return _get_deep_copy_func()(self)
+
+
+def _replace_supported(self: Any, **kwargs: Any) -> Any:
+ import copy # noqa: PLC0415
+
+ obj = copy.copy(self)
+ for key, value in kwargs.items():
+ setattr(obj, key, value)
+ return obj
+
+
+def _copy_unsupported(self: Any) -> Any:
+ raise TypeError(
+ f"Type `{type(self).__name__}` does not support copy. "
+ f"Enable it with refl::enable_copy() in the C++ type registration."
+ )
+
+
+def _deepcopy_unsupported(self: Any, memo: Any = None) -> Any:
+ raise TypeError(
+ f"Type `{type(self).__name__}` does not support deepcopy. "
+ f"Enable it with refl::enable_copy() in the C++ type registration."
+ )
+
+
+def _replace_unsupported(self: Any, **kwargs: Any) -> Any:
+ raise TypeError(
+ f"Type `{type(self).__name__}` does not support replace. "
+ f"Enable it with refl::enable_copy() in the C++ type registration."
+ )
+
+
+def _get_deep_copy_func() -> core.Function:
+ global _deep_copy_func_cache # noqa: PLW0603
+ if _deep_copy_func_cache is None:
+ _deep_copy_func_cache = get_global_func("ffi.DeepCopy")
+ return _deep_copy_func_cache
+
+
+_deep_copy_func_cache: core.Function | None = None
Review Comment:
probably don't need this - just call `get_global_func` every time
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]