This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new de93d56 [FEAT] Split SEqHashDef into recursive and non-recursive
variants (#583)
de93d56 is described below
commit de93d564a13c7f4d74255b907dd2370ad9874cd8
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon May 11 12:57:02 2026 -0400
[FEAT] Split SEqHashDef into recursive and non-recursive variants (#583)
The single SEqHashDef flag treated every nested free var inside a
def-region field as a fresh def. That conflates two different binding
shapes:
- Recursive (function parameters): the value var and any free vars
inside its sub-fields are co-introduced at the same binding site.
- Non-recursive (normal binding): only the immediate value var binds;
free vars in the var's sub-fields are use references that must
resolve against an outer-scope binding.
This PR splits the flag in two so each binding shape can be expressed
directly, and updates the structural-equal / structural-hash machinery
to distinguish them when descending into a FreeVar's own sub-fields.
The custom `__ffi_s_equal__` / `__ffi_s_hash__` callback signature
gains a typed kind for the def-mode argument; legacy callers passing
a bool continue to compile and preserve their meaning via the standard
bool->int coercion (true -> 1 = Recursive, false -> 0 = None).
The Cython / Python bindings mirror the rename and add the new
non-recursive flag to the dataclass field vocabulary. Regression tests
cover the four recursive / non-recursive corner cases on a new
`TDefHolder` test type whose `def_recursive` and `def_non_recursive`
fields hold a FreeVar with a nested FreeVar sub-field.
---
docs/concepts/structural_eq_hash.rst | 72 +++++++++++---
include/tvm/ffi/c_api.h | 68 ++++++++++++-
include/tvm/ffi/reflection/registry.h | 21 +++-
python/tvm_ffi/cython/base.pxi | 8 +-
python/tvm_ffi/cython/object.pxi | 8 +-
python/tvm_ffi/cython/type_info.pxi | 10 +-
python/tvm_ffi/dataclasses/field.py | 30 ++++--
src/ffi/extra/structural_equal.cc | 117 ++++++++++++++---------
src/ffi/extra/structural_hash.cc | 132 ++++++++++++++++----------
tests/cpp/extra/test_structural_equal_hash.cc | 64 +++++++++++++
tests/cpp/test_reflection.cc | 2 +
tests/cpp/testing_object.h | 77 ++++++++++++++-
12 files changed, 481 insertions(+), 128 deletions(-)
diff --git a/docs/concepts/structural_eq_hash.rst
b/docs/concepts/structural_eq_hash.rst
index 55614c3..8dff838 100644
--- a/docs/concepts/structural_eq_hash.rst
+++ b/docs/concepts/structural_eq_hash.rst
@@ -618,14 +618,14 @@ Use for:
redundant to compare.
- **Debug annotations** — names, comments, metadata for human consumption.
-``structural_eq="def"`` — Definition region
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+``structural_eq="def-recursive"`` / ``"def-non-recursive"`` — Definition region
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
@py_class(structural_eq="tree")
class Lambda(Object):
- params: list[Var] = field(structural_eq="def")
+ params: list[Var] = field(structural_eq="def-recursive")
body: Expr
**Meaning**: "This field introduces new variable bindings. When comparing
@@ -633,15 +633,44 @@ or hashing this field, allow new variable correspondences
to be
established."
This is the counterpart to ``"var"``. A ``"var"`` type says "I am a
-variable"; ``structural_eq="def"`` says "this field is where variables are
-defined." Together they enable alpha-equivalence: comparing functions up
-to consistent variable renaming.
+variable"; the ``"def-*"`` flags on a field say "this field is where
+variables are defined." Together they enable alpha-equivalence:
+comparing functions up to consistent variable renaming.
+
+There are two flavors of definition region, distinguished by what
+happens when a ``"var"`` reached through the field carries its own
+sub-fields (for example, a shape annotation in the var's type):
+
+- ``"def-recursive"`` (alias: ``"def"``) — the variable's sub-fields
+ stay inside the definition region. Any free variables encountered
+ in those sub-fields are themselves treated as fresh definitions at
+ the same site. One example is **function parameter lists**, where
+ the value var and any shape parameters in its type are co-introduced
+ together at the function boundary.
+
+- ``"def-non-recursive"`` — only the immediate variable(s) reached
+ through the field bind. The variable's sub-fields are walked
+ outside the definition region, so any free variables there are
+ *use* references that must resolve against an outer-scope binding.
+ One example is a **normal binding** whose value type references
+ outer-scope shape parameters (a ``let v = expr`` where ``v``'s
+ type refers to vars defined earlier).
+
+When the distinction does not matter (no nested free vars under the
+bound variable), either flavor works and ``"def-recursive"`` is the
+conventional default — that's why the bare ``"def"`` alias resolves
+to it.
Use for:
-- **Function parameter lists**
-- **Let-binding left-hand sides**
-- **Any field that introduces names into scope**
+- **Function parameter lists** — ``"def-recursive"`` so shape
+ parameters in each param's type co-introduce at the same site.
+- **Normal binding left-hand sides** (let bindings, for-loop
+ iterators) whose value type references outer-scope vars —
+ ``"def-non-recursive"`` so those references don't rebind.
+- **Any field that introduces names into scope** — pick the flavor
+ that matches the binding form's contract; default to
+ ``"def-recursive"`` when in doubt.
.. _sequal-shash:
@@ -679,7 +708,7 @@ Signatures
(self, other, eq_cb) -> bool
- eq_cb(lhs, rhs, def_region: bool, field_name: str) -> bool
+ eq_cb(lhs, rhs, def_region_kind: int, field_name: str) -> bool
``__s_hash__``:
@@ -687,12 +716,25 @@ Signatures
(self, init_hash: int, hash_cb) -> int
- hash_cb(value, init_hash: int, def_region: bool) -> int
+ hash_cb(value, init_hash: int, def_region_kind: int) -> int
+
+The ``def_region_kind`` argument on each recursive call mirrors the
+field-level ``"def-*"`` flags and controls whether the sub-value is
+compared/hashed inside a definition region:
+
+- ``0`` — not in a def region (matches ``None`` on a field).
+- ``1`` — recursive def region (matches ``"def-recursive"``, alias
+ ``"def"``).
+- ``2`` — non-recursive def region (matches ``"def-non-recursive"``).
+
+For back-compat with the original single-flag API, the callback also
+accepts a plain ``bool``: ``True`` is treated as ``1`` (recursive) and
+``False`` as ``0`` (not in a def region). The Python examples below
+use ``True`` / ``False`` for that reason; pass an explicit ``2`` (or
+the ``kTVMFFIDefRegionKindNonRecursive`` enum value from C++) when the
+non-recursive kind is needed.
-The ``def_region`` flag on each recursive call controls whether the
-sub-value is compared/hashed in a definition region (enabling new
-variable bindings, just like ``field(structural_eq="def")``). The
-``field_name`` argument on ``eq_cb`` is used only for mismatch path
+The ``field_name`` argument on ``eq_cb`` is used only for mismatch path
reporting from :py:func:`~tvm_ffi.get_first_structural_mismatch`.
Example (Python)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 0d2c9df..52b6337 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -858,11 +858,13 @@ typedef enum {
*/
kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3,
/*!
- * \brief The field enters a def region where var can be defined/matched.
+ * \brief The field enters a recursive def region.
*
* This is an optional meta-data for structural eq/hash.
+ *
+ * \sa TVMFFIDefRegionKind for the def-region semantics.
*/
- kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4,
+ kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4,
/*!
* \brief The default_value_or_factory is a callable factory function () ->
Any.
*
@@ -922,6 +924,17 @@ typedef enum {
* ``(field_addr_as_OpaquePtr, value_as_AnyView)``.
*/
kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11,
+ /*!
+ * \brief The field enters a non-recursive def region.
+ *
+ * This is an optional meta-data for structural eq/hash.
+ *
+ * \sa TVMFFIDefRegionKind for the def-region semantics.
+ *
+ * \note Bit 1 << 12 is used here because bits 1 << 5 .. 1 << 11 are
+ * already taken by other field flags above.
+ */
+ kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12,
#ifdef __cplusplus
};
#else
@@ -993,6 +1006,57 @@ typedef enum {
} TVMFFISEqHashKind;
#endif
+/*!
+ * \brief Kind of def region a structural-equal / structural-hash callback is
+ * currently in when visiting a field.
+ */
+#ifdef __cplusplus
+enum TVMFFIDefRegionKind : int32_t {
+#else
+typedef enum {
+#endif
+ /*!
+ * \brief Not in a def region.
+ *
+ * Free vars reachable through this field are uses; they must already
+ * be bound by an enclosing def region or equality / hashing falls
+ * back to pointer identity.
+ */
+ kTVMFFIDefRegionKindNone = 0,
+ /*!
+ * \brief In a recursive def region.
+ *
+ * When we see a free var for the first time, we define the var, and
+ * the sub-fields of the var (e.g. its struct_info / type_annotation /
+ * shape) are also still in the def region — any free vars discovered
+ * inside those sub-fields are themselves treated as fresh defs at the
+ * same site.
+ *
+ * One example is function parameter lists: the value var and any
+ * shape parameters in its type are co-introduced at the same binding
+ * site.
+ */
+ kTVMFFIDefRegionKindRecursive = 1,
+ /*!
+ * \brief In a non-recursive def region.
+ *
+ * When we see a free var for the first time, we define the var, but
+ * the sub-fields of the var are NOT in the def region — they are
+ * treated as use references that must resolve against an outer
+ * binding. Free vars found in those sub-fields therefore do not
+ * rebind; if they are not already bound, equality fails.
+ *
+ * One example is a normal binding whose value type contains shape
+ * parameters: the value var is introduced fresh, but its shape
+ * parameters reference vars defined in an outer scope.
+ */
+ kTVMFFIDefRegionKindNonRecursive = 2,
+#ifdef __cplusplus
+};
+#else
+} TVMFFIDefRegionKind;
+#endif
+
/*!
* \brief Information support for optional object reflection.
*/
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 543321c..f810e5f 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -213,10 +213,25 @@ class AttachFieldFlag : public InfoTrait {
explicit AttachFieldFlag(int32_t flag) : flag_(flag) {}
/*!
- * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef
+ * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefRecursive
+ *
+ * The field enters a recursive def region: free vars discovered both at
+ * the field's value and inside that value's sub-fields bind as fresh
+ * defs at the same site. Use for "function-style" bindings.
+ */
+ TVM_FFI_INLINE static AttachFieldFlag SEqHashDefRecursive() {
+ return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefRecursive);
+ }
+ /*!
+ * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive
+ *
+ * The field enters a non-recursive def region: only the immediate free
+ * var at the field's value binds; free vars in its sub-fields are uses
+ * that must already be bound by an outer def region. Use for "let-style"
+ * bindings whose sub-fields reference outer-scope vars.
*/
- TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() {
- return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef);
+ TVM_FFI_INLINE static AttachFieldFlag SEqHashDefNonRecursive() {
+ return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive);
}
/*!
* \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index c5c28a1..a249207 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -207,7 +207,7 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3
- kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4
+ kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4
kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5
kTVMFFIFieldFlagBitMaskReprOff = 1 << 6
kTVMFFIFieldFlagBitMaskCompareOff = 1 << 7
@@ -215,6 +215,7 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIFieldFlagBitMaskInitOff = 1 << 9
kTVMFFIFieldFlagBitMaskKwOnly = 1 << 10
kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11
+ kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12
ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value)
noexcept
@@ -248,6 +249,11 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFISEqHashKindConstTreeNode = 4
kTVMFFISEqHashKindUniqueInstance = 5
+ cdef enum TVMFFIDefRegionKind:
+ kTVMFFIDefRegionKindNone = 0
+ kTVMFFIDefRegionKindRecursive = 1
+ kTVMFFIDefRegionKindNonRecursive = 2
+
ctypedef struct TVMFFITypeMetadata:
TVMFFIByteArray doc
TVMFFIObjectCreator creator
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 803ead2..c564b42 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -541,11 +541,13 @@ cdef _type_info_create_from_type_key(object type_cls, str
type_key):
c_default_factory = make_ret(owned_default)
else:
c_default = make_ret(owned_default)
- # Decode SEqHashIgnore / SEqHashDef into the Field.structural_eq
vocabulary.
+ # Decode SEqHashIgnore / SEqHashDef* into the Field.structural_eq
vocabulary.
if (field.flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) != 0:
c_structural_eq = "ignore"
- elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDef) != 0:
- c_structural_eq = "def"
+ elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefRecursive) != 0:
+ c_structural_eq = "def-recursive"
+ elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) !=
0:
+ c_structural_eq = "def-non-recursive"
else:
c_structural_eq = None
fields.append(
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index d7f39be..960d5cd 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -908,8 +908,14 @@ cdef _register_one_field(
cdef object field_structure = getattr(py_field, "structural_eq", None)
if field_structure == "ignore":
flags |= kTVMFFIFieldFlagBitMaskSEqHashIgnore
- elif field_structure == "def":
- flags |= kTVMFFIFieldFlagBitMaskSEqHashDef
+ elif field_structure == "def" or field_structure == "def-recursive":
+ # ``"def"`` is the legacy short form, kept as a Python-side synonym for
+ # ``"def-recursive"`` since the C-level rename of the underlying flag
+ # (``kTVMFFIFieldFlagBitMaskSEqHashDef`` -> ``...SEqHashDefRecursive``)
+ # only changed the constant name, not the recursive semantics.
+ flags |= kTVMFFIFieldFlagBitMaskSEqHashDefRecursive
+ elif field_structure == "def-non-recursive":
+ flags |= kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive
info.flags = flags
# --- native layout ---
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index 53f576c..08c2ea3 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -87,9 +87,19 @@ class Field:
structural comparison and hashing.
- ``"ignore"``: the field is excluded from structural equality
and hashing entirely (e.g. source spans, caches).
- - ``"def"``: the field is a **definition region** that introduces
- new variable bindings. Free variables encountered inside this
- field are mapped by position, enabling alpha-equivalence.
+ - ``"def-recursive"`` (alias: ``"def"``): the field is a
+ **recursive definition region** that introduces new variable
+ bindings. Free variables encountered anywhere in this field's
+ subtree (including inside the var's own sub-fields) are
+ mapped by position. One example is function parameter lists,
+ where the value var and any shape parameters in its type are
+ co-introduced at the same site.
+ - ``"def-non-recursive"``: the field is a **non-recursive
+ definition region**. Only the immediate free var(s) at this
+ field's value bind; free vars inside their sub-fields must
+ resolve against an outer binding (use semantics). One example
+ is a normal binding whose value type contains shape
+ parameters that reference outer-scope vars.
doc : str | None
Optional docstring for the field.
@@ -125,8 +135,12 @@ class Field:
doc: str | None
#: Valid values for the *structural_eq* parameter.
+ #:
+ #: ``"def"`` is kept as a Python-side alias for ``"def-recursive"`` to
+ #: preserve back-compat with code written against the old single-flag
+ #: ``SEqHashDef`` API.
_VALID_STRUCTURAL_EQ_VALUES: ClassVar[frozenset[str | None]] = frozenset(
- {None, "ignore", "def"}
+ {None, "ignore", "def", "def-recursive", "def-non-recursive"}
)
def __init__( # noqa: PLR0913
@@ -226,8 +240,12 @@ def field(
structural_eq
Structural equality/hashing annotation. ``None`` (default) means
the field participates normally. ``"ignore"`` excludes the field
- from structural comparison and hashing. ``"def"`` marks the field
- as a definition region for variable binding.
+ from structural comparison and hashing. ``"def-recursive"``
+ (alias ``"def"``) marks the field as a recursive definition
+ region: free vars in the field's whole subtree bind.
``"def-non-recursive"``
+ marks it as a non-recursive definition region: only immediate
+ free vars bind; nested free vars must resolve against an outer
+ binding.
doc
Optional docstring for the field.
diff --git a/src/ffi/extra/structural_equal.cc
b/src/ffi/extra/structural_equal.cc
index 5f4db3c..237287d 100644
--- a/src/ffi/extra/structural_equal.cc
+++ b/src/ffi/extra/structural_equal.cc
@@ -133,7 +133,7 @@ class StructEqualHandler {
}
}
- bool CompareObject(ObjectRef lhs, ObjectRef rhs) {
+ bool CompareObject(const ObjectRef& lhs, const ObjectRef& rhs) {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index());
if (type_info->metadata == nullptr) {
@@ -171,6 +171,41 @@ class StructEqualHandler {
}
}
+ if (structural_eq_hash_kind != kTVMFFISEqHashKindFreeVar) {
+ bool success = CompareFields(lhs, rhs, type_info);
+ if (success && structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
+ // record the equality mapping for DAG nodes
+ equal_map_lhs_[lhs] = rhs;
+ equal_map_rhs_[rhs] = lhs;
+ }
+ return success;
+ }
+ // FreeVar path. In a non-recursive def region the FreeVar's own
+ // sub-fields are walked outside the def region (nested free vars
+ // there must resolve against an outer binding, not rebind), so we
+ // clamp ``def_region_kind_`` to ``kNone`` around the CompareFields
+ // call and restore before the binding decision below.
+ TVMFFIDefRegionKind saved_def_region_kind = def_region_kind_;
+ if (def_region_kind_ == kTVMFFIDefRegionKindNonRecursive) {
+ def_region_kind_ = kTVMFFIDefRegionKindNone;
+ }
+ bool success = CompareFields(lhs, rhs, type_info);
+ def_region_kind_ = saved_def_region_kind;
+ if (!success) return false;
+ // FreeVar that is not yet mapped: bind it iff identity-equal or we
+ // are inside a def region.
+ if (lhs.same_as(rhs) || def_region_kind_ != kTVMFFIDefRegionKindNone) {
+ equal_map_lhs_[lhs] = rhs;
+ equal_map_rhs_[rhs] = lhs;
+ return true;
+ }
+ return false;
+ }
+
+ // Compare an object's fields (generic walk or via the custom
+ // __ffi_s_equal__ callback). Does not touch the FreeVar def-region
+ // clamp — the caller (CompareObject) handles that for the FreeVar case.
+ bool CompareFields(const ObjectRef& lhs, const ObjectRef& rhs, const
TVMFFITypeInfo* type_info) {
static reflection::TypeAttrColumn custom_s_equal =
reflection::TypeAttrColumn(reflection::type_attr::kSEqual);
@@ -184,12 +219,17 @@ class StructEqualHandler {
reflection::FieldGetter getter(field_info);
Any lhs_value = getter(lhs);
Any rhs_value = getter(rhs);
- // field is in def region, enable free var mapping
- if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) {
- bool allow_free_var = true;
- std::swap(allow_free_var, map_free_vars_);
+ // Dispatch on the def-region flags.
+ constexpr int64_t kSEqHashDefAny =
kTVMFFIFieldFlagBitMaskSEqHashDefRecursive |
+
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive;
+ if (field_info->flags & kSEqHashDefAny) {
+ TVMFFIDefRegionKind new_kind =
+ (field_info->flags &
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive)
+ ? kTVMFFIDefRegionKindNonRecursive
+ : kTVMFFIDefRegionKindRecursive;
+ std::swap(new_kind, def_region_kind_);
success = CompareAny(lhs_value, rhs_value);
- std::swap(allow_free_var, map_free_vars_);
+ std::swap(new_kind, def_region_kind_);
} else {
success = CompareAny(lhs_value, rhs_value);
}
@@ -212,19 +252,20 @@ class StructEqualHandler {
// run custom equal function defined via __s_equal__ type attribute
if (s_equal_callback_ == nullptr) {
s_equal_callback_ = ffi::Function::FromTyped(
- [this](AnyView lhs, AnyView rhs, bool def_region, AnyView
field_name) {
+ // Third parameter is a ``TVMFFIDefRegionKind`` (passed on the wire
+ // as ``int`` to keep the FFI signature stable across language
+ // boundaries).
+ [this](AnyView inner_lhs, AnyView inner_rhs, int def_region_kind,
AnyView field_name) {
// NOTE: we explicitly make field_name as AnyView to avoid copy
overhead initially
// and only cast to string if mismatch happens
- bool success = true;
- if (def_region) {
- bool allow_free_var = true;
- std::swap(allow_free_var, map_free_vars_);
- success = CompareAny(lhs, rhs);
- std::swap(allow_free_var, map_free_vars_);
- } else {
- success = CompareAny(lhs, rhs);
- }
- if (!success) {
+ TVMFFIDefRegionKind new_kind =
+ (def_region_kind == kTVMFFIDefRegionKindNone)
+ ? def_region_kind_
+ : static_cast<TVMFFIDefRegionKind>(def_region_kind);
+ std::swap(new_kind, def_region_kind_);
+ bool sub_success = CompareAny(inner_lhs, inner_rhs);
+ std::swap(new_kind, def_region_kind_);
+ if (!sub_success) {
if (mismatch_lhs_reverse_path_ != nullptr) {
String field_name_str = field_name.cast<String>();
mismatch_lhs_reverse_path_->emplace_back(
@@ -233,38 +274,14 @@ class StructEqualHandler {
reflection::AccessStep::Attr(field_name_str));
}
}
- return success;
+ return sub_success;
});
}
success = custom_s_equal[type_info->type_index]
.cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
.cast<bool>();
}
-
- if (success) {
- if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
- // we are in a free var case that is not yet mapped.
- // in this case, either map_free_vars_ should be set to true, or
map_free_vars_ should be
- // set
- if (lhs.same_as(rhs) || map_free_vars_) {
- // record the equality
- equal_map_lhs_[lhs] = rhs;
- equal_map_rhs_[rhs] = lhs;
- return true;
- } else {
- return false;
- }
- }
- // if we have a success mapping and in graph/var mode, record the
equality mapping
- if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
- // record the equality
- equal_map_lhs_[lhs] = rhs;
- equal_map_rhs_[rhs] = lhs;
- }
- return true;
- } else {
- return false;
- }
+ return success;
}
template <typename MapType>
@@ -413,8 +430,12 @@ class StructEqualHandler {
}
return rhs_obj;
}
- // whether we map free variables that are not defined
- bool map_free_vars_{false};
+ // Current def-region kind. ``kNone`` means we are not in a def region;
+ // free vars discovered here do not bind (they must already be bound by an
+ // outer scope or comparison falls back to pointer identity). ``kRecursive``
+ // and ``kNonRecursive`` enable binding for the field-flag-driven walk and
+ // for the custom-callback path respectively (see CompareObject).
+ TVMFFIDefRegionKind def_region_kind_{kTVMFFIDefRegionKindNone};
// whether we compare tensor data
bool skip_tensor_content_{false};
// the root lhs for result printing
@@ -431,7 +452,8 @@ class StructEqualHandler {
bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars,
bool skip_tensor_content) {
StructEqualHandler handler;
- handler.map_free_vars_ = map_free_vars;
+ handler.def_region_kind_ =
+ map_free_vars ? kTVMFFIDefRegionKindRecursive : kTVMFFIDefRegionKindNone;
handler.skip_tensor_content_ = skip_tensor_content;
return handler.CompareAny(lhs, rhs);
}
@@ -441,7 +463,8 @@ Optional<reflection::AccessPathPair>
StructuralEqual::GetFirstMismatch(const Any
bool
map_free_vars,
bool
skip_tensor_content) {
StructEqualHandler handler;
- handler.map_free_vars_ = map_free_vars;
+ handler.def_region_kind_ =
+ map_free_vars ? kTVMFFIDefRegionKindRecursive : kTVMFFIDefRegionKindNone;
handler.skip_tensor_content_ = skip_tensor_content;
std::vector<reflection::AccessStep> lhs_reverse_path;
std::vector<reflection::AccessStep> rhs_reverse_path;
diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc
index 8ab96f0..21c5545 100644
--- a/src/ffi/extra/structural_hash.cc
+++ b/src/ffi/extra/structural_hash.cc
@@ -102,7 +102,7 @@ class StructuralHashHandler {
}
}
- uint64_t HashObject(ObjectRef obj) {
+ uint64_t HashObject(const ObjectRef& obj) {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index());
if (type_info->metadata == nullptr) {
@@ -127,71 +127,102 @@ class StructuralHashHandler {
return it->second;
}
+ uint64_t hash_value;
+ if (structural_eq_hash_kind != kTVMFFISEqHashKindFreeVar) {
+ hash_value = HashFields(obj, type_info, obj->GetTypeKeyHash());
+ } else {
+ // FreeVar path. In a non-recursive def region the FreeVar's own
+ // sub-fields are walked outside the def region (nested free vars
+ // there hash by pointer, matching use semantics), so we clamp
+ // ``def_region_kind_`` to ``kNone`` around the HashFields call and
+ // restore before the FreeVar-level injection below.
+ //
+ // We always call HashFields, even in use mode where the returned
+ // ``hash_value`` is discarded by the pointer-hash fallback. The
+ // walk's side effect on ``free_var_counter_`` (incremented for
+ // every nested FreeVar reached via SEqHashDef-tagged sub-fields)
+ // is observable to FreeVars hashed later in the same traversal;
+ // skipping the walk would silently change those subsequent hashes.
+ TVMFFIDefRegionKind saved_def_region_kind = def_region_kind_;
+ if (def_region_kind_ == kTVMFFIDefRegionKindNonRecursive) {
+ def_region_kind_ = kTVMFFIDefRegionKindNone;
+ }
+ hash_value = HashFields(obj, type_info, obj->GetTypeKeyHash());
+ def_region_kind_ = saved_def_region_kind;
+ if (def_region_kind_ != kTVMFFIDefRegionKindNone) {
+ // use lexical order of free var and its type
+ hash_value = details::StableHashCombine(hash_value,
free_var_counter_++);
+ } else {
+ // Fallback to pointer hash; we are not in a def region.
+ hash_value = std::hash<const Object*>()(obj.get());
+ }
+ }
+
+ // if it is a DAG node, also record the lexical order of graph counter
+ // this helps to distinguish DAG from trees.
+ if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
+ hash_value = details::StableHashCombine(hash_value,
graph_node_counter_++);
+ }
+ // record the hash value for this object
+ hash_memo_[obj] = hash_value;
+ return hash_value;
+ }
+
+ // Hash an object's fields (generic walk or via the custom __ffi_s_hash__
+ // callback). Does not touch the FreeVar def-region clamp — that lives
+ // inline in HashObject's FreeVar branch, which wraps this helper.
+ uint64_t HashFields(const ObjectRef& obj, const TVMFFITypeInfo* type_info,
uint64_t init_hash) {
static reflection::TypeAttrColumn custom_s_hash =
reflection::TypeAttrColumn(reflection::type_attr::kSHash);
- // compute the hash value
- uint64_t hash_value = obj->GetTypeKeyHash();
if (custom_s_hash[type_info->type_index] == nullptr) {
// go over the content and hash the fields
reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
field_info) {
// skip fields that are marked as structural eq hash ignore
if (!(field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore)) {
- // get the field value from both side
reflection::FieldGetter getter(field_info);
Any field_value = getter(obj);
- // field is in def region, enable free var mapping
- if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) {
- bool allow_free_var = true;
- std::swap(allow_free_var, map_free_vars_);
- hash_value = details::StableHashCombine(hash_value,
HashAny(field_value));
- std::swap(allow_free_var, map_free_vars_);
+ // Dispatch on the def-region flags (mirror of the equality side).
+ constexpr int64_t kSEqHashDefAny =
kTVMFFIFieldFlagBitMaskSEqHashDefRecursive |
+
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive;
+ if (field_info->flags & kSEqHashDefAny) {
+ TVMFFIDefRegionKind new_kind =
+ (field_info->flags &
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive)
+ ? kTVMFFIDefRegionKindNonRecursive
+ : kTVMFFIDefRegionKindRecursive;
+ std::swap(new_kind, def_region_kind_);
+ init_hash = details::StableHashCombine(init_hash,
HashAny(field_value));
+ std::swap(new_kind, def_region_kind_);
} else {
- hash_value = details::StableHashCombine(hash_value,
HashAny(field_value));
+ init_hash = details::StableHashCombine(init_hash,
HashAny(field_value));
}
}
});
} else {
if (s_hash_callback_ == nullptr) {
s_hash_callback_ =
- ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash,
bool def_region) {
- if (def_region) {
- bool allow_free_var = true;
- std::swap(allow_free_var, map_free_vars_);
- uint64_t hash_value = HashAny(val);
- std::swap(allow_free_var, map_free_vars_);
- return
static_cast<int64_t>(details::StableHashCombine(init_hash, hash_value));
- } else {
- // we explicitly bitcast the result from `uint64_t` to
`int64_t`.
- // The range of `uint64_t` is too large to fit as `int64_t`,
so if we don't bitcast,
- // it will trigger an overflow error in `uint64_t` -> `Any`
conversion.
- return
static_cast<int64_t>(details::StableHashCombine(init_hash, HashAny(val)));
- }
+ // Third parameter is a ``TVMFFIDefRegionKind`` (passed on the wire
+ // as ``int`` to keep the FFI signature stable across language
+ // boundaries).
+ ffi::Function::FromTyped([this](AnyView val, uint64_t inner_init,
int def_region_kind) {
+ TVMFFIDefRegionKind new_kind =
+ (def_region_kind == kTVMFFIDefRegionKindNone)
+ ? def_region_kind_
+ : static_cast<TVMFFIDefRegionKind>(def_region_kind);
+ std::swap(new_kind, def_region_kind_);
+ uint64_t hv = HashAny(val);
+ std::swap(new_kind, def_region_kind_);
+ // we explicitly bitcast the result from `uint64_t` to `int64_t`.
+ // The range of `uint64_t` is too large to fit as `int64_t`, so
if we don't bitcast,
+ // it will trigger an overflow error in `uint64_t` -> `Any`
conversion.
+ return
static_cast<int64_t>(details::StableHashCombine(inner_init, hv));
});
}
- hash_value =
- custom_s_hash[type_info->type_index]
- .cast<ffi::Function>()(obj, static_cast<int64_t>(hash_value),
s_hash_callback_)
- .cast<uint64_t>();
- }
-
- if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
- if (map_free_vars_) {
- // use lexical order of free var and its type
- hash_value = details::StableHashCombine(hash_value,
free_var_counter_++);
- } else {
- // Fallback to pointer hash, we are not mapping free var.
- hash_value = std::hash<const Object*>()(obj.get());
- }
+ init_hash = custom_s_hash[type_info->type_index]
+ .cast<ffi::Function>()(obj,
static_cast<int64_t>(init_hash), s_hash_callback_)
+ .cast<uint64_t>();
}
- // if it is a DAG node, also record the lexical order of graph counter
- // this helps to distinguish DAG from trees.
- if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
- hash_value = details::StableHashCombine(hash_value,
graph_node_counter_++);
- }
- // record the hash value for this object
- hash_memo_[obj] = hash_value;
- return hash_value;
+ return init_hash;
}
// NOLINTNEXTLINE(performance-unnecessary-value-param)
@@ -317,7 +348,11 @@ class StructuralHashHandler {
return hash_value;
}
- bool map_free_vars_{false};
+ // Current def-region kind. ``kNone`` means we are not in a def region; free
+ // vars hash by pointer. ``kRecursive`` and ``kNonRecursive`` enable
+ // ``free_var_counter_``-based hashing for the field-flag-driven walk and
+ // for the custom-callback path respectively (see HashObject).
+ TVMFFIDefRegionKind def_region_kind_{kTVMFFIDefRegionKindNone};
bool skip_tensor_content_{false};
// free var counter.
uint32_t free_var_counter_{0};
@@ -331,7 +366,8 @@ class StructuralHashHandler {
uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool
skip_tensor_content) {
StructuralHashHandler handler;
- handler.map_free_vars_ = map_free_vars;
+ handler.def_region_kind_ =
+ map_free_vars ? kTVMFFIDefRegionKindRecursive : kTVMFFIDefRegionKindNone;
handler.skip_tensor_content_ = skip_tensor_content;
return handler.HashAny(value);
}
diff --git a/tests/cpp/extra/test_structural_equal_hash.cc
b/tests/cpp/extra/test_structural_equal_hash.cc
index ad081e3..6ce380a 100644
--- a/tests/cpp/extra/test_structural_equal_hash.cc
+++ b/tests/cpp/extra/test_structural_equal_hash.cc
@@ -229,6 +229,70 @@ TEST(StructuralEqualHash, CustomTreeNode) {
EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
}
+// Regression tests for the SEqHashDefRecursive vs SEqHashDefNonRecursive
+// distinction. ``TDefHolder`` has two sibling fields:
+// - ``def_recursive`` tagged AttachFieldFlag::SEqHashDefRecursive()
+// - ``def_non_recursive`` tagged AttachFieldFlag::SEqHashDefNonRecursive()
+// each holding a ``TVarWithDep`` (a FreeVar with a sub-field ``dep`` that
+// can itself reference another FreeVar). The four sub-cases below cover
+// the observable behaviors of the two flags.
+TEST(StructuralEqualHash, NonRecursiveDef) {
+ {
+ // (a) Recursive flag rebinds nested FreeVars transitively.
+ // ``def_non_recursive`` is the same object on both sides so it equates
+ // by pointer; the case isolates the recursive field's rebinding.
+ SCOPED_TRACE("recursive flag rebinds nested FreeVars");
+ TVarWithDep a("a", TVar("m"));
+ TVarWithDep b("b", TVar("n"));
+ TDefHolder lhs(/*def_recursive=*/a, /*def_non_recursive=*/a);
+ TDefHolder rhs(/*def_recursive=*/b, /*def_non_recursive=*/b);
+ EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+ EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true),
+ StructuralHash::Hash(rhs, /*map_free_vars=*/true));
+ }
+ {
+ // (b) Non-recursive flag does NOT rebind nested FreeVars: the top-level
+ // FreeVar binds but the nested ``dep`` is clamped out of the def region.
+ // With no outer binding for "p"/"q", equality must fail.
+ SCOPED_TRACE("non-recursive flag does not rebind nested FreeVars");
+ TVarWithDep shared("shared", std::nullopt);
+ TVarWithDep c_with_dep("c", TVar("p"));
+ TVarWithDep d_with_dep("d", TVar("q"));
+ TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep);
+ TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep);
+ EXPECT_FALSE(StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false));
+ }
+ {
+ // (c) Non-recursive flag works if nested FreeVars resolve via an outer
+ // binding — here we cheat by wiring the same pointer, so the nested
+ // FreeVar passes the same-as pointer check without needing the def
+ // region to be on inside its sub-field walk.
+ SCOPED_TRACE("nested FreeVars resolve via outer pointer identity");
+ TVar shared_dep("dep");
+ TVarWithDep c_with_dep("c", shared_dep);
+ TVarWithDep d_with_dep("d", shared_dep);
+ TVarWithDep shared("shared", std::nullopt);
+ TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep);
+ TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep);
+ EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+ EXPECT_EQ(StructuralHash()(lhs), StructuralHash()(rhs));
+ }
+ {
+ // (d) Top-level FreeVar still binds under non-recursive — only the
+ // FreeVar's sub-fields are clamped out; the binding step itself for
+ // the immediate FreeVar is not suppressed.
+ SCOPED_TRACE("top-level FreeVar still binds under non-recursive flag");
+ TVarWithDep shared("shared", std::nullopt);
+ TVarWithDep c_no_dep("c", std::nullopt);
+ TVarWithDep d_no_dep("d", std::nullopt);
+ TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_no_dep);
+ TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_no_dep);
+ EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+ EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true),
+ StructuralHash::Hash(rhs, /*map_free_vars=*/true));
+ }
+}
+
TEST(StructuralEqualHash, List) {
List<int> a = {1, 2, 3};
List<int> b = {1, 2, 3};
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index 0593eca..6f99a74 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -66,6 +66,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
TFloatObj::RegisterReflection();
TPrimExprObj::RegisterReflection();
TVarObj::RegisterReflection();
+ TVarWithDepObj::RegisterReflection();
+ TDefHolderObj::RegisterReflection();
TFuncObj::RegisterReflection();
TCustomFuncObj::RegisterReflection();
TAllFieldsObj::RegisterReflection();
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index 48b1a01..d1bf0a9 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -206,6 +206,81 @@ class TVar : public ObjectRef {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVar, ObjectRef, TVarObj);
};
+// FreeVar test object that has a sub-field referencing another FreeVar.
+// This models the "var with nested vars" case (analogous to a relax::Var
+// whose struct_info contains tir shape vars). It is used to exercise the
+// difference between SEqHashDefRecursive and SEqHashDefNonRecursive at the
+// FFI layer: under recursive semantics the nested ``dep`` var rebinds
+// transitively; under non-recursive semantics it is treated as a use of an
+// outer-scope binding and equality fails when no such outer binding exists.
+class TVarWithDepObj : public Object {
+ public:
+ std::string name;
+ // Optional dependency var; when null, this object behaves like a plain
+ // FreeVar with no nested free vars.
+ Optional<ObjectRef> dep;
+
+ TVarWithDepObj(std::string name, Optional<ObjectRef> dep)
+ : name(std::move(name)), dep(std::move(dep)) {}
+ explicit TVarWithDepObj(UnsafeInit) {}
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TVarWithDepObj>()
+ .def_ro("name", &TVarWithDepObj::name,
refl::AttachFieldFlag::SEqHashIgnore())
+ // ``dep`` participates in structural equality without any def flag,
+ // so it is a USE position. Whether the FreeVar in ``dep`` may rebind
+ // is decided by the def flag on whichever outer field reaches this
+ // object.
+ .def_ro("dep", &TVarWithDepObj::dep);
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindFreeVar;
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.VarWithDep", TVarWithDepObj, Object);
+};
+
+class TVarWithDep : public ObjectRef {
+ public:
+ explicit TVarWithDep(std::string name, Optional<ObjectRef> dep =
std::nullopt) {
+ data_ = make_object<TVarWithDepObj>(std::move(name), std::move(dep));
+ }
+
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVarWithDep, ObjectRef,
TVarWithDepObj);
+};
+
+// Holder with one recursive-def field and one non-recursive-def field.
+// Used by StructuralEqualHash.NonRecursiveDef tests below.
+class TDefHolderObj : public Object {
+ public:
+ TVarWithDep def_recursive;
+ TVarWithDep def_non_recursive;
+
+ TDefHolderObj(TVarWithDep def_recursive, TVarWithDep def_non_recursive)
+ : def_recursive(std::move(def_recursive)),
def_non_recursive(std::move(def_non_recursive)) {}
+ explicit TDefHolderObj(UnsafeInit) {}
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TDefHolderObj>()
+ .def_ro("def_recursive", &TDefHolderObj::def_recursive,
+ refl::AttachFieldFlag::SEqHashDefRecursive())
+ .def_ro("def_non_recursive", &TDefHolderObj::def_non_recursive,
+ refl::AttachFieldFlag::SEqHashDefNonRecursive());
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.DefHolder", TDefHolderObj, Object);
+};
+
+class TDefHolder : public ObjectRef {
+ public:
+ explicit TDefHolder(TVarWithDep def_recursive, TVarWithDep
def_non_recursive) {
+ data_ = make_object<TDefHolderObj>(std::move(def_recursive),
std::move(def_non_recursive));
+ }
+
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TDefHolder, ObjectRef,
TDefHolderObj);
+};
+
class TFuncObj : public Object {
public:
Array<TVar> params;
@@ -220,7 +295,7 @@ class TFuncObj : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TFuncObj>()
- .def_ro("params", &TFuncObj::params,
refl::AttachFieldFlag::SEqHashDef())
+ .def_ro("params", &TFuncObj::params,
refl::AttachFieldFlag::SEqHashDefRecursive())
.def_ro("body", &TFuncObj::body)
.def_ro("comment", &TFuncObj::comment,
refl::AttachFieldFlag::SEqHashIgnore());
}