This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 3a909e2b95e7c0c2025c118d75bdd575e9d4928b
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Sat Feb 18 16:56:36 2023 -0500

    [Unity] Update tests to adapt to latest TVMScript syntax (#14039)
    
    Given that some latest changes of TVMScript syntax have been merged,
    some test files are now containing deprecated uses of TVMScript syntax.
    This PR updates the test files with latest TVMScript syntax so that
    running the tests will not trigger deprecation warnings.
    
    Co-authored-by: Tianqi Chen <tqc...@users.noreply.github.com>
---
 python/tvm/relax/frontend/torch/fx_translator.py   |  11 +-
 tests/python/relax/test_frontend_dynamo.py         |   8 +-
 tests/python/relax/test_transform.py               |   6 +-
 .../test_transform_annotate_tir_op_pattern.py      |   6 +-
 .../relax/test_transform_attach_global_symbol.py   |  16 +-
 tests/python/relax/test_transform_fold_constant.py |  24 +-
 tests/python/relax/test_transform_lambda_lift.py   |  10 +-
 .../relax/test_transform_legalize_ops_binary.py    | 264 ++++++++++-----------
 .../test_transform_legalize_ops_create_datatype.py | 120 +++++-----
 .../relax/test_transform_legalize_ops_image.py     |  28 +--
 ..._transform_legalize_ops_index_linear_algebra.py |  54 ++---
 .../test_transform_legalize_ops_manipulate.py      | 142 +++++------
 .../python/relax/test_transform_legalize_ops_nn.py | 154 ++++++------
 ...st_transform_legalize_ops_search_statistical.py | 106 ++++-----
 .../relax/test_transform_legalize_ops_unary.py     | 120 +++++-----
 .../relax/test_transform_meta_schedule_tuning.py   |   2 +-
 tests/python/relax/test_transform_normalize.py     |   2 +-
 .../test_transform_static_plan_block_memory.py     |   8 +-
 tests/python/relax/test_tuning_api.py              |   2 +-
 tests/python/relax/test_tvmscript_ir_builder.py    |   4 +-
 tests/python/relax/test_tvmscript_parser.py        |  42 ++--
 tests/python/relax/test_vm_build.py                |  22 +-
 22 files changed, 574 insertions(+), 577 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 582f2edbcf..a762b0a0fb 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -76,9 +76,8 @@ class TorchFXImporter:
     @staticmethod
     def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var:
         tensor = tensor.detach().cpu()
-        shape = tensor.data.shape
         dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype))
-        return relax.const(tensor.data.numpy(), relax.TensorStructInfo(shape, 
dtype))
+        return relax.const(tensor.data.numpy(), dtype)
 
     @staticmethod
     def shape_of(tensor):
@@ -444,8 +443,8 @@ class TorchFXImporter:
             gamma = self.params[module.weight]
             beta = self.params[module.bias]
         else:
-            gamma = relax.const(torch.ones_like(module.normalized_shape), 
x.checked_type)
-            beta = relax.const(torch.zeros_like(module.normalized_shape), 
x.checked_type)
+            gamma = relax.const(torch.ones_like(module.normalized_shape), 
x.struct_info.dtype)
+            beta = relax.const(torch.zeros_like(module.normalized_shape), 
x.struct_info.dtype)
         dim_num = len(module.normalized_shape)
         axes = list(range(-dim_num, 0))
 
@@ -702,9 +701,7 @@ class TorchFXImporter:
                     shape = param.data.shape
                     dtype = self._convert_data_type(str(param.data.dtype))
                     if dtype in ("float32", "float16"):
-                        self.params[param] = relax.const(
-                            param.data.cpu().numpy(), 
relax.TensorStructInfo(shape, dtype)
-                        )
+                        self.params[param] = 
relax.const(param.data.cpu().numpy(), dtype)
                     else:
                         raise ValueError("Unsupported data type for model 
parameters: %s" % dtype)
                 # Translate the model.
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index 370df2103d..b47e3e22bd 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -43,10 +43,10 @@ def test_relax_dynamo():
     class Input1_ir:
         @T.prim_func
         def main(
-            inp_0: T.Buffer[(T.int64(10), T.int64(100)), "float32"],
-            param_0: T.Buffer[(T.int64(100), T.int64(10)), "float32"],
-            param_1: T.Buffer[T.int64(10), "float32"],
-            compute: T.Buffer[(T.int64(10), T.int64(10)), "float32"],
+            inp_0: T.Buffer((T.int64(10), T.int64(100)), "float32"),
+            param_0: T.Buffer((T.int64(100), T.int64(10)), "float32"),
+            param_1: T.Buffer(T.int64(10), "float32"),
+            compute: T.Buffer((T.int64(10), T.int64(10)), "float32"),
         ):
             # function attr dict
             T.func_attr({"tir.noalias": True, "global_symbol": "main"})
diff --git a/tests/python/relax/test_transform.py 
b/tests/python/relax/test_transform.py
index 12dd095c6b..85de4f912e 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -30,7 +30,7 @@ def test_to_non_dataflow():
     class TestToNonDataflow:
         @R.function
         def foo(x: R.Tensor(("m", "n"), "float32")):
-            m, n = T.var("int64"), T.var("int64")
+            m, n = T.int64(), T.int64()
             with R.dataflow():
                 lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
                 gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32"))
@@ -75,7 +75,7 @@ def test_call_tir_rewrite():
     class TestCallTIRRewrite:
         @R.function
         def foo(x: R.Tensor(("m", "n"), "float32")):
-            m, n = T.var("int64"), T.var("int64")
+            m, n = T.int64(), T.int64()
             gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
             return gv0
 
@@ -108,7 +108,7 @@ def test_vm_builtin_lower():
     class TestVMBuiltinLower:
         @R.function
         def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
-            m, n = T.var("int64"), T.var("int64")
+            m, n = T.int64(), T.int64()
             alloc = R.builtin.alloc_tensor(R.shape([m, n]), 
runtime_device_index=0, dtype="float32")
             _ = R.call_packed(
                 "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, 
dtype="float32"))
diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py 
b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
index 73c6537869..23ce49a7c2 100644
--- a/tests/python/relax/test_transform_annotate_tir_op_pattern.py
+++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
@@ -39,9 +39,9 @@ def test_annotate_opkind_outewisefusable():
         @T.prim_func
         def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
             T.func_attr({"global_symbol": "tir_matmul"})
-            m = T.var("int32")
-            n = T.var("int32")
-            k = T.var("int32")
+            m = T.int32()
+            n = T.int32()
+            k = T.int32()
             A = T.match_buffer(x, (m, n))
             B = T.match_buffer(y, (n, k))
             C = T.match_buffer(z, (m, k))
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py 
b/tests/python/relax/test_transform_attach_global_symbol.py
index edfc646e21..cef3842e3e 100644
--- a/tests/python/relax/test_transform_attach_global_symbol.py
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -28,9 +28,9 @@ from tvm.script import tir as T, relax as R
 class Before:
     @T.prim_func
     def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
-        m = T.var("int64")
-        n = T.var("int64")
-        k = T.var("int64")
+        m = T.int64()
+        n = T.int64()
+        k = T.int64()
         A = T.match_buffer(x, (m, n))
         B = T.match_buffer(y, (n, k))
         C = T.match_buffer(z, (m, k))
@@ -44,7 +44,7 @@ class Before:
 
     @R.function
     def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")) -> R.Tensor:
-        m, n, k = T.var("int64"), T.var("int64"), T.var("int64")
+        m, n, k = T.int64(), T.int64(), T.int64()
         gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), 
dtype="float32"))
         return gv0
 
@@ -55,9 +55,9 @@ def test_basic():
         @T.prim_func
         def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
             T.func_attr({"global_symbol": "tir_matmul"})
-            m = T.var("int64")
-            n = T.var("int64")
-            k = T.var("int64")
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
             A = T.match_buffer(x, (m, n))
             B = T.match_buffer(y, (n, k))
             C = T.match_buffer(z, (m, k))
@@ -74,7 +74,7 @@ def test_basic():
             x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")
         ) -> R.Tensor:
             R.func_attr({"global_symbol": "main"})
-            m, n, k = T.var("int64"), T.var("int64"), T.var("int64")
+            m, n, k = T.int64(), T.int64(), T.int64()
             gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), 
dtype="float32"))
             return gv0
 
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
index 32ee3e7000..95542dd4e6 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -59,7 +59,7 @@ def test_one_fold_addone():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), 
"float32"]) -> None:
+        def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), 
"float32")) -> None:
             for i, j in T.grid(16, 16):
                 with T.block("addone"):
                     vi, vj = T.axis.remap("SS", [i, j])
@@ -89,7 +89,7 @@ def test_one_fold_transpose():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), 
"float32"]) -> None:
+        def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), 
"float32")) -> None:
             for i, j in T.grid(3, 2):
                 with T.block("transpose"):
                     vi, vj = T.axis.remap("SS", [i, j])
@@ -118,7 +118,7 @@ def test_two_hop_addone():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2), 
"float32"]) -> None:
+        def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), 
"float32")) -> None:
             for i, j in T.grid(2, 2):
                 with T.block("addone"):
                     vi, vj = T.axis.remap("SS", [i, j])
@@ -150,7 +150,7 @@ def test_dataflow_fold():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), 
"float32"]) -> None:
+        def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), 
"float32")) -> None:
             for i, j in T.grid(16, 16):
                 with T.block("identity"):
                     vi, vj = T.axis.remap("SS", [i, j])
@@ -184,8 +184,8 @@ def test_fold_mixed_case():
         # TIR function can handle different cases.
         @T.prim_func
         def addone(a: T.handle, b: T.handle) -> None:
-            n = T.var("int32")
-            m = T.var("int32")
+            n = T.int32()
+            m = T.int32()
             A = T.match_buffer(a, (n, m))
             B = T.match_buffer(b, (n, m))
             for i, j in T.grid(n, m):
@@ -195,9 +195,9 @@ def test_fold_mixed_case():
 
         @T.prim_func
         def sub(
-            A: T.Buffer[(16, 16), "float32"],
-            B: T.Buffer[(16, 16), "float32"],
-            C: T.Buffer[(16, 16), "float32"],
+            A: T.Buffer((16, 16), "float32"),
+            B: T.Buffer((16, 16), "float32"),
+            C: T.Buffer((16, 16), "float32"),
         ) -> None:
             for i, j in T.grid(16, 16):
                 with T.block("sub"):
@@ -206,7 +206,7 @@ def test_fold_mixed_case():
 
         @R.function
         def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", 
ndim=2)):
-            n, m = T.var("int64"), T.var("int64")
+            n, m = T.int64(), T.int64()
             x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
             # this line cannot be folded because n is unknown
             lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), 
dtype="float32"))
@@ -225,7 +225,7 @@ def test_fold_mixed_case():
             c2: R.Tensor((16, 16), "float32"),
             x: R.Tensor("float32", ndim=2),
         ) -> R.Tensor:
-            n, m = T.var("int64"), T.var("int64")
+            n, m = T.int64(), T.int64()
             x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
             # this line cannot be folded because n is unknown
             lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), 
dtype="float32"))
@@ -251,7 +251,7 @@ def test_int32_fold():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), 
"int32"]) -> None:
+        def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), 
"int32")) -> None:
             for i, j in T.grid(16, 16):
                 with T.block("addone"):
                     vi, vj = T.axis.remap("SS", [i, j])
diff --git a/tests/python/relax/test_transform_lambda_lift.py 
b/tests/python/relax/test_transform_lambda_lift.py
index fbdb1fbdce..c9bbc0fb91 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -190,7 +190,7 @@ def test_recursive():
 
     before = Before
     expected = Expected
-    # Perform Lamda Lifting
+    # Perform Lambda Lifting
     after = transform.LambdaLift()(before)
     assert len(after.functions) == 2
 
@@ -266,7 +266,7 @@ def test_multi_func():
 
     before = Before
     expected = Expected
-    # Perform Lamda Lifting
+    # Perform Lambda Lifting
     after = transform.LambdaLift()(before)
     assert len(after.functions) == 4
     assert_structural_equal(after, expected, map_free_vars=True)
@@ -278,9 +278,9 @@ def test_no_local_func():
     class Before:
         @T.prim_func
         def sub(
-            A: T.Buffer[(16, 16), "float32"],
-            B: T.Buffer[(16, 16), "float32"],
-            C: T.Buffer[(16, 16), "float32"],
+            A: T.Buffer((16, 16), "float32"),
+            B: T.Buffer((16, 16), "float32"),
+            C: T.Buffer((16, 16), "float32"),
         ) -> None:
             for i, j in T.grid(16, 16):
                 with T.block("sub"):
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py 
b/tests/python/relax/test_transform_legalize_ops_binary.py
index c2db7e9ba1..c99fb885c4 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -124,10 +124,10 @@ def test_add_symbolic():
     class Add:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "float32") = R.add(x, y)
             return gv
 
@@ -135,20 +135,20 @@ def test_add_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(add, (x, y), R.Tensor((a, b, c, d), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def add(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_T_add: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_add = T.match_buffer(var_T_add, [a, b, c, d], dtype="float32")
@@ -263,10 +263,10 @@ def test_divide_symbolic():
     class Divide:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "float32") = R.divide(x, y)
             return gv
 
@@ -274,20 +274,20 @@ def test_divide_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(divide, (x, y), R.Tensor((a, b, c, d), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_T_divide: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_divide = T.match_buffer(var_T_divide, [a, b, c, d], 
dtype="float32")
@@ -402,10 +402,10 @@ def test_floor_divide_symbolic():
     class FloorDivide:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "float32") = R.floor_divide(x, y)
             return gv
 
@@ -413,20 +413,20 @@ def test_floor_divide_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(floor_divide, (x, y), R.Tensor((a, b, c, d), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def floor_divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_floor_divide: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_floor_divide = T.match_buffer(var_T_floor_divide, [a, b, c, d], 
dtype="float32")
@@ -479,10 +479,10 @@ def test_multiply_symbolic():
     class Multiply:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "float32") = R.multiply(x, y)
             return gv
 
@@ -490,20 +490,20 @@ def test_multiply_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(multiply, (x, y), R.Tensor((a, b, c, d), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def multiply(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_multiply: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_multiply = T.match_buffer(var_T_multiply, [a, b, c, d], 
dtype="float32")
@@ -556,10 +556,10 @@ def test_subtract_symbolic():
     class Subtract:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "float32") = R.subtract(x, y)
             return gv
 
@@ -567,20 +567,20 @@ def test_subtract_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(subtract, (x, y), R.Tensor((a, b, c, d), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def subtract(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_subtract: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_subtract = T.match_buffer(var_T_subtract, [a, b, c, d], 
dtype="float32")
@@ -698,10 +698,10 @@ def test_equal_symbolic():
     class Equal:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "bool") = R.equal(x, y)
             return gv
 
@@ -709,20 +709,20 @@ def test_equal_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(equal, (x, y), R.Tensor((a, b, c, d), 
dtype="bool"))
             return gv
 
         @T.prim_func
         def equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_T_equal: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_equal = T.match_buffer(var_T_equal, [a, b, c, d], dtype="bool")
@@ -837,10 +837,10 @@ def test_greater_symbolic():
     class Greater:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "bool") = R.greater(x, y)
             return gv
 
@@ -848,20 +848,20 @@ def test_greater_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(greater, (x, y), R.Tensor((a, b, c, d), 
dtype="bool"))
             return gv
 
         @T.prim_func
         def greater(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_greater: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_greater = T.match_buffer(var_T_greater, [a, b, c, d], 
dtype="bool")
@@ -914,10 +914,10 @@ def test_greater_equal_symbolic():
     class GreaterEqual:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "bool") = R.greater_equal(x, y)
             return gv
 
@@ -925,20 +925,20 @@ def test_greater_equal_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(greater_equal, (x, y), R.Tensor((a, b, c, d), 
dtype="bool"))
             return gv
 
         @T.prim_func
         def greater_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_greater_equal: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_greater_equal = T.match_buffer(var_T_greater_equal, [a, b, c, 
d], dtype="bool")
@@ -991,10 +991,10 @@ def test_less_symbolic():
     class Less:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "bool") = R.less(x, y)
             return gv
 
@@ -1002,20 +1002,20 @@ def test_less_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(less, (x, y), R.Tensor((a, b, c, d), dtype="bool"))
             return gv
 
         @T.prim_func
         def less(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_T_less: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_less = T.match_buffer(var_T_less, [a, b, c, d], dtype="bool")
@@ -1130,10 +1130,10 @@ def test_less_equal_symbolic():
     class LessEqual:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "bool") = R.less_equal(x, y)
             return gv
 
@@ -1141,20 +1141,20 @@ def test_less_equal_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(less_equal, (x, y), R.Tensor((a, b, c, d), 
dtype="bool"))
             return gv
 
         @T.prim_func
         def less_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_less_equal: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_less_equal = T.match_buffer(var_T_less_equal, [a, b, c, d], 
dtype="bool")
@@ -1207,10 +1207,10 @@ def test_not_equal_symbolic():
     class NotEqual:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "bool") = R.not_equal(x, y)
             return gv
 
@@ -1218,20 +1218,20 @@ def test_not_equal_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(not_equal, (x, y), R.Tensor((a, b, c, d), 
dtype="bool"))
             return gv
 
         @T.prim_func
         def not_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_not_equal: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
             T_not_equal = T.match_buffer(var_T_not_equal, [a, b, c, d], 
dtype="bool")
diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py 
b/tests/python/relax/test_transform_legalize_ops_create_datatype.py
index 2506e96634..6082f74102 100644
--- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py
+++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py
@@ -123,8 +123,8 @@ def test_full_symbolic():
     class Full:
         @R.function
         def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) 
-> R.Tensor(("m", "n"), "int32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "int32") = R.full((m, n), v, dtype="int32")
             return gv
 
@@ -132,16 +132,16 @@ def test_full_symbolic():
     class Expected:
         @R.function
         def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) 
-> R.Tensor(("m", "n"), "int32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(full, (v,), R.Tensor((m, n), dtype="int32"))
             return gv
 
         @T.prim_func
         def full(rxplaceholder: T.Buffer((), "int32"), var_T_full: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             T_full = T.match_buffer(var_T_full, [m, n], dtype="int32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_full"):
@@ -254,8 +254,8 @@ def test_full_like_symbolic():
     class FullLike:
         @R.function
         def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) 
-> R.Tensor(("m", "n"), "float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.full_like(x, v)
             return gv
 
@@ -263,16 +263,16 @@ def test_full_like_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) 
-> R.Tensor(("m", "n"), "float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(full, (v,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_full"):
@@ -323,8 +323,8 @@ def test_ones_symbolic():
     class Ones:
         @R.function
         def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.ones((m, n), "float32")
             return gv
 
@@ -332,16 +332,16 @@ def test_ones_symbolic():
     class Expected:
         @R.function
         def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(ones, R.tuple(), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def ones(var_T_full: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_full"):
@@ -392,8 +392,8 @@ def test_ones_like_symbolic():
     class OnesLike:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.ones_like(x)
             return gv
 
@@ -401,16 +401,16 @@ def test_ones_like_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(ones, R.tuple(), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def ones(var_T_full: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_full"):
@@ -461,8 +461,8 @@ def test_zeros_symbolic():
     class Zeros:
         @R.function
         def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.zeros((m, n), "float32")
             return gv
 
@@ -470,16 +470,16 @@ def test_zeros_symbolic():
     class Expected:
         @R.function
         def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(zeros, R.tuple(), R.Tensor((m, n), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def zeros(var_T_full: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_full"):
@@ -530,8 +530,8 @@ def test_zeros_like_symbolic():
     class ZerosLike:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.zeros_like(x)
             return gv
 
@@ -539,16 +539,16 @@ def test_zeros_like_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(zeros, R.tuple(), R.Tensor((m, n), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def zeros(var_T_full: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_full"):
@@ -599,9 +599,9 @@ def test_tril_symbolic():
     class Tril:
         @R.function
         def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", 
"k"), "int8"):
-            m = T.var("int64")
-            n = T.var("int64")
-            k = T.var("int64")
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
             gv: R.Tensor((m, n, k), "int8") = R.tril(x, k=-2)
             return gv
 
@@ -609,18 +609,18 @@ def test_tril_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", 
"k"), "int8"):
-            m = T.var("int64")
-            n = T.var("int64")
-            k = T.var("int64")
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
             gv = R.call_tir(tril, (x,), R.Tensor((m, n, k), dtype="int8"))
             return gv
 
         @T.prim_func
         def tril(var_rxplaceholder: T.handle, var_trilu: T.handle):
             T.func_attr({"tir.noalias": True})
-            k = T.var("int64")
-            m = T.var("int64")
-            n = T.var("int64")
+            k = T.int64()
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], 
dtype="int8")
             trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8")
             for i0, i1, i2 in T.grid(m, n, k):
@@ -672,9 +672,9 @@ def test_triu_symbolic():
     class Triu:
         @R.function
         def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", 
"k"), "int8"):
-            m = T.var("int64")
-            n = T.var("int64")
-            k = T.var("int64")
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
             gv: R.Tensor((m, n, k), "int8") = R.triu(x, k=-2)
             return gv
 
@@ -682,18 +682,18 @@ def test_triu_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", 
"k"), "int8"):
-            m = T.var("int64")
-            n = T.var("int64")
-            k = T.var("int64")
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
             gv = R.call_tir(triu, (x,), R.Tensor((m, n, k), dtype="int8"))
             return gv
 
         @T.prim_func
         def triu(var_rxplaceholder: T.handle, var_trilu: T.handle):
             T.func_attr({"tir.noalias": True})
-            k = T.var("int64")
-            m = T.var("int64")
-            n = T.var("int64")
+            k = T.int64()
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], 
dtype="int8")
             trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8")
             for i0, i1, i2 in T.grid(m, n, k):
@@ -769,8 +769,8 @@ def test_astype_symbolic():
     class Astype:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"int32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "int32") = R.astype(x, "int32")
             return gv
 
@@ -778,16 +778,16 @@ def test_astype_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"int32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(cast, (x,), R.Tensor((m, n), dtype="int32"))
             return gv
 
         @T.prim_func
         def cast(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="int32")
             for i0, i1 in T.grid(m, n):
diff --git a/tests/python/relax/test_transform_legalize_ops_image.py 
b/tests/python/relax/test_transform_legalize_ops_image.py
index 36c8ecdd7b..5860fea0bf 100644
--- a/tests/python/relax/test_transform_legalize_ops_image.py
+++ b/tests/python/relax/test_transform_legalize_ops_image.py
@@ -58,10 +58,10 @@ def test_image_resize2d_symbolic():
     class Resize2D:
         @R.function
         def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", 
"h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"):
-            n = T.var("int64")
-            c = T.var("int64")
-            oh = T.var("int64")
-            ow = T.var("int64")
+            n = T.int64()
+            c = T.int64()
+            oh = T.int64()
+            ow = T.int64()
             gv: R.Tensor((n, c, oh, ow, 16), "float32") = R.image.resize2d(x, 
size=(oh, ow), layout="NCHW16c", method="nearest_neighbor", 
coordinate_transformation_mode="asymmetric")
             return gv
 
@@ -69,22 +69,22 @@ def test_image_resize2d_symbolic():
     class Expected:
         @R.function
         def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", 
"h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"):
-            n = T.var("int64")
-            c = T.var("int64")
-            oh = T.var("int64")
-            ow = T.var("int64")
+            n = T.int64()
+            c = T.int64()
+            oh = T.int64()
+            ow = T.int64()
             gv = R.call_tir(resize2d, (x,), R.Tensor((n, c, oh, ow, 16), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def resize2d(var_rxplaceholder: T.handle, var_resize: T.handle):
             T.func_attr({"tir.noalias": True})
-            c = T.var("int64")
-            h = T.var("int64")
-            n = T.var("int64")
-            oh = T.var("int64")
-            ow = T.var("int64")
-            w = T.var("int64")
+            c = T.int64()
+            h = T.int64()
+            n = T.int64()
+            oh = T.int64()
+            ow = T.int64()
+            w = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w, 
T.int64(16)], dtype="float32")
             resize = T.match_buffer(var_resize, [n, c, oh, ow, T.int64(16)], 
dtype="float32")
             for i0, i1, i2, i3, i4 in T.grid(n, c, oh, ow, T.int64(16)):
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 8b6f9de981..5dd9728918 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -61,8 +61,8 @@ def test_take_symbolic():
     class Take:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), 
"int64")) -> R.Tensor(("m", "i"), "float32"):
-            m = T.var("int64")
-            i = T.var("int64")
+            m = T.int64()
+            i = T.int64()
             gv: R.Tensor((m, i), "float32") = R.take(x, indices, axis=1)
             return gv
 
@@ -70,17 +70,17 @@ def test_take_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), 
"int64")) -> R.Tensor(("m", "i"), "float32"):
-            m = T.var("int64")
-            i = T.var("int64")
+            m = T.int64()
+            i = T.int64()
             gv = R.call_tir(take, (x, indices), R.Tensor((m, i), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def take(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_T_take: T.handle):
             T.func_attr({"tir.noalias": True})
-            i = T.var("int64")
-            m = T.var("int64")
-            n = T.var("int64")
+            i = T.int64()
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [i], 
dtype="int64")
             T_take = T.match_buffer(var_T_take, [m, i], dtype="float32")
@@ -165,7 +165,7 @@ def test_strided_slice_symbolic_sliced_axis():
     class StridedSlice:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), 
"float32"):
-            n = T.var("int64")
+            n = T.int64()
             gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], 
begin=[1], end=[8], strides=[3])
             return gv
     # fmt: on
@@ -180,7 +180,7 @@ def test_strided_slice_symbolic():
     class StridedSlice:
         @R.function
         def main(x: R.Tensor((10, "n"), "float32")) -> R.Tensor((3, "n"), 
"float32"):
-            n = T.var("int64")
+            n = T.int64()
             gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0], 
begin=[1], end=[8], strides=[3])
             return gv
 
@@ -188,14 +188,14 @@ def test_strided_slice_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, 
"n"), dtype="float32"):
-            n = T.var("int64")
+            n = T.int64()
             gv = R.call_tir(strided_slice, (x,), R.Tensor((3, n), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def strided_slice(var_rxplaceholder: T.handle, 
var_T_strided_slice_with_axes: T.handle):
             T.func_attr({"tir.noalias": True})
-            n = T.var("int64")
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), 
n], dtype="float32")
             T_strided_slice_with_axes = 
T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32")
             for i0, i1 in T.grid(T.int64(3), n):
@@ -351,11 +351,11 @@ def test_matmul_4_5_symbolic():
     class Matmul:
         @R.function
         def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 
1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), 
"float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            m = T.var("int64")
-            n = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((a, b, c, m, n), "float32") = R.matmul(x, y)
             return gv
 
@@ -363,23 +363,23 @@ def test_matmul_4_5_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 
1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), 
"float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            m = T.var("int64")
-            n = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(matmul, (x, y), R.Tensor((a, b, c, m, n), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def matmul(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_matmul: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            k = T.var("int64")
-            m = T.var("int64")
-            n = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            k = T.int64()
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), 
m, k], dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, 
T.int64(1), c, k, n], dtype="float32")
             matmul = T.match_buffer(var_matmul, [a, b, c, m, n], 
dtype="float32")
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 53aa868ffe..2a30994b83 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -62,10 +62,10 @@ def test_broadcast_to_symbolic():
     class BroadcastTo:
         @R.function
         def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), 
"float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, b, c, d), "float32") = R.broadcast_to(x, (a, b, 
c, d))
             return gv
 
@@ -73,20 +73,20 @@ def test_broadcast_to_symbolic():
     class Expected:
         @R.function
         def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), 
"float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv = R.call_tir(broadcast_to, (x,), R.Tensor((a, b, c, d), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def broadcast_to(var_rxplaceholder: T.handle, var_T_broadcast_to: 
T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), 
d], dtype="float32")
             T_broadcast_to = T.match_buffer(var_T_broadcast_to, [a, b, c, d], 
dtype="float32")
             for i0, i1, i2, i3 in T.grid(a, b, c, d):
@@ -171,10 +171,10 @@ def test_concat_input_tuple_var_symbolic():
     class Concat:
         @R.function
         def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", 
"b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + 
b1 + b2"), "float32"):
-            a = T.var("int64")
-            b0 = T.var("int64")
-            b1 = T.var("int64")
-            b2 = T.var("int64")
+            a = T.int64()
+            b0 = T.int64()
+            b1 = T.int64()
+            b2 = T.int64()
             gv: R.Tensor((a, b0 + b1 + b2), "float32") = R.concat(t, axis=1)
             return gv
 
@@ -182,10 +182,10 @@ def test_concat_input_tuple_var_symbolic():
     class Expected:
         @R.function
         def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", 
"b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + 
b1 + b2"), "float32"):
-            a = T.var("int64")
-            b0 = T.var("int64")
-            b1 = T.var("int64")
-            b2 = T.var("int64")
+            a = T.int64()
+            b0 = T.int64()
+            b1 = T.int64()
+            b2 = T.int64()
             gv: R.Tensor((a, b0), dtype="float32") = t[0]
             gv1: R.Tensor((a, b1), dtype="float32") = t[1]
             gv2: R.Tensor((a, b2), dtype="float32") = t[2]
@@ -195,10 +195,10 @@ def test_concat_input_tuple_var_symbolic():
         @T.prim_func
         def concatenate(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, var_T_concat: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b0 = T.var("int64")
-            b1 = T.var("int64")
-            b2 = T.var("int64")
+            a = T.int64()
+            b0 = T.int64()
+            b1 = T.int64()
+            b2 = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b0], 
dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b1], 
dtype="float32")
             rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [a, b2], 
dtype="float32")
@@ -252,9 +252,9 @@ def test_expand_dims_symbolic():
     class ExpandDims:
         @R.function
         def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, 
"b", 1, "c", 1), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv: R.Tensor((a, 1, b, 1, c, 1), "float32") = R.expand_dims(x, 
axis=[1, 3, 5])
             return gv
 
@@ -262,18 +262,18 @@ def test_expand_dims_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, 
"b", 1, "c", 1), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv = R.call_tir(expand_dims, (x,), R.Tensor((a, 1, b, 1, c, 1), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def expand_dims(var_rxplaceholder: T.handle, var_expand_dims: 
T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], 
dtype="float32")
             expand_dims = T.match_buffer(var_expand_dims, [a, T.int64(1), b, 
T.int64(1), c, T.int64(1)], dtype="float32")
             for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), b, T.int64(1), 
c, T.int64(1)):
@@ -356,9 +356,9 @@ def test_flatten_symbolic():
     class Flatten:
         @R.function
         def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b 
* c",), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv: R.Tensor((a * b * c,), "float32") = R.flatten(x)
             return gv
 
@@ -366,18 +366,18 @@ def test_flatten_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b 
* c",), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv = R.call_tir(reshape, (x,), R.Tensor((((a * b) * c),), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], 
dtype="float32")
             T_reshape = T.match_buffer(var_T_reshape, [a * b * c], 
dtype="float32")
             for i0 in T.serial(a * b * c):
@@ -429,10 +429,10 @@ def test_permute_dims_symbolic():
     class PermuteDims:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> 
R.Tensor(("b", "d", "c", "a"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             gv: R.Tensor((b, d, c, a), "float32") = R.permute_dims(x, axes=[1, 
-1, 2, -4])
             return gv
 
@@ -440,20 +440,20 @@ def test_permute_dims_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> 
R.Tensor(("b", "d", "c", "a"), dtype="float32"):
-            b = T.var("int64")
-            d = T.var("int64")
-            c = T.var("int64")
-            a = T.var("int64")
+            b = T.int64()
+            d = T.int64()
+            c = T.int64()
+            a = T.int64()
             gv = R.call_tir(transpose, (x,), R.Tensor((b, d, c, a), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def transpose(var_rxplaceholder: T.handle, var_T_transpose: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             T_transpose = T.match_buffer(var_T_transpose, [b, d, c, a], 
dtype="float32")
             for i0, i1, i2, i3 in T.grid(b, d, c, a):
@@ -505,8 +505,8 @@ def test_reshape_symbolic():
     class Reshape:
         @R.function
         def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b 
* 2"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
+            a = T.int64()
+            b = T.int64()
             gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, (a // 2, b 
* 2))
             return gv
 
@@ -514,16 +514,16 @@ def test_reshape_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b 
* 2"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
+            a = T.int64()
+            b = T.int64()
             gv = R.call_tir(reshape, (x,), R.Tensor(((a // 2), (b * 2)), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
+            a = T.int64()
+            b = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], 
dtype="float32")
             T_reshape = T.match_buffer(var_T_reshape, [a // T.int64(2), b * 
T.int64(2)], dtype="float32")
             for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)):
@@ -638,8 +638,8 @@ def test_split_by_indices_n_section_divisible_symbolic():
     class Split:
         @R.function
         def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), 
"float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), 
"float32"), R.Tensor(("m", "n"), "float32")]):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), 
"float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1)
             return gv
 
@@ -647,15 +647,15 @@ def test_split_by_indices_n_section_divisible_symbolic():
     class Expected:
         @R.function
         def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), 
"float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), 
R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), 
R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(split, (x,), [R.Tensor((m, ((n * 3) // 3)), 
"float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), 
R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,))
             return gv
 
         @T.prim_func
         def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, 
var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
+            m = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * 
T.int64(3)], dtype="float32")
             T_split_sections = T.match_buffer(var_T_split_sections, [m, n * 
T.int64(3) // T.int64(3)], dtype="float32")
             T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n 
* T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], 
dtype="float32")
@@ -752,8 +752,8 @@ def test_squeeze_symbolic():
     class Squeeze:
         @R.function
         def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", 
"b", 1), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
+            a = T.int64()
+            b = T.int64()
             gv: R.Tensor((a, b, 1), "float32") = R.squeeze(x, [1])
             return gv
 
@@ -761,16 +761,16 @@ def test_squeeze_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", 
"b", 1), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
+            a = T.int64()
+            b = T.int64()
             gv = R.call_tir(squeeze, (x,), R.Tensor((a, b, 1), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
+            a = T.int64()
+            b = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, T.int64(1), 
b, T.int64(1)], dtype="float32")
             T_squeeze = T.match_buffer(var_T_squeeze, [a, b, T.int64(1)], 
dtype="float32")
             for i0, i1, i2 in T.grid(a, b, T.int64(1)):
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 3f9f02c410..729368b82a 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -151,12 +151,12 @@ def test_conv2d_symbolic():
     class Conv2d:
         @R.function
         def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: 
R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 
1", "w - kw + 1"), "float32"):
-            n = T.var("int64")
-            h = T.var("int64")
-            w = T.var("int64")
-            f = T.var("int64")
-            kh = T.var("int64")
-            kw = T.var("int64")
+            n = T.int64()
+            h = T.int64()
+            w = T.int64()
+            f = T.int64()
+            kh = T.int64()
+            kw = T.int64()
             gv: R.Tensor((n, f, h - kh + 1, w - kw + 1), "float32") = 
R.nn.conv2d(x, kernel)
             return gv
 
@@ -164,25 +164,25 @@ def test_conv2d_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: 
R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 
1", "w - kw + 1"), "float32"):
-            n = T.var("int64")
-            f = T.var("int64")
-            h = T.var("int64")
-            kh = T.var("int64")
-            w = T.var("int64")
-            kw = T.var("int64")
+            n = T.int64()
+            f = T.int64()
+            h = T.int64()
+            kh = T.int64()
+            w = T.int64()
+            kw = T.int64()
             gv = R.call_tir(conv2d, (x, kernel), R.Tensor((n, f, ((h - kh) + 
1), ((w - kw) + 1)), dtype="float32"))
             return gv
 
         @T.prim_func
         def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_conv2d_nchw: T.handle):
             T.func_attr({"tir.noalias": True})
-            c = T.var("int64")
-            f = T.var("int64")
-            h = T.var("int64")
-            kh = T.var("int64")
-            kw = T.var("int64")
-            n = T.var("int64")
-            w = T.var("int64")
+            c = T.int64()
+            f = T.int64()
+            h = T.int64()
+            kh = T.int64()
+            kw = T.int64()
+            n = T.int64()
+            w = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], 
dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, 
kw], dtype="float32")
             conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh + 
T.int64(1), w - kw + T.int64(1)], dtype="float32")
@@ -330,12 +330,12 @@ def test_max_pool2d_symbolic():
     class MaxPool2D:
         @R.function
         def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", 
"h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), 
"float32"):
-            n = T.var("int64")
-            c = T.var("int64")
-            h = T.var("int64")
-            w = T.var("int64")
-            kh = T.var("int64")
-            kw = T.var("int64")
+            n = T.int64()
+            c = T.int64()
+            h = T.int64()
+            w = T.int64()
+            kh = T.int64()
+            kw = T.int64()
             gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = 
R.nn.max_pool2d(x, pool_size=[kh, kw])
             return gv
 
@@ -434,10 +434,10 @@ def test_adaptive_avg_pool2d_symbolic():
     class AdaptiveAvgPool2D:
         @R.function
         def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", 
"h", "w"), "float32")) -> R.Tensor(("n", "c", "oh", "ow"), "float32"):
-            n = T.var("int64")
-            c = T.var("int64")
-            oh = T.var("int64")
-            ow = T.var("int64")
+            n = T.int64()
+            c = T.int64()
+            oh = T.int64()
+            ow = T.int64()
             gv: R.Tensor((n, c, oh, ow), "float32") = 
R.nn.adaptive_avg_pool2d(x, (oh, ow))
             return gv
     # fmt: on
@@ -483,8 +483,8 @@ def test_relu_symbolic():
     class Relu:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.nn.relu(x)
             return gv
 
@@ -492,16 +492,16 @@ def test_relu_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(relu, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def relu(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -581,8 +581,8 @@ def test_gelu_symbolic():
     class Gelu:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.nn.gelu(x)
             return gv
 
@@ -590,16 +590,16 @@ def test_gelu_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(gelu, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def gelu(var_rxplaceholder: T.handle, var_T_multiply: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             T_multiply = T.match_buffer(var_T_multiply, [m, n], 
dtype="float32")
             T_multiply_1 = T.alloc_buffer([m, n], dtype="float32")
@@ -686,8 +686,8 @@ def test_silu_symbolic():
     class Silu:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.nn.silu(x)
             return gv
 
@@ -695,16 +695,16 @@ def test_silu_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(silu, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def silu(var_rxplaceholder: T.handle, var_T_multiply: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             T_multiply = T.match_buffer(var_T_multiply, [m, n], 
dtype="float32")
             compute = T.alloc_buffer([m, n], dtype="float32")
@@ -789,9 +789,9 @@ def test_softmax_symbolic():
     class Softmax:
         @R.function
         def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 
"b", "c"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv: R.Tensor((a, b, c), "float32") = R.nn.softmax(x)
             return gv
 
@@ -799,18 +799,18 @@ def test_softmax_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 
"b", "c"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv = R.call_tir(softmax, (x,), R.Tensor((a, b, c), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], 
dtype="float32")
             T_softmax_norm = T.match_buffer(var_T_softmax_norm, [a, b, c], 
dtype="float32")
             T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32")
@@ -963,10 +963,10 @@ def test_batch_norm_symbolic():
     class BatchNorm:
         @R.function
         def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: 
R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: 
R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> 
R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), 
R.Tensor(("c",), "float32")):
-            n = T.var("int64")
-            h = T.var("int64")
-            w = T.var("int64")
-            c = T.var("int64")
+            n = T.int64()
+            h = T.int64()
+            w = T.int64()
+            c = T.int64()
             gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), 
"float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, 
moving_mean, moving_var, axis=-1)
             return gv
 
@@ -974,20 +974,20 @@ def test_batch_norm_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: 
R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: 
R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> 
R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), 
R.Tensor(("c",), "float32")):
-            n = T.var("int64")
-            h = T.var("int64")
-            w = T.var("int64")
-            c = T.var("int64")
+            n = T.int64()
+            h = T.int64()
+            w = T.int64()
+            c = T.int64()
             gv = R.call_tir(batch_norm, (x, gamma, beta, moving_mean, 
moving_var), [R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), 
R.Tensor((c,), "float32")])
             return gv
 
         @T.prim_func
         def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, 
var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_multiply: T.handle, 
var_T_multiply_1: T.handle):
             T.func_attr({"tir.noalias": True})
-            c = T.var("int64")
-            h = T.var("int64")
-            n = T.var("int64")
-            w = T.var("int64")
+            c = T.int64()
+            h = T.int64()
+            n = T.int64()
+            w = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [n, h, w, c], 
dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [c], 
dtype="float32")
             rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [c], 
dtype="float32")
@@ -1133,9 +1133,9 @@ def test_layer_norm_symbolic():
     class LayerNorm:
         @R.function
         def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: 
R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> 
R.Tensor(("n", "s", "f"), "float32"):
-            n = T.var("int64")
-            s = T.var("int64")
-            f = T.var("int64")
+            n = T.int64()
+            s = T.int64()
+            f = T.int64()
             gv: R.Tensor((n, s, f), "float32") = R.nn.layer_norm(x, gamma, 
beta, axes=[1, 2])
             return gv
 
@@ -1143,18 +1143,18 @@ def test_layer_norm_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: 
R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> 
R.Tensor(("n", "s", "f"), "float32"):
-            n = T.var("int64")
-            s = T.var("int64")
-            f = T.var("int64")
+            n = T.int64()
+            s = T.int64()
+            f = T.int64()
             gv = R.call_tir(layer_norm, (x, gamma, beta), R.Tensor((n, s, f), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
             T.func_attr({"tir.noalias": True})
-            f = T.var("int64")
-            n = T.var("int64")
-            s = T.var("int64")
+            f = T.int64()
+            n = T.int64()
+            s = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f], 
dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f], 
dtype="float32")
             rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f], 
dtype="float32")
diff --git 
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py 
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 4c31077d9c..5bdfb1774c 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -61,9 +61,9 @@ def test_where_symbolic():
     class Where:
         @R.function
         def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", 
"c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", 
"c"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv: R.Tensor((a, b, c), "float32") = R.where(condition, x, y)
             return gv
 
@@ -71,18 +71,18 @@ def test_where_symbolic():
     class Expected:
         @R.function
         def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", 
"c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", 
"c"), "float32"):
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             gv = R.call_tir(where, (condition, x, y), R.Tensor((a, b, c), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def where(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_rxplaceholder_2: T.handle, var_T_where: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, 
T.int64(1)], dtype="bool")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [b, c], 
dtype="float32")
             rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [b, 
T.int64(1)], dtype="float32")
@@ -141,8 +141,8 @@ def test_max_symbolic():
     class Max:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> 
R.Tensor(("a", "d"), "float32"):
-            a = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, d), "float32") = R.max(x, axis=[1, 2])
             return gv
 
@@ -150,18 +150,18 @@ def test_max_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> 
R.Tensor(("a", "d"), "float32"):
-            a = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            d = T.int64()
             gv = R.call_tir(max, (x,), R.Tensor((a, d), dtype="float32"))
             return gv
 
         @T.prim_func
         def max(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, d], 
dtype="float32")
             for i0, i1, i2, i3 in T.grid(a, d, b, c):
@@ -217,8 +217,8 @@ def test_min_symbolic():
     class Min:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> 
R.Tensor(("a", 1, 1, "d"), "float32"):
-            a = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            d = T.int64()
             gv: R.Tensor((a, 1, 1, d), "float32") = R.min(x, axis=[1, 2], 
keepdims=True)
             return gv
 
@@ -226,18 +226,18 @@ def test_min_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> 
R.Tensor(("a", 1, 1, "d"), "float32"):
-            a = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            d = T.int64()
             gv = R.call_tir(min, (x,), R.Tensor((a, 1, 1, d), dtype="float32"))
             return gv
 
         @T.prim_func
         def min(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, 
T.int64(1), T.int64(1), d], dtype="float32")
             for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), T.int64(1), d, 
b, c):
@@ -306,10 +306,10 @@ def test_sum_symbolic():
         @T.prim_func
         def sum(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((), 
"float32")):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             for i0, i1, i2, i3 in T.grid(a, b, c, d):
                 with T.block("rxplaceholder_red"):
@@ -377,10 +377,10 @@ def test_prod_symbolic():
         @T.prim_func
         def prod(var_rxplaceholder: T.handle, rxplaceholder_red: 
T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), 
T.int64(1), T.int64(1), T.int64(1), a, b, c, d):
                 with T.block("rxplaceholder_red"):
@@ -442,8 +442,8 @@ def test_mean_symbolic():
     class Mean:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> 
R.Tensor(("b", "c"), "float32"):
-            b = T.var("int64")
-            c = T.var("int64")
+            b = T.int64()
+            c = T.int64()
             gv: R.Tensor((b, c), "float32") = R.mean(x, [0, 3])
             return gv
 
@@ -451,18 +451,18 @@ def test_mean_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> 
R.Tensor(("b", "c"), dtype="float32"):
-            b = T.var("int64")
-            c = T.var("int64")
+            b = T.int64()
+            c = T.int64()
             gv = R.call_tir(mean, (x,), R.Tensor((b, c), dtype="float32"))
             return gv
 
         @T.prim_func
         def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             T_divide = T.match_buffer(var_T_divide, [b, c], dtype="float32")
             rxplaceholder_red = T.alloc_buffer([b, c], dtype="float32")
@@ -579,10 +579,10 @@ def test_std_symbolic():
         @T.prim_func
         def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             rxplaceholder_red = T.alloc_buffer([], dtype="float32")
             T_divide = T.alloc_buffer([], dtype="float32")
@@ -715,8 +715,8 @@ def test_variance_symbolic():
     class Variance:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 
"b", "c", 1), "float32"):
-            b = T.var("int64")
-            c = T.var("int64")
+            b = T.int64()
+            c = T.int64()
             gv: R.Tensor((1, b, c, 1), "float32") = R.variance(x, [0, 3], 
keepdims=True)
             return gv
 
@@ -724,18 +724,18 @@ def test_variance_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 
"b", "c", 1), "float32"):
-            b = T.var("int64")
-            c = T.var("int64")
+            b = T.int64()
+            c = T.int64()
             gv = R.call_tir(variance, (x,), R.Tensor((1, b, c, 1), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle):
             T.func_attr({"tir.noalias": True})
-            a = T.var("int64")
-            b = T.var("int64")
-            c = T.var("int64")
-            d = T.var("int64")
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], 
dtype="float32")
             T_divide = T.match_buffer(var_T_divide, [T.int64(1), b, c, 
T.int64(1)], dtype="float32")
             rxplaceholder_red = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], 
dtype="float32")
diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py 
b/tests/python/relax/test_transform_legalize_ops_unary.py
index 12ae366dcc..7250e711be 100644
--- a/tests/python/relax/test_transform_legalize_ops_unary.py
+++ b/tests/python/relax/test_transform_legalize_ops_unary.py
@@ -59,8 +59,8 @@ def test_abs_symbolic():
     class Abs:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.abs(x)
             return gv
 
@@ -68,16 +68,16 @@ def test_abs_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_abs, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_abs(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -129,8 +129,8 @@ def test_cos_symbolic():
     class Cos:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.cos(x)
             return gv
 
@@ -138,16 +138,16 @@ def test_cos_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_cos, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_cos(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -199,8 +199,8 @@ def test_exp_symbolic():
     class Exp:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.exp(x)
             return gv
 
@@ -208,16 +208,16 @@ def test_exp_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_exp, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -269,8 +269,8 @@ def test_log_symbolic():
     class Log:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.log(x)
             return gv
 
@@ -278,16 +278,16 @@ def test_log_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_log, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_log(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -339,8 +339,8 @@ def test_negative_symbolic():
     class Negative:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.negative(x)
             return gv
 
@@ -348,16 +348,16 @@ def test_negative_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_negative, (x,), R.Tensor((m, n), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_negative(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -409,8 +409,8 @@ def test_sigmoid_symbolic():
     class Sigmoid:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.sigmoid(x)
             return gv
 
@@ -418,16 +418,16 @@ def test_sigmoid_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_sigmoid, (x,), R.Tensor((m, n), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_sigmoid(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -479,8 +479,8 @@ def test_sin_symbolic():
     class Sin:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.sin(x)
             return gv
 
@@ -488,16 +488,16 @@ def test_sin_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_sin, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_sin(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -549,8 +549,8 @@ def test_sqrt_symbolic():
     class Sqrt:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.sqrt(x)
             return gv
 
@@ -558,16 +558,16 @@ def test_sqrt_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_sqrt, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_sqrt(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -619,8 +619,8 @@ def test_tanh_symbolic():
     class Tanh:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.tanh(x)
             return gv
 
@@ -628,16 +628,16 @@ def test_tanh_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_tanh, (x,), R.Tensor((m, n), dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_tanh(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
@@ -657,8 +657,8 @@ def test_clip_symbolic():
     class Clip:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv: R.Tensor((m, n), "float32") = R.clip(x, 5, 8)
             return gv
 
@@ -666,16 +666,16 @@ def test_clip_symbolic():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             gv = R.call_tir(tir_clip, (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
             return gv
 
         @T.prim_func
         def tir_clip(var_rxplaceholder: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
             compute = T.match_buffer(var_compute, [m, n], dtype="float32")
             for i0, i1 in T.grid(m, n):
diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py 
b/tests/python/relax/test_transform_meta_schedule_tuning.py
index ff695b9436..d87ea5cec7 100644
--- a/tests/python/relax/test_transform_meta_schedule_tuning.py
+++ b/tests/python/relax/test_transform_meta_schedule_tuning.py
@@ -36,7 +36,7 @@ class InputModule:
     @T.prim_func
     def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
         T.func_attr({"global_symbol": "tir_matmul"})
-        k = T.var("int32")
+        k = T.int32()
         A = T.match_buffer(x, (32, 32))
         B = T.match_buffer(y, (32, 32))
         C = T.match_buffer(z, (32, 32))
diff --git a/tests/python/relax/test_transform_normalize.py 
b/tests/python/relax/test_transform_normalize.py
index 9e9533a5ed..da123f956d 100644
--- a/tests/python/relax/test_transform_normalize.py
+++ b/tests/python/relax/test_transform_normalize.py
@@ -122,7 +122,7 @@ def test_normalize_no_op():
     class ANFMod2:
         @R.function
         def foo(x: R.Tensor(("m", "n"), "float32")):
-            m, n = T.var("int64"), T.var("int64")
+            m, n = T.int64(), T.int64()
             with R.dataflow():
                 lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
                 gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32"))
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index f11df58b26..1b556139cc 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -530,16 +530,16 @@ def test_symbolic_shape():
     class Module:
         @T.prim_func
         def exp(var_A: T.handle, var_B: T.handle):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             A = T.match_buffer(var_A, (m, n), "float32")
             B = T.match_buffer(var_B, (m, n), "float32")
             T.evaluate(0)
 
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")):
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             alloc: R.Tensor((m, n), dtype="float32") = R.builtin.alloc_tensor(
                 R.shape([m, n]), dtype="float32", runtime_device_index=0
             )
diff --git a/tests/python/relax/test_tuning_api.py 
b/tests/python/relax/test_tuning_api.py
index b12ff01670..3fc2d41618 100644
--- a/tests/python/relax/test_tuning_api.py
+++ b/tests/python/relax/test_tuning_api.py
@@ -47,7 +47,7 @@ from tvm.relax.transform.tuning_api import (
 @tvm.script.ir_module
 class TestModule:
     @T.prim_func
-    def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) 
-> None:
+    def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) 
-> None:
         T.func_attr(({"global_symbol": "addone"}))
         for i, j in T.grid(16, 16):
             with T.block("addone"):
diff --git a/tests/python/relax/test_tvmscript_ir_builder.py 
b/tests/python/relax/test_tvmscript_ir_builder.py
index 12d8b114b8..eb0aaf5604 100644
--- a/tests/python/relax/test_tvmscript_ir_builder.py
+++ b/tests/python/relax/test_tvmscript_ir_builder.py
@@ -61,8 +61,8 @@ def test_match_cast():
     """
     @R.function
     def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")):
-        m = T.var("int64")
-        n = T.var("int64")
+        m = T.int64()
+        n = T.int64()
         _ = R.match_cast(x, R.Tensor((m,), "float32"))
         y1 = R.match_cast(x, R.Tensor((n,), "float32"))
         return (m, n * 2)
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 507ce72c06..8df125ac72 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -105,7 +105,7 @@ def test_unexpected_tir_cast_args():
 
         @R.function
         def f(x: R.Tensor(("m",), "float32")):
-            m = T.var("int64")
+            m = T.int64()
             # tir.cast expects 2 arguments, but got 3
             return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), 
dtype="float32"))
 
@@ -116,7 +116,7 @@ def test_unexpected_tir_max_args():
 
         @R.function
         def f(x: R.Tensor(("m", "n"), "float32")):
-            m = T.var("int64")
+            m = T.int64()
             # tir.max expects 2 arguments, but got 1
             return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), 
dtype="float32"))
 
@@ -220,15 +220,15 @@ def test_relax_base_op():
 def test_symbolic_shape():
     @R.function
     def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-        m = T.var("int64", "m")
-        n = T.var("int64", "n")
+        m = T.int64()
+        n = T.int64()
         gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))
         return gv0
 
     @R.function
     def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-        m = T.var("int64")
-        n = T.var("int64")
+        m = T.int64()
+        n = T.int64()
         gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))
         return gv0
 
@@ -236,8 +236,8 @@ def test_symbolic_shape():
 
         @R.function
         def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> 
R.Tensor(None, "float32", ndim=2):
-            m = T.var("int64")
-            n = T.var("int32")  # The shape dtype should be int64
+            m = T.int64()
+            n = T.int32()  # The shape dtype should be int64
             gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), 
dtype="float32"))
             return gv0
 
@@ -282,8 +282,8 @@ def test_shadowing():
 def test_match_cast():
     @R.function
     def foo(x: R.Tensor("float32"), y: R.Tensor("float32")):
-        m = T.var("int64")
-        n = T.var("int64")
+        m = T.int64()
+        n = T.int64()
         x0 = R.match_cast(x, R.Tensor([m], "float32"))
         with R.dataflow():
             y0 = R.match_cast(y, R.Tensor([n], "float32"))
@@ -327,7 +327,7 @@ def test_tuple_return():
 def test_tuple_return_2():
     @R.function
     def foo(x: R.Tensor("float32", ndim=2)):
-        n, m = T.var("int64"), T.var("int64")
+        n, m = T.int64(), T.int64()
         x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
         return (x0, R.shape([n + 1, m, 1]))
 
@@ -344,7 +344,7 @@ def test_tuple_return_2():
 def test_tuple_binding():
     @R.function
     def foo(x: R.Tensor("float32", ndim=2)):
-        n, m = T.var("int64"), T.var("int64")
+        n, m = T.int64(), T.int64()
         x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
         t0 = (x, x0)
         t1 = (x, R.shape([n, m]), t0)
@@ -414,8 +414,8 @@ def test_dataflow_block_advanced():
         gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
         gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), 
dtype="float32"))
         with R.dataflow():
-            m = T.var("int64")
-            n = T.var("int64")
+            m = T.int64()
+            n = T.int64()
             lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), 
dtype="float32"))
             lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32"))
             gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), 
dtype="float32"))
@@ -601,7 +601,7 @@ def test_annotation():
         y: R.Tensor(("m",), "float32"),
         r: R.Tensor(dtype="int64"),
     ) -> R.Object:
-        m = T.var("int64", "m")
+        m = T.int64()
         z: R.Tensor((32, m), "float32") = R.multiply(x, y)
         w: R.Tensor = R.multiply(z, z)
         q: R.Tensor(ndim=2) = R.add(w, w)
@@ -690,7 +690,7 @@ def test_call_tir_with_tir_var():
         def main(
             dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", 
"float32"))
         ) -> R.Tensor(("n * 2",), "float32"):
-            n = T.var("int64")
+            n = T.int64()
             y = R.call_tir(copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), 
tir_vars=(n,))
             return y
 
@@ -884,7 +884,7 @@ def test_erase_to_well_defined():
     @R.function
     def foo(x: R.Tensor):
         q = x
-        m, n = T.var("int64"), T.var("int64")
+        m, n = T.int64(), T.int64()
         z = R.match_cast(q, R.Tensor((m, n)))
         w = z
         return w
@@ -930,7 +930,7 @@ def test_symbolic_shape_computing():
     def bar(
         x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), 
"float32")
     ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"):
-        m = T.var("int64")
+        m = T.int64()
         z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), 
dtype="float32"))
         return z
 
@@ -949,7 +949,7 @@ def test_symbolic_shape_computing():
     # Shape Case
     @R.function
     def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")):
-        m = T.var("int64")
+        m = T.int64()
         z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32"))
         return z
 
@@ -977,8 +977,8 @@ def test_symbolic_shape_computing():
 def test_vm_ops():
     @R.function
     def foo(x: R.Tensor(("m", "n"), dtype="float32")):
-        m = T.var("int64")
-        n = T.var("int64")
+        m = T.int64()
+        n = T.int64()
         storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", 
runtime_device_index=0)
         alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, 
dtype="float32")
         tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", 
runtime_device_index=0)
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index d57efd8b99..e78e926dcb 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -86,7 +86,7 @@ def test_vm_compile_stage2(exec_mode):
     class TestVMCompileStage2:
         @R.function
         def foo(x: R.Tensor(dtype="float32")) -> R.Shape:
-            n, m = T.var("int64"), T.var("int64")
+            n, m = T.int64(), T.int64()
             _ = R.match_cast(x, R.Tensor((n, m), "float32"))
             return R.shape([n * 2, m * 3])
 
@@ -143,7 +143,7 @@ def test_vm_compile_e2e(exec_mode):
         @R.function
         def foo(x: R.Tensor(dtype="float32")) -> R.Tensor:
             with R.dataflow():
-                n, m = T.var("int64"), T.var("int64")
+                n, m = T.int64(), T.int64()
                 _ = R.match_cast(x, R.Tensor((n, m), "float32"))
                 y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), 
dtype="float32"))
                 R.output(y)
@@ -168,9 +168,9 @@ def test_vm_compile_e2e_func_param_with_shape(exec_mode):
         @T.prim_func
         def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
             T.func_attr({"global_symbol": "tir_matmul"})
-            m = T.var("int32")
-            n = T.var("int32")
-            k = T.var("int32")
+            m = T.int32()
+            n = T.int32()
+            k = T.int32()
             A = T.match_buffer(x, (m, n))
             B = T.match_buffer(y, (n, k))
             C = T.match_buffer(z, (m, k))
@@ -186,7 +186,7 @@ def test_vm_compile_e2e_func_param_with_shape(exec_mode):
         def func(
             x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")
         ) -> R.Tensor:
-            m, k = T.var("int64"), T.var("int64")
+            m, k = T.int64(), T.int64()
             gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), 
dtype="float32"))
             return gv0
 
@@ -540,9 +540,9 @@ def test_sub_func_call(exec_mode):
         @T.prim_func
         def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
             T.func_attr({"global_symbol": "tir_matmul"})
-            m = T.var("int32")
-            n = T.var("int32")
-            k = T.var("int32")
+            m = T.int32()
+            n = T.int32()
+            k = T.int32()
             A = T.match_buffer(x, (m, n))
             B = T.match_buffer(y, (n, k))
             C = T.match_buffer(z, (m, k))
@@ -680,8 +680,8 @@ class TestVMSetInput:
     @T.prim_func
     def test_vm_mul(x: T.handle, y: T.handle, z: T.handle):
         T.func_attr({"global_symbol": "test_vm_mul"})
-        m = T.var("int32")
-        n = T.var("int32")
+        m = T.int32()
+        n = T.int32()
         A = T.match_buffer(x, (m, n))
         B = T.match_buffer(y, (m, n))
         C = T.match_buffer(z, (m, n))

Reply via email to