This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fffed0f  [TensorIR] TVMScript Parser/Printer (#7630)
fffed0f is described below

commit fffed0ff91c46f5c45070b52794f4f2bf4d1b8a5
Author: Siyuan Feng <hzfen...@sjtu.edu.cn>
AuthorDate: Sun Mar 21 04:22:53 2021 +0800

    [TensorIR] TVMScript Parser/Printer (#7630)
    
    
    Co-authored-by: Bohan Hou 
<32121147+spectrometer...@users.noreply.github.com>
    Co-authored-by: Junru Shao <junrushao1...@gmail.com>
    Co-authored-by: Tianqi Chen <tqc...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <lairuihangdongd...@qq.com>
    Co-authored-by: Hongyi Jin <3231950...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Tristan Konolige <tristan.konol...@gmail.com>
    Co-authored-by: Cody Yu <comaniac0...@gmail.com>
---
 include/tvm/tir/analysis.h                         |  15 +
 python/tvm/script/context_maintainer.py            | 210 +++++++--
 python/tvm/script/intrin.py                        |  20 +-
 python/tvm/script/node.py                          | 150 +++++++
 python/tvm/script/parser.py                        | 179 +++++---
 python/tvm/script/registry.py                      |  20 +-
 python/tvm/script/scope_handler.py                 | 473 ++++++++++++++++++---
 python/tvm/script/special_stmt.py                  | 380 +++++++++++++++--
 python/tvm/script/utils.py                         |  95 ++++-
 python/tvm/tir/analysis/analysis.py                |  23 +
 src/printer/tir_text_printer.cc                    |   3 +-
 src/printer/tvmscript_printer.cc                   | 232 +++++++++-
 src/tir/analysis/block_access_region_detector.cc   | 246 +++++++++++
 src/tir/ir/script/script_complete.cc               | 122 ++++++
 .../test_tir_analysis_get_block_access_region.py   |  57 +++
 .../python/unittest/test_tvmscript_error_report.py | 205 +++++++++
 tests/python/unittest/test_tvmscript_roundtrip.py  | 170 ++++++++
 tests/scripts/task_ci_python_setup.sh              |   2 +-
 tests/scripts/task_ci_setup.sh                     |   2 +-
 19 files changed, 2395 insertions(+), 209 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 1ad7859..1692a8c 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -157,6 +157,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
  */
 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> 
constraints);
 
+/*!
+ * \brief Auto detect the block read/write region according to body stmt
+ *        It will detect the read/write region as an array in order of 
appearance in AST
+ * \param block The block to be detected
+ * \param buffer_var_map The outside buffers which may be accessed the block.
+ *                       It is a map from buffer var to the buffer.
+ * \return Array of access regions.
+ *         There are three arrays of BufferRegion:
+ *           - first: read regions
+ *           - second: write regions
+ *           - third: opaque regions
+ */
+Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
+                                                const Map<Var, Buffer>& 
buffer_var_map);
+
 // Pass variants of verification analysis
 // directly throws RuntimeError when verification fails.
 namespace transform {
diff --git a/python/tvm/script/context_maintainer.py 
b/python/tvm/script/context_maintainer.py
index 955266c..ae3e9d8 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/context_maintainer.py
@@ -16,59 +16,217 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature
+
+    Examples
+    ----------
+    .. code-block:: python
+
+        @tvm.script.tir
+        def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+            A = tir.match_buffer(a, (16, 16), "float32")
+            B = tir.match_buffer(b, (16, 16), "float32")
+            C = tir.match_buffer(a, (16, 16), "float32")
+
+            for i, j, k in tir.grid(16, 16, 16):
+                with tir.block([16, 16, tir.reduce_axis(16)], "matmul") as 
[vi, vj, vk]:
+                    tir.bind(vi, i)
+                    tir.bind(vj, j)
+                    tir.bind(vk, k)         # iter_bindings = {vj: i, vj: j, 
vk: k}
+
+                    tir.where(True)         # predicate of the block_realize
+
+                    tir.reads(A[0:16, 0:16], B[0: 16, 0: 16])      # reads 
region of the block
+                    tir.writes(C[0: 16, 0: 16])                    # writes 
region of the block
+                    tir.block_attr({"attr_key": "attr_value"})     # block 
annotations
+
+                    # alloc_buffers inside the block
+                    CC = tir.alloc_buffer((1, 1), dtype="float32")
+
+                    # match_buffers of the block,
+                    # which bind a sub-region of source buffer into a new 
buffer
+                    D = tir.match_buffer_region(C[vi, vj])
+
+                    # init part of the block, executed when all reduce axes 
are the beginning value
+                    with tir.init():
+                        C[vi, vj] = tir.float32(0)
+
+                    # block body
+                    CC[0, 0] = A[vi, vk] * B[vj, vk]
+                    D[0, 0] += CC[0, 0]         # The same as C[vi, vj] += 
CC[0, 0]
+    """
+
+    alloc_buffers: List[Buffer] = []
+    """List[Buffer]: list of tir.alloc_buffer statements in the block 
signature"""
+    match_buffers: List[MatchBufferRegion] = []
+    """List[MatchBufferRegion]: list of tir.match_buffer_region statements in 
the block signature"""
+    iter_bindings: Mapping[Var, PrimExpr] = {}
+    """Mapping[Var, PrimExpr]: map of block iter var to its values"""
+    reads: Optional[List[BufferSlice]] = None
+    """Optional[List[BufferSlice]]:
+    list of tir.reads statements in the block signature, None for 
not-visited"""
+    writes: Optional[List[BufferSlice]] = None
+    """Optional[List[BufferSlice]]:
+    list of tir.writes statements in the block signature, None for 
not-visited"""
+    annotations: Optional[Mapping[str, Object]] = None
+    """Optional[Mapping[str, Object]]:
+    list of tir.block_attr statements in the block signature, None for 
not-visited"""
+    predicate: Optional[PrimExpr] = None
+    """Optional[PrimExpr]: block realize predicate, None for not-visited"""
+    init: Optional[Stmt] = None
+    """Optional[Stmt]: init part of the block, None for not-visited"""
+
+    def __init__(self):
+        self.alloc_buffers = []
+        self.match_buffers = []
+        self.iter_bindings = {}
+        self.reads = None
+        self.writes = None
+        self.annotations = None
+        self.predicate = None
+        self.init = None
 
 
 class ContextMaintainer:
-    """Maintain all the necessary context info"""
+    """Maintain all the necessary context info
+    Parameters
+    ----------
+    _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
+        The report error function handle
+    """
+
+    # scope context
+    node_stack: List[List[synr.ast.Node]] = []
+    """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
+    block_info_stack: List[BlockInfo] = []
+    """List[BlockInfo]: The block info for the current block scope"""
+    loop_stack: List[List[Var]] = []
+    """List[List[Var]]: List of loop vars inside the current block scope"""
+    symbols: List[Dict[str, Union[Var, Buffer]]] = []
+    """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for 
the current scope"""
 
-    def __init__(self, parser):
+    # function context
+    func_params: List[Var] = []
+    """List[Var]: The function parameters"""
+    func_buffer_map: Mapping[Var, Buffer] = {}
+    """Mapping[Var, Buffer]: The function buffer map"""
+    func_dict_attr: Mapping[str, Object] = {}
+    """Mapping[str, Object]: The function attrs"""
+    func_var_env_dict: Mapping[Var, str] = {}
+    """Mapping[Var, str]: The map from var to env thread"""
+
+    # parser and analyzer
+    analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
+    """tvm.arith.Analyzer: The analyzer for simplifying"""
+    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error 
function handle"""
+
+    def __init__(self, _report_error: Callable[[str, Union[Span, 
synr.ast.Span]], None]):
         # scope context
-        self.node_stack = []  # AST nodes of scopes
-        self.symbols = []  # symbols of scopes
+        self.node_stack = []
+        self.block_info_stack = []
+        self.loop_stack = []
+        self.symbols = []
         # function context
-        self.func_params = []  # parameter list of function
-        self.func_buffer_map = {}  # buffer_map of function
-        self.func_dict_attr = {}  # func_attr of function
-        self.func_var_env_dict = {}  # map from var to env_name
-        # parser
-        self.parser = parser
-
-    def pop_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
+        self.func_params = []
+        self.func_buffer_map = {}
+        self.func_dict_attr = {}
+        self.func_var_env_dict = {}
+        # parser and analyzer
+        self._report_error = _report_error
+        self.analyzer = tvm.arith.Analyzer()
+
+    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creates a new scope
 
-    def new_scope(self, nodes=None):
-        """Creating a new scope"""
+        Note
+        ----
+        This function is used for normal scopes that do not involve
+        a `with block` scope. Use `enter_block_scope`
+        for block scope cases.
+
+        Parameters
+        ----------
+        nodes : Optional[List[synr.ast.Node]]
+            The synr AST nodes in new scope
+        """
         if nodes is None:
             nodes = []
         self.node_stack.append(list(reversed(nodes)))
         self.symbols.append(dict())
 
-    def update_symbol(self, name, symbol):
+    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creates a new block scope, the function will call `enter_scope` 
implicitly
+        Besides the behaviors of `enter_scope`, it will update loop_stack and 
block_info_stack
+        to maintain block info.
+
+        Note
+        ----
+        This function should be used to handle a block scope,
+        aka the blocks that involve a `with block` scope.
+
+        Parameters
+        ----------
+        nodes : Optional[List[synr.ast.Node]]
+            The synr AST nodes in new scope
+        """
+        self.enter_scope(nodes)
+        # Create a new loop stack for the new block
+        self.loop_stack.append([])
+        # Create a new BlockInfo for the new block
+        self.block_info_stack.append(BlockInfo())
+
+    def exit_scope(self):
+        """Pop the inner most scope"""
+        self.symbols.pop()
+        self.node_stack.pop()
+
+    def exit_block_scope(self):
+        """Pop the inner most block scope, the function will call `exit_scope` 
implicitly"""
+        self.exit_scope()
+        # Pop loop stack
+        self.loop_stack.pop()
+        # Pop block_info
+        self.block_info_stack.pop()
+
+    def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: 
synr.ast.Node):
         """Append a symbol into current scope"""
-        if isinstance(symbol, schedule.Buffer):
+        if isinstance(symbol, Buffer):
             if name in self.symbols[0]:
-                self.parser.report_error("Duplicate Buffer name")
+                self.report_error("Duplicate Buffer name: " + symbol.name, 
node.span)
             self.symbols[0][name] = symbol
         else:
             self.symbols[-1][name] = symbol
 
-    def remove_symbol(self, name):
+    def remove_symbol(self, name: str):
         """Remove a symbol"""
         for symbols in reversed(self.symbols):
             if name in symbols:
                 symbols.pop(name)
                 return
-        raise RuntimeError("Internal error of tvm script parser: no symbol 
named" + name)
+        raise RuntimeError("Internal error of tvm script parser: no symbol 
named " + name)
 
-    def lookup_symbol(self, name):
+    def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
         """Look up symbol by name"""
         for symbols in reversed(self.symbols):
             if name in symbols:
                 return symbols[name]
         return None
 
-    def report_error(self, message, span):
-        self.parser.report_error(message, span)
+    def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
+        self._report_error(message, span)
+
+    def current_block_scope(self) -> BlockInfo:
+        return self.block_info_stack[-1]
diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py
index 053cd4a..48f50a2 100644
--- a/python/tvm/script/intrin.py
+++ b/python/tvm/script/intrin.py
@@ -16,9 +16,11 @@
 # under the License.
 """TVM Script Parser Intrinsic Classes"""
 # pylint: disable=redefined-builtin, relative-beyond-top-level
+from typing import List, Any
+
 import tvm.tir
 from .registry import register
-from .utils import get_param_list, from_synr_span
+from .utils import get_param_list, tvm_span_from_synr
 
 
 class Intrin:
@@ -29,8 +31,8 @@ class Intrin:
     def signature(self):
         return "tir." + self.intrin.__name__, get_param_list(self.intrin)
 
-    def handle(self, arg_list, span):
-        return self.intrin(*arg_list, span=from_synr_span(span))
+    def handle(self, arg_list: List[Any], span: tvm.ir.Span):
+        return self.intrin(*arg_list, span=tvm_span_from_synr(span))
 
 
 @register
@@ -99,6 +101,16 @@ def float64(imm, span):
 
 
 @register
+def min_value(dtype, span):
+    return tvm.tir.min_value(dtype, span)
+
+
+@register
+def max_value(dtype, span):
+    return tvm.tir.max_value(dtype, span)
+
+
+@register
 def floordiv(x, y, span):
     return tvm.tir.floordiv(x, y, span)
 
@@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span):
     block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)
 
     iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
-    return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], 
span)
+    return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], 
span=span)
 
 
 @register
diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py
new file mode 100644
index 0000000..039eeb4
--- /dev/null
+++ b/python/tvm/script/node.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""TVM Script nodes."""
+
+from typing import Optional, Union, List, Callable
+import synr
+
+from tvm.runtime import ObjectGeneric
+from tvm.tir import PrimExpr, Buffer, BufferLoad
+from tvm.ir import Span
+
+
+class Slice:
+    """A helper class to present slice information for BufferSlice
+
+    Parameters
+    ----------
+    start : Union[PrimExpr, int]
+        The start index.
+
+    stop : Optional[Union[PrimExpr, int]]
+        The stop index, None means the Slice is an element-wise index
+
+    span : Optional[Span]
+        The location of the slice in the source.
+    """
+
+    start: Union[PrimExpr, int]
+    stop: Optional[Union[PrimExpr, int]]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        start: Union[PrimExpr, int],
+        stop: Optional[Union[PrimExpr, int]] = None,
+        span: Optional[Span] = None,
+    ):
+        self.start = start
+        self.stop = stop
+        self.span = span
+
+
+class BufferSlice(ObjectGeneric):
+    """A generic object for representing general buffer access. Following 
cases are supported:
+        - element wise access buffer[i, j], which can be converted to 
BufferLoad if necessary
+        - slice access buffer[i: i + 1, j : j + 2]
+        - union of element and slice buffer[i, j: j + 2]
+
+        This node is used in TVMScript to parse BufferLoad, BufferRegion and 
Realize
+
+    Parameters
+    ----------
+    buffer : Buffer
+        The buffer.
+
+    indices : List[Union[Slice, PrimExpr, int]]
+        The access indexes can be slice, PrimExpr or int.
+
+    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+        The error report func
+
+    span : Optional[Span]
+        The location of the buffer access in the source.
+    """
+
+    buffer: Buffer
+    slices: List[Slice]
+    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        buffer: Buffer,
+        indices: List[Union[Slice, PrimExpr, int]],
+        report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
+        span: Optional[Span] = None,
+    ):
+        def check_index(index: Union[int, PrimExpr]):
+            """ Check input index is non-negative integer or PrimExpr"""
+            if isinstance(index, int):
+                if index < 0:
+                    report_error("Negative index is not allowed during buffer 
access", span)
+            elif isinstance(index, PrimExpr):
+                if index.dtype != "int32":
+                    report_error(
+                        "index expected an int32 type PrimExpr but got " + 
str(index.dtype),
+                        index.span,
+                    )
+            else:
+                report_error(
+                    "Unsupported index type, expected int or tvm.tir.PrimExpr, 
but got "
+                    + str(type(index)),
+                    span,
+                )
+
+        slices: List[Slice] = []
+        for index in indices:
+            if isinstance(index, Slice):
+                check_index(index.start)
+                check_index(index.stop)
+                slices.append(index)
+            elif isinstance(index, (PrimExpr, int)):
+                check_index(index)
+                slices.append(Slice(index))
+            else:
+                report_error(
+                    "Unsupported index type for BufferSlice, "
+                    + "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got "
+                    + str(type(index)),
+                    span,
+                )
+
+        self.buffer = buffer
+        self.slices = slices
+        self.report_error = report_error
+        self.span = span
+
+    def __str__(self):
+        regions: List[str] = []
+        for s in self.slices:
+            if s.stop is None:
+                regions.append(str(s.start))
+            else:
+                regions.append(str(s.start) + ": " + str(s.stop))
+
+        return self.buffer.name + "[" + ", ".join(regions) + "]"
+
+    def asobject(self) -> BufferLoad:
+        """Convert object."""
+        for s in self.slices:
+            if s.stop is not None:
+                self.report_error("BufferLoad only accepts elementwise 
access", self.span)
+
+        indices = [s.start for s in self.slices]
+        return BufferLoad(self.buffer, indices, span=self.span)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 33b0bab..8f6d338 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -24,6 +24,7 @@ use for error reporting.
 import json
 import operator
 import inspect
+from typing import Union
 from synr import ast, Transformer, to_ast
 
 import tvm
@@ -32,6 +33,7 @@ from tvm._ffi.base import TVMError
 from tvm.ir import GlobalVar
 
 from . import context_maintainer, ty
+from .context_maintainer import BlockInfo
 from .meta_unparser import MetaUnparser
 from .registry import Registry
 from .intrin import Intrin
@@ -39,7 +41,8 @@ from .special_stmt import SpecialStmt
 from .scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
 from . import _ffi_api
 from .diagnostics import TVMDiagnosticCtx
-from .utils import from_synr_span
+from .utils import tvm_span_from_synr, synr_span_from_tvm, 
call_with_error_reporting
+from .node import Slice, BufferSlice
 
 
 class CallArgumentReader(object):
@@ -158,7 +161,7 @@ class TVMScriptParser(Transformer):
 
     def init_function_parsing_env(self):
         """Initialize function parsing environment"""
-        self.context = context_maintainer.ContextMaintainer(self)  # scope 
emitter
+        self.context = context_maintainer.ContextMaintainer(self.report_error) 
 # scope emitter
 
     def init_meta(self, meta_dict):
         if meta_dict is not None:
@@ -182,7 +185,7 @@ class TVMScriptParser(Transformer):
 
         return transform_res
 
-    def report_error(self, message, span):
+    def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
         """Report an error occuring at a location.
 
         This just dispatches to synr's DiagnosticContext.
@@ -191,9 +194,11 @@ class TVMScriptParser(Transformer):
         ----------
         message : str
             Error message
-        span : synr.ast.Span
+        span : Union[synr.ast.Span, tvm.ir.Span】
             Location of the error
         """
+        if isinstance(span, tvm.ir.Span):
+            span = synr_span_from_tvm(span)
         self.error(message, span)
 
     def parse_body(self, parent):
@@ -221,7 +226,7 @@ class TVMScriptParser(Transformer):
             )
         else:
             return (
-                tvm.tir.SeqStmt(body, from_synr_span(ast.Span.union(spans)))
+                tvm.tir.SeqStmt(body, 
tvm_span_from_synr(ast.Span.union(spans)))
                 if len(body) > 1
                 else body[0]
             )
@@ -270,6 +275,13 @@ class TVMScriptParser(Transformer):
             internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), 
arg_name, default=default))
         if varargs is not None:
             internal_args.extend(reader.get_varargs(len(pos_only) + 
len(kwargs) + 1))
+        elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
+            self.report_error(
+                "Arguments mismatched. "
+                + f"Expected {len(pos_only) + len(kwargs)} args but got "
+                + f"{len(args) + len(kw_args)}",
+                node_call.span,
+            )
         return internal_args
 
     def parse_type(self, type_node, parent):
@@ -401,25 +413,52 @@ class TVMScriptParser(Transformer):
         """
 
         self.init_function_parsing_env()
-        self.context.new_scope(nodes=node.body.stmts)
+        self.context.enter_scope(nodes=node.body.stmts)
 
         # add parameters of function
         for arg in node.params:
             arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
-            self.context.update_symbol(arg.name, arg_var)
+            self.context.update_symbol(arg.name, arg_var, node)
             self.context.func_params.append(arg_var)
 
-        # fetch the body and return a tir.PrimFunc
+        # New Scope : Implicit root block
+        # Each function contains an implicit root block in TensorIR,
+        # so here we need a block scope for it. Please note that 
`enter_block_scope`
+        # will not create a block directly but just stores some information.
+        # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or 
low-level func),
+        # the root block will not be added. The logic to add root block is in 
`_ffi_api.Complete`
+        self.context.enter_block_scope(nodes=node.body.stmts)
+
+        # fetch the body of root block
+        body = self.parse_body(node.body)
+        # Emit Scope : Implicit root block
+        root_info: BlockInfo = self.context.current_block_scope()
+        self.context.exit_block_scope()
+
+        # return a tir.PrimFunc
+        dict_attr = self.context.func_dict_attr
         func = tvm.tir.PrimFunc(
             self.context.func_params,
-            self.parse_body(node.body),
+            body,
             ret_type=self.parse_type(node.ret_type, node),
             buffer_map=self.context.func_buffer_map,
-            attrs=tvm.ir.make_node("DictAttrs", **self.context.func_dict_attr),
-            span=from_synr_span(node.span),
+            attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else 
None,
+            span=tvm_span_from_synr(node.span),
+        )
+
+        # Fix the PrimFunc
+        # 1. generate root block if necessary
+        # 2. generate surrounding loops for blocks if necessary
+
+        func = call_with_error_reporting(
+            self.report_error,
+            node.span,
+            _ffi_api.Complete,
+            func,
+            root_info.alloc_buffers,
         )
 
-        self.context.pop_scope()
+        self.context.exit_scope()
         return func
 
     def transform_Assign(self, node):
@@ -470,12 +509,12 @@ class TVMScriptParser(Transformer):
                 var = tvm.te.var(
                     node.lhs.id.name,
                     self.parse_type(node.ty, node.lhs),
-                    span=from_synr_span(node.lhs.span),
+                    span=tvm_span_from_synr(node.lhs.span),
                 )
-                self.context.update_symbol(var.name, var)
+                self.context.update_symbol(var.name, var, node)
                 body = self.parse_body(node)
                 self.context.remove_symbol(var.name)
-                return tvm.tir.LetStmt(var, value, body, 
span=from_synr_span(node.span))
+                return tvm.tir.LetStmt(var, value, body, 
span=tvm_span_from_synr(node.span))
 
         self.report_error("Unsupported Assign stmt", node.span)
 
@@ -484,28 +523,28 @@ class TVMScriptParser(Transformer):
         symbol = self.transform(node.params[0])
         indexes = self.transform(node.params[1])
         rhs = self.transform(node.params[2])
-        rhs_span = from_synr_span(node.params[2].span)
+        rhs_span = tvm_span_from_synr(node.params[2].span)
         if isinstance(symbol, tvm.tir.Buffer):
             # BufferStore
             return tvm.tir.BufferStore(
                 symbol,
                 tvm.runtime.convert(rhs, span=rhs_span),
                 indexes,
-                span=from_synr_span(node.span),
+                span=tvm_span_from_synr(node.span),
             )
         else:
             if len(indexes) != 1:
                 self.report_error(
                     f"Store is only allowed with one index, but {len(indexes)} 
were provided.",
-                    Span.union([x.span for x in indexes]),
+                    tvm.ir.Span.union([x.span for x in indexes]),
                 )
             # Store
             return tvm.tir.Store(
                 symbol,
                 tvm.runtime.convert(rhs, span=rhs_span),
                 indexes[0],
-                tvm.runtime.convert(True, span=from_synr_span(node.span)),
-                span=from_synr_span(node.span),
+                tvm.runtime.convert(True, span=tvm_span_from_synr(node.span)),
+                span=tvm_span_from_synr(node.span),
             )
 
     def transform_Assert(self, node):
@@ -520,7 +559,7 @@ class TVMScriptParser(Transformer):
         message = self.transform(node.msg)
         body = self.parse_body(node)
         return tvm.tir.AssertStmt(
-            condition, tvm.runtime.convert(message), body, 
span=from_synr_span(node.span)
+            condition, tvm.runtime.convert(message), body, 
span=tvm_span_from_synr(node.span)
         )
 
     def transform_For(self, node):
@@ -529,7 +568,8 @@ class TVMScriptParser(Transformer):
             For(expr target, expr iter, stmt* body, stmt* orelse, string? 
type_comment)
         By now 1 pattern of For is supported:
             1. for scope handler
-                for name in 
tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll()
+                for name in 
tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll()/tir.range()/
+                            tir.grid()/tir.thread_binding()
         """
 
         if not isinstance(node.rhs, ast.Call):
@@ -543,14 +583,14 @@ class TVMScriptParser(Transformer):
         old_lineno, old_col_offset = self.current_lineno, 
self.current_col_offset
         self.current_lineno = node.span.start_line
         self.current_col_offset = node.span.start_column
-        self.context.new_scope(nodes=node.body.stmts)
+        self.context.enter_scope(nodes=node.body.stmts)
         # for scope handler process the scope
         arg_list = self.parse_arg_list(func, node.rhs)
         func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
         func.body = self.parse_body(node)
         res = func.exit_scope(node, self.context, arg_list, 
node.rhs.func_name.span)
         # exit the scope
-        self.context.pop_scope()
+        self.context.exit_scope()
         self.current_lineno, self.current_col_offset = old_lineno, 
old_col_offset
         return res
 
@@ -561,9 +601,9 @@ class TVMScriptParser(Transformer):
             withitem = (expr context_expr, expr? optional_vars)
         By now 2 patterns of With is supported:
             1. with scope handler with symbol def
-                with tir.allocate() as targets:
+                with tir.block(*axes)/tir.allocate() as targets:
             2. with scope handler without symbol def
-                with tir.let()/tir.Assert()/tir.attr()//tir.realize()
+                with tir.let()/tir.Assert()/tir.attr()/tir.realize()
         """
 
         if not isinstance(node.rhs, ast.Call):
@@ -582,14 +622,14 @@ class TVMScriptParser(Transformer):
         old_lineno, old_col_offset = self.current_lineno, 
self.current_col_offset
         self.current_lineno = node.body.span.start_line
         self.current_col_offset = node.body.span.start_column
-        self.context.new_scope(nodes=node.body.stmts)
+        self.context.enter_block_scope(nodes=node.body.stmts)
         # with scope handler process the scope
         arg_list = self.parse_arg_list(func, node.rhs)
         func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
         func.body = self.parse_body(node)
         res = func.exit_scope(node, self.context, arg_list, 
node.rhs.func_name.span)
         # exit the scope
-        self.context.pop_scope()
+        self.context.exit_block_scope()
         self.current_lineno, self.current_col_offset = old_lineno, 
old_col_offset
         return res
 
@@ -601,19 +641,21 @@ class TVMScriptParser(Transformer):
 
         condition = self.transform(node.condition)
         # then body
-        self.context.new_scope(nodes=node.true.stmts)
+        self.context.enter_scope(nodes=node.true.stmts)
         then_body = self.parse_body(node)
-        self.context.pop_scope()
+        self.context.exit_scope()
 
         # else body
         if len(node.false.stmts) > 0:
-            self.context.new_scope(nodes=node.false.stmts)
+            self.context.enter_scope(nodes=node.false.stmts)
             else_body = self.parse_body(node)
-            self.context.pop_scope()
+            self.context.exit_scope()
         else:
             else_body = None
 
-        return tvm.tir.IfThenElse(condition, then_body, else_body, 
span=from_synr_span(node.span))
+        return tvm.tir.IfThenElse(
+            condition, then_body, else_body, span=tvm_span_from_synr(node.span)
+        )
 
     def transform_Call(self, node):
         """Call visitor
@@ -633,18 +675,26 @@ class TVMScriptParser(Transformer):
                 lhs = self.transform(node.params[0])
                 rhs = self.transform(node.params[1])
                 return self._binop_maker[node.func_name.name](
-                    lhs, rhs, span=from_synr_span(node.span)
+                    lhs, rhs, span=tvm_span_from_synr(node.span)
                 )
             if node.func_name.name in self._unaryop_maker:
                 rhs = self.transform(node.params[0])
-                return self._unaryop_maker[node.func_name.name](rhs, 
span=from_synr_span(node.span))
+                return self._unaryop_maker[node.func_name.name](
+                    rhs, span=tvm_span_from_synr(node.span)
+                )
             self.report_error(f"Unsupported operator {node.func_name.name}.", 
node.func_name.span)
         else:
             func = self.transform(node.func_name)
             if isinstance(func, Intrin) and not func.stmt:
                 # pattern 1
                 arg_list = self.parse_arg_list(func, node)
-                return func.handle(arg_list, node.func_name.span)
+                return call_with_error_reporting(
+                    self.report_error,
+                    node.func_name.span,
+                    func.handle,
+                    arg_list,
+                    node.func_name.span,
+                )
             else:
                 args = [self.transform(arg) for arg in node.params]
                 kw_args = {
@@ -653,7 +703,7 @@ class TVMScriptParser(Transformer):
                 if isinstance(func, tvm.tir.op.Op):
                     # pattern 2
                     return tvm.tir.Call(
-                        kw_args["dtype"], func, args, 
span=from_synr_span(node.span)
+                        kw_args["dtype"], func, args, 
span=tvm_span_from_synr(node.span)
                     )
                 elif callable(func):
                     # pattern 3
@@ -700,7 +750,13 @@ class TVMScriptParser(Transformer):
             )
 
         if isinstance(func, Intrin) and func.stmt:
-            return func.handle(arg_list, node.call.func_name.span)
+            return call_with_error_reporting(
+                self.report_error,
+                node.call.func_name.span,
+                func.handle,
+                arg_list,
+                node.call.func_name.span,
+            )
         elif isinstance(func, WithScopeHandler) and func.concise_scope and not 
func.def_symbol:
             func.enter_scope(node, self.context, arg_list, 
node.call.func_name.span)
             func.body = self.parse_body(node)
@@ -716,11 +772,7 @@ class TVMScriptParser(Transformer):
         end = self.transform(node.end)
         if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
             self.report_error("Only step size 1 is supported for slices.", 
node.step.span)
-        extent = end - start
-        if isinstance(extent, tvm.tir.PrimExpr):
-            ana = tvm.arith.Analyzer()
-            extent = ana.simplify(extent)
-        return tvm.ir.Range.from_min_extent(start, extent, 
span=from_synr_span(node.span))
+        return Slice(start, end)
 
     def transform_Subscript(self, node):
         """Array access visitor.
@@ -728,7 +780,7 @@ class TVMScriptParser(Transformer):
         By now only 2 types of Subscript are supported:
             1. Buffer[index, index, ...], Buffer element access(BufferLoad & 
BufferStore)
                Var[index] Buffer element access()
-            2. meta[type_key][index], Meta info access
+            2. Buffer[start: stop, start: stop, ...], 
BufferRealize(realize(buffer[...]))
         """
 
         symbol = self.transform(node.params[0])
@@ -736,19 +788,27 @@ class TVMScriptParser(Transformer):
             self.report_error(f"Variable {node.value.id} is not defined.", 
node.params[0].span)
 
         indexes = [self.transform(x) for x in node.params[1].values]
-        if isinstance(indexes[0], tvm.ir.Range):
-            return symbol, indexes
-
         if isinstance(symbol, tvm.tir.expr.Var):
-            return tvm.tir.Load("float32", symbol, indexes, True, 
span=from_synr_span(node.span))
-        if isinstance(symbol, tvm.tir.Buffer):
-            return tvm.tir.BufferLoad(symbol, indexes, 
span=from_synr_span(node.span))
-
-        self.report_error(
-            f"Cannot subscript from a {type(symbol).__name__}. Only variables 
and "
-            "buffers are supported.",
-            node.params[0].span,
-        )
+            for index in indexes:
+                if not isinstance(index, (tvm.tir.PrimExpr, int)):
+                    self.report_error(
+                        "Buffer load indexes should be int or PrimExpr, but 
they are "
+                        + type(index),
+                        node.span,
+                    )
+            return tvm.tir.Load(
+                "float32", symbol, indexes, True, 
span=tvm_span_from_synr(node.span)
+            )
+        elif isinstance(symbol, tvm.tir.Buffer):
+            return BufferSlice(
+                symbol, indexes, self.report_error, 
span=tvm_span_from_synr(node.span)
+            )
+        else:
+            self.report_error(
+                f"Cannot subscript from a {type(symbol).__name__}. Only 
variables and "
+                "buffers are supported.",
+                node.params[0].span,
+            )
 
     def transform_Attr(self, node):
         """Visitor for field access of the form `x.y`.
@@ -756,7 +816,7 @@ class TVMScriptParser(Transformer):
         This visitor is used to lookup function and symbol names. We have two
         cases to handle here:
         1. If we have a statement of the form `tir.something`, then we lookup
-           `tir.somthing` in the `Registry`. If the function is not in the
+           `tir.something` in the `Registry`. If the function is not in the
            registry, then we try to find a `tvm.ir.op.Op` with the same name.
         2. All other names `tvm.something` are lookup up in this current python
            namespace.
@@ -875,7 +935,7 @@ class TVMScriptParser(Transformer):
         Constant values include `None`, `"strings"`, `2` (integers), `4.2`
         (floats), and `true` (booleans).
         """
-        return tvm.runtime.convert(node.value, span=from_synr_span(node.span))
+        return tvm.runtime.convert(node.value, 
span=tvm_span_from_synr(node.span))
 
     def transform_TypeConstant(self, node):
         """Constant value visitor for types.
@@ -902,8 +962,7 @@ def from_source(src):
     ----------
     src : [str, function, class]
         Pruned source of original script
-    func_lineno : Optional[int]
-        The line number of the first line of the script to be parsed
+
     Returns
     -------
     functions : PrimFunc or IRModule
diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py
index 3895701..245cc01 100644
--- a/python/tvm/script/registry.py
+++ b/python/tvm/script/registry.py
@@ -16,7 +16,8 @@
 # under the License.
 """TVM Script Parser Function Registry """
 # pylint: disable=inconsistent-return-statements, relative-beyond-top-level, 
import-outside-toplevel
-import inspect
+import types
+from typing import Union, Callable, Dict, Optional, Any
 
 
 class Registry(object):
@@ -24,10 +25,10 @@ class Registry(object):
     All these maps are static
     """
 
-    registrations = dict()
+    registrations: Dict[str, type] = dict()
 
     @staticmethod
-    def lookup(name):
+    def lookup(name: str) -> Optional[Any]:
         if name in Registry.registrations:
             # every time we create a new handler
             # since we may want to keep some local info inside it
@@ -35,12 +36,14 @@ class Registry(object):
         return None
 
 
-def register(inputs):
+def register(inputs: Union[Callable, type]) -> type:
     """Register Intrin/ScopeHandler/SpecialStmt"""
-    if inspect.isfunction(inputs):
+    registration: type
+    if isinstance(inputs, types.FunctionType):
+        # is function
         from .intrin import Intrin
 
-        def create_new_intrin(func):
+        def create_new_intrin(func) -> type:
             class NewIntrin(Intrin):
                 def __init__(self):
                     super().__init__(func)
@@ -48,11 +51,12 @@ def register(inputs):
             return NewIntrin
 
         registration = create_new_intrin(inputs)
-    elif inspect.isclass(inputs):
+    elif isinstance(inputs, type):
+        # is class
         registration = inputs
     else:
         raise ValueError()
 
-    key = registration().signature()[0]
+    key: str = registration().signature()[0]
     Registry.registrations[key] = registration
     return registration
diff --git a/python/tvm/script/scope_handler.py 
b/python/tvm/script/scope_handler.py
index 9449cbd..c7d841a 100644
--- a/python/tvm/script/scope_handler.py
+++ b/python/tvm/script/scope_handler.py
@@ -16,32 +16,59 @@
 # under the License.
 """TVM Script Parser Scope Handler Classes"""
 # pylint: disable=redefined-builtin, unused-argument, invalid-name, 
relative-beyond-top-level
+from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
 
+import synr
 from synr import ast
 import tvm.tir
-from .utils import get_param_list, from_synr_span
+from tvm.runtime import Object
+from tvm.ir import Span, Range
+from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
+
+from .context_maintainer import ContextMaintainer
+from .utils import (
+    get_param_list,
+    tvm_span_from_synr,
+    buffer_slice_to_region,
+    call_with_error_reporting,
+)
 from .registry import register
+from .node import BufferSlice
 
 
 class ScopeHandler:
     """Base class for all scope handlers"""
 
-    def __init__(self, func):
-        self.func = func
-        self.body = None
-        self.node = None
-        self.context = None
+    def __init__(self, func: Callable):
+        self.func: Callable = func
+        self.body: Optional[Stmt] = None
+        self.node: Optional[synr.ast.Node] = None
+        self.context: Optional[ContextMaintainer] = None
 
-    def signature(self):
+    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
         return "tir." + self.func.__name__, get_param_list(self.func)
 
-    def enter_scope(self, node, context, arg_list, span):
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
         pass
 
-    def exit_scope(self, node, context, arg_list, span):
+    def exit_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
         self.node = node
         self.context = context
-        return self.func(*arg_list, span=from_synr_span(span))
+        return call_with_error_reporting(
+            context.report_error, span, self.func, *arg_list, 
span=tvm_span_from_synr(span)
+        )
 
 
 class WithScopeHandler(ScopeHandler):
@@ -55,24 +82,29 @@ class WithScopeHandler(ScopeHandler):
     @staticmethod
     def get_optional_var_names(node, context):
         """Get list of names from ast.With's optional_vars"""
-        assert isinstance(node, ast.With)
-
-        var_names = None
-        if isinstance(node.items[0].optional_vars, ast.Name):
-            var_names = [node.items[0].optional_vars.id]
-        elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)):
-            for var in node.items[0].optional_vars.elts:
-                if not isinstance(var, ast.Name):
-                    context.report_error("Invalid optional var definition")
-            var_names = [var.id for var in node.items[0].optional_vars.elts]
+        assert isinstance(
+            node, ast.With
+        ), f"WithScopeHandler expected ast.With but got {type(node)}"
+
+        if isinstance(node.lhs, list):
+            for var in node.lhs:
+                if not isinstance(var, ast.Var):
+                    context.report_error(
+                        f"Invalid optional var definition, expected Var but 
got {type(var)}",
+                        node.span,
+                    )
+            var_names = [var.id.name for var in node.lhs]
         else:
-            context.report_error("Invalid optional var definition")
+            context.report_error(
+                f"Invalid optional var definition, expected list of Var but 
got {type(node.lhs)}",
+                node.span,
+            )
         return var_names
 
 
 @register
 class Allocate(WithScopeHandler):
-    """ With scope handler tir.alloc_with_scope(var, extents, dtype, scope, 
condition) """
+    """ With scope handler tir.allocate(extents, dtype, scope, condition) """
 
     def __init__(self):
         def allocate(extents, dtype, scope, condition=True, span=None):
@@ -86,7 +118,13 @@ class Allocate(WithScopeHandler):
         super().__init__(allocate, concise_scope=True, def_symbol=True)
         self.buffer_var = None
 
-    def enter_scope(self, node, context, arg_list, span):
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
         # define buffer vars in symbol table
         if isinstance(node, ast.With):
             names = WithScopeHandler.get_optional_var_names(node, context)
@@ -98,13 +136,13 @@ class Allocate(WithScopeHandler):
         else:
             raise Exception("Internal Bug")
 
-        def setup_buffer_var(extents, dtype, scope, condition=True, span=None):
+        def setup_buffer_var(extents, dtype, scope, condition=True, span: Span 
= None):
             """Setup buffer var for a given type."""
             buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
             self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
 
-        setup_buffer_var(*arg_list, span=from_synr_span(node.lhs.id.span))
-        context.update_symbol(name, self.buffer_var)
+        setup_buffer_var(*arg_list, span=tvm_span_from_synr(node.lhs.id.span))
+        context.update_symbol(name, self.buffer_var, node)
 
 
 @register
@@ -115,10 +153,10 @@ class LaunchThread(WithScopeHandler):
         def launch_thread(env_var, extent, span):
             extent = tvm.runtime.convert(extent, span=span)
             return tvm.tir.AttrStmt(
-                tvm.tir.IterVar(
+                IterVar(
                     None,
                     env_var,
-                    getattr(tvm.tir.IterVar, "ThreadIndex"),
+                    getattr(IterVar, "ThreadIndex"),
                     self.context.func_var_env_dict[env_var],
                     span=span,
                 ),
@@ -136,8 +174,19 @@ class Realize(WithScopeHandler):
     """ With scope handler tir.realize(buffer_bounds, scope, condition) """
 
     def __init__(self):
-        def realize(buffer_bounds, scope, condition=True, span=None):
-            buffer, bounds = buffer_bounds
+        def realize(
+            buffer_slice: BufferSlice, scope: str, condition: bool = True, 
span: bool = None
+        ):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            buffer: Buffer = buffer_slice.buffer
+            bounds: List[Range] = []
+            for s in buffer_slice.slices:
+                min: Union[PrimExpr, int] = s.start
+                extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop 
- s.start
+                if isinstance(extent, PrimExpr):
+                    extent = self.context.analyzer.simplify(extent)
+                bounds.append(Range.from_min_extent(min, extent, span=s.span))
+
             scope = tvm.runtime.convert(scope, span=span)
             return tvm.tir.AttrStmt(
                 buffer,
@@ -185,92 +234,380 @@ class Let(WithScopeHandler):
         super().__init__(let, concise_scope=False, def_symbol=False)
 
 
+@register
+class Block(WithScopeHandler):
+    """ With scope handler tir.block(extents, name) as iter_vars"""
+
+    def __init__(self):
+        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+            assert (
+                self.node and self.context and self.body
+            ), "call 'exit_scope' before 'enter_scope'"
+            block_info = self.context.block_info_stack[-1]
+            if axes is None:
+                axes = []
+            if len(axes) != len(self.block_vars):
+                self.context.report_error(
+                    "Inconsistent number of block vars, "
+                    + f"there are {len(axes)} axes but {len(self.block_vars)} 
block vars. "
+                    + "The number of block vars should match the number of 
axes.",
+                    self.node.span,
+                )
+            block_iters: List[IterVar] = []
+            for i, axis in enumerate(axes):
+                axis = tvm.runtime.convert(axis)
+                if isinstance(axis, tvm.tir.PrimExpr):
+                    block_var_dom = Range.from_min_extent(0, axis)
+                    block_iters.append(IterVar(block_var_dom, 
self.block_vars[i], 0))
+                elif isinstance(axis, Range):
+                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
+                elif isinstance(axis, IterVar):
+                    block_iters.append(IterVar(axis.dom, self.block_vars[i], 
axis.iter_type))
+                else:
+                    self.context.report_error(
+                        "Invalid argument of tir.block(), "
+                        + f"expected PrimExpr, Range or IterVar, but got 
{type(axis)}",
+                        self.node.span,
+                    )
+
+            # create block read/write regions
+
+            reads: List[BufferRegion] = (
+                [buffer_slice_to_region(read) for read in block_info.reads]
+                if block_info.reads
+                else []
+            )
+            writes: List[BufferRegion] = (
+                [buffer_slice_to_region(write) for write in block_info.writes]
+                if block_info.writes
+                else []
+            )
+            inner = tvm.tir.Block(
+                block_iters,
+                reads,
+                writes,
+                name_hint,
+                self.body,
+                block_info.init,
+                block_info.alloc_buffers,
+                block_info.match_buffers,
+                block_info.annotations,
+                span,
+            )
+            # create block var iter binding
+            values: List[PrimExpr]
+            if not block_info.iter_bindings:
+                values = self.context.loop_stack[-2].copy()
+                if len(values) == 0:
+                    values = [tvm.tir.const(float("nan"), dtype="float32")] * 
len(block_iters)
+                elif len(values) != len(block_iters):
+                    self.context.report_error(
+                        "Number of block iter var and outer loop nesting 
mismatch, "
+                        + f"{len(block_iters)} block iter vars but 
{len(values)} loops",
+                        self.node.span,
+                    )
+            else:
+                for block_var in self.block_vars:
+                    if block_var not in block_info.iter_bindings:
+                        self.context.report_error(
+                            "Missing block iter var binding for " + 
block_var.name,
+                            self.node.span,
+                        )
+                values = [block_info.iter_bindings[block_var] for block_var in 
self.block_vars]
+            predicate = (
+                tvm.tir.const(True, "bool")
+                if block_info.predicate is None
+                else block_info.predicate
+            )
+            body = tvm.tir.BlockRealize(values, predicate, inner, span)
+            return body
+
+        super().__init__(func=block, concise_scope=False, def_symbol=True)
+        self.block_vars = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        # define block vars
+        assert isinstance(
+            node, ast.With
+        ), f"BlockScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        var_names = WithScopeHandler.get_optional_var_names(node, context)
+        self.block_vars = [tvm.te.var(name) for name in var_names]
+        for block_var in self.block_vars:
+            context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+    """ With scope handler tir.init()"""
+
+    def __init__(self):
+        def init(span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if self.context.block_info_stack[-2].init is not None:
+                self.context.report_error("Duplicate init block declaration", 
span)
+            self.context.block_info_stack[-2].init = self.body
+
+        super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
 class ForScopeHandler(ScopeHandler):
     """Base class for all for scope handlers"""
 
     def __init__(self, func):
         super().__init__(func)
-        self.loop_vars = None
+        self.loop_vars: Optional[List[Var]] = None
 
-    def enter_scope(self, node, context, arg_list, span):
-        assert isinstance(node, ast.For)
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert isinstance(node, ast.For), f"ForScopeHandler expected ast.For 
but got {type(node)}"
 
         loop_var_names = list()
         spans = list()
         if isinstance(node.lhs, ast.Var):
             loop_var_names.append(node.lhs.id.name)
-            spans.append(from_synr_span(node.lhs.id.span))
-        elif isinstance(node.lhs, ast.Tuple):
-            for elt in node.lhs.values:
+            spans.append(tvm_span_from_synr(node.lhs.id.span))
+        elif isinstance(node.lhs, list):
+            for elt in node.lhs:
                 if not isinstance(elt, ast.Var):
-                    context.report_error("Invalid loop var", elt.span)
+                    context.report_error(
+                        f"Invalid loop var. Expected a var, but got 
{type(elt)}", elt.span
+                    )
                 loop_var_names.append(elt.id.name)
-                spans.append(from_synr_span(elt.id.span))
+                spans.append(tvm_span_from_synr(elt.id.span))
         else:
-            context.report_error("Invalid loop var", node.lhs.span)
+            context.report_error(
+                f"Invalid loop var. Expected var or list of vars as lhs, but 
got {type(node.lhs)}",
+                span,
+            )
 
         self.loop_vars = [
             tvm.te.var(name, dtype="int32", span=span) for name, span in 
zip(loop_var_names, spans)
         ]
         for loop_var in self.loop_vars:
-            context.update_symbol(loop_var.name, loop_var)
+            context.update_symbol(loop_var.name, loop_var, node)
+            context.loop_stack[-1].append(loop_var)
+
+    def exit_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
+        for _ in self.loop_vars:
+            context.loop_stack[-1].pop()
+        return super().exit_scope(node, context, arg_list, span)
+
+    def create_loop(
+        self,
+        begin: PrimExpr,
+        end: PrimExpr,
+        kind: ForKind,
+        thread_binding: Optional[str] = None,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> tvm.tir.For:
+        """
+        Helper function for creating For in TVM Script parser.
+
+        Parameters
+        ----------
+        begin : PrimExpr
+            The beginning value.
+
+        end : PrimExpr
+            The endding value.
+
+        kind : ForKind
+            The type of the for.
+
+        thread_binding: Optional[str]
+            The thread this loop binds to.
+
+        annotations : Optional[Mapping[str, Object]]
+            Additional annotation hints.
+
+        span : Optional[Span]
+            The location of this for in the source code.
+
+        Returns
+        -------
+        for : For
+            The constructed For.
+        """
+        assert (
+            self.loop_vars and self.context and self.node
+        ), "call 'exit_scope' before 'enter_scope'"
+        if len(self.loop_vars) != 1:
+            self.context.report_error(
+                f"Expected exactly one loop var, but got {self.loop_vars}", 
self.node.span
+            )
+        extent = end if begin == 0 else self.context.analyzer.simplify(end - 
begin)
+        annos: Mapping[str, Object] = {}
+        if annotations is not None:
+            annos = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in annotations.items()
+            }
+        return tvm.tir.For(
+            self.loop_vars[0],
+            begin,
+            extent,
+            kind,
+            self.body,
+            thread_binding=thread_binding,
+            annotations=annos,
+            span=span,
+        )
 
 
 @register
 class Serial(ForScopeHandler):
-    """ For scope handler tir.serial(begin, end)"""
+    """ For scope handler tir.serial(begin, end, annotations)"""
 
     def __init__(self):
-        def serial(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var", span)
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, 
span=span)
+        def serial(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(begin, end, ForKind.SERIAL, 
annotations=annotations, span=span)
 
         super().__init__(serial)
 
 
 @register
 class Parallel(ForScopeHandler):
-    """ For scope handler tir.parallel(begin, end)"""
+    """ For scope handler tir.parallel(begin, end, annotations)"""
 
     def __init__(self):
-        def parallel(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, 
span=span)
+        def parallel(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.PARALLEL, annotations=annotations, 
span=span
+            )
 
         super().__init__(parallel)
 
 
 @register
 class Vectorized(ForScopeHandler):
-    """ For scope handler tir.vectorized(begin, end)"""
+    """ For scope handler tir.vectorized(begin, end, annotations)"""
 
     def __init__(self):
-        def vectorized(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, 
span=span)
+        def vectorized(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.VECTORIZED, annotations=annotations, 
span=span
+            )
 
         super().__init__(vectorized)
 
 
 @register
 class Unroll(ForScopeHandler):
-    """ For scope handler tir.unroll(begin, end)"""
+    """ For scope handler tir.unroll(begin, end, annotations)"""
 
     def __init__(self):
-        def unroll(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, 
span=span)
+        def unroll(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.UNROLLED, annotations=annotations, 
span=span
+            )
 
         super().__init__(unroll)
+
+
+@register
+class ThreadBinding(ForScopeHandler):
+    """ For scope handler tir.thread_binding(begin, end, thread, 
annotations)"""
+
+    def __init__(self):
+        def thread_binding(
+            begin: PrimExpr,
+            end: PrimExpr,
+            thread: str,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread, 
span=span)
+            return self.create_loop(
+                begin,
+                end,
+                ForKind.THREAD_BINDING,
+                thread_binding=thread_iter_var,
+                annotations=annotations,
+                span=span,
+            )
+
+        super().__init__(thread_binding)
+
+
+@register
+class RangeHandler(ForScopeHandler):
+    """For scope handler range(begin, end, annotations)
+    Note that tir.range is totally the same as tir.serial
+    """
+
+    def __init__(self):
+        def for_range(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(begin, end, ForKind.SERIAL, 
annotations=annotations, span=span)
+
+        super().__init__(for_range)
+
+    def signature(self):
+        return "range", get_param_list(self.func)
+
+
+@register
+class Grid(ForScopeHandler):
+    """ For scope handler tir.grid(extents)"""
+
+    def __init__(self):
+        def grid(*extents: List[PrimExpr], span: Span):
+            assert (
+                self.node and self.context and self.loop_vars
+            ), "call 'exit_scope' before 'enter_scope'"
+            if len(self.loop_vars) != len(extents):
+                self.context.report_error(
+                    "Inconsistent number of loop vars and extents, "
+                    + f"got {len(self.loop_vars)} vs {len(extents)}",
+                    self.node.span,
+                )
+            body = self.body
+            for loop_var, extent in zip(reversed(self.loop_vars), 
reversed(extents)):
+                body = tvm.tir.For(loop_var, 0, extent, ForKind.SERIAL, body, 
span=span)
+            return body
+
+        super().__init__(grid)
diff --git a/python/tvm/script/special_stmt.py 
b/python/tvm/script/special_stmt.py
index 62ce1ea..6aa1239 100644
--- a/python/tvm/script/special_stmt.py
+++ b/python/tvm/script/special_stmt.py
@@ -17,30 +17,81 @@
 """TVM Script Parser Special Stmt Classes"""
 # pylint: disable=unused-argument, no-self-argument, 
inconsistent-return-statements
 # pylint: disable=relative-beyond-top-level
+from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+
+import synr
 from synr import ast
 
 import tvm.tir
+from tvm.runtime import Object
 from tvm import te
-from .utils import get_param_list, from_synr_span
+from tvm.ir import Span
+from tvm.tir import IntImm
+from .utils import (
+    get_param_list,
+    tvm_span_from_synr,
+    buffer_slice_to_region,
+    call_with_error_reporting,
+)
 from .registry import register
+from .context_maintainer import ContextMaintainer
+from .node import BufferSlice
+
+
+def convert_to_int(
+    value: Union[IntImm, int],
+    arg_name: str,
+    report_error: Callable,
+    span: Union[Span, synr.ast.Span],
+) -> int:
+    """convert a const int or TVM IntImm to Python int.
+    Reports an error when input cannot be converted to int.
+
+    Parameters
+    ----------
+    value : Union[tvm.tir.IntImm, int]
+        The input value to be converted.
+    arg_name : str
+        Function argument name for error reporting.
+    report_error: Callable
+        The report error function handle
+    span : Union[synr.ast.Span, tvm.ir.Span]
+        Location of the error
+    """
+    if isinstance(value, IntImm):
+        return value.value
+    if isinstance(value, int):
+        return value
+    report_error(
+        f"Expected int or IntImm for {arg_name}, but got {str(type(value))}",
+        span,
+    )
 
 
 class SpecialStmt:
     """Base class for all Special Stmts"""
 
-    def __init__(self, func, def_symbol):
-        self.func = func
-        self.def_symbol = def_symbol
-        self.node = None
-        self.context = None
+    def __init__(self, func: Callable, def_symbol: bool):
+        self.func: Callable = func
+        self.def_symbol: bool = def_symbol
+        self.node: Optional[synr.ast.Node] = None
+        self.context: Optional[ContextMaintainer] = None
 
-    def signature(self):
+    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
         return "tir." + self.func.__name__, get_param_list(self.func)
 
-    def handle(self, node, context, arg_list, span):
+    def handle(
+        self,
+        node: ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
         self.node = node
         self.context = context
-        return self.func(*arg_list, span=from_synr_span(span))
+        return call_with_error_reporting(
+            context.report_error, span, self.func, *arg_list, 
span=tvm_span_from_synr(span)
+        )
 
 
 @register
@@ -67,17 +118,20 @@ class MatchBuffer(SpecialStmt):
             buffer_type="default",
             span=None,
         ):
-            assert isinstance(self.node, ast.Assign)
-
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "match_buffer must be assigned to a buffer, e.g. A = 
match_buffer(...)",
+                    self.node.span,
+                )
             if param not in self.context.func_params:
                 self.context.report_error(
                     "Can not bind non-input param to buffer", 
self.node.rhs.params[0].span
                 )
             if strides is None:
                 strides = []
-            align = align.value if not isinstance(align, int) else align
-            offset_factor = (
-                offset_factor.value if not isinstance(offset_factor, int) else 
offset_factor
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
             )
             buffer = tvm.tir.decl_buffer(
                 shape,
@@ -93,7 +147,7 @@ class MatchBuffer(SpecialStmt):
                 span=span,
             )
             self.context.func_buffer_map[param] = buffer
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
 
         super().__init__(match_buffer, def_symbol=True)
 
@@ -121,13 +175,17 @@ class BufferDeclare(SpecialStmt):
             buffer_type="default",
             span=None,
         ):
-            assert isinstance(self.node, ast.Assign)
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "buffer_decl must be assigned to a buffer, e.g. A = 
buffer_decl(...)",
+                    self.node.span,
+                )
 
             if strides is None:
                 strides = []
-            align = align.value if not isinstance(align, int) else align
-            offset_factor = (
-                offset_factor.value if not isinstance(offset_factor, int) else 
offset_factor
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
             )
             buffer = tvm.tir.decl_buffer(
                 shape,
@@ -142,21 +200,293 @@ class BufferDeclare(SpecialStmt):
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
 @register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            if not isinstance(read_regions, list):
+                self.context.report_error(
+                    "Incorrect input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.writes is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.writes)),
+                    span,
+                )
+            if isinstance(write_region, list):
+                pass
+            elif isinstance(write_region, BufferSlice):
+                write_region = [write_region]
+            else:
+                self.context.report_error(
+                    "Incorrect input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(write_region)}",
+                    span,
+                )
+            block_scope.writes = write_region
+
+        super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+    """Special function block_attr({attr_key: attr_value})
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.block_attr({"double_buffer_scope": 1})
+    """
+
+    def __init__(self):
+        def block_attr(attrs: Mapping[str, Object], span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.annotations is not None:
+                self.context.report_error(
+                    "Duplicate block annotations declaration, "
+                    + "previous one is "
+                    + str(block_scope.annotations),
+                    span,
+                )
+            attrs = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in attrs.items()
+            }
+            block_scope.annotations = attrs
+
+        super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+    """Special function where(predicate)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.where(i < 4)
+    """
+
+    def __init__(self):
+        def where(predicate, span=None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.predicate is not None:
+                self.context.report_error(
+                    "Duplicate block predicate declaration, "
+                    + "previous one is "
+                    + str(block_scope.predicate),
+                    span,
+                )
+
+            block_scope.predicate = predicate
+
+        super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+    """Special function match_buffer_region(source, strides, elem_offset, 
align, offset_factor)
+
+    Example
+    -------
+    .. code-block:: python
+
+        B = tir.match_buffer_region(A[0: 4])
+    """
+
+    def __init__(self):
+        def match_buffer_region(
+            source,
+            strides=None,
+            elem_offset=None,
+            align=-1,
+            offset_factor=0,
+            span=None,
+        ):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "match_buffer_region must be assigned to a buffer, "
+                    + "e.g. A = match_buffer_region(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+
+            if not isinstance(source, BufferSlice):
+                self.context.report_error(
+                    "match_buffer_region needs a buffer region as source",
+                    span=span,
+                )
+            buffer_region = buffer_slice_to_region(source)
+            shape = [r.extent for r in buffer_region.region]
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                buffer_region.buffer.dtype,
+                self.node.lhs.id.name,
+                data=None,
+                strides=strides,
+                elem_offset=elem_offset,
+                scope=buffer_region.buffer.scope,
+                data_alignment=align,
+                offset_factor=offset_factor,
+                span=span,
+            )
+            self.context.current_block_scope().match_buffers.append(
+                tvm.tir.MatchBufferRegion(buffer, buffer_region)
+            )
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(match_buffer_region, def_symbol=True)
+
+
+@register
 class VarDef(SpecialStmt):
     """ Special function for defining a Var"""
 
     def __init__(self):
         def var(dtype, span):
-            assert isinstance(self.node, ast.Assign)
+            assert isinstance(
+                self.node, ast.Assign
+            ), f"VarDef expected ast.Assign but got {type(self.node)}"
             v = te.var(self.node.lhs.id.name, dtype, span=span)
-            self.context.update_symbol(v.name, v)
+            self.context.update_symbol(v.name, v, self.node)
 
         super().__init__(var, def_symbol=True)
 
@@ -167,10 +497,12 @@ class EnvThread(SpecialStmt):
 
     def __init__(self):
         def env_thread(env_name, span):
-            assert isinstance(self.node, ast.Assign)
+            assert isinstance(
+                self.node, ast.Assign
+            ), f"EnvThread expected ast.Assign but got {type(self.node)}"
             v = te.var(self.node.lhs.id.name, span=span)
             self.context.func_var_env_dict[v] = env_name
-            self.context.update_symbol(v.name, v)
+            self.context.update_symbol(v.name, v, self.node)
 
         super().__init__(env_thread, def_symbol=True)
 
diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py
index a6ba9d0..f8a0f61 100644
--- a/python/tvm/script/utils.py
+++ b/python/tvm/script/utils.py
@@ -16,15 +16,32 @@
 # under the License.
 """Helper functions in TVM Script Parser"""
 
+from typing import Callable, List, Any, Optional, Tuple, Union
+
 import inspect
-from ..ir import Span, SourceName
+import synr
+
+from tvm.arith import Analyzer
+from tvm.ir import Range, Span, SourceName
+from tvm.tir import PrimExpr, BufferRegion
+from tvm.error import DiagnosticError
+from .node import BufferSlice
 
 
-def get_param_list(func):
+def get_param_list(
+    func: Callable,
+) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]:
     """Get the parameter list from definition of function"""
-    full_arg_spec = inspect.getfullargspec(func)
+    full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func)
 
-    args, defaults = full_arg_spec.args, full_arg_spec.defaults
+    args: List[str]
+    defaults: Optional[Tuple[Any, ...]]
+    kwonlyargs: List[str]
+    args, defaults, kwonlyargs = (
+        full_arg_spec.args,
+        full_arg_spec.defaults,
+        full_arg_spec.kwonlyargs,
+    )
 
     if defaults is None:
         defaults = tuple()
@@ -33,14 +50,17 @@ def get_param_list(func):
         raise RuntimeError(
             "TVM Script register error : variable keyword argument is not 
supported now"
         )
-    if not len(full_arg_spec.kwonlyargs) == 0:
+
+    if len(kwonlyargs) == 1 and kwonlyargs[0] == "span":
+        pass
+    elif not len(kwonlyargs) == 0:
         raise RuntimeError("TVM Script register error : keyword only argument 
is not supported now")
 
-    pos_only = list()
+    pos_only: List[str] = list()
     for arg in args[: len(args) - len(defaults)]:
         if arg != "span":
             pos_only.append(arg)
-    kwargs = list()
+    kwargs: List[Tuple[str, Tuple[Any, ...]]] = list()
     for default, arg in zip(defaults, args[len(args) - len(defaults) :]):
         if arg != "span":
             kwargs.append((arg, default))
@@ -48,7 +68,37 @@ def get_param_list(func):
     return pos_only, kwargs, full_arg_spec.varargs
 
 
-def from_synr_span(span):
+def buffer_slice_to_region(
+    buffer_slice: BufferSlice, analyzer: Optional[Analyzer] = None
+) -> BufferRegion:
+    """Construct BufferRegion from BufferSlice
+
+    Parameters
+    ----------
+    buffer_slice : BufferSlice
+        The input BufferSlice
+
+    analyzer : Optional[tvm.arith.Analyzer]
+        The analyzer for simplifying. If not provided, the method will 
construct a new one
+
+    Returns
+    -------
+    buffer_region : BufferRegion
+        The constructed BufferRegion.
+    """
+    region: List[Range] = []
+    for s in buffer_slice.slices:
+        start: Union[PrimExpr, int] = s.start
+        extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - 
s.start
+        if not analyzer:
+            analyzer = Analyzer()
+        if isinstance(extent, PrimExpr):
+            extent = analyzer.simplify(extent)
+        region.append(Range.from_min_extent(start, extent, span=s.span))
+    return BufferRegion(buffer_slice.buffer, region)
+
+
+def tvm_span_from_synr(span: synr.ast.Span) -> Span:
     """Convert a synr span to a TVM span"""
     return Span(
         SourceName(span.filename),
@@ -57,3 +107,32 @@ def from_synr_span(span):
         span.start_column,
         span.end_column,
     )
+
+
+def synr_span_from_tvm(span: Span) -> synr.ast.Span:
+    """Convert a TVM span to a synr span"""
+    return synr.ast.Span(
+        span.source_name.name,
+        span.line,
+        span.column,
+        span.end_line,
+        span.end_column,
+    )
+
+
+def call_with_error_reporting(
+    report_error,
+    node_span,
+    func,
+    *args,
+    **kwargs,
+):
+    """Call function with exception handling and report error using 
node_span"""
+    try:
+        return func(*args, **kwargs)
+    except DiagnosticError:
+        raise
+    except Exception as err:  # pylint: disable=broad-except
+        # printing last non-empty row of error message.
+        error_msg = list(filter(None, str(err).split("\n")))[-1]
+        report_error(error_msg, node_span)
diff --git a/python/tvm/tir/analysis/analysis.py 
b/python/tvm/tir/analysis/analysis.py
index 1a3eb48..829eb8b 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -106,3 +106,26 @@ def verify_gpu_code(func, constraints):
         The result of verification.
     """
     return _ffi_api.verify_gpu_code(func, constraints)
+
+
+def get_block_access_region(block, buffer_var_map):
+    """Detect which regions of tensors in this block are read or written to.
+       Regions are sorted by order of appearance in the AST.
+
+    Parameters
+    ----------
+    block: tvm.tir.Block
+        The block in which we are detecting read/write regions.
+
+    buffer_var_map : Dict[Var, Buffer]
+        The outside buffers which may access the block. Mapping from buffer 
var to the buffer
+
+    Returns
+    -------
+    result : List[List[BufferRegion]]
+        Array of access regions. There are three arrays of BufferRegion:
+            - first: read regions
+            - second: write regions
+            - third: opaque regions
+    """
+    return _ffi_api.get_block_access_region(block, buffer_var_map)
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index 8d5bba5..7880740 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -476,8 +476,7 @@ inline const char* ForKind2String(ForKind t) {
     case ForKind::kUnrolled:
       return "unroll";
     case ForKind::kThreadBinding:
-      LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
-                 << "not yet supported in TIR";
+      return "thread_binding";
   }
   LOG(FATAL) << "Unknown ForKind";
   return "Unknown";
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 86b175e..4380795 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -22,6 +22,7 @@
  * \brief Printer class to print Tensor IR to python syntax script
  */
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/module.h>
 #include <tvm/node/serialization.h>
 #include <tvm/runtime/registry.h>
@@ -66,7 +67,10 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const 
Stmt&)>,
   std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_;
   /*! \brief var collector (var defined by For/Loop/Block) */
   std::unordered_set<const VarNode*> var_not_in_headers;
-  /*! \brief buffer collector (buffer defined in BufferMap and 
BufferAllocation)*/
+  /*!
+   * \brief buffer collector
+   *        (buffer defined in BufferMap, BufferAllocation and 
MatchBufferRegion)
+   */
   std::unordered_set<const BufferNode*> buf_not_in_headers;
   /*! \brief Map from Var to thread env name */
   std::unordered_map<Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
@@ -84,6 +88,8 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   int num_child_;
   /*! \brief the number of current node */
   int current_num_;
+  /*! \brief loop stack without annotations */
+  std::vector<For> loop_stack_;
 
   Doc VisitExpr_(const CastNode* op) override;
   Doc VisitExpr_(const VarNode* op) override;
@@ -131,6 +137,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const 
Stmt&)>,
   Doc VisitStmt_(const ForNode* op) override;
   Doc VisitStmt_(const PrefetchNode* op) override;
   Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const BlockRealizeNode* op) override;
   Doc VisitStmtDefault_(const Object* op) override;
 
   Doc VisitType_(const PrimTypeNode* node) override;
@@ -145,12 +152,24 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const 
Stmt&)>,
   Doc PrintArray(const ArrayNode* op);
   Doc PrintBuffer(const BufferNode* op);
   Doc AllocBufferDeclaration(const Buffer& buf);
+  Doc PrintBufferRegion(const BufferRegionNode* op);
+  Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op);
+  Doc PrintAnnotations(const Map<String, ObjectRef>& annotations);
   static Doc PrintString(const StringObj* op) { return 
Doc::StrLiteral(op->data); }
 
   Doc GetUniqueName(std::string prefix);
   Doc AllocVar(const Var& var);
   Doc AllocBuf(const Buffer& buffer);
 
+  /*! Helper functions for loop printing. */
+  /*!
+   * \brief Print a single for loop
+   * \param loop The for loop to be printed
+   */
+  Doc PrintLoop(const For& loop);
+  /*! \brief Print all simple loops in stack into one line using tir.grid(). */
+  Doc PrintLoopStack();
+
   /*!
    * \brief Print additional info about expr in comment.
    * \param expr The expression.
@@ -308,6 +327,36 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) {
   return val;
 }
 
+Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
+  const Buffer& buf = op->buffer;
+  buf_not_in_headers.insert(buf.get());
+
+  Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" << 
Print(op->source);
+  if (!buf->strides.empty()) {
+    doc << ", strides=" << Print(buf->strides);
+  }
+  if (buf->offset_factor != 0 && buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      doc << ", elem_offset=" << Print(buf->elem_offset);
+    } else {
+      // implicitly define elem_offset
+      memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + 
".elem_offset");
+      var_not_in_headers.insert(elem_offset.get());
+    }
+  } else {
+    doc << ", elem_offset=" << Print(buf->elem_offset);
+  }
+  if (buf->data_alignment != -1) {
+    doc << ", align=" << buf->data_alignment;
+  }
+  if (buf->offset_factor != 0) {
+    doc << ", offset_factor=" << buf->offset_factor;
+  }
+  doc << ")";
+  return doc;
+}
+
 Doc TVMScriptPrinter::Print(const ObjectRef& node) {
   if (!node.defined()) return Doc::Text("None");
   if (node->IsInstance<StmtNode>()) {
@@ -330,6 +379,10 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) {
     return PrintIterVar(node.as<IterVarNode>());
   } else if (node->IsInstance<RangeNode>()) {
     return PrintRange(node.as<RangeNode>());
+  } else if (node->IsInstance<BufferRegionNode>()) {
+    return PrintBufferRegion(node.as<BufferRegionNode>());
+  } else if (node->IsInstance<MatchBufferRegionNode>()) {
+    return PrintMatchBufferRegion(node.as<MatchBufferRegionNode>());
   } else {
     meta_collector_.Collect(node);
     return this->meta_.GetMetaNode(node);
@@ -660,9 +713,7 @@ inline const char* ForKind2String(ForKind t) {
     case ForKind::kUnrolled:
       return "unroll";
     case ForKind::kThreadBinding:
-      LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
-                 << "not yet supported in TIR";
-      return "threadbinding";
+      return "thread_binding";
   }
   LOG(FATAL) << "Unknown ForKind";
   return "Unknown";
@@ -671,9 +722,27 @@ inline const char* ForKind2String(ForKind t) {
 Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
   Doc doc;
   var_not_in_headers.insert(op->loop_var.get());
-  doc << "for " << Print(op->loop_var) << " in tir." + 
std::string(ForKind2String(op->kind)) + "("
-      << Print(op->min) << ", " << Print(op->min + op->extent)
-      << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+  const auto* body = op->body.as<ForNode>();
+  bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() 
&& is_zero(op->min);
+  if (simple_loop) loop_stack_.push_back(GetRef<For>(op));
+  // It is a loop that can be compressed, let the loops below print it out
+  if (simple_loop && body != nullptr) return Print(GetRef<For>(body));
+  // It is a loop that can not be compressed
+  bool print_above = !loop_stack_.empty();
+  // print loops above if needed
+  if (print_above) {
+    doc << PrintLoopStack();
+    loop_stack_.clear();
+  }
+  if (!simple_loop) {
+    // print current loop if needed
+    Doc current_loop;
+    current_loop << PrintLoop(GetRef<For>(op));
+    current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+    doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : 
current_loop);
+  } else {
+    doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+  }
   return doc;
 }
 
@@ -713,6 +782,88 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* 
op) {
   return doc;
 }
 
+Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
+  const auto* block_op = op->block.as<BlockNode>();
+  // print block name and block vars
+  Doc doc;
+  doc << "with tir.block([";
+  std::vector<Doc> block_var_docs;
+  for (const auto& iter_var : block_op->iter_vars) {
+    Doc block_var_doc;
+    if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) {
+      block_var_doc << Print(iter_var->dom->extent);
+    } else {
+      block_var_doc << "tir.";
+      switch (iter_var->iter_type) {
+        case kDataPar:
+          block_var_doc << "range";
+          break;
+        case kCommReduce:
+          block_var_doc << "reduce_axis";
+          break;
+        case kOrdered:
+          block_var_doc << "scan_axis";
+          break;
+        case kOpaque:
+          block_var_doc << "opaque_axis";
+          break;
+        default:
+          LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type;
+          break;
+      }
+      block_var_doc << "(" << Print(iter_var->dom->min) << ", "
+                    << Print(iter_var->dom->min + iter_var->dom->extent) << 
")";
+    }
+    block_var_docs.push_back(block_var_doc);
+  }
+  doc << PrintSep(block_var_docs, Doc::Text(", ")) << "], ";
+  doc << Doc::StrLiteral(block_op->name_hint) << ")";
+  std::vector<Doc> block_var_names;
+  for (const auto& iter_var : block_op->iter_vars) {
+    var_not_in_headers.insert(iter_var->var.get());
+    block_var_names.push_back(Print(iter_var->var));
+  }
+  if (!block_var_names.empty()) {
+    doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]";
+  }
+  doc << ":";
+  Doc block_attr_doc;
+  // print predicate, binding, read/write tensor region, annotations
+  if (!is_one(op->predicate)) {
+    block_attr_doc << Doc::NewLine() << "tir.where(" << Print(op->predicate) 
<< ")";
+  }
+  for (size_t i = 0; i < block_op->iter_vars.size(); ++i)
+    block_attr_doc << Doc::NewLine() << "tir.bind(" << 
Print(block_op->iter_vars[i]->var) << ", "
+                   << Print(op->iter_values[i]) << ")";
+  block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) 
<< ")";
+  block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) 
<< ")";
+  if (!block_op->annotations.empty()) {
+    block_attr_doc << Doc::NewLine() << "tir.block_attr({";
+    block_attr_doc << PrintAnnotations(block_op->annotations);
+    block_attr_doc << "})";
+  }
+  // print body
+  Doc body;
+  body << Doc::NewLine();
+  for (const auto& alloc_buf : block_op->alloc_buffers) {
+    buf_not_in_headers.insert(alloc_buf.get());
+    body << Print(alloc_buf) << " = tir.alloc_buffer(" << 
memo_buf_decl_[alloc_buf] << ")"
+         << Doc::NewLine();
+  }
+  for (const auto& match_buf : block_op->match_buffers) {
+    body << Print(match_buf) << Doc::NewLine();
+  }
+  if (block_op->init.defined()) {
+    Doc init_block;
+    init_block << "with tir.init():";
+    init_block << Doc::Indent(4, Doc::NewLine() << 
PrintBody(block_op->init.value()));
+    body << init_block << Doc::NewLine();
+  }
+  body << PrintBody(block_op->body);
+  doc << Doc::Indent(4, block_attr_doc << body);
+  return doc;
+}
+
 Doc TVMScriptPrinter::PrintBody(const Stmt& body) {
   int memo_num_child, memo_current_num;
   std::swap(memo_num_child, num_child_);
@@ -890,6 +1041,73 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
   return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
 }
 
+Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << "[";
+  for (size_t i = 0; i < op->region.size(); ++i) {
+    if (i != 0) doc << ", ";
+    const auto& range = op->region[i];
+    if (!is_one(range->extent)) {
+      doc << Print(range->min) << ":" << Print(range->min + range->extent);
+    } else {
+      doc << Print(range->min);
+    }
+  }
+  doc << "]";
+  return doc;
+}
+
+Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& 
annotations) {
+  Doc res;
+  std::vector<std::pair<String, ObjectRef>> anno_list;
+  anno_list.reserve(annotations.size());
+  for (const auto& pair : annotations) {
+    anno_list.emplace_back(pair);
+  }
+  sort(anno_list.begin(), anno_list.end());
+  for (size_t i = 0; i < anno_list.size(); ++i) {
+    if (i != 0) {
+      res << ", ";
+    }
+    res << "\"" << anno_list[i].first << "\":" << Print(anno_list[i].second);
+  }
+  return res;
+}
+
+Doc TVMScriptPrinter::PrintLoop(const For& loop) {
+  Doc res;
+  res << "for " << Print(loop->loop_var)
+      << " in tir." + std::string(ForKind2String(loop->kind)) + "(" << 
Print(loop->min) << ", "
+      << Print(loop->min + loop->extent);
+  if (loop->thread_binding.defined()) {
+    res << ", thread = ";
+    res << Print(loop->thread_binding.value()->thread_tag);
+  }
+  if (!loop->annotations.empty()) {
+    res << ", annotation = {";
+    res << PrintAnnotations(loop->annotations);
+    res << "}";
+  }
+  res << "):";
+  return res;
+}
+
+Doc TVMScriptPrinter::PrintLoopStack() {
+  Doc res;
+  if (loop_stack_.size() == 1) {
+    res << PrintLoop(loop_stack_[0]);
+  } else if (loop_stack_.size() > 1) {
+    std::vector<Doc> vars, extents;
+    for (const auto& loop : loop_stack_) {
+      vars.push_back(Print(loop->loop_var));
+      extents.push_back(Print(loop->extent));
+    }
+    res << "for " << PrintSep(vars, Doc::Text(", ")) << " in tir.grid("
+        << PrintSep(extents, Doc::Text(", ")) << "):";
+  }
+  return res;
+}
+
 TVM_REGISTER_GLOBAL("script.AsTVMScript")
     .set_body_typed<std::string(const ObjectRef&, bool)>([](const ObjectRef& 
functions,
                                                             bool show_meta) {
diff --git a/src/tir/analysis/block_access_region_detector.cc 
b/src/tir/analysis/block_access_region_detector.cc
new file mode 100644
index 0000000..b1da536
--- /dev/null
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/block_region_detector.cc
+ * \brief Detect block read/write regions by visiting its body
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Detect which regions of tensors in this block are read or written 
to. Regions are sorted
+ * by order of appearance in the AST. \note This detector can only visit 
blocks and will not visit
+ * child blocks recursively
+ */
+class BlockReadWriteDetector : public StmtExprVisitor {
+ public:
+  explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
+      : buffer_var_map_(buffer_var_map) {}
+
+  /*! \brief Return read regions of the block */
+  Array<BufferRegion> CollectReads();
+  /*! \brief Return write regions of the block */
+  Array<BufferRegion> CollectWrites();
+  /*!
+   * \brief Return opaque buffer regions of the block
+   * \note The buffer accessed by load/store or call with buffer.data will
+   *       be marked as opaque.
+   */
+  Array<BufferRegion> CollectOpaques();
+  /*! \brief overload operator() to make sure it accepts a block node */
+  void operator()(const Stmt& stmt);
+
+ private:
+  /*! \brief Iteration range for loop_vars */
+  std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
+  /*! \brief The buffers that the current block reads */
+  std::vector<Buffer> read_buffers_;
+  /*! \brief The buffers that the current block writes */
+  std::vector<Buffer> writes_buffers_;
+  /*! \brief The opaque buffer which is access by buffer.data */
+  std::vector<Buffer> opaque_buffers_;
+  /*! \brief The read regions of the current block */
+  std::vector<std::vector<tvm::arith::IntSet>> read_regions_;
+  /*! \brief The write regions of the current block */
+  std::vector<std::vector<tvm::arith::IntSet>> write_regions_;
+  /*! \brief The outside buffer data mapping to its buffer */
+  Map<Var, Buffer> buffer_var_map_;
+  /*! \brief The analyzer for simplifying*/
+  arith::Analyzer analyzer_;
+
+  /*!
+   * \brief Update read/write buffers and regions with provided buffer and 
region
+   * \param buffers The buffers should be updated
+   * \param regions The access regions should be updated
+   * \param buffer The provided buffer
+   * \param region The provided region
+   */
+  void Update(std::vector<Buffer>* buffers, 
std::vector<std::vector<arith::IntSet>>* regions,
+              const Buffer& buffer, const std::vector<arith::IntSet>& region);
+
+  /*! \brief Helper function to collect access regions. */
+  Array<BufferRegion> CollectRegions(const std::vector<Buffer>& buffers,
+                                     const 
std::vector<std::vector<tvm::arith::IntSet>>& regions);
+
+  /*! \brief Helper function to add a opaque buffer. */
+  void AddOpaque(const Var& buffer_var);
+
+  void VisitStmt_(const ForNode* op) override;
+  void VisitStmt_(const BlockRealizeNode* op) override;
+  void VisitStmt_(const BufferStoreNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitExpr_(const BufferLoadNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitExpr_(const VarNode* op) override;
+};
+
+void BlockReadWriteDetector::operator()(const Stmt& stmt) {
+  ICHECK(stmt.as<BlockNode>() != nullptr)
+      << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
+  StmtExprVisitor::operator()(stmt);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectReads() {
+  return CollectRegions(read_buffers_, read_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectWrites() {
+  return CollectRegions(writes_buffers_, write_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
+  Array<BufferRegion> res;
+  res.reserve(opaque_buffers_.size());
+  for (const Buffer& buffer : opaque_buffers_) {
+    res.push_back(BufferRegion::FullRegion(buffer));
+  }
+  return res;
+}
+
+void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { 
AddOpaque(GetRef<Var>(op)); }
+
+void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
+  AddOpaque(op->buffer_var);
+  ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
+  std::vector<arith::IntSet> relaxed_region;
+  for (const PrimExpr& index : op->indices) {
+    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+  }
+  Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
+  ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
+  Range range = Range::FromMinExtent(op->min, op->extent);
+  dom_map_[op->loop_var.get()] = arith::IntSet::FromRange(range);
+  StmtVisitor::VisitStmt_(op);
+  dom_map_.erase(op->loop_var.get());
+}
+
+void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
+  AddOpaque(op->buffer_var);
+  StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
+  std::vector<arith::IntSet> relaxed_region;
+  for (const PrimExpr& index : op->indices) {
+    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+  }
+  Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
+  StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
+  /*! \note detector will not visit child block recursively, so it will stop 
here */
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
+  for (size_t i = 0; i < op->block->iter_vars.size(); ++i) {
+    vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i];
+  }
+  for (const auto& read : op->block->reads) {
+    std::vector<arith::IntSet> relaxed_region;
+    for (const auto& range : read->region) {
+      relaxed_region.push_back(
+          arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
+                             Substitute(range->min, vmap), 
Substitute(range->extent, vmap))),
+                         dom_map_));
+    }
+    Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
+  }
+  for (const auto& write : op->block->writes) {
+    std::vector<arith::IntSet> relaxed_region;
+    for (const auto& range : write->region) {
+      relaxed_region.push_back(
+          arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
+                             Substitute(range->min, vmap), 
Substitute(range->extent, vmap))),
+                         dom_map_));
+    }
+    Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
+  }
+}
+
+void BlockReadWriteDetector::Update(std::vector<Buffer>* buffers,
+                                    std::vector<std::vector<arith::IntSet>>* 
regions,
+                                    const Buffer& buffer,
+                                    const std::vector<arith::IntSet>& region) {
+  if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return;
+  ICHECK_EQ(buffers->size(), regions->size())
+      << " Expected the buffer and regions to have the same size ";
+  for (size_t i = 0; i < regions->size(); ++i) {
+    if ((*buffers)[i].same_as(buffer)) {
+      ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer 
dimension";
+      for (size_t j = 0; j < region.size(); ++j) {
+        (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]});
+      }
+      return;
+    }
+  }
+  buffers->push_back(buffer);
+  regions->push_back(region);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
+    const std::vector<Buffer>& buffers,
+    const std::vector<std::vector<tvm::arith::IntSet>>& regions) {
+  ICHECK_EQ(buffers.size(), regions.size());
+  Array<BufferRegion> res;
+  res.reserve(buffers.size());
+  for (size_t i = 0; i < regions.size(); ++i) {
+    Array<Range> region;
+    region.reserve(regions[i].size());
+    for (size_t j = 0; j < regions[i].size(); j++) {
+      tvm::arith::IntSet range = regions[i][j];
+      region.push_back(range.CoverRange(Range::FromMinExtent(0, 
buffers[i]->shape[j])));
+    }
+    res.push_back(BufferRegion(buffers[i], region));
+  }
+  return res;
+}
+
+void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) {
+  auto it = buffer_var_map_.find(buffer_var);
+  if (it != buffer_var_map_.end()) {
+    const Buffer& buffer = (*it).second;
+    for (const Buffer& opaque_buffer : opaque_buffers_) {
+      if (buffer.same_as(opaque_buffer)) return;
+    }
+    opaque_buffers_.push_back(buffer);
+  }
+}
+
+Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
+                                                const Map<Var, Buffer>& 
buffer_var_map) {
+  BlockReadWriteDetector detector(buffer_var_map);
+  detector(block);
+  return {detector.CollectReads(), detector.CollectWrites(), 
detector.CollectOpaques()};
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/ir/script/script_complete.cc 
b/src/tir/ir/script/script_complete.cc
new file mode 100644
index 0000000..7c9fff7
--- /dev/null
+++ b/src/tir/ir/script/script_complete.cc
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/script/script_complete.cc
+ * \brief Used by TVM Script parser to expand incomplete TIR input
+ */
+
+#include <tvm/arith/int_set.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*! \brief Generate surrounding loops automatically */
+class ScriptCompleter : public StmtMutator {
+ public:
+  explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) : 
buffer_var_map_(buffer_var_map) {}
+  /*! \brief Whether the stmt contains at least one block. */
+  bool contains_block = false;
+
+ private:
+  Map<Var, Buffer>* buffer_var_map_;
+  Stmt VisitStmt_(const BlockRealizeNode* op) override {
+    contains_block = true;
+    Stmt body = StmtMutator::VisitStmt_(op);
+    if (!op->iter_values.empty() && !op->iter_values[0].dtype().is_int()) {
+      auto block_with_binding = 
CopyOnWrite(Downcast<BlockRealize>(body).get());
+      std::vector<PrimExpr> bindings;
+      for (size_t i = 0; i < op->iter_values.size(); ++i) {
+        bindings.push_back(Var("i" + std::to_string(i)));
+      }
+      block_with_binding->iter_values = bindings;
+      body = BlockRealize(block_with_binding);
+      for (int i = op->iter_values.size() - 1; i >= 0; --i) {
+        body = For(Downcast<Var>(bindings[i]), 
op->block->iter_vars[i]->dom->min,
+                   op->block->iter_vars[i]->dom->extent, {}, body);
+      }
+    }
+    return body;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) override {
+    // Buffers allocated in the block can be accessed by its body.
+    for (const auto& alloc_buffer : op->alloc_buffers) {
+      buffer_var_map_->Set(alloc_buffer->data, alloc_buffer);
+    }
+    Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
+    // Remove buffers allocated inside block to detect its access region
+    for (const auto& alloc_buffer : op->alloc_buffers) {
+      buffer_var_map_->erase(alloc_buffer->data);
+    }
+    if (block->reads.empty() || block->writes.empty()) {
+      auto access_region = GetBlockAccessRegion(block, *buffer_var_map_);
+      const Array<BufferRegion>& reads = access_region[0];
+      const Array<BufferRegion>& writes = access_region[1];
+      const Array<BufferRegion>& opaque = access_region[2];
+      CHECK(opaque.empty())
+          << "ValueError: Can not auto detect buffer access region from 
tir.Load, tir.Store or "
+             "direct access by buffer data. Please annotation the access 
region manually";
+      auto n = CopyOnWrite(block.operator->());
+      if (!n->reads.defined()) n->reads = reads;
+      if (!n->writes.defined()) n->writes = writes;
+      return Block(n);
+    } else {
+      return std::move(block);
+    }
+  }
+};
+
+PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
+  Map<Var, Buffer> buffer_var_map;
+  for (const auto& pair : func->buffer_map) {
+    const Buffer& buffer = pair.second;
+    buffer_var_map.Set(buffer->data, buffer);
+  }
+  for (const auto& alloc : root_allocates) {
+    buffer_var_map.Set(alloc->data, alloc);
+  }
+  ScriptCompleter script_completer(&buffer_var_map);
+  // generate surrounding loops automatically
+  Stmt res = script_completer(func->body);
+  // generate root block automatically
+  if (script_completer.contains_block &&
+      (!res->IsInstance<BlockRealizeNode>() || !root_allocates.empty())) {
+    res = Block({}, {}, {}, "root", res, NullOpt, root_allocates);
+    res = BlockRealize({}, Bool(true), Downcast<Block>(res));
+  }
+  if (func->body.same_as(res)) {
+    return func;
+  } else {
+    auto fptr = func.CopyOnWrite();
+    fptr->body = res;
+    return func;
+  }
+}
+
+TVM_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py 
b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
new file mode 100644
index 0000000..7e4d7d8
--- /dev/null
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir, script
+from tvm.ir import Range
+
+
+@tvm.script.tir
+def func() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    B = tir.alloc_buffer((128, 128), "float32")
+    C = tir.alloc_buffer((128, 128), "float32")
+    D = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([]):
+        # Need add read/write region manually to avoid triggering block access 
region detector
+        tir.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]])
+        tir.writes([A[0:12, 0:12]])
+        for i, j in tir.grid(8, 8):
+            A[i, j] = B[0, 0] + C[0, 0]
+        with tir.block([2, 2]) as [vi, vj]:
+            tir.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], 
C[12:16, 12:16]])
+            tir.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]])
+            for i, j in tir.grid(4, 4):
+                A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12]
+        tir.evaluate(D.data)
+
+
+def test_block_access_region_detector():
+    block = func.body.block.body.block
+    alloc_buffers = func.body.block.alloc_buffers
+    buffer_var_map = {buf.data: buf for buf in alloc_buffers}
+    ret = tir.analysis.get_block_access_region(block, buffer_var_map)
+
+    tvm.ir.assert_structural_equal(block.reads, ret[0])
+    tvm.ir.assert_structural_equal(block.writes, ret[1])
+    D = alloc_buffers[-1]
+    tvm.ir.assert_structural_equal(
+        [tvm.tir.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2]
+    )
+
+
+if __name__ == "__main__":
+    test_block_access_region_detector()
diff --git a/tests/python/unittest/test_tvmscript_error_report.py 
b/tests/python/unittest/test_tvmscript_error_report.py
index 048a954..052217b 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -144,6 +144,197 @@ def test_no_body():
     check_error(no_body, 3)
 
 
+def allocate_with_buffers() -> None:
+    with tir.allocate([1], "float32", "") as [A, B]:  # error
+        tir.evaluate(1.0)
+
+
+def test_allocate_with_buffers():
+    check_error(allocate_with_buffers, 2)
+
+
+def inconsistent_binding() -> None:
+    with tir.block([128, 128]) as [vi]:  # error
+        tir.evaluate(1.0)
+
+
+def test_inconsistent_binding():
+    check_error(inconsistent_binding, 2)
+
+
+def invalid_block_axes(a: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    with tir.block([A]) as [vi]:  # error
+        tir.evaluate(1.0)
+
+
+def test_invalid_block_axes():
+    check_error(invalid_block_axes, 3)
+
+
+def miss_block_bind() -> None:
+    with tir.block([16, 16]) as [vi, vj]:  # error
+        tir.bind(vi, 1)
+        tir.evaluate(1.0)
+
+
+def test_miss_block_bind():
+    check_error(miss_block_bind, 2)
+
+
+def invalid_loop_var() -> None:
+    for i, j in range(0, 16):  # error
+        tir.evaluate(1.0)
+
+
+def test_invalid_loop_var():
+    check_error(invalid_loop_var, 2)
+
+
+def inconsistent_grid() -> None:
+    for i in tir.grid(16, 16):  # error
+        tir.evaluate(1.0)
+
+
+def test_inconsistent_grid():
+    check_error(inconsistent_grid, 2)
+
+
+def invalid_match_buffer_region() -> None:
+    with tir.block([16, 16]) as [vi, vj]:
+        A = tir.match_buffer_region(vi)  # error
+        tir.evaluate(1.0)
+
+
+def test_invalid_match_buffer_region():
+    check_error(invalid_match_buffer_region, 3)
+
+
+def duplicate_buffer() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        A = tir.alloc_buffer((128, 128), "float32")  # error
+        tir.evaluate(1.0)
+
+
+def test_duplicate_buffer():
+    check_error(duplicate_buffer, 4)
+
+
+def duplicate_reads() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        tir.reads(A[0:8, 0:8])
+        tir.reads(A[0:16, 0:16])  # error
+        tir.evaluate(1.0)
+
+
+def duplicate_writes() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        tir.writes(A[0:8, 0:8])
+        tir.writes(A[0:16, 0:16])  # error
+        tir.evaluate(1.0)
+
+
+def duplicate_predicate() -> None:
+    with tir.block([16, 16]) as [vi, vj]:
+        tir.where(1)
+        tir.where(0)  # error
+
+
+def duplicate_annotations() -> None:
+    with tir.block([16, 16]) as [vi, vj]:
+        tir.block_attr({})
+        tir.block_attr({})  # error
+
+
+def duplicate_init() -> None:
+    with tir.block([16, 16]) as [vi, vj]:
+        with tir.init():
+            tir.evaluate(1.0)
+        with tir.init():  # error
+            tir.evaluate(1.0)
+
+
+def test_duplicate_block_signature():
+    check_error(duplicate_reads, 5)
+    check_error(duplicate_writes, 5)
+    check_error(duplicate_predicate, 4)
+    check_error(duplicate_annotations, 4)
+    check_error(duplicate_init, 5)
+
+
+def opaque_access_during_complete(a: ty.handle) -> None:  # error
+    A = tir.match_buffer(a, (16, 16), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        tir.evaluate(tir.load("float32", A.data, vi * 16 + vj))
+
+
+def test_opaque_access_during_complete():
+    check_error(opaque_access_during_complete, 1)
+
+
+def convert_slice_to_bufferload() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        A[vi, vj] = A[vi : vi + 2, vj] + 1  # error
+
+
+def test_convert_slice_to_bufferload():
+    check_error(convert_slice_to_bufferload, 4)
+
+
+def error_index_type() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        A[vi, vj] = A[vi, 0.0] + 1  # error
+
+
+def test_error_index_type():
+    check_error(error_index_type, 4)
+
+
+def mismatch_args() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        tir.reads(A[0, 0], A[1, 1])  # error
+        tir.evaluate(1.0)
+
+
+def test_mismatch_args():
+    check_error(mismatch_args, 4)
+
+
+def special_stmt_except() -> None:
+    A = tir.alloc_buffer("(128, 128)", "float32")  # error
+    with tir.block([16, 16]) as [vi, vj]:
+        tir.evaluate(1.0)
+
+
+def scope_handler_except() -> None:
+    for i in tir.serial("1", "1"):  # error
+        tir.evaluate(1)
+
+
+def intrin_except_unassign(a: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    tir.evaluate(A)  # error
+
+
+def intrin_except_assign(a: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    A[0, 0] = tir.load(A, A, A)  # error
+
+
+def test_tvm_exception_catch():
+    # test catching c++ side exception
+    check_error(special_stmt_except, 2)
+    check_error(scope_handler_except, 2)
+    check_error(intrin_except_unassign, 3)
+    check_error(intrin_except_assign, 3)
+
+
 def check_error(module, rel_lineno):
     # Override the default renderer to accumulate errors
     _, start_line = inspect.getsourcelines(module)
@@ -180,3 +371,17 @@ if __name__ == "__main__":
     test_return_not_allowed()
     test_tir_assert()
     test_no_body()
+    test_allocate_with_buffers()
+    test_inconsistent_binding()
+    test_invalid_block_axes()
+    test_miss_block_bind()
+    test_invalid_loop_var()
+    test_inconsistent_grid()
+    test_invalid_match_buffer_region()
+    test_duplicate_buffer()
+    test_duplicate_block_signature()
+    test_opaque_access_during_complete()
+    test_convert_slice_to_bufferload()
+    test_error_index_type()
+    test_mismatch_args()
+    test_tvm_exception_catch()
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index c7a38cc..a295908 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -2662,6 +2662,169 @@ def test_opt_conv_tensorcore_mod_host():
     tvm.ir.assert_structural_equal(mod, rt_mod, True)
 
 
+@tvm.script.tir
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+
+    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, 
vk]:
+        with tir.init():
+            C[vi, vj] = tir.float32(0)
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def matmul_original(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+
+    for i, j in tir.grid(128, 128):
+        with tir.block([128, 128], "init") as [vi, vj]:
+            C[vi, vj] = tir.float32(0)
+
+        for k in range(0, 128):
+            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as 
[vi, vj, vk]:
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128), "float32")
+    C = tir.match_buffer(c, (128, 128), "float32")
+    B = tir.alloc_buffer((128, 128), "float32")
+
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * tir.float32(2)
+
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + tir.float32(1)
+
+
+@tvm.script.tir
+def predicate(b: ty.handle, c: ty.handle) -> None:
+    B = tir.match_buffer(b, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+
+    for i, jo, ji in tir.grid(16, 4, 5):
+        with tir.block([16, 16], "update") as [vi, vj]:
+            tir.bind(vi, i)
+            tir.bind(vj, jo * 4 + ji)
+            tir.where(jo * 4 + ji < 16)
+            C[vi, vj] = B[vi, vj] + tir.float32(1)
+
+
+def test_module_define():
+    func1 = tvm.script.create_module({"matmul": matmul})["matmul"]
+    func2 = tvm.script.create_module({"element_wise": 
element_wise})["element_wise"]
+    func3 = tvm.script.create_module({"predicate": predicate})["predicate"]
+    mod1 = tvm.script.create_module({"func1": func1, "func2": func2, "func3": 
func3})
+    mod2 = tvm.script.create_module({"func1": matmul, "func2": element_wise, 
"func3": predicate})
+    tvm.ir.assert_structural_equal(mod1, mod2)
+
+
+def test_matmul():
+    func = matmul
+    rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+    tvm.ir.assert_structural_equal(func, rt_func)
+
+
+def test_matmul_original():
+    func = matmul_original
+    rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+    tvm.ir.assert_structural_equal(func, rt_func)
+
+    assert isinstance(rt_func.body.block, tir.stmt.Block)
+    assert isinstance(rt_func.body.block.body, tir.stmt.For)
+    assert isinstance(rt_func.body.block.body.body, tir.stmt.For)
+    assert isinstance(rt_func.body.block.body.body.body, tir.stmt.SeqStmt)
+    assert isinstance(rt_func.body.block.body.body.body[0].block, 
tir.stmt.Block)
+    assert isinstance(rt_func.body.block.body.body.body[1], tir.stmt.For)
+    assert isinstance(rt_func.body.block.body.body.body[1].body.block, 
tir.stmt.Block)
+
+
+def test_element_wise():
+    func = element_wise
+    rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+    tvm.ir.assert_structural_equal(func, rt_func)
+
+    assert isinstance(rt_func.body.block, tir.stmt.Block)
+    assert isinstance(rt_func.body.block.body, tir.stmt.SeqStmt)
+    assert isinstance(rt_func.body.block.body[0], tir.stmt.For)
+    assert isinstance(rt_func.body.block.body[0].body, tir.stmt.For)
+    assert isinstance(rt_func.body.block.body[0].body.body.block, 
tir.stmt.Block)
+
+    assert isinstance(rt_func.body.block.body[1], tir.stmt.For)
+    assert isinstance(rt_func.body.block.body[1].body, tir.stmt.For)
+    assert isinstance(rt_func.body.block.body[1].body.body.block, 
tir.stmt.Block)
+
+
+def test_predicate():
+    func = predicate
+    rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+    tvm.ir.assert_structural_equal(func, rt_func)
+
+    assert isinstance(rt_func.body.block, tir.stmt.Block)
+    assert isinstance(rt_func.body.block.body, tir.stmt.For)
+    assert isinstance(rt_func.body.block.body.body, tir.stmt.For)
+    assert isinstance(rt_func.body.block.body.body.body, tir.stmt.For)
+    assert isinstance(rt_func.body.block.body.body.body.body.block, 
tir.stmt.Block)
+
+
+@tvm.script.tir
+def for_thread_binding(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    B = tir.match_buffer(b, (16, 16), "float32")
+
+    for i in tir.thread_binding(0, 16, thread="threadIdx.x"):
+        for j in tir.thread_binding(0, 16, thread="threadIdx.y"):
+            A[i, j] = B[i, j] + tir.float32(1)
+
+
+def test_for_thread_binding():
+    func = for_thread_binding
+    rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+    tvm.ir.assert_structural_equal(func, rt_func)
+
+    assert isinstance(rt_func.body, tir.stmt.For)
+    assert rt_func.body.kind == 4
+    assert rt_func.body.thread_binding.thread_tag == "threadIdx.x"
+    assert isinstance(rt_func.body.body, tir.stmt.For)
+    assert rt_func.body.body.kind == 4
+    assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y"
+
+
+@tvm.script.tir
+def block_elements(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    B = tir.match_buffer(b, (1, 1), "float32")
+
+    with tir.block([1], "update") as [vi]:
+        tir.bind(vi, 0)
+        tir.where(True)
+        tir.reads(A[0:16, 0:16])
+        tir.writes(B[0, 0])
+        tir.block_attr({"attr_key": "attr_value"})
+        C = tir.alloc_buffer((4, 4), dtype="float32")
+        D = tir.match_buffer_region(A[0:4, 0])
+        with tir.init():
+            B[0, 0] = tir.float32(0)
+        B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0]
+
+
+def test_block_elements():
+    func = block_elements
+    rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+    tvm.ir.assert_structural_equal(func, rt_func)
+
+    assert isinstance(rt_func.body.block, tir.stmt.Block)
+    assert isinstance(rt_func.body.block.body, tir.stmt.BufferStore)
+    assert isinstance(rt_func.body.block.init, tir.stmt.BufferStore)
+    assert len(rt_func.body.block.annotations) == 1
+    assert rt_func.body.block.annotations["attr_key"] == "attr_value"
+
+
 if __name__ == "__main__":
     test_opt_gemm_normalize()
     test_opt_gemm_mod_host()
@@ -2669,3 +2832,10 @@ if __name__ == "__main__":
     test_opt_conv_tensorcore_normalize()
     test_opt_conv_tensorcore_lower()
     test_opt_conv_tensorcore_mod_host()
+    test_module_define()
+    test_matmul()
+    test_matmul_original()
+    test_element_wise()
+    test_predicate()
+    test_for_thread_binding()
+    test_block_elements()
diff --git a/tests/scripts/task_ci_python_setup.sh 
b/tests/scripts/task_ci_python_setup.sh
index f48ed49..b880cb9 100755
--- a/tests/scripts/task_ci_python_setup.sh
+++ b/tests/scripts/task_ci_python_setup.sh
@@ -30,4 +30,4 @@ set -o pipefail
 #
 echo "Addtiional setup in" ${CI_IMAGE_NAME}
 
-python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.2.1
+python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.3.0
diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh
index 17838c5..9dda54e 100755
--- a/tests/scripts/task_ci_setup.sh
+++ b/tests/scripts/task_ci_setup.sh
@@ -30,7 +30,7 @@ set -o pipefail
 #
 echo "Addtiional setup in" ${CI_IMAGE_NAME}
 
-python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.2.1
+python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.3.0
 
 # Rebuild standalone_crt in build/ tree. This file is not currently archived 
by pack_lib() in
 # Jenkinsfile. We expect config.cmake to be present from pack_lib().

Reply via email to