This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1240649257 [FFI][REFACTOR] Direct structural APIs to tvm-ffi (#19661)
1240649257 is described below
commit 12406492577818a329464fade45d068102173c7d
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Jun 3 18:57:05 2026 -0400
[FFI][REFACTOR] Direct structural APIs to tvm-ffi (#19661)
## Summary
Python callers should reach the canonical tvm-ffi structural helpers
directly instead of going through a TVM-side redirect layer. This makes
the public tvm.ir bindings exact aliases of the tvm_ffi APIs and exposes
get_first_structural_mismatch from tvm.ir.
Main changes:
- Import structural_equal, get_first_structural_mismatch, and
structural_hash directly from tvm_ffi
- Remove the pure wrappers from tvm.ir.base while keeping
assert_structural_equal's TVM-specific formatting
- Update mismatch tests and add identity coverage for the direct
bindings
---
docs/deep_dive/tensor_ir/tutorials/tir_creation.py | 5 +-
docs/reference/security.rst | 2 +-
python/tvm/ir/__init__.py | 2 -
python/tvm/ir/base.py | 119 +-----
python/tvm/ir/expr.py | 2 +-
python/tvm/ir/global_info.py | 2 +-
python/tvm/ir/type.py | 3 +-
python/tvm/relax/expr.py | 2 +-
python/tvm/relax/frontend/nn/subroutine.py | 9 +-
python/tvm/relax/frontend/onnx/onnx_frontend.py | 3 +-
.../frontend/torch/base_fx_graph_translator.py | 5 +-
python/tvm/relax/script/parser/parser.py | 6 +-
.../relax/transform/optimize_layout_transform.py | 5 +-
.../relax/transform/remove_redundant_reshape.py | 7 +-
python/tvm/s_tir/dlight/analysis/gemv.py | 6 +-
python/tvm/s_tir/dlight/benchmark/extract.py | 3 +-
python/tvm/s_tir/dlight/gpu/low_batch_gemv.py | 6 +-
python/tvm/s_tir/dlight/gpu/reduction.py | 6 +-
.../meta_schedule/testing/space_generation.py | 6 +-
.../meta_schedule/testing/validate_database.py | 5 +-
python/tvm/tirx/analysis/analysis.py | 6 +-
.../python/arith/test_arith_canonical_simplify.py | 4 +-
tests/python/arith/test_arith_rewrite_simplify.py | 3 +-
.../arith/test_arith_solve_linear_equations.py | 27 +-
.../arith/test_arith_solve_linear_inequality.py | 37 +-
tests/python/contrib/test_hexagon/test_take.py | 3 +-
tests/python/ir/test_container_structural_equal.py | 177 ---------
tests/python/ir/test_ir_attrs.py | 22 +-
.../distributed/test_distributed_dtensor_sinfo.py | 7 +-
.../relax/test_analysis_struct_info_analysis.py | 3 +-
tests/python/relax/test_dataflow_pattern.py | 9 +-
...eliminate_pad_branch_using_buffer_assumption.py | 8 +-
tests/python/relax/test_expr.py | 5 +-
.../python/relax/test_frontend_nn_extern_module.py | 3 +-
tests/python/relax/test_struct_info.py | 7 +-
tests/python/relax/test_transform_lambda_lift.py | 6 +-
.../relax/test_transform_meta_schedule_tuning.py | 6 +-
...st_transform_operator_specific_normalization.py | 13 +-
.../relax/test_transform_rewrite_cuda_graph.py | 3 +-
.../meta_schedule/test_meta_schedule_database.py | 29 +-
.../test_meta_schedule_post_order_apply.py | 7 +-
.../test_meta_schedule_tune_context.py | 3 +-
.../s_tir/schedule/test_tir_schedule_reduction.py | 5 +-
.../s_tir/schedule/test_tir_schedule_state.py | 19 +-
.../transform/test_s_tir_transform_hoist_if.py | 9 +-
tests/python/tirx-base/test_tir_constructor.py | 3 +-
.../tirx-base/test_tir_structural_equal_hash.py | 440 ---------------------
tests/python/tvmscript/test_tvmscript_roundtrip.py | 3 +-
48 files changed, 201 insertions(+), 870 deletions(-)
diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py
b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py
index ca59f7a8db..2f02c62dde 100644
--- a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py
+++ b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py
@@ -53,6 +53,7 @@ If not already acquainted, please refer to
:ref:`tirx-learning` initially.
# format of the ir_module and in TVMScript:
import numpy as np
+import tvm_ffi
import tvm
from tvm.script import ir as I
@@ -126,7 +127,7 @@ class ConciseModule:
######################################################################
# We can use the following code to verify that the two modules are equivalent:
-print(tvm.ir.structural_equal(MyModule, ConciseModule))
+print(tvm_ffi.structural_equal(MyModule, ConciseModule))
######################################################################
# Interactive with Python Variables
@@ -165,7 +166,7 @@ class ConciseModuleFromPython:
######################################################################
# Check the equivalence:
-print(tvm.ir.structural_equal(ConciseModule, ConciseModuleFromPython))
+print(tvm_ffi.structural_equal(ConciseModule, ConciseModuleFromPython))
######################################################################
diff --git a/docs/reference/security.rst b/docs/reference/security.rst
index de3ebf464d..c044c36bb1 100644
--- a/docs/reference/security.rst
+++ b/docs/reference/security.rst
@@ -58,7 +58,7 @@ Subroutine Cache Hash Collision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``SubroutineMixin._get_subroutine()`` in
``python/tvm/relax/frontend/nn/subroutine.py``
-used ``ir.structural_hash`` as the sole cache lookup key without a subsequent
+used ``tvm_ffi.structural_hash`` as the sole cache lookup key without a
subsequent
``structural_equal`` verification. If two different ``arg_sinfo`` values
produced the
same 64-bit hash, the cache would return a previously compiled function with
mismatched parameter shapes, leading to silently incorrect compiled output.
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 50073a942a..60e11fbf56 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -29,8 +29,6 @@ from .base import (
assert_structural_equal,
load_json,
save_json,
- structural_equal,
- structural_hash,
)
from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelaxExpr
from .function import BaseFunc, CallingConv
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index ceccb401f4..6bae30f791 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -162,81 +162,6 @@ def save_json(node) -> str:
return to_json_graph_str(node, {"tvm_version": __version__})
-def structural_equal(lhs, rhs, map_free_vars=False):
- """Check structural equality of lhs and rhs.
-
- The structural equality is recursively defined in the DAG of IRNodes.
- There are two kinds of nodes:
-
- - Graph node: a graph node in lhs can only be mapped as equal to
- one and only one graph node in rhs.
- - Normal node: equality is recursively defined without the restriction
- of graph nodes.
-
- Vars(tirx::Var, relax::Var) are graph nodes.
-
- A var-type node(e.g. tirx::Var) can be mapped as equal to another var
- with the same type if one of the following condition holds:
-
- - They appear in a same definition point(e.g. function argument).
- - They points to the same VarNode via the same_as relation.
- - They appear in a same usage point, and map_free_vars is set to be True.
-
- The rules for var are used to remap variables occurs in function
- arguments and let-bindings.
-
- Parameters
- ----------
- lhs : Object
- The left operand.
-
- rhs : Object
- The left operand.
-
- map_free_vars : bool
- Whether free variables (i.e. variables without a definition site)
should be mapped
- as equal to each other.
-
- Return
- ------
- result : bool
- The comparison result.
-
- See Also
- --------
- structural_hash
- assert_strucural_equal
- """
- return tvm_ffi.structural_equal(lhs, rhs, map_free_vars)
-
-
-def get_first_structural_mismatch(lhs, rhs, map_free_vars=False,
skip_tensor_content=False):
- """Like structural_equal(), but returns the AccessPath pair of the first
detected mismatch.
-
- Parameters
- ----------
- lhs : Object
- The left operand.
-
- rhs : Object
- The left operand.
-
- map_free_vars : bool
- Whether free variables (i.e. variables without a definition site)
should be mapped
- as equal to each other.
-
- skip_tensor_content : bool
- Whether to skip the content of ndarray.
-
- Returns
- -------
- mismatch: Optional[Tuple[AccessPath, AccessPath]]
- `None` if `lhs` and `rhs` are structurally equal.
- Otherwise, a tuple of two AccessPath objects that point to the first
detected mismtach.
- """
- return tvm_ffi.get_first_structural_mismatch(lhs, rhs, map_free_vars,
skip_tensor_content)
-
-
def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.
@@ -258,7 +183,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
See Also
--------
- structural_equal
+ tvm_ffi.structural_equal
"""
first_mismatch = tvm_ffi.get_first_structural_mismatch(lhs, rhs,
map_free_vars)
if first_mismatch is not None:
@@ -278,48 +203,6 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
)
-def structural_hash(node, map_free_vars=False):
- """Compute structural hash of node
-
- The structural hash value is recursively defined in the DAG of IRNodes.
- There are two kinds of nodes:
-
- - Normal node: the hash value is defined by its content and type only.
- - Graph node: each graph node will be assigned a unique index ordered by
the
- first occurrence during the visit. The hash value of a graph node is
- combined from the hash values of its contents and the index.
-
- structural_hash is made to be concistent with structural_equal.
- If two nodes are structurally equal to each other,
- then their structural hash (with the same map_free_vars option)
- should be equal to each other as well.
-
- If the structural hash of two nodes equals to each other,
- then it is highly likely(except for rare hash value collison cases)
- that the two nodes are structurally equal to each other.
-
- Parameters
- ----------
- node : Object
- The input to be hashed.
-
- map_free_vars : bool
- If map_free_vars is set to true, we will hash free variables
- by the order of their occurrences. Otherwise, we will hash by
- their in-memory pointer address.
-
- Return
- ------
- result : int
- The hash result
-
- See Also
- --------
- structrual_equal
- """
- return tvm_ffi.structural_hash(node, map_free_vars)
-
-
def deprecated(
method_name: str,
new_method_name: str,
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 3dab9d02d5..c2107c9a8f 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -167,7 +167,7 @@ class Range(Node, Scriptable):
return _ffi_api.Range_from_min_extent(min_value, extent, span)
def __eq__(self, other: Object) -> bool:
- return tvm.ir.structural_equal(self, other)
+ return tvm_ffi.structural_equal(self, other)
def __ne__(self, other: Object) -> bool:
return not self.__eq__(other)
diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py
index 14bdb76b08..c6754e747a 100644
--- a/python/tvm/ir/global_info.py
+++ b/python/tvm/ir/global_info.py
@@ -30,7 +30,7 @@ class GlobalInfo(Object):
def __eq__(self, other):
"""Compare two struct info for structural equivalence."""
- return tvm.ir.structural_equal(self, other)
+ return tvm_ffi.structural_equal(self, other)
def __ne__(self, other):
return not self.__eq__(other)
diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py
index 8d2a71d0a9..3ade4b80fc 100644
--- a/python/tvm/ir/type.py
+++ b/python/tvm/ir/type.py
@@ -18,7 +18,6 @@
import tvm_ffi
-import tvm
from tvm.runtime import Scriptable
from . import _ffi_api
@@ -31,7 +30,7 @@ class Type(Node, Scriptable):
def __eq__(self, other):
"""Compare two types for structural equivalence."""
- return bool(tvm.ir.structural_equal(self, other))
+ return bool(tvm_ffi.structural_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 6dffaab8f4..730febc51c 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -69,7 +69,7 @@ class StructInfo(Node, Scriptable):
def __eq__(self, other):
"""Compare two struct info for structural equivalence."""
- return tvm.ir.structural_equal(self, other)
+ return tvm_ffi.structural_equal(self, other)
def __ne__(self, other):
return not self.__eq__(other)
diff --git a/python/tvm/relax/frontend/nn/subroutine.py
b/python/tvm/relax/frontend/nn/subroutine.py
index abd94b19cb..e821756e8d 100644
--- a/python/tvm/relax/frontend/nn/subroutine.py
+++ b/python/tvm/relax/frontend/nn/subroutine.py
@@ -27,7 +27,6 @@ import typing
import tvm_ffi
from tvm import ir, relax
-from tvm.ir import structural_equal
from tvm.relax.frontend import nn
@@ -144,10 +143,14 @@ class SubroutineMixin:
arg_sinfo = _get_struct_info([*func_args.values(), *model_params])
is_dataflow = block_builder.current_block_is_dataflow()
- lookup_key = (old_forward, ir.structural_hash(arg_sinfo,
map_free_vars=True), is_dataflow)
+ lookup_key = (
+ old_forward,
+ tvm_ffi.structural_hash(arg_sinfo, map_free_vars=True),
+ is_dataflow,
+ )
for cached_sinfo, cached_result in cls._gvar.get(lookup_key, []):
- if structural_equal(cached_sinfo, arg_sinfo, map_free_vars=True):
+ if tvm_ffi.structural_equal(cached_sinfo, arg_sinfo,
map_free_vars=True):
return cached_result
func_name = _camel_to_snake(cls.__name__)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 1a224e431b..b82fceff1d 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -47,6 +47,7 @@ from typing import Any
import numpy as _np
import onnx.onnx_ml_pb2
+import tvm_ffi
import tvm
from tvm import TVMError, relax, tirx, topi
@@ -661,7 +662,7 @@ class Equal(OnnxOpConverter):
rhs = get_prim_expr_list(inputs[1])
if len(lhs) != len(rhs):
raise ValueError("Cannot compare two tensors with different
shapes")
- output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)]
+ output = [tvm_ffi.structural_equal(l, r) for l, r in zip(lhs, rhs)]
return relax.const(output, "bool")
return relax.op.equal(inputs[0], inputs[1])
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 581475ebd8..91b6a3a171 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -26,7 +26,8 @@ import operator
from collections.abc import Callable
from functools import reduce
-import tvm
+import tvm_ffi
+
from tvm import relax, tirx
@@ -2537,7 +2538,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
data_shape = self.shape_of(data)
mask_shape = self.shape_of(mask)
- shapes_equal = tvm.ir.structural_equal(data_shape, mask_shape)
+ shapes_equal = tvm_ffi.structural_equal(data_shape, mask_shape)
if not shapes_equal:
mask = self.block_builder.emit(relax.op.broadcast_to(mask,
data_shape))
diff --git a/python/tvm/relax/script/parser/parser.py
b/python/tvm/relax/script/parser/parser.py
index 47daf17d35..256b867a10 100644
--- a/python/tvm/relax/script/parser/parser.py
+++ b/python/tvm/relax/script/parser/parser.py
@@ -20,8 +20,10 @@ import functools
import numbers
from typing import Any
+import tvm_ffi
+
from tvm import relax, tirx
-from tvm.ir import GlobalVar, structural_equal
+from tvm.ir import GlobalVar
from tvm.relax import Expr, StructInfo
from tvm.relax.script import builder as R
from tvm.relax.script.builder.frame import BindingBlockFrame
@@ -87,7 +89,7 @@ def bind_assign_value(
if isinstance(value, relax.Expr):
var = R.emit(value, anno_sinfo)
elif isinstance(value, MatchCastPair):
- if anno_sinfo is not None and not structural_equal(anno_sinfo,
value.struct_info):
+ if anno_sinfo is not None and not tvm_ffi.structural_equal(anno_sinfo,
value.struct_info):
self.report_error(
node, "Cannot specify inconsistent annotation for a match cast
pair. "
)
diff --git a/python/tvm/relax/transform/optimize_layout_transform.py
b/python/tvm/relax/transform/optimize_layout_transform.py
index 53e87f5713..7dd071dd7e 100644
--- a/python/tvm/relax/transform/optimize_layout_transform.py
+++ b/python/tvm/relax/transform/optimize_layout_transform.py
@@ -17,7 +17,8 @@
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
"""Relax Optimize Layout Transform pass."""
-from tvm.ir import structural_equal
+import tvm_ffi
+
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.relax import Expr
@@ -78,7 +79,7 @@ class OptimizeLayoutTransform:
if "remove_pad" == self.mod[arg2].attrs["operator_name"]:
arg2 = matches[self.input]
if hasattr(arg1.struct_info, "shape") and
hasattr(arg2.struct_info, "shape"):
- if structural_equal(arg1.struct_info.shape,
arg2.struct_info.shape):
+ if tvm_ffi.structural_equal(arg1.struct_info.shape,
arg2.struct_info.shape):
return arg2
return expr
diff --git a/python/tvm/relax/transform/remove_redundant_reshape.py
b/python/tvm/relax/transform/remove_redundant_reshape.py
index 5ceb4bf852..11119f5e8c 100644
--- a/python/tvm/relax/transform/remove_redundant_reshape.py
+++ b/python/tvm/relax/transform/remove_redundant_reshape.py
@@ -17,8 +17,9 @@
# pylint: disable=invalid-name, unused-argument, missing-function-docstring,
abstract-method
"""Relax Remove Redundant Reshape ops"""
+import tvm_ffi
+
from tvm import IRModule, relax
-from tvm.ir import structural_equal
from tvm.ir.transform import PassContext
from tvm.relax import Expr
from tvm.relax.dpl import is_op, rewrite_call, wildcard
@@ -73,7 +74,9 @@ class RemoveRedundantReshape:
elif self.no_op_reshape in matches:
output_shape = matches[self.no_op_reshape].args[1]
- if arg.struct_info.shape and
structural_equal(arg.struct_info.shape, output_shape):
+ if arg.struct_info.shape and tvm_ffi.structural_equal(
+ arg.struct_info.shape, output_shape
+ ):
return arg
return expr
diff --git a/python/tvm/s_tir/dlight/analysis/gemv.py
b/python/tvm/s_tir/dlight/analysis/gemv.py
index 75d5b17dfd..a9c8cb82e6 100644
--- a/python/tvm/s_tir/dlight/analysis/gemv.py
+++ b/python/tvm/s_tir/dlight/analysis/gemv.py
@@ -16,7 +16,9 @@
# under the License.
"""Analysis for GEMV."""
-from tvm import arith, ir, s_tir, tirx
+import tvm_ffi
+
+from tvm import arith, s_tir, tirx
from .common_analysis import (
SBlockInfo,
@@ -48,7 +50,7 @@ def get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr |
None:
return None
if not isinstance(buffer_store.value, tirx.Add):
return None
- if not ir.structural_equal(
+ if not tvm_ffi.structural_equal(
buffer_store.value.a,
tirx.BufferLoad(buffer_store.buffer, block.body.indices),
map_free_vars=True,
diff --git a/python/tvm/s_tir/dlight/benchmark/extract.py
b/python/tvm/s_tir/dlight/benchmark/extract.py
index 33d9b4402b..ee7358efcf 100644
--- a/python/tvm/s_tir/dlight/benchmark/extract.py
+++ b/python/tvm/s_tir/dlight/benchmark/extract.py
@@ -19,6 +19,7 @@
from pathlib import Path
import cloudpickle
+import tvm_ffi
import tvm
from tvm import relax
@@ -294,7 +295,7 @@ def extract_prim_func( # pylint: disable=too-many-arguments
"model_name": model_name,
"relax_func_name": relax_func_name,
"prim_func_name": prim_func_name,
- "func_hash": tvm.ir.structural_hash(func),
+ "func_hash": tvm_ffi.structural_hash(func),
"weight": weight,
"sample_number": sample_number,
"dym_var_dict": f"pickle.loads({cloudpickle.dumps(dym_var_dict)})"
diff --git a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
index 15a9f8f506..74a1d8923a 100644
--- a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
+++ b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
@@ -20,7 +20,9 @@
from functools import reduce
from typing import Literal
-from tvm import arith, ir, s_tir, tirx
+import tvm_ffi
+
+from tvm import arith, s_tir, tirx
from tvm.target import Target
from ..analysis import (
@@ -42,7 +44,7 @@ def _get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr
| None:
return None
if not isinstance(buffer_store.value, tirx.Add):
return None
- if not ir.structural_equal(
+ if not tvm_ffi.structural_equal(
buffer_store.value.a,
tirx.BufferLoad(buffer_store.buffer, block.body.indices),
map_free_vars=True,
diff --git a/python/tvm/s_tir/dlight/gpu/reduction.py
b/python/tvm/s_tir/dlight/gpu/reduction.py
index af310c25c5..cb96651347 100644
--- a/python/tvm/s_tir/dlight/gpu/reduction.py
+++ b/python/tvm/s_tir/dlight/gpu/reduction.py
@@ -19,7 +19,9 @@
# TODO: combine reduction rule and general reduction rule into one file.
from collections.abc import Mapping
-from tvm import arith, ir, s_tir, tirx
+import tvm_ffi
+
+from tvm import arith, s_tir, tirx
from tvm.target import Target
from ..analysis import (
@@ -39,7 +41,7 @@ def _get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr
| None:
return None
if not isinstance(buffer_store.value, tirx.Add):
return None
- if not ir.structural_equal(
+ if not tvm_ffi.structural_equal(
buffer_store.value.a,
tirx.BufferLoad(buffer_store.buffer, block.body.indices),
map_free_vars=True,
diff --git a/python/tvm/s_tir/meta_schedule/testing/space_generation.py
b/python/tvm/s_tir/meta_schedule/testing/space_generation.py
index b2a5046f65..877daf423b 100644
--- a/python/tvm/s_tir/meta_schedule/testing/space_generation.py
+++ b/python/tvm/s_tir/meta_schedule/testing/space_generation.py
@@ -21,7 +21,9 @@ from typing import Literal
# isort: on
-from tvm.ir import IRModule, structural_equal
+import tvm_ffi
+
+from tvm.ir import IRModule
from tvm.s_tir import Schedule
from tvm.s_tir import meta_schedule as ms
from tvm.s_tir.schedule import Trace
@@ -51,7 +53,7 @@ def structural_equal_no_gs(mod1: IRModule, mod2: IRModule) ->
bool:
stripped_mod[global_var] = func.without_attr("global_symbol")
return stripped_mod
- return structural_equal(remove_global_symbols(mod1),
remove_global_symbols(mod2))
+ return tvm_ffi.structural_equal(remove_global_symbols(mod1),
remove_global_symbols(mod2))
def generate_design_space(
diff --git a/python/tvm/s_tir/meta_schedule/testing/validate_database.py
b/python/tvm/s_tir/meta_schedule/testing/validate_database.py
index f266e6ac3a..125c88be51 100644
--- a/python/tvm/s_tir/meta_schedule/testing/validate_database.py
+++ b/python/tvm/s_tir/meta_schedule/testing/validate_database.py
@@ -26,6 +26,7 @@ from statistics import mean
from typing import Any
import numpy as np # type: ignore
+import tvm_ffi
from tvm_ffi import get_global_func, register_global_func
import tvm
@@ -197,10 +198,10 @@ class OriginalModule:
self.mod = mod
def __eq__(self, __o: "OriginalModule") -> bool: # type: ignore
- return tvm.ir.structural_equal(self.mod, __o.mod)
+ return tvm_ffi.structural_equal(self.mod, __o.mod)
def __hash__(self) -> int:
- return tvm.ir.structural_hash(self.mod)
+ return tvm_ffi.structural_hash(self.mod)
def initializer() -> None:
diff --git a/python/tvm/tirx/analysis/analysis.py
b/python/tvm/tirx/analysis/analysis.py
index 6350eee7b5..bb89e9845d 100644
--- a/python/tvm/tirx/analysis/analysis.py
+++ b/python/tvm/tirx/analysis/analysis.py
@@ -48,18 +48,18 @@ def expr_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:
This function does not remap variable bindings, it will not
return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless
x.same_as(y).
- Use py:func:`tvm.ir.structural_equal` to handle structural variable
remapping.
+ Use py:func:`tvm_ffi.structural_equal` to handle structural variable
remapping.
Due to the restriction of not remapping variables, this function can run
faster than StructuralEqual and can be used as a utility function during
arithmetic
simplifications.
- Always consider py:func:`tvm.ir.structural_equal` first, which handles
+ Always consider py:func:`tvm_ffi.structural_equal` first, which handles
the structural remapping.
See Also
--------
- tvm.ir.structural_equal
+ tvm_ffi.structural_equal
"""
return _ffi_api.expr_deep_equal(lhs, rhs) # type: ignore
diff --git a/tests/python/arith/test_arith_canonical_simplify.py
b/tests/python/arith/test_arith_canonical_simplify.py
index ce89db9c99..35ecf3b700 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: E731, F841
+import tvm_ffi
+
import tvm
import tvm.testing
from tvm import te, tirx
@@ -38,7 +40,7 @@ class CanonicalChecker:
def verify(self, data, expected):
res = self.analyzer.canonical_simplify(data)
expected = self._convert(expected)
- assert tvm.ir.structural_equal(res, expected), (
+ assert tvm_ffi.structural_equal(res, expected), (
f"\ndata={data}\nres={res}\nexpected={expected}"
)
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index 17a8397ce9..31c944e179 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -19,6 +19,7 @@
import inspect
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -81,7 +82,7 @@ class BaseCompare:
with analyzer.constraint_scope(test_case.constraint):
after = analyzer.rewrite_simplify(test_case.before)
- assert tvm.ir.structural_equal(after, test_case.expected), (
+ assert tvm_ffi.structural_equal(after, test_case.expected), (
f"Rewrite didn't match expected.\n"
f"Before = {test_case.before}\n"
f"After = {after}\n"
diff --git a/tests/python/arith/test_arith_solve_linear_equations.py
b/tests/python/arith/test_arith_solve_linear_equations.py
index d1218fc351..b9550d570d 100644
--- a/tests/python/arith/test_arith_solve_linear_equations.py
+++ b/tests/python/arith/test_arith_solve_linear_equations.py
@@ -19,6 +19,7 @@ import random
import sys
import pytest
+import tvm_ffi
import tvm
from tvm import arith, ir, testing, tirx
@@ -98,8 +99,8 @@ def test_empty_var_to_solve():
assert len(solution.dst_to_src) == 0
assert len(solution.src.variables) == 0
assert len(solution.src.ranges) == 0
- assert ir.structural_equal(solution.src.relations, equations)
- assert ir.structural_equal(solution.src, solution.dst)
+ assert tvm_ffi.structural_equal(solution.src.relations, equations)
+ assert tvm_ffi.structural_equal(solution.src, solution.dst)
def test_unique_solution():
@@ -113,8 +114,8 @@ def test_unique_solution():
[x, y],
)
assert list(solution.dst.variables) == []
- assert ir.structural_equal(solution.src_to_dst[x], T.int32(15))
- assert ir.structural_equal(solution.src_to_dst[y], T.int32(5))
+ assert tvm_ffi.structural_equal(solution.src_to_dst[x], T.int32(15))
+ assert tvm_ffi.structural_equal(solution.src_to_dst[y], T.int32(5))
def test_low_rank():
@@ -130,9 +131,9 @@ def test_low_rank():
ranges,
)
[n0] = solution.dst.variables
- assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
- assert ir.structural_equal(solution.src_to_dst[y], -n0)
- assert ir.structural_equal(solution.src_to_dst[z], T.int32(5))
+ assert tvm_ffi.structural_equal(solution.src_to_dst[x], n0 + 10)
+ assert tvm_ffi.structural_equal(solution.src_to_dst[y], -n0)
+ assert tvm_ffi.structural_equal(solution.src_to_dst[z], T.int32(5))
def test_infer_range():
@@ -150,16 +151,16 @@ def test_infer_range():
ranges,
)
[n0] = solution.dst.variables
- assert ir.structural_equal(solution.src_to_dst[x], n0)
- assert ir.structural_equal(solution.src_to_dst[y], -n0)
+ assert tvm_ffi.structural_equal(solution.src_to_dst[x], n0)
+ assert tvm_ffi.structural_equal(solution.src_to_dst[y], -n0)
# inferred from y's range
- assert ir.structural_equal(solution.dst.ranges[n0].min, T.int32(-9))
- assert ir.structural_equal(solution.dst.ranges[n0].extent, T.int32(10))
+ assert tvm_ffi.structural_equal(solution.dst.ranges[n0].min, T.int32(-9))
+ assert tvm_ffi.structural_equal(solution.dst.ranges[n0].extent,
T.int32(10))
# additional inequality is added into the system for x
[ineq] = solution.dst.relations
assert isinstance(ineq, tvm.tirx.LE)
- assert ir.structural_equal(ineq.a, T.int32(-5))
- assert ir.structural_equal(ineq.b, n0)
+ assert tvm_ffi.structural_equal(ineq.a, T.int32(-5))
+ assert tvm_ffi.structural_equal(ineq.b, n0)
def test_ill_formed():
diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py
b/tests/python/arith/test_arith_solve_linear_inequality.py
index 8050f73b44..04b109fef3 100644
--- a/tests/python/arith/test_arith_solve_linear_inequality.py
+++ b/tests/python/arith/test_arith_solve_linear_inequality.py
@@ -18,6 +18,7 @@ import random
import sys
import pytest
+import tvm_ffi
import tvm
from tvm import arith, ir, testing, tirx
@@ -99,10 +100,10 @@ def test_dual_variable():
# solution as conditions
solution = arith._ffi_api.SolveInequalitiesAsCondition(variables, ranges,
problem)
- assert ir.structural_equal(solution[0], x >= (y + 10))
- assert ir.structural_equal(solution[1], x <= (20 - y))
- assert ir.structural_equal(solution[2], y >= 0)
- assert ir.structural_equal(solution[3], y <= 5)
+ assert tvm_ffi.structural_equal(solution[0], x >= (y + 10))
+ assert tvm_ffi.structural_equal(solution[1], x <= (20 - y))
+ assert tvm_ffi.structural_equal(solution[2], y >= 0)
+ assert tvm_ffi.structural_equal(solution[3], y <= 5)
# solve and get the ranges
solution = arith.solve_linear_inequalities(problem, variables, ranges)
@@ -110,22 +111,22 @@ def test_dual_variable():
assert solution.ranges[y].min == 0
assert solution.ranges[y].extent == 6
# y + 10 <= x <= 20 - y
- assert ir.structural_equal(solution.ranges[x].min, y + 10)
+ assert tvm_ffi.structural_equal(solution.ranges[x].min, y + 10)
assert solution.ranges[x].extent == 11 # max(10 - 2y)
# deskew the solved ranges to be starting from zero
solution = arith.solve_linear_inequalities(problem, variables, ranges,
deskew_range=True)
[x_new, y_new] = solution.dst.variables
[rel] = solution.dst.relations
- assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10)
- assert ir.structural_equal(solution.dst.ranges[x_new].min, T.int32(0))
- assert ir.structural_equal(solution.dst.ranges[x_new].extent, T.int32(11))
- assert ir.structural_equal(solution.dst.ranges[y_new].min, T.int32(0))
- assert ir.structural_equal(solution.dst.ranges[y_new].extent, T.int32(6))
- assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10))
- assert ir.structural_equal(solution.src_to_dst[y], y_new)
- assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10)
- assert ir.structural_equal(solution.dst_to_src[y_new], y)
+ assert tvm_ffi.structural_equal(rel, (y_new * 2) + x_new <= 10)
+ assert tvm_ffi.structural_equal(solution.dst.ranges[x_new].min, T.int32(0))
+ assert tvm_ffi.structural_equal(solution.dst.ranges[x_new].extent,
T.int32(11))
+ assert tvm_ffi.structural_equal(solution.dst.ranges[y_new].min, T.int32(0))
+ assert tvm_ffi.structural_equal(solution.dst.ranges[y_new].extent,
T.int32(6))
+ assert tvm_ffi.structural_equal(solution.src_to_dst[x], x_new + (y_new +
10))
+ assert tvm_ffi.structural_equal(solution.src_to_dst[y], y_new)
+ assert tvm_ffi.structural_equal(solution.dst_to_src[x_new], x - y - 10)
+ assert tvm_ffi.structural_equal(solution.dst_to_src[y_new], y)
def test_equal():
@@ -163,7 +164,7 @@ def test_multi_equal():
assert solution.ranges[x].min == 6
assert solution.ranges[x].extent == 1
assert len(solution.relations) == 3
- assert ir.structural_equal(solution.relations[0], x == z * y)
+ assert tvm_ffi.structural_equal(solution.relations[0], x == z * y)
assert isinstance(solution.relations[1], tvm.tirx.LE)
assert solution.relations[1].b == 0
@@ -172,9 +173,9 @@ def test_multi_equal():
# (z*y - 6) <= 0 && (6 - z*y) <= 0
ana = tvm.arith.Analyzer()
assert ana.simplify(solution.relations[1].a + solution.relations[2].a) == 0
- assert ir.structural_equal(solution.relations[1].a, (z * y - 6)) or
ir.structural_equal(
- solution.relations[2].a, (z * y - 6)
- )
+ assert tvm_ffi.structural_equal(
+ solution.relations[1].a, (z * y - 6)
+ ) or tvm_ffi.structural_equal(solution.relations[2].a, (z * y - 6))
solution = arith.solve_linear_inequalities(problem, [x, y, z],
deskew_range=True)
assert solution.src_to_dst[y] == y
diff --git a/tests/python/contrib/test_hexagon/test_take.py
b/tests/python/contrib/test_hexagon/test_take.py
index 04debadacc..4d54c89ce7 100644
--- a/tests/python/contrib/test_hexagon/test_take.py
+++ b/tests/python/contrib/test_hexagon/test_take.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-docstring, invalid-name, unused-argument,
not-callable
import numpy as np
+import tvm_ffi
from scipy import special
import tvm
@@ -388,5 +389,5 @@ def test_structural():
]
for mod in Modules:
after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(mod)
- assert not tvm.ir.structural_equal(after["main"], mod["main"])
+ assert not tvm_ffi.structural_equal(after["main"], mod["main"])
print("Passed Structural")
diff --git a/tests/python/ir/test_container_structural_equal.py
b/tests/python/ir/test_container_structural_equal.py
deleted file mode 100644
index 1d9d575af8..0000000000
--- a/tests/python/ir/test_container_structural_equal.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import pytest
-import tvm_ffi
-from tvm_ffi.access_path import AccessPath
-
-import tvm
-import tvm.testing
-from tvm.ir.base import get_first_structural_mismatch
-
-
-def get_first_mismatch_ensure_symmetry(a, b):
- mismatch = get_first_structural_mismatch(a, b)
- mismatch_swapped = get_first_structural_mismatch(b, a)
-
- if mismatch is None and mismatch_swapped is None:
- return None
-
- if (
- mismatch is None
- or mismatch_swapped is None
- or mismatch[0] != mismatch_swapped[1]
- or mismatch[1] != mismatch_swapped[0]
- ):
- raise AssertionError(
- "get_first_structural_mismatch(a, b) and
get_first_structural_mismatch(b, a) returned"
- f" inconsistent results '{mismatch}' and '{mismatch_swapped}' for
a='{a}', b='{b}'"
- )
-
- a_path, b_path = mismatch
- b_path_swapped, a_path_swapped = mismatch_swapped
- assert a_path == a_path_swapped
- assert b_path == b_path_swapped
-
- return mismatch
-
-
[email protected](
- "a, b, expected_a_path, expected_b_path",
- [
- (
- [1, 2, 3],
- [1, 4, 3],
- AccessPath.root().array_item(1),
- AccessPath.root().array_item(1),
- ),
- (
- [1, 2, 3],
- [10, 2, 30],
- AccessPath.root().array_item(0),
- AccessPath.root().array_item(0),
- ),
- (
- [1, 3, 4],
- [1, 2, 3, 4],
- AccessPath.root().array_item(1),
- AccessPath.root().array_item(1),
- ),
- (
- [1, 2, 3],
- [1, 2, 3, 4],
- AccessPath.root().array_item_missing(3),
- AccessPath.root().array_item(3),
- ),
- (
- [],
- [1],
- AccessPath.root().array_item_missing(0),
- AccessPath.root().array_item(0),
- ),
- ],
-)
-def test_array_structural_mismatch(a, b, expected_a_path, expected_b_path):
- a = tvm.runtime.convert(a)
- b = tvm.runtime.convert(b)
- a_path, b_path = get_first_mismatch_ensure_symmetry(a, b)
- assert a_path == expected_a_path
- assert b_path == expected_b_path
-
-
[email protected](
- "contents",
- [
- [],
- [1],
- [1, 2, 3],
- ],
-)
-def test_array_structural_equal_to_self(contents):
- a = tvm.runtime.convert(list(contents))
- b = tvm.runtime.convert(list(contents))
- assert get_first_mismatch_ensure_symmetry(a, b) is None
-
-
[email protected](
- "contents",
- [
- [],
- [1],
- [1, 2, 3],
- ],
-)
-def test_shape_tuple_structural_equal_to_self(contents):
- a = tvm_ffi.Shape(list(contents))
- b = tvm_ffi.Shape(list(contents))
- assert get_first_mismatch_ensure_symmetry(a, b) is None
-
-
[email protected](
- "contents",
- [
- {},
- {"a": 1, "b": 2},
- {"a": True, "b": False},
- ],
-)
-def test_string_map_structural_equal_to_self(contents):
- a = tvm.runtime.convert({**contents})
- b = tvm.runtime.convert({**contents})
- assert get_first_mismatch_ensure_symmetry(a, b) is None
-
-
[email protected](
- "a, b, expected_a_path, expected_b_path",
- [
- (
- dict(a=3, b=4),
- dict(a=3, b=5),
- AccessPath.root().map_item("b"),
- AccessPath.root().map_item("b"),
- ),
- (
- dict(a=3, b=4),
- dict(a=3, b=4, c=5),
- AccessPath.root().map_item_missing("c"),
- AccessPath.root().map_item("c"),
- ),
- ],
-)
-def test_string_map_structural_mismatch(a, b, expected_a_path,
expected_b_path):
- a = tvm.runtime.convert(a)
- b = tvm.runtime.convert(b)
- a_path, b_path = get_first_mismatch_ensure_symmetry(a, b)
- assert a_path == expected_a_path
- assert b_path == expected_b_path
-
-
[email protected](
- "contents",
- [
- dict(),
- dict(a=1),
- dict(a=3, b=4, c=5),
- ],
-)
-def test_string_structural_equal_to_self(contents):
- a = tvm.runtime.convert(dict(contents))
- b = tvm.runtime.convert(dict(contents))
- assert get_first_mismatch_ensure_symmetry(a, b) is None
-
-
-if __name__ == "__main__":
- tvm.testing.main()
diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py
index 7074215505..25480f7265 100644
--- a/tests/python/ir/test_ir_attrs.py
+++ b/tests/python/ir/test_ir_attrs.py
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: F841
+import pytest
+import tvm_ffi
+
import tvm
@@ -36,9 +39,22 @@ def test_attrs_equal():
dattr1 = tvm.ir.make_node("ir.DictAttrs", y=[10, 20], x=1)
dattr2 = tvm.ir.make_node("ir.DictAttrs", x=1, y=None)
tvm.ir.assert_structural_equal(dattr0, dattr1)
- assert not tvm.ir.structural_equal(dattr0, dattr2)
- assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1))
- assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1))
+ assert not tvm_ffi.structural_equal(dattr0, dattr2)
+ assert not tvm_ffi.structural_equal({"x": 1}, tvm.runtime.convert(1))
+ assert not tvm_ffi.structural_equal([1, 2], tvm.runtime.convert(1))
+
+
+def test_assert_structural_equal_reports_mismatch():
+ dattr0 = tvm.ir.make_node("ir.DictAttrs", x=1, y=[10, 20])
+ dattr1 = tvm.ir.make_node("ir.DictAttrs", x=1, y=[10, 30])
+
+ with pytest.raises(ValueError) as err:
+ tvm.ir.assert_structural_equal(dattr0, dattr1)
+
+ message = str(err.value)
+ assert "StructuralEqual check failed" in message
+ assert "caused by lhs at" in message
+ assert "and rhs at" in message
if __name__ == "__main__":
diff --git a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
b/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
index c9b38aeaa6..1bac08412a 100644
--- a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
+++ b/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
@@ -17,20 +17,21 @@
# ruff: noqa: F401
import pytest
+import tvm_ffi
import tvm
import tvm.testing
from tvm import TVMError, tirx
from tvm import relax as rx
-from tvm.ir import Range, structural_equal
+from tvm.ir import Range
def _check_equal(x, y, map_free_vars=False):
tvm.ir.assert_structural_equal(x, y, map_free_vars)
tvm.ir.assert_structural_equal(y, x, map_free_vars)
- xhash = tvm.ir.structural_hash(x, map_free_vars)
- yhash = tvm.ir.structural_hash(y, map_free_vars)
+ xhash = tvm_ffi.structural_hash(x, map_free_vars)
+ yhash = tvm_ffi.structural_hash(y, map_free_vars)
assert xhash == yhash
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py
b/tests/python/relax/test_analysis_struct_info_analysis.py
index dbcc94db83..e2141bf94d 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -19,6 +19,7 @@
"""Tests analysis functions of struct info"""
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -708,7 +709,7 @@ def test_prim_struct_info_lca(test_case):
lhs, rhs, expected = map(_normalize_sinfo, test_case)
lca = rx.analysis.struct_info_lca(lhs, rhs)
- assert tvm.ir.structural_equal(lca, expected), (
+ assert tvm_ffi.structural_equal(lca, expected), (
f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead
found {lca}"
)
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index a647100cae..303557e81a 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -20,6 +20,7 @@ import functools
import math
import pytest
+import tvm_ffi
import tvm.testing
from tvm import relax as rx
@@ -239,7 +240,7 @@ def test_shape_pattern():
shape = [32, 32]
pattern = wildcard().has_shape(shape)
assert isinstance(pattern, ShapePattern)
- tvm.ir.structural_equal(pattern.shape, shape)
+ tvm_ffi.structural_equal(pattern.shape, shape)
assert pattern.match(bindings[0].var)
assert wildcard().has_shape([32, 32]).match(bindings[0].var)
n, m = tirx.Var("n", dtype="int64"), tirx.Var("m", dtype="int64")
@@ -1478,7 +1479,7 @@ def
test_rewrite_without_trivial_binding(bind_to_dataflow_var):
arg = matches[pattern_arg]
shape_expr = matches[pattern_shape_expr]
- if tvm.ir.structural_equal(arg.struct_info.shape, shape_expr):
+ if tvm_ffi.structural_equal(arg.struct_info.shape, shape_expr):
return arg
else:
return expr
@@ -1755,7 +1756,9 @@ def test_iterative_rewrite_with_removed_intermediates():
if pat_unwrap_concat_split in matches:
args = matches[pat_args]
- if len(args) == 2 and tvm.ir.structural_equal(args[0].struct_info,
args[1].struct_info):
+ if len(args) == 2 and tvm_ffi.structural_equal(
+ args[0].struct_info, args[1].struct_info
+ ):
return args
elif pat_add_self in matches:
diff --git
a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py
b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py
index 2c0d22bd3f..0666919799 100644
--- a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py
+++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py
@@ -20,6 +20,8 @@
# The test attempts to eliminate redundant pad branch and overcompute the
value for elementwise ops.
# This helps to expose more opportunities to vectorize the code.
+import tvm_ffi
+
import tvm
import tvm.script
import tvm.testing
@@ -626,17 +628,17 @@ class MulExpected:
def test_add_primfunc_overcompute():
add_after = tvm.s_tir.transform.UseAssumeToReduceBranches()(AddBefore)
- tvm.ir.structural_equal(add_after["add"], AddExpected["add"],
map_free_vars=True)
+ tvm_ffi.structural_equal(add_after["add"], AddExpected["add"],
map_free_vars=True)
def test_sub_primfunc_overcompute():
sub_after = tvm.s_tir.transform.UseAssumeToReduceBranches()(SubBefore)
- tvm.ir.structural_equal(sub_after["sub"], SubExpected["sub"],
map_free_vars=True)
+ tvm_ffi.structural_equal(sub_after["sub"], SubExpected["sub"],
map_free_vars=True)
def test_mul_primfunc_overcompute():
mul_after = tvm.s_tir.transform.UseAssumeToReduceBranches()(MulBefore)
- tvm.ir.structural_equal(mul_after["mul"], MulExpected["mul"],
map_free_vars=True)
+ tvm_ffi.structural_equal(mul_after["mul"], MulExpected["mul"],
map_free_vars=True)
if __name__ == "__main__":
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index b9f12c2f4a..30a59ae30d 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -17,6 +17,7 @@
# ruff: noqa: F811
import numpy as np
import pytest
+import tvm_ffi
import tvm
from tvm import relax as rx
@@ -29,8 +30,8 @@ def _check_equal(x, y, map_free_vars=False):
tvm.ir.assert_structural_equal(x, y, map_free_vars)
tvm.ir.assert_structural_equal(y, x, map_free_vars)
- xhash = tvm.ir.structural_hash(x, map_free_vars)
- yhash = tvm.ir.structural_hash(y, map_free_vars)
+ xhash = tvm_ffi.structural_hash(x, map_free_vars)
+ yhash = tvm_ffi.structural_hash(y, map_free_vars)
assert xhash == yhash
diff --git a/tests/python/relax/test_frontend_nn_extern_module.py
b/tests/python/relax/test_frontend_nn_extern_module.py
index 7f884f0dfb..a304766399 100644
--- a/tests/python/relax/test_frontend_nn_extern_module.py
+++ b/tests/python/relax/test_frontend_nn_extern_module.py
@@ -21,6 +21,7 @@ import tempfile
from pathlib import Path
import numpy as np
+import tvm_ffi
import tvm
import tvm.testing
@@ -40,7 +41,7 @@ def _infer_scalar_add(x, y): # pylint: disable=invalid-name
def _infer_test_sym(a, b): # pylint: disable=invalid-name
def _var_equal(a, b): # pylint: disable=invalid-name
- return tvm.ir.structural_equal(a, b, map_free_vars=True)
+ return tvm_ffi.structural_equal(a, b, map_free_vars=True)
assert isinstance(a, nn.Tensor)
assert isinstance(b, nn.Tensor)
diff --git a/tests/python/relax/test_struct_info.py
b/tests/python/relax/test_struct_info.py
index 31060f11b3..622f1e369b 100644
--- a/tests/python/relax/test_struct_info.py
+++ b/tests/python/relax/test_struct_info.py
@@ -16,6 +16,7 @@
# under the License.
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -27,8 +28,8 @@ def _check_equal(x, y, map_free_vars=False):
tvm.ir.assert_structural_equal(x, y, map_free_vars)
tvm.ir.assert_structural_equal(y, x, map_free_vars)
- xhash = tvm.ir.structural_hash(x, map_free_vars)
- yhash = tvm.ir.structural_hash(y, map_free_vars)
+ xhash = tvm_ffi.structural_hash(x, map_free_vars)
+ yhash = tvm_ffi.structural_hash(y, map_free_vars)
assert xhash == yhash
@@ -95,7 +96,7 @@ def test_prim_struct_info_with_expr():
sinfo = rx.PrimStructInfo(value=n + 1)
_check_equal(sinfo, rx.PrimStructInfo(value=n + 1))
- assert not tvm.ir.structural_equal(sinfo, rx.PrimStructInfo(dtype=n.dtype))
+ assert not tvm_ffi.structural_equal(sinfo,
rx.PrimStructInfo(dtype=n.dtype))
# can turn into str
str(sinfo)
diff --git a/tests/python/relax/test_transform_lambda_lift.py
b/tests/python/relax/test_transform_lambda_lift.py
index 2d3b91ec01..113bab4525 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -16,6 +16,8 @@
# under the License.
# ruff: noqa: F841
+import tvm_ffi
+
import tvm
import tvm.script
import tvm.testing
@@ -31,8 +33,8 @@ def _check_equal(x, y):
tvm.ir.assert_structural_equal(x, y)
tvm.ir.assert_structural_equal(y, x)
- xhash = tvm.ir.structural_hash(x, map_free_vars=True)
- yhash = tvm.ir.structural_hash(y, map_free_vars=True)
+ xhash = tvm_ffi.structural_hash(x, map_free_vars=True)
+ yhash = tvm_ffi.structural_hash(y, map_free_vars=True)
assert xhash == yhash
diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py
b/tests/python/relax/test_transform_meta_schedule_tuning.py
index d3d0992f47..65f04f2dc7 100644
--- a/tests/python/relax/test_transform_meta_schedule_tuning.py
+++ b/tests/python/relax/test_transform_meta_schedule_tuning.py
@@ -34,6 +34,8 @@
import tempfile
+import tvm_ffi
+
import tvm
import tvm.s_tir.meta_schedule as ms
import tvm.testing
@@ -114,7 +116,7 @@ def test_ms_tuning_irmodule():
application_pass =
relax.transform.MetaScheduleApplyDatabase(work_dir)
out_mod = application_pass(mod)
- assert not tvm.ir.structural_equal(mod, out_mod)
+ assert not tvm_ffi.structural_equal(mod, out_mod)
def test_ms_tuning_primfunc():
@@ -141,7 +143,7 @@ def test_ms_tuning_primfunc():
application_pass =
relax.transform.MetaScheduleApplyDatabase(work_dir)
out_mod = application_pass(mod)
- assert not tvm.ir.structural_equal(mod, out_mod)
+ assert not tvm_ffi.structural_equal(mod, out_mod)
with tempfile.TemporaryDirectory() as work_dir:
with target, PassContext(opt_level=0):
diff --git
a/tests/python/relax/test_transform_operator_specific_normalization.py
b/tests/python/relax/test_transform_operator_specific_normalization.py
index 8fd1c15f06..aa3c44d08c 100644
--- a/tests/python/relax/test_transform_operator_specific_normalization.py
+++ b/tests/python/relax/test_transform_operator_specific_normalization.py
@@ -19,6 +19,7 @@
"""Test FNormalize usage"""
import pytest
+import tvm_ffi
import tvm
import tvm.relax.testing.transform
@@ -112,7 +113,7 @@ def
test_normalization_applied_during_cpp_mutator(custom_op):
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
- assert not tvm.ir.structural_equal(Before, After)
+ assert not tvm_ffi.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
@@ -133,7 +134,7 @@ def
test_normalization_applied_during_python_mutator(custom_op):
after = EmptyPyExprMutator().visit_expr(before)
- assert not tvm.ir.structural_equal(before, after)
+ assert not tvm_ffi.structural_equal(before, after)
tvm.ir.assert_structural_equal(expected, after)
@@ -210,7 +211,7 @@ def test_normalize_to_inline_tuple_for_call_tir(custom_op):
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
- assert not tvm.ir.structural_equal(Before, After)
+ assert not tvm_ffi.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
@@ -256,7 +257,7 @@ def
test_normalize_argument_to_inline_tuple_for_call_tir(custom_op):
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
- assert not tvm.ir.structural_equal(Before, After)
+ assert not tvm_ffi.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
@@ -307,7 +308,7 @@ def
test_normalize_to_inline_tuple_for_call_tir_inplace(custom_op):
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
- assert not tvm.ir.structural_equal(Before, After)
+ assert not tvm_ffi.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
@@ -372,7 +373,7 @@ def
test_normalize_to_inline_tuple_for_call_tir_with_grad(custom_op):
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
- assert not tvm.ir.structural_equal(Before, After)
+ assert not tvm_ffi.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 341ba66025..80637edcc0 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -17,6 +17,7 @@
# ruff: noqa: E501, F841
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -703,7 +704,7 @@ def test_transform_is_no_op_when_disabled():
with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph":
False}):
AfterWhenDisabled = relax.transform.RewriteCUDAGraph()(Before)
- assert not tvm.ir.structural_equal(Before, AfterWhenEnabled)
+ assert not tvm_ffi.structural_equal(Before, AfterWhenEnabled)
tvm.ir.assert_structural_equal(Before, AfterWhenDisabled)
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
index ffe4945f68..645ae49966 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
@@ -24,6 +24,7 @@ from collections.abc import Callable
from typing import Optional
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -123,13 +124,13 @@ class PyMemoryDatabaseDefault(ms.database.PyDatabase):
def has_workload(self, mod: IRModule) -> bool:
for workload in self.workloads_:
- if tvm.ir.structural_equal(mod, workload.mod):
+ if tvm_ffi.structural_equal(mod, workload.mod):
return True
def commit_workload(self, mod: IRModule) -> ms.database.Workload:
if self.has_workload(mod):
for workload in self.workloads_:
- if tvm.ir.structural_equal(mod, workload.mod):
+ if tvm_ffi.structural_equal(mod, workload.mod):
return workload
else:
workload = ms.database.Workload(mod)
@@ -146,7 +147,7 @@ class PyMemoryDatabaseDefault(ms.database.PyDatabase):
return sorted(
list(
filter(
- lambda x: tvm.ir.structural_equal(workload.mod,
x.workload.mod),
+ lambda x: tvm_ffi.structural_equal(workload.mod,
x.workload.mod),
self.tuning_records_,
)
),
@@ -166,13 +167,13 @@ class PyMemoryDatabaseOverride(ms.database.PyDatabase):
def has_workload(self, mod: IRModule) -> bool:
for workload in self.workloads_:
- if tvm.ir.structural_equal(mod, workload.mod):
+ if tvm_ffi.structural_equal(mod, workload.mod):
return True
def commit_workload(self, mod: IRModule) -> ms.database.Workload:
if self.has_workload(mod):
for workload in self.workloads_:
- if tvm.ir.structural_equal(mod, workload.mod):
+ if tvm_ffi.structural_equal(mod, workload.mod):
return workload
else:
workload = ms.database.Workload(mod)
@@ -189,7 +190,7 @@ class PyMemoryDatabaseOverride(ms.database.PyDatabase):
return sorted(
list(
filter(
- lambda x: tvm.ir.structural_equal(workload.mod,
x.workload.mod),
+ lambda x: tvm_ffi.structural_equal(workload.mod,
x.workload.mod),
self.tuning_records_,
)
),
@@ -482,17 +483,17 @@ def test_meta_schedule_pydatabase_default_query():
record = query(db, mod, target, "record")
assert record is not None and record.run_secs[0].value == 1.0
sch_res = query(db, mod, target, "schedule")
- assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod,
sch.mod)
+ assert sch_res is not None and tvm_ffi.structural_equal(sch_res.mod,
sch.mod)
mod_res = query(db, mod, target, "ir_module")
- assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod)
+ assert mod_res is not None and tvm_ffi.structural_equal(mod_res, sch.mod)
commit_record(Schedule(mod).trace, db, 0.2) # Empty Trace
record = query(db, mod, target, "record")
assert record is not None and record.run_secs[0].value == 0.2
sch_res = query(db, mod, target, "schedule")
- assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, mod)
+ assert sch_res is not None and tvm_ffi.structural_equal(sch_res.mod, mod)
mod_res = query(db, mod, target, "ir_module")
- assert mod_res is not None and tvm.ir.structural_equal(mod_res, mod)
+ assert mod_res is not None and tvm_ffi.structural_equal(mod_res, mod)
def test_meta_schedule_pydatabase_override_query():
@@ -521,17 +522,17 @@ def test_meta_schedule_pydatabase_override_query():
record = query(db, mod, target, "record")
assert record is not None and record.run_secs[0].value == 1.14
sch_res = query(db, mod, target, "schedule")
- assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod,
sch.mod)
+ assert sch_res is not None and tvm_ffi.structural_equal(sch_res.mod,
sch.mod)
mod_res = query(db, mod, target, "ir_module")
- assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod)
+ assert mod_res is not None and tvm_ffi.structural_equal(mod_res, sch.mod)
commit_record(Schedule(mod).trace, db, 0.514) # Empty Trace
record = query(db, mod, target, "record")
assert record is not None and record.run_secs[0].value == 1.14 # Override
to 2nd best
sch_res = query(db, mod, target, "schedule")
- assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod,
sch.mod)
+ assert sch_res is not None and tvm_ffi.structural_equal(sch_res.mod,
sch.mod)
mod_res = query(db, mod, target, "ir_module")
- assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod)
+ assert mod_res is not None and tvm_ffi.structural_equal(mod_res, sch.mod)
def test_meta_schedule_pydatabase_current():
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
index 46d71ca6e7..1dee52eb57 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
@@ -21,6 +21,7 @@ import math
import sys
import pytest
+import tvm_ffi
from tvm_ffi import register_global_func
import tvm
@@ -255,7 +256,7 @@ def test_meta_schedule_post_order_apply():
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 1
- assert not tvm.ir.structural_equal(schs[0].mod, mod)
+ assert not tvm_ffi.structural_equal(schs[0].mod, mod)
_check_correct(schs[0])
@@ -275,7 +276,7 @@ def test_meta_schedule_post_order_apply_double():
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 2
for sch in schs:
- assert not tvm.ir.structural_equal(sch.mod, mod)
+ assert not tvm_ffi.structural_equal(sch.mod, mod)
_check_correct(sch)
@@ -295,7 +296,7 @@ def test_meta_schedule_post_order_apply_multiple():
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 4
for sch in schs:
- assert not tvm.ir.structural_equal(sch.mod, mod)
+ assert not tvm_ffi.structural_equal(sch.mod, mod)
_check_correct(sch)
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py
index 35d56a5fc9..d5ff427aa4 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py
@@ -20,6 +20,7 @@
import sys
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -55,7 +56,7 @@ def test_tune_context_create():
assert context.num_threads > 0
assert context.rand_state != -1
assert context.task_name == "Test Task"
- assert context.mod == mod or tvm.ir.structural_equal(context.mod, mod)
+ assert context.mod == mod or tvm_ffi.structural_equal(context.mod, mod)
if __name__ == "__main__":
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_reduction.py
b/tests/python/s_tir/schedule/test_tir_schedule_reduction.py
index 4311d5785e..1643b13df0 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_reduction.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_reduction.py
@@ -19,6 +19,7 @@
import sys
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -293,12 +294,12 @@ def test_reduction_decompose_with_different_for_kind():
def test_decompose_reduction_ref_hash_check():
mod = tvm.IRModule.from_expr(matmul.with_attr("global_symbol", "main"))
mod_bak = mod
- hash_before = tvm.ir.structural_hash(mod_bak)
+ hash_before = tvm_ffi.structural_hash(mod_bak)
s = tvm.s_tir.Schedule(mod["main"], debug_mask="all")
C = s.get_sblock("update")
i, j, k = s.get_loops(C)
s.decompose_reduction(C, k)
- hash_after = tvm.ir.structural_hash(mod_bak)
+ hash_after = tvm_ffi.structural_hash(mod_bak)
assert hash_before == hash_after
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_state.py
b/tests/python/s_tir/schedule/test_tir_schedule_state.py
index 1738523468..43a4a84f5e 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_state.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_state.py
@@ -20,6 +20,7 @@ import gc
import sys
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -173,7 +174,7 @@ def test_replace_direct_write1():
s.replace(sref, target)
# There is no other reference so the AST node can be written directly
assert old_hash == s.mod["main"].body.block.body.__hash__()
- assert not tvm.ir.structural_equal(hold_ref.body, target)
+ assert not tvm_ffi.structural_equal(hold_ref.body, target)
# Check the replaced part is equal to the target
tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target)
# The target reuse `sref.stmt`, so the sref won't be None
@@ -189,7 +190,7 @@ def test_replace_copy():
s.replace(sref, target)
# We need to copy the whole func to remain the old_func unchanged
assert old_hash != s.mod["main"].__hash__()
- assert not tvm.ir.structural_equal(old_func.body, s.mod["main"].body)
+ assert not tvm_ffi.structural_equal(old_func.body, s.mod["main"].body)
assert old_hash == old_func.__hash__()
# Check the replaced part is equal to the target
tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target)
@@ -208,7 +209,7 @@ def test_replace_partial_copy0():
# The stmt is held by `hold_sref`, so it will be coped in copy-on-write
# because the ref count is not unique
assert ref_old_hash != s.mod["main"].body.block.body[0].__hash__()
- assert not tvm.ir.structural_equal(hold_ref.body, target)
+ assert not tvm_ffi.structural_equal(hold_ref.body, target)
# The function and the other part stmt can be directly written
assert func_old_hash == s.mod["main"].__hash__()
assert other_part_hash == s.mod["main"].body.block.body[1].__hash__()
@@ -228,7 +229,7 @@ def test_replace_partial_copy1():
s.replace(sref, target)
# The parent stmt will change since there is only one reference
assert stmt_old_hash == s.mod["main"].body.block.body[0].__hash__()
- assert not tvm.ir.structural_equal(hold_ref.body, target)
+ assert not tvm_ffi.structural_equal(hold_ref.body, target)
# The function and the other part stmt can be directly written
assert func_old_hash == s.mod["main"].__hash__()
assert other_part_hash == s.mod["main"].body.block.body[1].__hash__()
@@ -259,7 +260,7 @@ def test_replace_root_copy0():
tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
# Check the original func remains unchanged
assert old_hash == func_ref.__hash__()
- assert not tvm.ir.structural_equal(func_ref.body, target)
+ assert not tvm_ffi.structural_equal(func_ref.body, target)
def test_replace_root_copy1():
@@ -273,7 +274,7 @@ def test_replace_root_copy1():
tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target)
# Check the original func remains unchanged
assert old_hash == func_ref.__hash__()
- assert not tvm.ir.structural_equal(func_ref.body, target)
+ assert not tvm_ffi.structural_equal(func_ref.body, target)
def test_replace_root_copy2():
@@ -288,7 +289,7 @@ def test_replace_root_copy2():
# Check the original func remains unchanged
assert old_hash == func_ref.__hash__()
for _, v in func_ref.items():
- assert not tvm.ir.structural_equal(v.body.block, target)
+ assert not tvm_ffi.structural_equal(v.body.block, target)
def test_replace_root_copy3():
@@ -302,7 +303,7 @@ def test_replace_root_copy3():
tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
# Check the original func remains unchanged
assert old_hash == func_ref.__hash__()
- assert not tvm.ir.structural_equal(func_ref["main"].body.block, target)
+ assert not tvm_ffi.structural_equal(func_ref["main"].body.block, target)
def test_replace_block_remap():
@@ -349,7 +350,7 @@ def test_replace_ir_module():
tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
# Check the original func remains unchanged
assert old_hash == func_ref.__hash__()
- assert not tvm.ir.structural_equal(func_ref.body, target)
+ assert not tvm_ffi.structural_equal(func_ref.body, target)
assert other_func_hash == s.mod["other"].__hash__()
diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
index 66fb3d9a5d..d59aedb4d2 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
@@ -17,6 +17,7 @@
# ruff: noqa: E741, F401
import numpy as np
import pytest
+import tvm_ffi
import tvm
from tvm import s_tir
@@ -478,7 +479,7 @@ def test_hoisting_block_scope_2():
config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
- assert not tvm.ir.structural_equal(new_stmt, stmt)
+ assert not tvm_ffi.structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_5():
@@ -496,7 +497,7 @@ def test_hoisting_block_scope_5():
stmt = Module["main"].body
new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
- assert not tvm.ir.structural_equal(new_stmt, stmt)
+ assert not tvm_ffi.structural_equal(new_stmt, stmt)
mod = tvm.IRModule.from_expr(tvm.tirx.PrimFunc([], new_stmt))
stmt = new_stmt
@@ -533,7 +534,7 @@ def test_hoisting_block_scope_6():
config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
- assert not tvm.ir.structural_equal(new_stmt, stmt)
+ assert not tvm_ffi.structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_7():
@@ -561,7 +562,7 @@ def test_hoisting_block_scope_7():
config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
- assert not tvm.ir.structural_equal(new_stmt, stmt)
+ assert not tvm_ffi.structural_equal(new_stmt, stmt)
if __name__ == "__main__":
diff --git a/tests/python/tirx-base/test_tir_constructor.py
b/tests/python/tirx-base/test_tir_constructor.py
index eda7fd9ebf..b2628e1b00 100644
--- a/tests/python/tirx-base/test_tir_constructor.py
+++ b/tests/python/tirx-base/test_tir_constructor.py
@@ -16,6 +16,7 @@
# under the License.
import pytest
+import tvm_ffi
import tvm
from tvm import te, topi
@@ -144,7 +145,7 @@ def test_expr_constructor():
attrs={"disable_tma": True},
)
assert x_with_attrs.attrs["disable_tma"] is True
- assert not tvm.ir.structural_equal(x, x_with_attrs)
+ assert not tvm_ffi.structural_equal(x, x_with_attrs)
script = tvm.tirx.Evaluate(x_with_attrs).script()
assert "attrs" in script
assert "disable_tma" in script
diff --git a/tests/python/tirx-base/test_tir_structural_equal_hash.py
b/tests/python/tirx-base/test_tir_structural_equal_hash.py
deleted file mode 100644
index 1efef38e3f..0000000000
--- a/tests/python/tirx-base/test_tir_structural_equal_hash.py
+++ /dev/null
@@ -1,440 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import numpy as np
-import pytest
-from tvm_ffi.access_path import AccessPath
-
-import tvm
-from tvm.script import ir as I
-from tvm.script import tirx as T
-
-
-def consistent_equal(x, y, map_free_vars=False):
- struct_equal0 = tvm.ir.structural_equal(x, y, map_free_vars)
- struct_equal1 = tvm.ir.structural_equal(y, x, map_free_vars)
-
- xhash = tvm.ir.structural_hash(x, map_free_vars)
- yhash = tvm.ir.structural_hash(y, map_free_vars)
-
- if struct_equal0 != struct_equal1:
- raise ValueError(
- f"Non-commutative {x} vs {y}, sequal0={struct_equal0},
sequal1={struct_equal1}"
- )
-
- # NOTE: hash colision can happen but should be rare.
- # we can confirm that hash colison doesn't happen for our testcases
- if struct_equal0 != (xhash == yhash):
- raise ValueError(
- f"Inconsistent {x} vs {y}, sequal={struct_equal0}, xhash={xhash},
yhash={yhash}"
- )
- return struct_equal0
-
-
-def get_sequal_mismatch(x, y, map_free_vars=False):
- mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars)
- mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars)
-
- if mismatch_0 is None and mismatch_1 is None:
- return None
-
- if (
- mismatch_0 is None
- or mismatch_1 is None
- or mismatch_0[0] != mismatch_1[1]
- or mismatch_0[1] != mismatch_1[0]
- ):
- raise ValueError(
- f"Non-commutative {x} vs {y}, mismatch_0={mismatch_0},
mismatch_1={mismatch_1}"
- )
-
- return mismatch_0
-
-
-def test_exprs():
- # save load json
- x = tvm.tirx.const(1, "int32")
- y = tvm.tirx.const(10, "int32")
- vx = tvm.tirx.Var("x", "int32")
- vy = tvm.tirx.Var("y", "int32")
- vz = tvm.tirx.Var("z", "int32")
- zx = vx + vx
- zy = vy + vy
-
- assert consistent_equal(zx * zx, (vx + vx) * (vx + vx),
map_free_vars=False)
-
- # test assert trigger.
- with pytest.raises(ValueError):
- tvm.ir.assert_structural_equal(x, y)
-
- assert not consistent_equal(vx, vy)
- assert consistent_equal(vx, vy, map_free_vars=True)
- # corner case lhs:vx == rhs:vy, but cannot map it iteslf
- assert not consistent_equal(vx + vx, vy + vx, map_free_vars=True)
- # corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
- assert consistent_equal(vx + vy, vy + vx, map_free_vars=True)
- # corner case2: rolling remap.
- assert consistent_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
- assert not consistent_equal(vx + 1, vy + 1, map_free_vars=False)
- # Defintition remap
- assert consistent_equal(tvm.tirx.Let(vx, 1, vx - 1), tvm.tirx.Let(vy, 1,
vy - 1))
- # Default same address free var remap
- assert consistent_equal(tvm.tirx.Let(vx, 1, vx // vz), tvm.tirx.Let(vy, 1,
vy // vz))
-
- assert consistent_equal(zx * zx, zx * zx)
- assert consistent_equal(zx * zx, zy * zy, map_free_vars=True)
- assert not consistent_equal(zx * zx, zy * zy, map_free_vars=False)
-
-
-def test_prim_func():
- x = tvm.tirx.Var("x", "int32")
- y = tvm.tirx.Var("y", "int32")
- # counter example of same equality
- func0 = tvm.tirx.PrimFunc([x, y], tvm.tirx.Evaluate(x + y))
- func1 = tvm.tirx.PrimFunc([x, y], tvm.tirx.Evaluate(y + x))
- assert not consistent_equal(func0, func1)
-
- # new cases
- b = tvm.tirx.decl_buffer((x,), "float32")
- stmt = tvm.tirx.SeqStmt([tvm.tirx.Bind(x, 10), tvm.tirx.Evaluate(x + 1)])
- func0 = tvm.tirx.PrimFunc([x, y, b], stmt)
- # easiest way to deep copy is via save/load
- func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
- tvm.ir.assert_structural_equal(func0, func1)
-
- data0 = tvm.runtime.tensor([1, 2, 3])
- data1 = tvm.runtime.tensor([1, 2, 3])
- # attributes and ndarrays
- func0 = func0.with_attr("data", data0)
- func1 = func1.with_attr("data", data1)
- # IRModules
- mod0 = tvm.IRModule.from_expr(func0)
- mod1 = tvm.IRModule.from_expr(func1)
- tvm.ir.assert_structural_equal(mod0, mod1)
-
-
-def test_prim_func_param_count_mismatch():
- x = tvm.tirx.Var("x", "int32")
- y = tvm.tirx.Var("y", "int32")
- z = tvm.tirx.Var("z", "int32")
- # counter example of same equality
- func0 = tvm.tirx.PrimFunc([x, y], tvm.tirx.Evaluate(x))
- func1 = tvm.tirx.PrimFunc([x, y, z], tvm.tirx.Evaluate(x))
- lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
- expected_lhs_path = AccessPath.root().attr("params").array_item_missing(2)
- expected_rhs_path = AccessPath.root().attr("params").array_item(2)
- assert lhs_path == expected_lhs_path
- assert rhs_path == expected_rhs_path
-
-
-def test_prim_func_param_dtype_mismatch():
- x = tvm.tirx.Var("x", "int32")
- y_0 = tvm.tirx.Var("y", "int32")
- y_1 = tvm.tirx.Var("z", "float32")
- # counter example of same equality
- func0 = tvm.tirx.PrimFunc([x, y_0], tvm.tirx.Evaluate(x))
- func1 = tvm.tirx.PrimFunc([x, y_1], tvm.tirx.Evaluate(x))
- lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
- expected_path =
AccessPath.root().attr("params").array_item(1).attr("dtype")
- assert lhs_path == expected_path
- assert rhs_path == expected_path
-
-
-def test_prim_func_body_mismatch():
- x_0 = tvm.tirx.Var("x", "int32")
- y_0 = tvm.tirx.Var("y", "int32")
- x_1 = tvm.tirx.Var("x", "int32")
- y_1 = tvm.tirx.Var("y", "int32")
- # counter example of same equality
- func0 = tvm.tirx.PrimFunc([x_0, y_0], tvm.tirx.Evaluate(x_0 + x_0))
- func1 = tvm.tirx.PrimFunc([x_1, y_1], tvm.tirx.Evaluate(x_1 + y_1))
- lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
- expected_path = AccessPath.root().attr("body").attr("value").attr("b")
- assert lhs_path == expected_path
- assert rhs_path == expected_path
-
-
-def test_array():
- x = np.arange(10)
- nx = tvm.runtime.tensor(x)
- ny = tvm.runtime.tensor(x)
- nz = tvm.runtime.tensor(x.reshape(2, 5))
- assert consistent_equal(nx, ny)
- assert not consistent_equal(nx, nz)
-
-
-def test_env_func():
- @tvm.register_global_func("test.sequal.env_func")
- def test(x):
- return x + 1
-
- x = tvm.ir.EnvFunc.get("test.sequal.env_func")
- y = tvm.ir.EnvFunc.get("test.sequal.env_func")
- assert consistent_equal(y, x)
-
-
-def test_stmt():
- @T.prim_func(private=True, check_well_formed=False, s_tir=True)
- def func2(A: T.handle, n_param: T.int32):
- n_var = T.var("int32")
- Ab = T.match_buffer(A, (n_var,))
- for i in T.serial(n_var):
- Ab[i] = Ab[i] + T.float32(1)
- for j in T.serial(10):
- Ab[j] = Ab[j] + T.float32(2)
- Ab[j] = Ab[j] + T.float32(2)
-
- assert consistent_equal(func2.body, func2.body)
-
-
-def test_buffer_storage_scope():
- x = tvm.tirx.Var("x", "handle")
-
- buffer_local_0 = tvm.tirx.decl_buffer((10, 10), "float32", scope="local")
- buffer_local_1 = tvm.tirx.decl_buffer((10, 10), "float32", scope="local")
- buffer_global = tvm.tirx.decl_buffer((10, 10), "float32")
- buffer_empty = tvm.tirx.decl_buffer((10, 10), "float32", scope="")
-
- func0 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_local_0})
- func1 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_local_1})
- func2 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_global})
- func3 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_empty})
-
- assert consistent_equal(func0, func1)
- assert consistent_equal(func2, func3)
- assert not consistent_equal(func0, func2)
-
-
-def test_buffer_map_mismatch():
- x = tvm.tirx.Var("x", "int32")
- buffer_0 = tvm.tirx.decl_buffer((10, 10))
- buffer_0_clone = tvm.tirx.decl_buffer((10, 10))
- buffer_1 = tvm.tirx.decl_buffer((10, 20))
-
- func_0 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_0})
- func_0_clone = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_0_clone})
- func_1 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_1})
-
- lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
- expected_path = (
-
AccessPath.root().attr("buffer_map").map_item(x).attr("shape").array_item(1).attr("value")
- )
- assert lhs_path == expected_path
- assert rhs_path == expected_path
-
- assert get_sequal_mismatch(func_0, func_0_clone) is None
-
-
-def test_buffer_map_length_mismatch():
- x = tvm.tirx.Var("x", "int32")
- y = tvm.tirx.Var("x", "int32")
-
- buffer_0 = tvm.tirx.decl_buffer((10, 10))
- buffer_1 = tvm.tirx.decl_buffer((10, 20))
-
- func_0 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_0})
- func_1 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x:
buffer_0, y: buffer_1})
-
- lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
-
- expected_lhs_path =
AccessPath.root().attr("buffer_map").map_item_missing(y)
- assert lhs_path == expected_lhs_path
- expected_rhs_path = AccessPath.root().attr("buffer_map").map_item(y)
- assert rhs_path == expected_rhs_path
-
-
-def test_buffer_load_store():
- b = tvm.tirx.decl_buffer((10, 10), "float32")
- x = tvm.tirx.BufferLoad(b, [0, 1])
- y = tvm.tirx.BufferLoad(b, [0, 1])
- z = tvm.tirx.BufferLoad(b, [1, 2])
- assert consistent_equal(y, x)
- assert not consistent_equal(y, z)
-
- i = tvm.tirx.Var("x", "int32")
- sx = tvm.tirx.BufferStore(b, 0.1, [0, i])
- sy = tvm.tirx.BufferStore(b, 0.1, [0, i])
- sz = tvm.tirx.BufferStore(b, 0.1, [1, i])
- assert consistent_equal(sy, sx)
- assert not consistent_equal(sy, sz)
-
-
-def test_while():
- x = tvm.tirx.Var("x", "int32")
- y = tvm.tirx.Var("y", "int32")
- wx = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x))
- wy = tvm.tirx.While(y > 0, tvm.tirx.Evaluate(y))
- assert not consistent_equal(wx, wy)
- assert consistent_equal(wx, wy, map_free_vars=True)
-
-
-def test_while_condition_mismatch():
- x = tvm.tirx.Var("x", "int32")
- w_0 = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x))
- w_1 = tvm.tirx.While(x < 0, tvm.tirx.Evaluate(x))
- lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
- expected_path = AccessPath.root().attr("condition")
- assert lhs_path == expected_path
- assert rhs_path == expected_path
-
-
-def test_while_body_mismatch():
- x = tvm.tirx.Var("x", "int32")
- w_0 = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x))
- w_1 = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x + 1))
- lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
- expected_path = AccessPath.root().attr("body").attr("value")
- assert lhs_path == expected_path
- assert rhs_path == expected_path
-
-
-def test_seq_mismatch():
- x = tvm.tirx.Var("x", "int32")
- seq_0 = tvm.tirx.SeqStmt(
- [
- tvm.tirx.Evaluate(x),
- tvm.tirx.Evaluate(x + 1),
- tvm.tirx.Evaluate(x + 2),
- tvm.tirx.Evaluate(x + 3),
- ]
- )
- seq_1 = tvm.tirx.SeqStmt(
- [
- tvm.tirx.Evaluate(x),
- tvm.tirx.Evaluate(x + 1),
- tvm.tirx.Evaluate(x + 99),
- tvm.tirx.Evaluate(x + 3),
- ]
- )
- lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
- expected_path = (
-
AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value")
- )
- assert lhs_path == expected_path
- assert rhs_path == expected_path
-
-
-def test_seq_mismatch_different_lengths():
- # Make sure we report a difference inside the array first, rather than the
difference in length
- x = tvm.tirx.Var("x", "int32")
- seq_0 = tvm.tirx.SeqStmt(
- [
- tvm.tirx.Evaluate(x),
- tvm.tirx.Evaluate(x + 1),
- tvm.tirx.Evaluate(x + 2),
- tvm.tirx.Evaluate(x + 3),
- ]
- )
- seq_1 = tvm.tirx.SeqStmt(
- [tvm.tirx.Evaluate(x), tvm.tirx.Evaluate(x + 1), tvm.tirx.Evaluate(x +
3)]
- )
- lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
- expected_path = (
-
AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value")
- )
- assert lhs_path == expected_path
- assert rhs_path == expected_path
-
-
-def test_seq_length_mismatch():
- x = tvm.tirx.Var("x", "int32")
- seq_0 = tvm.tirx.SeqStmt(
- [
- tvm.tirx.Evaluate(x),
- tvm.tirx.Evaluate(x + 1),
- tvm.tirx.Evaluate(x + 2),
- tvm.tirx.Evaluate(x + 3),
- ]
- )
- seq_1 = tvm.tirx.SeqStmt(
- [tvm.tirx.Evaluate(x), tvm.tirx.Evaluate(x + 1), tvm.tirx.Evaluate(x +
2)]
- )
- lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
- expected_lhs_path = AccessPath.root().attr("seq").array_item(3)
- expected_rhs_path = AccessPath.root().attr("seq").array_item_missing(3)
- assert lhs_path == expected_lhs_path
- assert rhs_path == expected_rhs_path
-
-
-def test_ir_module_equal():
- def generate(n: int):
- @I.ir_module
- class module:
- @T.prim_func(s_tir=True)
- def func(A: T.Buffer(1, "int32")):
- for i in range(n):
- A[0] = A[0] + 1
-
- return module
-
- # Equivalent IRModules should compare as equivalent, even though
- # they have distinct GlobalVars, and GlobalVars usually compare by
- # reference equality.
- tvm.ir.assert_structural_equal(generate(16), generate(16))
-
- # When there is a difference, the location should include the
- # function name that caused the failure.
- with pytest.raises(ValueError) as err:
- tvm.ir.assert_structural_equal(generate(16), generate(32))
-
- assert '<root>.functions[I.GlobalVar("func")].body.extent.value' in
err.value.args[0]
-
-
-def test_nan_values_are_equivalent():
- """Structural equality treats two NaN values as equivalent.
-
- By IEEE, a check of `NaN == NaN` returns false, as does
- `abs(NaN - NaN) < tolerance`. However, for the purpose of
- comparing IR representations, both NaN values are equivalent.
-
- """
-
- @T.prim_func(private=True, s_tir=True)
- def func_1():
- return T.float32("nan")
-
- @T.prim_func(private=True, s_tir=True)
- def func_2():
- return T.float32("nan")
-
- tvm.ir.assert_structural_equal(func_1, func_2)
- assert tvm.ir.structural_hash(func_1) == tvm.ir.structural_hash(func_2)
-
-
-def test_all_nan_values_are_equivalent():
- """Structural equality treats two NaN values as equivalent.
-
- IEEE defines NaN as any value that has all exponent bits set,
- and has a non-zero mantissa. For the purposes of comparing IR
- representations, all NaN values are considered equivalent.
-
- """
-
- # A NaN with the first payload bit set.
- nan_all_zeros = np.int32(0x7FC00000).view("float32")
-
- # A NaN with the last payload bit set.
- nan_with_payload = np.int32(0x7F800001).view("float32")
-
- float_1 = T.float32(nan_all_zeros)
- float_2 = T.float32(nan_with_payload)
-
- tvm.ir.assert_structural_equal(float_1, float_2)
- assert tvm.ir.structural_hash(float_1) == tvm.ir.structural_hash(float_2)
-
-
-if __name__ == "__main__":
- tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 81c63a58b7..9dac2b3f0a 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -20,6 +20,7 @@ import sys
import numpy as np
import pytest
+import tvm_ffi
import tvm
import tvm.testing
@@ -2491,7 +2492,7 @@ def test_void_ptr_vs_handle():
def handle(out_ret_value: T.handle):
T.evaluate(out_ret_value)
- assert not tvm.ir.structural_equal(void_ptr, handle)
+ assert not tvm_ffi.structural_equal(void_ptr, handle)
def void_ptr():