This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 82cf9f72d6 [TVMScript] Simplify TIR Var Definition (#13970)
82cf9f72d6 is described below

commit 82cf9f72d68903f1f36921cd1e7ae4435eced5d3
Author: Junru Shao <[email protected]>
AuthorDate: Sun Feb 12 18:32:41 2023 -0800

    [TVMScript] Simplify TIR Var Definition (#13970)
    
    This PR introduces a small tweak to TVMScript printer that simplifies
    variable definition in TIR.
    
    Originally, defining a TIR var uses `T.var(dtype)`, e.g.
    
    ```python
    a = T.var("int32")
    ```
    
    This PR encourages to shorten the definition to:
    
    ```python
    a = T.int32()
    ```
    
    There is no breaking change in this PR, which means the legacy behavior
    still works without any problem.
---
 python/tvm/script/ir_builder/tir/ir.py             |  1 +
 python/tvm/script/parser/tir/parser.py             |  2 +-
 python/tvm/tir/tensor_intrin/cuda.py               | 40 ++++++------
 python/tvm/utils/roofline/cuda.py                  |  2 +-
 python/tvm/utils/roofline/x86.py                   |  2 +-
 src/script/printer/tir/expr.cc                     | 25 +++++--
 .../test_ethosu/test_copy_compute_reordering.py    | 76 +++++++++++-----------
 .../contrib/test_ethosu/test_merge_constants.py    | 40 ++++++------
 tests/python/integration/test_lower.py             | 12 ++--
 .../unittest/test_aot_legalize_packed_call.py      | 16 ++---
 tests/python/unittest/test_arith_domain_touched.py |  4 +-
 .../test_meta_schedule_postproc_verify_gpu_code.py | 12 ++--
 .../unittest/test_meta_schedule_trace_apply.py     | 40 ++++++------
 tests/python/unittest/test_te_create_primfunc.py   | 16 ++---
 tests/python/unittest/test_tir_analysis_oob.py     |  2 +-
 tests/python/unittest/test_tir_intrin.py           | 10 +--
 .../python/unittest/test_tir_lower_match_buffer.py | 26 ++++----
 tests/python/unittest/test_tir_renew_defs.py       |  6 +-
 tests/python/unittest/test_tir_schedule_rfactor.py |  2 +-
 .../python/unittest/test_tir_schedule_tensorize.py | 24 +++----
 tests/python/unittest/test_tir_specialize.py       | 18 ++---
 .../test_tir_transform_common_subexpr_elim.py      |  4 +-
 .../test_tir_transform_hoist_expression.py         |  4 +-
 .../python/unittest/test_tvmscript_error_report.py |  4 +-
 .../unittest/test_tvmscript_ir_builder_tir.py      | 32 +++++----
 .../python/unittest/test_tvmscript_printer_tir.py  | 52 +++++++--------
 tests/python/unittest/test_tvmscript_roundtrip.py  | 22 +++----
 .../python/unittest/test_tvmscript_syntax_sugar.py |  6 +-
 28 files changed, 257 insertions(+), 243 deletions(-)

diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 25d16b56dc..2c5a848e4a 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1393,6 +1393,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
     return _ffi_api.Void(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
+@deprecated("T.var", "T.{dtype}")
 def var(dtype: str, name: str = "") -> Var:
     """Construct a new tir.Var.
 
diff --git a/python/tvm/script/parser/tir/parser.py 
b/python/tvm/script/parser/tir/parser.py
index 0e74114ba2..fbef1a9691 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -143,7 +143,7 @@ def bind_assign_value(self: Parser, node: doc.expr, 
var_name: str, value: Any) -
         IRBuilder.name(var_name, value)
         return value
     elif isinstance(value, PrimExpr):
-        var = T.var(value.dtype)
+        var = Var("", value.dtype)
         IRBuilder.name(var_name, var)
         frame = T.let(var, value)
         frame.add_callback(partial(frame.__exit__, None, None, None))
diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
index 0703811ea7..6483b99454 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -146,8 +146,8 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
 
     @T.prim_func
     def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
-        s0 = T.var("int32")
-        s1 = T.var("int32")
+        s0 = T.int32()
+        s1 = T.int32()
         shared = T.match_buffer(
             shared_handle,
             shmem_shape,
@@ -385,8 +385,8 @@ def get_mma_store_intrin(dtype, local_size, scope="global"):
 
     @T.prim_func
     def mma_store_impl(a: T.handle, c: T.handle) -> None:
-        s0 = T.var("int32")
-        s1 = T.var("int32")
+        s0 = T.int32()
+        s1 = T.int32()
 
         C_warp = T.match_buffer(
             a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
@@ -530,10 +530,10 @@ def get_wmma_load_intrin(
 
     @T.prim_func
     def wmma_load_impl(a: T.handle, c: T.handle) -> None:
-        s1 = T.var("int32")
-        s0 = T.var("int32")
-        d1 = T.var("int32")
-        d0 = T.var("int32")
+        s1 = T.int32()
+        s0 = T.int32()
+        d1 = T.int32()
+        d0 = T.int32()
         A = T.match_buffer(
             a,
             (m_dim, n_dim),
@@ -593,8 +593,8 @@ def get_wmma_fill_intrin(
 
     @T.prim_func
     def wmma_fill_impl(c: T.handle) -> None:
-        d1 = T.var("int32")
-        d0 = T.var("int32")
+        d1 = T.int32()
+        d0 = T.int32()
         C = T.match_buffer(
             c,
             (m_dim, n_dim),
@@ -643,10 +643,10 @@ def get_wmma_store_intrin(
 
     @T.prim_func
     def wmma_store_impl(a: T.handle, c: T.handle) -> None:
-        s1 = T.var("int32")
-        s0 = T.var("int32")
-        d1 = T.var("int32")
-        d0 = T.var("int32")
+        s1 = T.int32()
+        s0 = T.int32()
+        d1 = T.int32()
+        d0 = T.int32()
         A = T.match_buffer(
             a,
             (m_dim, n_dim),
@@ -726,12 +726,12 @@ def get_wmma_sync_intrin(
 
     @T.prim_func
     def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
-        a1 = T.var("int32")
-        a0 = T.var("int32")
-        b1 = T.var("int32")
-        b0 = T.var("int32")
-        c1 = T.var("int32")
-        c0 = T.var("int32")
+        a1 = T.int32()
+        a0 = T.int32()
+        b1 = T.int32()
+        b0 = T.int32()
+        c1 = T.int32()
+        c0 = T.int32()
 
         A = T.match_buffer(
             a,
diff --git a/python/tvm/utils/roofline/cuda.py 
b/python/tvm/utils/roofline/cuda.py
index 5d80c80880..b83a902b7f 100644
--- a/python/tvm/utils/roofline/cuda.py
+++ b/python/tvm/utils/roofline/cuda.py
@@ -299,7 +299,7 @@ def estimate_peak_flops(
 @T.prim_func
 def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: 
T.int32) -> None:
     # pylint: disable=invalid-name, missing-function-docstring
-    N = T.var("int32")
+    N = T.int32()
     A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
     B = T.match_buffer(b, [blocks, 4, warp_size], "float32")
     for i in T.thread_binding(blocks, "blockIdx.x"):
diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py
index 37a666d252..5d2dd27e52 100644
--- a/python/tvm/utils/roofline/x86.py
+++ b/python/tvm/utils/roofline/x86.py
@@ -216,7 +216,7 @@ def estimate_peak_fma_flops(
 @T.prim_func
 def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: 
T.int32) -> None:
     # pylint: disable=invalid-name, missing-function-docstring
-    N = T.var("int32")
+    N = T.int32()
     A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
     B = T.match_buffer(b, [threads, 4, vec_width], "float32")
     # Parallelism is necessary to hit all cores/nodes
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index a5d5d492ff..d860eeb2a7 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -29,14 +29,29 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, 
const IRDocsifier& d)
     if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
       ExprDoc lhs = DefineVar(var, opt_f.value(), d);
       Type type = var->type_annotation;
+      ObjectPath type_p = var_p->Attr("type_annotation");
+      ExprDoc rhs{nullptr};
       if (const auto* ptr_type = type.as<PointerTypeNode>()) {
-        ICHECK(ptr_type->element_type->IsInstance<PrimTypeNode>());
-        ExprDoc rhs = d->AsDoc<ExprDoc>(type, var_p->Attr("type_annotation"));
-        opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+        const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
+        ICHECK(prim_type);
+        ExprDoc element_type =
+            LiteralDoc::DataType(prim_type->dtype, 
type_p->Attr("element_type")->Attr("dtype"));
+        rhs = TIR(d, "handle");
+        rhs->source_paths.push_back(var_p->Attr("dtype"));
+        if (ptr_type->storage_scope == "") {
+          rhs = rhs->Call({element_type});
+        } else {
+          rhs = rhs->Call({element_type,
+                           LiteralDoc::Str(ptr_type->storage_scope,  //
+                                           type_p->Attr("storage_scope"))});
+        }
       } else {
-        ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, 
var_p->Attr("dtype"))});
-        opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+        rhs = TIR(d, DType2Str(var->dtype));
+        rhs->source_paths.push_back(var_p->Attr("dtype"));
+        rhs = rhs->Call({});
       }
+      rhs->source_paths.push_back(type_p);
+      opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
     } else {
       LOG(WARNING) << "Didn't find variable definition for: " << 
var->name_hint;
     }
diff --git a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py 
b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
index 99bd273115..1a00e01b60 100644
--- a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
+++ b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
@@ -476,16 +476,16 @@ def test_reordering_based_on_cycles():
         def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: 
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), 
placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: 
T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
             # function attr dict
             T.func_attr({"tir.noalias": True, "global_symbol": "main", 
"from_legacy_te_schedule": True})
-            ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
-            nn = T.var("int32")
-            nn_1 = T.var("int32")
-            nn_2 = T.var("int32")
-            nn_3 = T.var("int32")
-            nn_4 = T.var("int32")
-            nn_5 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
+            nn = T.int32()
+            nn_1 = T.int32()
+            nn_2 = T.int32()
+            nn_3 = T.int32()
+            nn_4 = T.int32()
+            nn_5 = T.int32()
             # body
             placeholder_d_global = T.decl_buffer([208], "uint8")
             placeholder_d_global_1 = T.decl_buffer([112], "uint8")
@@ -524,16 +524,16 @@ def test_reordering_based_on_cycles():
         def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: 
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), 
placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: 
T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
             # function attr dict
             T.func_attr({"tir.noalias": True, "global_symbol": "main", 
"from_legacy_te_schedule": True})
-            ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
-            nn = T.var("int32")
-            nn_1 = T.var("int32")
-            nn_2 = T.var("int32")
-            nn_3 = T.var("int32")
-            nn_4 = T.var("int32")
-            nn_5 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
+            nn = T.int32()
+            nn_1 = T.int32()
+            nn_2 = T.int32()
+            nn_3 = T.int32()
+            nn_4 = T.int32()
+            nn_5 = T.int32()
             # body
             placeholder_d_global = T.decl_buffer([208], "uint8")
             placeholder_d_global_1 = T.decl_buffer([112], "uint8")
@@ -579,15 +579,15 @@ def test_reordering_based_on_cycles_luts_present():
         def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: 
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), 
placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, 
"uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, 
"int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
             # function attr dict
             T.func_attr({"tir.noalias": True, "global_symbol": "main", 
"from_legacy_te_schedule": True})
-            ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
-            nn = T.var("int32")
-            nn_1 = T.var("int32")
-            nn_2 = T.var("int32")
-            nn_3 = T.var("int32")
-            nn_4 = T.var("int32")
-            nn_5 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+            nn = T.int32()
+            nn_1 = T.int32()
+            nn_2 = T.int32()
+            nn_3 = T.int32()
+            nn_4 = T.int32()
+            nn_5 = T.int32()
             # body
             placeholder_d_d_global = T.decl_buffer([208], "uint8")
             placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
@@ -629,15 +629,15 @@ def test_reordering_based_on_cycles_luts_present():
         def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: 
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), 
placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, 
"uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, 
"int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
             # function attr dict
             T.func_attr({"tir.noalias": True, "global_symbol": "main", 
"from_legacy_te_schedule": True})
-            ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
-            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
-            nn = T.var("int32")
-            nn_1 = T.var("int32")
-            nn_2 = T.var("int32")
-            nn_3 = T.var("int32")
-            nn_4 = T.var("int32")
-            nn_5 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+            nn = T.int32()
+            nn_1 = T.int32()
+            nn_2 = T.int32()
+            nn_3 = T.int32()
+            nn_4 = T.int32()
+            nn_5 = T.int32()
             # body
             placeholder_d_d_global = T.decl_buffer([208], "uint8")
             placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py 
b/tests/python/contrib/test_ethosu/test_merge_constants.py
index 909f9fe673..624bef00c7 100644
--- a/tests/python/contrib/test_ethosu/test_merge_constants.py
+++ b/tests/python/contrib/test_ethosu/test_merge_constants.py
@@ -650,18 +650,18 @@ def test_cycle_count():
         def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), 
"uint8"), buffer4: T.Buffer((112,), "uint8"), buffer5: T.Buffer((32,), 
"uint8"), buffer6: T.Buffer((112,), "uint8"), buffer7: T.Buffer((32,), 
"uint8"), buffer8: T.Buffer((112,), "uint8"), buffer9: T.Buffer((32,), 
"uint8")) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": 
"main", "tir.noalias": True})
-            v1a = T.var("int32")
-            v1b = T.var("int32")
-            v1c = T.var("int32")
-            v2a = T.var("int32")
-            v2b = T.var("int32")
-            v2c = T.var("int32")
-            v3a = T.var("int32")
-            v3b = T.var("int32")
-            v3c = T.var("int32")
-            v4a = T.var("int32")
-            v4b = T.var("int32")
-            v4c = T.var("int32")
+            v1a = T.int32()
+            v1b = T.int32()
+            v1c = T.int32()
+            v2a = T.int32()
+            v2b = T.int32()
+            v2c = T.int32()
+            v3a = T.int32()
+            v3b = T.int32()
+            v3c = T.int32()
+            v4a = T.int32()
+            v4b = T.int32()
+            v4c = T.int32()
             buffer1 = T.Buffer([8192], "int8")
             buffer10 = T.Buffer([2048], "int8")
             # body
@@ -713,14 +713,14 @@ def test_cycle_count():
         def main(buffer2: T.Buffer((160,), "uint8"), buffer4: T.Buffer((144,), 
"uint8"), buffer6: T.Buffer((144,), "uint8"), buffer8: T.Buffer((144,), 
"uint8")) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": 
"main", "tir.noalias": True})
-            v1a = T.var("int32")
-            v1c = T.var("int32")
-            v2a = T.var("int32")
-            v2c = T.var("int32")
-            v3a = T.var("int32")
-            v3c = T.var("int32")
-            v4a = T.var("int32")
-            v4c = T.var("int32")
+            v1a = T.int32()
+            v1c = T.int32()
+            v2a = T.int32()
+            v2c = T.int32()
+            v3a = T.int32()
+            v3c = T.int32()
+            v4a = T.int32()
+            v4c = T.int32()
             buffer1 = T.Buffer([8192], "int8")
             buffer10 = T.Buffer([2048], "int8")
             # body
diff --git a/tests/python/integration/test_lower.py 
b/tests/python/integration/test_lower.py
index 1ccdde8b13..965ab80beb 100644
--- a/tests/python/integration/test_lower.py
+++ b/tests/python/integration/test_lower.py
@@ -136,8 +136,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, 
handle_c: T.handle)
                                                 axis_vk * 16 : axis_vk * 16 + 
16,
                                             ]
                                         )
-                                        stride0 = T.var("int32")
-                                        stride1 = T.var("int32")
+                                        stride0 = T.int32()
+                                        stride1 = T.int32()
                                         match_buffer_a0 = T.match_buffer(
                                             shared_a[
                                                 new_axis_vi * 16 : new_axis_vi 
* 16 + 16,
@@ -198,8 +198,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, 
handle_c: T.handle)
                                                 axis_vk * 16 : axis_vk * 16 + 
16,
                                             ]
                                         )
-                                        stride0 = T.var("int32")
-                                        stride1 = T.var("int32")
+                                        stride0 = T.int32()
+                                        stride1 = T.int32()
                                         match_buffer_b0 = T.match_buffer(
                                             shared_b[
                                                 new_axis_vj * 16 : new_axis_vj 
* 16 + 16,
@@ -335,8 +335,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, 
handle_c: T.handle)
                                         new_axis_vj * 16 : new_axis_vj * 16 + 
16,
                                     ]
                                 )
-                                stride0 = T.var("int32")
-                                stride1 = T.var("int32")
+                                stride0 = T.int32()
+                                stride1 = T.int32()
                                 wmma_c2 = T.match_buffer(
                                     wmma_c[
                                         new_axis_vi * 16 : new_axis_vi * 16 + 
16,
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py 
b/tests/python/unittest/test_aot_legalize_packed_call.py
index ad970d52c0..6f66f3a432 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -35,10 +35,10 @@ class Module:
 
     @T.prim_func
     def tir_packed_call() -> None:
-        A = T.var("handle")
-        B = T.var("handle")
-        C = T.var("handle")
-        device_context = T.var("handle")
+        A = T.handle()
+        B = T.handle()
+        C = T.handle()
+        device_context = T.handle()
         # body
         T.evaluate(
             T.tvm_call_cpacked(
@@ -65,10 +65,10 @@ class Expected:
 
     @T.prim_func
     def tir_packed_call() -> None:
-        A = T.var("handle")
-        B = T.var("handle")
-        C = T.var("handle")
-        device_context = T.var("handle")
+        A = T.handle()
+        B = T.handle()
+        C = T.handle()
+        device_context = T.handle()
 
         # body
         T.evaluate(
diff --git a/tests/python/unittest/test_arith_domain_touched.py 
b/tests/python/unittest/test_arith_domain_touched.py
index 9f7eee0963..e19991b3b8 100644
--- a/tests/python/unittest/test_arith_domain_touched.py
+++ b/tests/python/unittest/test_arith_domain_touched.py
@@ -21,7 +21,7 @@ from tvm.script import tir as T
 
 @T.prim_func
 def scalar_func(a: T.handle, b: T.handle):
-    m = T.var("int32")
+    m = T.int32()
     n = 100
     A = T.match_buffer(a, (n, m))
     B = T.match_buffer(b, (n, m))
@@ -73,7 +73,7 @@ def test_domain_touched_vector():
 
     @T.prim_func
     def func(a: T.handle, b: T.handle):
-        n = T.var("int32")
+        n = T.int32()
         A = T.match_buffer(a, (n * m,))
         B = T.match_buffer(b, (n * m,))
 
diff --git 
a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py 
b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
index 59de0b0c57..0facc9b961 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
@@ -399,12 +399,12 @@ def GMMCUDATensorCore(
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
-    s0 = T.var("int32")
-    s0_1 = T.var("int32")
-    s0_2 = T.var("int32")
-    s1 = T.var("int32")
-    s1_1 = T.var("int32")
-    s1_2 = T.var("int32")
+    s0 = T.int32()
+    s0_1 = T.int32()
+    s0_2 = T.int32()
+    s1 = T.int32()
+    s1_1 = T.int32()
+    s1_2 = T.int32()
     # body
     # with T.block("root")
     Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="wmma.accumulator")
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py 
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index ae65cc1a81..d09f2a226c 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -637,26 +637,26 @@ class Conv2dInt8_tensorcore_scheduled:
     def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 
64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 
256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 
256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), 
"int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), 
compute: T.Buffer((16, 56, 56, 256), "uint8")) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        A_s0 = T.var("int32")
-        A_s0_1 = T.var("int32")
-        A_s0_2 = T.var("int32")
-        A_s0_3 = T.var("int32")
-        A_s1 = T.var("int32")
-        A_s1_1 = T.var("int32")
-        A_s1_2 = T.var("int32")
-        A_s1_3 = T.var("int32")
-        B_s0 = T.var("int32")
-        B_s1 = T.var("int32")
-        C_s0 = T.var("int32")
-        C_s0_1 = T.var("int32")
-        C_s0_2 = T.var("int32")
-        C_s0_3 = T.var("int32")
-        C_s0_4 = T.var("int32")
-        C_s1 = T.var("int32")
-        C_s1_1 = T.var("int32")
-        C_s1_2 = T.var("int32")
-        C_s1_3 = T.var("int32")
-        C_s1_4 = T.var("int32")
+        A_s0 = T.int32()
+        A_s0_1 = T.int32()
+        A_s0_2 = T.int32()
+        A_s0_3 = T.int32()
+        A_s1 = T.int32()
+        A_s1_1 = T.int32()
+        A_s1_2 = T.int32()
+        A_s1_3 = T.int32()
+        B_s0 = T.int32()
+        B_s1 = T.int32()
+        C_s0 = T.int32()
+        C_s0_1 = T.int32()
+        C_s0_2 = T.int32()
+        C_s0_3 = T.int32()
+        C_s0_4 = T.int32()
+        C_s1 = T.int32()
+        C_s1_1 = T.int32()
+        C_s1_2 = T.int32()
+        C_s1_3 = T.int32()
+        C_s1_4 = T.int32()
         # body
         # with T.block("root")
         conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], 
dtype="int32", scope="shared")
diff --git a/tests/python/unittest/test_te_create_primfunc.py 
b/tests/python/unittest/test_te_create_primfunc.py
index 0b6f87b833..2598d620ba 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -199,8 +199,8 @@ def te_multi_output():
 @T.prim_func
 def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) 
-> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
-    m = T.var("int32")
-    n = T.var("int32")
+    m = T.int32()
+    n = T.int32()
     A0 = T.match_buffer(a0, (m, n))
     A1 = T.match_buffer(a1, (m, n))
     B0 = T.match_buffer(b0, (m, n))
@@ -491,8 +491,8 @@ def tir_argmax_idx_val(
     var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, 
var_argmax_v1: T.handle
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
-    m = T.var("int32")
-    n = T.var("int32")
+    m = T.int32()
+    n = T.int32()
     idx = T.match_buffer(var_idx, [m, n], dtype="int32")
     val = T.match_buffer(var_val, [m, n], dtype="float32")
     argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32")
@@ -538,8 +538,8 @@ def tir_argmax_val_idx(
     var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, 
var_argmax_v1: T.handle
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
-    m = T.var("int32")
-    n = T.var("int32")
+    m = T.int32()
+    n = T.int32()
     val = T.match_buffer(var_val, [m, n], dtype="float32")
     idx = T.match_buffer(var_idx, [m, n], dtype="int32")
     argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
@@ -711,8 +711,8 @@ def tir_resize2d_symbolic(
     var_resize: T.handle,
 ):
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
-    oh = T.var("int64")
-    ow = T.var("int64")
+    oh = T.int64()
+    ow = T.int64()
     resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow], 
dtype="float32")
     for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), oh, ow):
         with T.block("resize"):
diff --git a/tests/python/unittest/test_tir_analysis_oob.py 
b/tests/python/unittest/test_tir_analysis_oob.py
index 83c0294176..7c8ceed36e 100644
--- a/tests/python/unittest/test_tir_analysis_oob.py
+++ b/tests/python/unittest/test_tir_analysis_oob.py
@@ -44,7 +44,7 @@ def bad_store_loop(A: T.Buffer((2, 3), "float32"), B: 
T.Buffer((3, 2), "float32"
 
 @T.prim_func
 def unknown_bounds(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), 
"float32")):
-    N = T.var("int32")
+    N = T.int32()
     for i in range(3):
         B[0, N] = A[1, i]
 
diff --git a/tests/python/unittest/test_tir_intrin.py 
b/tests/python/unittest/test_tir_intrin.py
index f887f8877a..1ee709191c 100644
--- a/tests/python/unittest/test_tir_intrin.py
+++ b/tests/python/unittest/test_tir_intrin.py
@@ -193,11 +193,11 @@ class Module:
     def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> 
None:
         # function attr dict
         T.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
-        n = T.var("int32")
-        stride = T.var("int32")
-        stride_1 = T.var("int32")
-        stride_2 = T.var("int32")
-        stride_3 = T.var("int32")
+        n = T.int32()
+        stride = T.int32()
+        stride_1 = T.int32()
+        stride_2 = T.int32()
+        stride_3 = T.int32()
         A_1 = T.match_buffer(
             A,
             [n],
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py 
b/tests/python/unittest/test_tir_lower_match_buffer.py
index 535e0bb329..5bea77ffe3 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -93,8 +93,8 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
             )
     for i, j, k in T.grid(64, 2, 8):
         with T.block():
-            Bs_0 = T.var("int32")
-            Bs_1 = T.var("int32")
+            Bs_0 = T.int32()
+            Bs_1 = T.int32()
             T.reads([])
             T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8])
             sub_B = T.match_buffer(
@@ -154,8 +154,8 @@ def high_dim_opaque_access(a: T.handle) -> None:
     A = T.match_buffer(a, (16, 32, 64))
     for i, j, k in T.grid(16, 2, 4):
         with T.block():
-            As_0 = T.var("int32")
-            As_1 = T.var("int32")
+            As_0 = T.int32()
+            As_1 = T.int32()
             T.reads([])
             T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
             sub_A = T.match_buffer(
@@ -200,8 +200,8 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) 
-> None:
     A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
     for i, j, k in T.grid(16, 2, 4):
         with T.block():
-            As_0 = T.var("int32")
-            As_1 = T.var("int32")
+            As_0 = T.int32()
+            As_1 = T.int32()
             T.reads([])
             T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
             sub_A = T.match_buffer(
@@ -254,8 +254,8 @@ def recursive_match(a: T.handle, b: T.handle) -> None:
                     B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
                 ]
             )
-            As_0 = T.var("int32")
-            As_1 = T.var("int32")
+            As_0 = T.int32()
+            As_1 = T.int32()
             sub_A = T.match_buffer(
                 A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
                 (16, 16),
@@ -276,8 +276,8 @@ def recursive_match(a: T.handle, b: T.handle) -> None:
                             sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4],
                         ]
                     )
-                    Ass_0 = T.var("int32")
-                    Ass_1 = T.var("int32")
+                    Ass_0 = T.int32()
+                    Ass_1 = T.int32()
                     sub_sub_A = T.match_buffer(
                         sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4],
                         (4, 4),
@@ -355,8 +355,8 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: 
T.int32) -> None:
         with T.block():
             T.reads([])
             T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 
4]])
-            Bs_0 = T.var("int32")
-            Bs_1 = T.var("int32")
+            Bs_0 = T.int32()
+            Bs_1 = T.int32()
             sub_A = T.match_buffer(A[i * m : i * m + m, 0:m], (m, m), 
offset_factor=1)
             sub_B = T.match_buffer(
                 B[i * n : i * n + 2, 0 : m * 4], (2, m * 4), strides=[Bs_0, 
Bs_1], offset_factor=1
@@ -470,7 +470,7 @@ def fail_buffer_bind(a: T.handle) -> None:
     A = T.match_buffer(a, (8, 8))
     for i, j in T.grid(8, 2):
         with T.block():
-            stride = T.var("int32")
+            stride = T.int32()
             sub_A = T.match_buffer(
                 A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], 
offset_factor=1
             )
diff --git a/tests/python/unittest/test_tir_renew_defs.py 
b/tests/python/unittest/test_tir_renew_defs.py
index e14cd5a898..e01f5ecb12 100644
--- a/tests/python/unittest/test_tir_renew_defs.py
+++ b/tests/python/unittest/test_tir_renew_defs.py
@@ -88,8 +88,8 @@ def test_match_buffer():
     # A and B should be remapped
     def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: 
T.Buffer((128, 128), "float32")):
         with T.block("root"):
-            s = T.var("int32")
-            e = T.var("int32")
+            s = T.int32()
+            e = T.int32()
             # A0 should be remapped
             A0 = T.match_buffer(
                 A[0:128, 0:128],
@@ -157,7 +157,7 @@ def test_undefined_buffer():
 def test_symbolic_func():
     @T.prim_func
     def symbolic_func(a: T.handle, b: T.handle, n: T.int32):
-        m = T.var("int32")
+        m = T.int32()
         A = T.match_buffer(a, (n, m))
         B = T.match_buffer(b, (n, m * 2))
         for i, j in T.grid(n, m):
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py 
b/tests/python/unittest/test_tir_schedule_rfactor.py
index 766cc3f867..199e822e84 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -954,7 +954,7 @@ def argmax_split_body_bufferstore_value_unbound_var(
     argmax_v0: T.Buffer((128,), "int32"),
     argmax_v1: T.Buffer((128,), "float32"),
 ) -> None:
-    v_unbound = T.var("int32")
+    v_unbound = T.int32()
     for i0, i1_0, i1_1 in T.grid(128, 4, 32):
         with T.block("argmax"):
             i = T.axis.spatial(128, i0)
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py 
b/tests/python/unittest/test_tir_schedule_tensorize.py
index 143cf87d04..fcb4bacbba 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -195,9 +195,9 @@ def tensorized_matmul(a: T.handle, b: T.handle, c: 
T.handle) -> None:
                     ]
                 )
                 T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
-                A_elem_offset = T.var("int32")
-                B_elem_offset = T.var("int32")
-                C_elem_offset = T.var("int32")
+                A_elem_offset = T.int32()
+                B_elem_offset = T.int32()
+                C_elem_offset = T.int32()
                 A_sub = T.match_buffer(
                     A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                     [16, 16],
@@ -267,9 +267,9 @@ def tensorized_batch_matmul_mma(
                     B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 
16],
                 )
                 T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 
16 + 16])
-                A_elem_offset = T.var("int32")
-                B_elem_offset = T.var("int32")
-                C_elem_offset = T.var("int32")
+                A_elem_offset = T.int32()
+                B_elem_offset = T.int32()
+                C_elem_offset = T.int32()
                 A_sub = T.match_buffer(
                     A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 
16],
                     (16, 16),
@@ -429,9 +429,9 @@ def annotated_tensorized_matmul(a: T.handle, b: T.handle, 
c: T.handle) -> None:
                     ]
                 )
                 T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
-                A_elem_offset = T.var("int32")
-                B_elem_offset = T.var("int32")
-                C_elem_offset = T.var("int32")
+                A_elem_offset = T.int32()
+                B_elem_offset = T.int32()
+                C_elem_offset = T.int32()
                 A_sub = T.match_buffer(
                     A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                     [16, 16],
@@ -745,9 +745,9 @@ def test_tensorize_matmul_mixed_dtype():
                         ]
                     )
                     T.writes(C[vi * T.int64(16) : vi * T.int64(16) + 
T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)])
-                    A_elem_offset = T.var("int64")
-                    B_elem_offset = T.var("int64")
-                    C_elem_offset = T.var("int64")
+                    A_elem_offset = T.int64()
+                    B_elem_offset = T.int64()
+                    C_elem_offset = T.int64()
                     A_sub = T.match_buffer(
                         A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), 
vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
                         [T.int64(16), T.int64(16)],
diff --git a/tests/python/unittest/test_tir_specialize.py 
b/tests/python/unittest/test_tir_specialize.py
index 72666a89eb..ebae827ef5 100644
--- a/tests/python/unittest/test_tir_specialize.py
+++ b/tests/python/unittest/test_tir_specialize.py
@@ -22,7 +22,7 @@ from tvm.script import tir as T
 
 @T.prim_func
 def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None:
-    m = T.var("int32")
+    m = T.int32()
     A = T.match_buffer(a, [m, n])
     B = T.match_buffer(b, [m, n])
     C = T.match_buffer(c, [m, m])
@@ -51,7 +51,7 @@ def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None:
 
 @T.prim_func
 def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None:
-    m = T.var("int32")
+    m = T.int32()
     A = T.match_buffer(a, [m, 128])
     B = T.match_buffer(b, [m, 128])
     C = T.match_buffer(c, [m, m])
@@ -66,8 +66,8 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> 
None:
 
 @T.prim_func
 def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None:
-    x = T.var("int32")
-    m = T.var("int32")
+    x = T.int32()
+    m = T.int32()
     A = T.match_buffer(a, [m, x * 8])
     B = T.match_buffer(b, [m, x * 8])
     C = T.match_buffer(c, [m, m])
@@ -82,8 +82,8 @@ def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> 
None:
 
 @T.prim_func
 def element_wise(a: T.handle, c: T.handle) -> None:
-    m = T.var("int32")
-    n = T.var("int32")
+    m = T.int32()
+    n = T.int32()
     A = T.match_buffer(a, (m, n), "float32")
     C = T.match_buffer(c, (m, n), "float32")
 
@@ -119,7 +119,7 @@ def element_wise_128_64(a: T.handle, c: T.handle) -> None:
 
 @T.prim_func
 def element_wise_128_n(a: T.handle, c: T.handle) -> None:
-    n = T.var("int32")
+    n = T.int32()
     A = T.match_buffer(a, (128, n), "float32")
     C = T.match_buffer(c, (128, n), "float32")
     B = T.alloc_buffer((128, n), "float32")
@@ -170,7 +170,7 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, 
n: T.int32, p: T.int3
 
 @T.prim_func
 def param_in_arith_exprs(a: T.handle, b: T.handle) -> None:
-    n = T.var("int32")
+    n = T.int32()
     A = T.match_buffer(a, [n // 8, 8], "int32")
     B = T.match_buffer(b, [n], "int32")
     for i in range(n - 1):
@@ -181,7 +181,7 @@ def param_in_arith_exprs(a: T.handle, b: T.handle) -> None:
 
 @T.prim_func
 def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None:
-    n = T.var("int32")
+    n = T.int32()
     A = T.match_buffer(a, [2, 8], "int32")
     B = T.match_buffer(b, [16], "int32")
     for i in range(15):
diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py 
b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
index 113d9f0474..1755a66ec9 100644
--- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
@@ -359,7 +359,7 @@ def func_distributivity_expected(
     i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
 ) -> None:
     B = T.Buffer((50,), "int32")
-    cse_var_1 = T.var("int32")
+    cse_var_1 = T.int32()
     with T.let(cse_var_1, x * y + x * z):
         B[i1] = cse_var_1
         B[i2] = cse_var_1
@@ -377,7 +377,7 @@ def func_associativity_expected(
     i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
 ) -> None:
     B = T.Buffer((50,), "int32")
-    cse_var_1 = T.var("int32")
+    cse_var_1 = T.int32()
     with T.let(cse_var_1, (x + y) + z):
         B[i1] = cse_var_1
         B[i2] = cse_var_1
diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py 
b/tests/python/unittest/test_tir_transform_hoist_expression.py
index 77862f64d6..ca37915597 100644
--- a/tests/python/unittest/test_tir_transform_hoist_expression.py
+++ b/tests/python/unittest/test_tir_transform_hoist_expression.py
@@ -447,7 +447,7 @@ class TestHoistLetExpr(BaseBeforeAfter):
     @T.prim_func
     def before(A: T.Buffer((4, 4), "float32")):
         for i, j in T.grid(4, 4):
-            x = T.var("float32")
+            x = T.float32()
             A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, 
"float32"))
 
     @T.prim_func
@@ -466,7 +466,7 @@ class TestSuppressHoistLetExpr(BaseBeforeAfter):
     @T.prim_func
     def before(A: T.Buffer((4, 4), "float32")):
         for i, j in T.grid(4, 4):
-            x = T.var("float32")
+            x = T.float32()
             A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, 
"float32"))
 
     expected = before
diff --git a/tests/python/unittest/test_tvmscript_error_report.py 
b/tests/python/unittest/test_tvmscript_error_report.py
index d2f275ac3d..2713669bd3 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -511,7 +511,7 @@ def test_report_error_root_block():
 
 def test_load_var():
     def load_var_multiple() -> None:
-        d = T.var("float32")
+        d = T.float32()
         d[2] = d[2, 1]  # error cannot provide two indices to load
 
     check_error(load_var_multiple, 3)
@@ -519,7 +519,7 @@ def test_load_var():
 
 def test_store_var():
     def store_var_multiple() -> None:
-        d = T.var("float32")
+        d = T.float32()
         d[2, 1] = d[1]  # error cannot provide two indices to store
 
     check_error(store_var_multiple, 3)
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py 
b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 85d2e808b3..889f0c9eda 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -52,7 +52,7 @@ def test_ir_builder_tir_primfunc_complete():
     with IRBuilder() as ib:
         with T.prim_func():
             T.arg("a", T.handle())
-            T.arg("b", T.var("int64"))
+            T.arg("b", T.int64())
             T.arg("c", T.Buffer((128, 128), "float32"))
             d = T.arg("d", T.handle())
             e = T.arg("e", T.Buffer((1024,), "int8"))
@@ -119,12 +119,12 @@ def test_ir_builder_tir_block_base():
 
 def test_ir_builder_tir_block_complete():
     with IRBuilder() as ib:
-        a = T.var("int64", "a")
+        a = T.int64()
         b = T.Buffer((128, 128), "float32")
         c = T.Buffer((128, 128), "float32")
-        d = T.var("int32", "d")
+        d = T.int32()
         e = T.Buffer((128, 128), "float32")
-        f = T.var("int32", "f")
+        f = T.int32()
         with T.block("block"):
             T.where(a > 1)
             T.reads(b[0:16, 0:16])
@@ -169,10 +169,10 @@ def test_ir_builder_tir_block_complete():
 
 def test_ir_builder_tir_axis():
     with IRBuilder() as ib:
-        a = T.var("int32", "a")
-        b = T.var("int32", "b")
-        c = T.var("int32", "c")
-        d = T.var("int32", "d")
+        a = T.int32()
+        b = T.int32()
+        c = T.int32()
+        d = T.int32()
         with T.block("block"):
             T.axis.spatial(8, a)
             T.axis.reduce(16, b)
@@ -269,15 +269,13 @@ def test_ir_builder_tir_for():
 
 def test_ir_builder_tir_assert():
     with IRBuilder() as ib:
-        with T.Assert(T.var("int32", name="a") == 0, message="a is 0"):
+        with T.Assert(T.int32() == 0, message="a is 0"):
             T.evaluate(0)
     # the assert generated by IRBuilder
     assert_actual = ib.get()
 
     # the expected assert statement
-    assert_expected = tir.AssertStmt(
-        T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0)
-    )
+    assert_expected = tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0"), 
tir.Evaluate(0))
 
     # Check if the generated ir is expected
     assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
@@ -285,13 +283,13 @@ def test_ir_builder_tir_assert():
 
 def test_ir_builder_tir_let():
     with IRBuilder() as ib:
-        with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)):
+        with T.let(T.int32(), tir.IntImm("int32", 2)):
             T.evaluate(0)
     # the let binding generated by IRBuilder
     let_actual = ib.get()
 
     # the expected Let statement
-    let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 
2), tir.Evaluate(0))
+    let_expected = tir.LetStmt(T.int32(), tir.IntImm("int32", 2), 
tir.Evaluate(0))
 
     # Check if the generated ir is expected
     assert_structural_equal(let_actual, let_expected, map_free_vars=True)
@@ -381,7 +379,7 @@ def test_ir_builder_tir_allocate_const():
 
 def test_ir_builder_tir_while():
     with IRBuilder() as ib:
-        with T.While(T.var("int32", "x") > 0):
+        with T.While(T.int32() > 0):
             T.evaluate(0)
 
     # the while generated by IRBuilder
@@ -396,7 +394,7 @@ def test_ir_builder_tir_while():
 
 def test_ir_builder_tir_if_then_else():
     with IRBuilder() as ib:
-        with T.If(T.var("int32", "c") < 12):
+        with T.If(T.int32() < 12):
             with T.Then():
                 T.evaluate(T.int32(0))
             with T.Else():
@@ -418,7 +416,7 @@ def test_ir_builder_tir_if_then_else():
 
 def test_ir_builder_tir_buffer_store():
     buffer_a = T.Buffer((10, 10), "float32")
-    i = T.var("int32", "x")
+    i = T.int32()
     with IRBuilder() as ib:
         T.buffer_store(buffer_a, 0.1, [0, i])
 
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py 
b/tests/python/unittest/test_tvmscript_printer_tir.py
index a04544152e..13aaacb3b7 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -117,9 +117,9 @@ def test_block_realize():
     _assert_print(
         obj,
         """
-i = T.var("int32")
-j = T.var("int32")
-k = T.var("int32")
+i = T.int32()
+j = T.int32()
+k = T.int32()
 with T.block("block"):
     vi = T.axis.spatial(128, i)
     vj = T.axis.spatial(64, j)
@@ -248,13 +248,13 @@ for i, j, k in T.grid(128, 128, 128):
 
 def test_let_stmt():
     with IRBuilder() as ib:
-        with T.let(T.var("float32"), T.float32(10)):
+        with T.let(T.float32(), T.float32(10)):
             T.evaluate(0)
     obj = ib.get()
     _assert_print(
         obj,
         """
-v = T.var("float32")
+v = T.float32()
 with T.let(v, T.float32(10)):
     T.evaluate(0)
 """,
@@ -291,14 +291,14 @@ with T.Assert(1, "assertion"):
 
 def test_while():
     with IRBuilder() as ib:
-        x = T.var("int32")
+        x = T.int32()
         with T.While(x < 10):
             T.evaluate(0)
     obj = ib.get()
     _assert_print(
         obj,
         """
-v = T.var("int32")
+v = T.int32()
 while v < 10:
     T.evaluate(0)
 """,
@@ -410,7 +410,7 @@ T.evaluate(1)
 
 def test_if_then_else():
     with IRBuilder() as ib:
-        with T.If(T.var("int32") == 1):
+        with T.If(T.int32() == 1):
             with T.Then():
                 T.evaluate(0)
 
@@ -418,7 +418,7 @@ def test_if_then_else():
     _assert_print(
         obj,
         """
-v = T.var("int32")
+v = T.int32()
 if v == 1:
     T.evaluate(0)
 """,
@@ -458,7 +458,7 @@ def test_var():
     _assert_print(
         a,
         """
-a = T.var("float32")
+a = T.float32()
 a""",
     )
 
@@ -468,7 +468,7 @@ def test_size_var():
     _assert_print(
         a,
         """
-a = T.var("float32")
+a = T.float32()
 a""",
     )
 
@@ -478,7 +478,7 @@ def test_iter_var():
     _assert_print(
         a,
         """
-a = T.var("int32")
+a = T.int32()
 T.iter_var(a, T.Range(0, 8), "DataPar", "")
 """,
     )
@@ -494,7 +494,7 @@ def test_cast():
     _assert_print(
         obj,
         """
-a = T.var("float32")
+a = T.float32()
 T.Cast("float64", a)
 """,
     )
@@ -521,15 +521,15 @@ def test_binary_arith():
         obj = op(a, b)
         if sign.isalpha():
             expected = """
-a = T.var("float32")
-b = T.var("float32")
+a = T.float32()
+b = T.float32()
 T.{}(a, b)""".format(
                 sign
             )
         else:
             expected = """
-a = T.var("float32")
-b = T.var("float32")
+a = T.float32()
+b = T.float32()
 a {} b""".format(
                 sign
             )
@@ -537,28 +537,28 @@ a {} b""".format(
 
 
 def test_logical():
-    a = T.var("bool", "a")
-    b = T.var("bool", "b")
+    a = tir.Var("a", "bool")
+    b = tir.Var("b", "bool")
     _assert_print(
         tir.And(a, b),
         """
-a = T.var("bool")
-b = T.var("bool")
+a = T.bool()
+b = T.bool()
 a and b
 """,
     )
     _assert_print(
         tir.Or(a, b),
         """
-a = T.var("bool")
-b = T.var("bool")
+a = T.bool()
+b = T.bool()
 a or b
 """,
     )
     _assert_print(
         tir.Not(a),
         """
-a = T.var("bool")
+a = T.bool()
 not a
 """,
     )
@@ -579,7 +579,7 @@ def test_ramp():
     _assert_print(
         obj,
         """
-a = T.var("int32")
+a = T.int32()
 T.Ramp(a, 1, 32)
 """,
     )
@@ -601,7 +601,7 @@ def test_let_expr():
     _assert_print(
         obj,
         """
-x = T.var("int32")
+x = T.int32()
 T.let(x, 1, x + 1)
 """,
     )
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index db21223366..48a5999469 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -2904,10 +2904,10 @@ def constant_folding():
 def simplify_bracket():
     @T.prim_func
     def simplify_bracket() -> None:
-        a = T.var("int32")
-        b = T.var("int32")
-        c = T.var("int32")
-        d = T.var("int32")
+        a = T.int32()
+        b = T.int32()
+        c = T.int32()
+        d = T.int32()
         T.evaluate(a + b * (c + d))
 
     return simplify_bracket
@@ -3039,8 +3039,8 @@ def multiple_commreducer():
 def func_div_mod():
     @T.prim_func
     def func_div_mod():
-        a = T.var("int32")
-        b = T.var("int32")
+        a = T.int32()
+        b = T.int32()
         T.evaluate(a // b)
         T.evaluate(a % b)
         T.evaluate(T.truncmod(a, b))
@@ -3316,7 +3316,7 @@ def buffer_ramp_access_as_slice_index():
 def let_expression():
     @T.prim_func
     def func():
-        x = T.var("int32")
+        x = T.int32()
         T.evaluate(T.let(x, 1, x + 1))
 
     return func
@@ -3542,8 +3542,8 @@ def intrinsic_pow():
 def let_stmt_var():
     @T.prim_func
     def func():
-        x = T.var("int32")
-        y = T.var("int32")
+        x = T.int32()
+        y = T.int32()
         with T.let(x, 0):
             with T.let(y, 0):
                 T.evaluate(0)
@@ -3555,8 +3555,8 @@ def let_stmt_var():
 def let_stmt_value():
     @T.prim_func
     def func():
-        x = T.var("int32")
-        y = T.var("int32")
+        x = T.int32()
+        y = T.int32()
         with T.let(x, y):
             with T.let(y, 0):
                 T.evaluate(0)
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py 
b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index a840722bea..e4ba1f7950 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -155,9 +155,9 @@ def test_match_buffer_1d():
 # dynamic shape gemm
 @T.prim_func
 def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
-    N = T.var("int32")
-    M = T.var("int32")
-    K = T.var("int32")
+    N = T.int32()
+    M = T.int32()
+    K = T.int32()
     A = T.match_buffer(a, (N, K), "float32")
     B = T.match_buffer(b, (K, M), "float32")
     C = T.match_buffer(c, (N, M), "float32")

Reply via email to