This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c0828bc8ad [REFACTOR][TEST] Migrate tir-transform tests from TE to 
TVMScript (#18805)
c0828bc8ad is described below

commit c0828bc8ad0afe4ff637e59114700efcf283ca14
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 21 18:23:48 2026 -0500

    [REFACTOR][TEST] Migrate tir-transform tests from TE to TVMScript (#18805)
    
    This PR migrates te.var/te.compute usage to direct tvm.tir.Var
    and PrimFunc construction in tir-transform test files.
---
 .../python/arith/test_arith_canonical_simplify.py  | 106 ++---
 tests/python/arith/test_arith_const_int_bound.py   |  65 ++-
 tests/python/arith/test_arith_deduce_bound.py      |  38 +-
 tests/python/arith/test_arith_detect_clip_bound.py |  11 +-
 .../arith/test_arith_detect_linear_equation.py     |  11 +-
 tests/python/arith/test_arith_intset.py            |  32 +-
 tests/python/arith/test_arith_modular_set.py       |  41 +-
 tests/python/arith/test_arith_rewrite_simplify.py  | 462 +++++++++++----------
 .../arith/test_arith_solve_linear_equations.py     |  14 +-
 .../arith/test_arith_solve_linear_inequality.py    |  20 +-
 tests/python/ir/test_ir_container.py               |  18 +-
 .../test_tir_analysis_expr_deep_equal.py           |   7 +-
 .../tir-analysis/test_tir_analysis_verify_ssa.py   |   7 +-
 tests/python/tir-base/test_tir_buffer.py           |  45 +-
 tests/python/tir-base/test_tir_constructor.py      |  13 +-
 tests/python/tir-base/test_tir_intrin.py           |   2 +-
 tests/python/tir-base/test_tir_nodes.py            |  74 ++--
 tests/python/tir-base/test_tir_ops.py              |  31 +-
 .../tir-base/test_tir_structural_equal_hash.py     |  50 +--
 .../test_tir_transform_common_subexpr_elim.py      |  67 ++-
 .../test_tir_transform_lower_intrin.py             |  75 ++--
 .../test_tir_transform_prim_func_pass.py           |   7 +-
 .../test_tir_transform_remove_no_op.py             |  11 +-
 23 files changed, 616 insertions(+), 591 deletions(-)

diff --git a/tests/python/arith/test_arith_canonical_simplify.py 
b/tests/python/arith/test_arith_canonical_simplify.py
index 733d1d13b3..ca40bd2f25 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -44,7 +44,7 @@ class CanonicalChecker:
 
 def test_mul_sum_simplify():
     ck = CanonicalChecker()
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     ck.verify(2 + (3 * x + z + y + 1) * 4 + x, x * 13 + z * 4 + y * 4 + 6)
     ck.verify(x * 3 - 4 * x + 1, 1 - x)
@@ -56,8 +56,8 @@ def test_mul_sum_simplify():
     ck.verify(tmod(x + y + x + y * 3, 2), 0)
 
     # floordiv
-    fld = tvm.te.floordiv
-    flm = tvm.te.floormod
+    fld = tvm.tir.floordiv
+    flm = tvm.tir.floormod
     ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2))
     ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x)
     ck.verify(flm(x + y + x + y * 3, 2), 0)
@@ -66,7 +66,7 @@ def test_mul_sum_simplify():
 
 def test_split_index_simplify():
     ck = CanonicalChecker()
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     # trucdiv
     tdiv = tvm.tir.truncdiv
@@ -96,8 +96,8 @@ def test_split_index_simplify():
     ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
 
     # floordiv
-    fld = tvm.te.floordiv
-    flm = tvm.te.floormod
+    fld = tvm.tir.floordiv
+    flm = tvm.tir.floormod
     ck.verify(fld(x * 5, 2), fld(x * 5, 2))
     ck.verify(fld(x, 3) * 3 + flm(x, 3), x)
     ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)
@@ -114,7 +114,7 @@ def test_split_index_simplify():
 
 def test_div_simplify():
     ck = CanonicalChecker()
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     tdiv = tvm.tir.truncdiv
 
     # truc div
@@ -129,7 +129,7 @@ def test_div_simplify():
     ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16))
 
     # floordiv
-    fld = tvm.te.floordiv
+    fld = tvm.tir.floordiv
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)
     ck.verify(fld(16 + 48 * x, 16), x * 3 + 1)
     ck.verify(fld(17 + 48 * x, 16), x * 3 + 1)
@@ -154,8 +154,8 @@ def test_fp16_const_fold():
 
 def test_floormod_simplify():
     ck = CanonicalChecker()
-    flm = tvm.te.floormod
-    x, y = te.var("x"), te.var("y")
+    flm = tvm.tir.floormod
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + 
y + 12, 16))
     ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4)
 
@@ -164,26 +164,26 @@ def test_floormod_simplify():
 
 def test_canonical_mixed():
     ck = CanonicalChecker()
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     z = tvm.tir.const(3, "int32")
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
     ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0)
     ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0)
     ck.verify(x - 2 < 3, x < 5)
-    ck.verify(tvm.te.max(x, 1) - tvm.te.max(x, 1), 0)
-    ck.verify(tvm.te.min(x, 1) - tvm.te.min(x, 1), 0)
+    ck.verify(tvm.tir.max(x, 1) - tvm.tir.max(x, 1), 0)
+    ck.verify(tvm.tir.min(x, 1) - tvm.tir.min(x, 1), 0)
     ck.verify(x * x - x * x, 0)
     ck.verify(tmod(tdiv(tmod(x, 20), 2) * 2, 4), tdiv(tmod(x, 4), 2) * 2)
 
-    fld = tvm.te.floordiv
+    fld = tvm.tir.floordiv
     ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0)
     ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0)
 
 
 def test_reduce_combiner_simplify():
     ck = CanonicalChecker()
-    dummy = te.var("dummy")
+    dummy = tvm.tir.Var("dummy", "int32")
     comm_reducer = te.comm_reducer
     prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tir.const(1, t0))
 
@@ -262,13 +262,13 @@ def test_reduce_simplify():
     ck.verify(te.sum(A[3], []), A[3])
     ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, 
dtype="float32"))
     # The rule below is not typical, removed for now
-    ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k))
+    ck.verify(te.sum(tvm.tir.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), 
k))
 
 
 def test_simplify_if_then_else():
     ck = CanonicalChecker()
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
     # simplification that takes condition into account.
@@ -315,8 +315,8 @@ def test_simplify_if_then_else():
 
 def test_complex_cases():
     ck = CanonicalChecker()
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
     res2 = (
@@ -346,30 +346,30 @@ def test_complex_cases():
 def test_simplify_cast():
     ck = CanonicalChecker()
     tcast = tvm.tir.Cast
-    fld = tvm.te.floordiv
-    flm = tvm.te.floormod
+    fld = tvm.tir.floordiv
+    flm = tvm.tir.floormod
     # cast(i64, i + j + 1) - cast(i64, i)
-    i = te.var("i", dtype="int32")
-    j = te.var("j", dtype="int32")
+    i = tvm.tir.Var("i", "int32")
+    j = tvm.tir.Var("j", "int32")
     res = tcast("int64", i + j + 1) - tcast("int64", i)
     ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
     # cast(i32, i + j + 1) - cast(i32, i)
-    i = te.var("i", dtype="int64")
-    j = te.var("j", dtype="int64")
+    i = tvm.tir.Var("i", "int64")
+    j = tvm.tir.Var("j", "int64")
     ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
     ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
     res = tcast("int32", i + j + 1) - tcast("int32", i)
     ck.verify(res, tcast("int32", j) + 1)
     # cast(i32, i + j - 100)
-    i = te.var("i", dtype="int64")
-    j = te.var("j", dtype="int64")
+    i = tvm.tir.Var("i", "int64")
+    j = tvm.tir.Var("j", "int64")
     ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1))
     ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
     res = tcast("int32", i + j - 100)
     ck.verify(res, res)
     # cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
     # - cast(i32, flm(axis, 7i64) * 2i64)
-    axis = te.var("axis", dtype="int64")
+    axis = tvm.tir.Var("axis", "int64")
     ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
     res = (
         tcast(
@@ -385,26 +385,26 @@ def test_simplify_cast():
 
 def test_simplify_normalize_min_value_expr():
     ck = CanonicalChecker()
-    x = te.var("x", "int32")
+    x = tvm.tir.Var("x", "int32")
 
-    ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32"))
-    ck.verify(te.min_value("int32") + x == 0, tir.const(False))
-    ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32"))
-    ck.verify(0 == te.min_value("int32") + x, tir.const(False))
-    ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32"))
-    ck.verify(x + te.min_value("int32") == 0, tir.const(False))
-    ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32"))
-    ck.verify(0 == x + te.min_value("int32"), tir.const(False))
+    ck.verify(tvm.tir.min_value("int32") - x == 0, x == 
tvm.tir.min_value("int32"))
+    ck.verify(tvm.tir.min_value("int32") + x == 0, tir.const(False))
+    ck.verify(0 == tvm.tir.min_value("int32") - x, x == 
tvm.tir.min_value("int32"))
+    ck.verify(0 == tvm.tir.min_value("int32") + x, tir.const(False))
+    ck.verify(-x + tvm.tir.min_value("int32") == 0, x == 
tvm.tir.min_value("int32"))
+    ck.verify(x + tvm.tir.min_value("int32") == 0, tir.const(False))
+    ck.verify(0 == -x + tvm.tir.min_value("int32"), x == 
tvm.tir.min_value("int32"))
+    ck.verify(0 == x + tvm.tir.min_value("int32"), tir.const(False))
 
 
 def test_proddiv_simplify():
     ck = CanonicalChecker()
-    flm = tvm.te.floormod
-    fld = tvm.te.floordiv
-    tdiv = tvm.te.truncdiv
-    tmod = tvm.te.truncmod
+    flm = tvm.tir.floormod
+    fld = tvm.tir.floordiv
+    tdiv = tvm.tir.truncdiv
+    tmod = tvm.tir.truncmod
 
-    x, y, z = te.var("x"), te.var("y"), te.var("y")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("y", "int32")
 
     ck.verify(flm(x * 32 * x, x), 0)
     ck.verify(flm(z * x * 32 * x * y, x * z), 0)
@@ -428,15 +428,15 @@ def test_proddiv_simplify():
 
 def test_floormod_two():
     ck = CanonicalChecker()
-    flm = tvm.te.floormod
-    x, y = te.var("x"), te.var("y")
+    flm = tvm.tir.floormod
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)
 
 
 def test_simplify_le():
     ck = CanonicalChecker()
     # Case 1. Ignore the extra expr if it's small than the division number
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     ck.analyzer.bind(y, tvm.ir.Range(0, 8))
     ck.analyzer.bind(z, tvm.ir.Range(0, 2))
     ck.verify(x * 8 + y < 16, x < 2)
@@ -449,16 +449,16 @@ def test_simplify_le():
 
     ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
 
-    n = te.size_var("n")
+    n = tvm.tir.SizeVar("n", "int32")
     ck.verify(x * 8 + y < n, x * 8 + y < n)
 
     # Case 2. Simplify the extra expr
     x1, x2, ty, tx, vec = (
-        tvm.te.var("x1"),
-        tvm.te.var("x2"),
-        tvm.te.var("ty"),
-        tvm.te.var("tx"),
-        tvm.te.var("vec"),
+        tvm.tir.Var("x1", "int32"),
+        tvm.tir.Var("x2", "int32"),
+        tvm.tir.Var("ty", "int32"),
+        tvm.tir.Var("tx", "int32"),
+        tvm.tir.Var("vec", "int32"),
     )
     ck.analyzer.bind(x1, tvm.ir.Range(0, 2))
     ck.analyzer.bind(x2, tvm.ir.Range(0, 3))
@@ -472,7 +472,7 @@ def test_simplify_le():
     ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8)
 
     # Case 3. No failure
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     ck.analyzer.bind(y, tvm.ir.Range(0, 1024))
     ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)
 
diff --git a/tests/python/arith/test_arith_const_int_bound.py 
b/tests/python/arith/test_arith_const_int_bound.py
index 8728df7e3f..bd2a39b4f4 100644
--- a/tests/python/arith/test_arith_const_int_bound.py
+++ b/tests/python/arith/test_arith_const_int_bound.py
@@ -22,7 +22,6 @@ import pytest
 import tvm
 import tvm.testing
 
-from tvm import te
 from tvm.arith import ConstIntBound
 
 NEG_INF = ConstIntBound.NEG_INF
@@ -69,15 +68,15 @@ class BaseCompare:
 
 class TestDataType(BaseCompare):
     test_case = tvm.testing.parameter(
-        TestCase(te.var("x", dtype="int64"), (NEG_INF, POS_INF)),
-        TestCase(te.var("x", dtype="int8"), (-128, 127)),
-        TestCase(te.var("x", dtype="uint8"), (0, 255)),
-        TestCase(te.size_var("x", dtype="int32"), (0, POS_INF)),
+        TestCase(tvm.tir.Var("x", "int64"), (NEG_INF, POS_INF)),
+        TestCase(tvm.tir.Var("x", "int8"), (-128, 127)),
+        TestCase(tvm.tir.Var("x", "uint8"), (0, 255)),
+        TestCase(tvm.tir.SizeVar("x", "int32"), (0, POS_INF)),
     )
 
 
 class TestCastBound(BaseCompare):
-    x = te.var("x", dtype="int8")
+    x = tvm.tir.Var("x", "int8")
     tmod = tvm.tir.truncmod
 
     test_case = tvm.testing.parameter(
@@ -87,8 +86,8 @@ class TestCastBound(BaseCompare):
 
 
 class TestAddSubBound(BaseCompare):
-    x = te.var("x", "int64")
-    y = te.var("y", "int64")
+    x = tvm.tir.Var("x", "int64")
+    y = tvm.tir.Var("y", "int64")
 
     test_case = tvm.testing.parameter(
         TestCase(x + y, (NEG_INF, POS_INF)),
@@ -119,7 +118,7 @@ class TestBoundsUsingReciprocals(BaseCompare):
     achieve its minimum while `A*B` simultaneously achieves its maximum.
     """
 
-    A, B, C = [te.var(letter, "int64") for letter in "ABC"]
+    A, B, C = [tvm.tir.Var(letter, "int64") for letter in "ABC"]
 
     symmetric_bounds = {A: (1, 4095), B: (1, 4095), C: (2048, 2048)}
     asymmetric_bounds = {A: (1, 1024), B: (1, POS_INF), C: (2048, 2048)}
@@ -137,7 +136,7 @@ class TestBoundsUsingReciprocals(BaseCompare):
 
 
 class TestMulBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(x * y + 20, (0, 60), {x: (-2, 4), y: (4, 10)}),
@@ -147,7 +146,7 @@ class TestMulBound(BaseCompare):
 
 
 class TestTruncDivBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     expr = tvm.tir.truncdiv(x, y)
 
@@ -160,7 +159,7 @@ class TestTruncDivBound(BaseCompare):
 
 
 class TestTruncModBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     expr = tvm.tir.truncmod(x, y)
 
@@ -172,9 +171,9 @@ class TestTruncModBound(BaseCompare):
 
 
 class TestFloorDivBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
-    ux = te.var("x", dtype="uint32")
-    uy = te.var("y", dtype="uint32")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
+    ux = tvm.tir.Var("x", "uint32")
+    uy = tvm.tir.Var("y", "uint32")
 
     test_case = tvm.testing.parameter(
         TestCase(x // y, (-9 // 4, None), {x: (-9, 4), y: (4, 10)}),
@@ -186,7 +185,7 @@ class TestFloorDivBound(BaseCompare):
 
 
 class TestFloorModBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(x % y, (0, 9), {x: (-9, 4), y: (4, 10)}),
@@ -196,18 +195,18 @@ class TestFloorModBound(BaseCompare):
 
 
 class TestMinMaxBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     test_case = tvm.testing.parameter(
-        TestCase(tvm.te.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}),
-        TestCase(tvm.te.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y: 
(4, 10)}),
-        TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y: 
(4, 10)}),
-        TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4, 
10)}),
+        TestCase(tvm.tir.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}),
+        TestCase(tvm.tir.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y: 
(4, 10)}),
+        TestCase(tvm.tir.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y: 
(4, 10)}),
+        TestCase(tvm.tir.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4, 
10)}),
     )
 
 
 class TestSelectBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(
@@ -219,7 +218,7 @@ class TestSelectBound(BaseCompare):
 
 
 class TestShiftAndBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(x >> y, (-3, 2), {x: (-9, 11), y: (2, 10)}),
@@ -229,7 +228,7 @@ class TestShiftAndBound(BaseCompare):
 
 
 class TestMixIndexBound(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
 
@@ -243,15 +242,15 @@ class TestMixIndexBound(BaseCompare):
 
 
 class TestLetBound(BaseCompare):
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     test_case = tvm.testing.parameter(
         TestCase(tvm.tir.Let(x, 1, x + 1), (2, 2)),
     )
 
 
 class TestFloorModNegativeDivisor(BaseCompare):
-    flm, fld = tvm.te.floormod, tvm.te.floordiv
-    a, b = te.var("a"), te.var("b")
+    flm, fld = tvm.tir.floormod, tvm.tir.floordiv
+    a, b = tvm.tir.Var("a", "int32"), tvm.tir.Var("b", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(a % b, (-4, 6), {a: (0, 6), b: (-5, 7)}),
@@ -264,7 +263,7 @@ class TestDivModAssumeNoZeroDivisor(BaseCompare):
     from symbolic shape programs
     """
 
-    a, b = te.var("a"), te.var("b")
+    a, b = tvm.tir.Var("a", "int32"), tvm.tir.Var("b", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(a // b, (0, 6), {a: (0, 6), b: (0, POS_INF)}),
@@ -273,7 +272,7 @@ class TestDivModAssumeNoZeroDivisor(BaseCompare):
 
 
 class TestMultipleCondition(BaseCompare):
-    a = te.var("a")
+    a = tvm.tir.Var("a", "int32")
     test_case = tvm.testing.parameter(
         TestCase(
             a % 58 - 1,
@@ -285,14 +284,14 @@ class TestMultipleCondition(BaseCompare):
 
 
 class TestBroadcastBound(BaseCompare):
-    a = te.var("a")
+    a = tvm.tir.Var("a", "int32")
     test_case = tvm.testing.parameter(
         TestCase(tvm.tir.Broadcast(a, 4), (0, 128), {a: (0, 128)}),
     )
 
 
 class TestRampBound(BaseCompare):
-    a = te.var("a")
+    a = tvm.tir.Var("a", "int32")
     test_case = tvm.testing.parameter(
         TestCase(tvm.tir.Ramp(a, 2, 4) + 2, (2, 128 + 2 * 3 + 2), {a: (0, 
128)}),
     )
@@ -300,8 +299,8 @@ class TestRampBound(BaseCompare):
 
 class TestModularSetBound(BaseCompare):
     analyzer = tvm.arith.Analyzer()
-    tx = tvm.te.var("tx", dtype="int32")
-    bx = tvm.te.var("bx", dtype="int32")
+    tx = tvm.tir.Var("tx", "int32")
+    bx = tvm.tir.Var("bx", "int32")
 
     expr = (bx * 2048 + tx * 16) % 7168
 
diff --git a/tests/python/arith/test_arith_deduce_bound.py 
b/tests/python/arith/test_arith_deduce_bound.py
index a36fd21479..1a6bfb4925 100644
--- a/tests/python/arith/test_arith_deduce_bound.py
+++ b/tests/python/arith/test_arith_deduce_bound.py
@@ -17,22 +17,22 @@
 import pytest
 import tvm
 import tvm.testing
-from tvm import te
+
 from tvm.tir.buffer import decl_buffer
 
 
 def test_deduce():
-    a = te.var("a")
-    b = te.var("b")
-    c = te.var("c")
-    d = te.var("d")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
+    c = tvm.tir.Var("c", "int32")
+    d = tvm.tir.Var("d", "int32")
 
     b_s = tvm.arith.IntervalSet(2, 3)
     c_s = tvm.arith.IntervalSet(10, 15)
     d_s = tvm.arith.IntervalSet(-3, -1)
     zero = tvm.tir.const(0, "int32")
 
-    fdiv = tvm.te.floordiv
+    fdiv = tvm.tir.floordiv
 
     e0 = (-b) * a + c - d
     res0 = tvm.arith.deduce_bound(a, e0 >= 0, {b: b_s, c: c_s, d: d_s}, {})
@@ -62,13 +62,13 @@ def test_deduce():
     res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
     tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
 
-    e2 = tvm.te.max(5, a * 4) < 0
+    e2 = tvm.tir.max(5, a * 4) < 0
     res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
     assert str(res2.max_value) == "neg_inf"
     assert str(res2.min_value) == "pos_inf"
 
     # expression containing variable a is on rhs
-    e2 = zero < tvm.te.max(5, a * 4)
+    e2 = zero < tvm.tir.max(5, a * 4)
     res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
     assert str(res2.max_value) == "neg_inf"
     assert str(res2.min_value) == "pos_inf"
@@ -121,10 +121,10 @@ def test_deduce():
 
 
 def test_check():
-    a = te.var("a")
-    b = te.var("b")
-    c = te.var("c")
-    d = te.var("d")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
+    c = tvm.tir.Var("c", "int32")
+    d = tvm.tir.Var("d", "int32")
 
     b_s = tvm.arith.IntervalSet(2, 3)
     c_s = tvm.arith.IntervalSet(5, 7)
@@ -145,8 +145,8 @@ def test_check():
 
 def test_deduce_basic():
     def test_basic(a1, a2, coff):
-        a = te.var("a")
-        b = te.var("b")
+        a = tvm.tir.Var("a", "int32")
+        b = tvm.tir.Var("b", "int32")
         b_s = tvm.arith.IntervalSet(a1, a2)
         e0 = b + a * coff + 3
 
@@ -179,8 +179,8 @@ def test_deduce_basic():
 
 def test_deduce_complex():
     def test_complex(a1, a2, coff):
-        a = te.var("a")
-        b = te.var("b")
+        a = tvm.tir.Var("a", "int32")
+        b = tvm.tir.Var("b", "int32")
         b_s = tvm.arith.IntervalSet(a1, a2)
         e0 = (b * 3 + a * coff) * 4
 
@@ -211,7 +211,7 @@ def test_deduce_complex():
 
 
 def test_deduce_non_support():
-    a = te.var("a")
+    a = tvm.tir.Var("a", "int32")
 
     def test_non_support(lhs):
         res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
@@ -232,7 +232,7 @@ def test_deduce_non_support():
 
 def test_deduce_floordiv():
     def do_test(gen_expr, dom_map, expect_min, expect_max):
-        a = te.var("a")
+        a = tvm.tir.Var("a", "int32")
         expr = gen_expr(a)
         res = tvm.arith.deduce_bound(a, expr, dom_map, dom_map)
         if isinstance(expect_min, str):
@@ -259,7 +259,7 @@ def test_deduce_floordiv():
     do_test(lambda a: 8 // a >= 2, {}, "pos_inf", "neg_inf")
 
     # test nested cases
-    b = te.var("b")
+    b = tvm.tir.Var("b", "int32")
     bs = {b: tvm.arith.IntervalSet(2, 6)}
     do_test(lambda a: b * 3 + a // 8 < 63, bs, "neg_inf", 359)
     do_test(lambda a: b * 3 + a // 8 <= 63, bs, "neg_inf", 367)
diff --git a/tests/python/arith/test_arith_detect_clip_bound.py 
b/tests/python/arith/test_arith_detect_clip_bound.py
index 03fff11f77..830c6d4811 100644
--- a/tests/python/arith/test_arith_detect_clip_bound.py
+++ b/tests/python/arith/test_arith_detect_clip_bound.py
@@ -16,13 +16,12 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
 
 
 def test_basic():
-    a = te.var("a")
-    b = te.var("b")
-    c = te.var("c")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
+    c = tvm.tir.Var("c", "int32")
     m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a])
     tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1)
     assert m[0].value == 2
@@ -40,8 +39,8 @@ def test_basic():
 
 
 def test_trivial_eq():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     m = tvm.arith.detect_clip_bound(b == 3, [a, b])
     tvm.testing.assert_prim_expr_equal(m[2], 3)
     tvm.testing.assert_prim_expr_equal(m[3], 3)
diff --git a/tests/python/arith/test_arith_detect_linear_equation.py 
b/tests/python/arith/test_arith_detect_linear_equation.py
index 829b101af3..d7902047ca 100644
--- a/tests/python/arith/test_arith_detect_linear_equation.py
+++ b/tests/python/arith/test_arith_detect_linear_equation.py
@@ -16,12 +16,11 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
 
 
 def test_basic():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a])
     assert m[0].value == 4
     tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7)
@@ -43,14 +42,14 @@ def test_basic():
     assert len(m) == 1
     tvm.testing.assert_prim_expr_equal(m[0], b * 7)
 
-    c = te.var("c", "uint32")
+    c = tvm.tir.Var("c", "uint32")
     m = tvm.arith.detect_linear_equation(128 - c, [c])
     assert m[0].value == -1
 
 
 def test_multivariate():
-    v = [te.var("v%d" % i) for i in range(4)]
-    b = te.var("b")
+    v = [tvm.tir.Var("v%d" % i, "int32") for i in range(4)]
+    b = tvm.tir.Var("b", "int32")
     m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
 
     tvm.testing.assert_prim_expr_equal(m[0], b + 5)
diff --git a/tests/python/arith/test_arith_intset.py 
b/tests/python/arith/test_arith_intset.py
index 04014ca300..e71b211836 100644
--- a/tests/python/arith/test_arith_intset.py
+++ b/tests/python/arith/test_arith_intset.py
@@ -16,7 +16,7 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
+
 from tvm import tir
 from tvm.arith.analyzer import Analyzer
 
@@ -64,7 +64,7 @@ def test_scalable_vector():
 
 def test_add_sub():
     ck = IntSetChecker()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
     ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y: 
tvm.arith.IntervalSet(1, 11)}, (1, 21))
     ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y: 
tvm.arith.IntervalSet(1, 11)}, (-11, 9))
@@ -72,7 +72,7 @@ def test_add_sub():
 
 def test_mul_div():
     ck = IntSetChecker()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     tdiv = tvm.tir.truncdiv
     ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
@@ -83,20 +83,20 @@ def test_mul_div():
     ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
     ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5))
 
-    fld = tvm.te.floordiv
+    fld = tvm.tir.floordiv
     ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
     ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
 
 
 def test_mod():
     ck = IntSetChecker()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     tmod = tvm.tir.truncmod
     ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
     ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
     ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9))
 
-    flm = tvm.te.floormod
+    flm = tvm.tir.floormod
     ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9))
     ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5))
     ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5))
@@ -104,8 +104,8 @@ def test_mod():
     ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9))
     ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9))
 
-    fld = tvm.te.floordiv
-    z = te.var("z")
+    fld = tvm.tir.floordiv
+    z = tvm.tir.Var("z", "int32")
     ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3))
     ck.verify(
         flm(y, 8),
@@ -124,16 +124,16 @@ def test_mod():
 
 def test_max_min():
     ck = IntSetChecker()
-    x, y = te.var("x"), te.var("y")
-    ck.verify(tvm.te.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11))
-    ck.verify(tvm.te.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, 
(-1, 9))
-    ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y)))
-    ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y)))
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
+    ck.verify(tvm.tir.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 
11))
+    ck.verify(tvm.tir.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, 
(-1, 9))
+    ck.verify(tvm.tir.min(x, y), {}, (tvm.tir.min(x, y), tvm.tir.min(x, y)))
+    ck.verify(tvm.tir.max(x, y), {}, (tvm.tir.max(x, y), tvm.tir.max(x, y)))
 
 
 def test_select():
     ck = IntSetChecker()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: 
tvm.arith.IntervalSet(0, 10)}, (-1, 11))
 
 
@@ -389,8 +389,8 @@ def test_union_lower_bound():
 
 def test_modular_set():
     ck = IntSetChecker()
-    x = tvm.te.var("x", dtype="int32")
-    y = tvm.te.var("y", dtype="int32")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     expr = (x * 2048 + y * 16) % 7168
     ck.verify(
         expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 
3584)}, (0, 7152)
diff --git a/tests/python/arith/test_arith_modular_set.py 
b/tests/python/arith/test_arith_modular_set.py
index 914402fb62..ef7b0593d0 100644
--- a/tests/python/arith/test_arith_modular_set.py
+++ b/tests/python/arith/test_arith_modular_set.py
@@ -16,12 +16,11 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
 
 
 def test_cast():
     analyzer = tvm.arith.Analyzer()
-    x = te.var("x", dtype="int8")
+    x = tvm.tir.Var("x", "int8")
     m = analyzer.modular_set((x * 3).astype("uint32"))
     assert m.coeff == 3
     assert m.base == 0
@@ -32,7 +31,7 @@ def test_cast():
 
 def test_add_sub():
     analyzer = tvm.arith.Analyzer()
-    x, y = te.var("x", "int64"), te.var("y", "int64")
+    x, y = tvm.tir.Var("x", "int64"), tvm.tir.Var("y", "int64")
     m = analyzer.modular_set(x * 6 + y * 4)
     assert m.coeff == 2
     assert m.base == 0
@@ -45,7 +44,7 @@ def test_add_sub():
 
 def test_mul():
     analyzer = tvm.arith.Analyzer()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1))
     assert m.coeff == 4
     assert m.base == 2
@@ -53,7 +52,7 @@ def test_mul():
 
 def test_floormod():
     analyzer = tvm.arith.Analyzer()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256))
     assert m.coeff == 4
     assert m.base == 0
@@ -61,7 +60,7 @@ def test_floormod():
 
 def test_div_shift():
     analyzer = tvm.arith.Analyzer()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     # not sure if x is non-negative
     tdiv = tvm.tir.truncdiv
     m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
@@ -71,7 +70,7 @@ def test_div_shift():
     m = analyzer.modular_set((x * 4 + 2) >> 1)
     assert m.coeff == 2
     assert m.base == 1
-    fld = tvm.te.floordiv
+    fld = tvm.tir.floordiv
     m = analyzer.modular_set(fld(x * 4 + 2, 2))
     assert m.coeff == 2
     assert m.base == 1
@@ -84,7 +83,7 @@ def test_div_shift():
 
 def test_mod():
     analyzer = tvm.arith.Analyzer()
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     tmod = tvm.tir.truncmod
     fmod = tvm.tir.floormod
     # not sure if x is non-negative
@@ -111,12 +110,12 @@ def test_mod():
 
 def test_min_max_select():
     analyzer = tvm.arith.Analyzer()
-    x, y = te.var("x"), te.var("y")
-    m = analyzer.modular_set(tvm.te.min(x * 3, y * 9))
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
+    m = analyzer.modular_set(tvm.tir.min(x * 3, y * 9))
     assert m.coeff == 3
     assert m.base == 0
 
-    m = analyzer.modular_set(tvm.te.max(x * 3 + 1, y * 9 + 4))
+    m = analyzer.modular_set(tvm.tir.max(x * 3 + 1, y * 9 + 4))
     assert m.coeff == 3
     assert m.base == 1
 
@@ -126,8 +125,8 @@ def test_min_max_select():
 
 
 def test_mix_index():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     analyzer = tvm.arith.Analyzer()
     tdiv = tvm.tir.truncdiv
     m = analyzer.modular_set(a * 4 + b * 6 + 7)
@@ -150,14 +149,14 @@ def test_mix_index():
     assert m.coeff == 3
     assert m.base == 2
 
-    m = analyzer.modular_set(a * 12 + tvm.te.min(b * 3 * 7, 2))
+    m = analyzer.modular_set(a * 12 + tvm.tir.min(b * 3 * 7, 2))
     assert m.coeff == 1
     assert m.base == 0
 
 
 def test_constraint_scope():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     analyzer = tvm.arith.Analyzer()
     tmod = tvm.tir.truncmod
 
@@ -179,7 +178,7 @@ def test_constraint_scope():
 
 
 def test_intersect():
-    a = te.var("a")
+    a = tvm.tir.Var("a", "int32")
     analyzer = tvm.arith.Analyzer()
     tmod = tvm.tir.truncmod
     with analyzer.constraint_scope(tmod(a, 4) == 1):
@@ -198,8 +197,8 @@ def test_intersect():
 
 def test_let():
     analyzer = tvm.arith.Analyzer()
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1))
     assert m.coeff == 10
     assert m.base == 1
@@ -207,8 +206,8 @@ def test_let():
 
 def test_bitwise_and():
     analyzer = tvm.arith.Analyzer()
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
 
     # RHS of bitwise_and is 2^p - 1
     m = analyzer.modular_set((x * 16 + y * 4) & 31)
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index b85a6b758c..a87d81a7b7 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -21,7 +21,7 @@ import pytest
 
 import tvm
 import tvm.testing
-from tvm import te, tir
+from tvm import tir
 from tvm.tir import floordiv as fld
 from tvm.tir import floormod as flm
 from tvm.tir import truncdiv as tdiv
@@ -90,10 +90,10 @@ class BaseCompare:
 
 
 class TestVector(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
-    x64 = te.var("x", dtype="int64")
-    vx = te.var("vx", dtype="int32x2")
-    vc = te.var("vc", dtype="bool")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
+    x64 = tvm.tir.Var("x", "int64")
+    vx = tvm.tir.Var("vx", "int32x2")
+    vc = tvm.tir.Var("vc", "bool")
     test_case = tvm.testing.parameter(
         # Add rules
         TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x 
+ y, 3, 4)),
@@ -271,18 +271,20 @@ class TestVector(BaseCompare):
         ),  # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 
13, 20]
         # Min/Max rules
         TestCase(
-            tvm.te.min(y.astype("int32x2"), x.astype("int32x2")), 
tvm.te.min(y, x).astype("int32x2")
+            tvm.tir.min(y.astype("int32x2"), x.astype("int32x2")),
+            tvm.tir.min(y, x).astype("int32x2"),
         ),
         TestCase(
-            tvm.te.min(tvm.te.min(vx, y.astype("int32x2")), 
x.astype("int32x2")),
-            tvm.te.min(vx, tvm.te.min(y, x).astype("int32x2")),
+            tvm.tir.min(tvm.tir.min(vx, y.astype("int32x2")), 
x.astype("int32x2")),
+            tvm.tir.min(vx, tvm.tir.min(y, x).astype("int32x2")),
         ),
         TestCase(
-            tvm.te.max(y.astype("int32x2"), x.astype("int32x2")), 
tvm.te.max(y, x).astype("int32x2")
+            tvm.tir.max(y.astype("int32x2"), x.astype("int32x2")),
+            tvm.tir.max(y, x).astype("int32x2"),
         ),
         TestCase(
-            tvm.te.max(tvm.te.max(vx, y.astype("int32x2")), 
x.astype("int32x2")),
-            tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
+            tvm.tir.max(tvm.tir.max(vx, y.astype("int32x2")), 
x.astype("int32x2")),
+            tvm.tir.max(vx, tvm.tir.max(y, x).astype("int32x2")),
         ),
         ## Logical rules
         TestCase(y.astype("int32x2").equal(x.astype("int32x2")), 
(y.equal(x)).astype("boolx2")),
@@ -306,7 +308,7 @@ class TestVector(BaseCompare):
 
 
 class TestSelect(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         # Add rules
@@ -321,12 +323,12 @@ class TestSelect(BaseCompare):
         TestCase(tvm.tir.Select(x < 0, y, z) - y, tvm.tir.Select(x < 0, 0, z - 
y)),
         TestCase(tvm.tir.Select(x < 0, y, z) - z, tvm.tir.Select(x < 0, y - z, 
0)),
         TestCase(
-            tvm.te.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, 
z)),
-            tvm.tir.Select(x < 0, tvm.te.min(y, 1), tvm.te.min(0, z)),
+            tvm.tir.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, 
z)),
+            tvm.tir.Select(x < 0, tvm.tir.min(y, 1), tvm.tir.min(0, z)),
         ),
         TestCase(
-            tvm.te.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, 
z)),
-            tvm.tir.Select(x < 0, tvm.te.max(y, 1), tvm.te.max(0, z)),
+            tvm.tir.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, 
z)),
+            tvm.tir.Select(x < 0, tvm.tir.max(y, 1), tvm.tir.max(0, z)),
         ),
         TestCase(tvm.tir.Select(x * 3 + 1 != 0, y, z), y),
         TestCase(tvm.tir.Select(x * 3 + 1 == 0, y, z), z),
@@ -371,31 +373,31 @@ class TestCancellation(BaseCompare):
 
 
 class TestAddIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(x + (y - x), y),
         TestCase(x - (y + 1) + (y + 1), x),
         TestCase((x - 10) + (10 - z), x - z),
         TestCase((x - y) + (z - x), z - y),
-        TestCase(tvm.te.min(x, y - z) + z, tvm.te.min(x + z, y)),
-        TestCase(tvm.te.min(x - z, y) + z, tvm.te.min(x, y + z)),
-        TestCase(tvm.te.max(x, y - 10) + 10, tvm.te.max(x + 10, y)),
-        TestCase(tvm.te.max(x - 11, y) + 11, tvm.te.max(x, y + 11)),
-        TestCase(tvm.te.max(x, y * 2) + tvm.te.min(x, y * 2), x + y * 2),
-        TestCase(tvm.te.min(x, y * 2) + tvm.te.max(x, y * 2), x + y * 2),
-        TestCase(tvm.te.max(x, y + 2) + (-2), tvm.te.max(x + (-2), y)),
-        TestCase(tvm.te.min(x, y + 2) + (-2), tvm.te.min(x + (-2), y)),
-        TestCase(tvm.te.min(x + 2, y + 3) + (-2), tvm.te.min(x, y + 1)),
-        TestCase(tvm.te.max(0, 1 - x * 4) + x * 4, tvm.te.max(x * 4, 1)),
-        TestCase(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)),
-        TestCase(tvm.te.min(0, 1 - x * 4) + x * 4, tvm.te.min(x * 4, 1)),
-        TestCase(tvm.te.min(2 - x * 4, 0) + x * 4, tvm.te.min(x * 4, 2)),
+        TestCase(tvm.tir.min(x, y - z) + z, tvm.tir.min(x + z, y)),
+        TestCase(tvm.tir.min(x - z, y) + z, tvm.tir.min(x, y + z)),
+        TestCase(tvm.tir.max(x, y - 10) + 10, tvm.tir.max(x + 10, y)),
+        TestCase(tvm.tir.max(x - 11, y) + 11, tvm.tir.max(x, y + 11)),
+        TestCase(tvm.tir.max(x, y * 2) + tvm.tir.min(x, y * 2), x + y * 2),
+        TestCase(tvm.tir.min(x, y * 2) + tvm.tir.max(x, y * 2), x + y * 2),
+        TestCase(tvm.tir.max(x, y + 2) + (-2), tvm.tir.max(x + (-2), y)),
+        TestCase(tvm.tir.min(x, y + 2) + (-2), tvm.tir.min(x + (-2), y)),
+        TestCase(tvm.tir.min(x + 2, y + 3) + (-2), tvm.tir.min(x, y + 1)),
+        TestCase(tvm.tir.max(0, 1 - x * 4) + x * 4, tvm.tir.max(x * 4, 1)),
+        TestCase(tvm.tir.max(2 - x * 4, 0) + x * 4, tvm.tir.max(x * 4, 2)),
+        TestCase(tvm.tir.min(0, 1 - x * 4) + x * 4, tvm.tir.min(x * 4, 1)),
+        TestCase(tvm.tir.min(2 - x * 4, 0) + x * 4, tvm.tir.min(x * 4, 2)),
         TestCase(x * y + x * 10, (y + 10) * x),
         TestCase(y * x + x * 10, (y + 10) * x),
         TestCase(y * x + 10 * x, (y + 10) * x),
         TestCase(x * y + 10 * x, (y + 10) * x),
-        TestCase((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), 
y)),
+        TestCase((2 * z) + tvm.tir.min(x, y - (2 * z)), tvm.tir.min(x + (z * 
2), y)),
         TestCase(y * x + x, (y + 1) * x),
         TestCase(x * y + x, (y + 1) * x),
         TestCase((x + 10) + 13, x + 23),
@@ -419,21 +421,21 @@ class TestAddIndex(BaseCompare):
 
 
 class TestSubIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(x + y - y, x),
         TestCase(x + y - x, y),
         TestCase(x - (y + x), 0 - y),
         TestCase(x - (x + y), 0 - y),
-        TestCase(tvm.te.min(x, y) - x, tvm.te.min(0, y - x)),
-        TestCase(tvm.te.min(x, y) - y, tvm.te.min(x - y, 0)),
-        TestCase(tvm.te.max(x, y) - x, tvm.te.max(0, y - x)),
-        TestCase(tvm.te.max(x, y) - y, tvm.te.max(x - y, 0)),
-        TestCase(x - tvm.te.min(x, y), tvm.te.max(0, x - y)),
-        TestCase(y - tvm.te.min(x, y), tvm.te.max(y - x, 0)),
-        TestCase(x - tvm.te.max(x, y), tvm.te.min(0, x - y)),
-        TestCase(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)),
+        TestCase(tvm.tir.min(x, y) - x, tvm.tir.min(0, y - x)),
+        TestCase(tvm.tir.min(x, y) - y, tvm.tir.min(x - y, 0)),
+        TestCase(tvm.tir.max(x, y) - x, tvm.tir.max(0, y - x)),
+        TestCase(tvm.tir.max(x, y) - y, tvm.tir.max(x - y, 0)),
+        TestCase(x - tvm.tir.min(x, y), tvm.tir.max(0, x - y)),
+        TestCase(y - tvm.tir.min(x, y), tvm.tir.max(y - x, 0)),
+        TestCase(x - tvm.tir.max(x, y), tvm.tir.min(0, x - y)),
+        TestCase(y - tvm.tir.max(x, y), tvm.tir.min(y - x, 0)),
         # mul co-efficient foldng
         TestCase(x - x, 0),
         TestCase(x * y - x, (y + (-1)) * x),
@@ -446,26 +448,26 @@ class TestSubIndex(BaseCompare):
         TestCase((y + x) - (x + z), y - z),
         TestCase((x + y) - (z + x), y - z),
         TestCase((y + x) - (z + x), y - z),
-        TestCase(tvm.te.min(x + y, z) - x, tvm.te.min(y, z - x)),
-        TestCase(tvm.te.min(y + x, z) - x, tvm.te.min(y, z - x)),
-        TestCase(tvm.te.min(z, x + y) - x, tvm.te.min(z - x, y)),
-        TestCase(tvm.te.min(z, y + x) - x, tvm.te.min(z - x, y)),
-        TestCase(tvm.te.max(x + y, z) - x, tvm.te.max(y, z - x)),
-        TestCase(tvm.te.max(y + x, z) - x, tvm.te.max(y, z - x)),
-        TestCase(tvm.te.max(z, x + y) - x, tvm.te.max(z - x, y)),
-        TestCase(tvm.te.max(z, y + x) - x, tvm.te.max(z - x, y)),
-        TestCase(x - tvm.te.min(x + y, z), tvm.te.max(0 - y, x - z)),
-        TestCase(x - tvm.te.min(y + x, z), tvm.te.max(0 - y, x - z)),
-        TestCase(x - tvm.te.min(z, x + y), tvm.te.max(x - z, 0 - y)),
-        TestCase(x - tvm.te.min(z, y + x), tvm.te.max(x - z, 0 - y)),
-        TestCase(tvm.te.min(x, y) - tvm.te.min(y, x), 0),
-        TestCase(tvm.te.max(x, y) - tvm.te.max(y, x), 0),
-        TestCase(tvm.te.min(x, y) - tvm.te.min(x + 10, y + 10), -10),
-        TestCase(tvm.te.min(x + 10, y + 1) - tvm.te.min(x, y - 9), 10),
-        TestCase(x - tvm.te.max(x + y, 0), tvm.te.min(0 - y, x)),
-        TestCase(x - tvm.te.max(0, x + y), tvm.te.min(x, 0 - y)),
-        TestCase(x - tvm.te.min(x + y, 0), tvm.te.max(0 - y, x)),
-        TestCase(x - tvm.te.min(0, x + y), tvm.te.max(x, 0 - y)),
+        TestCase(tvm.tir.min(x + y, z) - x, tvm.tir.min(y, z - x)),
+        TestCase(tvm.tir.min(y + x, z) - x, tvm.tir.min(y, z - x)),
+        TestCase(tvm.tir.min(z, x + y) - x, tvm.tir.min(z - x, y)),
+        TestCase(tvm.tir.min(z, y + x) - x, tvm.tir.min(z - x, y)),
+        TestCase(tvm.tir.max(x + y, z) - x, tvm.tir.max(y, z - x)),
+        TestCase(tvm.tir.max(y + x, z) - x, tvm.tir.max(y, z - x)),
+        TestCase(tvm.tir.max(z, x + y) - x, tvm.tir.max(z - x, y)),
+        TestCase(tvm.tir.max(z, y + x) - x, tvm.tir.max(z - x, y)),
+        TestCase(x - tvm.tir.min(x + y, z), tvm.tir.max(0 - y, x - z)),
+        TestCase(x - tvm.tir.min(y + x, z), tvm.tir.max(0 - y, x - z)),
+        TestCase(x - tvm.tir.min(z, x + y), tvm.tir.max(x - z, 0 - y)),
+        TestCase(x - tvm.tir.min(z, y + x), tvm.tir.max(x - z, 0 - y)),
+        TestCase(tvm.tir.min(x, y) - tvm.tir.min(y, x), 0),
+        TestCase(tvm.tir.max(x, y) - tvm.tir.max(y, x), 0),
+        TestCase(tvm.tir.min(x, y) - tvm.tir.min(x + 10, y + 10), -10),
+        TestCase(tvm.tir.min(x + 10, y + 1) - tvm.tir.min(x, y - 9), 10),
+        TestCase(x - tvm.tir.max(x + y, 0), tvm.tir.min(0 - y, x)),
+        TestCase(x - tvm.tir.max(0, x + y), tvm.tir.min(x, 0 - y)),
+        TestCase(x - tvm.tir.min(x + y, 0), tvm.tir.max(0 - y, x)),
+        TestCase(x - tvm.tir.min(0, x + y), tvm.tir.max(x, 0 - y)),
         # DivMod patterns
         # truc div
         TestCase(x - tdiv(x, 3) * 3, tmod(x, 3)),
@@ -514,18 +516,18 @@ class TestSubIndex(BaseCompare):
 
 
 class TestMulIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     test_case = tvm.testing.parameter(
         TestCase((x + 2) * 3, x * 3 + 6),
         TestCase((x * 2) * 3, x * 6),
-        TestCase(tvm.te.min(x, y) * tvm.te.max(x, y), x * y),
-        TestCase(tvm.te.max(x, y) * tvm.te.min(x, y), x * y),
+        TestCase(tvm.tir.min(x, y) * tvm.tir.max(x, y), x * y),
+        TestCase(tvm.tir.max(x, y) * tvm.tir.min(x, y), x * y),
         TestCase((x - y) * (-2), (y - x) * 2),
     )
 
 
 class TestDivIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     non_negative = [x >= 0, y >= 0, z >= 0]
 
     test_case = tvm.testing.parameter(
@@ -535,11 +537,11 @@ class TestDivIndex(BaseCompare):
         TestCase(tdiv(x * 2, 4), tdiv(x, 2)),
         TestCase(tdiv(x * 4, 2), x * 2),
         TestCase(tdiv(x * 4 + y, 2), x * 2 + tdiv(y, 2), non_negative),
-        TestCase(tdiv(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, tdiv(y, 2)), 
non_negative),
-        TestCase(tdiv(tvm.te.max(x * 6, y), 2), tvm.te.max(x * 3, tdiv(y, 2)), 
non_negative),
+        TestCase(tdiv(tvm.tir.min(x * 6, y), 2), tvm.tir.min(x * 3, tdiv(y, 
2)), non_negative),
+        TestCase(tdiv(tvm.tir.max(x * 6, y), 2), tvm.tir.max(x * 3, tdiv(y, 
2)), non_negative),
         TestCase(tdiv(y + x * 4, 2), tdiv(y, 2) + x * 2, non_negative),
-        TestCase(tdiv(tvm.te.min(y, x * 6), 2), tvm.te.min(tdiv(y, 2), x * 3), 
non_negative),
-        TestCase(tdiv(tvm.te.max(y, x * 6), 2), tvm.te.max(tdiv(y, 2), x * 3), 
non_negative),
+        TestCase(tdiv(tvm.tir.min(y, x * 6), 2), tvm.tir.min(tdiv(y, 2), x * 
3), non_negative),
+        TestCase(tdiv(tvm.tir.max(y, x * 6), 2), tvm.tir.max(tdiv(y, 2), x * 
3), non_negative),
         # 3-operands
         TestCase(tdiv(x * 6 + y + z, 2), x * 3 + tdiv(y + z, 2), non_negative),
         TestCase(tdiv(x * 6 - y + (y + 3), 2), x * 3 + 1, non_negative),
@@ -562,7 +564,7 @@ class TestDivIndex(BaseCompare):
 
 
 class TestFloordivIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(fld(fld(x, 2), 3), fld(x, 6)),
@@ -581,11 +583,11 @@ class TestFloordivIndex(BaseCompare):
         TestCase(fld(x * 360 + y, 25), x * 14, [x >= 0, x < 2, y >= 0, y < 7]),
         TestCase(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)),
         TestCase(fld(x * 4 + y, 2), x * 2 + fld(y, 2)),
-        TestCase(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))),
-        TestCase(fld(tvm.te.max(x * 6, y), 2), tvm.te.max(x * 3, fld(y, 2))),
+        TestCase(fld(tvm.tir.min(x * 6, y), 2), tvm.tir.min(x * 3, fld(y, 2))),
+        TestCase(fld(tvm.tir.max(x * 6, y), 2), tvm.tir.max(x * 3, fld(y, 2))),
         TestCase(fld(y + x * 4, 2), x * 2 + fld(y, 2)),
-        TestCase(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3)),
-        TestCase(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3)),
+        TestCase(fld(tvm.tir.min(y, x * 6), 2), tvm.tir.min(fld(y, 2), x * 3)),
+        TestCase(fld(tvm.tir.max(y, x * 6), 2), tvm.tir.max(fld(y, 2), x * 3)),
         # 3-operands
         #
         # TODO(Lunderberg): Remove the necessity for the preconditions
@@ -615,7 +617,13 @@ class TestFloordivIndex(BaseCompare):
 
 
 class TestModIndex(BaseCompare):
-    x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"), 
te.var("z")
+    x, y, nx, ny, z = (
+        tvm.tir.Var("x", "int32"),
+        tvm.tir.Var("y", "int32"),
+        tvm.tir.Var("nx", "int32"),
+        tvm.tir.Var("ny", "int32"),
+        tvm.tir.Var("z", "int32"),
+    )
 
     test_case = tvm.testing.parameter(
         # TODO(Lunderberg): Loosen these preconditions.  When there's
@@ -647,7 +655,7 @@ class TestModIndex(BaseCompare):
 
 
 class TestFloormodIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(flm(x * 10, 2), 0),
@@ -685,7 +693,7 @@ class TestFloorModTwo(BaseCompare):
     however during simplification
     """
 
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     test_case = tvm.testing.parameter(
         # Removing offsets from floormod
         TestCase(flm(x, 2) + flm(x + 1, 2), 1),
@@ -714,7 +722,7 @@ class TestFloorModPadded(BaseCompare):
     such that (x - x % k) must be divisible by k
     """
 
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     test_case = tvm.testing.parameter(
         TestCase(flm(x - flm(x, 9), 9), 0),
         TestCase(flm(x - flm(x, -9), 9), 0),
@@ -727,159 +735,183 @@ class TestFloorModPadded(BaseCompare):
 
 
 class TestMinIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     test_case = tvm.testing.parameter(
         # const int bound
-        TestCase(tvm.te.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2)),
-        TestCase(tvm.te.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)),
-        TestCase(tvm.te.min(x + 1, x + 10), x + 1),
-        TestCase(tvm.te.min(x + 111, x + 10), x + 10),
-        TestCase(tvm.te.min(x + 1, x), x),
-        TestCase(tvm.te.min(x, x + 2), x),
-        TestCase(tvm.te.min(1 - x, 2 - x), 1 - x),
-        TestCase(tvm.te.min(3 - x, 2 - x), 2 - x),
-        TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.min(x, y)), tvm.te.min(x, 
y)),
-        TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.min(y, x)), tvm.te.min(x, 
y)),
-        TestCase(tvm.te.min(tvm.te.max(x, y), x), x),
-        TestCase(tvm.te.min(tvm.te.max(y, x), x), x),
-        TestCase(tvm.te.min(tvm.te.min(x, y), x), tvm.te.min(x, y)),
-        TestCase(tvm.te.min(tvm.te.min(x, y), y), tvm.te.min(x, y)),
-        TestCase(tvm.te.min(x, tvm.te.max(x, y)), x),
-        TestCase(tvm.te.min(x, tvm.te.max(y, x)), x),
-        TestCase(tvm.te.min(x, tvm.te.min(x, y)), tvm.te.min(x, y)),
-        TestCase(tvm.te.min(y, tvm.te.min(x, y)), tvm.te.min(x, y)),
-        TestCase(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), y), 
tvm.te.min(tvm.te.min(x, y), z)),
-        TestCase(
-            tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), y),
-            tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2),
-        ),
-        TestCase(
-            tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), 
x * 2), z * 2), y),
-            tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), z * 
2),
-        ),
-        TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.max(x, z)), 
tvm.te.max(tvm.te.min(y, z), x)),
-        TestCase(tvm.te.min(tvm.te.max(x, y), tvm.te.max(z, x)), 
tvm.te.max(tvm.te.min(y, z), x)),
-        TestCase(tvm.te.min(tvm.te.max(y, x), tvm.te.max(x, z)), 
tvm.te.max(tvm.te.min(y, z), x)),
-        TestCase(tvm.te.min(tvm.te.max(y, x), tvm.te.max(z, x)), 
tvm.te.max(tvm.te.min(y, z), x)),
-        TestCase(tvm.te.min(y + x, z + x), tvm.te.min(y, z) + x),
-        TestCase(tvm.te.min(y + x, x + z), tvm.te.min(y, z) + x),
-        TestCase(tvm.te.min(x + y, z + x), tvm.te.min(y, z) + x),
-        TestCase(tvm.te.min(x + y, x + z), tvm.te.min(y, z) + x),
-        TestCase(tvm.te.min(x - y, x - z), x - tvm.te.max(y, z)),
-        TestCase(tvm.te.min(y - x, z - x), tvm.te.min(y, z) - x),
-        TestCase(tvm.te.min(tvm.te.min(x, 1), 10), tvm.te.min(x, 1)),
-        TestCase(tvm.te.min(tvm.te.min(x, 11), 10), tvm.te.min(x, 10)),
-        TestCase(tvm.te.min(x * 3, 9), tvm.te.min(x, 3) * 3),
-        TestCase(tvm.te.min(x * 2, 0), tvm.te.min(x, 0) * 2),
-        TestCase(tvm.te.min(0 - x * 2, 0), tvm.te.max(x, 0) * -2),
-        TestCase(tvm.te.min(3 - x, 2), 3 - tvm.te.max(x, 1)),
-        TestCase(tvm.te.min(x * (-2), -4), tvm.te.max(x, 2) * -2),
-        TestCase(tvm.te.min(x * (-2), 4), tvm.te.max(x, -2) * -2),
-        TestCase(tvm.te.min(x * (0), 4), 0),
-        TestCase(tvm.te.min(x * (0), -4), -4),
+        TestCase(tvm.tir.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2)),
+        TestCase(tvm.tir.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)),
+        TestCase(tvm.tir.min(x + 1, x + 10), x + 1),
+        TestCase(tvm.tir.min(x + 111, x + 10), x + 10),
+        TestCase(tvm.tir.min(x + 1, x), x),
+        TestCase(tvm.tir.min(x, x + 2), x),
+        TestCase(tvm.tir.min(1 - x, 2 - x), 1 - x),
+        TestCase(tvm.tir.min(3 - x, 2 - x), 2 - x),
+        TestCase(tvm.tir.min(tvm.tir.max(x, y), tvm.tir.min(x, y)), 
tvm.tir.min(x, y)),
+        TestCase(tvm.tir.min(tvm.tir.max(x, y), tvm.tir.min(y, x)), 
tvm.tir.min(x, y)),
+        TestCase(tvm.tir.min(tvm.tir.max(x, y), x), x),
+        TestCase(tvm.tir.min(tvm.tir.max(y, x), x), x),
+        TestCase(tvm.tir.min(tvm.tir.min(x, y), x), tvm.tir.min(x, y)),
+        TestCase(tvm.tir.min(tvm.tir.min(x, y), y), tvm.tir.min(x, y)),
+        TestCase(tvm.tir.min(x, tvm.tir.max(x, y)), x),
+        TestCase(tvm.tir.min(x, tvm.tir.max(y, x)), x),
+        TestCase(tvm.tir.min(x, tvm.tir.min(x, y)), tvm.tir.min(x, y)),
+        TestCase(tvm.tir.min(y, tvm.tir.min(x, y)), tvm.tir.min(x, y)),
+        TestCase(
+            tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), y), 
tvm.tir.min(tvm.tir.min(x, y), z)
+        ),
+        TestCase(
+            tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2), 
y),
+            tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2),
+        ),
+        TestCase(
+            tvm.tir.min(
+                tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 
2), z * 2), y
+            ),
+            tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2), 
z * 2),
+        ),
+        TestCase(
+            tvm.tir.min(tvm.tir.max(x, y), tvm.tir.max(x, z)), 
tvm.tir.max(tvm.tir.min(y, z), x)
+        ),
+        TestCase(
+            tvm.tir.min(tvm.tir.max(x, y), tvm.tir.max(z, x)), 
tvm.tir.max(tvm.tir.min(y, z), x)
+        ),
+        TestCase(
+            tvm.tir.min(tvm.tir.max(y, x), tvm.tir.max(x, z)), 
tvm.tir.max(tvm.tir.min(y, z), x)
+        ),
+        TestCase(
+            tvm.tir.min(tvm.tir.max(y, x), tvm.tir.max(z, x)), 
tvm.tir.max(tvm.tir.min(y, z), x)
+        ),
+        TestCase(tvm.tir.min(y + x, z + x), tvm.tir.min(y, z) + x),
+        TestCase(tvm.tir.min(y + x, x + z), tvm.tir.min(y, z) + x),
+        TestCase(tvm.tir.min(x + y, z + x), tvm.tir.min(y, z) + x),
+        TestCase(tvm.tir.min(x + y, x + z), tvm.tir.min(y, z) + x),
+        TestCase(tvm.tir.min(x - y, x - z), x - tvm.tir.max(y, z)),
+        TestCase(tvm.tir.min(y - x, z - x), tvm.tir.min(y, z) - x),
+        TestCase(tvm.tir.min(tvm.tir.min(x, 1), 10), tvm.tir.min(x, 1)),
+        TestCase(tvm.tir.min(tvm.tir.min(x, 11), 10), tvm.tir.min(x, 10)),
+        TestCase(tvm.tir.min(x * 3, 9), tvm.tir.min(x, 3) * 3),
+        TestCase(tvm.tir.min(x * 2, 0), tvm.tir.min(x, 0) * 2),
+        TestCase(tvm.tir.min(0 - x * 2, 0), tvm.tir.max(x, 0) * -2),
+        TestCase(tvm.tir.min(3 - x, 2), 3 - tvm.tir.max(x, 1)),
+        TestCase(tvm.tir.min(x * (-2), -4), tvm.tir.max(x, 2) * -2),
+        TestCase(tvm.tir.min(x * (-2), 4), tvm.tir.max(x, -2) * -2),
+        TestCase(tvm.tir.min(x * (0), 4), 0),
+        TestCase(tvm.tir.min(x * (0), -4), -4),
         # DivMod rules
         # truc div
-        TestCase(tvm.te.min(tdiv(x + 3, 4) * 4, x), x),
-        TestCase(tvm.te.min(x, tdiv(x + 3, 4) * 4), x),
-        TestCase(tvm.te.min(tdiv(x + 3, 4) * 4, tvm.te.max(x, 4)), 
tvm.te.max(x, 4), x > 0),
-        TestCase(tvm.te.min(tvm.te.max(x, 4), tdiv(x + 3, 4) * 4), 
tvm.te.max(x, 4), x > 0),
-        TestCase(tvm.te.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.min(x, y), 
10)),
-        TestCase(tvm.te.min(tdiv(x, (-10)), tdiv(y, (-10))), 
tdiv(tvm.te.max(x, y), (-10))),
+        TestCase(tvm.tir.min(tdiv(x + 3, 4) * 4, x), x),
+        TestCase(tvm.tir.min(x, tdiv(x + 3, 4) * 4), x),
+        TestCase(tvm.tir.min(tdiv(x + 3, 4) * 4, tvm.tir.max(x, 4)), 
tvm.tir.max(x, 4), x > 0),
+        TestCase(tvm.tir.min(tvm.tir.max(x, 4), tdiv(x + 3, 4) * 4), 
tvm.tir.max(x, 4), x > 0),
+        TestCase(tvm.tir.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tir.min(x, 
y), 10)),
+        TestCase(tvm.tir.min(tdiv(x, (-10)), tdiv(y, (-10))), 
tdiv(tvm.tir.max(x, y), (-10))),
         # floor div
-        TestCase(tvm.te.min(fld(x + 3, 4) * 4, x), x),
-        TestCase(tvm.te.min(x, fld(x + 3, 4) * 4), x),
-        TestCase(tvm.te.min(x, fld(x, 4) * 4), fld(x, 4) * 4),
-        TestCase(tvm.te.min(fld(x + 3, 4) * 4, tvm.te.max(x, 4)), 
tvm.te.max(x, 4), x > 0),
-        TestCase(tvm.te.min(tvm.te.max(x, 4), fld(x + 3, 4) * 4), 
tvm.te.max(x, 4), x > 0),
-        TestCase(tvm.te.min(fld(x, 10), fld(y, 10)), fld(tvm.te.min(x, y), 
10)),
-        TestCase(tvm.te.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.max(x, 
y), (-10))),
+        TestCase(tvm.tir.min(fld(x + 3, 4) * 4, x), x),
+        TestCase(tvm.tir.min(x, fld(x + 3, 4) * 4), x),
+        TestCase(tvm.tir.min(x, fld(x, 4) * 4), fld(x, 4) * 4),
+        TestCase(tvm.tir.min(fld(x + 3, 4) * 4, tvm.tir.max(x, 4)), 
tvm.tir.max(x, 4), x > 0),
+        TestCase(tvm.tir.min(tvm.tir.max(x, 4), fld(x + 3, 4) * 4), 
tvm.tir.max(x, 4), x > 0),
+        TestCase(tvm.tir.min(fld(x, 10), fld(y, 10)), fld(tvm.tir.min(x, y), 
10)),
+        TestCase(tvm.tir.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.tir.max(x, 
y), (-10))),
     )
 
 
 class TestMaxIndex(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         # const int bound
-        TestCase(tvm.te.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10),
-        TestCase(tvm.te.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10),
-        TestCase(tvm.te.max(x + 1, x + 10), x + 10),
-        TestCase(tvm.te.max(x + 111, x + 10), x + 111),
-        TestCase(tvm.te.max(x + 1, x), x + 1),
-        TestCase(tvm.te.max(x, x + 2), x + 2),
-        TestCase(tvm.te.max(1 - x, 2 - x), 2 - x),
-        TestCase(tvm.te.max(3 - x, 2 - x), 3 - x),
-        TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.max(x, y)), tvm.te.max(x, 
y)),
-        TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.max(y, x)), tvm.te.max(x, 
y)),
-        TestCase(tvm.te.max(tvm.te.min(x, y), x), x),
-        TestCase(tvm.te.max(tvm.te.min(y, x), x), x),
-        TestCase(tvm.te.max(tvm.te.max(x, y), x), tvm.te.max(x, y)),
-        TestCase(tvm.te.max(tvm.te.max(x, y), y), tvm.te.max(x, y)),
-        TestCase(tvm.te.max(x, tvm.te.min(x, y)), x),
-        TestCase(tvm.te.max(x, tvm.te.min(y, x)), x),
-        TestCase(tvm.te.max(x, tvm.te.max(x, y)), tvm.te.max(x, y)),
-        TestCase(tvm.te.max(y, tvm.te.max(x, y)), tvm.te.max(x, y)),
-        TestCase(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), y), 
tvm.te.max(tvm.te.max(x, y), z)),
-        TestCase(
-            tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), y),
-            tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2),
-        ),
-        TestCase(
-            tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), 
x * 2), z * 2), y),
-            tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 
2),
-        ),
-        TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.min(x, z)), 
tvm.te.min(tvm.te.max(y, z), x)),
-        TestCase(tvm.te.max(tvm.te.min(x, y), tvm.te.min(z, x)), 
tvm.te.min(tvm.te.max(y, z), x)),
-        TestCase(tvm.te.max(tvm.te.min(y, x), tvm.te.min(x, z)), 
tvm.te.min(tvm.te.max(y, z), x)),
-        TestCase(tvm.te.max(tvm.te.min(y, x), tvm.te.min(z, x)), 
tvm.te.min(tvm.te.max(y, z), x)),
-        TestCase(tvm.te.max(y + x, z + x), tvm.te.max(y, z) + x),
-        TestCase(tvm.te.max(y + x, x + z), tvm.te.max(y, z) + x),
-        TestCase(tvm.te.max(x + y, z + x), tvm.te.max(y, z) + x),
-        TestCase(tvm.te.max(x + y, x + z), tvm.te.max(y, z) + x),
-        TestCase(tvm.te.max(x - y, x - z), x - tvm.te.min(y, z)),
-        TestCase(tvm.te.max(y - x, z - x), tvm.te.max(y, z) - x),
-        TestCase(tvm.te.max(tvm.te.max(x, 1), 10), tvm.te.max(x, 10)),
-        TestCase(tvm.te.max(tvm.te.max(x, 11), 10), tvm.te.max(x, 11)),
-        TestCase(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3),
-        TestCase(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x, 2)),
-        TestCase(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2),
-        TestCase(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2),
-        TestCase(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2),
-        TestCase(tvm.te.max(x * (-2), 4), tvm.te.min(x, -2) * -2),
-        TestCase(tvm.te.max(x * (0), 4), 4),
-        TestCase(tvm.te.max(x * (0), -4), 0),
+        TestCase(tvm.tir.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10),
+        TestCase(tvm.tir.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10),
+        TestCase(tvm.tir.max(x + 1, x + 10), x + 10),
+        TestCase(tvm.tir.max(x + 111, x + 10), x + 111),
+        TestCase(tvm.tir.max(x + 1, x), x + 1),
+        TestCase(tvm.tir.max(x, x + 2), x + 2),
+        TestCase(tvm.tir.max(1 - x, 2 - x), 2 - x),
+        TestCase(tvm.tir.max(3 - x, 2 - x), 3 - x),
+        TestCase(tvm.tir.max(tvm.tir.min(x, y), tvm.tir.max(x, y)), 
tvm.tir.max(x, y)),
+        TestCase(tvm.tir.max(tvm.tir.min(x, y), tvm.tir.max(y, x)), 
tvm.tir.max(x, y)),
+        TestCase(tvm.tir.max(tvm.tir.min(x, y), x), x),
+        TestCase(tvm.tir.max(tvm.tir.min(y, x), x), x),
+        TestCase(tvm.tir.max(tvm.tir.max(x, y), x), tvm.tir.max(x, y)),
+        TestCase(tvm.tir.max(tvm.tir.max(x, y), y), tvm.tir.max(x, y)),
+        TestCase(tvm.tir.max(x, tvm.tir.min(x, y)), x),
+        TestCase(tvm.tir.max(x, tvm.tir.min(y, x)), x),
+        TestCase(tvm.tir.max(x, tvm.tir.max(x, y)), tvm.tir.max(x, y)),
+        TestCase(tvm.tir.max(y, tvm.tir.max(x, y)), tvm.tir.max(x, y)),
+        TestCase(
+            tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), y), 
tvm.tir.max(tvm.tir.max(x, y), z)
+        ),
+        TestCase(
+            tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2), 
y),
+            tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2),
+        ),
+        TestCase(
+            tvm.tir.max(
+                tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 
2), z * 2), y
+            ),
+            tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2), 
z * 2),
+        ),
+        TestCase(
+            tvm.tir.max(tvm.tir.min(x, y), tvm.tir.min(x, z)), 
tvm.tir.min(tvm.tir.max(y, z), x)
+        ),
+        TestCase(
+            tvm.tir.max(tvm.tir.min(x, y), tvm.tir.min(z, x)), 
tvm.tir.min(tvm.tir.max(y, z), x)
+        ),
+        TestCase(
+            tvm.tir.max(tvm.tir.min(y, x), tvm.tir.min(x, z)), 
tvm.tir.min(tvm.tir.max(y, z), x)
+        ),
+        TestCase(
+            tvm.tir.max(tvm.tir.min(y, x), tvm.tir.min(z, x)), 
tvm.tir.min(tvm.tir.max(y, z), x)
+        ),
+        TestCase(tvm.tir.max(y + x, z + x), tvm.tir.max(y, z) + x),
+        TestCase(tvm.tir.max(y + x, x + z), tvm.tir.max(y, z) + x),
+        TestCase(tvm.tir.max(x + y, z + x), tvm.tir.max(y, z) + x),
+        TestCase(tvm.tir.max(x + y, x + z), tvm.tir.max(y, z) + x),
+        TestCase(tvm.tir.max(x - y, x - z), x - tvm.tir.min(y, z)),
+        TestCase(tvm.tir.max(y - x, z - x), tvm.tir.max(y, z) - x),
+        TestCase(tvm.tir.max(tvm.tir.max(x, 1), 10), tvm.tir.max(x, 10)),
+        TestCase(tvm.tir.max(tvm.tir.max(x, 11), 10), tvm.tir.max(x, 11)),
+        TestCase(tvm.tir.max(x * 3, 9), tvm.tir.max(x, 3) * 3),
+        TestCase(tvm.tir.max(3 - x, 1), 3 - tvm.tir.min(x, 2)),
+        TestCase(tvm.tir.max(x * 2, 0), tvm.tir.max(x, 0) * 2),
+        TestCase(tvm.tir.max(0 - x * 2, 0), tvm.tir.min(x, 0) * -2),
+        TestCase(tvm.tir.max(x * (-2), -4), tvm.tir.min(x, 2) * -2),
+        TestCase(tvm.tir.max(x * (-2), 4), tvm.tir.min(x, -2) * -2),
+        TestCase(tvm.tir.max(x * (0), 4), 4),
+        TestCase(tvm.tir.max(x * (0), -4), 0),
         # DivMod rules
         # truc div
-        TestCase(tvm.te.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.max(x, y), 
10)),
-        TestCase(tvm.te.max(tdiv(x, (-10)), tdiv(y, (-10))), 
tdiv(tvm.te.min(x, y), (-10))),
-        TestCase(tvm.te.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4),
+        TestCase(tvm.tir.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tir.max(x, 
y), 10)),
+        TestCase(tvm.tir.max(tdiv(x, (-10)), tdiv(y, (-10))), 
tdiv(tvm.tir.min(x, y), (-10))),
+        TestCase(tvm.tir.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4),
         # floordiv
-        TestCase(tvm.te.max(fld(x, 10), fld(y, 10)), fld(tvm.te.max(x, y), 
10)),
-        TestCase(tvm.te.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.min(x, 
y), (-10))),
-        TestCase(tvm.te.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4),
-        TestCase(tvm.te.max(fld(x, 4) * 4, x), x),
-        TestCase(tvm.te.max(x, fld(x, 4) * 4), x),
+        TestCase(tvm.tir.max(fld(x, 10), fld(y, 10)), fld(tvm.tir.max(x, y), 
10)),
+        TestCase(tvm.tir.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.tir.min(x, 
y), (-10))),
+        TestCase(tvm.tir.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4),
+        TestCase(tvm.tir.max(fld(x, 4) * 4, x), x),
+        TestCase(tvm.tir.max(x, fld(x, 4) * 4), x),
     )
 
 
 class TestScalableIndex(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     test_case = tvm.testing.parameter(
         # MinNode
-        TestCase(tvm.te.min(x + tir.vscale() * 4, x), x),
-        TestCase(tvm.te.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4),
-        TestCase(tvm.te.min(x + tir.vscale() * 4, x + tir.vscale() * 8), 
tir.vscale() * 4 + x),
-        TestCase(tvm.te.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), 
x), x),
-        TestCase(tvm.te.min(tir.vscale() * x, tir.vscale() * y), tir.vscale() 
* x, x < y),
+        TestCase(tvm.tir.min(x + tir.vscale() * 4, x), x),
+        TestCase(tvm.tir.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4),
+        TestCase(tvm.tir.min(x + tir.vscale() * 4, x + tir.vscale() * 8), 
tir.vscale() * 4 + x),
+        TestCase(tvm.tir.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), 
x), x),
+        TestCase(tvm.tir.min(tir.vscale() * x, tir.vscale() * y), tir.vscale() 
* x, x < y),
         # MaxNode
-        TestCase(tvm.te.max(x + tir.vscale() * 4, x), x + tir.vscale() * 4),
-        TestCase(tvm.te.max(x - tir.vscale() * 4, x), x),
-        TestCase(tvm.te.max(x + tir.vscale() * 4, x + tir.vscale() * 4), x + 
tir.vscale() * 4),
+        TestCase(tvm.tir.max(x + tir.vscale() * 4, x), x + tir.vscale() * 4),
+        TestCase(tvm.tir.max(x - tir.vscale() * 4, x), x),
+        TestCase(tvm.tir.max(x + tir.vscale() * 4, x + tir.vscale() * 4), x + 
tir.vscale() * 4),
         TestCase(
-            tvm.te.max(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x),
+            tvm.tir.max(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x),
             x + tir.vscale() * 4 - flm(4, tir.vscale() * 4),
         ),
-        TestCase(tvm.te.max(tir.vscale() * x, tir.vscale() * y), tir.vscale() 
* x, x > y),
+        TestCase(tvm.tir.max(tir.vscale() * x, tir.vscale() * y), tir.vscale() 
* x, x > y),
         # FloorDiv
         TestCase(fld(x * tir.vscale() * 4 + y, tir.vscale() * 4), x + fld(y, 
tir.vscale() * 4)),
         TestCase(fld(x, tir.vscale() * 4), 0, [x >= 0, x < tir.vscale() * 4]),
@@ -894,7 +926,7 @@ class TestScalableIndex(BaseCompare):
 
 
 class TestComparisons(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         # const int bound
@@ -1066,10 +1098,10 @@ class TestComparisons(BaseCompare):
             tir.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)),
             tir.all(T.int32(-43) <= x, x < -40),
         ),
-        TestCase(tvm.te.min(x, 11) < 10, x < 10),
-        TestCase(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")),
-        TestCase(tvm.te.max(8, x) > 10, tvm.tir.LT(10, x)),
-        TestCase(x + 1 < tvm.te.max(8, x), x < 7),
+        TestCase(tvm.tir.min(x, 11) < 10, x < 10),
+        TestCase(tvm.tir.min(x, 8) < 10, tvm.tir.const(1, "bool")),
+        TestCase(tvm.tir.max(8, x) > 10, tvm.tir.LT(10, x)),
+        TestCase(x + 1 < tvm.tir.max(8, x), x < 7),
         TestCase(x < 11, tvm.tir.const(1, "bool"), x <= 10),
         TestCase(x <= 10, tvm.tir.const(1, "bool"), x <= 10),
         TestCase(z <= 5, tvm.tir.const(1, "bool"), z <= 5),
@@ -1088,7 +1120,7 @@ class TestComparisons(BaseCompare):
 class TestComparisonOfProductAndSum(BaseCompare):
     extensions = tvm.arith.Extension.ComparisonOfProductAndSum
 
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         # Special inequality cases
@@ -1119,7 +1151,7 @@ class TestComparisonOfProductAndSum(BaseCompare):
 
 
 class TestLogical(BaseCompare):
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(tvm.tir.And(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), 
tvm.tir.const(False, "bool")),
@@ -1164,7 +1196,7 @@ class TestLogical(BaseCompare):
 
 
 class TestLet(BaseCompare):
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     z = tvm.tir.Let(x, 1, x + 1)
 
     test_case = tvm.testing.parameter(
@@ -1174,7 +1206,7 @@ class TestLet(BaseCompare):
 
 class TestCast(BaseCompare):
     def _generate_tests():
-        x = te.var("x")
+        x = tvm.tir.Var("x", "int32")
         dtypes = ["float32", "float16", "int32", "int8", "bool"]
         for dtype1 in dtypes:
             yield TestCase(tvm.tir.Cast(dtype1, x - x), tvm.tir.const(0, 
dtype1))
@@ -1218,7 +1250,7 @@ class TestSubBufferload(BaseCompare):
 
 
 class TestIfThenElse(BaseCompare):
-    x = te.var("x", "int32")
+    x = tvm.tir.Var("x", "int32")
 
     test_case = tvm.testing.parameter(
         TestCase(
diff --git a/tests/python/arith/test_arith_solve_linear_equations.py 
b/tests/python/arith/test_arith_solve_linear_equations.py
index 61c107915e..ddb0b36db8 100644
--- a/tests/python/arith/test_arith_solve_linear_equations.py
+++ b/tests/python/arith/test_arith_solve_linear_equations.py
@@ -18,7 +18,7 @@ import random
 import sys
 import pytest
 import tvm
-from tvm import te, arith, ir, tir, testing
+from tvm import arith, ir, tir, testing
 from tvm.script import tir as T
 
 
@@ -31,7 +31,7 @@ def test_solution_consistency():
     random.seed(seed)
 
     def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
-        variables = [te.var("x" + str(i)) for i in range(num_vars)]
+        variables = [tvm.tir.Var("x" + str(i), "int32") for i in 
range(num_vars)]
 
         relations = []
         for i in range(num_formulas):
@@ -85,7 +85,7 @@ def test_solution_consistency():
 
 
 def test_empty_var_to_solve():
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     equations = [
         tvm.tir.EQ(x + y, 20),
         tvm.tir.EQ(x - y, 10),
@@ -100,7 +100,7 @@ def test_empty_var_to_solve():
 
 
 def test_unique_solution():
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     solution = arith.solve_linear_equations(
         [
@@ -115,7 +115,7 @@ def test_unique_solution():
 
 
 def test_low_rank():
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     ranges = {}
 
     solution = arith.solve_linear_equations(
@@ -133,7 +133,7 @@ def test_low_rank():
 
 
 def test_infer_range():
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     ranges = {
         x: tvm.ir.Range.from_min_extent(-5, 10),
         y: tvm.ir.Range.from_min_extent(0, 10),
@@ -160,7 +160,7 @@ def test_infer_range():
 
 
 def test_ill_formed():
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     solution = arith.solve_linear_equations(
         [
diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py 
b/tests/python/arith/test_arith_solve_linear_inequality.py
index 192c46b56c..17557fd952 100644
--- a/tests/python/arith/test_arith_solve_linear_inequality.py
+++ b/tests/python/arith/test_arith_solve_linear_inequality.py
@@ -18,7 +18,7 @@ import random
 import sys
 import pytest
 import tvm
-from tvm import te, arith, ir, tir, testing
+from tvm import arith, ir, tir, testing
 from tvm.script import tir as T
 
 
@@ -32,7 +32,7 @@ def test_solution_consistency():
     random.seed(seed)
 
     def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)):
-        vs = [te.var("x" + str(i)) for i in range(variables)]
+        vs = [tvm.tir.Var("x" + str(i), "int32") for i in range(variables)]
 
         fs = []
         for i in range(formulas):
@@ -44,9 +44,9 @@ def test_solution_consistency():
             fs.append(op(s1, s2))
 
         vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs}
-        before = te.all(tir.const(1, "bool"), *fs)
+        before = tvm.tir.all(tir.const(1, "bool"), *fs)
         after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs)
-        after = te.all(tir.const(1, "bool"), *after)
+        after = tvm.tir.all(tir.const(1, "bool"), *after)
         testing.check_bool_expr_is_true(before == after, vranges)
 
         solution = arith.solve_linear_inequalities(fs, vs, vranges, 
deskew_range=True)
@@ -81,7 +81,7 @@ def test_solution_consistency():
 
 
 def test_dual_variable():
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
 
     variables = [x, y]
     ranges = {
@@ -125,7 +125,7 @@ def test_dual_variable():
 
 
 def test_equal():
-    x, y = te.var("x"), te.var("y")
+    x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32")
     problem = [
         tvm.tir.GE(x + y, 10),
         tvm.tir.GE(x - y, 2),
@@ -147,7 +147,7 @@ def test_equal():
 
 
 def test_multi_equal():
-    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), 
tvm.tir.Var("z", "int32")
     problem = [
         tvm.tir.LE(x, 6),
         tvm.tir.GE(x, 6),
@@ -179,7 +179,7 @@ def test_multi_equal():
 
 
 def test_no_solution():
-    x = te.var("x0")
+    x = tvm.tir.Var("x0", "int32")
     vranges = {x: tvm.ir.Range.from_min_extent(-20, 41)}
     problem = [-x - 4 <= -5 * x + 2, x * 4 + 5 <= x * 5]
 
@@ -198,8 +198,8 @@ def test_no_solution():
 
 
 def test_unbound_var_range():
-    x = te.var("x0")
-    free_var = te.var("fv")
+    x = tvm.tir.Var("x0", "int32")
+    free_var = tvm.tir.Var("fv", "int32")
     vranges = {x: tvm.ir.Range.from_min_extent(0, tvm.tir.Cast("int32", 1 + 
tvm.tir.log(free_var)))}
     problem = [x > 3]
     solution = arith.solve_linear_inequalities(
diff --git a/tests/python/ir/test_ir_container.py 
b/tests/python/ir/test_ir_container.py
index cd47766159..af820acb0a 100644
--- a/tests/python/ir/test_ir_container.py
+++ b/tests/python/ir/test_ir_container.py
@@ -17,7 +17,7 @@
 import pytest
 import tvm_ffi
 import tvm
-from tvm import te
+
 import numpy as np
 
 
@@ -45,8 +45,8 @@ def test_dir_array():
 
 
 def test_map():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     amap = tvm.runtime.convert({a: 2, b: 3})
     assert a in amap
     assert len(amap) == 2
@@ -70,8 +70,8 @@ def test_str_map():
 
 
 def test_map_save_load_json():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     amap = tvm.runtime.convert({a: 2, b: 3})
     json_str = tvm.ir.save_json(amap)
     amap = tvm.ir.load_json(json_str)
@@ -81,15 +81,15 @@ def test_map_save_load_json():
 
 
 def test_dir_map():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     amap = tvm.runtime.convert({a: 2, b: 3})
     assert dir(amap)
 
 
 def test_getattr_map():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     amap = tvm.runtime.convert({a: 2, b: 3})
     assert isinstance(amap, tvm_ffi.Map)
 
diff --git a/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py 
b/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py
index c3ae417dcd..5ed8314d1a 100644
--- a/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py
+++ b/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py
@@ -15,18 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te
 
 
 def test_equal_expr():
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
 
     def func1():
         return x + y + 1
 
     def func2():
-        return te.exp(tvm.tir.truncdiv((x + y + 1) * y, 4))
+        return tvm.tir.exp(tvm.tir.truncdiv((x + y + 1) * y, 4))
 
     assert tvm.tir.analysis.expr_deep_equal(func1(), func1())
     assert tvm.tir.analysis.expr_deep_equal(func2(), func2())
diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py 
b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py
index a6db37f5a2..6d559a81d2 100644
--- a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py
+++ b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py
@@ -15,12 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te
 
 
 def test_verify_ssa():
-    x = te.var("x")
-    y = te.var()
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("tindex", "int32")
     z = tvm.tir.Evaluate(x + y)
     assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], z))
 
@@ -28,7 +27,7 @@ def test_verify_ssa():
 
 
 def test_verify_weak_let_ssa():
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     z1 = tvm.tir.Let(x, 1, x + 1)
     z2 = tvm.tir.Let(x, 2, x + 2)
 
diff --git a/tests/python/tir-base/test_tir_buffer.py 
b/tests/python/tir-base/test_tir_buffer.py
index 791de76995..df60d43061 100644
--- a/tests/python/tir-base/test_tir_buffer.py
+++ b/tests/python/tir-base/test_tir_buffer.py
@@ -17,7 +17,6 @@
 
 import tvm
 import tvm.testing
-from tvm import te
 from tvm.tir import Buffer
 from tvm.script import tir as T
 
@@ -26,9 +25,9 @@ import pytest
 
 
 def test_buffer():
-    m = te.size_var("m")
-    n = te.size_var("n")
-    l = te.size_var("l")
+    m = tvm.tir.SizeVar("m", "int32")
+    n = tvm.tir.SizeVar("n", "int32")
+    l = tvm.tir.SizeVar("l", "int32")
     Ab = tvm.tir.decl_buffer((m, n), "float32")
     Bb = tvm.tir.decl_buffer((n, l), "float32")
 
@@ -38,8 +37,8 @@ def test_buffer():
 
 
 def test_buffer_access_ptr():
-    m = te.size_var("m")
-    n = te.size_var("n")
+    m = tvm.tir.SizeVar("m", "int32")
+    n = tvm.tir.SizeVar("n", "int32")
     Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1])
     aptr = Ab.access_ptr("rw")
     tvm.ir.assert_structural_equal(aptr.args[3], Ab.strides[0] * m)
@@ -50,13 +49,13 @@ def test_buffer_access_ptr():
 
 
 def test_buffer_access_ptr_offset():
-    m = te.size_var("m")
-    n = te.size_var("n")
+    m = tvm.tir.SizeVar("m", "int32")
+    n = tvm.tir.SizeVar("n", "int32")
     Ab = tvm.tir.decl_buffer((m, n), "float32")
     aptr = Ab.access_ptr("rw", offset=100)
     tvm.testing.assert_prim_expr_equal(aptr.args[2], 100)
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
-    v = te.size_var("int32")
+    v = tvm.tir.SizeVar("int32", "int32")
     aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
     tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v)
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
@@ -68,8 +67,8 @@ def test_buffer_access_ptr_offset():
 
 
 def test_buffer_access_ptr_extent():
-    m = te.size_var("m")
-    n = te.size_var("n")
+    m = tvm.tir.SizeVar("m", "int32")
+    n = tvm.tir.SizeVar("n", "int32")
     Ab = tvm.tir.decl_buffer((m, n), "float32")
     aptr = Ab.access_ptr("rw")
     tvm.ir.assert_structural_equal(aptr.args[3], m * n)
@@ -87,27 +86,27 @@ def test_buffer_access_ptr_extent():
 
 
 def test_buffer_vload():
-    m = te.size_var("m")
-    n = te.size_var("n")
+    m = tvm.tir.SizeVar("m", "int32")
+    n = tvm.tir.SizeVar("n", "int32")
     Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
     load = Ab.vload([2, 3])
     tvm.ir.assert_structural_equal(load.indices, [T.int32(2), T.int32(3)])
 
 
 def test_buffer_offset_of():
-    m = te.size_var("m")
-    n = te.size_var("n")
+    m = tvm.tir.SizeVar("m", "int32")
+    n = tvm.tir.SizeVar("n", "int32")
     Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
     offset = Ab.offset_of([2, 3])
     tvm.ir.assert_structural_equal(offset, [n * 2 + 103])
 
 
 def test_buffer_index_merge_mult_mod():
-    m = te.size_var("m")
-    n = te.size_var("n")
-    s = te.size_var("s")
-    k0 = te.size_var("k0")
-    k1 = te.size_var("k1")
+    m = tvm.tir.SizeVar("m", "int32")
+    n = tvm.tir.SizeVar("n", "int32")
+    s = tvm.tir.SizeVar("s", "int32")
+    k0 = tvm.tir.SizeVar("k0", "int32")
+    k1 = tvm.tir.SizeVar("k1", "int32")
     A = tvm.tir.decl_buffer((m, n), "float32")
     A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
 
@@ -153,9 +152,9 @@ def test_buffer_index_merge_mult_mod():
 
     # Test Case5
     B = tvm.tir.decl_buffer((1, 14, 14, 1024))
-    i = te.size_var("i")
-    j = te.size_var("j")
-    k = te.size_var("k")
+    i = tvm.tir.SizeVar("i", "int32")
+    j = tvm.tir.SizeVar("j", "int32")
+    k = tvm.tir.SizeVar("k", "int32")
 
     index_simplified1 = B.offset_of(
         (
diff --git a/tests/python/tir-base/test_tir_constructor.py 
b/tests/python/tir-base/test_tir_constructor.py
index 4076070557..25f62c7096 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -17,7 +17,6 @@
 
 import pytest
 import tvm
-from tvm import te
 
 
 def test_expr_constructor():
@@ -50,7 +49,7 @@ def test_expr_constructor():
     assert x.value.value == 1
 
     a = tvm.tir.const(1.0, dtype="float32")
-    b = te.var("x", dtype="float32")
+    b = tvm.tir.Var("x", "float32")
 
     for cls in [
         tvm.tir.Add,
@@ -70,8 +69,8 @@ def test_expr_constructor():
         assert x.a == a
         assert x.b.same_as(b)
 
-    a = tvm.runtime.convert(te.var("x") > 1)
-    b = tvm.runtime.convert(te.var("x") == 1)
+    a = tvm.runtime.convert(tvm.tir.Var("x", "int32") > 1)
+    b = tvm.runtime.convert(tvm.tir.Var("x", "int32") == 1)
 
     for cls in [tvm.tir.And, tvm.tir.Or]:
         x = cls(a, b)
@@ -120,7 +119,7 @@ def test_expr_constructor():
     assert x.op.name == "tir.call_extern"
     assert x.args[1] == a
 
-    v = te.var("aa")
+    v = tvm.tir.Var("aa", "int32")
     x = tvm.tir.Let(v, 1, v)
     assert x.var == v
     assert x.value.value == 1
@@ -128,7 +127,7 @@ def test_expr_constructor():
 
 
 def test_stmt_constructor():
-    v = te.var("aa")
+    v = tvm.tir.Var("aa", "int32")
     nop = tvm.tir.Evaluate(1)
     x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1))
     assert isinstance(x, tvm.tir.LetStmt)
@@ -144,7 +143,7 @@ def test_stmt_constructor():
     assert isinstance(x, tvm.tir.AssertStmt)
     assert x.body == nop
 
-    x = tvm.tir.For(te.var("x"), 0, 10, tvm.tir.ForKind.SERIAL, nop)
+    x = tvm.tir.For(tvm.tir.Var("x", "int32"), 0, 10, tvm.tir.ForKind.SERIAL, 
nop)
     assert isinstance(x, tvm.tir.For)
     assert x.min.value == 0
     assert x.extent.value == 10
diff --git a/tests/python/tir-base/test_tir_intrin.py 
b/tests/python/tir-base/test_tir_intrin.py
index 9f199a2b63..f85c06c5b2 100644
--- a/tests/python/tir-base/test_tir_intrin.py
+++ b/tests/python/tir-base/test_tir_intrin.py
@@ -55,7 +55,7 @@ def test_nearbyint():
 
 
 def test_round_intrinsics_on_int():
-    i = tvm.te.var("i", "int32")
+    i = tvm.tir.Var("i", "int32")
     for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil, tvm.tir.floor, 
tvm.tir.nearbyint]:
         assert op(tvm.tir.const(10, "int32")).value == 10
         assert op(tvm.tir.const(True, "bool")).value == True
diff --git a/tests/python/tir-base/test_tir_nodes.py 
b/tests/python/tir-base/test_tir_nodes.py
index b0bd5ac891..7f857a3203 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -17,7 +17,7 @@
 import numpy as np
 import pytest
 import tvm
-from tvm import ir, te
+from tvm import ir
 
 
 def test_const():
@@ -27,7 +27,7 @@ def test_const():
 
 
 def test_te_const():
-    x = tvm.te.const(1, "int32")
+    x = tvm.tir.const(1, "int32")
     assert x.dtype == "int32"
     assert isinstance(x, tvm.tir.IntImm)
 
@@ -57,7 +57,7 @@ def test_tir_const_dtype_inference():
 
 def test_make():
     x = tvm.tir.const(1, "int32")
-    y = te.var("x")
+    y = tvm.tir.Var("x", "int32")
     z = x + y
     assert isinstance(tvm.tir.max(x, y), tvm.tir.Max)
     assert isinstance(tvm.tir.min(x, y), tvm.tir.Min)
@@ -72,12 +72,12 @@ def test_ir():
 
 
 def test_ir2():
-    buf_size = te.var("size")
-    x = te.var("n")
+    buf_size = tvm.tir.Var("size", "int32")
+    x = tvm.tir.Var("n", "int32")
 
     storage_type = ir.PrimType("int32")
     handle_type = ir.PointerType(storage_type)
-    array = te.var("array", handle_type)
+    array = tvm.tir.Var("array", handle_type)
     buf = tvm.tir.decl_buffer([buf_size], "int32", data=array)
 
     st = tvm.tir.BufferStore(buf, x + 1, [1])
@@ -87,13 +87,13 @@ def test_ir2():
 
 
 def test_let():
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
 
 
 def test_cast():
-    x = te.var("x", dtype="float32")
+    x = tvm.tir.Var("x", "float32")
     y = x.astype("int32")
     z = x.astype("float32x4")
     assert isinstance(y, tvm.tir.Cast)
@@ -110,8 +110,8 @@ def test_cast():
 
 
 def test_attr():
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     stmt = tvm.tir.AttrStmt(y, "stride", 10, tvm.tir.Evaluate(x + 1))
     assert stmt.node == y
 
@@ -125,34 +125,34 @@ def test_attr():
 
 
 def test_basic():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     c = a + b
     assert str(c) == "%s + %s" % (a.name, b.name)
 
 
 def test_stmt():
     x = tvm.tir.Evaluate(0)
-    tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x)
-    tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.UNROLLED, x, step=2)
+    tvm.tir.For(tvm.tir.Var("i", "int32"), 0, 1, tvm.tir.ForKind.SERIAL, x)
+    tvm.tir.For(tvm.tir.Var("i", "int32"), 0, 1, tvm.tir.ForKind.UNROLLED, x, 
step=2)
 
 
 def test_dir():
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     dir(x)
 
 
 def test_dtype():
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     assert x.dtype == "int32"
-    y = te.var("y")
+    y = tvm.tir.Var("y", "int32")
     assert (x > y).dtype == "bool"
 
 
 def test_any():
-    x = te.var("x")
-    y = te.var("y")
-    z = te.var("z")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
     try:
         t = x or x
         assert False
@@ -183,9 +183,9 @@ def test_any():
 
 
 def test_all():
-    x = te.var("x")
-    y = te.var("y")
-    z = te.var("z")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
     try:
         t = x and x
         assert False
@@ -216,8 +216,8 @@ def test_all():
 
 
 def test_bitwise():
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     assert str(x << y) == "T.shift_left(x, y)"
     assert str(x >> y) == "T.shift_right(x, y)"
     assert str(x & y) == "T.bitwise_and(x, y)"
@@ -233,7 +233,7 @@ def test_bitwise():
     assert str(~x) == "T.bitwise_not(x)"
     assert (tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
     assert (x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
-    assert (te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == 
"int8x2"
+    assert (tvm.tir.Var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == 
"int8x2"
 
 
 def test_float_bitwise():
@@ -258,7 +258,7 @@ def test_float_bitwise():
 
 
 def test_shift_bounds():
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     for test in [lambda lhs, rhs: lhs << rhs, lambda lhs, rhs: lhs >> rhs]:
         # negative case
         for testcase in [(x, -1), (x, 32)]:
@@ -295,20 +295,20 @@ def test_infinity():
 
 
 def test_isnan():
-    x = te.var("x", "float32")
+    x = tvm.tir.Var("x", "float32")
     assert str(tvm.tir.isnan(x)) == "T.isnan(x)"
     assert str(tvm.tir.isnan(x).dtype) == "bool"
-    y = te.var("y", "float16")
+    y = tvm.tir.Var("y", "float16")
     assert str(tvm.tir.isnan(y)) == 'T.isnan(T.Cast("float32", y))'
-    z = te.var("z", "int32")
+    z = tvm.tir.Var("z", "int32")
     assert str(tvm.tir.isnan(z)) == "T.bool(False)"
-    k = te.var("k", "int8x2")
+    k = tvm.tir.Var("k", "int8x2")
     assert str(tvm.tir.isnan(k).dtype) == "boolx2"
 
 
 def test_equality():
-    a = te.var("a")
-    b = te.var("b")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     c = a == b
     assert not c
     d = c != c
@@ -323,8 +323,8 @@ def test_equality_string_imm():
 
 
 def test_prim_func():
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     b = tvm.tir.decl_buffer((x,), "float32")
     stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
 
@@ -390,7 +390,7 @@ def _create_broadcast(lanes):
     return tvm.tir.Broadcast(0, lanes)
 
 
[email protected]("lanes", [(tvm.tir.IntImm(dtype="int64", value=11))])
[email protected]("lanes", [tvm.tir.IntImm(dtype="int64", value=11)])
 @pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
 def test_lane_types(lanes, node_func):
     def _check_dtype(node):
diff --git a/tests/python/tir-base/test_tir_ops.py 
b/tests/python/tir-base/test_tir_ops.py
index cb7d8c597a..6614f557f4 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -16,7 +16,6 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
 
 import pytest
 
@@ -52,7 +51,7 @@ def test_const_fold():
 
 
 def test_const_fold2():
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     tmod = tvm.tir.truncmod
     tdiv = tvm.tir.truncdiv
     assert (x + 0).same_as(x)
@@ -66,7 +65,7 @@ def test_const_fold2():
 
 def test_const_fold3():
     # Test that using ints with logic operations is forbidden
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     for val in [0, 1]:
         for func in [tvm.tir.all, tvm.tir.any]:
             check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
@@ -84,7 +83,7 @@ def test_const_fold3():
                     tvm.tir.const(py_func(v1, v2), "bool"),
                 )
 
-    x = te.var("x", "bool")
+    x = tvm.tir.Var("x", "bool")
     true = tvm.tir.const(1, "bool")
     false = tvm.tir.const(0, "bool")
 
@@ -108,11 +107,11 @@ def test_const_fold4():
     assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3
     x4 = x3 + 0.55
     assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6
-    x5 = te.ceil(x4)
+    x5 = tvm.tir.ceil(x4)
     assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4
     x6 = x5.astype("int")
     assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6)
-    y = (te.round((tvm.tir.const(6.5, "float32") - 1) / 1.5) + 2).astype("int")
+    y = (tvm.tir.round((tvm.tir.const(6.5, "float32") - 1) / 1.5) + 
2).astype("int")
     assert isinstance(y, tvm.tir.IntImm) and y.value == 6
 
 
@@ -126,8 +125,8 @@ def test_binary_dtype_match():
             [("uint32", "int32"), "uint32"],
         ]
         for (lhs_dtype, rhs_dtype), out_dtype in rules:
-            lhs = te.var("lhs", dtype=lhs_dtype)
-            rhs = te.var("rhs", dtype=rhs_dtype)
+            lhs = tvm.tir.Var("lhs", lhs_dtype)
+            rhs = tvm.tir.Var("rhs", rhs_dtype)
             out = f(lhs, rhs)
             if not is_conditional:
                 assert out.dtype == out_dtype
@@ -146,8 +145,8 @@ def test_binary_dtype_match():
     def verify_callop_float_only(f):
         for lhs_dtype in ["int32", "float32", "float64"]:
             for rhs_dtype in ["int32", "float32", "float64"]:
-                lhs = te.var("lhs", dtype=lhs_dtype)
-                rhs = te.var("rhs", dtype=rhs_dtype)
+                lhs = tvm.tir.Var("lhs", lhs_dtype)
+                rhs = tvm.tir.Var("rhs", rhs_dtype)
                 if "float" not in lhs_dtype and "float" not in rhs_dtype:
                     check_throws(lambda: f(lhs, rhs))
                 elif "float" in lhs_dtype:
@@ -176,7 +175,7 @@ def test_binary_dtype_match():
     verify_general_dtype_support(lambda a, b: a * b)
     verify_general_dtype_support(lambda a, b: a >= b, is_conditional=True)
     verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True)
-    verify_callop_float_only(lambda a, b: te.power(a, b))
+    verify_callop_float_only(lambda a, b: tvm.tir.power(a, b))
 
     # verify bool & int32 constant folding
     assert tvm.tir.const(1) == tvm.tir.const(True)
@@ -185,15 +184,15 @@ def test_binary_dtype_match():
 
 def test_if_then_else():
     cases = [
-        [(te.var("cond", dtype="bool"), "bool", "int32"), "int32"],
+        [(tvm.tir.Var("cond", "bool"), "bool", "int32"), "int32"],
         [(True, "int32", "float32"), "float32"],
         [(False, "int32", "int64"), "int64"],
-        [(te.var("cond", dtype="bool"), "uint32", "int32"), "uint32"],
-        [(te.var("cond", dtype="int32"), "uint32", "int32"), "uint32"],
+        [(tvm.tir.Var("cond", "bool"), "uint32", "int32"), "uint32"],
+        [(tvm.tir.Var("cond", "int32"), "uint32", "int32"), "uint32"],
     ]
     for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
-        lhs = te.var("lhs", dtype=lhs_dtype)
-        rhs = te.var("rhs", dtype=rhs_dtype)
+        lhs = tvm.tir.Var("lhs", lhs_dtype)
+        rhs = tvm.tir.Var("rhs", rhs_dtype)
         if cond is True or cond is False:
             out = tvm.tir.if_then_else(cond, lhs, rhs)
             out2 = tvm.tir.if_then_else(not cond, rhs, lhs)
diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py 
b/tests/python/tir-base/test_tir_structural_equal_hash.py
index 296450b9a2..e7113acdad 100644
--- a/tests/python/tir-base/test_tir_structural_equal_hash.py
+++ b/tests/python/tir-base/test_tir_structural_equal_hash.py
@@ -17,7 +17,7 @@
 import tvm
 import numpy as np
 import pytest
-from tvm import te
+
 from tvm_ffi.access_path import AccessPath
 from tvm.script import tir as T, ir as I
 
@@ -73,9 +73,9 @@ def test_exprs():
     # save load json
     x = tvm.tir.const(1, "int32")
     y = tvm.tir.const(10, "int32")
-    vx = te.var("x")
-    vy = te.var("y")
-    vz = te.var("z")
+    vx = tvm.tir.Var("x", "int32")
+    vy = tvm.tir.Var("y", "int32")
+    vz = tvm.tir.Var("z", "int32")
     zx = vx + vx
     zy = vy + vy
 
@@ -105,8 +105,8 @@ def test_exprs():
 
 
 def test_prim_func():
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     # counter example of same equality
     func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x + y))
     func1 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(y + x))
@@ -132,9 +132,9 @@ def test_prim_func():
 
 
 def test_prim_func_param_count_mismatch():
-    x = te.var("x")
-    y = te.var("y")
-    z = te.var("z")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
     # counter example of same equality
     func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x))
     func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x))
@@ -146,9 +146,9 @@ def test_prim_func_param_count_mismatch():
 
 
 def test_prim_func_param_dtype_mismatch():
-    x = te.var("x")
-    y_0 = te.var("y", dtype="int32")
-    y_1 = te.var("z", dtype="float32")
+    x = tvm.tir.Var("x", "int32")
+    y_0 = tvm.tir.Var("y", "int32")
+    y_1 = tvm.tir.Var("z", "float32")
     # counter example of same equality
     func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x))
     func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x))
@@ -159,10 +159,10 @@ def test_prim_func_param_dtype_mismatch():
 
 
 def test_prim_func_body_mismatch():
-    x_0 = te.var("x")
-    y_0 = te.var("y")
-    x_1 = te.var("x")
-    y_1 = te.var("y")
+    x_0 = tvm.tir.Var("x", "int32")
+    y_0 = tvm.tir.Var("y", "int32")
+    x_1 = tvm.tir.Var("x", "int32")
+    y_1 = tvm.tir.Var("y", "int32")
     # counter example of same equality
     func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0))
     func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1))
@@ -206,16 +206,6 @@ def test_attrs():
 
 
 def test_stmt():
-    x = te.var("x")
-    y = te.var("y")
-    n = 128
-    A = te.placeholder((n, n), name="A")
-    B = te.placeholder((n, n), name="B")
-    ii = te.var("i")
-    jj = te.var("j")
-
-    n = te.var("n")
-
     @T.prim_func(private=True, check_well_formed=False)
     def func2(A: T.handle, n_param: T.int32):
         n_var = T.var("int32")
@@ -230,7 +220,7 @@ def test_stmt():
 
 
 def test_buffer_storage_scope():
-    x = te.var("x", dtype="handle")
+    x = tvm.tir.Var("x", "handle")
 
     buffer_local_0 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
     buffer_local_1 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
@@ -248,7 +238,7 @@ def test_buffer_storage_scope():
 
 
 def test_buffer_map_mismatch():
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     buffer_0 = tvm.tir.decl_buffer((10, 10))
     buffer_0_clone = tvm.tir.decl_buffer((10, 10))
     buffer_1 = tvm.tir.decl_buffer((10, 20))
@@ -268,8 +258,8 @@ def test_buffer_map_mismatch():
 
 
 def test_buffer_map_length_mismatch():
-    x = te.var("x")
-    y = te.var("x")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("x", "int32")
 
     buffer_0 = tvm.tir.decl_buffer((10, 10))
     buffer_1 = tvm.tir.decl_buffer((10, 20))
diff --git 
a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py 
b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
index 1be5e57ba1..8b8e187341 100644
--- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
@@ -17,7 +17,6 @@
 import hashlib
 
 import tvm
-from tvm import te, topi
 from tvm.ir.base import save_json
 from tvm.ir.module import IRModule
 from tvm.script import tir as T
@@ -29,15 +28,15 @@ from tvm.script import tir as T
 # A test program which gives the opportunity for the CSE pass to introduce two 
new variables,
 # at two different levels
 def test_cse():
-    z1 = te.var("z1")
-    z2 = te.var("z2")
-    z3 = te.var("z3")
-    i1 = te.var("i1")
-    i2 = te.var("i2")
-    x = te.var("x")
-    y = te.var("y")
-    a = te.var("a")
-    b = te.var("b")
+    z1 = tvm.tir.Var("z1", "int32")
+    z2 = tvm.tir.Var("z2", "int32")
+    z3 = tvm.tir.Var("z3", "int32")
+    i1 = tvm.tir.Var("i1", "int32")
+    i2 = tvm.tir.Var("i2", "int32")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    a = tvm.tir.Var("a", "int32")
+    b = tvm.tir.Var("b", "int32")
     dtype = "int32"
     buffer = tvm.tir.decl_buffer((50,), dtype)
     # Test prog :
@@ -152,12 +151,12 @@ def test_cse():
 # branch, not before the whole If (otherwise that would lead to some 
computations being computed
 # for nothing when it is the Else branch that is executed).
 def test_cse_ifNode_1():
-    b = te.var("b")
-    i1 = te.var("i1")
-    i2 = te.var("i2")
-    i3 = te.var("i3")
-    y = te.var("y")
-    z = te.var("z")
+    b = tvm.tir.Var("b", "int32")
+    i1 = tvm.tir.Var("i1", "int32")
+    i2 = tvm.tir.Var("i2", "int32")
+    i3 = tvm.tir.Var("i3", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
     dtype = "int32"
     buffer = tvm.tir.decl_buffer((50,), dtype)
     # Test prog :
@@ -208,12 +207,12 @@ def test_cse_ifNode_1():
 # In this case, the CSE pass should introduce the redundant computation before 
the whole If node,
 # because regardless of the execution path, it is going to be computed.
 def test_cse_ifNode_2():
-    b = te.var("b")
-    i1 = te.var("i1")
-    i2 = te.var("i2")
-    i3 = te.var("i3")
-    y = te.var("y")
-    z = te.var("z")
+    b = tvm.tir.Var("b", "int32")
+    i1 = tvm.tir.Var("i1", "int32")
+    i2 = tvm.tir.Var("i2", "int32")
+    i3 = tvm.tir.Var("i3", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
     dtype = "int32"
     buffer = tvm.tir.decl_buffer((50,), dtype)
     # Test prog :
@@ -261,12 +260,12 @@ def test_cse_ifNode_2():
 # and in the rest of the program.
 # 
-------------------------------------------------------------------------------------------------
 def test_cse_cascade():
-    i1 = te.var("i1")
-    i2 = te.var("i2")
-    i3 = te.var("i3")
-    x = te.var("x")
-    y = te.var("y")
-    z = te.var("z")
+    i1 = tvm.tir.Var("i1", "int32")
+    i2 = tvm.tir.Var("i2", "int32")
+    i3 = tvm.tir.Var("i3", "int32")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
     dtype = "int32"
     buffer = tvm.tir.decl_buffer((50,), dtype)
     # Test prog :
@@ -326,10 +325,10 @@ def test_cse_cascade():
 # A test which ensures that we don't perform normalizations outside of 
introduced variables
 # 
-----------------------------------------------------------------------------------------
 def test_no_normalization_without_commoning():
-    x = te.var("x")
-    y = te.var("y")
-    z = te.var("z")
-    a = te.var("a")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
+    a = tvm.tir.Var("a", "int32")
     # Test prog :
     # let a = x + (y + z) in a
     body = tvm.tir.LetStmt(a, x + (y + z), tvm.tir.Evaluate(a))
@@ -418,8 +417,8 @@ def test_deterministic_cse():
     NUM_TERMS = 10
     REPEATS = 10
 
-    x = te.var("x")
-    result = te.var("result")
+    x = tvm.tir.Var("x", "int32")
+    result = tvm.tir.Var("result", "int32")
 
     offsets = sorted([i + 1 for i in range(NUM_TERMS)])
     inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)]
diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py 
b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
index a0a6ab2508..32367e0aa0 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
@@ -16,7 +16,6 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
 import numpy as np
 
 
@@ -45,18 +44,34 @@ def check_value(expr, variables, data, fref):
     num_vars = len(variables)
     assert num_vars >= 1 and all(len(row) == num_vars for row in data)
 
-    placeholders = [
-        te.placeholder((n,), name=f"v{i}", dtype=variables[i].dtype) for i in 
range(num_vars)
+    # Build input and output buffers
+    input_bufs = [
+        tvm.tir.decl_buffer((n,), dtype=variables[i].dtype, name=f"v{i}") for 
i in range(num_vars)
     ]
+    out_buf = tvm.tir.decl_buffer((n,), dtype=expr.dtype, name="C")
 
-    def make_binds(i):
-        x = expr
+    # Build loop body: for each i, bind variables[j] = input_bufs[j][i], then 
store expr to out
+    loop_var = tvm.tir.Var("i", "int32")
+
+    def make_store(i_var):
+        # Build the expression with each variable bound to the corresponding 
buffer load
+        result = expr
         for j in range(num_vars - 1, -1, -1):
-            x = tvm.tir.Let(variables[j], placeholders[j][i], x)
-        return x
+            result = tvm.tir.Let(variables[j], 
tvm.tir.BufferLoad(input_bufs[j], [i_var]), result)
+        return tvm.tir.BufferStore(out_buf, result, [i_var])
+
+    loop = tvm.tir.For(
+        loop_var,
+        tvm.tir.const(0, "int32"),
+        tvm.tir.const(n, "int32"),
+        tvm.tir.ForKind.SERIAL,
+        make_store(loop_var),
+    )
+
+    prim_func = tvm.tir.PrimFunc(input_bufs + [out_buf], loop)
+    prim_func = prim_func.with_attr({"tir.noalias": True, "global_symbol": 
"main"})
+    f = tvm.compile(prim_func, "llvm")
 
-    C = te.compute((n,), make_binds)
-    f = tvm.compile(te.create_prim_func(placeholders + [C]), "llvm")
     arrays = [
         tvm.runtime.tensor(np.array([row[j] for row in data], 
dtype=variables[j].dtype))
         for j in range(num_vars)
@@ -81,32 +96,32 @@ def get_ref_data():
 def test_lower_floordiv():
     data = get_ref_data()
     for dtype in ["int32", "int64", "int16"]:
-        x = te.var("x", dtype=dtype)
-        y = te.var("y", dtype=dtype)
+        x = tvm.tir.Var("x", dtype)
+        y = tvm.tir.Var("y", dtype)
         zero = tvm.tir.const(0, dtype)
         # no constraints
-        res = lower_intrin([x, y], tvm.te.floordiv(x, y))
+        res = lower_intrin([x, y], tvm.tir.floordiv(x, y))
         check_value(res, [x, y], data, lambda a, b: a // b)
         # rhs >= 0
-        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, 
y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.tir.floordiv(x, 
y), zero))
         check_value(res, [x, y], data, lambda a, b: a // b if b > 0 else 0)
         # involves max
         res = lower_intrin(
-            [x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), 
zero), zero)
+            [x, y], tvm.tir.Select(y >= 0, tvm.tir.max(tvm.tir.floordiv(x, y), 
zero), zero)
         )
         check_value(res, [x, y], data, lambda a, b: max(a // b, 0) if b > 0 
else 0)
         # lhs >= 0
         res = lower_intrin(
-            [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), 
tvm.te.floordiv(x, y), zero)
+            [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), 
tvm.tir.floordiv(x, y), zero)
         )
         check_value(res, [x, y], data, lambda a, b: a // b if b > 0 and a >= 0 
else 0)
         # const power of two
-        res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, 
dtype=dtype)))
+        res = lower_intrin([x, y], tvm.tir.floordiv(x, tvm.tir.const(8, 
dtype=dtype)))
         check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda 
a, b: a // b)
         # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
         res = lower_intrin(
             [x, y],
-            tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
+            tvm.tir.floordiv(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
         )
         check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda 
a, b: (a + 4) // b)
 
@@ -115,27 +130,27 @@ def test_lower_floordiv():
 def test_lower_floormod():
     data = get_ref_data()
     for dtype in ["int32", "int64", "int16"]:
-        x = te.var("x", dtype=dtype)
-        y = te.var("y", dtype=dtype)
+        x = tvm.tir.Var("x", dtype)
+        y = tvm.tir.Var("y", dtype)
         zero = tvm.tir.const(0, dtype)
         # no constraints
-        res = lower_intrin([x, y], tvm.te.floormod(x, y))
+        res = lower_intrin([x, y], tvm.tir.floormod(x, y))
         check_value(res, [x, y], data, lambda a, b: a % b)
         # rhs >= 0
-        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, 
y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.tir.floormod(x, 
y), zero))
         check_value(res, [x, y], data, lambda a, b: a % b if b > 0 else 0)
         # lhs >= 0
         res = lower_intrin(
-            [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), 
tvm.te.floormod(x, y), zero)
+            [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), 
tvm.tir.floormod(x, y), zero)
         )
         check_value(res, [x, y], data, lambda a, b: a % b if b > 0 and a >= 0 
else 0)
         # const power of two
-        res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, 
dtype=dtype)))
+        res = lower_intrin([x, y], tvm.tir.floormod(x, tvm.tir.const(8, 
dtype=dtype)))
         check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda 
a, b: a % b)
         # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
         res = lower_intrin(
             [x, y],
-            tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
+            tvm.tir.floormod(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
         )
         check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda 
a, b: (a + 4) % b)
 
@@ -149,26 +164,26 @@ def test_lower_floordiv_overflow_checks():
     """
     # Check 3: (b-1) - a_min must not overflow (numerator and C++ int64).
     # x (int64) full range -> min_value = -2^63. With b = 3: numerator = 2 - 
(-2^63) > LLONG_MAX.
-    x = te.var("x", dtype="int64")
-    res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int64")))
+    x = tvm.tir.Var("x", "int64")
+    res = lower_intrin([x], tvm.tir.floordiv(x, tvm.tir.const(3, "int64")))
     data_check3 = [(-(2**63),), (0,), (100,)]
     check_value(res, [x], data_check3, lambda a: a // 3)
 
     # Check 4: c_value * b_value must not overflow dtype.
     # x (int16) full range -> min_value = -32768, c = ceil(32770/3) = 10923; 
10923*3 > 32767.
-    x = te.var("x", dtype="int16")
-    res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int16")))
+    x = tvm.tir.Var("x", "int16")
+    res = lower_intrin([x], tvm.tir.floordiv(x, tvm.tir.const(3, "int16")))
     data_check4 = [(-32768,), (0,), (100,)]
     check_value(res, [x], data_check4, lambda a: a // 3)
 
     # Check 5: a_max + b*c must not overflow (offset numerator).
     # tir.min(tir.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4; 
a_max + 12 > 32767.
     # In practice this path may not be triggered. This test still validates 
correct lowering.
-    x = te.var("x", dtype="int16")
+    x = tvm.tir.Var("x", "int16")
     clamped = tvm.tir.min(
         tvm.tir.max(x, tvm.tir.const(-10, "int16")), tvm.tir.const(32758, 
"int16")
     )
-    res = lower_intrin([x], tvm.te.floordiv(clamped, tvm.tir.const(3, 
"int16")))
+    res = lower_intrin([x], tvm.tir.floordiv(clamped, tvm.tir.const(3, 
"int16")))
     data_check5 = [(-10,), (0,), (32758,), (32757,)]
     check_value(res, [x], data_check5, lambda a: (min(max(a, -10), 32758)) // 
3)
 
diff --git a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py 
b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py
index 553c745770..8bee82f02c 100644
--- a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py
+++ b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py
@@ -16,7 +16,6 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
 
 
 def test_prim_func_pass():
@@ -30,8 +29,8 @@ def test_prim_func_pass():
         def transform_function(self, func, mod, ctx):
             return self.new_func
 
-    x = te.var("x")
-    y = te.var("y")
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
     b = tvm.tir.decl_buffer((x,), "float32")
     stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
 
@@ -51,7 +50,7 @@ def test_cow_pass():
         return f
 
     pidentity = tvm.tir.transform.Apply(fapply)
-    x = te.var("x")
+    x = tvm.tir.Var("x", "int32")
     func = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x)).with_attr("target_bits", 
32)
     func_hash = func.__hash__()
     mod = tvm.IRModule({"main": func})
diff --git a/tests/python/tir-transform/test_tir_transform_remove_no_op.py 
b/tests/python/tir-transform/test_tir_transform_remove_no_op.py
index 5d712a21f7..91902ea400 100644
--- a/tests/python/tir-transform/test_tir_transform_remove_no_op.py
+++ b/tests/python/tir-transform/test_tir_transform_remove_no_op.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te
 from tvm.script import tir as T
 import tvm.testing
 
@@ -27,11 +26,11 @@ def nop():
 
 
 def test_remove_no_op():
-    i = te.var("i")
-    j = te.var("j")
-    k = te.var("k")
-    m = te.var("m")
-    n = te.var("n")
+    i = tvm.tir.Var("i", "int32")
+    j = tvm.tir.Var("j", "int32")
+    k = tvm.tir.Var("k", "int32")
+    m = tvm.tir.Var("m", "int32")
+    n = tvm.tir.Var("n", "int32")
     dtype = "int64"
     Ab = tvm.tir.decl_buffer((n,), dtype)
     stmt = tvm.tir.For(

Reply via email to