gemini-code-assist[bot] commented on code in PR #421:
URL: https://github.com/apache/tvm-ffi/pull/421#discussion_r2724029320
##########
python/tvm_ffi/dataclasses/field.py:
##########
@@ -160,11 +166,14 @@ class PyBase:
raise TypeError("`init` must be a bool")
if not isinstance(repr, bool):
raise TypeError("`repr` must be a bool")
+<<<<<<< HEAD
Review Comment:

There appears to be an unresolved merge conflict marker `<<<<<<< HEAD` left
in the code. This will cause a syntax error and needs to be removed.
##########
python/tvm_ffi/dataclasses/_utils.py:
##########
@@ -162,7 +162,189 @@ def method_repr(type_cls: type, type_info: TypeInfo) ->
Callable[..., str]:
return __repr__
-def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
+def method_eq(type_cls: type, type_info: TypeInfo) -> Callable[..., bool]:
+ """Generate an ``__eq__`` method that compares all fields with
``compare=True``.
+
+ The generated method compares all fields with ``compare=True`` in the order
+ they appear in the type hierarchy.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should be compared
+ compare_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.compare:
+ compare_fields.append(field.name)
+
+ # Step 2. Generate the eq method
+ if not compare_fields:
+ # No fields to compare, all instances are equal
+ body_lines = ["return True"]
+ else:
+ # Build field comparisons
+ comparisons = " and ".join(
+ f"self.{field_name} == other.{field_name}" for field_name in
compare_fields
+ )
Review Comment:

Generating a long string of `and` comparisons can be less efficient and
harder to read than tuple comparison. Python's tuple comparison is implemented
in C and is generally faster for this kind of lexicographical comparison. You
can simplify this by comparing tuples of field values.
```python
comparisons = f"({', '.join(f'self.{f}' for f in compare_fields)},)
== ({', '.join(f'other.{f}' for f in compare_fields)},)"
```
##########
tests/python/test_dataclasses_c_class.py:
##########
@@ -184,3 +185,92 @@ def test_field_kw_only_with_default() -> None:
def test_kw_only_sentinel_exists() -> None:
assert isinstance(KW_ONLY, _KW_ONLY_TYPE)
+
+
+def test_cxx_class_eq() -> None:
+ """Test that eq=True generates __eq__ and __ne__ methods."""
+ # Use the already registered _TestCxxClassBase which has eq=True by default
+ obj1 = _TestCxxClassBase(v_i64=123, v_i32=456)
+ obj2 = _TestCxxClassBase(v_i64=123, v_i32=456)
+ obj3 = _TestCxxClassBase(v_i64=789, v_i32=456)
+
+ # Test __eq__
+ assert obj1 == obj2
+ assert not (obj1 == obj3)
+
+ # Test __ne__
+ assert obj1 != obj3
+ assert not (obj1 != obj2)
+
+ # Test with different types
+ assert obj1 != "not an object"
+ assert not (obj1 == "not an object")
+
+
+def test_cxx_class_order() -> None:
+ """Test that order=True generates ordering methods."""
+ from tvm_ffi.dataclasses import c_class
+
+ # Create a test class with order=True using a different type key
+ # We need to use a type that supports ordering, so we'll test with
_TestCxxClassDerived
+ # which should inherit comparison methods if order=True is set
+ # For now, let's test that ordering methods can be generated by checking
if they exist
+ # on a class that was registered with order=True
+ # Note: Since _TestCxxClassBase doesn't have order=True, we'll test the
functionality
+ # by creating a new class that would have order=True if we could register
it
+ # Instead, let's verify that the methods would be generated correctly by
testing
+ # the comparison logic on _TestCxxClassBase instances
+ obj1 = _TestCxxClassBase(v_i64=1, v_i32=2)
+ obj2 = _TestCxxClassBase(v_i64=1, v_i32=3)
+ obj3 = _TestCxxClassBase(v_i64=2, v_i32=1)
+ obj4 = _TestCxxClassBase(v_i64=1, v_i32=2)
+
+ # Check if ordering methods exist (they might not if order=False was used)
+ has_ordering = any(method in _TestCxxClassBase.__dict__ for method in
["__lt__", "__le__", "__gt__", "__ge__"])
+
+ if has_ordering:
+ # Test __lt__ (less than)
+ assert obj1 < obj2 # v_i64 equal, v_i32: 2 < 3
+ assert obj1 < obj3 # v_i64: 1 < 2
+ assert not (obj1 < obj4) # equal
+
+ # Test __le__ (less than or equal)
+ assert obj1 <= obj2
+ assert obj1 <= obj4 # equal
+ assert not (obj2 <= obj1)
+
+ # Test __gt__ (greater than)
+ assert obj2 > obj1
+ assert obj3 > obj1
+ assert not (obj1 > obj4) # equal
+
+ # Test __ge__ (greater than or equal)
+ assert obj2 >= obj1
+ assert obj1 >= obj4 # equal
+ assert not (obj1 >= obj2)
+ else:
+ # If ordering methods don't exist, that's expected if order=False was
used
+ # We'll just verify that the class exists and can be instantiated
+ assert obj1 is not None
+ assert obj2 is not None
Review Comment:

This test for ordering methods is ineffective. The `_TestCxxClassBase` is
likely decorated with `order=False` (the default), so `has_ordering` will be
`False`, and the actual comparison assertions will be skipped. The comments in
the test also seem contradictory about how the test is supposed to work.
To properly test this feature, you should have a test class that is
decorated with `order=True`. This might require defining a new C++ class and
registering it with a unique type key for testing purposes, or finding a way to
re-decorate an existing class with `order=True` in a separate test namespace to
avoid registration conflicts.
##########
python/tvm_ffi/dataclasses/_utils.py:
##########
@@ -162,7 +162,189 @@ def method_repr(type_cls: type, type_info: TypeInfo) ->
Callable[..., str]:
return __repr__
-def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
+def method_eq(type_cls: type, type_info: TypeInfo) -> Callable[..., bool]:
+ """Generate an ``__eq__`` method that compares all fields with
``compare=True``.
+
+ The generated method compares all fields with ``compare=True`` in the order
+ they appear in the type hierarchy.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should be compared
+ compare_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.compare:
+ compare_fields.append(field.name)
+
+ # Step 2. Generate the eq method
+ if not compare_fields:
+ # No fields to compare, all instances are equal
+ body_lines = ["return True"]
+ else:
+ # Build field comparisons
+ comparisons = " and ".join(
+ f"self.{field_name} == other.{field_name}" for field_name in
compare_fields
+ )
+ body_lines = [
+ "if not isinstance(other, type(self)):",
+ " return NotImplemented",
+ f"return {comparisons}",
+ ]
+
+ source_lines = ["def __eq__(self, other: object) -> bool:"]
+ source_lines.extend(f" {line}" for line in body_lines)
+ source = "\n".join(source_lines)
+
+ # Note: Code generation in this case is guaranteed to be safe,
+ # because the generated code does not contain any untrusted input.
+ namespace: dict[str, Any] = {}
+ exec(source, {}, namespace)
+ __eq__ = namespace["__eq__"]
+ return __eq__
+
+
+def method_ne(type_cls: type, type_info: TypeInfo) -> Callable[..., bool]:
+ """Generate a ``__ne__`` method that compares all fields with
``compare=True``.
+
+ The generated method is the negation of ``__eq__``.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should be compared
+ compare_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.compare:
+ compare_fields.append(field.name)
+
+ # Step 2. Generate the ne method
+ if not compare_fields:
+ # No fields to compare, all instances are equal, so ne always returns
False
+ body_lines = ["return False"]
+ else:
+ # Build field comparisons
+ comparisons = " or ".join(
+ f"self.{field_name} != other.{field_name}" for field_name in
compare_fields
+ )
+ body_lines = [
+ "if not isinstance(other, type(self)):",
+ " return NotImplemented",
+ f"return {comparisons}",
+ ]
+
+ source_lines = ["def __ne__(self, other: object) -> bool:"]
+ source_lines.extend(f" {line}" for line in body_lines)
+ source = "\n".join(source_lines)
+
+ # Note: Code generation in this case is guaranteed to be safe,
+ # because the generated code does not contain any untrusted input.
+ namespace: dict[str, Any] = {}
+ exec(source, {}, namespace)
+ __ne__ = namespace["__ne__"]
+ return __ne__
+
+
+def method_order(type_cls: type, type_info: TypeInfo) -> dict[str,
Callable[..., bool]]:
+ """Generate ordering methods (``__lt__``, ``__le__``, ``__gt__``,
``__ge__``).
+
+ The generated methods compare all fields with ``compare=True`` in the order
+ they appear in the type hierarchy, using lexicographic comparison.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should be compared
+ compare_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.compare:
+ compare_fields.append(field.name)
+
+ # Step 2. Generate lexicographic comparison logic
+ if not compare_fields:
+ # No fields to compare, all instances are equal
+ comparison_body = "False"
+ else:
+ # Build lexicographic comparison: compare field by field
+ # For each field, check if all previous fields are equal and current
field is less
+ comparison_parts: list[str] = []
+ for i, field_name in enumerate(compare_fields):
+ if i == 0:
+ # First field: just compare directly
+ comparison_parts.append(f"self.{field_name} <
other.{field_name}")
+ else:
+ # Subsequent fields: all previous must be equal, then compare
current
+ eq_checks = " and ".join(
+ f"self.{f} == other.{f}" for f in compare_fields[:i]
+ )
+ comparison_parts.append(
+ f"({eq_checks} and self.{field_name} < other.{field_name})"
+ )
+ comparison_body = " or ".join(comparison_parts)
+
+ # Generate __lt__
+ source_lines_lt = [
+ "def __lt__(self, other: object) -> bool:",
+ " if not isinstance(other, type(self)):",
+ " return NotImplemented",
+ f" return {comparison_body}",
+ ]
+ source_lt = "\n".join(source_lines_lt)
+ namespace_lt: dict[str, Any] = {}
+ exec(source_lt, {}, namespace_lt)
+ __lt__ = namespace_lt["__lt__"]
+
+ # Generate __le__ (less than or equal)
+ source_lines_le = [
+ "def __le__(self, other: object) -> bool:",
+ " if not isinstance(other, type(self)):",
+ " return NotImplemented",
+ " return self < other or self == other",
+ ]
+ source_le = "\n".join(source_lines_le)
+ namespace_le: dict[str, Any] = {}
+ exec(source_le, {}, namespace_le)
+ __le__ = namespace_le["__le__"]
+
+ # Generate __gt__ (greater than)
+ source_lines_gt = [
+ "def __gt__(self, other: object) -> bool:",
+ " if not isinstance(other, type(self)):",
+ " return NotImplemented",
+ " return other < self",
+ ]
+ source_gt = "\n".join(source_lines_gt)
+ namespace_gt: dict[str, Any] = {}
+ exec(source_gt, {}, namespace_gt)
+ __gt__ = namespace_gt["__gt__"]
+
+ # Generate __ge__ (greater than or equal)
+ source_lines_ge = [
+ "def __ge__(self, other: object) -> bool:",
+ " if not isinstance(other, type(self)):",
+ " return NotImplemented",
+ " return self > other or self == other",
+ ]
+ source_ge = "\n".join(source_lines_ge)
+ namespace_ge: dict[str, Any] = {}
+ exec(source_ge, {}, namespace_ge)
+ __ge__ = namespace_ge["__ge__"]
+
+ return {
+ "__lt__": __lt__,
+ "__le__": __le__,
+ "__gt__": __gt__,
+ "__ge__": __ge__,
+ }
Review Comment:

This function can be significantly simplified.
1. **Use tuple comparison for `__lt__`**: The lexicographical comparison
logic can be simplified by creating tuples of the fields and comparing them
directly. This is more concise, readable, and often more performant.
2. **Use `functools.total_ordering`**: Instead of manually generating
`__le__`, `__gt__`, and `__ge__`, you can define `__eq__` and `__lt__` and then
use the `@functools.total_ordering` class decorator to automatically generate
the other rich comparison methods. This would require a small change in
`c_class.py` to apply the decorator when `order=True`.
With these changes, this function could be renamed to `method_lt` and only
be responsible for generating the `__lt__` method.
##########
python/tvm_ffi/dataclasses/_utils.py:
##########
@@ -162,7 +162,189 @@ def method_repr(type_cls: type, type_info: TypeInfo) ->
Callable[..., str]:
return __repr__
-def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
+def method_eq(type_cls: type, type_info: TypeInfo) -> Callable[..., bool]:
+ """Generate an ``__eq__`` method that compares all fields with
``compare=True``.
+
+ The generated method compares all fields with ``compare=True`` in the order
+ they appear in the type hierarchy.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should be compared
+ compare_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.compare:
+ compare_fields.append(field.name)
+
+ # Step 2. Generate the eq method
+ if not compare_fields:
+ # No fields to compare, all instances are equal
+ body_lines = ["return True"]
+ else:
+ # Build field comparisons
+ comparisons = " and ".join(
+ f"self.{field_name} == other.{field_name}" for field_name in
compare_fields
+ )
+ body_lines = [
+ "if not isinstance(other, type(self)):",
+ " return NotImplemented",
+ f"return {comparisons}",
+ ]
+
+ source_lines = ["def __eq__(self, other: object) -> bool:"]
+ source_lines.extend(f" {line}" for line in body_lines)
+ source = "\n".join(source_lines)
+
+ # Note: Code generation in this case is guaranteed to be safe,
+ # because the generated code does not contain any untrusted input.
+ namespace: dict[str, Any] = {}
+ exec(source, {}, namespace)
+ __eq__ = namespace["__eq__"]
+ return __eq__
+
+
+def method_ne(type_cls: type, type_info: TypeInfo) -> Callable[..., bool]:
+ """Generate a ``__ne__`` method that compares all fields with
``compare=True``.
+
+ The generated method is the negation of ``__eq__``.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should be compared
+ compare_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.compare:
+ compare_fields.append(field.name)
+
+ # Step 2. Generate the ne method
+ if not compare_fields:
+ # No fields to compare, all instances are equal, so ne always returns
False
+ body_lines = ["return False"]
+ else:
+ # Build field comparisons
+ comparisons = " or ".join(
+ f"self.{field_name} != other.{field_name}" for field_name in
compare_fields
+ )
Review Comment:

Similar to `__eq__`, this can be simplified and made more efficient by using
tuple comparison.
```python
comparisons = f"({', '.join(f'self.{f}' for f in compare_fields)},)
!= ({', '.join(f'other.{f}' for f in compare_fields)},)"
```
##########
python/tvm_ffi/dataclasses/_utils.py:
##########
@@ -162,7 +162,189 @@ def method_repr(type_cls: type, type_info: TypeInfo) ->
Callable[..., str]:
return __repr__
-def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
+def method_eq(type_cls: type, type_info: TypeInfo) -> Callable[..., bool]:
+ """Generate an ``__eq__`` method that compares all fields with
``compare=True``.
+
+ The generated method compares all fields with ``compare=True`` in the order
+ they appear in the type hierarchy.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should be compared
+ compare_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.compare:
+ compare_fields.append(field.name)
Review Comment:

The logic for collecting and filtering fields to be compared is duplicated
across `method_eq`, `method_ne`, and `method_order`. To improve maintainability
and reduce code duplication, consider extracting this logic into a private
helper function, for example `_get_compare_fields(type_info)`.
##########
python/tvm_ffi/dataclasses/c_class.py:
##########
@@ -128,9 +134,13 @@ class MyClass:
"""
def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]:
# noqa: UP006
- nonlocal init, repr
+ nonlocal init, kw_only, repr, eq, order
init = init and "__init__" not in super_type_cls.__dict__
repr = repr and "__repr__" not in super_type_cls.__dict__
+ eq = eq and "__eq__" not in super_type_cls.__dict__
Review Comment:

The check for whether to generate equality methods only considers `__eq__`.
If a user defines `__ne__` but not `__eq__`, this logic will proceed to
generate and overwrite the user's `__ne__`. To avoid this, you should also
check for the presence of `__ne__`. Typically, if either `__eq__` or `__ne__`
is user-defined, neither should be auto-generated.
```suggestion
eq = eq and not ("__eq__" in super_type_cls.__dict__ or "__ne__" in
super_type_cls.__dict__)
```
##########
tests/python/test_dataclasses_c_class.py:
##########
@@ -184,3 +185,92 @@ def test_field_kw_only_with_default() -> None:
def test_kw_only_sentinel_exists() -> None:
assert isinstance(KW_ONLY, _KW_ONLY_TYPE)
+
+
+def test_cxx_class_eq() -> None:
+ """Test that eq=True generates __eq__ and __ne__ methods."""
+ # Use the already registered _TestCxxClassBase which has eq=True by default
+ obj1 = _TestCxxClassBase(v_i64=123, v_i32=456)
+ obj2 = _TestCxxClassBase(v_i64=123, v_i32=456)
+ obj3 = _TestCxxClassBase(v_i64=789, v_i32=456)
+
+ # Test __eq__
+ assert obj1 == obj2
+ assert not (obj1 == obj3)
+
+ # Test __ne__
+ assert obj1 != obj3
+ assert not (obj1 != obj2)
+
+ # Test with different types
+ assert obj1 != "not an object"
+ assert not (obj1 == "not an object")
+
+
+def test_cxx_class_order() -> None:
+ """Test that order=True generates ordering methods."""
+ from tvm_ffi.dataclasses import c_class
+
+ # Create a test class with order=True using a different type key
+ # We need to use a type that supports ordering, so we'll test with
_TestCxxClassDerived
+ # which should inherit comparison methods if order=True is set
+ # For now, let's test that ordering methods can be generated by checking
if they exist
+ # on a class that was registered with order=True
+ # Note: Since _TestCxxClassBase doesn't have order=True, we'll test the
functionality
+ # by creating a new class that would have order=True if we could register
it
+ # Instead, let's verify that the methods would be generated correctly by
testing
+ # the comparison logic on _TestCxxClassBase instances
+ obj1 = _TestCxxClassBase(v_i64=1, v_i32=2)
+ obj2 = _TestCxxClassBase(v_i64=1, v_i32=3)
+ obj3 = _TestCxxClassBase(v_i64=2, v_i32=1)
+ obj4 = _TestCxxClassBase(v_i64=1, v_i32=2)
+
+ # Check if ordering methods exist (they might not if order=False was used)
+ has_ordering = any(method in _TestCxxClassBase.__dict__ for method in
["__lt__", "__le__", "__gt__", "__ge__"])
+
+ if has_ordering:
+ # Test __lt__ (less than)
+ assert obj1 < obj2 # v_i64 equal, v_i32: 2 < 3
+ assert obj1 < obj3 # v_i64: 1 < 2
+ assert not (obj1 < obj4) # equal
+
+ # Test __le__ (less than or equal)
+ assert obj1 <= obj2
+ assert obj1 <= obj4 # equal
+ assert not (obj2 <= obj1)
+
+ # Test __gt__ (greater than)
+ assert obj2 > obj1
+ assert obj3 > obj1
+ assert not (obj1 > obj4) # equal
+
+ # Test __ge__ (greater than or equal)
+ assert obj2 >= obj1
+ assert obj1 >= obj4 # equal
+ assert not (obj1 >= obj2)
+ else:
+ # If ordering methods don't exist, that's expected if order=False was
used
+ # We'll just verify that the class exists and can be instantiated
+ assert obj1 is not None
+ assert obj2 is not None
+
+
+def test_cxx_class_compare_field() -> None:
+ """Test that compare parameter in field() controls comparison."""
+ # Since we can't re-register testing.TestCxxClassBase, we'll test the
compare
+ # functionality by verifying that _TestCxxClassBase uses all fields in
comparison
+ # (since they all have compare=True by default)
+ obj1 = _TestCxxClassBase(v_i64=1, v_i32=100)
+ obj2 = _TestCxxClassBase(v_i64=1, v_i32=100) # Same values
+
+ # Should be equal because all fields match
+ assert obj1 == obj2
+
+ # If v_i64 differs, they should not be equal
+ obj3 = _TestCxxClassBase(v_i64=2, v_i32=100)
+ assert obj1 != obj3
+
+ # If v_i32 differs, they should not be equal
+ obj4 = _TestCxxClassBase(v_i64=1, v_i32=200)
+ assert obj1 != obj4
Review Comment:

This test only covers the case where `compare=True` for all fields, which is
the default. It doesn't test the case where a field has `compare=False` and
should be excluded from comparisons.
To make this test comprehensive, you should create a test case with a class
where one or more fields are defined with `field(compare=False)` and assert
that changes in those fields do *not* affect the equality outcome.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]