This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c0828bc8ad [REFACTOR][TEST] Migrate tir-transform tests from TE to
TVMScript (#18805)
c0828bc8ad is described below
commit c0828bc8ad0afe4ff637e59114700efcf283ca14
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 21 18:23:48 2026 -0500
[REFACTOR][TEST] Migrate tir-transform tests from TE to TVMScript (#18805)
This PR migrates te.var/te.compute usage to direct tvm.tir.Var
and PrimFunc construction in tir-transform test files.
---
.../python/arith/test_arith_canonical_simplify.py | 106 ++---
tests/python/arith/test_arith_const_int_bound.py | 65 ++-
tests/python/arith/test_arith_deduce_bound.py | 38 +-
tests/python/arith/test_arith_detect_clip_bound.py | 11 +-
.../arith/test_arith_detect_linear_equation.py | 11 +-
tests/python/arith/test_arith_intset.py | 32 +-
tests/python/arith/test_arith_modular_set.py | 41 +-
tests/python/arith/test_arith_rewrite_simplify.py | 462 +++++++++++----------
.../arith/test_arith_solve_linear_equations.py | 14 +-
.../arith/test_arith_solve_linear_inequality.py | 20 +-
tests/python/ir/test_ir_container.py | 18 +-
.../test_tir_analysis_expr_deep_equal.py | 7 +-
.../tir-analysis/test_tir_analysis_verify_ssa.py | 7 +-
tests/python/tir-base/test_tir_buffer.py | 45 +-
tests/python/tir-base/test_tir_constructor.py | 13 +-
tests/python/tir-base/test_tir_intrin.py | 2 +-
tests/python/tir-base/test_tir_nodes.py | 74 ++--
tests/python/tir-base/test_tir_ops.py | 31 +-
.../tir-base/test_tir_structural_equal_hash.py | 50 +--
.../test_tir_transform_common_subexpr_elim.py | 67 ++-
.../test_tir_transform_lower_intrin.py | 75 ++--
.../test_tir_transform_prim_func_pass.py | 7 +-
.../test_tir_transform_remove_no_op.py | 11 +-
23 files changed, 616 insertions(+), 591 deletions(-)
diff --git a/tests/python/arith/test_arith_canonical_simplify.py
b/tests/python/arith/test_arith_canonical_simplify.py
index 733d1d13b3..ca40bd2f25 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -44,7 +44,7 @@ class CanonicalChecker:
def test_mul_sum_simplify():
ck = CanonicalChecker()
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
ck.verify(2 + (3 * x + z + y + 1) * 4 + x, x * 13 + z * 4 + y * 4 + 6)
ck.verify(x * 3 - 4 * x + 1, 1 - x)
@@ -56,8 +56,8 @@ def test_mul_sum_simplify():
ck.verify(tmod(x + y + x + y * 3, 2), 0)
# floordiv
- fld = tvm.te.floordiv
- flm = tvm.te.floormod
+ fld = tvm.tir.floordiv
+ flm = tvm.tir.floormod
ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2))
ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x)
ck.verify(flm(x + y + x + y * 3, 2), 0)
@@ -66,7 +66,7 @@ def test_mul_sum_simplify():
def test_split_index_simplify():
ck = CanonicalChecker()
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
# trucdiv
tdiv = tvm.tir.truncdiv
@@ -96,8 +96,8 @@ def test_split_index_simplify():
ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
# floordiv
- fld = tvm.te.floordiv
- flm = tvm.te.floormod
+ fld = tvm.tir.floordiv
+ flm = tvm.tir.floormod
ck.verify(fld(x * 5, 2), fld(x * 5, 2))
ck.verify(fld(x, 3) * 3 + flm(x, 3), x)
ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)
@@ -114,7 +114,7 @@ def test_split_index_simplify():
def test_div_simplify():
ck = CanonicalChecker()
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
tdiv = tvm.tir.truncdiv
# truc div
@@ -129,7 +129,7 @@ def test_div_simplify():
ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16))
# floordiv
- fld = tvm.te.floordiv
+ fld = tvm.tir.floordiv
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)
ck.verify(fld(16 + 48 * x, 16), x * 3 + 1)
ck.verify(fld(17 + 48 * x, 16), x * 3 + 1)
@@ -154,8 +154,8 @@ def test_fp16_const_fold():
def test_floormod_simplify():
ck = CanonicalChecker()
- flm = tvm.te.floormod
- x, y = te.var("x"), te.var("y")
+ flm = tvm.tir.floormod
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) +
y + 12, 16))
ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4)
@@ -164,26 +164,26 @@ def test_floormod_simplify():
def test_canonical_mixed():
ck = CanonicalChecker()
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
z = tvm.tir.const(3, "int32")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0)
ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0)
ck.verify(x - 2 < 3, x < 5)
- ck.verify(tvm.te.max(x, 1) - tvm.te.max(x, 1), 0)
- ck.verify(tvm.te.min(x, 1) - tvm.te.min(x, 1), 0)
+ ck.verify(tvm.tir.max(x, 1) - tvm.tir.max(x, 1), 0)
+ ck.verify(tvm.tir.min(x, 1) - tvm.tir.min(x, 1), 0)
ck.verify(x * x - x * x, 0)
ck.verify(tmod(tdiv(tmod(x, 20), 2) * 2, 4), tdiv(tmod(x, 4), 2) * 2)
- fld = tvm.te.floordiv
+ fld = tvm.tir.floordiv
ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0)
ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0)
def test_reduce_combiner_simplify():
ck = CanonicalChecker()
- dummy = te.var("dummy")
+ dummy = tvm.tir.Var("dummy", "int32")
comm_reducer = te.comm_reducer
prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tir.const(1, t0))
@@ -262,13 +262,13 @@ def test_reduce_simplify():
ck.verify(te.sum(A[3], []), A[3])
ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0,
dtype="float32"))
# The rule below is not typical, removed for now
- ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k))
+ ck.verify(te.sum(tvm.tir.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"),
k))
def test_simplify_if_then_else():
ck = CanonicalChecker()
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
# simplification that takes condition into account.
@@ -315,8 +315,8 @@ def test_simplify_if_then_else():
def test_complex_cases():
ck = CanonicalChecker()
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
res2 = (
@@ -346,30 +346,30 @@ def test_complex_cases():
def test_simplify_cast():
ck = CanonicalChecker()
tcast = tvm.tir.Cast
- fld = tvm.te.floordiv
- flm = tvm.te.floormod
+ fld = tvm.tir.floordiv
+ flm = tvm.tir.floormod
# cast(i64, i + j + 1) - cast(i64, i)
- i = te.var("i", dtype="int32")
- j = te.var("j", dtype="int32")
+ i = tvm.tir.Var("i", "int32")
+ j = tvm.tir.Var("j", "int32")
res = tcast("int64", i + j + 1) - tcast("int64", i)
ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
# cast(i32, i + j + 1) - cast(i32, i)
- i = te.var("i", dtype="int64")
- j = te.var("j", dtype="int64")
+ i = tvm.tir.Var("i", "int64")
+ j = tvm.tir.Var("j", "int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j + 1) - tcast("int32", i)
ck.verify(res, tcast("int32", j) + 1)
# cast(i32, i + j - 100)
- i = te.var("i", dtype="int64")
- j = te.var("j", dtype="int64")
+ i = tvm.tir.Var("i", "int64")
+ j = tvm.tir.Var("j", "int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j - 100)
ck.verify(res, res)
# cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
# - cast(i32, flm(axis, 7i64) * 2i64)
- axis = te.var("axis", dtype="int64")
+ axis = tvm.tir.Var("axis", "int64")
ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
res = (
tcast(
@@ -385,26 +385,26 @@ def test_simplify_cast():
def test_simplify_normalize_min_value_expr():
ck = CanonicalChecker()
- x = te.var("x", "int32")
+ x = tvm.tir.Var("x", "int32")
- ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32"))
- ck.verify(te.min_value("int32") + x == 0, tir.const(False))
- ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32"))
- ck.verify(0 == te.min_value("int32") + x, tir.const(False))
- ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32"))
- ck.verify(x + te.min_value("int32") == 0, tir.const(False))
- ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32"))
- ck.verify(0 == x + te.min_value("int32"), tir.const(False))
+ ck.verify(tvm.tir.min_value("int32") - x == 0, x ==
tvm.tir.min_value("int32"))
+ ck.verify(tvm.tir.min_value("int32") + x == 0, tir.const(False))
+ ck.verify(0 == tvm.tir.min_value("int32") - x, x ==
tvm.tir.min_value("int32"))
+ ck.verify(0 == tvm.tir.min_value("int32") + x, tir.const(False))
+ ck.verify(-x + tvm.tir.min_value("int32") == 0, x ==
tvm.tir.min_value("int32"))
+ ck.verify(x + tvm.tir.min_value("int32") == 0, tir.const(False))
+ ck.verify(0 == -x + tvm.tir.min_value("int32"), x ==
tvm.tir.min_value("int32"))
+ ck.verify(0 == x + tvm.tir.min_value("int32"), tir.const(False))
def test_proddiv_simplify():
ck = CanonicalChecker()
- flm = tvm.te.floormod
- fld = tvm.te.floordiv
- tdiv = tvm.te.truncdiv
- tmod = tvm.te.truncmod
+ flm = tvm.tir.floormod
+ fld = tvm.tir.floordiv
+ tdiv = tvm.tir.truncdiv
+ tmod = tvm.tir.truncmod
- x, y, z = te.var("x"), te.var("y"), te.var("y")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("y", "int32")
ck.verify(flm(x * 32 * x, x), 0)
ck.verify(flm(z * x * 32 * x * y, x * z), 0)
@@ -428,15 +428,15 @@ def test_proddiv_simplify():
def test_floormod_two():
ck = CanonicalChecker()
- flm = tvm.te.floormod
- x, y = te.var("x"), te.var("y")
+ flm = tvm.tir.floormod
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)
def test_simplify_le():
ck = CanonicalChecker()
# Case 1. Ignore the extra expr if it's small than the division number
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
ck.analyzer.bind(y, tvm.ir.Range(0, 8))
ck.analyzer.bind(z, tvm.ir.Range(0, 2))
ck.verify(x * 8 + y < 16, x < 2)
@@ -449,16 +449,16 @@ def test_simplify_le():
ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
- n = te.size_var("n")
+ n = tvm.tir.SizeVar("n", "int32")
ck.verify(x * 8 + y < n, x * 8 + y < n)
# Case 2. Simplify the extra expr
x1, x2, ty, tx, vec = (
- tvm.te.var("x1"),
- tvm.te.var("x2"),
- tvm.te.var("ty"),
- tvm.te.var("tx"),
- tvm.te.var("vec"),
+ tvm.tir.Var("x1", "int32"),
+ tvm.tir.Var("x2", "int32"),
+ tvm.tir.Var("ty", "int32"),
+ tvm.tir.Var("tx", "int32"),
+ tvm.tir.Var("vec", "int32"),
)
ck.analyzer.bind(x1, tvm.ir.Range(0, 2))
ck.analyzer.bind(x2, tvm.ir.Range(0, 3))
@@ -472,7 +472,7 @@ def test_simplify_le():
ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8)
# Case 3. No failure
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
ck.analyzer.bind(y, tvm.ir.Range(0, 1024))
ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)
diff --git a/tests/python/arith/test_arith_const_int_bound.py
b/tests/python/arith/test_arith_const_int_bound.py
index 8728df7e3f..bd2a39b4f4 100644
--- a/tests/python/arith/test_arith_const_int_bound.py
+++ b/tests/python/arith/test_arith_const_int_bound.py
@@ -22,7 +22,6 @@ import pytest
import tvm
import tvm.testing
-from tvm import te
from tvm.arith import ConstIntBound
NEG_INF = ConstIntBound.NEG_INF
@@ -69,15 +68,15 @@ class BaseCompare:
class TestDataType(BaseCompare):
test_case = tvm.testing.parameter(
- TestCase(te.var("x", dtype="int64"), (NEG_INF, POS_INF)),
- TestCase(te.var("x", dtype="int8"), (-128, 127)),
- TestCase(te.var("x", dtype="uint8"), (0, 255)),
- TestCase(te.size_var("x", dtype="int32"), (0, POS_INF)),
+ TestCase(tvm.tir.Var("x", "int64"), (NEG_INF, POS_INF)),
+ TestCase(tvm.tir.Var("x", "int8"), (-128, 127)),
+ TestCase(tvm.tir.Var("x", "uint8"), (0, 255)),
+ TestCase(tvm.tir.SizeVar("x", "int32"), (0, POS_INF)),
)
class TestCastBound(BaseCompare):
- x = te.var("x", dtype="int8")
+ x = tvm.tir.Var("x", "int8")
tmod = tvm.tir.truncmod
test_case = tvm.testing.parameter(
@@ -87,8 +86,8 @@ class TestCastBound(BaseCompare):
class TestAddSubBound(BaseCompare):
- x = te.var("x", "int64")
- y = te.var("y", "int64")
+ x = tvm.tir.Var("x", "int64")
+ y = tvm.tir.Var("y", "int64")
test_case = tvm.testing.parameter(
TestCase(x + y, (NEG_INF, POS_INF)),
@@ -119,7 +118,7 @@ class TestBoundsUsingReciprocals(BaseCompare):
achieve its minimum while `A*B` simultaneously achieves its maximum.
"""
- A, B, C = [te.var(letter, "int64") for letter in "ABC"]
+ A, B, C = [tvm.tir.Var(letter, "int64") for letter in "ABC"]
symmetric_bounds = {A: (1, 4095), B: (1, 4095), C: (2048, 2048)}
asymmetric_bounds = {A: (1, 1024), B: (1, POS_INF), C: (2048, 2048)}
@@ -137,7 +136,7 @@ class TestBoundsUsingReciprocals(BaseCompare):
class TestMulBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
test_case = tvm.testing.parameter(
TestCase(x * y + 20, (0, 60), {x: (-2, 4), y: (4, 10)}),
@@ -147,7 +146,7 @@ class TestMulBound(BaseCompare):
class TestTruncDivBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
expr = tvm.tir.truncdiv(x, y)
@@ -160,7 +159,7 @@ class TestTruncDivBound(BaseCompare):
class TestTruncModBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
expr = tvm.tir.truncmod(x, y)
@@ -172,9 +171,9 @@ class TestTruncModBound(BaseCompare):
class TestFloorDivBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
- ux = te.var("x", dtype="uint32")
- uy = te.var("y", dtype="uint32")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
+ ux = tvm.tir.Var("x", "uint32")
+ uy = tvm.tir.Var("y", "uint32")
test_case = tvm.testing.parameter(
TestCase(x // y, (-9 // 4, None), {x: (-9, 4), y: (4, 10)}),
@@ -186,7 +185,7 @@ class TestFloorDivBound(BaseCompare):
class TestFloorModBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
test_case = tvm.testing.parameter(
TestCase(x % y, (0, 9), {x: (-9, 4), y: (4, 10)}),
@@ -196,18 +195,18 @@ class TestFloorModBound(BaseCompare):
class TestMinMaxBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
test_case = tvm.testing.parameter(
- TestCase(tvm.te.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}),
- TestCase(tvm.te.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y:
(4, 10)}),
- TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y:
(4, 10)}),
- TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4,
10)}),
+ TestCase(tvm.tir.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}),
+ TestCase(tvm.tir.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y:
(4, 10)}),
+ TestCase(tvm.tir.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y:
(4, 10)}),
+ TestCase(tvm.tir.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4,
10)}),
)
class TestSelectBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
test_case = tvm.testing.parameter(
TestCase(
@@ -219,7 +218,7 @@ class TestSelectBound(BaseCompare):
class TestShiftAndBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
test_case = tvm.testing.parameter(
TestCase(x >> y, (-3, 2), {x: (-9, 11), y: (2, 10)}),
@@ -229,7 +228,7 @@ class TestShiftAndBound(BaseCompare):
class TestMixIndexBound(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
@@ -243,15 +242,15 @@ class TestMixIndexBound(BaseCompare):
class TestLetBound(BaseCompare):
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
test_case = tvm.testing.parameter(
TestCase(tvm.tir.Let(x, 1, x + 1), (2, 2)),
)
class TestFloorModNegativeDivisor(BaseCompare):
- flm, fld = tvm.te.floormod, tvm.te.floordiv
- a, b = te.var("a"), te.var("b")
+ flm, fld = tvm.tir.floormod, tvm.tir.floordiv
+ a, b = tvm.tir.Var("a", "int32"), tvm.tir.Var("b", "int32")
test_case = tvm.testing.parameter(
TestCase(a % b, (-4, 6), {a: (0, 6), b: (-5, 7)}),
@@ -264,7 +263,7 @@ class TestDivModAssumeNoZeroDivisor(BaseCompare):
from symbolic shape programs
"""
- a, b = te.var("a"), te.var("b")
+ a, b = tvm.tir.Var("a", "int32"), tvm.tir.Var("b", "int32")
test_case = tvm.testing.parameter(
TestCase(a // b, (0, 6), {a: (0, 6), b: (0, POS_INF)}),
@@ -273,7 +272,7 @@ class TestDivModAssumeNoZeroDivisor(BaseCompare):
class TestMultipleCondition(BaseCompare):
- a = te.var("a")
+ a = tvm.tir.Var("a", "int32")
test_case = tvm.testing.parameter(
TestCase(
a % 58 - 1,
@@ -285,14 +284,14 @@ class TestMultipleCondition(BaseCompare):
class TestBroadcastBound(BaseCompare):
- a = te.var("a")
+ a = tvm.tir.Var("a", "int32")
test_case = tvm.testing.parameter(
TestCase(tvm.tir.Broadcast(a, 4), (0, 128), {a: (0, 128)}),
)
class TestRampBound(BaseCompare):
- a = te.var("a")
+ a = tvm.tir.Var("a", "int32")
test_case = tvm.testing.parameter(
TestCase(tvm.tir.Ramp(a, 2, 4) + 2, (2, 128 + 2 * 3 + 2), {a: (0,
128)}),
)
@@ -300,8 +299,8 @@ class TestRampBound(BaseCompare):
class TestModularSetBound(BaseCompare):
analyzer = tvm.arith.Analyzer()
- tx = tvm.te.var("tx", dtype="int32")
- bx = tvm.te.var("bx", dtype="int32")
+ tx = tvm.tir.Var("tx", "int32")
+ bx = tvm.tir.Var("bx", "int32")
expr = (bx * 2048 + tx * 16) % 7168
diff --git a/tests/python/arith/test_arith_deduce_bound.py
b/tests/python/arith/test_arith_deduce_bound.py
index a36fd21479..1a6bfb4925 100644
--- a/tests/python/arith/test_arith_deduce_bound.py
+++ b/tests/python/arith/test_arith_deduce_bound.py
@@ -17,22 +17,22 @@
import pytest
import tvm
import tvm.testing
-from tvm import te
+
from tvm.tir.buffer import decl_buffer
def test_deduce():
- a = te.var("a")
- b = te.var("b")
- c = te.var("c")
- d = te.var("d")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
+ c = tvm.tir.Var("c", "int32")
+ d = tvm.tir.Var("d", "int32")
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(10, 15)
d_s = tvm.arith.IntervalSet(-3, -1)
zero = tvm.tir.const(0, "int32")
- fdiv = tvm.te.floordiv
+ fdiv = tvm.tir.floordiv
e0 = (-b) * a + c - d
res0 = tvm.arith.deduce_bound(a, e0 >= 0, {b: b_s, c: c_s, d: d_s}, {})
@@ -62,13 +62,13 @@ def test_deduce():
res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
- e2 = tvm.te.max(5, a * 4) < 0
+ e2 = tvm.tir.max(5, a * 4) < 0
res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
# expression containing variable a is on rhs
- e2 = zero < tvm.te.max(5, a * 4)
+ e2 = zero < tvm.tir.max(5, a * 4)
res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
@@ -121,10 +121,10 @@ def test_deduce():
def test_check():
- a = te.var("a")
- b = te.var("b")
- c = te.var("c")
- d = te.var("d")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
+ c = tvm.tir.Var("c", "int32")
+ d = tvm.tir.Var("d", "int32")
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(5, 7)
@@ -145,8 +145,8 @@ def test_check():
def test_deduce_basic():
def test_basic(a1, a2, coff):
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = b + a * coff + 3
@@ -179,8 +179,8 @@ def test_deduce_basic():
def test_deduce_complex():
def test_complex(a1, a2, coff):
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = (b * 3 + a * coff) * 4
@@ -211,7 +211,7 @@ def test_deduce_complex():
def test_deduce_non_support():
- a = te.var("a")
+ a = tvm.tir.Var("a", "int32")
def test_non_support(lhs):
res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
@@ -232,7 +232,7 @@ def test_deduce_non_support():
def test_deduce_floordiv():
def do_test(gen_expr, dom_map, expect_min, expect_max):
- a = te.var("a")
+ a = tvm.tir.Var("a", "int32")
expr = gen_expr(a)
res = tvm.arith.deduce_bound(a, expr, dom_map, dom_map)
if isinstance(expect_min, str):
@@ -259,7 +259,7 @@ def test_deduce_floordiv():
do_test(lambda a: 8 // a >= 2, {}, "pos_inf", "neg_inf")
# test nested cases
- b = te.var("b")
+ b = tvm.tir.Var("b", "int32")
bs = {b: tvm.arith.IntervalSet(2, 6)}
do_test(lambda a: b * 3 + a // 8 < 63, bs, "neg_inf", 359)
do_test(lambda a: b * 3 + a // 8 <= 63, bs, "neg_inf", 367)
diff --git a/tests/python/arith/test_arith_detect_clip_bound.py
b/tests/python/arith/test_arith_detect_clip_bound.py
index 03fff11f77..830c6d4811 100644
--- a/tests/python/arith/test_arith_detect_clip_bound.py
+++ b/tests/python/arith/test_arith_detect_clip_bound.py
@@ -16,13 +16,12 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
def test_basic():
- a = te.var("a")
- b = te.var("b")
- c = te.var("c")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
+ c = tvm.tir.Var("c", "int32")
m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a])
tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1)
assert m[0].value == 2
@@ -40,8 +39,8 @@ def test_basic():
def test_trivial_eq():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
m = tvm.arith.detect_clip_bound(b == 3, [a, b])
tvm.testing.assert_prim_expr_equal(m[2], 3)
tvm.testing.assert_prim_expr_equal(m[3], 3)
diff --git a/tests/python/arith/test_arith_detect_linear_equation.py
b/tests/python/arith/test_arith_detect_linear_equation.py
index 829b101af3..d7902047ca 100644
--- a/tests/python/arith/test_arith_detect_linear_equation.py
+++ b/tests/python/arith/test_arith_detect_linear_equation.py
@@ -16,12 +16,11 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
def test_basic():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a])
assert m[0].value == 4
tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7)
@@ -43,14 +42,14 @@ def test_basic():
assert len(m) == 1
tvm.testing.assert_prim_expr_equal(m[0], b * 7)
- c = te.var("c", "uint32")
+ c = tvm.tir.Var("c", "uint32")
m = tvm.arith.detect_linear_equation(128 - c, [c])
assert m[0].value == -1
def test_multivariate():
- v = [te.var("v%d" % i) for i in range(4)]
- b = te.var("b")
+ v = [tvm.tir.Var("v%d" % i, "int32") for i in range(4)]
+ b = tvm.tir.Var("b", "int32")
m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
tvm.testing.assert_prim_expr_equal(m[0], b + 5)
diff --git a/tests/python/arith/test_arith_intset.py
b/tests/python/arith/test_arith_intset.py
index 04014ca300..e71b211836 100644
--- a/tests/python/arith/test_arith_intset.py
+++ b/tests/python/arith/test_arith_intset.py
@@ -16,7 +16,7 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
+
from tvm import tir
from tvm.arith.analyzer import Analyzer
@@ -64,7 +64,7 @@ def test_scalable_vector():
def test_add_sub():
ck = IntSetChecker()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y:
tvm.arith.IntervalSet(1, 11)}, (1, 21))
ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y:
tvm.arith.IntervalSet(1, 11)}, (-11, 9))
@@ -72,7 +72,7 @@ def test_add_sub():
def test_mul_div():
ck = IntSetChecker()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
tdiv = tvm.tir.truncdiv
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
@@ -83,20 +83,20 @@ def test_mul_div():
ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5))
- fld = tvm.te.floordiv
+ fld = tvm.tir.floordiv
ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
def test_mod():
ck = IntSetChecker()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
tmod = tvm.tir.truncmod
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9))
- flm = tvm.te.floormod
+ flm = tvm.tir.floormod
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5))
@@ -104,8 +104,8 @@ def test_mod():
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9))
- fld = tvm.te.floordiv
- z = te.var("z")
+ fld = tvm.tir.floordiv
+ z = tvm.tir.Var("z", "int32")
ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3))
ck.verify(
flm(y, 8),
@@ -124,16 +124,16 @@ def test_mod():
def test_max_min():
ck = IntSetChecker()
- x, y = te.var("x"), te.var("y")
- ck.verify(tvm.te.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11))
- ck.verify(tvm.te.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)},
(-1, 9))
- ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y)))
- ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y)))
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
+ ck.verify(tvm.tir.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1,
11))
+ ck.verify(tvm.tir.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)},
(-1, 9))
+ ck.verify(tvm.tir.min(x, y), {}, (tvm.tir.min(x, y), tvm.tir.min(x, y)))
+ ck.verify(tvm.tir.max(x, y), {}, (tvm.tir.max(x, y), tvm.tir.max(x, y)))
def test_select():
ck = IntSetChecker()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x:
tvm.arith.IntervalSet(0, 10)}, (-1, 11))
@@ -389,8 +389,8 @@ def test_union_lower_bound():
def test_modular_set():
ck = IntSetChecker()
- x = tvm.te.var("x", dtype="int32")
- y = tvm.te.var("y", dtype="int32")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
expr = (x * 2048 + y * 16) % 7168
ck.verify(
expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0,
3584)}, (0, 7152)
diff --git a/tests/python/arith/test_arith_modular_set.py
b/tests/python/arith/test_arith_modular_set.py
index 914402fb62..ef7b0593d0 100644
--- a/tests/python/arith/test_arith_modular_set.py
+++ b/tests/python/arith/test_arith_modular_set.py
@@ -16,12 +16,11 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
def test_cast():
analyzer = tvm.arith.Analyzer()
- x = te.var("x", dtype="int8")
+ x = tvm.tir.Var("x", "int8")
m = analyzer.modular_set((x * 3).astype("uint32"))
assert m.coeff == 3
assert m.base == 0
@@ -32,7 +31,7 @@ def test_cast():
def test_add_sub():
analyzer = tvm.arith.Analyzer()
- x, y = te.var("x", "int64"), te.var("y", "int64")
+ x, y = tvm.tir.Var("x", "int64"), tvm.tir.Var("y", "int64")
m = analyzer.modular_set(x * 6 + y * 4)
assert m.coeff == 2
assert m.base == 0
@@ -45,7 +44,7 @@ def test_add_sub():
def test_mul():
analyzer = tvm.arith.Analyzer()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1))
assert m.coeff == 4
assert m.base == 2
@@ -53,7 +52,7 @@ def test_mul():
def test_floormod():
analyzer = tvm.arith.Analyzer()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256))
assert m.coeff == 4
assert m.base == 0
@@ -61,7 +60,7 @@ def test_floormod():
def test_div_shift():
analyzer = tvm.arith.Analyzer()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
# not sure if x is non-negative
tdiv = tvm.tir.truncdiv
m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
@@ -71,7 +70,7 @@ def test_div_shift():
m = analyzer.modular_set((x * 4 + 2) >> 1)
assert m.coeff == 2
assert m.base == 1
- fld = tvm.te.floordiv
+ fld = tvm.tir.floordiv
m = analyzer.modular_set(fld(x * 4 + 2, 2))
assert m.coeff == 2
assert m.base == 1
@@ -84,7 +83,7 @@ def test_div_shift():
def test_mod():
analyzer = tvm.arith.Analyzer()
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
tmod = tvm.tir.truncmod
fmod = tvm.tir.floormod
# not sure if x is non-negative
@@ -111,12 +110,12 @@ def test_mod():
def test_min_max_select():
analyzer = tvm.arith.Analyzer()
- x, y = te.var("x"), te.var("y")
- m = analyzer.modular_set(tvm.te.min(x * 3, y * 9))
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
+ m = analyzer.modular_set(tvm.tir.min(x * 3, y * 9))
assert m.coeff == 3
assert m.base == 0
- m = analyzer.modular_set(tvm.te.max(x * 3 + 1, y * 9 + 4))
+ m = analyzer.modular_set(tvm.tir.max(x * 3 + 1, y * 9 + 4))
assert m.coeff == 3
assert m.base == 1
@@ -126,8 +125,8 @@ def test_min_max_select():
def test_mix_index():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
analyzer = tvm.arith.Analyzer()
tdiv = tvm.tir.truncdiv
m = analyzer.modular_set(a * 4 + b * 6 + 7)
@@ -150,14 +149,14 @@ def test_mix_index():
assert m.coeff == 3
assert m.base == 2
- m = analyzer.modular_set(a * 12 + tvm.te.min(b * 3 * 7, 2))
+ m = analyzer.modular_set(a * 12 + tvm.tir.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
def test_constraint_scope():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
analyzer = tvm.arith.Analyzer()
tmod = tvm.tir.truncmod
@@ -179,7 +178,7 @@ def test_constraint_scope():
def test_intersect():
- a = te.var("a")
+ a = tvm.tir.Var("a", "int32")
analyzer = tvm.arith.Analyzer()
tmod = tvm.tir.truncmod
with analyzer.constraint_scope(tmod(a, 4) == 1):
@@ -198,8 +197,8 @@ def test_intersect():
def test_let():
analyzer = tvm.arith.Analyzer()
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1))
assert m.coeff == 10
assert m.base == 1
@@ -207,8 +206,8 @@ def test_let():
def test_bitwise_and():
analyzer = tvm.arith.Analyzer()
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
# RHS of bitwise_and is 2^p - 1
m = analyzer.modular_set((x * 16 + y * 4) & 31)
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index b85a6b758c..a87d81a7b7 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -21,7 +21,7 @@ import pytest
import tvm
import tvm.testing
-from tvm import te, tir
+from tvm import tir
from tvm.tir import floordiv as fld
from tvm.tir import floormod as flm
from tvm.tir import truncdiv as tdiv
@@ -90,10 +90,10 @@ class BaseCompare:
class TestVector(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
- x64 = te.var("x", dtype="int64")
- vx = te.var("vx", dtype="int32x2")
- vc = te.var("vc", dtype="bool")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
+ x64 = tvm.tir.Var("x", "int64")
+ vx = tvm.tir.Var("vx", "int32x2")
+ vc = tvm.tir.Var("vc", "bool")
test_case = tvm.testing.parameter(
# Add rules
TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x
+ y, 3, 4)),
@@ -271,18 +271,20 @@ class TestVector(BaseCompare):
), # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6,
13, 20]
# Min/Max rules
TestCase(
- tvm.te.min(y.astype("int32x2"), x.astype("int32x2")),
tvm.te.min(y, x).astype("int32x2")
+ tvm.tir.min(y.astype("int32x2"), x.astype("int32x2")),
+ tvm.tir.min(y, x).astype("int32x2"),
),
TestCase(
- tvm.te.min(tvm.te.min(vx, y.astype("int32x2")),
x.astype("int32x2")),
- tvm.te.min(vx, tvm.te.min(y, x).astype("int32x2")),
+ tvm.tir.min(tvm.tir.min(vx, y.astype("int32x2")),
x.astype("int32x2")),
+ tvm.tir.min(vx, tvm.tir.min(y, x).astype("int32x2")),
),
TestCase(
- tvm.te.max(y.astype("int32x2"), x.astype("int32x2")),
tvm.te.max(y, x).astype("int32x2")
+ tvm.tir.max(y.astype("int32x2"), x.astype("int32x2")),
+ tvm.tir.max(y, x).astype("int32x2"),
),
TestCase(
- tvm.te.max(tvm.te.max(vx, y.astype("int32x2")),
x.astype("int32x2")),
- tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
+ tvm.tir.max(tvm.tir.max(vx, y.astype("int32x2")),
x.astype("int32x2")),
+ tvm.tir.max(vx, tvm.tir.max(y, x).astype("int32x2")),
),
## Logical rules
TestCase(y.astype("int32x2").equal(x.astype("int32x2")),
(y.equal(x)).astype("boolx2")),
@@ -306,7 +308,7 @@ class TestVector(BaseCompare):
class TestSelect(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
# Add rules
@@ -321,12 +323,12 @@ class TestSelect(BaseCompare):
TestCase(tvm.tir.Select(x < 0, y, z) - y, tvm.tir.Select(x < 0, 0, z -
y)),
TestCase(tvm.tir.Select(x < 0, y, z) - z, tvm.tir.Select(x < 0, y - z,
0)),
TestCase(
- tvm.te.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1,
z)),
- tvm.tir.Select(x < 0, tvm.te.min(y, 1), tvm.te.min(0, z)),
+ tvm.tir.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1,
z)),
+ tvm.tir.Select(x < 0, tvm.tir.min(y, 1), tvm.tir.min(0, z)),
),
TestCase(
- tvm.te.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1,
z)),
- tvm.tir.Select(x < 0, tvm.te.max(y, 1), tvm.te.max(0, z)),
+ tvm.tir.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1,
z)),
+ tvm.tir.Select(x < 0, tvm.tir.max(y, 1), tvm.tir.max(0, z)),
),
TestCase(tvm.tir.Select(x * 3 + 1 != 0, y, z), y),
TestCase(tvm.tir.Select(x * 3 + 1 == 0, y, z), z),
@@ -371,31 +373,31 @@ class TestCancellation(BaseCompare):
class TestAddIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
TestCase(x + (y - x), y),
TestCase(x - (y + 1) + (y + 1), x),
TestCase((x - 10) + (10 - z), x - z),
TestCase((x - y) + (z - x), z - y),
- TestCase(tvm.te.min(x, y - z) + z, tvm.te.min(x + z, y)),
- TestCase(tvm.te.min(x - z, y) + z, tvm.te.min(x, y + z)),
- TestCase(tvm.te.max(x, y - 10) + 10, tvm.te.max(x + 10, y)),
- TestCase(tvm.te.max(x - 11, y) + 11, tvm.te.max(x, y + 11)),
- TestCase(tvm.te.max(x, y * 2) + tvm.te.min(x, y * 2), x + y * 2),
- TestCase(tvm.te.min(x, y * 2) + tvm.te.max(x, y * 2), x + y * 2),
- TestCase(tvm.te.max(x, y + 2) + (-2), tvm.te.max(x + (-2), y)),
- TestCase(tvm.te.min(x, y + 2) + (-2), tvm.te.min(x + (-2), y)),
- TestCase(tvm.te.min(x + 2, y + 3) + (-2), tvm.te.min(x, y + 1)),
- TestCase(tvm.te.max(0, 1 - x * 4) + x * 4, tvm.te.max(x * 4, 1)),
- TestCase(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)),
- TestCase(tvm.te.min(0, 1 - x * 4) + x * 4, tvm.te.min(x * 4, 1)),
- TestCase(tvm.te.min(2 - x * 4, 0) + x * 4, tvm.te.min(x * 4, 2)),
+ TestCase(tvm.tir.min(x, y - z) + z, tvm.tir.min(x + z, y)),
+ TestCase(tvm.tir.min(x - z, y) + z, tvm.tir.min(x, y + z)),
+ TestCase(tvm.tir.max(x, y - 10) + 10, tvm.tir.max(x + 10, y)),
+ TestCase(tvm.tir.max(x - 11, y) + 11, tvm.tir.max(x, y + 11)),
+ TestCase(tvm.tir.max(x, y * 2) + tvm.tir.min(x, y * 2), x + y * 2),
+ TestCase(tvm.tir.min(x, y * 2) + tvm.tir.max(x, y * 2), x + y * 2),
+ TestCase(tvm.tir.max(x, y + 2) + (-2), tvm.tir.max(x + (-2), y)),
+ TestCase(tvm.tir.min(x, y + 2) + (-2), tvm.tir.min(x + (-2), y)),
+ TestCase(tvm.tir.min(x + 2, y + 3) + (-2), tvm.tir.min(x, y + 1)),
+ TestCase(tvm.tir.max(0, 1 - x * 4) + x * 4, tvm.tir.max(x * 4, 1)),
+ TestCase(tvm.tir.max(2 - x * 4, 0) + x * 4, tvm.tir.max(x * 4, 2)),
+ TestCase(tvm.tir.min(0, 1 - x * 4) + x * 4, tvm.tir.min(x * 4, 1)),
+ TestCase(tvm.tir.min(2 - x * 4, 0) + x * 4, tvm.tir.min(x * 4, 2)),
TestCase(x * y + x * 10, (y + 10) * x),
TestCase(y * x + x * 10, (y + 10) * x),
TestCase(y * x + 10 * x, (y + 10) * x),
TestCase(x * y + 10 * x, (y + 10) * x),
- TestCase((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2),
y)),
+ TestCase((2 * z) + tvm.tir.min(x, y - (2 * z)), tvm.tir.min(x + (z *
2), y)),
TestCase(y * x + x, (y + 1) * x),
TestCase(x * y + x, (y + 1) * x),
TestCase((x + 10) + 13, x + 23),
@@ -419,21 +421,21 @@ class TestAddIndex(BaseCompare):
class TestSubIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
TestCase(x + y - y, x),
TestCase(x + y - x, y),
TestCase(x - (y + x), 0 - y),
TestCase(x - (x + y), 0 - y),
- TestCase(tvm.te.min(x, y) - x, tvm.te.min(0, y - x)),
- TestCase(tvm.te.min(x, y) - y, tvm.te.min(x - y, 0)),
- TestCase(tvm.te.max(x, y) - x, tvm.te.max(0, y - x)),
- TestCase(tvm.te.max(x, y) - y, tvm.te.max(x - y, 0)),
- TestCase(x - tvm.te.min(x, y), tvm.te.max(0, x - y)),
- TestCase(y - tvm.te.min(x, y), tvm.te.max(y - x, 0)),
- TestCase(x - tvm.te.max(x, y), tvm.te.min(0, x - y)),
- TestCase(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)),
+ TestCase(tvm.tir.min(x, y) - x, tvm.tir.min(0, y - x)),
+ TestCase(tvm.tir.min(x, y) - y, tvm.tir.min(x - y, 0)),
+ TestCase(tvm.tir.max(x, y) - x, tvm.tir.max(0, y - x)),
+ TestCase(tvm.tir.max(x, y) - y, tvm.tir.max(x - y, 0)),
+ TestCase(x - tvm.tir.min(x, y), tvm.tir.max(0, x - y)),
+ TestCase(y - tvm.tir.min(x, y), tvm.tir.max(y - x, 0)),
+ TestCase(x - tvm.tir.max(x, y), tvm.tir.min(0, x - y)),
+ TestCase(y - tvm.tir.max(x, y), tvm.tir.min(y - x, 0)),
# mul co-efficient foldng
TestCase(x - x, 0),
TestCase(x * y - x, (y + (-1)) * x),
@@ -446,26 +448,26 @@ class TestSubIndex(BaseCompare):
TestCase((y + x) - (x + z), y - z),
TestCase((x + y) - (z + x), y - z),
TestCase((y + x) - (z + x), y - z),
- TestCase(tvm.te.min(x + y, z) - x, tvm.te.min(y, z - x)),
- TestCase(tvm.te.min(y + x, z) - x, tvm.te.min(y, z - x)),
- TestCase(tvm.te.min(z, x + y) - x, tvm.te.min(z - x, y)),
- TestCase(tvm.te.min(z, y + x) - x, tvm.te.min(z - x, y)),
- TestCase(tvm.te.max(x + y, z) - x, tvm.te.max(y, z - x)),
- TestCase(tvm.te.max(y + x, z) - x, tvm.te.max(y, z - x)),
- TestCase(tvm.te.max(z, x + y) - x, tvm.te.max(z - x, y)),
- TestCase(tvm.te.max(z, y + x) - x, tvm.te.max(z - x, y)),
- TestCase(x - tvm.te.min(x + y, z), tvm.te.max(0 - y, x - z)),
- TestCase(x - tvm.te.min(y + x, z), tvm.te.max(0 - y, x - z)),
- TestCase(x - tvm.te.min(z, x + y), tvm.te.max(x - z, 0 - y)),
- TestCase(x - tvm.te.min(z, y + x), tvm.te.max(x - z, 0 - y)),
- TestCase(tvm.te.min(x, y) - tvm.te.min(y, x), 0),
- TestCase(tvm.te.max(x, y) - tvm.te.max(y, x), 0),
- TestCase(tvm.te.min(x, y) - tvm.te.min(x + 10, y + 10), -10),
- TestCase(tvm.te.min(x + 10, y + 1) - tvm.te.min(x, y - 9), 10),
- TestCase(x - tvm.te.max(x + y, 0), tvm.te.min(0 - y, x)),
- TestCase(x - tvm.te.max(0, x + y), tvm.te.min(x, 0 - y)),
- TestCase(x - tvm.te.min(x + y, 0), tvm.te.max(0 - y, x)),
- TestCase(x - tvm.te.min(0, x + y), tvm.te.max(x, 0 - y)),
+ TestCase(tvm.tir.min(x + y, z) - x, tvm.tir.min(y, z - x)),
+ TestCase(tvm.tir.min(y + x, z) - x, tvm.tir.min(y, z - x)),
+ TestCase(tvm.tir.min(z, x + y) - x, tvm.tir.min(z - x, y)),
+ TestCase(tvm.tir.min(z, y + x) - x, tvm.tir.min(z - x, y)),
+ TestCase(tvm.tir.max(x + y, z) - x, tvm.tir.max(y, z - x)),
+ TestCase(tvm.tir.max(y + x, z) - x, tvm.tir.max(y, z - x)),
+ TestCase(tvm.tir.max(z, x + y) - x, tvm.tir.max(z - x, y)),
+ TestCase(tvm.tir.max(z, y + x) - x, tvm.tir.max(z - x, y)),
+ TestCase(x - tvm.tir.min(x + y, z), tvm.tir.max(0 - y, x - z)),
+ TestCase(x - tvm.tir.min(y + x, z), tvm.tir.max(0 - y, x - z)),
+ TestCase(x - tvm.tir.min(z, x + y), tvm.tir.max(x - z, 0 - y)),
+ TestCase(x - tvm.tir.min(z, y + x), tvm.tir.max(x - z, 0 - y)),
+ TestCase(tvm.tir.min(x, y) - tvm.tir.min(y, x), 0),
+ TestCase(tvm.tir.max(x, y) - tvm.tir.max(y, x), 0),
+ TestCase(tvm.tir.min(x, y) - tvm.tir.min(x + 10, y + 10), -10),
+ TestCase(tvm.tir.min(x + 10, y + 1) - tvm.tir.min(x, y - 9), 10),
+ TestCase(x - tvm.tir.max(x + y, 0), tvm.tir.min(0 - y, x)),
+ TestCase(x - tvm.tir.max(0, x + y), tvm.tir.min(x, 0 - y)),
+ TestCase(x - tvm.tir.min(x + y, 0), tvm.tir.max(0 - y, x)),
+ TestCase(x - tvm.tir.min(0, x + y), tvm.tir.max(x, 0 - y)),
# DivMod patterns
# truc div
TestCase(x - tdiv(x, 3) * 3, tmod(x, 3)),
@@ -514,18 +516,18 @@ class TestSubIndex(BaseCompare):
class TestMulIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
TestCase((x + 2) * 3, x * 3 + 6),
TestCase((x * 2) * 3, x * 6),
- TestCase(tvm.te.min(x, y) * tvm.te.max(x, y), x * y),
- TestCase(tvm.te.max(x, y) * tvm.te.min(x, y), x * y),
+ TestCase(tvm.tir.min(x, y) * tvm.tir.max(x, y), x * y),
+ TestCase(tvm.tir.max(x, y) * tvm.tir.min(x, y), x * y),
TestCase((x - y) * (-2), (y - x) * 2),
)
class TestDivIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
non_negative = [x >= 0, y >= 0, z >= 0]
test_case = tvm.testing.parameter(
@@ -535,11 +537,11 @@ class TestDivIndex(BaseCompare):
TestCase(tdiv(x * 2, 4), tdiv(x, 2)),
TestCase(tdiv(x * 4, 2), x * 2),
TestCase(tdiv(x * 4 + y, 2), x * 2 + tdiv(y, 2), non_negative),
- TestCase(tdiv(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, tdiv(y, 2)),
non_negative),
- TestCase(tdiv(tvm.te.max(x * 6, y), 2), tvm.te.max(x * 3, tdiv(y, 2)),
non_negative),
+ TestCase(tdiv(tvm.tir.min(x * 6, y), 2), tvm.tir.min(x * 3, tdiv(y,
2)), non_negative),
+ TestCase(tdiv(tvm.tir.max(x * 6, y), 2), tvm.tir.max(x * 3, tdiv(y,
2)), non_negative),
TestCase(tdiv(y + x * 4, 2), tdiv(y, 2) + x * 2, non_negative),
- TestCase(tdiv(tvm.te.min(y, x * 6), 2), tvm.te.min(tdiv(y, 2), x * 3),
non_negative),
- TestCase(tdiv(tvm.te.max(y, x * 6), 2), tvm.te.max(tdiv(y, 2), x * 3),
non_negative),
+ TestCase(tdiv(tvm.tir.min(y, x * 6), 2), tvm.tir.min(tdiv(y, 2), x *
3), non_negative),
+ TestCase(tdiv(tvm.tir.max(y, x * 6), 2), tvm.tir.max(tdiv(y, 2), x *
3), non_negative),
# 3-operands
TestCase(tdiv(x * 6 + y + z, 2), x * 3 + tdiv(y + z, 2), non_negative),
TestCase(tdiv(x * 6 - y + (y + 3), 2), x * 3 + 1, non_negative),
@@ -562,7 +564,7 @@ class TestDivIndex(BaseCompare):
class TestFloordivIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
TestCase(fld(fld(x, 2), 3), fld(x, 6)),
@@ -581,11 +583,11 @@ class TestFloordivIndex(BaseCompare):
TestCase(fld(x * 360 + y, 25), x * 14, [x >= 0, x < 2, y >= 0, y < 7]),
TestCase(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)),
TestCase(fld(x * 4 + y, 2), x * 2 + fld(y, 2)),
- TestCase(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))),
- TestCase(fld(tvm.te.max(x * 6, y), 2), tvm.te.max(x * 3, fld(y, 2))),
+ TestCase(fld(tvm.tir.min(x * 6, y), 2), tvm.tir.min(x * 3, fld(y, 2))),
+ TestCase(fld(tvm.tir.max(x * 6, y), 2), tvm.tir.max(x * 3, fld(y, 2))),
TestCase(fld(y + x * 4, 2), x * 2 + fld(y, 2)),
- TestCase(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3)),
- TestCase(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3)),
+ TestCase(fld(tvm.tir.min(y, x * 6), 2), tvm.tir.min(fld(y, 2), x * 3)),
+ TestCase(fld(tvm.tir.max(y, x * 6), 2), tvm.tir.max(fld(y, 2), x * 3)),
# 3-operands
#
# TODO(Lunderberg): Remove the necessity for the preconditions
@@ -615,7 +617,13 @@ class TestFloordivIndex(BaseCompare):
class TestModIndex(BaseCompare):
- x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"),
te.var("z")
+ x, y, nx, ny, z = (
+ tvm.tir.Var("x", "int32"),
+ tvm.tir.Var("y", "int32"),
+ tvm.tir.Var("nx", "int32"),
+ tvm.tir.Var("ny", "int32"),
+ tvm.tir.Var("z", "int32"),
+ )
test_case = tvm.testing.parameter(
# TODO(Lunderberg): Loosen these preconditions. When there's
@@ -647,7 +655,7 @@ class TestModIndex(BaseCompare):
class TestFloormodIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
TestCase(flm(x * 10, 2), 0),
@@ -685,7 +693,7 @@ class TestFloorModTwo(BaseCompare):
however during simplification
"""
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
# Removing offsets from floormod
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
@@ -714,7 +722,7 @@ class TestFloorModPadded(BaseCompare):
such that (x - x % k) must be divisible by k
"""
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
test_case = tvm.testing.parameter(
TestCase(flm(x - flm(x, 9), 9), 0),
TestCase(flm(x - flm(x, -9), 9), 0),
@@ -727,159 +735,183 @@ class TestFloorModPadded(BaseCompare):
class TestMinIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
# const int bound
- TestCase(tvm.te.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2)),
- TestCase(tvm.te.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)),
- TestCase(tvm.te.min(x + 1, x + 10), x + 1),
- TestCase(tvm.te.min(x + 111, x + 10), x + 10),
- TestCase(tvm.te.min(x + 1, x), x),
- TestCase(tvm.te.min(x, x + 2), x),
- TestCase(tvm.te.min(1 - x, 2 - x), 1 - x),
- TestCase(tvm.te.min(3 - x, 2 - x), 2 - x),
- TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.min(x, y)), tvm.te.min(x,
y)),
- TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.min(y, x)), tvm.te.min(x,
y)),
- TestCase(tvm.te.min(tvm.te.max(x, y), x), x),
- TestCase(tvm.te.min(tvm.te.max(y, x), x), x),
- TestCase(tvm.te.min(tvm.te.min(x, y), x), tvm.te.min(x, y)),
- TestCase(tvm.te.min(tvm.te.min(x, y), y), tvm.te.min(x, y)),
- TestCase(tvm.te.min(x, tvm.te.max(x, y)), x),
- TestCase(tvm.te.min(x, tvm.te.max(y, x)), x),
- TestCase(tvm.te.min(x, tvm.te.min(x, y)), tvm.te.min(x, y)),
- TestCase(tvm.te.min(y, tvm.te.min(x, y)), tvm.te.min(x, y)),
- TestCase(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), y),
tvm.te.min(tvm.te.min(x, y), z)),
- TestCase(
- tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), y),
- tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2),
- ),
- TestCase(
- tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z),
x * 2), z * 2), y),
- tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), z *
2),
- ),
- TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.max(x, z)),
tvm.te.max(tvm.te.min(y, z), x)),
- TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.max(z, x)),
tvm.te.max(tvm.te.min(y, z), x)),
- TestCase(tvm.te.min(tvm.te.max(y, x), tvm.te.max(x, z)),
tvm.te.max(tvm.te.min(y, z), x)),
- TestCase(tvm.te.min(tvm.te.max(y, x), tvm.te.max(z, x)),
tvm.te.max(tvm.te.min(y, z), x)),
- TestCase(tvm.te.min(y + x, z + x), tvm.te.min(y, z) + x),
- TestCase(tvm.te.min(y + x, x + z), tvm.te.min(y, z) + x),
- TestCase(tvm.te.min(x + y, z + x), tvm.te.min(y, z) + x),
- TestCase(tvm.te.min(x + y, x + z), tvm.te.min(y, z) + x),
- TestCase(tvm.te.min(x - y, x - z), x - tvm.te.max(y, z)),
- TestCase(tvm.te.min(y - x, z - x), tvm.te.min(y, z) - x),
- TestCase(tvm.te.min(tvm.te.min(x, 1), 10), tvm.te.min(x, 1)),
- TestCase(tvm.te.min(tvm.te.min(x, 11), 10), tvm.te.min(x, 10)),
- TestCase(tvm.te.min(x * 3, 9), tvm.te.min(x, 3) * 3),
- TestCase(tvm.te.min(x * 2, 0), tvm.te.min(x, 0) * 2),
- TestCase(tvm.te.min(0 - x * 2, 0), tvm.te.max(x, 0) * -2),
- TestCase(tvm.te.min(3 - x, 2), 3 - tvm.te.max(x, 1)),
- TestCase(tvm.te.min(x * (-2), -4), tvm.te.max(x, 2) * -2),
- TestCase(tvm.te.min(x * (-2), 4), tvm.te.max(x, -2) * -2),
- TestCase(tvm.te.min(x * (0), 4), 0),
- TestCase(tvm.te.min(x * (0), -4), -4),
+ TestCase(tvm.tir.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2)),
+ TestCase(tvm.tir.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)),
+ TestCase(tvm.tir.min(x + 1, x + 10), x + 1),
+ TestCase(tvm.tir.min(x + 111, x + 10), x + 10),
+ TestCase(tvm.tir.min(x + 1, x), x),
+ TestCase(tvm.tir.min(x, x + 2), x),
+ TestCase(tvm.tir.min(1 - x, 2 - x), 1 - x),
+ TestCase(tvm.tir.min(3 - x, 2 - x), 2 - x),
+ TestCase(tvm.tir.min(tvm.tir.max(x, y), tvm.tir.min(x, y)),
tvm.tir.min(x, y)),
+ TestCase(tvm.tir.min(tvm.tir.max(x, y), tvm.tir.min(y, x)),
tvm.tir.min(x, y)),
+ TestCase(tvm.tir.min(tvm.tir.max(x, y), x), x),
+ TestCase(tvm.tir.min(tvm.tir.max(y, x), x), x),
+ TestCase(tvm.tir.min(tvm.tir.min(x, y), x), tvm.tir.min(x, y)),
+ TestCase(tvm.tir.min(tvm.tir.min(x, y), y), tvm.tir.min(x, y)),
+ TestCase(tvm.tir.min(x, tvm.tir.max(x, y)), x),
+ TestCase(tvm.tir.min(x, tvm.tir.max(y, x)), x),
+ TestCase(tvm.tir.min(x, tvm.tir.min(x, y)), tvm.tir.min(x, y)),
+ TestCase(tvm.tir.min(y, tvm.tir.min(x, y)), tvm.tir.min(x, y)),
+ TestCase(
+ tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), y),
tvm.tir.min(tvm.tir.min(x, y), z)
+ ),
+ TestCase(
+ tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2),
y),
+ tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2),
+ ),
+ TestCase(
+ tvm.tir.min(
+ tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x *
2), z * 2), y
+ ),
+ tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2),
z * 2),
+ ),
+ TestCase(
+ tvm.tir.min(tvm.tir.max(x, y), tvm.tir.max(x, z)),
tvm.tir.max(tvm.tir.min(y, z), x)
+ ),
+ TestCase(
+ tvm.tir.min(tvm.tir.max(x, y), tvm.tir.max(z, x)),
tvm.tir.max(tvm.tir.min(y, z), x)
+ ),
+ TestCase(
+ tvm.tir.min(tvm.tir.max(y, x), tvm.tir.max(x, z)),
tvm.tir.max(tvm.tir.min(y, z), x)
+ ),
+ TestCase(
+ tvm.tir.min(tvm.tir.max(y, x), tvm.tir.max(z, x)),
tvm.tir.max(tvm.tir.min(y, z), x)
+ ),
+ TestCase(tvm.tir.min(y + x, z + x), tvm.tir.min(y, z) + x),
+ TestCase(tvm.tir.min(y + x, x + z), tvm.tir.min(y, z) + x),
+ TestCase(tvm.tir.min(x + y, z + x), tvm.tir.min(y, z) + x),
+ TestCase(tvm.tir.min(x + y, x + z), tvm.tir.min(y, z) + x),
+ TestCase(tvm.tir.min(x - y, x - z), x - tvm.tir.max(y, z)),
+ TestCase(tvm.tir.min(y - x, z - x), tvm.tir.min(y, z) - x),
+ TestCase(tvm.tir.min(tvm.tir.min(x, 1), 10), tvm.tir.min(x, 1)),
+ TestCase(tvm.tir.min(tvm.tir.min(x, 11), 10), tvm.tir.min(x, 10)),
+ TestCase(tvm.tir.min(x * 3, 9), tvm.tir.min(x, 3) * 3),
+ TestCase(tvm.tir.min(x * 2, 0), tvm.tir.min(x, 0) * 2),
+ TestCase(tvm.tir.min(0 - x * 2, 0), tvm.tir.max(x, 0) * -2),
+ TestCase(tvm.tir.min(3 - x, 2), 3 - tvm.tir.max(x, 1)),
+ TestCase(tvm.tir.min(x * (-2), -4), tvm.tir.max(x, 2) * -2),
+ TestCase(tvm.tir.min(x * (-2), 4), tvm.tir.max(x, -2) * -2),
+ TestCase(tvm.tir.min(x * (0), 4), 0),
+ TestCase(tvm.tir.min(x * (0), -4), -4),
# DivMod rules
# truc div
- TestCase(tvm.te.min(tdiv(x + 3, 4) * 4, x), x),
- TestCase(tvm.te.min(x, tdiv(x + 3, 4) * 4), x),
- TestCase(tvm.te.min(tdiv(x + 3, 4) * 4, tvm.te.max(x, 4)),
tvm.te.max(x, 4), x > 0),
- TestCase(tvm.te.min(tvm.te.max(x, 4), tdiv(x + 3, 4) * 4),
tvm.te.max(x, 4), x > 0),
- TestCase(tvm.te.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.min(x, y),
10)),
- TestCase(tvm.te.min(tdiv(x, (-10)), tdiv(y, (-10))),
tdiv(tvm.te.max(x, y), (-10))),
+ TestCase(tvm.tir.min(tdiv(x + 3, 4) * 4, x), x),
+ TestCase(tvm.tir.min(x, tdiv(x + 3, 4) * 4), x),
+ TestCase(tvm.tir.min(tdiv(x + 3, 4) * 4, tvm.tir.max(x, 4)),
tvm.tir.max(x, 4), x > 0),
+ TestCase(tvm.tir.min(tvm.tir.max(x, 4), tdiv(x + 3, 4) * 4),
tvm.tir.max(x, 4), x > 0),
+ TestCase(tvm.tir.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tir.min(x,
y), 10)),
+ TestCase(tvm.tir.min(tdiv(x, (-10)), tdiv(y, (-10))),
tdiv(tvm.tir.max(x, y), (-10))),
# floor div
- TestCase(tvm.te.min(fld(x + 3, 4) * 4, x), x),
- TestCase(tvm.te.min(x, fld(x + 3, 4) * 4), x),
- TestCase(tvm.te.min(x, fld(x, 4) * 4), fld(x, 4) * 4),
- TestCase(tvm.te.min(fld(x + 3, 4) * 4, tvm.te.max(x, 4)),
tvm.te.max(x, 4), x > 0),
- TestCase(tvm.te.min(tvm.te.max(x, 4), fld(x + 3, 4) * 4),
tvm.te.max(x, 4), x > 0),
- TestCase(tvm.te.min(fld(x, 10), fld(y, 10)), fld(tvm.te.min(x, y),
10)),
- TestCase(tvm.te.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.max(x,
y), (-10))),
+ TestCase(tvm.tir.min(fld(x + 3, 4) * 4, x), x),
+ TestCase(tvm.tir.min(x, fld(x + 3, 4) * 4), x),
+ TestCase(tvm.tir.min(x, fld(x, 4) * 4), fld(x, 4) * 4),
+ TestCase(tvm.tir.min(fld(x + 3, 4) * 4, tvm.tir.max(x, 4)),
tvm.tir.max(x, 4), x > 0),
+ TestCase(tvm.tir.min(tvm.tir.max(x, 4), fld(x + 3, 4) * 4),
tvm.tir.max(x, 4), x > 0),
+ TestCase(tvm.tir.min(fld(x, 10), fld(y, 10)), fld(tvm.tir.min(x, y),
10)),
+ TestCase(tvm.tir.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.tir.max(x,
y), (-10))),
)
class TestMaxIndex(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
# const int bound
- TestCase(tvm.te.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10),
- TestCase(tvm.te.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10),
- TestCase(tvm.te.max(x + 1, x + 10), x + 10),
- TestCase(tvm.te.max(x + 111, x + 10), x + 111),
- TestCase(tvm.te.max(x + 1, x), x + 1),
- TestCase(tvm.te.max(x, x + 2), x + 2),
- TestCase(tvm.te.max(1 - x, 2 - x), 2 - x),
- TestCase(tvm.te.max(3 - x, 2 - x), 3 - x),
- TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.max(x, y)), tvm.te.max(x,
y)),
- TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.max(y, x)), tvm.te.max(x,
y)),
- TestCase(tvm.te.max(tvm.te.min(x, y), x), x),
- TestCase(tvm.te.max(tvm.te.min(y, x), x), x),
- TestCase(tvm.te.max(tvm.te.max(x, y), x), tvm.te.max(x, y)),
- TestCase(tvm.te.max(tvm.te.max(x, y), y), tvm.te.max(x, y)),
- TestCase(tvm.te.max(x, tvm.te.min(x, y)), x),
- TestCase(tvm.te.max(x, tvm.te.min(y, x)), x),
- TestCase(tvm.te.max(x, tvm.te.max(x, y)), tvm.te.max(x, y)),
- TestCase(tvm.te.max(y, tvm.te.max(x, y)), tvm.te.max(x, y)),
- TestCase(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), y),
tvm.te.max(tvm.te.max(x, y), z)),
- TestCase(
- tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), y),
- tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2),
- ),
- TestCase(
- tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z),
x * 2), z * 2), y),
- tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z *
2),
- ),
- TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.min(x, z)),
tvm.te.min(tvm.te.max(y, z), x)),
- TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.min(z, x)),
tvm.te.min(tvm.te.max(y, z), x)),
- TestCase(tvm.te.max(tvm.te.min(y, x), tvm.te.min(x, z)),
tvm.te.min(tvm.te.max(y, z), x)),
- TestCase(tvm.te.max(tvm.te.min(y, x), tvm.te.min(z, x)),
tvm.te.min(tvm.te.max(y, z), x)),
- TestCase(tvm.te.max(y + x, z + x), tvm.te.max(y, z) + x),
- TestCase(tvm.te.max(y + x, x + z), tvm.te.max(y, z) + x),
- TestCase(tvm.te.max(x + y, z + x), tvm.te.max(y, z) + x),
- TestCase(tvm.te.max(x + y, x + z), tvm.te.max(y, z) + x),
- TestCase(tvm.te.max(x - y, x - z), x - tvm.te.min(y, z)),
- TestCase(tvm.te.max(y - x, z - x), tvm.te.max(y, z) - x),
- TestCase(tvm.te.max(tvm.te.max(x, 1), 10), tvm.te.max(x, 10)),
- TestCase(tvm.te.max(tvm.te.max(x, 11), 10), tvm.te.max(x, 11)),
- TestCase(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3),
- TestCase(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x, 2)),
- TestCase(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2),
- TestCase(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2),
- TestCase(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2),
- TestCase(tvm.te.max(x * (-2), 4), tvm.te.min(x, -2) * -2),
- TestCase(tvm.te.max(x * (0), 4), 4),
- TestCase(tvm.te.max(x * (0), -4), 0),
+ TestCase(tvm.tir.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10),
+ TestCase(tvm.tir.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10),
+ TestCase(tvm.tir.max(x + 1, x + 10), x + 10),
+ TestCase(tvm.tir.max(x + 111, x + 10), x + 111),
+ TestCase(tvm.tir.max(x + 1, x), x + 1),
+ TestCase(tvm.tir.max(x, x + 2), x + 2),
+ TestCase(tvm.tir.max(1 - x, 2 - x), 2 - x),
+ TestCase(tvm.tir.max(3 - x, 2 - x), 3 - x),
+ TestCase(tvm.tir.max(tvm.tir.min(x, y), tvm.tir.max(x, y)),
tvm.tir.max(x, y)),
+ TestCase(tvm.tir.max(tvm.tir.min(x, y), tvm.tir.max(y, x)),
tvm.tir.max(x, y)),
+ TestCase(tvm.tir.max(tvm.tir.min(x, y), x), x),
+ TestCase(tvm.tir.max(tvm.tir.min(y, x), x), x),
+ TestCase(tvm.tir.max(tvm.tir.max(x, y), x), tvm.tir.max(x, y)),
+ TestCase(tvm.tir.max(tvm.tir.max(x, y), y), tvm.tir.max(x, y)),
+ TestCase(tvm.tir.max(x, tvm.tir.min(x, y)), x),
+ TestCase(tvm.tir.max(x, tvm.tir.min(y, x)), x),
+ TestCase(tvm.tir.max(x, tvm.tir.max(x, y)), tvm.tir.max(x, y)),
+ TestCase(tvm.tir.max(y, tvm.tir.max(x, y)), tvm.tir.max(x, y)),
+ TestCase(
+ tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), y),
tvm.tir.max(tvm.tir.max(x, y), z)
+ ),
+ TestCase(
+ tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2),
y),
+ tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2),
+ ),
+ TestCase(
+ tvm.tir.max(
+ tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x *
2), z * 2), y
+ ),
+ tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2),
z * 2),
+ ),
+ TestCase(
+ tvm.tir.max(tvm.tir.min(x, y), tvm.tir.min(x, z)),
tvm.tir.min(tvm.tir.max(y, z), x)
+ ),
+ TestCase(
+ tvm.tir.max(tvm.tir.min(x, y), tvm.tir.min(z, x)),
tvm.tir.min(tvm.tir.max(y, z), x)
+ ),
+ TestCase(
+ tvm.tir.max(tvm.tir.min(y, x), tvm.tir.min(x, z)),
tvm.tir.min(tvm.tir.max(y, z), x)
+ ),
+ TestCase(
+ tvm.tir.max(tvm.tir.min(y, x), tvm.tir.min(z, x)),
tvm.tir.min(tvm.tir.max(y, z), x)
+ ),
+ TestCase(tvm.tir.max(y + x, z + x), tvm.tir.max(y, z) + x),
+ TestCase(tvm.tir.max(y + x, x + z), tvm.tir.max(y, z) + x),
+ TestCase(tvm.tir.max(x + y, z + x), tvm.tir.max(y, z) + x),
+ TestCase(tvm.tir.max(x + y, x + z), tvm.tir.max(y, z) + x),
+ TestCase(tvm.tir.max(x - y, x - z), x - tvm.tir.min(y, z)),
+ TestCase(tvm.tir.max(y - x, z - x), tvm.tir.max(y, z) - x),
+ TestCase(tvm.tir.max(tvm.tir.max(x, 1), 10), tvm.tir.max(x, 10)),
+ TestCase(tvm.tir.max(tvm.tir.max(x, 11), 10), tvm.tir.max(x, 11)),
+ TestCase(tvm.tir.max(x * 3, 9), tvm.tir.max(x, 3) * 3),
+ TestCase(tvm.tir.max(3 - x, 1), 3 - tvm.tir.min(x, 2)),
+ TestCase(tvm.tir.max(x * 2, 0), tvm.tir.max(x, 0) * 2),
+ TestCase(tvm.tir.max(0 - x * 2, 0), tvm.tir.min(x, 0) * -2),
+ TestCase(tvm.tir.max(x * (-2), -4), tvm.tir.min(x, 2) * -2),
+ TestCase(tvm.tir.max(x * (-2), 4), tvm.tir.min(x, -2) * -2),
+ TestCase(tvm.tir.max(x * (0), 4), 4),
+ TestCase(tvm.tir.max(x * (0), -4), 0),
# DivMod rules
# truc div
- TestCase(tvm.te.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.max(x, y),
10)),
- TestCase(tvm.te.max(tdiv(x, (-10)), tdiv(y, (-10))),
tdiv(tvm.te.min(x, y), (-10))),
- TestCase(tvm.te.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4),
+ TestCase(tvm.tir.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tir.max(x,
y), 10)),
+ TestCase(tvm.tir.max(tdiv(x, (-10)), tdiv(y, (-10))),
tdiv(tvm.tir.min(x, y), (-10))),
+ TestCase(tvm.tir.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4),
# floordiv
- TestCase(tvm.te.max(fld(x, 10), fld(y, 10)), fld(tvm.te.max(x, y),
10)),
- TestCase(tvm.te.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.min(x,
y), (-10))),
- TestCase(tvm.te.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4),
- TestCase(tvm.te.max(fld(x, 4) * 4, x), x),
- TestCase(tvm.te.max(x, fld(x, 4) * 4), x),
+ TestCase(tvm.tir.max(fld(x, 10), fld(y, 10)), fld(tvm.tir.max(x, y),
10)),
+ TestCase(tvm.tir.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.tir.min(x,
y), (-10))),
+ TestCase(tvm.tir.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4),
+ TestCase(tvm.tir.max(fld(x, 4) * 4, x), x),
+ TestCase(tvm.tir.max(x, fld(x, 4) * 4), x),
)
class TestScalableIndex(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
test_case = tvm.testing.parameter(
# MinNode
- TestCase(tvm.te.min(x + tir.vscale() * 4, x), x),
- TestCase(tvm.te.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4),
- TestCase(tvm.te.min(x + tir.vscale() * 4, x + tir.vscale() * 8),
tir.vscale() * 4 + x),
- TestCase(tvm.te.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4),
x), x),
- TestCase(tvm.te.min(tir.vscale() * x, tir.vscale() * y), tir.vscale()
* x, x < y),
+ TestCase(tvm.tir.min(x + tir.vscale() * 4, x), x),
+ TestCase(tvm.tir.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4),
+ TestCase(tvm.tir.min(x + tir.vscale() * 4, x + tir.vscale() * 8),
tir.vscale() * 4 + x),
+ TestCase(tvm.tir.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4),
x), x),
+ TestCase(tvm.tir.min(tir.vscale() * x, tir.vscale() * y), tir.vscale()
* x, x < y),
# MaxNode
- TestCase(tvm.te.max(x + tir.vscale() * 4, x), x + tir.vscale() * 4),
- TestCase(tvm.te.max(x - tir.vscale() * 4, x), x),
- TestCase(tvm.te.max(x + tir.vscale() * 4, x + tir.vscale() * 4), x +
tir.vscale() * 4),
+ TestCase(tvm.tir.max(x + tir.vscale() * 4, x), x + tir.vscale() * 4),
+ TestCase(tvm.tir.max(x - tir.vscale() * 4, x), x),
+ TestCase(tvm.tir.max(x + tir.vscale() * 4, x + tir.vscale() * 4), x +
tir.vscale() * 4),
TestCase(
- tvm.te.max(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x),
+ tvm.tir.max(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x),
x + tir.vscale() * 4 - flm(4, tir.vscale() * 4),
),
- TestCase(tvm.te.max(tir.vscale() * x, tir.vscale() * y), tir.vscale()
* x, x > y),
+ TestCase(tvm.tir.max(tir.vscale() * x, tir.vscale() * y), tir.vscale()
* x, x > y),
# FloorDiv
TestCase(fld(x * tir.vscale() * 4 + y, tir.vscale() * 4), x + fld(y,
tir.vscale() * 4)),
TestCase(fld(x, tir.vscale() * 4), 0, [x >= 0, x < tir.vscale() * 4]),
@@ -894,7 +926,7 @@ class TestScalableIndex(BaseCompare):
class TestComparisons(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
# const int bound
@@ -1066,10 +1098,10 @@ class TestComparisons(BaseCompare):
tir.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)),
tir.all(T.int32(-43) <= x, x < -40),
),
- TestCase(tvm.te.min(x, 11) < 10, x < 10),
- TestCase(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")),
- TestCase(tvm.te.max(8, x) > 10, tvm.tir.LT(10, x)),
- TestCase(x + 1 < tvm.te.max(8, x), x < 7),
+ TestCase(tvm.tir.min(x, 11) < 10, x < 10),
+ TestCase(tvm.tir.min(x, 8) < 10, tvm.tir.const(1, "bool")),
+ TestCase(tvm.tir.max(8, x) > 10, tvm.tir.LT(10, x)),
+ TestCase(x + 1 < tvm.tir.max(8, x), x < 7),
TestCase(x < 11, tvm.tir.const(1, "bool"), x <= 10),
TestCase(x <= 10, tvm.tir.const(1, "bool"), x <= 10),
TestCase(z <= 5, tvm.tir.const(1, "bool"), z <= 5),
@@ -1088,7 +1120,7 @@ class TestComparisons(BaseCompare):
class TestComparisonOfProductAndSum(BaseCompare):
extensions = tvm.arith.Extension.ComparisonOfProductAndSum
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
# Special inequality cases
@@ -1119,7 +1151,7 @@ class TestComparisonOfProductAndSum(BaseCompare):
class TestLogical(BaseCompare):
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
test_case = tvm.testing.parameter(
TestCase(tvm.tir.And(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)),
tvm.tir.const(False, "bool")),
@@ -1164,7 +1196,7 @@ class TestLogical(BaseCompare):
class TestLet(BaseCompare):
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
z = tvm.tir.Let(x, 1, x + 1)
test_case = tvm.testing.parameter(
@@ -1174,7 +1206,7 @@ class TestLet(BaseCompare):
class TestCast(BaseCompare):
def _generate_tests():
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
dtypes = ["float32", "float16", "int32", "int8", "bool"]
for dtype1 in dtypes:
yield TestCase(tvm.tir.Cast(dtype1, x - x), tvm.tir.const(0,
dtype1))
@@ -1218,7 +1250,7 @@ class TestSubBufferload(BaseCompare):
class TestIfThenElse(BaseCompare):
- x = te.var("x", "int32")
+ x = tvm.tir.Var("x", "int32")
test_case = tvm.testing.parameter(
TestCase(
diff --git a/tests/python/arith/test_arith_solve_linear_equations.py
b/tests/python/arith/test_arith_solve_linear_equations.py
index 61c107915e..ddb0b36db8 100644
--- a/tests/python/arith/test_arith_solve_linear_equations.py
+++ b/tests/python/arith/test_arith_solve_linear_equations.py
@@ -18,7 +18,7 @@ import random
import sys
import pytest
import tvm
-from tvm import te, arith, ir, tir, testing
+from tvm import arith, ir, tir, testing
from tvm.script import tir as T
@@ -31,7 +31,7 @@ def test_solution_consistency():
random.seed(seed)
def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
- variables = [te.var("x" + str(i)) for i in range(num_vars)]
+ variables = [tvm.tir.Var("x" + str(i), "int32") for i in
range(num_vars)]
relations = []
for i in range(num_formulas):
@@ -85,7 +85,7 @@ def test_solution_consistency():
def test_empty_var_to_solve():
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
equations = [
tvm.tir.EQ(x + y, 20),
tvm.tir.EQ(x - y, 10),
@@ -100,7 +100,7 @@ def test_empty_var_to_solve():
def test_unique_solution():
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
solution = arith.solve_linear_equations(
[
@@ -115,7 +115,7 @@ def test_unique_solution():
def test_low_rank():
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
ranges = {}
solution = arith.solve_linear_equations(
@@ -133,7 +133,7 @@ def test_low_rank():
def test_infer_range():
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
ranges = {
x: tvm.ir.Range.from_min_extent(-5, 10),
y: tvm.ir.Range.from_min_extent(0, 10),
@@ -160,7 +160,7 @@ def test_infer_range():
def test_ill_formed():
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
solution = arith.solve_linear_equations(
[
diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py
b/tests/python/arith/test_arith_solve_linear_inequality.py
index 192c46b56c..17557fd952 100644
--- a/tests/python/arith/test_arith_solve_linear_inequality.py
+++ b/tests/python/arith/test_arith_solve_linear_inequality.py
@@ -18,7 +18,7 @@ import random
import sys
import pytest
import tvm
-from tvm import te, arith, ir, tir, testing
+from tvm import arith, ir, tir, testing
from tvm.script import tir as T
@@ -32,7 +32,7 @@ def test_solution_consistency():
random.seed(seed)
def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)):
- vs = [te.var("x" + str(i)) for i in range(variables)]
+ vs = [tvm.tir.Var("x" + str(i), "int32") for i in range(variables)]
fs = []
for i in range(formulas):
@@ -44,9 +44,9 @@ def test_solution_consistency():
fs.append(op(s1, s2))
vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs}
- before = te.all(tir.const(1, "bool"), *fs)
+ before = tvm.tir.all(tir.const(1, "bool"), *fs)
after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs)
- after = te.all(tir.const(1, "bool"), *after)
+ after = tvm.tir.all(tir.const(1, "bool"), *after)
testing.check_bool_expr_is_true(before == after, vranges)
solution = arith.solve_linear_inequalities(fs, vs, vranges,
deskew_range=True)
@@ -81,7 +81,7 @@ def test_solution_consistency():
def test_dual_variable():
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
variables = [x, y]
ranges = {
@@ -125,7 +125,7 @@ def test_dual_variable():
def test_equal():
- x, y = te.var("x"), te.var("y")
+ x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
problem = [
tvm.tir.GE(x + y, 10),
tvm.tir.GE(x - y, 2),
@@ -147,7 +147,7 @@ def test_equal():
def test_multi_equal():
- x, y, z = te.var("x"), te.var("y"), te.var("z")
+ x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"),
tvm.tir.Var("z", "int32")
problem = [
tvm.tir.LE(x, 6),
tvm.tir.GE(x, 6),
@@ -179,7 +179,7 @@ def test_multi_equal():
def test_no_solution():
- x = te.var("x0")
+ x = tvm.tir.Var("x0", "int32")
vranges = {x: tvm.ir.Range.from_min_extent(-20, 41)}
problem = [-x - 4 <= -5 * x + 2, x * 4 + 5 <= x * 5]
@@ -198,8 +198,8 @@ def test_no_solution():
def test_unbound_var_range():
- x = te.var("x0")
- free_var = te.var("fv")
+ x = tvm.tir.Var("x0", "int32")
+ free_var = tvm.tir.Var("fv", "int32")
vranges = {x: tvm.ir.Range.from_min_extent(0, tvm.tir.Cast("int32", 1 +
tvm.tir.log(free_var)))}
problem = [x > 3]
solution = arith.solve_linear_inequalities(
diff --git a/tests/python/ir/test_ir_container.py
b/tests/python/ir/test_ir_container.py
index cd47766159..af820acb0a 100644
--- a/tests/python/ir/test_ir_container.py
+++ b/tests/python/ir/test_ir_container.py
@@ -17,7 +17,7 @@
import pytest
import tvm_ffi
import tvm
-from tvm import te
+
import numpy as np
@@ -45,8 +45,8 @@ def test_dir_array():
def test_map():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
amap = tvm.runtime.convert({a: 2, b: 3})
assert a in amap
assert len(amap) == 2
@@ -70,8 +70,8 @@ def test_str_map():
def test_map_save_load_json():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
amap = tvm.runtime.convert({a: 2, b: 3})
json_str = tvm.ir.save_json(amap)
amap = tvm.ir.load_json(json_str)
@@ -81,15 +81,15 @@ def test_map_save_load_json():
def test_dir_map():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
amap = tvm.runtime.convert({a: 2, b: 3})
assert dir(amap)
def test_getattr_map():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
amap = tvm.runtime.convert({a: 2, b: 3})
assert isinstance(amap, tvm_ffi.Map)
diff --git a/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py
b/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py
index c3ae417dcd..5ed8314d1a 100644
--- a/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py
+++ b/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py
@@ -15,18 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te
def test_equal_expr():
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
def func1():
return x + y + 1
def func2():
- return te.exp(tvm.tir.truncdiv((x + y + 1) * y, 4))
+ return tvm.tir.exp(tvm.tir.truncdiv((x + y + 1) * y, 4))
assert tvm.tir.analysis.expr_deep_equal(func1(), func1())
assert tvm.tir.analysis.expr_deep_equal(func2(), func2())
diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py
b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py
index a6db37f5a2..6d559a81d2 100644
--- a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py
+++ b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py
@@ -15,12 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te
def test_verify_ssa():
- x = te.var("x")
- y = te.var()
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("tindex", "int32")
z = tvm.tir.Evaluate(x + y)
assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], z))
@@ -28,7 +27,7 @@ def test_verify_ssa():
def test_verify_weak_let_ssa():
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
z1 = tvm.tir.Let(x, 1, x + 1)
z2 = tvm.tir.Let(x, 2, x + 2)
diff --git a/tests/python/tir-base/test_tir_buffer.py
b/tests/python/tir-base/test_tir_buffer.py
index 791de76995..df60d43061 100644
--- a/tests/python/tir-base/test_tir_buffer.py
+++ b/tests/python/tir-base/test_tir_buffer.py
@@ -17,7 +17,6 @@
import tvm
import tvm.testing
-from tvm import te
from tvm.tir import Buffer
from tvm.script import tir as T
@@ -26,9 +25,9 @@ import pytest
def test_buffer():
- m = te.size_var("m")
- n = te.size_var("n")
- l = te.size_var("l")
+ m = tvm.tir.SizeVar("m", "int32")
+ n = tvm.tir.SizeVar("n", "int32")
+ l = tvm.tir.SizeVar("l", "int32")
Ab = tvm.tir.decl_buffer((m, n), "float32")
Bb = tvm.tir.decl_buffer((n, l), "float32")
@@ -38,8 +37,8 @@ def test_buffer():
def test_buffer_access_ptr():
- m = te.size_var("m")
- n = te.size_var("n")
+ m = tvm.tir.SizeVar("m", "int32")
+ n = tvm.tir.SizeVar("n", "int32")
Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1])
aptr = Ab.access_ptr("rw")
tvm.ir.assert_structural_equal(aptr.args[3], Ab.strides[0] * m)
@@ -50,13 +49,13 @@ def test_buffer_access_ptr():
def test_buffer_access_ptr_offset():
- m = te.size_var("m")
- n = te.size_var("n")
+ m = tvm.tir.SizeVar("m", "int32")
+ n = tvm.tir.SizeVar("n", "int32")
Ab = tvm.tir.decl_buffer((m, n), "float32")
aptr = Ab.access_ptr("rw", offset=100)
tvm.testing.assert_prim_expr_equal(aptr.args[2], 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
- v = te.size_var("int32")
+ v = tvm.tir.SizeVar("int32", "int32")
aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
@@ -68,8 +67,8 @@ def test_buffer_access_ptr_offset():
def test_buffer_access_ptr_extent():
- m = te.size_var("m")
- n = te.size_var("n")
+ m = tvm.tir.SizeVar("m", "int32")
+ n = tvm.tir.SizeVar("n", "int32")
Ab = tvm.tir.decl_buffer((m, n), "float32")
aptr = Ab.access_ptr("rw")
tvm.ir.assert_structural_equal(aptr.args[3], m * n)
@@ -87,27 +86,27 @@ def test_buffer_access_ptr_extent():
def test_buffer_vload():
- m = te.size_var("m")
- n = te.size_var("n")
+ m = tvm.tir.SizeVar("m", "int32")
+ n = tvm.tir.SizeVar("n", "int32")
Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
load = Ab.vload([2, 3])
tvm.ir.assert_structural_equal(load.indices, [T.int32(2), T.int32(3)])
def test_buffer_offset_of():
- m = te.size_var("m")
- n = te.size_var("n")
+ m = tvm.tir.SizeVar("m", "int32")
+ n = tvm.tir.SizeVar("n", "int32")
Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
offset = Ab.offset_of([2, 3])
tvm.ir.assert_structural_equal(offset, [n * 2 + 103])
def test_buffer_index_merge_mult_mod():
- m = te.size_var("m")
- n = te.size_var("n")
- s = te.size_var("s")
- k0 = te.size_var("k0")
- k1 = te.size_var("k1")
+ m = tvm.tir.SizeVar("m", "int32")
+ n = tvm.tir.SizeVar("n", "int32")
+ s = tvm.tir.SizeVar("s", "int32")
+ k0 = tvm.tir.SizeVar("k0", "int32")
+ k1 = tvm.tir.SizeVar("k1", "int32")
A = tvm.tir.decl_buffer((m, n), "float32")
A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
@@ -153,9 +152,9 @@ def test_buffer_index_merge_mult_mod():
# Test Case5
B = tvm.tir.decl_buffer((1, 14, 14, 1024))
- i = te.size_var("i")
- j = te.size_var("j")
- k = te.size_var("k")
+ i = tvm.tir.SizeVar("i", "int32")
+ j = tvm.tir.SizeVar("j", "int32")
+ k = tvm.tir.SizeVar("k", "int32")
index_simplified1 = B.offset_of(
(
diff --git a/tests/python/tir-base/test_tir_constructor.py
b/tests/python/tir-base/test_tir_constructor.py
index 4076070557..25f62c7096 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -17,7 +17,6 @@
import pytest
import tvm
-from tvm import te
def test_expr_constructor():
@@ -50,7 +49,7 @@ def test_expr_constructor():
assert x.value.value == 1
a = tvm.tir.const(1.0, dtype="float32")
- b = te.var("x", dtype="float32")
+ b = tvm.tir.Var("x", "float32")
for cls in [
tvm.tir.Add,
@@ -70,8 +69,8 @@ def test_expr_constructor():
assert x.a == a
assert x.b.same_as(b)
- a = tvm.runtime.convert(te.var("x") > 1)
- b = tvm.runtime.convert(te.var("x") == 1)
+ a = tvm.runtime.convert(tvm.tir.Var("x", "int32") > 1)
+ b = tvm.runtime.convert(tvm.tir.Var("x", "int32") == 1)
for cls in [tvm.tir.And, tvm.tir.Or]:
x = cls(a, b)
@@ -120,7 +119,7 @@ def test_expr_constructor():
assert x.op.name == "tir.call_extern"
assert x.args[1] == a
- v = te.var("aa")
+ v = tvm.tir.Var("aa", "int32")
x = tvm.tir.Let(v, 1, v)
assert x.var == v
assert x.value.value == 1
@@ -128,7 +127,7 @@ def test_expr_constructor():
def test_stmt_constructor():
- v = te.var("aa")
+ v = tvm.tir.Var("aa", "int32")
nop = tvm.tir.Evaluate(1)
x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1))
assert isinstance(x, tvm.tir.LetStmt)
@@ -144,7 +143,7 @@ def test_stmt_constructor():
assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop
- x = tvm.tir.For(te.var("x"), 0, 10, tvm.tir.ForKind.SERIAL, nop)
+ x = tvm.tir.For(tvm.tir.Var("x", "int32"), 0, 10, tvm.tir.ForKind.SERIAL,
nop)
assert isinstance(x, tvm.tir.For)
assert x.min.value == 0
assert x.extent.value == 10
diff --git a/tests/python/tir-base/test_tir_intrin.py
b/tests/python/tir-base/test_tir_intrin.py
index 9f199a2b63..f85c06c5b2 100644
--- a/tests/python/tir-base/test_tir_intrin.py
+++ b/tests/python/tir-base/test_tir_intrin.py
@@ -55,7 +55,7 @@ def test_nearbyint():
def test_round_intrinsics_on_int():
- i = tvm.te.var("i", "int32")
+ i = tvm.tir.Var("i", "int32")
for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil, tvm.tir.floor,
tvm.tir.nearbyint]:
assert op(tvm.tir.const(10, "int32")).value == 10
assert op(tvm.tir.const(True, "bool")).value == True
diff --git a/tests/python/tir-base/test_tir_nodes.py
b/tests/python/tir-base/test_tir_nodes.py
index b0bd5ac891..7f857a3203 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -17,7 +17,7 @@
import numpy as np
import pytest
import tvm
-from tvm import ir, te
+from tvm import ir
def test_const():
@@ -27,7 +27,7 @@ def test_const():
def test_te_const():
- x = tvm.te.const(1, "int32")
+ x = tvm.tir.const(1, "int32")
assert x.dtype == "int32"
assert isinstance(x, tvm.tir.IntImm)
@@ -57,7 +57,7 @@ def test_tir_const_dtype_inference():
def test_make():
x = tvm.tir.const(1, "int32")
- y = te.var("x")
+ y = tvm.tir.Var("x", "int32")
z = x + y
assert isinstance(tvm.tir.max(x, y), tvm.tir.Max)
assert isinstance(tvm.tir.min(x, y), tvm.tir.Min)
@@ -72,12 +72,12 @@ def test_ir():
def test_ir2():
- buf_size = te.var("size")
- x = te.var("n")
+ buf_size = tvm.tir.Var("size", "int32")
+ x = tvm.tir.Var("n", "int32")
storage_type = ir.PrimType("int32")
handle_type = ir.PointerType(storage_type)
- array = te.var("array", handle_type)
+ array = tvm.tir.Var("array", handle_type)
buf = tvm.tir.decl_buffer([buf_size], "int32", data=array)
st = tvm.tir.BufferStore(buf, x + 1, [1])
@@ -87,13 +87,13 @@ def test_ir2():
def test_let():
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
def test_cast():
- x = te.var("x", dtype="float32")
+ x = tvm.tir.Var("x", "float32")
y = x.astype("int32")
z = x.astype("float32x4")
assert isinstance(y, tvm.tir.Cast)
@@ -110,8 +110,8 @@ def test_cast():
def test_attr():
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
stmt = tvm.tir.AttrStmt(y, "stride", 10, tvm.tir.Evaluate(x + 1))
assert stmt.node == y
@@ -125,34 +125,34 @@ def test_attr():
def test_basic():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
c = a + b
assert str(c) == "%s + %s" % (a.name, b.name)
def test_stmt():
x = tvm.tir.Evaluate(0)
- tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x)
- tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.UNROLLED, x, step=2)
+ tvm.tir.For(tvm.tir.Var("i", "int32"), 0, 1, tvm.tir.ForKind.SERIAL, x)
+ tvm.tir.For(tvm.tir.Var("i", "int32"), 0, 1, tvm.tir.ForKind.UNROLLED, x,
step=2)
def test_dir():
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
dir(x)
def test_dtype():
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
assert x.dtype == "int32"
- y = te.var("y")
+ y = tvm.tir.Var("y", "int32")
assert (x > y).dtype == "bool"
def test_any():
- x = te.var("x")
- y = te.var("y")
- z = te.var("z")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
try:
t = x or x
assert False
@@ -183,9 +183,9 @@ def test_any():
def test_all():
- x = te.var("x")
- y = te.var("y")
- z = te.var("z")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
try:
t = x and x
assert False
@@ -216,8 +216,8 @@ def test_all():
def test_bitwise():
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
assert str(x << y) == "T.shift_left(x, y)"
assert str(x >> y) == "T.shift_right(x, y)"
assert str(x & y) == "T.bitwise_and(x, y)"
@@ -233,7 +233,7 @@ def test_bitwise():
assert str(~x) == "T.bitwise_not(x)"
assert (tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
assert (x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
- assert (te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype ==
"int8x2"
+ assert (tvm.tir.Var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype ==
"int8x2"
def test_float_bitwise():
@@ -258,7 +258,7 @@ def test_float_bitwise():
def test_shift_bounds():
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
for test in [lambda lhs, rhs: lhs << rhs, lambda lhs, rhs: lhs >> rhs]:
# negative case
for testcase in [(x, -1), (x, 32)]:
@@ -295,20 +295,20 @@ def test_infinity():
def test_isnan():
- x = te.var("x", "float32")
+ x = tvm.tir.Var("x", "float32")
assert str(tvm.tir.isnan(x)) == "T.isnan(x)"
assert str(tvm.tir.isnan(x).dtype) == "bool"
- y = te.var("y", "float16")
+ y = tvm.tir.Var("y", "float16")
assert str(tvm.tir.isnan(y)) == 'T.isnan(T.Cast("float32", y))'
- z = te.var("z", "int32")
+ z = tvm.tir.Var("z", "int32")
assert str(tvm.tir.isnan(z)) == "T.bool(False)"
- k = te.var("k", "int8x2")
+ k = tvm.tir.Var("k", "int8x2")
assert str(tvm.tir.isnan(k).dtype) == "boolx2"
def test_equality():
- a = te.var("a")
- b = te.var("b")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
c = a == b
assert not c
d = c != c
@@ -323,8 +323,8 @@ def test_equality_string_imm():
def test_prim_func():
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
@@ -390,7 +390,7 @@ def _create_broadcast(lanes):
return tvm.tir.Broadcast(0, lanes)
[email protected]("lanes", [(tvm.tir.IntImm(dtype="int64", value=11))])
[email protected]("lanes", [tvm.tir.IntImm(dtype="int64", value=11)])
@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
def test_lane_types(lanes, node_func):
def _check_dtype(node):
diff --git a/tests/python/tir-base/test_tir_ops.py
b/tests/python/tir-base/test_tir_ops.py
index cb7d8c597a..6614f557f4 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -16,7 +16,6 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
import pytest
@@ -52,7 +51,7 @@ def test_const_fold():
def test_const_fold2():
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
tmod = tvm.tir.truncmod
tdiv = tvm.tir.truncdiv
assert (x + 0).same_as(x)
@@ -66,7 +65,7 @@ def test_const_fold2():
def test_const_fold3():
# Test that using ints with logic operations is forbidden
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
for val in [0, 1]:
for func in [tvm.tir.all, tvm.tir.any]:
check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
@@ -84,7 +83,7 @@ def test_const_fold3():
tvm.tir.const(py_func(v1, v2), "bool"),
)
- x = te.var("x", "bool")
+ x = tvm.tir.Var("x", "bool")
true = tvm.tir.const(1, "bool")
false = tvm.tir.const(0, "bool")
@@ -108,11 +107,11 @@ def test_const_fold4():
assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3
x4 = x3 + 0.55
assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6
- x5 = te.ceil(x4)
+ x5 = tvm.tir.ceil(x4)
assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4
x6 = x5.astype("int")
assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6)
- y = (te.round((tvm.tir.const(6.5, "float32") - 1) / 1.5) + 2).astype("int")
+ y = (tvm.tir.round((tvm.tir.const(6.5, "float32") - 1) / 1.5) +
2).astype("int")
assert isinstance(y, tvm.tir.IntImm) and y.value == 6
@@ -126,8 +125,8 @@ def test_binary_dtype_match():
[("uint32", "int32"), "uint32"],
]
for (lhs_dtype, rhs_dtype), out_dtype in rules:
- lhs = te.var("lhs", dtype=lhs_dtype)
- rhs = te.var("rhs", dtype=rhs_dtype)
+ lhs = tvm.tir.Var("lhs", lhs_dtype)
+ rhs = tvm.tir.Var("rhs", rhs_dtype)
out = f(lhs, rhs)
if not is_conditional:
assert out.dtype == out_dtype
@@ -146,8 +145,8 @@ def test_binary_dtype_match():
def verify_callop_float_only(f):
for lhs_dtype in ["int32", "float32", "float64"]:
for rhs_dtype in ["int32", "float32", "float64"]:
- lhs = te.var("lhs", dtype=lhs_dtype)
- rhs = te.var("rhs", dtype=rhs_dtype)
+ lhs = tvm.tir.Var("lhs", lhs_dtype)
+ rhs = tvm.tir.Var("rhs", rhs_dtype)
if "float" not in lhs_dtype and "float" not in rhs_dtype:
check_throws(lambda: f(lhs, rhs))
elif "float" in lhs_dtype:
@@ -176,7 +175,7 @@ def test_binary_dtype_match():
verify_general_dtype_support(lambda a, b: a * b)
verify_general_dtype_support(lambda a, b: a >= b, is_conditional=True)
verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True)
- verify_callop_float_only(lambda a, b: te.power(a, b))
+ verify_callop_float_only(lambda a, b: tvm.tir.power(a, b))
# verify bool & int32 constant folding
assert tvm.tir.const(1) == tvm.tir.const(True)
@@ -185,15 +184,15 @@ def test_binary_dtype_match():
def test_if_then_else():
cases = [
- [(te.var("cond", dtype="bool"), "bool", "int32"), "int32"],
+ [(tvm.tir.Var("cond", "bool"), "bool", "int32"), "int32"],
[(True, "int32", "float32"), "float32"],
[(False, "int32", "int64"), "int64"],
- [(te.var("cond", dtype="bool"), "uint32", "int32"), "uint32"],
- [(te.var("cond", dtype="int32"), "uint32", "int32"), "uint32"],
+ [(tvm.tir.Var("cond", "bool"), "uint32", "int32"), "uint32"],
+ [(tvm.tir.Var("cond", "int32"), "uint32", "int32"), "uint32"],
]
for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
- lhs = te.var("lhs", dtype=lhs_dtype)
- rhs = te.var("rhs", dtype=rhs_dtype)
+ lhs = tvm.tir.Var("lhs", lhs_dtype)
+ rhs = tvm.tir.Var("rhs", rhs_dtype)
if cond is True or cond is False:
out = tvm.tir.if_then_else(cond, lhs, rhs)
out2 = tvm.tir.if_then_else(not cond, rhs, lhs)
diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py
b/tests/python/tir-base/test_tir_structural_equal_hash.py
index 296450b9a2..e7113acdad 100644
--- a/tests/python/tir-base/test_tir_structural_equal_hash.py
+++ b/tests/python/tir-base/test_tir_structural_equal_hash.py
@@ -17,7 +17,7 @@
import tvm
import numpy as np
import pytest
-from tvm import te
+
from tvm_ffi.access_path import AccessPath
from tvm.script import tir as T, ir as I
@@ -73,9 +73,9 @@ def test_exprs():
# save load json
x = tvm.tir.const(1, "int32")
y = tvm.tir.const(10, "int32")
- vx = te.var("x")
- vy = te.var("y")
- vz = te.var("z")
+ vx = tvm.tir.Var("x", "int32")
+ vy = tvm.tir.Var("y", "int32")
+ vz = tvm.tir.Var("z", "int32")
zx = vx + vx
zy = vy + vy
@@ -105,8 +105,8 @@ def test_exprs():
def test_prim_func():
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x + y))
func1 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(y + x))
@@ -132,9 +132,9 @@ def test_prim_func():
def test_prim_func_param_count_mismatch():
- x = te.var("x")
- y = te.var("y")
- z = te.var("z")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x))
func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x))
@@ -146,9 +146,9 @@ def test_prim_func_param_count_mismatch():
def test_prim_func_param_dtype_mismatch():
- x = te.var("x")
- y_0 = te.var("y", dtype="int32")
- y_1 = te.var("z", dtype="float32")
+ x = tvm.tir.Var("x", "int32")
+ y_0 = tvm.tir.Var("y", "int32")
+ y_1 = tvm.tir.Var("z", "float32")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x))
func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x))
@@ -159,10 +159,10 @@ def test_prim_func_param_dtype_mismatch():
def test_prim_func_body_mismatch():
- x_0 = te.var("x")
- y_0 = te.var("y")
- x_1 = te.var("x")
- y_1 = te.var("y")
+ x_0 = tvm.tir.Var("x", "int32")
+ y_0 = tvm.tir.Var("y", "int32")
+ x_1 = tvm.tir.Var("x", "int32")
+ y_1 = tvm.tir.Var("y", "int32")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0))
func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1))
@@ -206,16 +206,6 @@ def test_attrs():
def test_stmt():
- x = te.var("x")
- y = te.var("y")
- n = 128
- A = te.placeholder((n, n), name="A")
- B = te.placeholder((n, n), name="B")
- ii = te.var("i")
- jj = te.var("j")
-
- n = te.var("n")
-
@T.prim_func(private=True, check_well_formed=False)
def func2(A: T.handle, n_param: T.int32):
n_var = T.var("int32")
@@ -230,7 +220,7 @@ def test_stmt():
def test_buffer_storage_scope():
- x = te.var("x", dtype="handle")
+ x = tvm.tir.Var("x", "handle")
buffer_local_0 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
buffer_local_1 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
@@ -248,7 +238,7 @@ def test_buffer_storage_scope():
def test_buffer_map_mismatch():
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
buffer_0 = tvm.tir.decl_buffer((10, 10))
buffer_0_clone = tvm.tir.decl_buffer((10, 10))
buffer_1 = tvm.tir.decl_buffer((10, 20))
@@ -268,8 +258,8 @@ def test_buffer_map_mismatch():
def test_buffer_map_length_mismatch():
- x = te.var("x")
- y = te.var("x")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("x", "int32")
buffer_0 = tvm.tir.decl_buffer((10, 10))
buffer_1 = tvm.tir.decl_buffer((10, 20))
diff --git
a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
index 1be5e57ba1..8b8e187341 100644
--- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
@@ -17,7 +17,6 @@
import hashlib
import tvm
-from tvm import te, topi
from tvm.ir.base import save_json
from tvm.ir.module import IRModule
from tvm.script import tir as T
@@ -29,15 +28,15 @@ from tvm.script import tir as T
# A test program which gives the opportunity for the CSE pass to introduce two
new variables,
# at two different levels
def test_cse():
- z1 = te.var("z1")
- z2 = te.var("z2")
- z3 = te.var("z3")
- i1 = te.var("i1")
- i2 = te.var("i2")
- x = te.var("x")
- y = te.var("y")
- a = te.var("a")
- b = te.var("b")
+ z1 = tvm.tir.Var("z1", "int32")
+ z2 = tvm.tir.Var("z2", "int32")
+ z3 = tvm.tir.Var("z3", "int32")
+ i1 = tvm.tir.Var("i1", "int32")
+ i2 = tvm.tir.Var("i2", "int32")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ a = tvm.tir.Var("a", "int32")
+ b = tvm.tir.Var("b", "int32")
dtype = "int32"
buffer = tvm.tir.decl_buffer((50,), dtype)
# Test prog :
@@ -152,12 +151,12 @@ def test_cse():
# branch, not before the whole If (otherwise that would lead to some
computations being computed
# for nothing when it is the Else branch that is executed).
def test_cse_ifNode_1():
- b = te.var("b")
- i1 = te.var("i1")
- i2 = te.var("i2")
- i3 = te.var("i3")
- y = te.var("y")
- z = te.var("z")
+ b = tvm.tir.Var("b", "int32")
+ i1 = tvm.tir.Var("i1", "int32")
+ i2 = tvm.tir.Var("i2", "int32")
+ i3 = tvm.tir.Var("i3", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
dtype = "int32"
buffer = tvm.tir.decl_buffer((50,), dtype)
# Test prog :
@@ -208,12 +207,12 @@ def test_cse_ifNode_1():
# In this case, the CSE pass should introduce the redundant computation before
the whole If node,
# because regardless of the execution path, it is going to be computed.
def test_cse_ifNode_2():
- b = te.var("b")
- i1 = te.var("i1")
- i2 = te.var("i2")
- i3 = te.var("i3")
- y = te.var("y")
- z = te.var("z")
+ b = tvm.tir.Var("b", "int32")
+ i1 = tvm.tir.Var("i1", "int32")
+ i2 = tvm.tir.Var("i2", "int32")
+ i3 = tvm.tir.Var("i3", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
dtype = "int32"
buffer = tvm.tir.decl_buffer((50,), dtype)
# Test prog :
@@ -261,12 +260,12 @@ def test_cse_ifNode_2():
# and in the rest of the program.
#
-------------------------------------------------------------------------------------------------
def test_cse_cascade():
- i1 = te.var("i1")
- i2 = te.var("i2")
- i3 = te.var("i3")
- x = te.var("x")
- y = te.var("y")
- z = te.var("z")
+ i1 = tvm.tir.Var("i1", "int32")
+ i2 = tvm.tir.Var("i2", "int32")
+ i3 = tvm.tir.Var("i3", "int32")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
dtype = "int32"
buffer = tvm.tir.decl_buffer((50,), dtype)
# Test prog :
@@ -326,10 +325,10 @@ def test_cse_cascade():
# A test which ensures that we don't perform normalizations outside of
introduced variables
#
-----------------------------------------------------------------------------------------
def test_no_normalization_without_commoning():
- x = te.var("x")
- y = te.var("y")
- z = te.var("z")
- a = te.var("a")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
+ a = tvm.tir.Var("a", "int32")
# Test prog :
# let a = x + (y + z) in a
body = tvm.tir.LetStmt(a, x + (y + z), tvm.tir.Evaluate(a))
@@ -418,8 +417,8 @@ def test_deterministic_cse():
NUM_TERMS = 10
REPEATS = 10
- x = te.var("x")
- result = te.var("result")
+ x = tvm.tir.Var("x", "int32")
+ result = tvm.tir.Var("result", "int32")
offsets = sorted([i + 1 for i in range(NUM_TERMS)])
inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)]
diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
index a0a6ab2508..32367e0aa0 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
@@ -16,7 +16,6 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
import numpy as np
@@ -45,18 +44,34 @@ def check_value(expr, variables, data, fref):
num_vars = len(variables)
assert num_vars >= 1 and all(len(row) == num_vars for row in data)
- placeholders = [
- te.placeholder((n,), name=f"v{i}", dtype=variables[i].dtype) for i in
range(num_vars)
+ # Build input and output buffers
+ input_bufs = [
+ tvm.tir.decl_buffer((n,), dtype=variables[i].dtype, name=f"v{i}") for
i in range(num_vars)
]
+ out_buf = tvm.tir.decl_buffer((n,), dtype=expr.dtype, name="C")
- def make_binds(i):
- x = expr
+ # Build loop body: for each i, bind variables[j] = input_bufs[j][i], then
store expr to out
+ loop_var = tvm.tir.Var("i", "int32")
+
+ def make_store(i_var):
+ # Build the expression with each variable bound to the corresponding
buffer load
+ result = expr
for j in range(num_vars - 1, -1, -1):
- x = tvm.tir.Let(variables[j], placeholders[j][i], x)
- return x
+ result = tvm.tir.Let(variables[j],
tvm.tir.BufferLoad(input_bufs[j], [i_var]), result)
+ return tvm.tir.BufferStore(out_buf, result, [i_var])
+
+ loop = tvm.tir.For(
+ loop_var,
+ tvm.tir.const(0, "int32"),
+ tvm.tir.const(n, "int32"),
+ tvm.tir.ForKind.SERIAL,
+ make_store(loop_var),
+ )
+
+ prim_func = tvm.tir.PrimFunc(input_bufs + [out_buf], loop)
+ prim_func = prim_func.with_attr({"tir.noalias": True, "global_symbol":
"main"})
+ f = tvm.compile(prim_func, "llvm")
- C = te.compute((n,), make_binds)
- f = tvm.compile(te.create_prim_func(placeholders + [C]), "llvm")
arrays = [
tvm.runtime.tensor(np.array([row[j] for row in data],
dtype=variables[j].dtype))
for j in range(num_vars)
@@ -81,32 +96,32 @@ def get_ref_data():
def test_lower_floordiv():
data = get_ref_data()
for dtype in ["int32", "int64", "int16"]:
- x = te.var("x", dtype=dtype)
- y = te.var("y", dtype=dtype)
+ x = tvm.tir.Var("x", dtype)
+ y = tvm.tir.Var("y", dtype)
zero = tvm.tir.const(0, dtype)
# no constraints
- res = lower_intrin([x, y], tvm.te.floordiv(x, y))
+ res = lower_intrin([x, y], tvm.tir.floordiv(x, y))
check_value(res, [x, y], data, lambda a, b: a // b)
# rhs >= 0
- res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x,
y), zero))
+ res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.tir.floordiv(x,
y), zero))
check_value(res, [x, y], data, lambda a, b: a // b if b > 0 else 0)
# involves max
res = lower_intrin(
- [x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y),
zero), zero)
+ [x, y], tvm.tir.Select(y >= 0, tvm.tir.max(tvm.tir.floordiv(x, y),
zero), zero)
)
check_value(res, [x, y], data, lambda a, b: max(a // b, 0) if b > 0
else 0)
# lhs >= 0
res = lower_intrin(
- [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0),
tvm.te.floordiv(x, y), zero)
+ [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0),
tvm.tir.floordiv(x, y), zero)
)
check_value(res, [x, y], data, lambda a, b: a // b if b > 0 and a >= 0
else 0)
# const power of two
- res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8,
dtype=dtype)))
+ res = lower_intrin([x, y], tvm.tir.floordiv(x, tvm.tir.const(8,
dtype=dtype)))
check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda
a, b: a // b)
# floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
- tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
+ tvm.tir.floordiv(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
)
check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda
a, b: (a + 4) // b)
@@ -115,27 +130,27 @@ def test_lower_floordiv():
def test_lower_floormod():
data = get_ref_data()
for dtype in ["int32", "int64", "int16"]:
- x = te.var("x", dtype=dtype)
- y = te.var("y", dtype=dtype)
+ x = tvm.tir.Var("x", dtype)
+ y = tvm.tir.Var("y", dtype)
zero = tvm.tir.const(0, dtype)
# no constraints
- res = lower_intrin([x, y], tvm.te.floormod(x, y))
+ res = lower_intrin([x, y], tvm.tir.floormod(x, y))
check_value(res, [x, y], data, lambda a, b: a % b)
# rhs >= 0
- res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x,
y), zero))
+ res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.tir.floormod(x,
y), zero))
check_value(res, [x, y], data, lambda a, b: a % b if b > 0 else 0)
# lhs >= 0
res = lower_intrin(
- [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0),
tvm.te.floormod(x, y), zero)
+ [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0),
tvm.tir.floormod(x, y), zero)
)
check_value(res, [x, y], data, lambda a, b: a % b if b > 0 and a >= 0
else 0)
# const power of two
- res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8,
dtype=dtype)))
+ res = lower_intrin([x, y], tvm.tir.floormod(x, tvm.tir.const(8,
dtype=dtype)))
check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda
a, b: a % b)
# floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
- tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
+ tvm.tir.floormod(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
)
check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda
a, b: (a + 4) % b)
@@ -149,26 +164,26 @@ def test_lower_floordiv_overflow_checks():
"""
# Check 3: (b-1) - a_min must not overflow (numerator and C++ int64).
# x (int64) full range -> min_value = -2^63. With b = 3: numerator = 2 -
(-2^63) > LLONG_MAX.
- x = te.var("x", dtype="int64")
- res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int64")))
+ x = tvm.tir.Var("x", "int64")
+ res = lower_intrin([x], tvm.tir.floordiv(x, tvm.tir.const(3, "int64")))
data_check3 = [(-(2**63),), (0,), (100,)]
check_value(res, [x], data_check3, lambda a: a // 3)
# Check 4: c_value * b_value must not overflow dtype.
# x (int16) full range -> min_value = -32768, c = ceil(32770/3) = 10923;
10923*3 > 32767.
- x = te.var("x", dtype="int16")
- res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int16")))
+ x = tvm.tir.Var("x", "int16")
+ res = lower_intrin([x], tvm.tir.floordiv(x, tvm.tir.const(3, "int16")))
data_check4 = [(-32768,), (0,), (100,)]
check_value(res, [x], data_check4, lambda a: a // 3)
# Check 5: a_max + b*c must not overflow (offset numerator).
# tir.min(tir.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4;
a_max + 12 > 32767.
# In practice this path may not be triggered. This test still validates
correct lowering.
- x = te.var("x", dtype="int16")
+ x = tvm.tir.Var("x", "int16")
clamped = tvm.tir.min(
tvm.tir.max(x, tvm.tir.const(-10, "int16")), tvm.tir.const(32758,
"int16")
)
- res = lower_intrin([x], tvm.te.floordiv(clamped, tvm.tir.const(3,
"int16")))
+ res = lower_intrin([x], tvm.tir.floordiv(clamped, tvm.tir.const(3,
"int16")))
data_check5 = [(-10,), (0,), (32758,), (32757,)]
check_value(res, [x], data_check5, lambda a: (min(max(a, -10), 32758)) //
3)
diff --git a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py
b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py
index 553c745770..8bee82f02c 100644
--- a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py
+++ b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py
@@ -16,7 +16,6 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
def test_prim_func_pass():
@@ -30,8 +29,8 @@ def test_prim_func_pass():
def transform_function(self, func, mod, ctx):
return self.new_func
- x = te.var("x")
- y = te.var("y")
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
@@ -51,7 +50,7 @@ def test_cow_pass():
return f
pidentity = tvm.tir.transform.Apply(fapply)
- x = te.var("x")
+ x = tvm.tir.Var("x", "int32")
func = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x)).with_attr("target_bits",
32)
func_hash = func.__hash__()
mod = tvm.IRModule({"main": func})
diff --git a/tests/python/tir-transform/test_tir_transform_remove_no_op.py
b/tests/python/tir-transform/test_tir_transform_remove_no_op.py
index 5d712a21f7..91902ea400 100644
--- a/tests/python/tir-transform/test_tir_transform_remove_no_op.py
+++ b/tests/python/tir-transform/test_tir_transform_remove_no_op.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te
from tvm.script import tir as T
import tvm.testing
@@ -27,11 +26,11 @@ def nop():
def test_remove_no_op():
- i = te.var("i")
- j = te.var("j")
- k = te.var("k")
- m = te.var("m")
- n = te.var("n")
+ i = tvm.tir.Var("i", "int32")
+ j = tvm.tir.Var("j", "int32")
+ k = tvm.tir.Var("k", "int32")
+ m = tvm.tir.Var("m", "int32")
+ n = tvm.tir.Var("n", "int32")
dtype = "int64"
Ab = tvm.tir.decl_buffer((n,), dtype)
stmt = tvm.tir.For(