This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new de93d56  [FEAT] Split SEqHashDef into recursive and non-recursive 
variants (#583)
de93d56 is described below

commit de93d564a13c7f4d74255b907dd2370ad9874cd8
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon May 11 12:57:02 2026 -0400

    [FEAT] Split SEqHashDef into recursive and non-recursive variants (#583)
    
    The single SEqHashDef flag treated every nested free var inside a
    def-region field as a fresh def. That conflates two different binding
    shapes:
    
    - Recursive (function parameters): the value var and any free vars
      inside its sub-fields are co-introduced at the same binding site.
    
    - Non-recursive (normal binding): only the immediate value var binds;
      free vars in the var's sub-fields are use references that must
      resolve against an outer-scope binding.
    
    This PR splits the flag in two so each binding shape can be expressed
    directly, and updates the structural-equal / structural-hash machinery
    to distinguish them when descending into a FreeVar's own sub-fields.
    The custom `__ffi_s_equal__` / `__ffi_s_hash__` callback signature
    gains a typed kind for the def-mode argument; legacy callers passing
    a bool continue to compile and preserve their meaning via the standard
    bool->int coercion (true -> 1 = Recursive, false -> 0 = None).
    
    The Cython / Python bindings mirror the rename and add the new
    non-recursive flag to the dataclass field vocabulary. Regression tests
    cover the four recursive / non-recursive corner cases on a new
    `TDefHolder` test type whose `def_recursive` and `def_non_recursive`
    fields hold a FreeVar with a nested FreeVar sub-field.
---
 docs/concepts/structural_eq_hash.rst          |  72 +++++++++++---
 include/tvm/ffi/c_api.h                       |  68 ++++++++++++-
 include/tvm/ffi/reflection/registry.h         |  21 +++-
 python/tvm_ffi/cython/base.pxi                |   8 +-
 python/tvm_ffi/cython/object.pxi              |   8 +-
 python/tvm_ffi/cython/type_info.pxi           |  10 +-
 python/tvm_ffi/dataclasses/field.py           |  30 ++++--
 src/ffi/extra/structural_equal.cc             | 117 ++++++++++++++---------
 src/ffi/extra/structural_hash.cc              | 132 ++++++++++++++++----------
 tests/cpp/extra/test_structural_equal_hash.cc |  64 +++++++++++++
 tests/cpp/test_reflection.cc                  |   2 +
 tests/cpp/testing_object.h                    |  77 ++++++++++++++-
 12 files changed, 481 insertions(+), 128 deletions(-)

diff --git a/docs/concepts/structural_eq_hash.rst 
b/docs/concepts/structural_eq_hash.rst
index 55614c3..8dff838 100644
--- a/docs/concepts/structural_eq_hash.rst
+++ b/docs/concepts/structural_eq_hash.rst
@@ -618,14 +618,14 @@ Use for:
   redundant to compare.
 - **Debug annotations** — names, comments, metadata for human consumption.
 
-``structural_eq="def"`` — Definition region
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+``structural_eq="def-recursive"`` / ``"def-non-recursive"`` — Definition region
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 .. code-block:: python
 
    @py_class(structural_eq="tree")
    class Lambda(Object):
-       params: list[Var] = field(structural_eq="def")
+       params: list[Var] = field(structural_eq="def-recursive")
        body: Expr
 
 **Meaning**: "This field introduces new variable bindings. When comparing
@@ -633,15 +633,44 @@ or hashing this field, allow new variable correspondences 
to be
 established."
 
 This is the counterpart to ``"var"``. A ``"var"`` type says "I am a
-variable"; ``structural_eq="def"`` says "this field is where variables are
-defined." Together they enable alpha-equivalence: comparing functions up
-to consistent variable renaming.
+variable"; the ``"def-*"`` flags on a field say "this field is where
+variables are defined." Together they enable alpha-equivalence:
+comparing functions up to consistent variable renaming.
+
+There are two flavors of definition region, distinguished by what
+happens when a ``"var"`` reached through the field carries its own
+sub-fields (for example, a shape annotation in the var's type):
+
+- ``"def-recursive"`` (alias: ``"def"``) — the variable's sub-fields
+  stay inside the definition region. Any free variables encountered
+  in those sub-fields are themselves treated as fresh definitions at
+  the same site. One example is **function parameter lists**, where
+  the value var and any shape parameters in its type are co-introduced
+  together at the function boundary.
+
+- ``"def-non-recursive"`` — only the immediate variable(s) reached
+  through the field bind. The variable's sub-fields are walked
+  outside the definition region, so any free variables there are
+  *use* references that must resolve against an outer-scope binding.
+  One example is a **normal binding** whose value type references
+  outer-scope shape parameters (a ``let v = expr`` where ``v``'s
+  type refers to vars defined earlier).
+
+When the distinction does not matter (no nested free vars under the
+bound variable), either flavor works and ``"def-recursive"`` is the
+conventional default — that's why the bare ``"def"`` alias resolves
+to it.
 
 Use for:
 
-- **Function parameter lists**
-- **Let-binding left-hand sides**
-- **Any field that introduces names into scope**
+- **Function parameter lists** — ``"def-recursive"`` so shape
+  parameters in each param's type co-introduce at the same site.
+- **Normal binding left-hand sides** (let bindings, for-loop
+  iterators) whose value type references outer-scope vars —
+  ``"def-non-recursive"`` so those references don't rebind.
+- **Any field that introduces names into scope** — pick the flavor
+  that matches the binding form's contract; default to
+  ``"def-recursive"`` when in doubt.
 
 
 .. _sequal-shash:
@@ -679,7 +708,7 @@ Signatures
 
    (self, other, eq_cb) -> bool
 
-   eq_cb(lhs, rhs, def_region: bool, field_name: str) -> bool
+   eq_cb(lhs, rhs, def_region_kind: int, field_name: str) -> bool
 
 ``__s_hash__``:
 
@@ -687,12 +716,25 @@ Signatures
 
    (self, init_hash: int, hash_cb) -> int
 
-   hash_cb(value, init_hash: int, def_region: bool) -> int
+   hash_cb(value, init_hash: int, def_region_kind: int) -> int
+
+The ``def_region_kind`` argument on each recursive call mirrors the
+field-level ``"def-*"`` flags and controls whether the sub-value is
+compared/hashed inside a definition region:
+
+- ``0`` — not in a def region (matches ``None`` on a field).
+- ``1`` — recursive def region (matches ``"def-recursive"``, alias
+  ``"def"``).
+- ``2`` — non-recursive def region (matches ``"def-non-recursive"``).
+
+For back-compat with the original single-flag API, the callback also
+accepts a plain ``bool``: ``True`` is treated as ``1`` (recursive) and
+``False`` as ``0`` (not in a def region). The Python examples below
+use ``True`` / ``False`` for that reason; pass an explicit ``2`` (or
+the ``kTVMFFIDefRegionKindNonRecursive`` enum value from C++) when the
+non-recursive kind is needed.
 
-The ``def_region`` flag on each recursive call controls whether the
-sub-value is compared/hashed in a definition region (enabling new
-variable bindings, just like ``field(structural_eq="def")``).  The
-``field_name`` argument on ``eq_cb`` is used only for mismatch path
+The ``field_name`` argument on ``eq_cb`` is used only for mismatch path
 reporting from :py:func:`~tvm_ffi.get_first_structural_mismatch`.
 
 Example (Python)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 0d2c9df..52b6337 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -858,11 +858,13 @@ typedef enum {
    */
   kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3,
   /*!
-   * \brief The field enters a def region where var can be defined/matched.
+   * \brief The field enters a recursive def region.
    *
    * This is an optional meta-data for structural eq/hash.
+   *
+   * \sa TVMFFIDefRegionKind for the def-region semantics.
    */
-  kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4,
+  kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4,
   /*!
    * \brief The default_value_or_factory is a callable factory function () -> 
Any.
    *
@@ -922,6 +924,17 @@ typedef enum {
    * ``(field_addr_as_OpaquePtr, value_as_AnyView)``.
    */
   kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11,
+  /*!
+   * \brief The field enters a non-recursive def region.
+   *
+   * This is an optional meta-data for structural eq/hash.
+   *
+   * \sa TVMFFIDefRegionKind for the def-region semantics.
+   *
+   * \note Bit 1 << 12 is used here because bits 1 << 5 .. 1 << 11 are
+   *       already taken by other field flags above.
+   */
+  kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12,
 #ifdef __cplusplus
 };
 #else
@@ -993,6 +1006,57 @@ typedef enum {
 } TVMFFISEqHashKind;
 #endif
 
+/*!
+ * \brief Kind of def region a structural-equal / structural-hash callback is
+ *        currently in when visiting a field.
+ */
+#ifdef __cplusplus
+enum TVMFFIDefRegionKind : int32_t {
+#else
+typedef enum {
+#endif
+  /*!
+   * \brief Not in a def region.
+   *
+   * Free vars reachable through this field are uses; they must already
+   * be bound by an enclosing def region or equality / hashing falls
+   * back to pointer identity.
+   */
+  kTVMFFIDefRegionKindNone = 0,
+  /*!
+   * \brief In a recursive def region.
+   *
+   * When we see a free var for the first time, we define the var, and
+   * the sub-fields of the var (e.g. its struct_info / type_annotation /
+   * shape) are also still in the def region — any free vars discovered
+   * inside those sub-fields are themselves treated as fresh defs at the
+   * same site.
+   *
+   * One example is function parameter lists: the value var and any
+   * shape parameters in its type are co-introduced at the same binding
+   * site.
+   */
+  kTVMFFIDefRegionKindRecursive = 1,
+  /*!
+   * \brief In a non-recursive def region.
+   *
+   * When we see a free var for the first time, we define the var, but
+   * the sub-fields of the var are NOT in the def region — they are
+   * treated as use references that must resolve against an outer
+   * binding. Free vars found in those sub-fields therefore do not
+   * rebind; if they are not already bound, equality fails.
+   *
+   * One example is a normal binding whose value type contains shape
+   * parameters: the value var is introduced fresh, but its shape
+   * parameters reference vars defined in an outer scope.
+   */
+  kTVMFFIDefRegionKindNonRecursive = 2,
+#ifdef __cplusplus
+};
+#else
+} TVMFFIDefRegionKind;
+#endif
+
 /*!
  * \brief Information support for optional object reflection.
  */
diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index 543321c..f810e5f 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -213,10 +213,25 @@ class AttachFieldFlag : public InfoTrait {
   explicit AttachFieldFlag(int32_t flag) : flag_(flag) {}
 
   /*!
-   * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef
+   * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefRecursive
+   *
+   * The field enters a recursive def region: free vars discovered both at
+   * the field's value and inside that value's sub-fields bind as fresh
+   * defs at the same site. Use for "function-style" bindings.
+   */
+  TVM_FFI_INLINE static AttachFieldFlag SEqHashDefRecursive() {
+    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefRecursive);
+  }
+  /*!
+   * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive
+   *
+   * The field enters a non-recursive def region: only the immediate free
+   * var at the field's value binds; free vars in its sub-fields are uses
+   * that must already be bound by an outer def region. Use for "let-style"
+   * bindings whose sub-fields reference outer-scope vars.
    */
-  TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() {
-    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef);
+  TVM_FFI_INLINE static AttachFieldFlag SEqHashDefNonRecursive() {
+    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive);
   }
   /*!
    * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index c5c28a1..a249207 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -207,7 +207,7 @@ cdef extern from "tvm/ffi/c_api.h":
         kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
         kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
         kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3
-        kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4
+        kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4
         kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5
         kTVMFFIFieldFlagBitMaskReprOff = 1 << 6
         kTVMFFIFieldFlagBitMaskCompareOff = 1 << 7
@@ -215,6 +215,7 @@ cdef extern from "tvm/ffi/c_api.h":
         kTVMFFIFieldFlagBitMaskInitOff = 1 << 9
         kTVMFFIFieldFlagBitMaskKwOnly = 1 << 10
         kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11
+        kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12
 
     ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
     ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) 
noexcept
@@ -248,6 +249,11 @@ cdef extern from "tvm/ffi/c_api.h":
         kTVMFFISEqHashKindConstTreeNode = 4
         kTVMFFISEqHashKindUniqueInstance = 5
 
+    cdef enum TVMFFIDefRegionKind:
+        kTVMFFIDefRegionKindNone = 0
+        kTVMFFIDefRegionKindRecursive = 1
+        kTVMFFIDefRegionKindNonRecursive = 2
+
     ctypedef struct TVMFFITypeMetadata:
         TVMFFIByteArray doc
         TVMFFIObjectCreator creator
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 803ead2..c564b42 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -541,11 +541,13 @@ cdef _type_info_create_from_type_key(object type_cls, str 
type_key):
                 c_default_factory = make_ret(owned_default)
             else:
                 c_default = make_ret(owned_default)
-        # Decode SEqHashIgnore / SEqHashDef into the Field.structural_eq 
vocabulary.
+        # Decode SEqHashIgnore / SEqHashDef* into the Field.structural_eq 
vocabulary.
         if (field.flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) != 0:
             c_structural_eq = "ignore"
-        elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDef) != 0:
-            c_structural_eq = "def"
+        elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefRecursive) != 0:
+            c_structural_eq = "def-recursive"
+        elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) != 
0:
+            c_structural_eq = "def-non-recursive"
         else:
             c_structural_eq = None
         fields.append(
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index d7f39be..960d5cd 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -908,8 +908,14 @@ cdef _register_one_field(
     cdef object field_structure = getattr(py_field, "structural_eq", None)
     if field_structure == "ignore":
         flags |= kTVMFFIFieldFlagBitMaskSEqHashIgnore
-    elif field_structure == "def":
-        flags |= kTVMFFIFieldFlagBitMaskSEqHashDef
+    elif field_structure == "def" or field_structure == "def-recursive":
+        # ``"def"`` is the legacy short form, kept as a Python-side synonym for
+        # ``"def-recursive"`` since the C-level rename of the underlying flag
+        # (``kTVMFFIFieldFlagBitMaskSEqHashDef`` -> ``...SEqHashDefRecursive``)
+        # only changed the constant name, not the recursive semantics.
+        flags |= kTVMFFIFieldFlagBitMaskSEqHashDefRecursive
+    elif field_structure == "def-non-recursive":
+        flags |= kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive
     info.flags = flags
 
     # --- native layout ---
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
index 53f576c..08c2ea3 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -87,9 +87,19 @@ class Field:
           structural comparison and hashing.
         - ``"ignore"``: the field is excluded from structural equality
           and hashing entirely (e.g. source spans, caches).
-        - ``"def"``: the field is a **definition region** that introduces
-          new variable bindings.  Free variables encountered inside this
-          field are mapped by position, enabling alpha-equivalence.
+        - ``"def-recursive"`` (alias: ``"def"``): the field is a
+          **recursive definition region** that introduces new variable
+          bindings.  Free variables encountered anywhere in this field's
+          subtree (including inside the var's own sub-fields) are
+          mapped by position. One example is function parameter lists,
+          where the value var and any shape parameters in its type are
+          co-introduced at the same site.
+        - ``"def-non-recursive"``: the field is a **non-recursive
+          definition region**.  Only the immediate free var(s) at this
+          field's value bind; free vars inside their sub-fields must
+          resolve against an outer binding (use semantics). One example
+          is a normal binding whose value type contains shape
+          parameters that reference outer-scope vars.
     doc : str | None
         Optional docstring for the field.
 
@@ -125,8 +135,12 @@ class Field:
     doc: str | None
 
     #: Valid values for the *structural_eq* parameter.
+    #:
+    #: ``"def"`` is kept as a Python-side alias for ``"def-recursive"`` to
+    #: preserve back-compat with code written against the old single-flag
+    #: ``SEqHashDef`` API.
     _VALID_STRUCTURAL_EQ_VALUES: ClassVar[frozenset[str | None]] = frozenset(
-        {None, "ignore", "def"}
+        {None, "ignore", "def", "def-recursive", "def-non-recursive"}
     )
 
     def __init__(  # noqa: PLR0913
@@ -226,8 +240,12 @@ def field(
     structural_eq
         Structural equality/hashing annotation. ``None`` (default) means
         the field participates normally. ``"ignore"`` excludes the field
-        from structural comparison and hashing. ``"def"`` marks the field
-        as a definition region for variable binding.
+        from structural comparison and hashing. ``"def-recursive"``
+        (alias ``"def"``) marks the field as a recursive definition
+        region: free vars in the field's whole subtree bind. 
``"def-non-recursive"``
+        marks it as a non-recursive definition region: only immediate
+        free vars bind; nested free vars must resolve against an outer
+        binding.
     doc
         Optional docstring for the field.
 
diff --git a/src/ffi/extra/structural_equal.cc 
b/src/ffi/extra/structural_equal.cc
index 5f4db3c..237287d 100644
--- a/src/ffi/extra/structural_equal.cc
+++ b/src/ffi/extra/structural_equal.cc
@@ -133,7 +133,7 @@ class StructEqualHandler {
     }
   }
 
-  bool CompareObject(ObjectRef lhs, ObjectRef rhs) {
+  bool CompareObject(const ObjectRef& lhs, const ObjectRef& rhs) {
     // NOTE: invariant: lhs and rhs are already the same type
     const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index());
     if (type_info->metadata == nullptr) {
@@ -171,6 +171,41 @@ class StructEqualHandler {
       }
     }
 
+    if (structural_eq_hash_kind != kTVMFFISEqHashKindFreeVar) {
+      bool success = CompareFields(lhs, rhs, type_info);
+      if (success && structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
+        // record the equality mapping for DAG nodes
+        equal_map_lhs_[lhs] = rhs;
+        equal_map_rhs_[rhs] = lhs;
+      }
+      return success;
+    }
+    // FreeVar path. In a non-recursive def region the FreeVar's own
+    // sub-fields are walked outside the def region (nested free vars
+    // there must resolve against an outer binding, not rebind), so we
+    // clamp ``def_region_kind_`` to ``kNone`` around the CompareFields
+    // call and restore before the binding decision below.
+    TVMFFIDefRegionKind saved_def_region_kind = def_region_kind_;
+    if (def_region_kind_ == kTVMFFIDefRegionKindNonRecursive) {
+      def_region_kind_ = kTVMFFIDefRegionKindNone;
+    }
+    bool success = CompareFields(lhs, rhs, type_info);
+    def_region_kind_ = saved_def_region_kind;
+    if (!success) return false;
+    // FreeVar that is not yet mapped: bind it iff identity-equal or we
+    // are inside a def region.
+    if (lhs.same_as(rhs) || def_region_kind_ != kTVMFFIDefRegionKindNone) {
+      equal_map_lhs_[lhs] = rhs;
+      equal_map_rhs_[rhs] = lhs;
+      return true;
+    }
+    return false;
+  }
+
+  // Compare an object's fields (generic walk or via the custom
+  // __ffi_s_equal__ callback). Does not touch the FreeVar def-region
+  // clamp — the caller (CompareObject) handles that for the FreeVar case.
+  bool CompareFields(const ObjectRef& lhs, const ObjectRef& rhs, const 
TVMFFITypeInfo* type_info) {
     static reflection::TypeAttrColumn custom_s_equal =
         reflection::TypeAttrColumn(reflection::type_attr::kSEqual);
 
@@ -184,12 +219,17 @@ class StructEqualHandler {
         reflection::FieldGetter getter(field_info);
         Any lhs_value = getter(lhs);
         Any rhs_value = getter(rhs);
-        // field is in def region, enable free var mapping
-        if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) {
-          bool allow_free_var = true;
-          std::swap(allow_free_var, map_free_vars_);
+        // Dispatch on the def-region flags.
+        constexpr int64_t kSEqHashDefAny = 
kTVMFFIFieldFlagBitMaskSEqHashDefRecursive |
+                                           
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive;
+        if (field_info->flags & kSEqHashDefAny) {
+          TVMFFIDefRegionKind new_kind =
+              (field_info->flags & 
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive)
+                  ? kTVMFFIDefRegionKindNonRecursive
+                  : kTVMFFIDefRegionKindRecursive;
+          std::swap(new_kind, def_region_kind_);
           success = CompareAny(lhs_value, rhs_value);
-          std::swap(allow_free_var, map_free_vars_);
+          std::swap(new_kind, def_region_kind_);
         } else {
           success = CompareAny(lhs_value, rhs_value);
         }
@@ -212,19 +252,20 @@ class StructEqualHandler {
       // run custom equal function defined via __s_equal__ type attribute
       if (s_equal_callback_ == nullptr) {
         s_equal_callback_ = ffi::Function::FromTyped(
-            [this](AnyView lhs, AnyView rhs, bool def_region, AnyView 
field_name) {
+            // Third parameter is a ``TVMFFIDefRegionKind`` (passed on the wire
+            // as ``int`` to keep the FFI signature stable across language
+            // boundaries).
+            [this](AnyView inner_lhs, AnyView inner_rhs, int def_region_kind, 
AnyView field_name) {
               // NOTE: we explicitly make field_name as AnyView to avoid copy 
overhead initially
               // and only cast to string if mismatch happens
-              bool success = true;
-              if (def_region) {
-                bool allow_free_var = true;
-                std::swap(allow_free_var, map_free_vars_);
-                success = CompareAny(lhs, rhs);
-                std::swap(allow_free_var, map_free_vars_);
-              } else {
-                success = CompareAny(lhs, rhs);
-              }
-              if (!success) {
+              TVMFFIDefRegionKind new_kind =
+                  (def_region_kind == kTVMFFIDefRegionKindNone)
+                      ? def_region_kind_
+                      : static_cast<TVMFFIDefRegionKind>(def_region_kind);
+              std::swap(new_kind, def_region_kind_);
+              bool sub_success = CompareAny(inner_lhs, inner_rhs);
+              std::swap(new_kind, def_region_kind_);
+              if (!sub_success) {
                 if (mismatch_lhs_reverse_path_ != nullptr) {
                   String field_name_str = field_name.cast<String>();
                   mismatch_lhs_reverse_path_->emplace_back(
@@ -233,38 +274,14 @@ class StructEqualHandler {
                       reflection::AccessStep::Attr(field_name_str));
                 }
               }
-              return success;
+              return sub_success;
             });
       }
       success = custom_s_equal[type_info->type_index]
                     .cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
                     .cast<bool>();
     }
-
-    if (success) {
-      if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
-        // we are in a free var case that is not yet mapped.
-        // in this case, either map_free_vars_ should be set to true, or 
map_free_vars_ should be
-        // set
-        if (lhs.same_as(rhs) || map_free_vars_) {
-          // record the equality
-          equal_map_lhs_[lhs] = rhs;
-          equal_map_rhs_[rhs] = lhs;
-          return true;
-        } else {
-          return false;
-        }
-      }
-      // if we have a success mapping and in graph/var mode, record the 
equality mapping
-      if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
-        // record the equality
-        equal_map_lhs_[lhs] = rhs;
-        equal_map_rhs_[rhs] = lhs;
-      }
-      return true;
-    } else {
-      return false;
-    }
+    return success;
   }
 
   template <typename MapType>
@@ -413,8 +430,12 @@ class StructEqualHandler {
     }
     return rhs_obj;
   }
-  // whether we map free variables that are not defined
-  bool map_free_vars_{false};
+  // Current def-region kind. ``kNone`` means we are not in a def region;
+  // free vars discovered here do not bind (they must already be bound by an
+  // outer scope or comparison falls back to pointer identity). ``kRecursive``
+  // and ``kNonRecursive`` enable binding for the field-flag-driven walk and
+  // for the custom-callback path respectively (see CompareObject).
+  TVMFFIDefRegionKind def_region_kind_{kTVMFFIDefRegionKindNone};
   // whether we compare tensor data
   bool skip_tensor_content_{false};
   // the root lhs for result printing
@@ -431,7 +452,8 @@ class StructEqualHandler {
 bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars,
                             bool skip_tensor_content) {
   StructEqualHandler handler;
-  handler.map_free_vars_ = map_free_vars;
+  handler.def_region_kind_ =
+      map_free_vars ? kTVMFFIDefRegionKindRecursive : kTVMFFIDefRegionKindNone;
   handler.skip_tensor_content_ = skip_tensor_content;
   return handler.CompareAny(lhs, rhs);
 }
@@ -441,7 +463,8 @@ Optional<reflection::AccessPathPair> 
StructuralEqual::GetFirstMismatch(const Any
                                                                        bool 
map_free_vars,
                                                                        bool 
skip_tensor_content) {
   StructEqualHandler handler;
-  handler.map_free_vars_ = map_free_vars;
+  handler.def_region_kind_ =
+      map_free_vars ? kTVMFFIDefRegionKindRecursive : kTVMFFIDefRegionKindNone;
   handler.skip_tensor_content_ = skip_tensor_content;
   std::vector<reflection::AccessStep> lhs_reverse_path;
   std::vector<reflection::AccessStep> rhs_reverse_path;
diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc
index 8ab96f0..21c5545 100644
--- a/src/ffi/extra/structural_hash.cc
+++ b/src/ffi/extra/structural_hash.cc
@@ -102,7 +102,7 @@ class StructuralHashHandler {
     }
   }
 
-  uint64_t HashObject(ObjectRef obj) {
+  uint64_t HashObject(const ObjectRef& obj) {
     // NOTE: invariant: lhs and rhs are already the same type
     const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index());
     if (type_info->metadata == nullptr) {
@@ -127,71 +127,102 @@ class StructuralHashHandler {
       return it->second;
     }
 
+    uint64_t hash_value;
+    if (structural_eq_hash_kind != kTVMFFISEqHashKindFreeVar) {
+      hash_value = HashFields(obj, type_info, obj->GetTypeKeyHash());
+    } else {
+      // FreeVar path. In a non-recursive def region the FreeVar's own
+      // sub-fields are walked outside the def region (nested free vars
+      // there hash by pointer, matching use semantics), so we clamp
+      // ``def_region_kind_`` to ``kNone`` around the HashFields call and
+      // restore before the FreeVar-level injection below.
+      //
+      // We always call HashFields, even in use mode where the returned
+      // ``hash_value`` is discarded by the pointer-hash fallback. The
+      // walk's side effect on ``free_var_counter_`` (incremented for
+      // every nested FreeVar reached via SEqHashDef-tagged sub-fields)
+      // is observable to FreeVars hashed later in the same traversal;
+      // skipping the walk would silently change those subsequent hashes.
+      TVMFFIDefRegionKind saved_def_region_kind = def_region_kind_;
+      if (def_region_kind_ == kTVMFFIDefRegionKindNonRecursive) {
+        def_region_kind_ = kTVMFFIDefRegionKindNone;
+      }
+      hash_value = HashFields(obj, type_info, obj->GetTypeKeyHash());
+      def_region_kind_ = saved_def_region_kind;
+      if (def_region_kind_ != kTVMFFIDefRegionKindNone) {
+        // use lexical order of free var and its type
+        hash_value = details::StableHashCombine(hash_value, 
free_var_counter_++);
+      } else {
+        // Fallback to pointer hash; we are not in a def region.
+        hash_value = std::hash<const Object*>()(obj.get());
+      }
+    }
+
+    // if it is a DAG node, also record the lexical order of graph counter
+    // this helps to distinguish DAG from trees.
+    if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
+      hash_value = details::StableHashCombine(hash_value, 
graph_node_counter_++);
+    }
+    // record the hash value for this object
+    hash_memo_[obj] = hash_value;
+    return hash_value;
+  }
+
+  // Hash an object's fields (generic walk or via the custom __ffi_s_hash__
+  // callback). Does not touch the FreeVar def-region clamp — that lives
+  // inline in HashObject's FreeVar branch, which wraps this helper.
+  uint64_t HashFields(const ObjectRef& obj, const TVMFFITypeInfo* type_info, 
uint64_t init_hash) {
     static reflection::TypeAttrColumn custom_s_hash =
         reflection::TypeAttrColumn(reflection::type_attr::kSHash);
 
-    // compute the hash value
-    uint64_t hash_value = obj->GetTypeKeyHash();
     if (custom_s_hash[type_info->type_index] == nullptr) {
       // go over the content and hash the fields
       reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* 
field_info) {
         // skip fields that are marked as structural eq hash ignore
         if (!(field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore)) {
-          // get the field value from both side
           reflection::FieldGetter getter(field_info);
           Any field_value = getter(obj);
-          // field is in def region, enable free var mapping
-          if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) {
-            bool allow_free_var = true;
-            std::swap(allow_free_var, map_free_vars_);
-            hash_value = details::StableHashCombine(hash_value, 
HashAny(field_value));
-            std::swap(allow_free_var, map_free_vars_);
+          // Dispatch on the def-region flags (mirror of the equality side).
+          constexpr int64_t kSEqHashDefAny = 
kTVMFFIFieldFlagBitMaskSEqHashDefRecursive |
+                                             
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive;
+          if (field_info->flags & kSEqHashDefAny) {
+            TVMFFIDefRegionKind new_kind =
+                (field_info->flags & 
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive)
+                    ? kTVMFFIDefRegionKindNonRecursive
+                    : kTVMFFIDefRegionKindRecursive;
+            std::swap(new_kind, def_region_kind_);
+            init_hash = details::StableHashCombine(init_hash, 
HashAny(field_value));
+            std::swap(new_kind, def_region_kind_);
           } else {
-            hash_value = details::StableHashCombine(hash_value, 
HashAny(field_value));
+            init_hash = details::StableHashCombine(init_hash, 
HashAny(field_value));
           }
         }
       });
     } else {
       if (s_hash_callback_ == nullptr) {
         s_hash_callback_ =
-            ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, 
bool def_region) {
-              if (def_region) {
-                bool allow_free_var = true;
-                std::swap(allow_free_var, map_free_vars_);
-                uint64_t hash_value = HashAny(val);
-                std::swap(allow_free_var, map_free_vars_);
-                return 
static_cast<int64_t>(details::StableHashCombine(init_hash, hash_value));
-              } else {
-                // we explicitly bitcast the result from `uint64_t` to 
`int64_t`.
-                // The range of `uint64_t` is too large to fit as `int64_t`, 
so if we don't bitcast,
-                // it will trigger an overflow error in `uint64_t` -> `Any` 
conversion.
-                return 
static_cast<int64_t>(details::StableHashCombine(init_hash, HashAny(val)));
-              }
+            // Third parameter is a ``TVMFFIDefRegionKind`` (passed on the wire
+            // as ``int`` to keep the FFI signature stable across language
+            // boundaries).
+            ffi::Function::FromTyped([this](AnyView val, uint64_t inner_init, 
int def_region_kind) {
+              TVMFFIDefRegionKind new_kind =
+                  (def_region_kind == kTVMFFIDefRegionKindNone)
+                      ? def_region_kind_
+                      : static_cast<TVMFFIDefRegionKind>(def_region_kind);
+              std::swap(new_kind, def_region_kind_);
+              uint64_t hv = HashAny(val);
+              std::swap(new_kind, def_region_kind_);
+              // we explicitly bitcast the result from `uint64_t` to `int64_t`.
+              // The range of `uint64_t` is too large to fit as `int64_t`, so 
if we don't bitcast,
+              // it will trigger an overflow error in `uint64_t` -> `Any` 
conversion.
+              return 
static_cast<int64_t>(details::StableHashCombine(inner_init, hv));
             });
       }
-      hash_value =
-          custom_s_hash[type_info->type_index]
-              .cast<ffi::Function>()(obj, static_cast<int64_t>(hash_value), 
s_hash_callback_)
-              .cast<uint64_t>();
-    }
-
-    if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
-      if (map_free_vars_) {
-        // use lexical order of free var and its type
-        hash_value = details::StableHashCombine(hash_value, 
free_var_counter_++);
-      } else {
-        // Fallback to pointer hash, we are not mapping free var.
-        hash_value = std::hash<const Object*>()(obj.get());
-      }
+      init_hash = custom_s_hash[type_info->type_index]
+                      .cast<ffi::Function>()(obj, 
static_cast<int64_t>(init_hash), s_hash_callback_)
+                      .cast<uint64_t>();
     }
-    // if it is a DAG node, also record the lexical order of graph counter
-    // this helps to distinguish DAG from trees.
-    if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
-      hash_value = details::StableHashCombine(hash_value, 
graph_node_counter_++);
-    }
-    // record the hash value for this object
-    hash_memo_[obj] = hash_value;
-    return hash_value;
+    return init_hash;
   }
 
   // NOLINTNEXTLINE(performance-unnecessary-value-param)
@@ -317,7 +348,11 @@ class StructuralHashHandler {
     return hash_value;
   }
 
-  bool map_free_vars_{false};
+  // Current def-region kind. ``kNone`` means we are not in a def region; free
+  // vars hash by pointer. ``kRecursive`` and ``kNonRecursive`` enable
+  // ``free_var_counter_``-based hashing for the field-flag-driven walk and
+  // for the custom-callback path respectively (see HashObject).
+  TVMFFIDefRegionKind def_region_kind_{kTVMFFIDefRegionKindNone};
   bool skip_tensor_content_{false};
   // free var counter.
   uint32_t free_var_counter_{0};
@@ -331,7 +366,8 @@ class StructuralHashHandler {
 
 uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool 
skip_tensor_content) {
   StructuralHashHandler handler;
-  handler.map_free_vars_ = map_free_vars;
+  handler.def_region_kind_ =
+      map_free_vars ? kTVMFFIDefRegionKindRecursive : kTVMFFIDefRegionKindNone;
   handler.skip_tensor_content_ = skip_tensor_content;
   return handler.HashAny(value);
 }
diff --git a/tests/cpp/extra/test_structural_equal_hash.cc 
b/tests/cpp/extra/test_structural_equal_hash.cc
index ad081e3..6ce380a 100644
--- a/tests/cpp/extra/test_structural_equal_hash.cc
+++ b/tests/cpp/extra/test_structural_equal_hash.cc
@@ -229,6 +229,70 @@ TEST(StructuralEqualHash, CustomTreeNode) {
   EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
 }
 
+// Regression tests for the SEqHashDefRecursive vs SEqHashDefNonRecursive
+// distinction. ``TDefHolder`` has two sibling fields:
+//   - ``def_recursive``     tagged AttachFieldFlag::SEqHashDefRecursive()
+//   - ``def_non_recursive`` tagged AttachFieldFlag::SEqHashDefNonRecursive()
+// each holding a ``TVarWithDep`` (a FreeVar with a sub-field ``dep`` that
+// can itself reference another FreeVar). The four sub-cases below cover
+// the observable behaviors of the two flags.
+TEST(StructuralEqualHash, NonRecursiveDef) {
+  {
+    // (a) Recursive flag rebinds nested FreeVars transitively.
+    // ``def_non_recursive`` is the same object on both sides so it equates
+    // by pointer; the case isolates the recursive field's rebinding.
+    SCOPED_TRACE("recursive flag rebinds nested FreeVars");
+    TVarWithDep a("a", TVar("m"));
+    TVarWithDep b("b", TVar("n"));
+    TDefHolder lhs(/*def_recursive=*/a, /*def_non_recursive=*/a);
+    TDefHolder rhs(/*def_recursive=*/b, /*def_non_recursive=*/b);
+    EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+    EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true),
+              StructuralHash::Hash(rhs, /*map_free_vars=*/true));
+  }
+  {
+    // (b) Non-recursive flag does NOT rebind nested FreeVars: the top-level
+    // FreeVar binds but the nested ``dep`` is clamped out of the def region.
+    // With no outer binding for "p"/"q", equality must fail.
+    SCOPED_TRACE("non-recursive flag does not rebind nested FreeVars");
+    TVarWithDep shared("shared", std::nullopt);
+    TVarWithDep c_with_dep("c", TVar("p"));
+    TVarWithDep d_with_dep("d", TVar("q"));
+    TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep);
+    TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep);
+    EXPECT_FALSE(StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false));
+  }
+  {
+    // (c) Non-recursive flag works if nested FreeVars resolve via an outer
+    // binding — here we cheat by wiring the same pointer, so the nested
+    // FreeVar passes the same-as pointer check without needing the def
+    // region to be on inside its sub-field walk.
+    SCOPED_TRACE("nested FreeVars resolve via outer pointer identity");
+    TVar shared_dep("dep");
+    TVarWithDep c_with_dep("c", shared_dep);
+    TVarWithDep d_with_dep("d", shared_dep);
+    TVarWithDep shared("shared", std::nullopt);
+    TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep);
+    TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep);
+    EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+    EXPECT_EQ(StructuralHash()(lhs), StructuralHash()(rhs));
+  }
+  {
+    // (d) Top-level FreeVar still binds under non-recursive — only the
+    // FreeVar's sub-fields are clamped out; the binding step itself for
+    // the immediate FreeVar is not suppressed.
+    SCOPED_TRACE("top-level FreeVar still binds under non-recursive flag");
+    TVarWithDep shared("shared", std::nullopt);
+    TVarWithDep c_no_dep("c", std::nullopt);
+    TVarWithDep d_no_dep("d", std::nullopt);
+    TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_no_dep);
+    TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_no_dep);
+    EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+    EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true),
+              StructuralHash::Hash(rhs, /*map_free_vars=*/true));
+  }
+}
+
 TEST(StructuralEqualHash, List) {
   List<int> a = {1, 2, 3};
   List<int> b = {1, 2, 3};
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index 0593eca..6f99a74 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -66,6 +66,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   TFloatObj::RegisterReflection();
   TPrimExprObj::RegisterReflection();
   TVarObj::RegisterReflection();
+  TVarWithDepObj::RegisterReflection();
+  TDefHolderObj::RegisterReflection();
   TFuncObj::RegisterReflection();
   TCustomFuncObj::RegisterReflection();
   TAllFieldsObj::RegisterReflection();
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index 48b1a01..d1bf0a9 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -206,6 +206,81 @@ class TVar : public ObjectRef {
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVar, ObjectRef, TVarObj);
 };
 
+// FreeVar test object that has a sub-field referencing another FreeVar.
+// This models the "var with nested vars" case (analogous to a relax::Var
+// whose struct_info contains tir shape vars). It is used to exercise the
+// difference between SEqHashDefRecursive and SEqHashDefNonRecursive at the
+// FFI layer: under recursive semantics the nested ``dep`` var rebinds
+// transitively; under non-recursive semantics it is treated as a use of an
+// outer-scope binding and equality fails when no such outer binding exists.
+class TVarWithDepObj : public Object {
+ public:
+  std::string name;
+  // Optional dependency var; when null, this object behaves like a plain
+  // FreeVar with no nested free vars.
+  Optional<ObjectRef> dep;
+
+  TVarWithDepObj(std::string name, Optional<ObjectRef> dep)
+      : name(std::move(name)), dep(std::move(dep)) {}
+  explicit TVarWithDepObj(UnsafeInit) {}
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<TVarWithDepObj>()
+        .def_ro("name", &TVarWithDepObj::name, 
refl::AttachFieldFlag::SEqHashIgnore())
+        // ``dep`` participates in structural equality without any def flag,
+        // so it is a USE position. Whether the FreeVar in ``dep`` may rebind
+        // is decided by the def flag on whichever outer field reaches this
+        // object.
+        .def_ro("dep", &TVarWithDepObj::dep);
+  }
+
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindFreeVar;
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.VarWithDep", TVarWithDepObj, Object);
+};
+
+class TVarWithDep : public ObjectRef {
+ public:
+  explicit TVarWithDep(std::string name, Optional<ObjectRef> dep = 
std::nullopt) {
+    data_ = make_object<TVarWithDepObj>(std::move(name), std::move(dep));
+  }
+
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVarWithDep, ObjectRef, 
TVarWithDepObj);
+};
+
+// Holder with one recursive-def field and one non-recursive-def field.
+// Used by StructuralEqualHash.NonRecursiveDef tests below.
+class TDefHolderObj : public Object {
+ public:
+  TVarWithDep def_recursive;
+  TVarWithDep def_non_recursive;
+
+  TDefHolderObj(TVarWithDep def_recursive, TVarWithDep def_non_recursive)
+      : def_recursive(std::move(def_recursive)), 
def_non_recursive(std::move(def_non_recursive)) {}
+  explicit TDefHolderObj(UnsafeInit) {}
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<TDefHolderObj>()
+        .def_ro("def_recursive", &TDefHolderObj::def_recursive,
+                refl::AttachFieldFlag::SEqHashDefRecursive())
+        .def_ro("def_non_recursive", &TDefHolderObj::def_non_recursive,
+                refl::AttachFieldFlag::SEqHashDefNonRecursive());
+  }
+
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindTreeNode;
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.DefHolder", TDefHolderObj, Object);
+};
+
+class TDefHolder : public ObjectRef {
+ public:
+  explicit TDefHolder(TVarWithDep def_recursive, TVarWithDep 
def_non_recursive) {
+    data_ = make_object<TDefHolderObj>(std::move(def_recursive), 
std::move(def_non_recursive));
+  }
+
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TDefHolder, ObjectRef, 
TDefHolderObj);
+};
+
 class TFuncObj : public Object {
  public:
   Array<TVar> params;
@@ -220,7 +295,7 @@ class TFuncObj : public Object {
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
     refl::ObjectDef<TFuncObj>()
-        .def_ro("params", &TFuncObj::params, 
refl::AttachFieldFlag::SEqHashDef())
+        .def_ro("params", &TFuncObj::params, 
refl::AttachFieldFlag::SEqHashDefRecursive())
         .def_ro("body", &TFuncObj::body)
         .def_ro("comment", &TFuncObj::comment, 
refl::AttachFieldFlag::SEqHashIgnore());
   }


Reply via email to