This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 2309578 [FEAT] Add structural visitor and typed structural walk APIs
(#601)
2309578 is described below
commit 2309578ec76bf7d26062c43b06e44714d6b33dd0
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Wed Jun 10 05:38:26 2026 -0700
[FEAT] Add structural visitor and typed structural walk APIs (#601)
## Summary
This PR adds structural traversal APIs for TVM FFI values in C++ and
Python.
The main user-facing addition is `StructuralWalk` /
`tvm_ffi.structural_walk`, a pytree-style walk that can visit both
object-backed nodes and POD leaves (`int`, `float`, `bool`, `str`,
`bytes`). It supports typed callbacks, pre/post-order traversal, child
skipping, early interruption, and def-region-aware callbacks.
## Key Changes
- Added C++ `StructuralVisitor`, `VisitInterrupt`, `WalkResult`, and
`StructuralWalk`.
- Added Python `tvm_ffi.structural_walk`.
- Traversal operates on `AnyView`, allowing callbacks to match objects,
containers, and POD values.
- C++ callbacks dispatch on their first argument and may optionally
accept a second `TVMFFIDefRegionKind`.
- Python callbacks are passed as `(type, callback)` entries, with
support for grouped types like `((int, float), callback)`, a single
callback catch-all, or a sequence of entries.
- Python def-region-aware callbacks are passed via
`with_def_region_kind=` and receive `(value, def_region_kind)`.
- Added runtime type-index matching for Python callbacks, including
subclass matching, `Any` catch-all, `Object` matching, and string/bytes
storage variants.
- Custom structural visit hooks now use `AnyView` for the visited value.
- Added visit error-context propagation for object-backed nodes when
traversal fails.
## Examples
Python:
```python
visited = []
def on_int(value):
visited.append(value)
if value == 0:
return tvm_ffi.VisitInterrupt(value)
return tvm_ffi.WalkResult.ADVANCE
result = tvm_ffi.structural_walk(root, (int, on_int), order="pre")
```
Grouped callback types:
```python
tvm_ffi.structural_walk(
root,
[
((int, float), on_number),
(MyNode, on_node),
(object, on_any),
],
)
```
Def-region-aware callbacks:
```python
uses = []
tvm_ffi.structural_walk(
func,
with_def_region_kind=(
Var,
lambda var, kind: uses.append(var) if kind ==
tvm_ffi.DefRegionKind.NONE else None,
),
)
```
C++:
```cpp
auto result = StructuralWalk<WalkOrder::kPreOrder>(
root,
[&](const Add& add) -> Expected<WalkResult> {
++num_adds;
return WalkResult::Advance();
},
[&](const Mul& mul) -> Expected<WalkResult> {
return WalkResult::Skip();
});
```
C++ def-region-aware walk:
```cpp
List<Var> uses;
auto result = StructuralWalk<WalkOrder::kPreOrder>(
func,
[&](const VarObj* var, TVMFFIDefRegionKind kind) ->
Expected<WalkResult> {
if (kind == kTVMFFIDefRegionKindNone) {
uses.push_back(Var(GetObjectPtr<VarObj>(var)));
}
return WalkResult::Advance();
});
```
## Testing
- Added C++ tests for structural visitor traversal, def-region hooks,
callback dispatch, POD leaves, object pointers, skip/interrupt behavior,
def-region-aware callbacks, and error propagation.
- Added Python tests for typed callbacks, grouped callback types, mixed
callback forms, def-region-aware callbacks, nested containers,
`Any`/`Object` behavior, post-order traversal, skip, and interrupts.
- Verified C++ structural visitor and visit error-context tests.
---
CMakeLists.txt | 1 +
docs/concepts/structural_eq_hash.rst | 188 ++++++-
docs/reference/python/index.rst | 16 +
include/tvm/ffi/c_api.h | 2 +
include/tvm/ffi/expected.h | 78 +++
include/tvm/ffi/extra/structural_visit.h | 771 ++++++++++++++++++++++++++++
include/tvm/ffi/extra/visit_error_context.h | 16 +
include/tvm/ffi/object.h | 45 ++
include/tvm/ffi/reflection/accessor.h | 21 +
python/tvm_ffi/__init__.py | 12 +
python/tvm_ffi/_ffi_api.py | 14 +-
python/tvm_ffi/dataclasses/py_class.py | 1 +
python/tvm_ffi/structural.py | 324 +++++++++++-
src/ffi/extra/structural_visit.cc | 175 +++++++
tests/cpp/extra/test_structural_visit.cc | 480 +++++++++++++++++
tests/cpp/test_expected.cc | 24 +
tests/cpp/test_reflection.cc | 1 +
tests/cpp/testing_object.h | 40 ++
tests/python/test_structural.py | 206 ++++++++
19 files changed, 2411 insertions(+), 4 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 018f8c3..bbf511c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -75,6 +75,7 @@ set(_tvm_ffi_objs_sources
set(_tvm_ffi_extra_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_equal.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_visit.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/visit_error_context.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc"
diff --git a/docs/concepts/structural_eq_hash.rst
b/docs/concepts/structural_eq_hash.rst
index 8dff838..cf6efa8 100644
--- a/docs/concepts/structural_eq_hash.rst
+++ b/docs/concepts/structural_eq_hash.rst
@@ -15,8 +15,8 @@
specific language governing permissions and limitations
under the License.
-Structural Equality and Hashing
-===============================
+Structural Equality, Hashing, and Walk
+======================================
TVM FFI provides ``structural_equal`` and ``structural_hash`` for the
object graph. These compare objects by **content** — recursively walking
@@ -955,3 +955,187 @@ And in Python:
assert structural_equal(f1, f2) # alpha-equivalent
assert structural_hash(f1) == structural_hash(f2) # same hash
+
+
+Structural Walk and Visit
+-------------------------
+
+``structural_equal`` and ``structural_hash`` are built on a structural
traversal
+of the value graph. ``structural_walk`` exposes that traversal directly: it
+visits containers, object fields, and POD leaves, and invokes user callbacks
for
+values whose runtime type matches a callback entry.
+
+It is useful when you want to collect information, validate a tree, find a
node, or
+stop traversal early without writing a custom equality/hash hook.
+
+Basic Walk
+~~~~~~~~~~
+
+Pass callbacks as ordered ``(type, callback)`` entries. The first matching
+entry runs for each visited value. Normal Python callbacks receive one
+argument, ``value``.
+
+.. code-block:: python
+
+ import tvm_ffi
+
+ visited = []
+
+ def on_int(value):
+ visited.append(value)
+ if value == 0:
+ return tvm_ffi.VisitInterrupt(value)
+ return tvm_ffi.WalkResult.ADVANCE
+
+ result = tvm_ffi.structural_walk(root, (int, on_int), order="pre")
+
+ if result is not None:
+ print("stopped at", result.value)
+
+Callbacks may return:
+
+- ``WalkResult.ADVANCE`` to continue into children.
+- ``WalkResult.SKIP`` to skip the current value's children.
+- ``VisitInterrupt(payload)`` to stop the entire walk and return an interrupt
+ carrying ``payload``.
+- ``None`` as shorthand for ``WalkResult.ADVANCE``.
+
+Grouped Types
+~~~~~~~~~~~~~
+
+Several types can share one callback by passing a tuple of types:
+
+.. code-block:: python
+
+ numbers = []
+ strings = []
+
+ tvm_ffi.structural_walk(
+ root,
+ [
+ ((int, float), lambda value: numbers.append(value)),
+ (str, lambda value: strings.append(value)),
+ ],
+ )
+
+This is normalized as if the same callback had been registered separately for
+``int`` and ``float``. Callback entries are still tried in order, so broad
+callbacks should usually come after more specific ones.
+
+Catch-All and Object Callbacks
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``object`` and ``typing.Any`` are catch-all callbacks. They match POD leaves
+and object-backed values.
+
+.. code-block:: python
+
+ from typing import Any
+
+ seen = []
+ tvm_ffi.structural_walk(root, (Any, lambda value: seen.append(value)))
+
+``tvm_ffi.Object`` is different: it matches only object-backed FFI values, such
+as ``Array``, ``Map``, ``Function``, ``String`` objects, or registered object
+classes. It does not match POD leaves such as ``int`` or ``float``.
+
+.. code-block:: python
+
+ objects = []
+ leaves = []
+
+ tvm_ffi.structural_walk(
+ root,
+ [
+ (tvm_ffi.Object, lambda value: objects.append(value)),
+ (object, lambda value: leaves.append(value)),
+ ],
+ )
+
+Def-Region Aware Walk
+~~~~~~~~~~~~~~~~~~~~~
+
+Callbacks passed to ``with_def_region_kind`` receive a second argument that
+reports whether the current value is visited as a definition or a use. This is
+useful for analyses such as collecting variable uses while skipping binders:
+
+.. code-block:: python
+
+ uses = []
+
+ tvm_ffi.structural_walk(
+ func,
+ with_def_region_kind=(
+ Var,
+ lambda var, kind: (
+ uses.append(var) if kind == tvm_ffi.DefRegionKind.NONE else None
+ ),
+ ),
+ )
+
+For a function node, parameters are visited in a definition region, while
+occurrences in the body are visited with ``DefRegionKind.NONE``.
+
+Traversal Order
+~~~~~~~~~~~~~~~
+
+The default order is pre-order: callbacks run before visiting children.
+Post-order callbacks run after children.
+
+.. code-block:: python
+
+ trace = []
+
+ tvm_ffi.structural_walk(
+ tvm_ffi.Array([tvm_ffi.Array([1]), 2]),
+ [
+ (tvm_ffi.Array, lambda value: trace.append(f"array:{len(value)}")),
+ (int, lambda value: trace.append(f"int:{value}")),
+ ],
+ order="post",
+ )
+
+ assert trace == ["int:1", "array:1", "int:2", "array:2"]
+
+C++ Walk
+~~~~~~~~
+
+C++ code can use ``StructuralWalk`` with typed callbacks. Callbacks are tried
+in order and dispatch on the first argument type:
+
+.. code-block:: cpp
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPreOrder>(
+ root,
+ [&](const Add& add) -> Expected<WalkResult> {
+ ++num_adds;
+ return WalkResult::Advance();
+ },
+ [&](const Mul& mul) -> Expected<WalkResult> {
+ return WalkResult::Skip();
+ });
+
+C++ callbacks dispatch on their first argument, which may be ``AnyView``,
+``Any``, an ``ObjectRef`` subclass, an ``Object`` pointer type, or another
+FFI-convertible POD type. They may also take an optional second
+``TVMFFIDefRegionKind`` argument to distinguish definition sites from uses.
+Errors should be returned as ``Expected<WalkResult>``.
+
+Low-Level ``StructuralVisitor``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``StructuralVisitor`` is the lower-level traversal object. It is mainly useful
+inside structural visit hooks or C++ integrations that need to participate in
+the same recursive traversal protocol.
+
+Python users normally call ``structural_walk`` instead. The low-level visitor
+API exposes:
+
+- ``visitor.visit(value)`` to recursively visit a child value.
+- ``visitor.def_region_kind()`` to inspect the current definition-region mode.
+- ``visitor.with_def_region_kind(kind, callback)`` to run a recursive visit
+ under a temporary definition-region mode.
+
+Custom visit hooks are registered as the ``__s_visit__`` type attribute. They
+receive the active visitor and the current object, and are responsible for
+calling ``visitor.visit(child)`` on structural children.
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 13f9d5a..b5227f9 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -76,6 +76,22 @@ Containers
Map
+Structural
+----------
+.. autosummary::
+ :toctree: generated/
+
+ StructuralKey
+ StructuralVisitor
+ VisitInterrupt
+ WalkOrder
+ WalkResult
+ get_first_structural_mismatch
+ structural_equal
+ structural_hash
+ structural_walk
+
+
Global Registry
---------------
.. autosummary::
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 20fd4fe..8af33b7 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -177,6 +177,8 @@ typedef enum {
kTVMFFIList = 75,
/*! \brief Dict object. */
kTVMFFIDict = 76,
+ /*! \brief Structural visit interrupt object. */
+ kTVMFFIVisitInterrupt = 77,
//----------------------------------------------------------------
// more complex objects
//----------------------------------------------------------------
diff --git a/include/tvm/ffi/expected.h b/include/tvm/ffi/expected.h
index 6e822df..ee5f7d3 100644
--- a/include/tvm/ffi/expected.h
+++ b/include/tvm/ffi/expected.h
@@ -64,6 +64,12 @@ template <typename E>
Unexpected(E) -> Unexpected<E>;
#endif
+namespace details {
+
+struct ExpectedUnsafe;
+
+} // namespace details
+
/*!
* \brief Expected<T> provides exception-free error handling for FFI functions.
*
@@ -113,6 +119,9 @@ class Expected {
// NOLINTNEXTLINE(google-explicit-constructor,runtime/explicit)
Expected(Unexpected<E> unexpected) :
data_(Any(std::move(unexpected).error())) {}
+ /*! \brief Return the raw stored type index. */
+ TVM_FFI_INLINE int32_t type_index() const noexcept { return
data_.type_index(); }
+
/*! \brief Returns true if the Expected contains a success value. */
TVM_FFI_INLINE bool is_ok() const noexcept {
return data_.type_index() != TypeIndex::kTVMFFIError;
@@ -186,9 +195,78 @@ class Expected {
}
private:
+ Expected() = default;
+
+ friend struct details::ExpectedUnsafe;
+
Any data_; // Invariant: holds a T (type_index != kTVMFFIError) or an Error.
};
+namespace details {
+
+/*!
+ * \brief Unsafe raw-storage helpers for Expected.
+ *
+ * These helpers bypass normal value checking and are intended for ABI
boundaries
+ * that already know the underlying Any storage holds either a valid T or
Error.
+ */
+struct ExpectedUnsafe {
+ /*!
+ * \brief Move a raw TVMFFIAny into Expected storage.
+ * \tparam T The Expected success type.
+ * \param raw The raw FFI value to move.
+ * \return Expected backed by moved Any storage.
+ */
+ template <typename T>
+ TVM_FFI_INLINE static Expected<T> MoveFromTVMFFIAny(TVMFFIAny raw) {
+ Expected<T> result;
+ result.data_ = AnyUnsafe::MoveTVMFFIAnyToAny(&raw);
+ return result;
+ }
+
+ /*!
+ * \brief Move Expected storage to a raw TVMFFIAny.
+ * \tparam T The Expected success type.
+ * \param result The Expected value to move from.
+ * \return Raw FFI value containing moved underlying Any storage.
+ */
+ template <typename T>
+ TVM_FFI_INLINE static TVMFFIAny MoveToTVMFFIAny(Expected<T>&& result) {
+ return AnyUnsafe::MoveAnyToTVMFFIAny(std::move(result.data_));
+ }
+
+ /*!
+ * \brief Return the underlying Any storage.
+ * \tparam T The Expected success type.
+ * \param result The Expected value to inspect.
+ * \return Const reference to the raw Any storage.
+ */
+ template <typename T>
+ TVM_FFI_INLINE static const Any& GetData(const Expected<T>& result) noexcept
{
+ return result.data_;
+ }
+
+ /*!
+ * \brief Read an Expected success value as a compatible raw storage type.
+ * \tparam T The type to read from the underlying Any storage.
+ * \tparam U The Expected success type.
+ * \param result The Expected value to read from.
+ * \return The stored value decoded as T.
+ *
+ * \note This assumes \p result stores T-compatible Any storage, or Error.
+ */
+ template <typename T, typename U>
+ TVM_FFI_INLINE static T ValueAs(const Expected<U>& result) {
+ const Any& data = result.data_;
+ if (TVM_FFI_PREDICT_TRUE(data.type_index() != TypeIndex::kTVMFFIError)) {
+ return AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data);
+ }
+ throw AnyUnsafe::CopyFromAnyViewAfterCheck<Error>(data);
+ }
+};
+
+} // namespace details
+
// TypeTraits specialization for Expected<T>
template <typename T>
inline constexpr bool use_default_type_traits_v<Expected<T>> = false;
diff --git a/include/tvm/ffi/extra/structural_visit.h
b/include/tvm/ffi/extra/structural_visit.h
new file mode 100644
index 0000000..e1c0c56
--- /dev/null
+++ b/include/tvm/ffi/extra/structural_visit.h
@@ -0,0 +1,771 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/structural_visit.h
+ * \brief Structural visit API.
+ */
+#ifndef TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_
+#define TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/cast.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/tuple.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/expected.h>
+#include <tvm/ffi/extra/visit_error_context.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/function_details.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/reflection/accessor.h>
+
+#include <cstddef>
+#include <exception>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Object node carrying the optional payload for an interrupted
structural visit.
+ */
+class VisitInterruptObj : public Object {
+ public:
+ /*! \brief Payload returned with the interrupt, or FFI None for no payload.
*/
+ Any value;
+
+ VisitInterruptObj() = default;
+ /*!
+ * \brief Construct a VisitInterruptObj with a payload.
+ * \param value The payload carried by the interrupt.
+ */
+ explicit VisitInterruptObj(Any value) : value(std::move(value)) {}
+
+ /// \cond Doxygen_Suppress
+ static constexpr const int32_t _type_index =
TypeIndex::kTVMFFIVisitInterrupt;
+ static const constexpr bool _type_final = true;
+ TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIVisitInterrupt,
VisitInterruptObj,
+ Object);
+ /// \endcond
+};
+
+/*!
+ * \brief ObjectRef wrapper for VisitInterruptObj.
+ */
+class VisitInterrupt : public ObjectRef {
+ public:
+ /*! \brief Construct an interrupt with no payload. */
+ VisitInterrupt() : VisitInterrupt(Any(nullptr)) {}
+ /*!
+ * \brief Construct an interrupt with a user-defined payload.
+ * \param value The payload carried by the interrupt.
+ */
+ explicit VisitInterrupt(Any value)
+ : ObjectRef(make_object<VisitInterruptObj>(std::move(value))) {}
+
+ /// \cond Doxygen_Suppress
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(VisitInterrupt, ObjectRef,
VisitInterruptObj);
+ /// \endcond
+};
+
+class StructuralVisitorObj;
+
+/*!
+ * \brief ABI of structural visit for ``kStructuralVisit`` type attribute and
+ * ``StructuralVisitorVTable`` function pointer signature.
+ *
+ * The callback receives the visitor and the value being visited as an
+ * ``AnyView``. It returns a raw ``TVMFFIAny`` storing
+ * ``Expected<Optional<VisitInterrupt>>``.
+ */
+using FStructuralVisit = TVMFFIAny (*)(StructuralVisitorObj* visitor, AnyView
value) noexcept;
+
+namespace details {
+
+// Visit reflected structural fields of an object-backed value.
+TVM_FFI_INLINE static Expected<Optional<VisitInterrupt>>
VisitReflectedFieldsExpected(
+ StructuralVisitorObj* visitor, const Object* obj) noexcept;
+
+} // namespace details
+
+/*!
+ * \brief VTable ABI for \ref StructuralVisitor dispatch. This function table
provides a stable ABI
+ * for the visit method.
+ */
+struct StructuralVisitorVTable {
+ /*!
+ * \brief Visit callback.
+ * \param visitor The active structural visitor.
+ * \param value The value to visit.
+ * \return TVMFFIAny carrying Expected<Optional<VisitInterrupt>>.
+ *
+ * \note The raw ``visitor`` pointer and ``value`` view are non-owning. On
+ * failure, the returned ``TVMFFIAny`` stores ``Error``; on success, it
stores
+ * either None or ``VisitInterrupt``.
+ */
+ FStructuralVisit visit = nullptr;
+};
+
+/*!
+ * \brief Object node of a structural visitor.
+ *
+ * A structural visitor is an active traversal context. It carries the
dispatch
+ * table used to visit each object and the current def-region state used by
+ * structural equality/hash semantics. The visitor is ref-counted so it can
+ * cross FFI boundaries, but one underlying visitor object should not be shared
+ * by overlapping top-level traversals.
+ */
+class StructuralVisitorObj : public Object {
+ public:
+ /*! \brief Construct the default structural visitor. */
+ StructuralVisitorObj() : StructuralVisitorObj(VTable()) {}
+
+ /*!
+ * \brief Visit a value, dispatching through this visitor's vtable.
+ *
+ * \param value The value to visit.
+ * \return ``std::nullopt`` to continue traversal, or a \ref VisitInterrupt
+ * to halt the entire visit.
+ */
+ TVM_FFI_INLINE Optional<VisitInterrupt> Visit(AnyView value) {
+ return VisitExpected(value).value();
+ }
+
+ /*!
+ * \brief Visit a value, propagating error through expected return.
+ *
+ * \param value The value to visit.
+ * \return Expected interrupt state. An error means traversal failed.
+ */
+ TVM_FFI_INLINE Expected<Optional<VisitInterrupt>> VisitExpected(AnyView
value) noexcept {
+ return
details::ExpectedUnsafe::MoveFromTVMFFIAny<Optional<VisitInterrupt>>(
+ (*vtable_->visit)(this, value));
+ }
+
+ /*!
+ * \brief Visit using the structural visit behavior registered by
kStructuralVisit for each type,
+ * or reflected structural fields when no custom behavior is registered.
+ */
+ TVM_FFI_INLINE Optional<VisitInterrupt> DefaultVisit(AnyView value) {
+ return DefaultVisitExpected(value).value();
+ }
+
+ /*!
+ * \brief Visit using the registered structural visit behavior by
kStructuralVisit, propagating
+ * errors by Expected.
+ *
+ * \param value The value to visit.
+ * \return Expected interrupt state. An error means traversal failed.
+ */
+ TVM_FFI_INLINE Expected<Optional<VisitInterrupt>>
DefaultVisitExpected(AnyView value) noexcept {
+ int32_t type_index = value.type_index();
+ static reflection::TypeAttrColumn
column(reflection::type_attr::kStructuralVisit);
+ AnyView attr = column[type_index];
+
+ // case 1: Type-specific override registered as an opaque ABI visit
function pointer.
+ if (attr.type_index() == TypeIndex::kTVMFFIOpaquePtr) {
+ auto* visit_fn = reinterpret_cast<FStructuralVisit>(attr.cast<void*>());
+ return
details::ExpectedUnsafe::MoveFromTVMFFIAny<Optional<VisitInterrupt>>(
+ (*visit_fn)(this, value));
+ }
+
+ // case 2: Type-specific override registered as an ffi::Function.
+ if (attr.type_index() == TypeIndex::kTVMFFIFunction) {
+ return
attr.cast<Function>().CallExpected<Optional<VisitInterrupt>>(this, value);
+ }
+
+ if (TVM_FFI_PREDICT_FALSE(attr.type_index() != TypeIndex::kTVMFFINone)) {
+ return Unexpected(Error("TypeError",
+
std::string(reflection::type_attr::kStructuralVisit) +
+ " must be an opaque function pointer or
ffi.Function",
+ ""));
+ }
+
+ if (type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+ return Optional<VisitInterrupt>(std::nullopt);
+ }
+
+ return details::VisitReflectedFieldsExpected(this, value.cast<const
Object*>());
+ }
+
+ /*!
+ * \brief Return the current def-region context.
+ * \return The active def-region kind.
+ */
+ TVM_FFI_INLINE TVMFFIDefRegionKind def_region_kind() const { return
def_region_mode_; }
+
+ /*!
+ * \brief Temporarily switch the def-region context while invoking \p
callback.
+ *
+ * This helper scopes updates to the traversal state used by def/use-region
+ * aware visitors. The previous state is restored when the callback returns
+ * or throws.
+ *
+ * \param kind The def-region kind to set during the callback.
+ * \param callback A nullary callable that performs recursive visiting.
+ * \return The value returned by \p callback.
+ */
+ template <typename Callback>
+ TVM_FFI_INLINE auto WithDefRegionKind(TVMFFIDefRegionKind kind, Callback&&
callback) {
+ class Scope {
+ public:
+ Scope(StructuralVisitorObj* visitor, TVMFFIDefRegionKind kind)
+ : visitor_(visitor), old_kind_(visitor->def_region_mode_) {
+ visitor_->def_region_mode_ = kind;
+ }
+ ~Scope() { visitor_->def_region_mode_ = old_kind_; }
+ Scope(const Scope&) = delete;
+ Scope& operator=(const Scope&) = delete;
+
+ private:
+ StructuralVisitorObj* visitor_;
+ TVMFFIDefRegionKind old_kind_;
+ };
+ Scope scope(this, kind);
+ return std::forward<Callback>(callback)();
+ }
+
+ /// \cond Doxygen_Suppress
+ static constexpr const bool _type_mutable = true;
+ TVM_FFI_DECLARE_OBJECT_INFO("ffi.StructuralVisitor", StructuralVisitorObj,
Object);
+ /// \endcond
+
+ protected:
+ /*!
+ * \brief Construct a structural visitor subclass with a custom dispatch
vtable.
+ *
+ * \param vtable The non-null dispatch table for this visitor.
+ *
+ * \note This constructor is for internal subclasses. The vtable and its
+ * ``visit`` callback must be valid for the lifetime of the visitor.
+ */
+ explicit StructuralVisitorObj(const StructuralVisitorVTable* vtable) :
vtable_(vtable) {}
+
+ /*!
+ * \brief Required ABI dispatch table. \ref StructuralVisitorVTable
+ * It must never be null on a constructed visitor.
+ */
+ const StructuralVisitorVTable* vtable_ = nullptr;
+
+ /*!
+ * \brief Current def-region context for structural equality/hash semantics.
+ *
+ * This is shared mutable traversal state. Be careful when mutating it
through
+ * multiple references to the same visitor object. Use \ref WithDefRegionKind
+ * to scope temporary changes.
+ */
+ TVMFFIDefRegionKind def_region_mode_ = kTVMFFIDefRegionKindNone;
+
+ private:
+ /*!
+ * \brief Return the vtable used by the default visitor.
+ * \return Pointer to the static structural visitor vtable.
+ */
+ static const StructuralVisitorVTable* VTable() {
+ static const StructuralVisitorVTable
vtable{&StructuralVisitorObj::DispatchVisit};
+ return &vtable;
+ }
+
+ /*!
+ * \brief Dispatch from the vtable to the default visitor.
+ * \param visitor The structural visitor object.
+ * \param value The value to visit.
+ * \return Interrupt state, or an error if traversal failed.
+ */
+ static TVMFFIAny DispatchVisit(StructuralVisitorObj* visitor, AnyView value)
noexcept {
+ auto interrupt = visitor->DefaultVisitExpected(value);
+ if (TVM_FFI_PREDICT_FALSE(interrupt.type_index() ==
TypeIndex::kTVMFFIError)) {
+ if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ Error err = interrupt.error();
+ details::UpdateVisitErrorContext(err, value.cast<ObjectRef>());
+ }
+ }
+ return details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(interrupt));
+ }
+};
+
+/*!
+ * \brief ObjectRef wrapper of \ref StructuralVisitorObj.
+ *
+ * \sa StructuralVisitorObj
+ */
+class StructuralVisitor : public ObjectRef {
+ public:
+ /*!
+ * \brief Construct the default structural visitor.
+ */
+ StructuralVisitor() : ObjectRef(make_object<StructuralVisitorObj>()) {}
+ /*!
+ * \brief Construct from an existing object pointer.
+ * \param n The object pointer to wrap.
+ */
+ explicit StructuralVisitor(ObjectPtr<StructuralVisitorObj> n) :
ObjectRef(std::move(n)) {}
+
+ /// \cond Doxygen_Suppress
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StructuralVisitor, ObjectRef,
StructuralVisitorObj);
+ /// \endcond
+};
+
+namespace details {
+
+/*!
+ * \brief Return true when \p result already carries a traversal-stopping
state.
+ * \tparam T The Expected success type.
+ * \param result The Expected value to inspect.
+ * \return Whether \p result stores an Error or VisitInterrupt.
+ */
+template <typename T>
+TVM_FFI_INLINE bool StructuralVisitNeedEarlyReturn(const Expected<T>& result)
noexcept {
+ int32_t type_index = result.type_index();
+ return type_index == TypeIndex::kTVMFFIError || type_index ==
TypeIndex::kTVMFFIVisitInterrupt;
+}
+
+/*!
+ * \brief Walk reflected structural fields of object-backed \p obj.
+ *
+ * Fields marked with ``kTVMFFIFieldFlagBitMaskSEqHashIgnore`` are skipped.
+ * Def-region field flags are scoped around recursive child visits.
+ *
+ * \param visitor The active visitor.
+ * \param obj The object whose reflected fields should be visited.
+ * \return Expected interrupt state. An error means traversal failed.
+ */
+TVM_FFI_INLINE static Expected<Optional<VisitInterrupt>>
VisitReflectedFieldsExpected(
+ StructuralVisitorObj* visitor, const Object* obj) noexcept {
+ int32_t type_index = obj->type_index();
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
+
+ Expected<Optional<VisitInterrupt>> result =
Optional<VisitInterrupt>(std::nullopt);
+ reflection::ForEachFieldInfoWithEarlyStop(
+ type_info, [&](const TVMFFIFieldInfo* field_info) -> bool {
+ if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) {
+ return false;
+ }
+
+ Any field_value;
+ const void* field_addr = reinterpret_cast<const char*>(obj) +
field_info->offset;
+ int ret_code = field_info->getter(const_cast<void*>(field_addr),
+
reinterpret_cast<TVMFFIAny*>(&field_value));
+ if (TVM_FFI_PREDICT_FALSE(ret_code != 0)) {
+ result = Unexpected(details::MoveFromSafeCallRaised());
+ return true;
+ }
+
+ TVMFFIDefRegionKind kind = kTVMFFIDefRegionKindNone;
+ if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive)
{
+ kind = kTVMFFIDefRegionKindNonRecursive;
+ } else if (field_info->flags &
kTVMFFIFieldFlagBitMaskSEqHashDefRecursive) {
+ kind = kTVMFFIDefRegionKindRecursive;
+ }
+
+ if (kind != kTVMFFIDefRegionKindNone) {
+ result = visitor->WithDefRegionKind(
+ kind, [&]() { return visitor->VisitExpected(field_value); });
+ } else {
+ result = visitor->VisitExpected(field_value);
+ }
+ return StructuralVisitNeedEarlyReturn(result);
+ });
+ return result;
+}
+
+} // namespace details
+
+// ---------------------------------------------------------------------------
+// Structural Walk API.
+// ---------------------------------------------------------------------------
+
+/*!
+ * \brief Per-node control signal returned by structural walk callbacks.
+ *
+ * Walk control result with one of three actions:
+ * - ``WalkResult::Advance()``: continue traversal, including this node's
children.
+ * - ``WalkResult::Skip()``: continue traversal but skip this node's children.
+ * - ``WalkResult::Interrupt()``: halt the entire walk, optionally carrying a
payload.
+ */
+class WalkResult : public Variant<VisitInterrupt, int32_t> {
+ public:
+ /*! \brief Internal tag value carried by ``WalkResult::Advance()``. */
+ static constexpr int32_t kAdvanceTag = 0;
+ /*! \brief Internal tag value carried by ``WalkResult::Skip()``. */
+ static constexpr int32_t kSkipTag = 1;
+
+ /*! \brief The underlying ``Variant`` used as storage. */
+ using Storage = Variant<VisitInterrupt, int32_t>;
+
+ /*! \brief Continue traversal and visit this node's children. */
+ static WalkResult Advance() { return WalkResult(kAdvanceTag); }
+
+ /*! \brief Continue traversal but skip this node's children. */
+ static WalkResult Skip() { return WalkResult(kSkipTag); }
+
+ /*!
+ * \brief Halt the walk and propagate an interrupt.
+ * \param signal The interrupt to propagate. Defaults to an interrupt with
+ * FFI None payload.
+ */
+ static WalkResult Interrupt(VisitInterrupt signal = VisitInterrupt()) {
+ return WalkResult(Storage(std::move(signal)));
+ }
+
+ private:
+ // Keep raw storage construction behind the named factories.
+ explicit WalkResult(int32_t tag) : Storage(tag) {}
+ explicit WalkResult(Storage storage) : Storage(std::move(storage)) {}
+
+ friend struct TypeTraits<WalkResult>;
+};
+
+/// \cond Doxygen_Suppress
+template <>
+inline constexpr bool use_default_type_traits_v<WalkResult> = false;
+
+// Allow WalkResult to round-trip through Any / Expected while reusing Variant
storage.
+template <>
+struct TypeTraits<WalkResult> : public TypeTraits<WalkResult::Storage> {
+ using Base = TypeTraits<WalkResult::Storage>;
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFINone ||
Base::CheckAnyStrict(src);
+ }
+ // Decode from borrowed Any storage after a strict type check.
+ TVM_FFI_INLINE static WalkResult CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) {
+ return WalkResult::Advance();
+ }
+ return WalkResult(Base::CopyFromAnyViewAfterCheck(src));
+ }
+ // Decode by moving from owned Any storage after a strict type check.
+ TVM_FFI_INLINE static WalkResult MoveFromAnyAfterCheck(TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) {
+ return WalkResult::Advance();
+ }
+ return WalkResult(Base::MoveFromAnyAfterCheck(src));
+ }
+ // Try all conversions supported by the underlying Variant storage.
+ TVM_FFI_INLINE static std::optional<WalkResult> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) {
+ return WalkResult::Advance();
+ }
+ if (auto opt = Base::TryCastFromAnyView(src)) {
+ return WalkResult(*std::move(opt));
+ }
+ return std::nullopt;
+ }
+ TVM_FFI_INLINE static std::string TypeStr() { return "WalkResult"; }
+};
+/// \endcond
+
+/*!
+ * \brief Callback order for \ref tvm::ffi::StructuralWalk.
+ */
+enum class WalkOrder : int32_t {
+ /*! \brief Invoke the callback before visiting children. */
+ kPreOrder = 0,
+ /*! \brief Invoke the callback after visiting children. */
+ kPostOrder = 1,
+};
+
+namespace details {
+
+/// \cond Doxygen_Suppress
+// Return from the current ABI visit function if Result stops traversal.
+// Result must evaluate to Expected whose raw storage can be moved to
TVMFFIAny.
+#define TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN(Result)
\
+ do {
\
+ auto&& tvm_ffi_res_ = (Result);
\
+ if (TVM_FFI_PREDICT_FALSE(
\
+
::tvm::ffi::details::StructuralVisitNeedEarlyReturn(tvm_ffi_res_))) {
\
+ return
::tvm::ffi::details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(tvm_ffi_res_)); \
+ }
\
+ } while (0)
+
+// Return from the current ABI visit function if Result stops traversal.
+// If Result is an Error, append Node to the visit error context before
returning.
+#define TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(Result, Node)
\
+ do {
\
+ auto&& tvm_ffi_res_ = (Result);
\
+ if (TVM_FFI_PREDICT_FALSE(
\
+
::tvm::ffi::details::StructuralVisitNeedEarlyReturn(tvm_ffi_res_))) {
\
+ if (TVM_FFI_PREDICT_FALSE(tvm_ffi_res_.type_index() ==
\
+ ::tvm::ffi::TypeIndex::kTVMFFIError)) {
\
+ if ((Node).type_index() >=
::tvm::ffi::TypeIndex::kTVMFFIStaticObjectBegin) { \
+ ::tvm::ffi::Error tvm_ffi_visit_err_ = tvm_ffi_res_.error();
\
+ ::tvm::ffi::details::UpdateVisitErrorContext(tvm_ffi_visit_err_,
\
+
(Node).cast<::tvm::ffi::ObjectRef>()); \
+ }
\
+ }
\
+ return
::tvm::ffi::details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(tvm_ffi_res_));
\
+ }
\
+ } while (0)
+/// \endcond
+
+/*!
+ * \brief Visitor used by callback-dispatched ``StructuralWalk``.
+ *
+ * \tparam order Callback placement relative to child traversal.
+ * \tparam Dispatch Callable returning ``Expected<WalkResult>`` when invoked
with ``AnyView`` and
+ * the active def-region kind. User callbacks wrapped by
this dispatcher may
+ * accept either ``(value)`` or ``(value, def_region_kind)``.
+ */
+template <WalkOrder order, typename Dispatch>
+class StructuralWalkCallbackVisitorObj : public StructuralVisitorObj {
+ public:
+ /*!
+ * \brief Construct a structural walk visitor.
+ * \param dispatch The composed dispatcher invoked on each visited node.
+ */
+ explicit StructuralWalkCallbackVisitorObj(Dispatch dispatch)
+ : StructuralVisitorObj(VTable()), dispatch_(std::move(dispatch)) {}
+
+ private:
+ /*!
+ * \brief Return the vtable used by this visitor.
+ * \return Pointer to the static structural visitor vtable.
+ */
+ static const StructuralVisitorVTable* VTable() {
+ static const StructuralVisitorVTable
vtable{&StructuralWalkCallbackVisitorObj::DispatchVisit};
+ return &vtable;
+ }
+
+ /*!
+ * \brief Dispatch from the erased visitor pointer to the concrete walk
visitor.
+ * \param self The erased structural visitor object.
+ * \param value The value to visit.
+ * \return Interrupt state, or an error if traversal failed.
+ */
+ static TVMFFIAny DispatchVisit(StructuralVisitorObj* self, AnyView value)
noexcept {
+ return
static_cast<StructuralWalkCallbackVisitorObj*>(self)->VisitImpl(value);
+ }
+
+ /*!
+ * \brief Visit one value according to the configured walk order.
+ * \param value The value to visit.
+ * \return Interrupt state, or an error if traversal failed.
+ */
+ TVMFFIAny VisitImpl(AnyView value) noexcept {
+ if (TVM_FFI_PREDICT_FALSE(value.type_index() == TypeIndex::kTVMFFINone)) {
+ return details::ExpectedUnsafe::MoveToTVMFFIAny(
+ Expected<Optional<VisitInterrupt>>(std::nullopt));
+ }
+ if constexpr (order == WalkOrder::kPreOrder) {
+ auto result = dispatch_(value, this->def_region_kind());
+ TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(result, value);
+ int32_t type_index = result.type_index();
+ TVM_FFI_UNSAFE_ASSUME(type_index == TypeIndex::kTVMFFIInt);
+ if
(TVM_FFI_PREDICT_FALSE(details::ExpectedUnsafe::ValueAs<int32_t>(result) ==
+ WalkResult::kSkipTag)) {
+ return details::ExpectedUnsafe::MoveToTVMFFIAny(
+ Expected<Optional<VisitInterrupt>>(std::nullopt));
+ }
+ }
+
+
TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(DefaultVisitExpected(value),
value);
+
+ if constexpr (order == WalkOrder::kPostOrder) {
+ TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(
+ dispatch_(value, this->def_region_kind()), value);
+ }
+
+ return details::ExpectedUnsafe::MoveToTVMFFIAny(
+ Expected<Optional<VisitInterrupt>>(std::nullopt));
+ }
+
+ /*! \brief Composed dispatch closure invoked once per visited node. */
+ Dispatch dispatch_;
+};
+
+/*!
+ * \brief Compose typed callbacks into a single per-node dispatcher.
+ *
+ * Each callback dispatches on its first parameter's type; callbacks are tested
+ * in declaration order and the first match runs. Callbacks may take an
optional
+ * second ``TVMFFIDefRegionKind`` argument. Nodes that match no callback fall
+ * through and traversal continues normally.
+ */
+struct StructuralWalkCallbackChain {
+ /*!
+ * \brief Build a dispatcher closure over a chain of typed callbacks.
+ * \tparam Callbacks Callable types whose first parameter selects the
dispatched
+ * value type.
+ * \param callbacks Callbacks to be tested in order.
+ * \return A dispatcher closure of type ``Expected<WalkResult>(AnyView,
+ * TVMFFIDefRegionKind)``. Each user callback may take either
+ * ``(value)`` or ``(value, def_region_kind)``.
+ */
+ template <typename... Callbacks>
+ static auto FromChain(Callbacks... callbacks) {
+ return [=](AnyView x, TVMFFIDefRegionKind kind) mutable ->
Expected<WalkResult> {
+ try {
+ Optional<Expected<WalkResult>> result;
+ // Fold expression: each TryCallLink returns empty Optional on no-match
+ // (falsy) or a result on match (truthy); || short-circuits on first
match.
+ (... || (result = TryCallLink(callbacks, x, kind)));
+ if (result.has_value()) {
+ return std::move(result).value();
+ }
+ return WalkResult::Advance();
+ } catch (const Error& err) {
+ return Unexpected(err);
+ }
+ };
+ }
+
+ private:
+ /*!
+ * \brief Invoke ``callback`` when ``x`` matches its first parameter type.
+ * \tparam Callback Callable whose first parameter selects the value type and
+ * whose optional second parameter receives the active
def-region kind.
+ * \param callback The callback under test.
+ * \param x The value to dispatch on.
+ * \param kind The active def-region kind.
+ * \return The callback result if it matched, empty ``Optional`` otherwise.
+ */
+ template <typename Callback>
+ static Optional<Expected<WalkResult>> TryCallLink(Callback& callback,
AnyView x,
+ TVMFFIDefRegionKind kind) {
+ using FuncInfo = FunctionInfo<std::decay_t<Callback>>;
+ static_assert(FuncInfo::num_args == 1 || FuncInfo::num_args == 2,
+ "StructuralWalk callbacks must take one argument (value) or
two arguments "
+ "(value, def-region kind)");
+ using FirstArg = std::tuple_element_t<0, typename FuncInfo::ArgType>;
+ using TSub = std::remove_cv_t<std::remove_reference_t<FirstArg>>;
+ if constexpr (std::is_same_v<TSub, AnyView>) {
+ // callback on AnyView
+ return InvokeCallback(callback, x, kind);
+ } else if constexpr (std::is_same_v<TSub, Any>) {
+ // callback on Any
+ return InvokeCallback(callback, Any(x), kind);
+ } else {
+ if (auto opt = x.template as<TSub>()) {
+ return InvokeCallback(callback, *std::move(opt), kind);
+ }
+ }
+ return std::nullopt;
+ }
+
+ /*!
+ * \brief Invoke a matched callback with optional def-region context.
+ * \tparam Callback Callable returning ``Expected<WalkResult>``.
+ * \tparam Value Type of the converted value passed to the callback.
+ * \param callback The matched callback to invoke.
+ * \param value The converted value.
+ * \param kind The active def-region kind.
+ * \return The callback result.
+ */
+ template <typename Callback, typename Value>
+ static Expected<WalkResult> InvokeCallback(Callback& callback, Value&& value,
+ TVMFFIDefRegionKind kind) {
+ using FuncInfo = FunctionInfo<std::decay_t<Callback>>;
+ if constexpr (FuncInfo::num_args == 1) {
+ return callback(std::forward<Value>(value));
+ } else {
+ return callback(std::forward<Value>(value), kind);
+ }
+ }
+};
+
+} // namespace details
+
+/*!
+ * \brief Walk a structured value graph and invoke typed callbacks on selected
values.
+ *
+ * The callbacks are invoked only for values matching the first argument type
of
+ * one of the callbacks. The first callback argument may be ``AnyView``,
``Any``,
+ * an object reference type, an object pointer type, or another FFI-convertible
+ * POD type. A callback may also optionally take a second
``TVMFFIDefRegionKind`` argument
+ * to inspect whether the value is being visited in a definition region.
+ * Callbacks are tested in order, and the first match is used.
+ *
+ * Each callback should return ``Expected<WalkResult>``; see ``WalkResult``.
+ * - ``WalkResult::Interrupt(...)`` halts traversal.
+ * - ``WalkResult::Advance()`` continues traversal.
+ * - ``WalkResult::Skip()`` skips children traversal.
+ * - ``Error`` indicates traversal failure.
+ *
+ * \sa WalkOrder, WalkResult
+ *
+ * Example:
+ *
+ * \code
+ * int num_adds = 0;
+ *
+ * Expected<Optional<VisitInterrupt>> result =
StructuralWalkExpected<WalkOrder::kPreOrder>(
+ * root,
+ * [&](const Add& add) -> Expected<WalkResult> {
+ * ++num_adds;
+ * return WalkResult::Advance();
+ * },
+ * [&](const Mul& mul) -> Expected<WalkResult> {
+ * return WalkResult::Skip();
+ * });
+ * \endcode
+ *
+ * \tparam order Whether to invoke the callback before or after visiting
children.
+ * \tparam Callbacks Callback types.
+ * \param root The root value to visit.
+ * \param callbacks Callbacks invoked for matching nodes. Each callback may
take
+ * either ``(value)`` or ``(value, def_region_kind)`` and
should return
+ * ``Expected<WalkResult>``.
+ * \return ``std::nullopt`` if traversal completed, or the interrupt returned
by
+ * a callback.
+ *
+ * \note Return type of each callback should be ``Expected<WalkResult>``.
+ */
+template <WalkOrder order, typename... Callbacks>
+Expected<Optional<VisitInterrupt>> StructuralWalkExpected(AnyView root,
+ Callbacks&&...
callbacks) noexcept {
+ static_assert(sizeof...(Callbacks) != 0, "StructuralWalk requires at least
one callback");
+ auto dispatch =
+
details::StructuralWalkCallbackChain::FromChain(std::forward<Callbacks>(callbacks)...);
+ using Visitor = details::StructuralWalkCallbackVisitorObj<order,
decltype(dispatch)>;
+ StructuralVisitor visitor(make_object<Visitor>(std::move(dispatch)));
+ return visitor->VisitExpected(root);
+}
+
+/*!
+ * \brief Throwing error over \ref tvm::ffi::StructuralWalkExpected.
+ *
+ * See \ref tvm::ffi::StructuralWalkExpected for callback semantics and
traversal behavior.
+ *
+ * \tparam order Whether to invoke the callback before or after visiting
children.
+ * \tparam Callbacks Callback types.
+ * \param root The root value to visit.
+ * \param callbacks Callbacks invoked for matching nodes. Each callback may
take
+ * either ``(value)`` or ``(value, def_region_kind)`` and
should return
+ * ``Expected<WalkResult>``.
+ * \return ``std::nullopt`` if traversal completed, or the interrupt returned
by
+ * a callback.
+ * \throws Error if traversal or a callback returned an error.
+ *
+ * \note Return type of each callback should be ``Expected<WalkResult>``.
+ */
+template <WalkOrder order, typename... Callbacks>
+Optional<VisitInterrupt> StructuralWalk(AnyView root, Callbacks&&...
callbacks) {
+ return StructuralWalkExpected<order>(root,
std::forward<Callbacks>(callbacks)...).value();
+}
+
+} // namespace ffi
+} // namespace tvm
+#endif // TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_
diff --git a/include/tvm/ffi/extra/visit_error_context.h
b/include/tvm/ffi/extra/visit_error_context.h
index fd9310c..e8b5c48 100644
--- a/include/tvm/ffi/extra/visit_error_context.h
+++ b/include/tvm/ffi/extra/visit_error_context.h
@@ -157,6 +157,22 @@ class VisitErrorContext : public ObjectRef {
throw;
\
}
+/*!
+ * \brief End a visit try block and catch any Error as an Expected error,
+ * appending node to the VisitErrorContext on the way up.
+ *
+ * Must be paired with TVM_FFI_VISIT_BEGIN() above the visit body.
+ *
+ * \param node The current visit value. Object-backed values are appended to
the
+ * error context's reverse_visit_pattern on exception.
+ */
+#define TVM_FFI_VISIT_END_RETURN_EXPECTED(node)
\
+ }
\
+ catch (::tvm::ffi::Error & _tvm_ffi_visit_err_) {
\
+ ::tvm::ffi::details::UpdateVisitErrorContext(_tvm_ffi_visit_err_, (node));
\
+ return ::tvm::ffi::Unexpected(_tvm_ffi_visit_err_);
\
+ }
+
/*!
* \brief Throw an error from inside a visit, with `node` recorded
* as the innermost frame of the resulting VisitErrorContext.
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 9b22636..262c570 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -118,6 +118,8 @@ struct StaticTypeKey {
static constexpr const char* kTVMFFIModule = "ffi.Module";
/*! \brief The type key for Dict */
static constexpr const char* kTVMFFIDict = "ffi.Dict";
+ /*! \brief The type key for VisitInterrupt */
+ static constexpr const char* kTVMFFIVisitInterrupt = "ffi.VisitInterrupt";
/*! \brief The type key for OpaquePyObject */
static constexpr const char* kTVMFFIOpaquePyObject = "ffi.OpaquePyObject";
};
@@ -1085,6 +1087,49 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t
object_type_index) {
}
}
+/*!
+ * \brief Return whether a runtime type index matches a target type index.
+ *
+ * \param actual_type_index Runtime type index of the value being checked.
+ * \param target_type_index Target type index to match against.
+ * \return Whether \p actual_type_index is compatible with \p
target_type_index.
+ */
+TVM_FFI_INLINE bool RuntimeTypeIndexMatch(int32_t actual_type_index, int32_t
target_type_index) {
+ if (actual_type_index == target_type_index) {
+ return true;
+ }
+ // Any target matches all runtime values.
+ if (target_type_index == TypeIndex::kTVMFFIAny) {
+ return true;
+ }
+ // str/bytes targets also match their small inline variants.
+ if (target_type_index == TypeIndex::kTVMFFIStr) {
+ return actual_type_index == TypeIndex::kTVMFFISmallStr;
+ }
+ if (target_type_index == TypeIndex::kTVMFFIBytes) {
+ return actual_type_index == TypeIndex::kTVMFFISmallBytes;
+ }
+ // Everything is a subclass of object.
+ if (target_type_index == TypeIndex::kTVMFFIObject) {
+ return actual_type_index >= TypeIndex::kTVMFFIStaticObjectBegin;
+ }
+ // Non-object type indices can only match through exact equality handled
above.
+ if (actual_type_index < TypeIndex::kTVMFFIStaticObjectBegin ||
+ target_type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+ return false;
+ }
+ // Invariance: parent index is always smaller than the child.
+ if (actual_type_index < target_type_index) {
+ return false;
+ }
+ // Fall back to runtime ancestry metadata.
+ const TypeInfo* actual_type_info = TVMFFIGetTypeInfo(actual_type_index);
+ const TypeInfo* target_type_info = TVMFFIGetTypeInfo(target_type_index);
+ return actual_type_info->type_depth > target_type_info->type_depth &&
+
actual_type_info->type_ancestors[target_type_info->type_depth]->type_index ==
+ target_type_index;
+}
+
/*!
* \brief Namespace to internally manipulate object class.
* \note These functions are only supposed to be used by internal
diff --git a/include/tvm/ffi/reflection/accessor.h
b/include/tvm/ffi/reflection/accessor.h
index ee43379..80de493 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -466,6 +466,27 @@ inline constexpr const char* kSHash = "__s_hash__";
* mapping; ``field_name`` is used for mismatch path reporting.
*/
inline constexpr const char* kSEqual = "__s_equal__";
+/*!
+ * \brief Custom structural visitor hook (used by ``StructuralVisitor``).
+ *
+ * The hook receives the active visitor and the current object value. Opaque
+ * pointer hooks return a ``TVMFFIAny`` that stores an
+ * ``Expected<Optional<VisitInterrupt>>`` result.
+ *
+ * Value type: either an opaque function pointer to a C++ structural visit hook
+ *
+ * ``TVMFFIAny (*)(StructuralVisitorObj* visitor, AnyView value) noexcept``
+ *
+ * or an ``ffi::Function`` with signature
+ *
+ * ``(StructuralVisitor visitor, Any value) -> Optional<VisitInterrupt>``.
+ *
+ * On success, the hook returns ``None`` for no interrupt, or a
+ * ``VisitInterrupt`` to halt traversal. On failure, it returns an ``Error``.
+ * The ``visitor`` parameter is the active traversal context used to
recursively
+ * visit structural children.
+ */
+inline constexpr const char* kStructuralVisit = "__s_visit__";
/*!
* \brief Serialize object data to a JSON-compatible ``Map``.
*
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 012de31..9676c71 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -76,10 +76,16 @@ if TYPE_CHECKING or not _is_config_mode():
from .module import Module, system_lib, load_module
from .stream import StreamContext, get_raw_stream, use_raw_stream,
use_torch_stream
from .structural import (
+ DefRegionKind,
StructuralKey,
+ StructuralVisitor,
+ VisitInterrupt,
+ WalkOrder,
+ WalkResult,
get_first_structural_mismatch,
structural_equal,
structural_hash,
+ structural_walk,
)
from . import serialization
from . import access_path
@@ -130,6 +136,7 @@ __all__ = [
"LIB",
"Array",
"DLDeviceType",
+ "DefRegionKind",
"Device",
"Dict",
"Function",
@@ -141,7 +148,11 @@ __all__ = [
"Shape",
"StreamContext",
"StructuralKey",
+ "StructuralVisitor",
"Tensor",
+ "VisitInterrupt",
+ "WalkOrder",
+ "WalkResult",
"__version__",
"__version_tuple__",
"access_path",
@@ -167,6 +178,7 @@ __all__ = [
"structural",
"structural_equal",
"structural_hash",
+ "structural_walk",
"system_lib",
"use_raw_stream",
"use_torch_stream",
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index a0e176d..eaddf0f 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping, MutableSequence,
Sequence
from ctypes import c_void_p
- from tvm_ffi import Device, Module, Object, StructuralKey as _StructuralKey
+ from tvm_ffi import Device, Module, Object, StructuralKey as
_StructuralKey, StructuralVisitor as _StructuralVisitor, VisitInterrupt as
_VisitInterrupt
from tvm_ffi.access_path import AccessPath
from typing import Any, Callable
# isort: on
@@ -109,8 +109,14 @@ if TYPE_CHECKING:
def String(_0: str, /) -> str: ...
def StructuralEqual(_0: Any, _1: Any, _2: bool, _3: bool, /) -> bool: ...
def StructuralHash(_0: Any, _1: bool, _2: bool, /) -> int: ...
+ def VisitInterrupt(_0: Any, /) -> _VisitInterrupt: ...
def StructuralKey(_0: Any, /) -> _StructuralKey: ...
def StructuralKeyEqual(_0: Any, _1: Any, /) -> bool: ...
+ def StructuralWalk(_0: Any, _1: Sequence[tuple[int, Callable[[Any],
Any]]], _2: Sequence[tuple[int, Callable[[Any, int], Any]]], _3: int, /) ->
_VisitInterrupt | None: ...
+ def StructuralVisitor() -> _StructuralVisitor: ...
+ def StructuralVisitorDefRegionKind(_0: _StructuralVisitor, /) -> int: ...
+ def StructuralVisitorWithDefRegionKind(_0: _StructuralVisitor, _1: int,
_2: Callable[[], Any], /) -> Any: ...
+ def StructuralVisitorVisit(_0: _StructuralVisitor, _1: Any, /) ->
_VisitInterrupt | None: ...
def SystemLib(*args: Any) -> Any: ...
def ToJSONGraph(_0: Any, _1: Any, /) -> Any: ...
def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
@@ -197,9 +203,15 @@ __all__ = [
"StructuralHash",
"StructuralKey",
"StructuralKeyEqual",
+ "StructuralVisitor",
+ "StructuralVisitorDefRegionKind",
+ "StructuralVisitorVisit",
+ "StructuralVisitorWithDefRegionKind",
+ "StructuralWalk",
"SystemLib",
"ToJSONGraph",
"ToJSONGraphString",
+ "VisitInterrupt",
"_PyClassRegisterTypeAttrColumns",
"_RegisterFFIInit",
# tvm-ffi-stubgen(end)
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index b9cacd3..bb2c8a1 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -607,6 +607,7 @@ _FFI_TYPE_ATTR_NAMES: frozenset[str] = frozenset(
"__any_equal__",
"__s_equal__",
"__s_hash__",
+ "__s_visit__",
"__data_to_json__",
"__data_from_json__",
}
diff --git a/python/tvm_ffi/structural.py b/python/tvm_ffi/structural.py
index 54434d8..975a934 100644
--- a/python/tvm_ffi/structural.py
+++ b/python/tvm_ffi/structural.py
@@ -19,23 +19,82 @@
from __future__ import annotations
+from collections.abc import Callable, Sequence
+from enum import IntEnum
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .access_path import AccessPath
-from . import _ffi_api
+from . import _ffi_api, core
from .core import Object
from .registry import register_object
__all__ = [
+ "DefRegionKind",
"StructuralKey",
+ "StructuralVisitor",
+ "VisitInterrupt",
+ "WalkOrder",
+ "WalkResult",
"get_first_structural_mismatch",
"structural_equal",
"structural_hash",
+ "structural_walk",
]
+class WalkOrder(IntEnum):
+ """Callback placement before or after visiting children for structural
walks.
+
+ See Also
+ --------
+ :py:func:`tvm_ffi.structural_walk`
+ Walk an object graph, invoke matching callbacks.
+
+ """
+
+ PREORDER = 0
+ POSTORDER = 1
+
+
+class WalkResult(IntEnum):
+ """Control signal for structural walks.
+
+ Advance continues visiting a node's children; skip continues traversal but
+ skips the current node's children.
+
+ Use :class:`VisitInterrupt` when traversal should stop entirely.
+
+ See Also
+ --------
+ :py:func:`tvm_ffi.structural_walk`
+ Walk an object graph, invoke matching callbacks.
+
+ """
+
+ ADVANCE = 0
+ SKIP = 1
+
+
+class DefRegionKind(IntEnum):
+ """Def-region state active during structural visiting.
+
+ The values mirror ``TVMFFIDefRegionKind`` in the C ABI.
+
+ See Also
+ --------
+ :py:class:`tvm_ffi.StructuralVisitor`
+ Structural traversal visitor that carries object dispatch and
def-region
+ state across recursive visits.
+
+ """
+
+ NONE = 0
+ DEF_RECURSIVE = 1
+ DEF_NON_RECURSIVE = 2
+
+
def structural_equal(
lhs: Any, rhs: Any, map_free_vars: bool = False, skip_tensor_content: bool
= False
) -> bool:
@@ -231,3 +290,266 @@ class StructuralKey(Object):
def __eq__(self, other: Any) -> bool:
"""Compare by structural equality."""
return isinstance(other, StructuralKey) and
_ffi_api.StructuralKeyEqual(self, other)
+
+
+@register_object("ffi.VisitInterrupt")
+class VisitInterrupt(Object):
+ """Payload-carrying signal that stops a structural visit.
+
+ This object can be returned from structural walk callbacks and structural
+ visit hooks to halt traversal early. The optional payload is preserved in
+ :py:attr:`value` and returned to the caller.
+
+ Examples
+ --------
+ Use ``VisitInterrupt`` to stop a structural walk when a target node is
found:
+
+ .. code-block:: python
+
+ import tvm_ffi
+
+
+ def on_node(node):
+ if is_target(node):
+ return tvm_ffi.VisitInterrupt(node)
+ return None
+
+
+ result = tvm_ffi.structural_walk(root, (object, on_node))
+ if result is not None:
+ found = result.value
+
+ See Also
+ --------
+ :py:func:`tvm_ffi.structural_walk`
+ Structural walk API whose callbacks may return ``VisitInterrupt``.
+
+ """
+
+ # tvm-ffi-stubgen(begin): object/ffi.VisitInterrupt
+ # fmt: off
+ value: Any
+ if TYPE_CHECKING:
+ def __init__(self, value: Any = ...) -> None: ...
+ def __ffi_init__(self, _0: Any, /) -> None: ... # ty:
ignore[invalid-method-override]
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+ def __init__(self, value: Any = None) -> None:
+ """Create an interrupt with an optional payload.
+
+ Parameters
+ ----------
+ value
+ Payload returned to the caller when traversal stops.
+
+ """
+ self.__init_handle_by_constructor__(_ffi_api.VisitInterrupt, value)
+
+
+@register_object("ffi.StructuralVisitor")
+class StructuralVisitor(Object):
+ """Low-level structural traversal visitor.
+
+ This class exposes the low-level visitor object used by structural
+ traversal hooks.
+ """
+
+ # tvm-ffi-stubgen(begin): object/ffi.StructuralVisitor
+ # fmt: off
+ if TYPE_CHECKING:
+ def __init__(self) -> None: ...
+ def __ffi_init__(self) -> None: ... # ty:
ignore[invalid-method-override]
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+ def __init__(self) -> None:
+ """Create a default structural visitor."""
+ self.__init_handle_by_constructor__(_ffi_api.StructuralVisitor)
+
+ def visit(self, value: Any) -> VisitInterrupt | None:
+ """Low-level API to visit ``value`` using this visitor's dispatch
behavior.
+
+ Parameters
+ ----------
+ value
+ Value to visit.
+
+ Returns
+ -------
+ result
+ ``None`` if traversal should continue, otherwise a
+ :class:`VisitInterrupt` carrying the early-exit payload.
+
+ """
+ return _ffi_api.StructuralVisitorVisit(self, value)
+
+ def def_region_kind(self) -> DefRegionKind:
+ """Low-level API to return the currently active structural def-region
kind.
+
+ Returns
+ -------
+ kind
+ The active :class:`DefRegionKind`.
+
+ """
+ return DefRegionKind(_ffi_api.StructuralVisitorDefRegionKind(self))
+
+ def with_def_region_kind(
+ self,
+ kind: int,
+ callback: Callable[[], Any],
+ ) -> Any:
+ """Low-level API to run ``callback`` with a temporarily active
def-region kind.
+
+ Parameters
+ ----------
+ kind
+ Def region kind to use while running ``callback``.
+
+ callback
+ Nullary callable to execute inside the scoped region.
+
+ Returns
+ -------
+ result
+ The value returned by ``callback``.
+
+ """
+ return _ffi_api.StructuralVisitorWithDefRegionKind(self, kind,
callback)
+
+
+def structural_walk(
+ root: Any,
+ callbacks: tuple | Sequence | Callable = (),
+ with_def_region_kind: tuple | Sequence | Callable = (),
+ order: str | WalkOrder = "pre",
+) -> VisitInterrupt | None:
+ """Walk a value structurally and invoke the first matching typed callback.
+
+ Parameters
+ ----------
+ root
+ Root value to traverse.
+
+ callbacks
+ Normal callbacks. These callbacks receive one argument: ``value``.
+ Callback entries are tried in order.
+
+ May be one of:
+
+ - A single callback, used as a ``typing.Any`` catch-all.
+ - A ``(type, callback)`` entry.
+ - A grouped ``((type1, type2, ...), callback)`` entry.
+ - A sequence of entries.
+
+ Types may be builtins, registered FFI object classes, or
+ ``typing.Any``/``object`` as a catch-all.
+
+ with_def_region_kind
+ Def-region-aware callbacks. These callbacks receive two arguments:
+ ``(value, def_region_kind)``. They accept the same callback entry forms
+ as ``callbacks``.
+
+ order
+ ``"pre"``/``WalkOrder.PREORDER`` to invoke callbacks before children,
or
+ ``"post"``/``WalkOrder.POSTORDER`` to invoke callbacks after children.
+
+ Returns
+ -------
+ result
+ ``None`` if traversal completed, otherwise a :class:`VisitInterrupt`
+ returned by a callback.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ visited = []
+
+
+ uses = []
+ result = tvm_ffi.structural_walk(
+ node,
+ ((int, float), lambda value: visited.append(("leaf", value))),
+ with_def_region_kind=(
+ Var,
+ lambda var, kind: (
+ uses.append(var) if kind == tvm_ffi.DefRegionKind.NONE
else None
+ ),
+ ),
+ )
+
+ """
+ if isinstance(order, WalkOrder):
+ order_int = int(order)
+ elif order in ("pre", "post"):
+ order_int = int(WalkOrder.PREORDER if order == "pre" else
WalkOrder.POSTORDER)
+ else:
+ raise ValueError(f"Unknown structural walk order: {order!r}")
+
+ def normalize_callbacks(
+ callbacks: tuple | Sequence | Callable,
+ ) -> list[tuple[object, Callable]]:
+ callback_entries = []
+
+ def add_callback_entry(callback_entry: tuple) -> None:
+ callback_type, fn = callback_entry
+ callback_types = callback_type if isinstance(callback_type, tuple)
else (callback_type,)
+ callback_entries.extend((t, fn) for t in callback_types)
+
+ if callable(callbacks):
+ callback_entries.append((Any, callbacks))
+ elif isinstance(callbacks, tuple) and len(callbacks) == 2 and
callable(callbacks[1]):
+ add_callback_entry(callbacks)
+ elif isinstance(callbacks, Sequence) and not isinstance(callbacks,
(str, bytes)):
+ for callback in callbacks:
+ if (
+ not isinstance(callback, tuple)
+ or len(callback) != 2
+ or not callable(callback[1])
+ ):
+ raise TypeError(
+ "structural_walk callbacks within a sequence must be "
+ "(type, callback) tuples"
+ )
+ add_callback_entry(callback)
+ else:
+ raise TypeError(
+ "structural_walk callbacks must be callbacks, (type, callback)
entries, "
+ "((type1, type2, ...), callback) entries, or sequences of
tuple entries"
+ )
+ return callback_entries
+
+ def wrap_callback_with_def_region_kind(fn: Callable[..., Any]) ->
Callable[[Any, int], Any]:
+ return lambda value, kind: fn(value, DefRegionKind(kind))
+
+ callback_entries = normalize_callbacks(callbacks)
+ callback_entries_with_def_region_kind =
normalize_callbacks(with_def_region_kind)
+
+ entries: list[tuple[int, Callable[[Any], Any]]] = [
+ (_callback_type_to_type_index(t), fn) for t, fn in callback_entries
+ ]
+ entries_with_def_region_kind: list[tuple[int, Callable[[Any, int], Any]]]
= [
+ (_callback_type_to_type_index(t),
wrap_callback_with_def_region_kind(fn))
+ for t, fn in callback_entries_with_def_region_kind
+ ]
+ return _ffi_api.StructuralWalk(root, entries,
entries_with_def_region_kind, order_int)
+
+
+def _callback_type_to_type_index(callback_type: type[Any] | Any) -> int:
+ """Convert a callback arg type to a type index."""
+ annotation = Any if callback_type is object else callback_type
+ try:
+ type_index =
core.TypeSchema.from_annotation(annotation).origin_type_index
+ except TypeError as err:
+ raise TypeError(
+ "structural_walk callback type must be a supported builtin, "
+ "typing.Any/object, or an FFI-registered object class"
+ ) from err
+ if type_index < 0 and annotation is not Any:
+ raise TypeError(
+ "structural_walk callback type_index is negative, the only"
+ "acceptable negative type_index is -1 for Any"
+ )
+ return type_index
diff --git a/src/ffi/extra/structural_visit.cc
b/src/ffi/extra/structural_visit.cc
new file mode 100644
index 0000000..4002db7
--- /dev/null
+++ b/src/ffi/extra/structural_visit.cc
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/ffi/extra/structural_visit.cc
+ * \brief Structural visit implementation.
+ */
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
+#include <tvm/ffi/container/list.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/extra/structural_visit.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/registry.h>
+
+namespace tvm {
+namespace ffi {
+
+// ---------------------------------------------------------------------------
+// Built-in container structural visit.
+// ---------------------------------------------------------------------------
+
+namespace details {
+
+/*!
+ * \brief Runtime structural walk for callback arrays.
+ *
+ * \param root The root value to visit.
+ * \param callbacks Runtime callback entries of ``(type_index,
ffi::Function)`` invoked as
+ * ``callback(value)``.
+ * \param callbacks_with_def_region_kind Runtime callback entries of
``(type_index, ffi::Function)``
+ * invoked as ``callback(value,
def_region_kind)``.
+ * \param order Integer value of \ref WalkOrder.
+ * \return Expected interrupt state. An error means traversal failed.
+ */
+Expected<Optional<VisitInterrupt>> StructuralWalkExpected(
+ AnyView root, const Array<Tuple<int32_t, Function>>& callbacks,
+ const Array<Tuple<int32_t, Function>>& callbacks_with_def_region_kind, int
order) noexcept {
+ auto dispatch = [callbacks, callbacks_with_def_region_kind](
+ AnyView x, TVMFFIDefRegionKind kind) ->
Expected<WalkResult> {
+ for (const auto& entry : callbacks) {
+ int32_t type_index = entry.template get<0>();
+ if (!RuntimeTypeIndexMatch(x.type_index(), type_index)) {
+ continue;
+ }
+ Function fn = entry.template get<1>();
+ return fn.CallExpected<WalkResult>(x);
+ }
+ for (const auto& entry : callbacks_with_def_region_kind) {
+ int32_t type_index = entry.template get<0>();
+ if (!RuntimeTypeIndexMatch(x.type_index(), type_index)) {
+ continue;
+ }
+ Function fn = entry.template get<1>();
+ return fn.CallExpected<WalkResult>(x, kind);
+ }
+ return WalkResult::Advance();
+ };
+
+ if (order == static_cast<int>(WalkOrder::kPreOrder)) {
+ using Visitor = StructuralWalkCallbackVisitorObj<WalkOrder::kPreOrder,
decltype(dispatch)>;
+ StructuralVisitor visitor(make_object<Visitor>(std::move(dispatch)));
+ return visitor->VisitExpected(root);
+ } else {
+ using Visitor = StructuralWalkCallbackVisitorObj<WalkOrder::kPostOrder,
decltype(dispatch)>;
+ StructuralVisitor visitor(make_object<Visitor>(std::move(dispatch)));
+ return visitor->VisitExpected(root);
+ }
+}
+
+/*! \brief Visit entries in a sequence container. */
+TVMFFIAny VisitSeqContainer(StructuralVisitorObj* visitor, const SeqBaseObj*
seq) noexcept {
+ for (const Any& item : *seq) {
+ TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN(visitor->VisitExpected(item));
+ }
+ return
ExpectedUnsafe::MoveToTVMFFIAny(Expected<Optional<VisitInterrupt>>(std::nullopt));
+}
+
+/*! \brief Visit keys and values in a map container. */
+TVMFFIAny VisitMapContainer(StructuralVisitorObj* visitor, const MapBaseObj*
map) noexcept {
+ for (const auto& kv : *map) {
+ TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN(visitor->VisitExpected(kv.first));
+ TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN(visitor->VisitExpected(kv.second));
+ }
+ return
ExpectedUnsafe::MoveToTVMFFIAny(Expected<Optional<VisitInterrupt>>(std::nullopt));
+}
+
+/*! \brief Structural visit hook for ArrayObj. */
+TVMFFIAny VisitArray(StructuralVisitorObj* visitor, AnyView value) noexcept {
+ const auto* array = value.cast<const ArrayObj*>();
+ return VisitSeqContainer(visitor, array);
+}
+
+/*! \brief Structural visit hook for ListObj. */
+TVMFFIAny VisitList(StructuralVisitorObj* visitor, AnyView value) noexcept {
+ const auto* list = value.cast<const ListObj*>();
+ return VisitSeqContainer(visitor, list);
+}
+
+/*! \brief Structural visit hook for MapObj. */
+TVMFFIAny VisitMap(StructuralVisitorObj* visitor, AnyView value) noexcept {
+ const auto* map = value.cast<const MapObj*>();
+ return VisitMapContainer(visitor, map);
+}
+
+/*! \brief Structural visit hook for DictObj. */
+TVMFFIAny VisitDict(StructuralVisitorObj* visitor, AnyView value) noexcept {
+ const auto* dict = value.cast<const DictObj*>();
+ return VisitMapContainer(visitor, dict);
+}
+
+} // namespace details
+
+// ---------------------------------------------------------------------------
+// Static registration.
+// ---------------------------------------------------------------------------
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<VisitInterruptObj>()
+ .def(refl::init<Any>(), "Constructor that creates a structural visit
interrupt")
+ .def_ro("value", &VisitInterruptObj::value);
+ refl::ObjectDef<StructuralVisitorObj>().def(
+ refl::init<>(), "Constructor that creates a default structural visitor");
+ refl::GlobalDef()
+ .def("ffi.VisitInterrupt", [](Any value) { return
VisitInterrupt(std::move(value)); })
+ .def("ffi.StructuralVisitor", []() { return StructuralVisitor(); })
+ .def_method("ffi.StructuralVisitorVisit", &StructuralVisitorObj::Visit)
+ .def_method("ffi.StructuralVisitorDefRegionKind",
&StructuralVisitorObj::def_region_kind)
+ .def_method(
+ "ffi.StructuralVisitorWithDefRegionKind",
+ [](const StructuralVisitor& visitor, TVMFFIDefRegionKind kind, const
Function& callback) {
+ return visitor->WithDefRegionKind(kind, callback);
+ })
+ .def("ffi.StructuralWalk",
+ [](AnyView root, const Array<Tuple<int32_t, Function>>& callbacks,
+ const Array<Tuple<int32_t, Function>>&
callbacks_with_def_region_kind,
+ int32_t order) -> Optional<VisitInterrupt> {
+ return details::StructuralWalkExpected(root, callbacks,
callbacks_with_def_region_kind,
+ order)
+ .value();
+ });
+ refl::EnsureTypeAttrColumn(refl::type_attr::kStructuralVisit);
+ refl::TypeAttrDef<ArrayObj>().attr(
+ refl::type_attr::kStructuralVisit,
+
reinterpret_cast<void*>(static_cast<FStructuralVisit>(&details::VisitArray)));
+ refl::TypeAttrDef<ListObj>().attr(
+ refl::type_attr::kStructuralVisit,
+
reinterpret_cast<void*>(static_cast<FStructuralVisit>(&details::VisitList)));
+ refl::TypeAttrDef<MapObj>().attr(
+ refl::type_attr::kStructuralVisit,
+
reinterpret_cast<void*>(static_cast<FStructuralVisit>(&details::VisitMap)));
+ refl::TypeAttrDef<DictObj>().attr(
+ refl::type_attr::kStructuralVisit,
+
reinterpret_cast<void*>(static_cast<FStructuralVisit>(&details::VisitDict)));
+}
+
+} // namespace ffi
+} // namespace tvm
diff --git a/tests/cpp/extra/test_structural_visit.cc
b/tests/cpp/extra/test_structural_visit.cc
new file mode 100644
index 0000000..df8ee11
--- /dev/null
+++ b/tests/cpp/extra/test_structural_visit.cc
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/structural_visit.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/string.h>
+
+#include <stdexcept>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "../testing_object.h"
+
+namespace {
+
+using namespace tvm::ffi;
+using namespace tvm::ffi::testing;
+
+class TestVisitorObj : public StructuralVisitorObj {
+ public:
+ TestVisitorObj() : StructuralVisitorObj(VTable()) {}
+
+ std::vector<ObjectRef> visited;
+ std::vector<TVMFFIDefRegionKind> modes;
+ ObjectRef interrupt_on;
+
+ private:
+ static const StructuralVisitorVTable* VTable() {
+ static const StructuralVisitorVTable
vtable{&TestVisitorObj::DispatchVisit};
+ return &vtable;
+ }
+
+ static TVMFFIAny DispatchVisit(StructuralVisitorObj* self, AnyView value)
noexcept {
+ return static_cast<TestVisitorObj*>(self)->VisitImpl(value);
+ }
+
+ // NOLINTNEXTLINE(bugprone-exception-escape)
+ TVMFFIAny VisitImpl(AnyView value) noexcept {
+ if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ ObjectRef value_ref = value.cast<ObjectRef>();
+ visited.push_back(value_ref);
+ modes.push_back(def_region_mode_);
+ if (interrupt_on.defined() && value_ref.same_as(interrupt_on)) {
+ Expected<Optional<VisitInterrupt>> interrupt =
+ Optional<VisitInterrupt>(VisitInterrupt(String("stop")));
+ return details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(interrupt));
+ }
+ }
+ return
details::ExpectedUnsafe::MoveToTVMFFIAny(DefaultVisitExpected(value));
+ }
+};
+
+StructuralVisitor MakeTestVisitor() { return
StructuralVisitor(make_object<TestVisitorObj>()); }
+
+TestVisitorObj* AsTestVisitor(const StructuralVisitor& visitor) {
+ return static_cast<TestVisitorObj*>(visitor.get());
+}
+
+void SetInterrupt(const StructuralVisitor& visitor, const ObjectRef& value) {
+ TestVisitorObj* test_visitor = AsTestVisitor(visitor);
+ test_visitor->interrupt_on = value;
+}
+
+void ExpectTrace(const std::vector<std::string>& actual,
+ std::initializer_list<const char*> expected) {
+ ASSERT_EQ(actual.size(), expected.size());
+ size_t i = 0;
+ for (const char* item : expected) {
+ EXPECT_EQ(actual[i], item);
+ ++i;
+ }
+}
+
+// ---------------------------------------------------------------------------
+// StructuralVisitor behavior.
+// ---------------------------------------------------------------------------
+
+TEST(StructuralVisitor, RecordsNode) {
+ ObjectRef leaf = TVar("leaf");
+ StructuralVisitor visitor = MakeTestVisitor();
+
+ Expected<Optional<VisitInterrupt>> result = visitor->VisitExpected(leaf);
+
+ ASSERT_TRUE(result.is_ok());
+ EXPECT_FALSE(result.value().has_value());
+ ASSERT_EQ(AsTestVisitor(visitor)->visited.size(), 1U);
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[0].same_as(leaf));
+}
+
+TEST(StructuralVisitor, PropagatesInterrupt) {
+ ObjectRef leaf = TVar("leaf");
+ StructuralVisitor visitor = MakeTestVisitor();
+ SetInterrupt(visitor, leaf);
+
+ Expected<Optional<VisitInterrupt>> result = visitor->VisitExpected(leaf);
+
+ ASSERT_TRUE(result.is_ok());
+ ASSERT_TRUE(result.value().has_value());
+ EXPECT_EQ(result.value().value()->value.cast<String>(), "stop");
+}
+
+TEST(StructuralVisitor, TraversesPair) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ ObjectRef root = TPair(lhs, rhs);
+ StructuralVisitor visitor = MakeTestVisitor();
+
+ Optional<VisitInterrupt> result = visitor->DefaultVisit(root);
+
+ EXPECT_FALSE(result.has_value());
+ ASSERT_EQ(AsTestVisitor(visitor)->visited.size(), 2U);
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[0].same_as(lhs));
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[1].same_as(rhs));
+}
+
+TEST(StructuralVisitor, TraversesFunction) {
+ TVar param("x");
+ ObjectRef body_value = TInt(1);
+ Array<TVar> params = {param};
+ Array<ObjectRef> body = {body_value};
+ ObjectRef root = TFunc(params, body, String("ignored function comment"));
+ StructuralVisitor visitor = MakeTestVisitor();
+
+ Optional<VisitInterrupt> result = visitor->DefaultVisit(root);
+
+ EXPECT_FALSE(result.has_value());
+ TestVisitorObj* test_visitor = AsTestVisitor(visitor);
+ ASSERT_EQ(test_visitor->visited.size(), 4U);
+ EXPECT_TRUE(test_visitor->visited[0].same_as(params));
+ EXPECT_EQ(test_visitor->modes[0], kTVMFFIDefRegionKindRecursive);
+ EXPECT_TRUE(test_visitor->visited[1].same_as(param));
+ EXPECT_EQ(test_visitor->modes[1], kTVMFFIDefRegionKindRecursive);
+ EXPECT_TRUE(test_visitor->visited[2].same_as(body));
+ EXPECT_EQ(test_visitor->modes[2], kTVMFFIDefRegionKindNone);
+ EXPECT_TRUE(test_visitor->visited[3].same_as(body_value));
+ EXPECT_EQ(test_visitor->modes[3], kTVMFFIDefRegionKindNone);
+}
+
+TEST(StructuralVisitor, StopsOnInterrupt) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ ObjectRef root = TPair(lhs, rhs);
+ StructuralVisitor visitor = MakeTestVisitor();
+ SetInterrupt(visitor, lhs);
+
+ Optional<VisitInterrupt> result = visitor->DefaultVisit(root);
+
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(result.value()->value.cast<String>(), "stop");
+ ASSERT_EQ(AsTestVisitor(visitor)->visited.size(), 1U);
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[0].same_as(lhs));
+}
+
+TEST(StructuralVisitor, TraversesArray) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ Array<ObjectRef> root = {lhs, rhs};
+ StructuralVisitor visitor = MakeTestVisitor();
+
+ Optional<VisitInterrupt> result = visitor->DefaultVisit(root);
+
+ EXPECT_FALSE(result.has_value());
+ ASSERT_EQ(AsTestVisitor(visitor)->visited.size(), 2U);
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[0].same_as(lhs));
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[1].same_as(rhs));
+}
+
+TEST(StructuralVisitor, TraversesMap) {
+ ObjectRef key = TVar("key");
+ ObjectRef value = TVar("value");
+ Map<Any, Any> root{{key, value}};
+ StructuralVisitor visitor = MakeTestVisitor();
+
+ Optional<VisitInterrupt> result = visitor->DefaultVisit(root);
+
+ EXPECT_FALSE(result.has_value());
+ ASSERT_EQ(AsTestVisitor(visitor)->visited.size(), 2U);
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[0].same_as(key));
+ EXPECT_TRUE(AsTestVisitor(visitor)->visited[1].same_as(value));
+}
+
+TEST(StructuralVisitor, UsesFuncHook) {
+ TVar param("x");
+ ObjectRef body_value = TInt(1);
+ Array<TVar> params = {param};
+ Array<ObjectRef> body = {body_value};
+ ObjectRef root = TFunc(params, body, String("ignored function comment"));
+ StructuralVisitor visitor = MakeTestVisitor();
+
+ Optional<VisitInterrupt> result = visitor->Visit(root);
+
+ EXPECT_FALSE(result.has_value());
+ TestVisitorObj* test_visitor = AsTestVisitor(visitor);
+ ASSERT_EQ(test_visitor->visited.size(), 5U);
+ EXPECT_TRUE(test_visitor->visited[0].same_as(root));
+ EXPECT_EQ(test_visitor->modes[0], kTVMFFIDefRegionKindNone);
+ EXPECT_TRUE(test_visitor->visited[1].same_as(params));
+ EXPECT_EQ(test_visitor->modes[1], kTVMFFIDefRegionKindRecursive);
+ EXPECT_TRUE(test_visitor->visited[2].same_as(param));
+ EXPECT_EQ(test_visitor->modes[2], kTVMFFIDefRegionKindRecursive);
+ EXPECT_TRUE(test_visitor->visited[3].same_as(body));
+ EXPECT_EQ(test_visitor->modes[3], kTVMFFIDefRegionKindNone);
+ EXPECT_TRUE(test_visitor->visited[4].same_as(body_value));
+ EXPECT_EQ(test_visitor->modes[4], kTVMFFIDefRegionKindNone);
+ EXPECT_EQ(test_visitor->def_region_kind(), kTVMFFIDefRegionKindNone);
+}
+
+TEST(StructuralVisitor, RestoresFuncDefRegion) {
+ TVar param("x");
+ Array<TVar> params = {param};
+ Array<ObjectRef> body = {TInt(1)};
+ ObjectRef root = TFunc(params, body, String("ignored function comment"));
+ StructuralVisitor visitor = MakeTestVisitor();
+ SetInterrupt(visitor, param);
+
+ Optional<VisitInterrupt> result = visitor->Visit(root);
+
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(result.value()->value.cast<String>(), "stop");
+ TestVisitorObj* test_visitor = AsTestVisitor(visitor);
+ ASSERT_EQ(test_visitor->visited.size(), 3U);
+ EXPECT_TRUE(test_visitor->visited[0].same_as(root));
+ EXPECT_EQ(test_visitor->modes[0], kTVMFFIDefRegionKindNone);
+ EXPECT_TRUE(test_visitor->visited[1].same_as(params));
+ EXPECT_EQ(test_visitor->modes[1], kTVMFFIDefRegionKindRecursive);
+ EXPECT_TRUE(test_visitor->visited[2].same_as(param));
+ EXPECT_EQ(test_visitor->modes[2], kTVMFFIDefRegionKindRecursive);
+ EXPECT_EQ(test_visitor->def_region_kind(), kTVMFFIDefRegionKindNone);
+}
+
+// ---------------------------------------------------------------------------
+// StructuralWalk behavior.
+// ---------------------------------------------------------------------------
+
+TEST(StructuralVisitor, WalkSkipsChildren) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ ObjectRef root = TPair(lhs, rhs);
+ std::vector<std::string> visited;
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPreOrder>(
+ root, [&](const ObjectRef& node) -> Expected<WalkResult> {
+ if (node.same_as(root)) {
+ visited.emplace_back("pair");
+ return WalkResult::Skip();
+ }
+ visited.emplace_back("child");
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ ExpectTrace(visited, {"pair"});
+}
+
+TEST(StructuralVisitor, WalkPostOrder) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ ObjectRef root = TPair(lhs, rhs);
+ std::vector<std::string> visited;
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPostOrder>(
+ root, [&](const ObjectRef& node) -> Expected<WalkResult> {
+ if (node.same_as(lhs)) {
+ visited.emplace_back("lhs");
+ } else if (node.same_as(rhs)) {
+ visited.emplace_back("rhs");
+ } else if (node.same_as(root)) {
+ visited.emplace_back("pair");
+ }
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ ExpectTrace(visited, {"lhs", "rhs", "pair"});
+}
+
+TEST(StructuralVisitor, WalkInterrupts) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ ObjectRef root = TPair(lhs, rhs);
+ std::vector<std::string> visited;
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPreOrder>(
+ root, [&](const ObjectRef& node) -> Expected<WalkResult> {
+ if (node.same_as(lhs)) {
+ visited.emplace_back("lhs");
+ return WalkResult::Interrupt(VisitInterrupt(String("found lhs")));
+ }
+ if (node.same_as(rhs)) {
+ visited.emplace_back("rhs");
+ }
+ return WalkResult::Advance();
+ });
+
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(result.value()->value.cast<String>(), "found lhs");
+ ExpectTrace(visited, {"lhs"});
+}
+
+TEST(StructuralVisitor, WalkVisitsPOD) {
+ int64_t seen = 0;
+
+ Optional<VisitInterrupt> result =
+ StructuralWalk<WalkOrder::kPreOrder>(42, [&](int64_t value) ->
Expected<WalkResult> {
+ seen = value;
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ EXPECT_EQ(seen, 42);
+}
+
+TEST(StructuralVisitor, WalkVisitsObjectPtr) {
+ TVar root("x");
+ std::vector<std::string> visited;
+
+ Optional<VisitInterrupt> result =
+ StructuralWalk<WalkOrder::kPreOrder>(root, [&](const TVarObj* var) ->
Expected<WalkResult> {
+ visited.emplace_back(var->name);
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ ExpectTrace(visited, {"x"});
+}
+
+TEST(StructuralVisitor, WalkReceivesDefRegionKind) {
+ TVar x("x");
+ TVar y("y");
+ ObjectRef root = TFunc({x}, {x, y}, String("ignored function comment"));
+ std::vector<std::string> use_vars;
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPreOrder>(
+ root, [&](const TVarObj* var, TVMFFIDefRegionKind kind) ->
Expected<WalkResult> {
+ if (kind == kTVMFFIDefRegionKindNone) {
+ use_vars.emplace_back(var->name);
+ }
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ ExpectTrace(use_vars, {"x", "y"});
+}
+
+TEST(StructuralVisitor, WalkReturnsError) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ ObjectRef root = TPair(lhs, rhs);
+
+ Expected<Optional<VisitInterrupt>> result =
StructuralWalkExpected<WalkOrder::kPreOrder>(
+ root, [&](const ObjectRef& node) -> Expected<WalkResult> {
+ if (node.same_as(lhs)) {
+ return Unexpected(Error("ValueError", "walk callback failed", ""));
+ }
+ return WalkResult::Advance();
+ });
+
+ ASSERT_TRUE(result.is_err());
+ EXPECT_EQ(result.error().kind(), "ValueError");
+ EXPECT_EQ(result.error().message(), "walk callback failed");
+}
+
+TEST(StructuralVisitor, WalkCatchesError) {
+ ObjectRef root = TVar("root");
+
+ Expected<Optional<VisitInterrupt>> result =
StructuralWalkExpected<WalkOrder::kPreOrder>(
+ root, [&](const ObjectRef&) -> Expected<WalkResult> {
+ TVM_FFI_THROW(ValueError) << "walk callback threw";
+ });
+
+ ASSERT_TRUE(result.is_err());
+ EXPECT_EQ(result.error().kind(), "ValueError");
+ EXPECT_EQ(result.error().message(), "walk callback threw");
+}
+
+TEST(StructuralVisitor, WalkFirstMatch) {
+ ObjectRef lhs = TVar("lhs");
+ ObjectRef rhs = TVar("rhs");
+ List<ObjectRef> list = {lhs};
+ Array<ObjectRef> root = {list, rhs};
+ std::vector<std::string> trace;
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPreOrder>(
+ root,
+ [&](const Array<ObjectRef>&) -> Expected<WalkResult> {
+ trace.emplace_back("array");
+ return WalkResult::Advance();
+ },
+ [&](const List<ObjectRef>&) -> Expected<WalkResult> {
+ trace.emplace_back("list");
+ return WalkResult::Advance();
+ },
+ [&](const ObjectRef&) -> Expected<WalkResult> {
+ trace.emplace_back("object");
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ ExpectTrace(trace, {"array", "list", "object", "object"});
+}
+
+TEST(StructuralVisitor, WalkFuncProgram) {
+ TVar m("m");
+ TVar n("n");
+ TVar acc("acc");
+ ObjectRef func =
+ TFunc({m, n}, {TInt(7), acc, TPair(m, TInt(1))}, String("ignored
function comment"));
+ Array<Any> root = {func, String("metadata"), nullptr};
+ std::vector<std::string> trace;
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPreOrder>(
+ root,
+ [&](const TFuncObj*) -> Expected<WalkResult> {
+ trace.emplace_back("func*");
+ return WalkResult::Advance();
+ },
+ [&](const TVarObj* var) -> Expected<WalkResult> {
+ trace.emplace_back("var*:" + var->name);
+ return WalkResult::Advance();
+ },
+ [&](int64_t value) -> Expected<WalkResult> {
+ trace.emplace_back("int:" + std::to_string(value));
+ return WalkResult::Advance();
+ },
+ [&](const ObjectRef&) -> Expected<WalkResult> {
+ trace.emplace_back("object-ref");
+ return WalkResult::Advance();
+ },
+ [&](AnyView) -> Expected<WalkResult> {
+ trace.emplace_back("any-view");
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ ExpectTrace(trace,
+ {"object-ref", "func*", "object-ref", "var*:m", "var*:n",
"object-ref", "object-ref",
+ "int:7", "var*:acc", "object-ref", "var*:m", "object-ref",
"int:1", "object-ref"});
+}
+
+TEST(StructuralVisitor, WalkAnyFallback) {
+ Array<Any> root = {String("metadata"), nullptr};
+ std::vector<std::string> trace;
+
+ Optional<VisitInterrupt> result = StructuralWalk<WalkOrder::kPreOrder>(
+ root,
+ [&](const ObjectRef&) -> Expected<WalkResult> {
+ trace.emplace_back("object-ref");
+ return WalkResult::Advance();
+ },
+ [&](const Any& value) -> Expected<WalkResult> {
+ trace.emplace_back(value == nullptr ? "any:none" : "any:value");
+ return WalkResult::Advance();
+ });
+
+ EXPECT_FALSE(result.has_value());
+ ExpectTrace(trace, {"object-ref", "object-ref"});
+}
+
+} // namespace
diff --git a/tests/cpp/test_expected.cc b/tests/cpp/test_expected.cc
index a1cdfb1..4f432a8 100644
--- a/tests/cpp/test_expected.cc
+++ b/tests/cpp/test_expected.cc
@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/expected.h>
@@ -343,6 +344,29 @@ TEST(Expected, TryCastIncompatible) {
EXPECT_FALSE(result.has_value()); // Cannot convert String to Expected<int>
}
+TEST(Expected, ExpectedUnsafeGetDataCompatibleStorageType) {
+ Expected<Variant<String, bool>> result = Variant<String, bool>(false);
+
+ EXPECT_EQ(result.type_index(), TypeIndex::kTVMFFIBool);
+ EXPECT_FALSE(details::AnyUnsafe::CopyFromAnyViewAfterCheck<bool>(
+ details::ExpectedUnsafe::GetData(result)));
+
+ Expected<Variant<String, bool>> true_result = Variant<String, bool>(true);
+ EXPECT_TRUE(details::AnyUnsafe::CopyFromAnyViewAfterCheck<bool>(
+ details::ExpectedUnsafe::GetData(true_result)));
+}
+
+TEST(Expected, ExpectedUnsafeMoveBetweenExpectedStorageTypes) {
+ Expected<String> src = String("hello");
+ TVMFFIAny raw = details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(src));
+ Expected<Optional<String>> dst =
+ details::ExpectedUnsafe::MoveFromTVMFFIAny<Optional<String>>(raw);
+
+ ASSERT_TRUE(dst.is_ok());
+ ASSERT_TRUE(dst.value().has_value());
+ EXPECT_EQ(dst.value().value(), "hello");
+}
+
// Test that Expected<DLDataType>::value() && compiles and runs correctly.
// Requires TypeTraits<DLDataType>::MoveFromAnyAfterCheck to be defined.
TEST(ExpectedRvalueMove, DLDataTypeMoveCompiles) {
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index 6f99a74..556700f 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -66,6 +66,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
TFloatObj::RegisterReflection();
TPrimExprObj::RegisterReflection();
TVarObj::RegisterReflection();
+ TPairObj::RegisterReflection();
TVarWithDepObj::RegisterReflection();
TDefHolderObj::RegisterReflection();
TFuncObj::RegisterReflection();
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index d1bf0a9..c3d5bfc 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -24,6 +24,7 @@
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/extra/structural_visit.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/registry.h>
@@ -206,6 +207,32 @@ class TVar : public ObjectRef {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVar, ObjectRef, TVarObj);
};
+class TPairObj : public Object {
+ public:
+ ObjectRef lhs;
+ ObjectRef rhs;
+
+ TPairObj(ObjectRef lhs, ObjectRef rhs) : lhs(std::move(lhs)),
rhs(std::move(rhs)) {}
+ explicit TPairObj(UnsafeInit) {}
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TPairObj>().def_ro("lhs", &TPairObj::lhs).def_ro("rhs",
&TPairObj::rhs);
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.Pair", TPairObj, Object);
+};
+
+class TPair : public ObjectRef {
+ public:
+ TPair(ObjectRef lhs, ObjectRef rhs) {
+ data_ = make_object<TPairObj>(std::move(lhs), std::move(rhs));
+ }
+
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TPair, ObjectRef, TPairObj);
+};
+
// FreeVar test object that has a sub-field referencing another FreeVar.
// This models the "var with nested vars" case (analogous to a relax::Var
// whose struct_info contains tir shape vars). It is used to exercise the
@@ -292,12 +319,25 @@ class TFuncObj : public Object {
TFuncObj(Array<TVar> params, Array<ObjectRef> body, Optional<String> comment)
: params(params), body(body), comment(comment) {}
+ static TVMFFIAny StructuralVisit(StructuralVisitorObj* visitor, AnyView
value) noexcept {
+ const auto* self = value.cast<const TFuncObj*>();
+
+ TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN(visitor->WithDefRegionKind(
+ kTVMFFIDefRegionKindRecursive, [&]() { return
visitor->VisitExpected(self->params); }));
+
+ return
details::ExpectedUnsafe::MoveToTVMFFIAny(visitor->VisitExpected(self->body));
+ }
+
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TFuncObj>()
.def_ro("params", &TFuncObj::params,
refl::AttachFieldFlag::SEqHashDefRecursive())
.def_ro("body", &TFuncObj::body)
.def_ro("comment", &TFuncObj::comment,
refl::AttachFieldFlag::SEqHashIgnore());
+ refl::EnsureTypeAttrColumn(refl::type_attr::kStructuralVisit);
+ refl::TypeAttrDef<TFuncObj>().attr(
+ refl::type_attr::kStructuralVisit,
+
reinterpret_cast<void*>(static_cast<FStructuralVisit>(&TFuncObj::StructuralVisit)));
}
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
diff --git a/tests/python/test_structural.py b/tests/python/test_structural.py
index 5808c41..6877978 100644
--- a/tests/python/test_structural.py
+++ b/tests/python/test_structural.py
@@ -15,9 +15,14 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+from typing import Any
+
import numpy as np
import tvm_ffi
import tvm_ffi.testing
+from tvm_ffi.dataclasses import Object, field, py_class
_recursive_eq = tvm_ffi.get_global_func("ffi.RecursiveEq")
@@ -158,3 +163,204 @@ def test_recursive_eq_mutual_cycle() -> None:
# Different content should not be equal.
c = make_cyclic(99)
assert not _recursive_eq(a, c)
+
+
+def test_visit_interrupt_payload() -> None:
+ payload = {"reason": "found", "path": [1, 2, 3]}
+ interrupt = tvm_ffi.VisitInterrupt(payload)
+
+ assert isinstance(interrupt, tvm_ffi.VisitInterrupt)
+ assert tvm_ffi.structural_equal(interrupt.value, payload)
+
+
+def test_structural_walk_typed_callbacks() -> None:
+ root = tvm_ffi.Array([1, 2.5, "tag"])
+ trace: list[str] = []
+
+ result = tvm_ffi.structural_walk(
+ root,
+ [
+ (tvm_ffi.Array, lambda value: trace.append(f"array:{len(value)}")),
+ ((int, float), lambda value: trace.append(f"number:{value}")),
+ (str, lambda value: trace.append(f"str:{value}")),
+ ],
+ )
+
+ assert result is None
+ assert trace == ["array:3", "number:1", "number:2.5", "str:tag"]
+
+
+def test_structural_walk_callback_def_region_kind() -> None:
+ @py_class(structural_eq="var")
+ class PyWalkVar(Object):
+ name: str = field(structural_eq="ignore")
+
+ @py_class(structural_eq="tree")
+ class PyWalkFunc(Object):
+ params: tvm_ffi.Array[PyWalkVar] = field(structural_eq="def")
+ body: tvm_ffi.Array[PyWalkVar]
+
+ x = PyWalkVar("x")
+ y = PyWalkVar("y")
+ root = PyWalkFunc(tvm_ffi.Array([x]), tvm_ffi.Array([x, y]))
+ uses: list[str] = []
+
+ result = tvm_ffi.structural_walk(
+ root,
+ with_def_region_kind=(
+ PyWalkVar,
+ lambda value, kind: (
+ uses.append(value.name) if kind == tvm_ffi.DefRegionKind.NONE
else None
+ ),
+ ),
+ )
+
+ assert result is None
+ assert uses == ["x", "y"]
+
+
+def test_structural_walk_first_match_and_skip() -> None:
+ root = tvm_ffi.Array([1, 2])
+ trace: list[str] = []
+
+ result = tvm_ffi.structural_walk(
+ root,
+ [
+ (
+ tvm_ffi.Array,
+ lambda value: trace.append(f"array:{len(value)}") or
tvm_ffi.WalkResult.SKIP,
+ ),
+ (object, lambda value: trace.append(type(value).__name__)),
+ ],
+ )
+
+ assert result is None
+ assert trace == ["array:2"]
+
+
+def test_structural_walk_interrupt() -> None:
+ root = tvm_ffi.Array([1, 2, 3])
+
+ def on_int(value: int) -> tvm_ffi.VisitInterrupt | None:
+ if value == 2:
+ return tvm_ffi.VisitInterrupt({"found": value})
+ return None
+
+ result = tvm_ffi.structural_walk(root, (int, on_int))
+
+ assert isinstance(result, tvm_ffi.VisitInterrupt)
+ assert tvm_ffi.structural_equal(result.value, {"found": 2})
+
+
+def test_structural_walk_nested_containers() -> None:
+ root = tvm_ffi.Array(
+ [
+ tvm_ffi.Map(
+ {
+ "numbers": tvm_ffi.Array([1, 2]),
+ "meta": tvm_ffi.Dict({"flag": True}),
+ }
+ ),
+ 3,
+ ]
+ )
+ containers: list[tuple[str, int]] = []
+ scalars: list[int] = []
+ strings: list[str] = []
+
+ result = tvm_ffi.structural_walk(
+ root,
+ [
+ (tvm_ffi.Array, lambda value: containers.append(("array",
len(value)))),
+ (tvm_ffi.Map, lambda value: containers.append(("map",
len(value)))),
+ (tvm_ffi.Dict, lambda value: containers.append(("dict",
len(value)))),
+ ((int, bool), lambda value: scalars.append(int(value))),
+ (str, lambda value: strings.append(value)),
+ ],
+ )
+
+ assert result is None
+ assert [kind for kind, _ in containers].count("array") == 2
+ assert ("map", 2) in containers
+ assert ("dict", 1) in containers
+ assert sorted(scalars) == [1, 1, 2, 3]
+ assert set(strings) == {"numbers", "meta", "flag"}
+
+
+def test_structural_walk_object_and_any_callbacks() -> None:
+ root = tvm_ffi.Array([1, tvm_ffi.Array([2])])
+ trace: list[str] = []
+
+ result = tvm_ffi.structural_walk(
+ root,
+ [
+ (tvm_ffi.Object, lambda value:
trace.append(f"object:{type(value).__name__}")),
+ (Any, lambda value: trace.append(f"any:{value}")),
+ ],
+ )
+
+ assert result is None
+ assert trace == ["object:Array", "any:1", "object:Array", "any:2"]
+
+ alias_trace: list[str] = []
+ result = tvm_ffi.structural_walk(
+ tvm_ffi.Array([1]),
+ (object, lambda value: alias_trace.append(type(value).__name__)),
+ )
+
+ assert result is None
+ assert alias_trace == ["Array", "int"]
+
+
+def test_structural_walk_post_order_enum() -> None:
+ root = tvm_ffi.Array([tvm_ffi.Array([1]), 2])
+ trace: list[str] = []
+
+ result = tvm_ffi.structural_walk(
+ root,
+ [
+ (tvm_ffi.Array, lambda value: trace.append(f"array:{len(value)}")),
+ (int, lambda value: trace.append(f"int:{value}")),
+ ],
+ order=tvm_ffi.WalkOrder.POSTORDER,
+ )
+
+ assert result is None
+ assert trace == ["int:1", "array:1", "int:2", "array:2"]
+
+
+def test_structural_walk_mixed_callback_forms() -> None:
+ @py_class(structural_eq="var")
+ class PyWalkMixedVar(Object):
+ name: str = field(structural_eq="ignore")
+
+ @py_class(structural_eq="tree")
+ class PyWalkMixedFunc(Object):
+ params: tvm_ffi.Array[PyWalkMixedVar] = field(structural_eq="def")
+ body: tvm_ffi.Array[PyWalkMixedVar]
+
+ x = PyWalkMixedVar("x")
+ y = PyWalkMixedVar("y")
+ root = tvm_ffi.Array([PyWalkMixedFunc(tvm_ffi.Array([x]),
tvm_ffi.Array([x, y])), "tag"])
+ trace: list[str] = []
+
+ result = tvm_ffi.structural_walk(
+ root,
+ [
+ (tvm_ffi.Array, lambda value: trace.append(f"array:{len(value)}")),
+ (str, lambda value: trace.append(f"str:{value}")),
+ ],
+ with_def_region_kind=[
+ (
+ PyWalkMixedVar,
+ lambda value, kind: (
+ trace.append(f"use:{value.name}")
+ if kind == tvm_ffi.DefRegionKind.NONE
+ else None
+ ),
+ ),
+ ],
+ )
+
+ assert result is None
+ assert trace == ["array:2", "array:1", "array:2", "use:x", "use:y",
"str:tag"]