This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch seqhashdef-non-recursive in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
commit 107ad5194ccdeead9c03764b48b5982966bacec9 Author: tqchen <[email protected]> AuthorDate: Sun May 10 00:36:58 2026 +0000 [FEAT] Split SEqHashDef into recursive and non-recursive variants The single ``SEqHashDef`` flag treated every nested free var inside a def-region field as a fresh def. That conflates two different binding shapes: - *Recursive* (function-style): the value var and any free vars inside its sub-fields (e.g. shape vars in a relax::Var's struct_info) are co-introduced at the same site. - *Non-recursive* (let-style): only the immediate value var binds; free vars in the var's sub-fields are use references that must resolve against an outer-scope binding. This change: - Renames the C constant ``kTVMFFIFieldFlagBitMaskSEqHashDef`` -> ``...SEqHashDefRecursive`` (no alias) and adds a new sibling ``...SEqHashDefNonRecursive`` at bit ``1 << 12`` (the next free bit). - Renames ``AttachFieldFlag::SEqHashDef()`` -> ``SEqHashDefRecursive()`` and adds ``SEqHashDefNonRecursive()``. - Adds a ``TVMFFIFieldDefKind`` C enum (None=0, Recursive=1, NonRecursive=2) for the custom ``__s_equal__`` / ``__s_hash__`` callback's def-mode parameter. The wire type stays ``int``, so legacy callers that pass a ``bool`` still compile and preserve their meaning via the standard bool->int coercion (true -> 1 = Recursive, false -> 0 = None). - Updates the structural_equal / structural_hash dispatch to clamp ``map_free_vars_`` to false when descending into a FreeVar's own sub-fields under the non-recursive flag (the binding step itself still runs with the caller's setting, so the immediate FreeVar still binds). - Mirrors the rename and the new enum into the Cython base.pxi / type_info.pxi / object.pxi layer; preserves Python ``"def"`` as a back-compat alias for ``"def-recursive"`` and adds ``"def-non-recursive"`` to the dataclass field vocabulary. - Adds C++ regression tests in tests/cpp/extra/test_structural_equal_hash.cc covering the four recursive / non-recursive corner cases on a new ``TDefHolder`` test type whose ``def_recursive`` and ``def_non_recursive`` fields hold a FreeVar with a nested FreeVar sub-field. --- include/tvm/ffi/c_api.h | 92 ++++++++++++++++++++++++++- include/tvm/ffi/reflection/registry.h | 21 +++++- python/tvm_ffi/cython/base.pxi | 8 ++- python/tvm_ffi/cython/object.pxi | 8 ++- python/tvm_ffi/cython/type_info.pxi | 10 ++- python/tvm_ffi/dataclasses/field.py | 29 +++++++-- src/ffi/extra/structural_equal.cc | 80 ++++++++++++++++++++--- src/ffi/extra/structural_hash.cc | 69 +++++++++++++++++--- tests/cpp/extra/test_structural_equal_hash.cc | 77 ++++++++++++++++++++++ tests/cpp/test_reflection.cc | 2 + tests/cpp/testing_object.h | 77 +++++++++++++++++++++- 11 files changed, 435 insertions(+), 38 deletions(-) diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h index 0d2c9df..e69fd6c 100644 --- a/include/tvm/ffi/c_api.h +++ b/include/tvm/ffi/c_api.h @@ -858,11 +858,21 @@ typedef enum { */ kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3, /*! - * \brief The field enters a def region where var can be defined/matched. + * \brief The field enters a recursive def region. + * + * When equality / hashing first encounters a free var reachable through + * this field, the var binds. Sub-fields of that var (e.g. its + * struct_info / type_annotation / shape) remain in the def region — any + * free vars discovered transitively also bind as fresh defs at the same + * site. + * + * Use for "function-style" bindings where the value var and its shape + * vars are co-introduced (e.g. ``relax::FunctionNode::params``, + * ``tirx::AllocBufferNode::buffer``). * * This is an optional meta-data for structural eq/hash. */ - kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4, + kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4, /*! * \brief The default_value_or_factory is a callable factory function () -> Any. * @@ -922,6 +932,27 @@ typedef enum { * ``(field_addr_as_OpaquePtr, value_as_AnyView)``. */ kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11, + /*! + * \brief The field enters a non-recursive def region. + * + * When equality / hashing first encounters a free var reachable through + * this field, the var binds. Sub-fields of that var are NOT in the def + * region — they are use references that must resolve against an outer + * binding. Free vars found in those sub-fields therefore do not rebind; + * if they are not already bound, equality fails (or hashing falls back + * to pointer identity). + * + * Use for "let-style" bindings where only the value var is introduced + * and its shape / type sub-fields refer to outer-scope vars + * (e.g. ``relax::BindingNode::var``, ``tirx::LetNode::var``, + * ``tirx::ForNode::loop_var``). + * + * This is an optional meta-data for structural eq/hash. + * + * \note Bit 1 << 12 is used here because bits 1 << 5 .. 1 << 11 are + * already taken by other field flags above. + */ + kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12, #ifdef __cplusplus }; #else @@ -993,6 +1024,63 @@ typedef enum { } TVMFFISEqHashKind; #endif +/*! + * \brief Kind of def region a structural-equal / structural-hash callback is + * currently in when visiting a field. + * + * The numeric values are stable: a legacy ``bool def_region`` argument that + * is implicitly coerced to ``int`` will land on ``kTVMFFIFieldDefKindNone`` + * (false → 0) or ``kTVMFFIFieldDefKindRecursive`` (true → 1), preserving + * the meaning of any pre-existing call site that passes a bool. + */ +#ifdef __cplusplus +enum TVMFFIFieldDefKind : int32_t { +#else +typedef enum { +#endif + /*! + * \brief Not in a def region. + * + * Free vars reachable through this field are uses; they must already + * be bound by an enclosing def region or equality / hashing falls + * back to pointer identity. + */ + kTVMFFIFieldDefKindNone = 0, + /*! + * \brief In a recursive def region. + * + * When we see a free var for the first time, we define the var, and + * the sub-fields of the var (e.g. its struct_info / type_annotation / + * shape) are also still in the def region — any free vars discovered + * inside those sub-fields are themselves treated as fresh defs at the + * same site. + * + * Use for "function-style" bindings where the value var and its shape + * vars are co-introduced (e.g. ``relax::FunctionNode::params``, + * ``tirx::AllocBufferNode::buffer``). + */ + kTVMFFIFieldDefKindRecursive = 1, + /*! + * \brief In a non-recursive def region. + * + * When we see a free var for the first time, we define the var, but + * the sub-fields of the var are NOT in the def region — they are + * treated as use references that must resolve against an outer + * binding. Free vars found in those sub-fields therefore do not + * rebind; if they are not already bound, equality fails. + * + * Use for "let-style" bindings where only the value var is introduced + * and its shape / type sub-fields refer to outer-scope vars + * (e.g. ``relax::BindingNode::var``, ``tirx::LetNode::var``, + * ``tirx::ForNode::loop_var``). + */ + kTVMFFIFieldDefKindNonRecursive = 2, +#ifdef __cplusplus +}; +#else +} TVMFFIFieldDefKind; +#endif + /*! * \brief Information support for optional object reflection. */ diff --git a/include/tvm/ffi/reflection/registry.h b/include/tvm/ffi/reflection/registry.h index 3e715fe..09ceb8b 100644 --- a/include/tvm/ffi/reflection/registry.h +++ b/include/tvm/ffi/reflection/registry.h @@ -213,10 +213,25 @@ class AttachFieldFlag : public InfoTrait { explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefRecursive + * + * The field enters a recursive def region: free vars discovered both at + * the field's value and inside that value's sub-fields bind as fresh + * defs at the same site. Use for "function-style" bindings. + */ + TVM_FFI_INLINE static AttachFieldFlag SEqHashDefRecursive() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefRecursive); + } + /*! + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive + * + * The field enters a non-recursive def region: only the immediate free + * var at the field's value binds; free vars in its sub-fields are uses + * that must already be bound by an outer def region. Use for "let-style" + * bindings whose sub-fields reference outer-scope vars. */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); + TVM_FFI_INLINE static AttachFieldFlag SEqHashDefNonRecursive() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive); } /*! * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi index c5c28a1..10d96c4 100644 --- a/python/tvm_ffi/cython/base.pxi +++ b/python/tvm_ffi/cython/base.pxi @@ -207,7 +207,7 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1 kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2 kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3 - kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4 + kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4 kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5 kTVMFFIFieldFlagBitMaskReprOff = 1 << 6 kTVMFFIFieldFlagBitMaskCompareOff = 1 << 7 @@ -215,6 +215,7 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFIFieldFlagBitMaskInitOff = 1 << 9 kTVMFFIFieldFlagBitMaskKwOnly = 1 << 10 kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11 + kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12 ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept @@ -248,6 +249,11 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFISEqHashKindConstTreeNode = 4 kTVMFFISEqHashKindUniqueInstance = 5 + cdef enum TVMFFIFieldDefKind: + kTVMFFIFieldDefKindNone = 0 + kTVMFFIFieldDefKindRecursive = 1 + kTVMFFIFieldDefKindNonRecursive = 2 + ctypedef struct TVMFFITypeMetadata: TVMFFIByteArray doc TVMFFIObjectCreator creator diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi index 803ead2..c564b42 100644 --- a/python/tvm_ffi/cython/object.pxi +++ b/python/tvm_ffi/cython/object.pxi @@ -541,11 +541,13 @@ cdef _type_info_create_from_type_key(object type_cls, str type_key): c_default_factory = make_ret(owned_default) else: c_default = make_ret(owned_default) - # Decode SEqHashIgnore / SEqHashDef into the Field.structural_eq vocabulary. + # Decode SEqHashIgnore / SEqHashDef* into the Field.structural_eq vocabulary. if (field.flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) != 0: c_structural_eq = "ignore" - elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDef) != 0: - c_structural_eq = "def" + elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefRecursive) != 0: + c_structural_eq = "def-recursive" + elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) != 0: + c_structural_eq = "def-non-recursive" else: c_structural_eq = None fields.append( diff --git a/python/tvm_ffi/cython/type_info.pxi b/python/tvm_ffi/cython/type_info.pxi index d7f39be..960d5cd 100644 --- a/python/tvm_ffi/cython/type_info.pxi +++ b/python/tvm_ffi/cython/type_info.pxi @@ -908,8 +908,14 @@ cdef _register_one_field( cdef object field_structure = getattr(py_field, "structural_eq", None) if field_structure == "ignore": flags |= kTVMFFIFieldFlagBitMaskSEqHashIgnore - elif field_structure == "def": - flags |= kTVMFFIFieldFlagBitMaskSEqHashDef + elif field_structure == "def" or field_structure == "def-recursive": + # ``"def"`` is the legacy short form, kept as a Python-side synonym for + # ``"def-recursive"`` since the C-level rename of the underlying flag + # (``kTVMFFIFieldFlagBitMaskSEqHashDef`` -> ``...SEqHashDefRecursive``) + # only changed the constant name, not the recursive semantics. + flags |= kTVMFFIFieldFlagBitMaskSEqHashDefRecursive + elif field_structure == "def-non-recursive": + flags |= kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive info.flags = flags # --- native layout --- diff --git a/python/tvm_ffi/dataclasses/field.py b/python/tvm_ffi/dataclasses/field.py index 53f576c..b332cd4 100644 --- a/python/tvm_ffi/dataclasses/field.py +++ b/python/tvm_ffi/dataclasses/field.py @@ -87,9 +87,18 @@ class Field: structural comparison and hashing. - ``"ignore"``: the field is excluded from structural equality and hashing entirely (e.g. source spans, caches). - - ``"def"``: the field is a **definition region** that introduces - new variable bindings. Free variables encountered inside this - field are mapped by position, enabling alpha-equivalence. + - ``"def-recursive"`` (alias: ``"def"``): the field is a + **recursive definition region** that introduces new variable + bindings. Free variables encountered anywhere in this field's + subtree (including inside the var's own sub-fields) are + mapped by position. Use for "function-style" bindings where + the value var and its shape vars are co-introduced. + - ``"def-non-recursive"``: the field is a **non-recursive + definition region**. Only the immediate free var(s) at this + field's value bind; free vars inside their sub-fields must + resolve against an outer binding (use semantics). Use for + "let-style" bindings whose sub-fields reference outer-scope + vars. doc : str | None Optional docstring for the field. @@ -125,8 +134,12 @@ class Field: doc: str | None #: Valid values for the *structural_eq* parameter. + #: + #: ``"def"`` is kept as a Python-side alias for ``"def-recursive"`` to + #: preserve back-compat with code written against the old single-flag + #: ``SEqHashDef`` API. _VALID_STRUCTURAL_EQ_VALUES: ClassVar[frozenset[str | None]] = frozenset( - {None, "ignore", "def"} + {None, "ignore", "def", "def-recursive", "def-non-recursive"} ) def __init__( # noqa: PLR0913 @@ -226,8 +239,12 @@ def field( structural_eq Structural equality/hashing annotation. ``None`` (default) means the field participates normally. ``"ignore"`` excludes the field - from structural comparison and hashing. ``"def"`` marks the field - as a definition region for variable binding. + from structural comparison and hashing. ``"def-recursive"`` + (alias ``"def"``) marks the field as a recursive definition + region: free vars in the field's whole subtree bind. ``"def-non-recursive"`` + marks it as a non-recursive definition region: only immediate + free vars bind; nested free vars must resolve against an outer + binding. doc Optional docstring for the field. diff --git a/src/ffi/extra/structural_equal.cc b/src/ffi/extra/structural_equal.cc index 5f4db3c..5c1cb0d 100644 --- a/src/ffi/extra/structural_equal.cc +++ b/src/ffi/extra/structural_equal.cc @@ -174,6 +174,25 @@ class StructEqualHandler { static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn(reflection::type_attr::kSEqual); + // Non-recursive def boundary. When we enter a non-recursive def region we + // keep ``map_free_vars_`` on so that any FreeVar reachable through + // containers in the field's value (e.g. each ``Var`` in an ``Array<Var>``) + // can still bind. But once we are about to walk a FreeVar's OWN sub-fields + // (e.g. ``struct_info``, ``type_annotation``), we turn ``map_free_vars_`` + // off so that nested free vars do not rebind — they must instead resolve + // against a binding established by an outer def region. + // + // ``non_recursive_def_active_`` stays on for the entire field value + // subtree (saved/restored at the field walk site below). It is consulted + // here to decide whether to clamp ``map_free_vars_`` to false during the + // FreeVar's field walk. + bool save_map_free_vars = map_free_vars_; + bool clamp_map_free_vars = + (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) && non_recursive_def_active_; + if (clamp_map_free_vars) { + map_free_vars_ = false; + } + bool success = true; if (custom_s_equal[type_info->type_index] == nullptr) { // We recursively compare the fields the object @@ -184,12 +203,27 @@ class StructEqualHandler { reflection::FieldGetter getter(field_info); Any lhs_value = getter(lhs); Any rhs_value = getter(rhs); - // field is in def region, enable free var mapping - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); + // Dispatch on the def-region flags. + // - Recursive : enable ``map_free_vars_`` for the whole subtree + // of this field's value, including nested FreeVars' + // sub-fields. + // - NonRecursive : enable ``map_free_vars_`` only for the immediate + // FreeVar(s) reachable through this field; their + // own sub-fields are walked with ``map_free_vars_`` + // clamped to false (the clamp lives in the + // ``CompareObject`` prologue above, gated by + // ``non_recursive_def_active_``). + constexpr int64_t kSEqHashDefAny = kTVMFFIFieldFlagBitMaskSEqHashDefRecursive | + kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive; + if (field_info->flags & kSEqHashDefAny) { + bool save_allow_free_var = map_free_vars_; + bool save_non_recursive = non_recursive_def_active_; + map_free_vars_ = true; + non_recursive_def_active_ = + (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) != 0; success = CompareAny(lhs_value, rhs_value); - std::swap(allow_free_var, map_free_vars_); + map_free_vars_ = save_allow_free_var; + non_recursive_def_active_ = save_non_recursive; } else { success = CompareAny(lhs_value, rhs_value); } @@ -212,16 +246,27 @@ class StructEqualHandler { // run custom equal function defined via __s_equal__ type attribute if (s_equal_callback_ == nullptr) { s_equal_callback_ = ffi::Function::FromTyped( - [this](AnyView lhs, AnyView rhs, bool def_region, AnyView field_name) { + // The third parameter is a ``TVMFFIFieldDefKind`` (typed as plain + // ``int`` on the wire to keep the FFI signature stable; legacy + // callers passing ``bool`` continue to compile and preserve their + // meaning via the implicit bool->int coercion: false -> 0 + // (kTVMFFIFieldDefKindNone), true -> 1 (kTVMFFIFieldDefKindRecursive)). + [this](AnyView lhs, AnyView rhs, int def_kind, AnyView field_name) { // NOTE: we explicitly make field_name as AnyView to avoid copy overhead initially // and only cast to string if mismatch happens bool success = true; - if (def_region) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); + if (def_kind == kTVMFFIFieldDefKindRecursive || + def_kind == kTVMFFIFieldDefKindNonRecursive) { + bool save_allow_free_var = map_free_vars_; + bool save_non_recursive = non_recursive_def_active_; + map_free_vars_ = true; + non_recursive_def_active_ = (def_kind == kTVMFFIFieldDefKindNonRecursive); success = CompareAny(lhs, rhs); - std::swap(allow_free_var, map_free_vars_); + map_free_vars_ = save_allow_free_var; + non_recursive_def_active_ = save_non_recursive; } else { + // kTVMFFIFieldDefKindNone (or any unknown value treated as None): + // not in a def region, leave map_free_vars_ as-is. success = CompareAny(lhs, rhs); } if (!success) { @@ -241,6 +286,14 @@ class StructEqualHandler { .cast<bool>(); } + // Restore the pre-clamp value of map_free_vars_ before deciding whether + // to bind a FreeVar pair below. The binding decision must use the value + // that the caller of CompareObject set — the clamp only affects this + // FreeVar's OWN sub-field walk. + if (clamp_map_free_vars) { + map_free_vars_ = save_map_free_vars; + } + if (success) { if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { // we are in a free var case that is not yet mapped. @@ -415,6 +468,13 @@ class StructEqualHandler { } // whether we map free variables that are not defined bool map_free_vars_{false}; + // Whether we are currently inside a non-recursive def region. Set when a + // field flagged ``kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive`` is being + // walked (or the custom-callback caller passed ``kTVMFFIFieldDefKindNonRecursive``). + // Consulted in CompareObject to clamp ``map_free_vars_`` to false when + // descending into a FreeVar's own sub-fields, while still allowing the + // FreeVar itself to bind in the post-pass. + bool non_recursive_def_active_{false}; // whether we compare tensor data bool skip_tensor_content_{false}; // the root lhs for result printing diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc index 8ab96f0..e4993f8 100644 --- a/src/ffi/extra/structural_hash.cc +++ b/src/ffi/extra/structural_hash.cc @@ -130,6 +130,21 @@ class StructuralHashHandler { static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn(reflection::type_attr::kSHash); + // Non-recursive def boundary (mirror of structural_equal.cc). When the + // current object is a FreeVar AND we are inside a non-recursive def + // region, clamp ``map_free_vars_`` to false during the FreeVar's own + // sub-field walk: nested free vars in those sub-fields then hash by + // pointer (matching use-semantics) instead of receiving fresh + // ``free_var_counter_`` slots. The clamp is restored before the + // FreeVar-level injection below so the FreeVar itself still gets its + // counter slot when ``map_free_vars_`` was on at the call site. + bool save_map_free_vars = map_free_vars_; + bool clamp_map_free_vars = + (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) && non_recursive_def_active_; + if (clamp_map_free_vars) { + map_free_vars_ = false; + } + // compute the hash value uint64_t hash_value = obj->GetTypeKeyHash(); if (custom_s_hash[type_info->type_index] == nullptr) { @@ -140,12 +155,23 @@ class StructuralHashHandler { // get the field value from both side reflection::FieldGetter getter(field_info); Any field_value = getter(obj); - // field is in def region, enable free var mapping - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); + // Dispatch on the def-region flags (mirror of the equality side). + // - Recursive : map_free_vars_ stays on for the whole subtree. + // - NonRecursive : map_free_vars_ on for the immediate FreeVar(s); + // clamped off when descending into a FreeVar's + // own sub-fields (the clamp lives in HashObject's + // prologue above, gated by non_recursive_def_active_). + constexpr int64_t kSEqHashDefAny = kTVMFFIFieldFlagBitMaskSEqHashDefRecursive | + kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive; + if (field_info->flags & kSEqHashDefAny) { + bool save_allow_free_var = map_free_vars_; + bool save_non_recursive = non_recursive_def_active_; + map_free_vars_ = true; + non_recursive_def_active_ = + (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) != 0; hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); - std::swap(allow_free_var, map_free_vars_); + map_free_vars_ = save_allow_free_var; + non_recursive_def_active_ = save_non_recursive; } else { hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); } @@ -154,12 +180,21 @@ class StructuralHashHandler { } else { if (s_hash_callback_ == nullptr) { s_hash_callback_ = - ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, bool def_region) { - if (def_region) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); + // The third parameter is a ``TVMFFIFieldDefKind`` (typed as plain + // ``int`` on the wire to keep the FFI signature stable; legacy + // callers passing ``bool`` continue to compile and preserve their + // meaning via the implicit bool->int coercion: false -> 0 + // (kTVMFFIFieldDefKindNone), true -> 1 (kTVMFFIFieldDefKindRecursive)). + ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, int def_kind) { + if (def_kind == kTVMFFIFieldDefKindRecursive || + def_kind == kTVMFFIFieldDefKindNonRecursive) { + bool save_allow_free_var = map_free_vars_; + bool save_non_recursive = non_recursive_def_active_; + map_free_vars_ = true; + non_recursive_def_active_ = (def_kind == kTVMFFIFieldDefKindNonRecursive); uint64_t hash_value = HashAny(val); - std::swap(allow_free_var, map_free_vars_); + map_free_vars_ = save_allow_free_var; + non_recursive_def_active_ = save_non_recursive; return static_cast<int64_t>(details::StableHashCombine(init_hash, hash_value)); } else { // we explicitly bitcast the result from `uint64_t` to `int64_t`. @@ -175,6 +210,13 @@ class StructuralHashHandler { .cast<uint64_t>(); } + // Restore the pre-clamp value of map_free_vars_ before deciding the + // FreeVar-level hash injection: the clamp only suppresses binding inside + // the FreeVar's own sub-fields, not the FreeVar slot itself. + if (clamp_map_free_vars) { + map_free_vars_ = save_map_free_vars; + } + if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { if (map_free_vars_) { // use lexical order of free var and its type @@ -318,6 +360,13 @@ class StructuralHashHandler { } bool map_free_vars_{false}; + // Whether we are currently inside a non-recursive def region. Set when a + // field flagged ``kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive`` is being + // hashed (or the custom-callback caller passed ``kTVMFFIFieldDefKindNonRecursive``). + // Consulted in HashObject to clamp ``map_free_vars_`` to false when + // descending into a FreeVar's own sub-fields, so nested free vars hash by + // pointer rather than receiving fresh ``free_var_counter_`` slots. + bool non_recursive_def_active_{false}; bool skip_tensor_content_{false}; // free var counter. uint32_t free_var_counter_{0}; diff --git a/tests/cpp/extra/test_structural_equal_hash.cc b/tests/cpp/extra/test_structural_equal_hash.cc index ad081e3..4649461 100644 --- a/tests/cpp/extra/test_structural_equal_hash.cc +++ b/tests/cpp/extra/test_structural_equal_hash.cc @@ -229,6 +229,83 @@ TEST(StructuralEqualHash, CustomTreeNode) { EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); } +// Regression tests for the SEqHashDefRecursive vs SEqHashDefNonRecursive +// distinction. ``TDefHolder`` has two sibling fields: +// - ``def_recursive`` tagged AttachFieldFlag::SEqHashDefRecursive() +// - ``def_non_recursive`` tagged AttachFieldFlag::SEqHashDefNonRecursive() +// each holding a ``TVarWithDep`` (a FreeVar with a sub-field ``dep`` that +// can itself reference another FreeVar). The tests below cover the four +// observable combinations of the two flags. +TEST(StructuralEqualHash, NonRecursiveDef_NestedFreeVarRebindsUnderRecursive) { + // Both fields receive a TVarWithDep whose ``dep`` contains a *fresh* + // FreeVar (TVar). Under the recursive flag, the nested ``dep`` rebinds + // transitively, so two holders that differ only in fresh names compare + // equal. + TVarWithDep a("a", TVar("m")); + TVarWithDep b("b", TVar("n")); + TDefHolder lhs(/*def_recursive=*/a, /*def_non_recursive=*/a); + TDefHolder rhs(/*def_recursive=*/b, /*def_non_recursive=*/b); + // ``def_non_recursive`` is the *same object* on both sides so it equates + // by pointer; the test exercises the recursive field's rebinding behavior + // without requiring the non-recursive side to succeed. + EXPECT_TRUE(StructuralEqual()(lhs, rhs)); + EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true), + StructuralHash::Hash(rhs, /*map_free_vars=*/true)); +} + +TEST(StructuralEqualHash, NonRecursiveDef_NestedFreeVarDoesNotRebindUnderNonRecursive) { + // The ``def_non_recursive`` field's value (TVarWithDep "c"/"d") binds + // because the holder explicitly tags the field as a def region. But the + // nested ``dep`` (TVar "p" / "q") is in the FreeVar's sub-field, which + // the non-recursive flag CLAMPS out of the def region. With no enclosing + // def region for "p" / "q", they must hit the unmapped FreeVar path in + // CompareObject and equality fails. + // + // The recursive sibling is set to identical pointers on both sides so the + // test isolates the non-recursive field as the failure source. + TVarWithDep shared("shared", std::nullopt); + TVarWithDep c_with_dep("c", TVar("p")); + TVarWithDep d_with_dep("d", TVar("q")); + TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep); + TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep); + EXPECT_FALSE(StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false)); +} + +TEST(StructuralEqualHash, NonRecursiveDef_NestedFreeVarResolvesViaOuterBinding) { + // Now wire the same nested free var on both sides, so that even under the + // non-recursive clamp the FreeVars at the leaf compare equal by *pointer* + // (the same.same_as branch in CompareObject's FreeVar handling). This + // mirrors the let-style use case where the nested var has been bound by + // an outer scope (here we cheat by using the same pointer directly). + TVar shared_dep("dep"); + TVarWithDep c_with_dep("c", shared_dep); + TVarWithDep d_with_dep("d", shared_dep); + TVarWithDep shared("shared", std::nullopt); + TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep); + TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep); + // ``c`` and ``d`` bind via the non-recursive def region; the nested + // shared_dep is the same object and so passes pointer-equality without + // needing map_free_vars_ to be on inside its sub-field walk. + EXPECT_TRUE(StructuralEqual()(lhs, rhs)); + EXPECT_EQ(StructuralHash()(lhs), StructuralHash()(rhs)); +} + +TEST(StructuralEqualHash, NonRecursiveDef_TopLevelFreeVarStillBinds) { + // Sanity: even with the non-recursive flag, the immediate FreeVar at the + // field's value MUST still bind. So a TVarWithDep with no nested ``dep`` + // (Optional set to nullopt) under the non-recursive flag should still + // compare equal across two fresh names — the binding step itself is not + // suppressed, only the descent into the FreeVar's sub-fields is. + TVarWithDep shared("shared", std::nullopt); + TVarWithDep c_no_dep("c", std::nullopt); + TVarWithDep d_no_dep("d", std::nullopt); + TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_no_dep); + TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_no_dep); + EXPECT_TRUE(StructuralEqual()(lhs, rhs)); + EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true), + StructuralHash::Hash(rhs, /*map_free_vars=*/true)); +} + TEST(StructuralEqualHash, List) { List<int> a = {1, 2, 3}; List<int> b = {1, 2, 3}; diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc index f9d567f..eef8ddb 100644 --- a/tests/cpp/test_reflection.cc +++ b/tests/cpp/test_reflection.cc @@ -66,6 +66,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { TFloatObj::RegisterReflection(); TPrimExprObj::RegisterReflection(); TVarObj::RegisterReflection(); + TVarWithDepObj::RegisterReflection(); + TDefHolderObj::RegisterReflection(); TFuncObj::RegisterReflection(); TCustomFuncObj::RegisterReflection(); TAllFieldsObj::RegisterReflection(); diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h index 48b1a01..d1bf0a9 100644 --- a/tests/cpp/testing_object.h +++ b/tests/cpp/testing_object.h @@ -206,6 +206,81 @@ class TVar : public ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVar, ObjectRef, TVarObj); }; +// FreeVar test object that has a sub-field referencing another FreeVar. +// This models the "var with nested vars" case (analogous to a relax::Var +// whose struct_info contains tir shape vars). It is used to exercise the +// difference between SEqHashDefRecursive and SEqHashDefNonRecursive at the +// FFI layer: under recursive semantics the nested ``dep`` var rebinds +// transitively; under non-recursive semantics it is treated as a use of an +// outer-scope binding and equality fails when no such outer binding exists. +class TVarWithDepObj : public Object { + public: + std::string name; + // Optional dependency var; when null, this object behaves like a plain + // FreeVar with no nested free vars. + Optional<ObjectRef> dep; + + TVarWithDepObj(std::string name, Optional<ObjectRef> dep) + : name(std::move(name)), dep(std::move(dep)) {} + explicit TVarWithDepObj(UnsafeInit) {} + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef<TVarWithDepObj>() + .def_ro("name", &TVarWithDepObj::name, refl::AttachFieldFlag::SEqHashIgnore()) + // ``dep`` participates in structural equality without any def flag, + // so it is a USE position. Whether the FreeVar in ``dep`` may rebind + // is decided by the def flag on whichever outer field reaches this + // object. + .def_ro("dep", &TVarWithDepObj::dep); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.VarWithDep", TVarWithDepObj, Object); +}; + +class TVarWithDep : public ObjectRef { + public: + explicit TVarWithDep(std::string name, Optional<ObjectRef> dep = std::nullopt) { + data_ = make_object<TVarWithDepObj>(std::move(name), std::move(dep)); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVarWithDep, ObjectRef, TVarWithDepObj); +}; + +// Holder with one recursive-def field and one non-recursive-def field. +// Used by StructuralEqualHash.NonRecursiveDef tests below. +class TDefHolderObj : public Object { + public: + TVarWithDep def_recursive; + TVarWithDep def_non_recursive; + + TDefHolderObj(TVarWithDep def_recursive, TVarWithDep def_non_recursive) + : def_recursive(std::move(def_recursive)), def_non_recursive(std::move(def_non_recursive)) {} + explicit TDefHolderObj(UnsafeInit) {} + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef<TDefHolderObj>() + .def_ro("def_recursive", &TDefHolderObj::def_recursive, + refl::AttachFieldFlag::SEqHashDefRecursive()) + .def_ro("def_non_recursive", &TDefHolderObj::def_non_recursive, + refl::AttachFieldFlag::SEqHashDefNonRecursive()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.DefHolder", TDefHolderObj, Object); +}; + +class TDefHolder : public ObjectRef { + public: + explicit TDefHolder(TVarWithDep def_recursive, TVarWithDep def_non_recursive) { + data_ = make_object<TDefHolderObj>(std::move(def_recursive), std::move(def_non_recursive)); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TDefHolder, ObjectRef, TDefHolderObj); +}; + class TFuncObj : public Object { public: Array<TVar> params; @@ -220,7 +295,7 @@ class TFuncObj : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef<TFuncObj>() - .def_ro("params", &TFuncObj::params, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("params", &TFuncObj::params, refl::AttachFieldFlag::SEqHashDefRecursive()) .def_ro("body", &TFuncObj::body) .def_ro("comment", &TFuncObj::comment, refl::AttachFieldFlag::SEqHashIgnore()); }
