This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit c21a7ddfccb57909faaeac127a99784867571108
Author: Yuchen Jin <yuch...@cs.washington.edu>
AuthorDate: Sat Feb 11 09:26:36 2023 -0800

    [Unity] e2e Relax minimum build flow (#13961)
    
    This PR introduces the e2e Relax lowering flow (`relax.vm.build`). Tests 
for each pass in the flow are added.
    
    Co-Authored-by: Altan Haan <alt...@cs.washington.edu>
    Co-Authored-by: Andrew Liu <andrewl...@gmail.com>
    Co-Authored-by: Hongyi Jin <3231950...@qq.com>
    Co-Authored-by: Jiawei Liu <jaway....@gmail.com>
    Co-Authored-by: Junru Shao <junrushao1...@gmail.com>
    Co-Authored-by: Prakalp Srivastava <prak...@octoml.ai>
    Co-Authored-by: Ruihang Lai <ruiha...@cs.cmu.edu>
    Co-Authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn>
    Co-Authored-by: Steven S. <Lyubomirsky slyubomir...@octoml.ai>
    Co-Authored-by: Sunghyun Park <49998730+sun...@users.noreply.github.com>
    Co-Authored-by: Tianqi Chen <tianqi.tc...@gmail.com>
    Co-Authored-by: Yong Wu <yongc...@gmail.com>
    Co-Authored-by: Ziheng Jiang <zih...@apache.org>
---
 CMakeLists.txt                                     |   1 +
 include/tvm/relax/analysis.h                       |  16 +
 include/tvm/relax/backend.h                        |   7 +
 include/tvm/relax/transform.h                      |  35 +
 python/tvm/relax/analysis/analysis.py              |  45 +
 python/tvm/relax/op/__init__.py                    |   3 +
 python/tvm/relax/op/{ => builtin}/__init__.py      |   6 +-
 .../relax/op/{__init__.py => builtin/_ffi_api.py}  |   9 +-
 python/tvm/relax/op/builtin/builtin.py             |  44 +
 python/tvm/relax/op/manipulate.py                  |  62 ++
 python/tvm/relax/op/{ => memory}/__init__.py       |   6 +-
 .../relax/op/{__init__.py => memory/_ffi_api.py}   |   9 +-
 python/tvm/relax/op/memory/memory.py               | 108 +++
 python/tvm/relax/{op => testing}/__init__.py       |   6 +-
 python/tvm/relax/testing/nn.py                     | 194 +++++
 python/tvm/relax/transform/transform.py            |  53 ++
 python/tvm/relax/vm.py                             |   4 +-
 python/tvm/script/ir_builder/relax/ir.py           |   6 +
 src/relax/analysis/tir_op_pattern_kind.cc          | 447 ++++++++++
 src/relax/backend/vm/vm_builtin_lower.cc           | 208 +++++
 src/relax/op/op.cc                                 |  81 ++
 src/relax/op/tensor/manipulate.cc                  | 163 ++++
 .../backend.h => src/relax/op/tensor/manipulate.h  |  25 +-
 src/relax/transform/attach_global_symbol.cc        |  68 ++
 src/relax/transform/call_tir_rewrite.cc            | 137 ++++
 src/relax/transform/rewrite_dataflow_reshape.cc    | 110 +++
 src/relax/transform/to_non_dataflow.cc             |  67 ++
 tests/python/relax/test_analysis.py                | 172 ++++
 tests/python/relax/test_transform.py               | 141 ++++
 .../relax/test_transform_attach_global_symbol.py   |  88 ++
 .../test_transform_rewrite_dataflow_reshape.py     | 166 ++++
 tests/python/relax/test_vm_build.py                | 908 +++++++++++++++++++++
 32 files changed, 3358 insertions(+), 37 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index eecd67be94..d0470677e1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -292,6 +292,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
     src/relax/ir/*.cc
     src/relax/op/*.cc
     src/relax/analysis/*.cc
+    src/relax/transform/*.cc
     src/relax/backend/vm/*.cc
     src/relax/utils.cc
     )
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index ad2bd19aa4..24cfe5b9bf 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -259,6 +259,22 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const 
StructInfo& derived,
  */
 TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
                                  arith::Analyzer* ana = nullptr);
+
+/*!
+ * \brief Check if the given PrimFunc is essentially doing a reshape operation.
+ * The reshape operation also includes expand_dims, squeeze, flatten, etc.
+ * \details Here the allowed reshape pattern is: for example, assume the 
operation is
+ *  `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove 
that the flattened
+ * index of l_0, ..., l_b under buffer B equals to the flattened index of r_0, 
..., r_a under
+ * buffer A.
+ * \param func The function to be examined.
+ * \return A boolean indicating if the given PrimFunc is doing a reshape.
+ * \note According to the description above, the returned result can only be 
false-negative and
+ * cannot be false-positive, since whenever we cannot prove the equality, we 
return false. This
+ * property guarantees the safety of this function.
+ */
+TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h
index 4ebeacac0f..2fb11f5a6f 100644
--- a/include/tvm/relax/backend.h
+++ b/include/tvm/relax/backend.h
@@ -30,6 +30,13 @@ namespace tvm {
 namespace relax {
 namespace transform {
 
+/*!
+ * \brief Perform builtin lowering to map most of the op to VM builtin 
functions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass VMBuiltinLower();
+
 /*!
  * \brief Lower the shape expression in relax to VM shape heap and TIR 
functions.
  *
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index fa288a7f06..ff98b16d25 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -65,6 +65,41 @@ TVM_DLL Pass CreateDataflowBlockPass(
     const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)>& pass_func,
     int opt_level, String name, tvm::Array<String> required);
 
+/*!
+ * \brief Transform all dataflow structure to non-dataflow version.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass ToNonDataflow();
+
+/*!
+ * \brief Perform explicit tensor allocation for call_tir.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass CallTIRRewrite();
+
+/*!
+ * \brief Convert all reshape-like call_tir whose corresponding binding
+ * vars are DataflowVars to relax.reshape operator calls. The relax.reshape
+ * calls will be lowered an external builtin function call in a subsequent
+ * pass, where the external builtin function does a CreateView operation
+ * at runtime, instead of doing real data copy.
+ * Here "reshape-like" includes reshape, expand_dims, flatten, etc.
+ *
+ * \return The Pass.
+ * \note The pass is applied at the first stage of Relax VM build, before
+ * rewriting call_tir, as this pass requires dataflow information.
+ */
+TVM_DLL Pass RewriteDataflowReshape();
+
+/*!
+ * \brief Attach global_symbol to Relax functions and TIR Primfuncs for 
codegen.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass AttachGlobalSymbol();
+
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/analysis/analysis.py 
b/python/tvm/relax/analysis/analysis.py
index d81c477145..27416c3a79 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -162,3 +162,48 @@ def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> 
StructInfo:
         The corresponding lca result.
     """
     return _ffi_api.StructInfoLCA(lhs, rhs)  # type: ignore
+
+
+def post_order_visit(expr, fvisit):
+    """Recursively visit the ir in post DFS order node,
+    apply fvisit. Each node is guaranteed to be visited
+    only once.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The input expression.
+
+    fvisit : function
+        The visitor function to be applied.
+    """
+    return _ffi_api.post_order_visit(expr, fvisit)  # type: ignore
+
+
+def has_reshape_pattern(func: tir.PrimFunc) -> bool:
+    """Check if the given PrimFunc is essentially doing a reshape operation.
+    The reshape operation also includes expand_dims, squeeze, flatten, etc.
+
+    Here the allowed reshape pattern is: for example, assume the operation is
+    `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove
+    that the flattened index of l_0, ..., l_b under buffer B equals to the
+    flattened index of r_0, ..., r_a under buffer A.
+
+    Parameters
+    ----------
+    func : tir.PrimFunc
+        The function to be examined.
+
+    Returns
+    -------
+    ret : bool
+        A boolean indicating if the given PrimFunc is doing a reshape.
+
+    Notes
+    -----
+    According to the description above, the returned result can only be
+    false-negative and cannot be false-positive, since whenever we cannot
+    prove the equality, we return false. This property guarantees the safety
+    of this function.
+    """
+    return _ffi_api.has_reshape_pattern(func)  # type: ignore
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 101b0827d6..9a131cdf95 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -20,3 +20,6 @@
 # Operators
 from .base import *
 from .binary import *
+from .manipulate import *
+from . import builtin
+from . import memory
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/builtin/__init__.py
similarity index 91%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/builtin/__init__.py
index 101b0827d6..04837724b1 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/builtin/__init__.py
@@ -15,8 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Relax builtin operators."""
 
-# Operators
-from .base import *
-from .binary import *
+from .builtin import *
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/builtin/_ffi_api.py
similarity index 83%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/builtin/_ffi_api.py
index 101b0827d6..42fe8cb652 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/builtin/_ffi_api.py
@@ -13,10 +13,7 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""FFI APIs for tvm.relax.op.builtin"""
+import tvm._ffi
 
-# Operators
-from .base import *
-from .binary import *
+tvm._ffi._init_api("relax.op.builtin", __name__)
diff --git a/python/tvm/relax/op/builtin/builtin.py 
b/python/tvm/relax/op/builtin/builtin.py
new file mode 100644
index 0000000000..0afe6a42d0
--- /dev/null
+++ b/python/tvm/relax/op/builtin/builtin.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+"""The builtin Relax operators."""
+
+from ...expr import Call, Expr
+from ...utils import args_converter
+from . import _ffi_api
+
+
+@args_converter.auto
+def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call:
+    """Construct a Call to allocate a tensor with specific shape, dtype, 
runtime_device_index.
+
+    Parameters
+    ----------
+    shape : Expr
+        The shape of the tensor to be allocated.
+
+    dtype : str
+        The datatype of the tensor to be allocated.
+
+    runtime_device_index : int
+        The device index indicating on which device the tensor is to be 
allocated at runtime.
+        Index -1 is reserved for the host device.
+
+    Returns
+    -------
+    result : Call
+        A relax Call, which gets the allocated tensor.
+    """
+    return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index)  # type: 
ignore
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
new file mode 100644
index 0000000000..fa9c815225
--- /dev/null
+++ b/python/tvm/relax/op/manipulate.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Manipulation operators."""
+from typing import Tuple, Union
+
+from tvm.ir.expr import PrimExpr
+
+
+from . import _ffi_api
+from ..expr import Expr
+
+
+PrimExprLike = Union[int, PrimExpr]
+
+
+def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr:
+    """Reshape the input array.
+
+    ``-1`` infers the dimension of the output shape by using the remainder of
+    the input dimensions keeping the size of the new array same as that of the 
input array.
+    At most one dimension of shape can be -1.
+
+        .. code-block:: python
+
+            x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4)
+            x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8)
+            x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,)
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data to the operator.
+
+    shape : Union[Tuple[PrimExprLike], Expr]
+        The new shape. Should be compatible with the original shape.
+
+    Returns
+    -------
+    result : relax.Expr
+        The reshaped result.
+
+    Note
+    ----
+    The ``-1`` inference is only performed at compile-time.
+    That is to say, in any case the dimension length of ``-1`` cannot be 
inferred in
+    compile-time, an error will be thrown.
+    """
+    return _ffi_api.reshape(x, shape)  # type: ignore
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/memory/__init__.py
similarity index 91%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/memory/__init__.py
index 101b0827d6..e039590251 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/memory/__init__.py
@@ -15,8 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Relax memory primitives."""
 
-# Operators
-from .base import *
-from .binary import *
+from .memory import *
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/memory/_ffi_api.py
similarity index 83%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/memory/_ffi_api.py
index 101b0827d6..475de481b2 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/memory/_ffi_api.py
@@ -13,10 +13,7 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""FFI APIs for tvm.relax.op.memory"""
+import tvm._ffi
 
-# Operators
-from .base import *
-from .binary import *
+tvm._ffi._init_api("relax.op.memory", __name__)
diff --git a/python/tvm/relax/op/memory/memory.py 
b/python/tvm/relax/op/memory/memory.py
new file mode 100644
index 0000000000..b58b987d2a
--- /dev/null
+++ b/python/tvm/relax/op/memory/memory.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+"""Relax memory primitives."""
+
+from . import _ffi_api
+from ...expr import Expr, Call
+from ...utils import args_converter
+
+
+@args_converter.auto
+def alloc_storage(size: Expr, virtual_device_index: int, storage_scope: str, 
dtype: str) -> Call:
+    """Construct a Call to allocate a storage with specific size, 
virtual_device_index,
+    storage_scope and dtype.
+
+    Parameters
+    ----------
+    size : Expr
+        The size of the storage to be allocated.
+
+    virtual_device_index : int
+        The virtual device index indicating on which device the storage is to 
be allocated.
+        Index -1 is reserved for the host device.
+
+    storage_scope : str
+        The storage scope to allocate the storage to.
+
+    dtype : str
+        The datatype of the storage to be allocated.
+
+    Returns
+    -------
+    result : Call
+        A relax Call, which gets the allocated storage.
+    """
+    return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, 
dtype)  # type: ignore
+
+
+@args_converter.auto
+def alloc_tensor(storage: Expr, offset: int, shape: Expr, dtype: str) -> Call:
+    """Construct a Call to allocate a tensor on a certain storage starting 
from the given offset.
+
+    Parameters
+    ----------
+    storage : Expr
+        The storage to allocate the tensor to.
+
+    offset : int
+        The storage offset to allocate the tensor.
+
+    shape : Expr
+        The shape of the tensor to be allocated.
+
+    dtype : str
+        The datatype of the tensor to be allocated.
+
+    Returns
+    -------
+    result : Call
+        A relax Call, which gets the allocated tensor.
+    """
+    return _ffi_api.alloc_tensor(storage, offset, shape, dtype)  # type: ignore
+
+
+@args_converter.auto
+def kill_storage(storage: Expr) -> Call:
+    """Construct a Call to kill a storage.
+
+    Parameters
+    ----------
+    storage : Expr
+        The storage to be killed.
+
+    Returns
+    -------
+    result : Call
+        A relax Call to kill a storage.
+    """
+    return _ffi_api.kill_storage(storage)  # type: ignore
+
+
+@args_converter.auto
+def kill_tensor(tensor: Expr) -> Call:
+    """Construct a Call to kill a tensor.
+
+    Parameters
+    ----------
+    tensor : Expr
+        The tensor to be killed.
+
+    Returns
+    -------
+    result : Call
+        A relax Call to kill a tensor.
+    """
+    return _ffi_api.kill_tensor(tensor)  # type: ignore
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/testing/__init__.py
similarity index 91%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/testing/__init__.py
index 101b0827d6..ab1dd6f515 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -15,8 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""The Relax testing namespace containing nn and translator."""
 
-# Operators
-from .base import *
-from .binary import *
+from .nn import *
diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py
new file mode 100644
index 0000000000..830ddd779f
--- /dev/null
+++ b/python/tvm/relax/testing/nn.py
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""PyTorch-like nn.Module API for constructing workloads."""
+
+
+from typing import List, Any, Callable, Union
+import typing
+import numpy as np  # type: ignore
+
+import tvm
+from tvm import relax, topi, tir
+
+
+def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var:
+    return relax.BlockBuilder.current().emit_te(func, *args, **kwargs)
+
+
+class Placeholder(relax.Var):
+    """A placeholder variable that can represent model input."""
+
+    def __init__(
+        self, shape: Union[List[Any], typing.Tuple[Any, ...]], 
dtype="float32", name="data"
+    ):
+        if not isinstance(shape, (list, tuple)):
+            raise TypeError("the shape of Placeholder is expected to be a list 
or a tuple")
+        super().__init__(
+            relax.BlockBuilder.current().get_unique_name(name), 
relax.TensorStructInfo(shape, dtype)
+        )
+
+
+class Parameter(relax.Var):
+    """A special kind of relax Var that represents model parameter(weight)."""
+
+    def __init__(
+        self, shape: Union[List[Any], typing.Tuple[Any, ...]], 
dtype="float32", name="param"
+    ):
+        if not isinstance(shape, (list, tuple)):
+            raise TypeError("the shape of Parameter is expected to be a list 
or a tuple")
+        super().__init__(
+            relax.BlockBuilder.current().get_unique_name(name), 
relax.TensorStructInfo(shape, dtype)
+        )
+
+
+class Module:
+    """Base class for all model modules.
+
+    A neural network or a layer can subclass this class.
+
+    Example
+    -------
+    .. code-block:: python
+
+        # Define a linear layer
+        class Linear(Module)
+            def __init__(self, in_features, out_features, bias=True):
+                self.in_features = in_features
+                self.out_features = out_features
+                self.weight = Parameter((in_features, out_features), 
name="linear_weight")
+                if bias:
+                    self.bias = Parameter((out_features,), name="linear_bias")
+                else:
+                    self.bias = None
+
+            # All submodules should implement forward.
+            # Defines the forward computation performed at every call.
+            def forward(self, input: relax.Expr) -> relax.Var:
+                y = emit_te(topi.matmul, input, self.weight)
+                if self.bias is not None:
+                    y = emit_te(topi.add, y, self.bias)
+                return y
+    """
+
+    def parameters(self) -> List[Parameter]:
+        """Return the list of parameters in the module."""
+        return _unpack_params(self.__dict__)
+
+    def forward(self, input: relax.Expr):
+        """Define the computation performed at every call."""
+        raise NotImplementedError()
+
+    def __call__(self, *args, **kwargs):
+        return self.forward(*args, **kwargs)
+
+
+def _unpack_params(value: object) -> List[relax.Var]:
+    if isinstance(value, Parameter):
+        return [value]
+    if isinstance(value, Module):
+        return value.parameters()
+    if isinstance(value, dict):
+        params = []
+        for v in value.values():
+            params += _unpack_params(v)
+        return params
+    if isinstance(value, (list, tuple)):
+        params = []
+        for v in value:
+            params += _unpack_params(v)
+        return params
+    if value is None or isinstance(value, (int, float, str)):
+        return []
+    raise TypeError("not supported type when unpacking parameters: 
{}".format(type(value)))
+
+
+def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]:
+    """Utility function to initialize model's parameters."""
+    shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params}
+    params = []
+    for k, v in shape_dict.items():
+        if k.startswith("data"):
+            continue
+        if isinstance(v, relax.ShapeExpr):
+            shape = []
+            for i in v:
+                if isinstance(i, tir.IntImm):
+                    shape.append(int(i))
+                else:
+                    raise TypeError("cannot initialize for unknown-shape 
parameters.")
+            params.append(tvm.nd.array(np.zeros(shape).astype(np.float32)))
+        else:
+            raise TypeError("cannot initialize for unknown-shape parameters.")
+    return params
+
+
+class Sequential(Module):
+    """A sequential container that concatenates modules in it.
+
+    Example
+    -------
+    .. code-block:: python
+
+        model = nn.Sequential(
+                    nn.Conv2d(1, 20, 5),
+                    nn.ReLU(),
+                    nn.Conv2d(20, 64, 5),
+                    nn.ReLU()
+                )
+    """
+
+    def __init__(self, *modules: Module):
+        self.modules = modules
+
+    def forward(self, input: relax.Expr) -> relax.Var:
+        for module in self.modules:
+            input = module(input)
+        return input
+
+
+class ReLU(Module):
+    """Applies the rectified linear unit activation function on the input."""
+
+    def forward(self, input: relax.Expr) -> relax.Var:
+        return emit_te(topi.nn.relu, input)
+
+
+class LogSoftmax(Module):
+    """Applies log softmax activation function on the input."""
+
+    def forward(self, input: relax.Expr) -> relax.Var:
+        return emit_te(topi.nn.log_softmax, input)
+
+
+class Linear(Module):
+    """Applies a linear transformation to the input data: :math:`y = xA + 
b`."""
+
+    def __init__(self, in_features, out_features, bias=True):
+        self.in_features = in_features
+        self.out_features = out_features
+        self.weight = Parameter((in_features, out_features), 
name="linear_weight")
+        if bias:
+            self.bias = Parameter((out_features,), name="linear_bias")
+        else:
+            self.bias = None
+
+    def forward(self, input: relax.Expr) -> relax.Var:
+        y = emit_te(topi.matmul, input, self.weight)
+        if self.bias is not None:
+            y = emit_te(topi.add, y, self.bias)
+        return y
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index f20f06c522..cab18797c6 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -37,6 +37,49 @@ class DataflowBlockPass(tvm.ir.transform.Pass):
     """A pass that works on each tvm.relax.DataflowBlock in a module."""
 
 
+def ToNonDataflow() -> tvm.ir.transform.Pass:
+    """Transform all dataflow structure to non-dataflow version.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.ToNonDataflow()  # type: ignore
+
+
+def CallTIRRewrite() -> tvm.ir.transform.Pass:
+    """Perform explicit tensor allocation for call_tir.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.CallTIRRewrite()  # type: ignore
+
+
+def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
+    """Convert all reshape-like call_tir to VM reshape operator call.
+    The VM reshape operator calls will be further lowered to a CreateView
+    operation at runtime, instead of doing real data copy.
+    Here "reshape-like" includes reshape, expand_dims, flatten, etc.
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+    """
+    return _ffi_api.RewriteDataflowReshape()  # type: ignore
+
+
+def VMBuiltinLower() -> tvm.ir.transform.Pass:
+    """Lowering generic intrinsic to VM intrinsics.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.VMBuiltinLower()  # type: ignore
+
+
 def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
     """Lower the symbolic shape and argument and match-cast structinfo 
matching.
 
@@ -52,6 +95,16 @@ def VMShapeLower(*, emit_err_ctx: bool = True) -> 
tvm.ir.transform.Pass:
     return _ffi_api.VMShapeLower(emit_err_ctx)  # type: ignore
 
 
+def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
+    """Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.AttachGlobalSymbol()  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py
index ba16dfb079..ff6bf816b6 100644
--- a/python/tvm/relax/vm.py
+++ b/python/tvm/relax/vm.py
@@ -581,7 +581,9 @@ def build(
     if isinstance(target, str):
         target = tvm.target.Target(target)
 
-    passes = [relax.transform.ToNonDataflow()]
+    passes = []
+    passes.append(relax.transform.RewriteDataflowReshape())
+    passes.append(relax.transform.ToNonDataflow())
     passes.append(relax.transform.CallTIRRewrite())
     passes.append(relax.transform.VMBuiltinLower())
     passes.append(relax.transform.VMShapeLower())
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 647ef8f25a..0692ec5683 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -31,13 +31,16 @@ from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, 
Var, const
 from tvm.relax.op import (
     add,
     assert_op,
+    builtin,
     call_builtin_with_ctx,
     call_tir,
     invoke_closure,
     make_closure,
+    memory,
     multiply,
     null_value,
     print,
+    reshape,
     shape_of,
 )
 from tvm.relax.struct_info import StructInfo
@@ -381,6 +384,7 @@ __all__ = [
     "add",
     "arg",
     "assert_op",
+    "builtin",
     "call_packed",
     "call_tir",
     "call_builtin_with_ctx",
@@ -396,11 +400,13 @@ __all__ = [
     "function",
     "invoke_closure",
     "make_closure",
+    "memory",
     "multiply",
     "null_value",
     "output",
     "prim_value",
     "print",
+    "reshape",
     "shape_of",
     "str",
     "tuple",
diff --git a/src/relax/analysis/tir_op_pattern_kind.cc 
b/src/relax/analysis/tir_op_pattern_kind.cc
new file mode 100644
index 0000000000..b7ac8faddd
--- /dev/null
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -0,0 +1,447 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+using namespace tir;
+
+class PatternKindAnalyzer : public StmtExprVisitor {
+ public:
+  explicit PatternKindAnalyzer(const tir::PrimFunc& func) {
+    for (const tir::Var& param : func->params) {
+      Optional<Buffer> param_buf = func->buffer_map.Get(param);
+      if (param_buf.defined()) {
+        param_buffers_.insert(param_buf.value());
+      }
+    }
+  }
+
+ private:
+  bool IsOutputBlock(const BlockNode* block) {
+    for (const BufferRegion& write_region : block->writes) {
+      if (param_buffers_.count(write_region->buffer)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    // We only support one buffer store in a block (ususally generated by TE 
compute)
+    // If we have already seen buffer store in the current block, classify as 
Opaque.
+    if (store_.defined()) {
+      kind_ = relay::kOpaque;
+      return;
+    }
+    store_ = GetRef<BufferStore>(op);
+    StmtVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    loads_.push_back(GetRef<BufferLoad>(op));
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (op->name_hint == "root") {
+      // Skip the root block
+      StmtVisitor::VisitStmt(op->body);
+      return;
+    }
+
+    // Step 1. Clear loads and store
+    loads_.clear();
+    store_ = NullOpt;
+    // Step 2. Visit block body.
+    StmtVisitor::VisitStmt(op->body);
+    BufferStore store = store_.value();
+
+    // Step 3. Checking load store indices pattern
+    relay::OpPatternKind index_pair_pattern = relay::kElemWise;
+    bool has_elem_wise = false;
+    for (const BufferLoad& load : loads_) {
+      // Since elemwise is stricter than broadcast and broadcast is stricter 
than injective,
+      // while the order amount enums: kElemWise < kBroadcast < kInjective.
+      // We can simpily use `std::max` to detect these three patterns.
+      // E.g Here is only one store node but two load nodes, like C[i, j] = 
A[i, j] + B[i]
+      // Buffer C and A are elemwise but C and B are broadcast. So the whole 
block follows
+      // broadcast pattern.
+      if (IsElemwisePattern(store, load)) {
+        index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise);
+        has_elem_wise = true;
+      } else if (IsBroadcastPattern(store, load)) {
+        index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast);
+      } else if (IsInjectivePattern(store, load)) {
+        index_pair_pattern = std::max(index_pair_pattern, relay::kInjective);
+      } else {
+        index_pair_pattern = relay::kOpaque;
+        break;
+      }
+    }
+    // If there is a index pair is kElemWise and others are kBroadcast, we 
regard it as kElemWise
+    // e.g. A[i, j] = B[i, j] + C[i]
+    if (index_pair_pattern == relay::kBroadcast && has_elem_wise) {
+      index_pair_pattern = relay::kElemWise;
+    }
+    // If the block index pattern is not opaque, update kind.
+    if (index_pair_pattern != relay::kOpaque) {
+      // This rule for softmax: reduce + injective.
+      if (IsOutputBlock(op) && kind_ == relay::kCommReduce) {
+        kind_ = relay::kOutEWiseFusable;
+      } else {
+        kind_ = std::max(kind_, index_pair_pattern);
+      }
+      return;
+    }
+
+    // Step 4. Checking if the block contains reduce axis by looking into 
block iterators.
+    bool has_reduction = false;
+    Array<tir::Var> reduce_vars;
+    for (const IterVar& it : op->iter_vars) {
+      if (it->iter_type == kCommReduce) {
+        has_reduction = true;
+        reduce_vars.push_back(it->var);
+      }
+    }
+
+    if (has_reduction) {
+      if (IsFMA(op->body)) {
+        // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv.
+        kind_ = std::max(kind_, relay::kOutEWiseFusable);
+        return;
+      } else {
+        for (size_t i = 0; i < loads_.size(); ++i) {
+          // If it's not a pure reduce, regards as kOutEWiseFusable.
+          // This rule works for pooling for now.
+          if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) {
+            kind_ = std::max(kind_, relay::kOutEWiseFusable);
+            return;
+          }
+        }
+      }
+      kind_ = std::max(kind_, relay::kCommReduce);
+    } else {
+      kind_ = relay::kOpaque;
+    }
+  }
+
+  /********** Helper Functions **********/
+
+  /*! \brief Checking if two arrays contains same elements. */
+  static bool IsSameArray(const Array<PrimExpr>& lhs, const Array<PrimExpr>& 
rhs) {
+    if (lhs.size() != rhs.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      if (!lhs[i].same_as(rhs[i])) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Checking the load indices and store indices follows elemwise 
pattern.
+   * It's elemwise pattern iff load indices and store indices are the same.
+   * E.g A[i, j] = B[i, j]
+   */
+  static bool IsElemwisePattern(const BufferStore& store, const BufferLoad& 
load) {
+    return IsSameArray(store->indices, load->indices);
+  }
+
+  /*!
+   * \brief Checking the load indices and store indices follows broadcast 
pattern.
+   * It's broadcast pattern iff all load indices are in the store indices in 
order
+   * E.g. A[i, j] = B[i] is broadcast since all load indices(`i`) are in the 
store indices
+   *      A[i, j] = B[i, k] is not broadcast since `k` are not in the store 
indices.
+   *      A[i, j] = B[j, i] is not broadcast the load indices are not in the 
same order as store's
+   */
+  static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad& 
load) {
+    size_t ndim_load_buf = load->buffer->shape.size();
+    size_t ndim_store_buf = store->buffer->shape.size();
+
+    for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) {
+      if (is_const_int(load->buffer->shape[i], 1) && 
is_const_int(load->indices[i], 0)) {
+        // Skip unit load dimensions
+        // E.g. A[i, j] = B[1, j] is still broadcast
+        continue;
+      }
+
+      // Try to find the i-th load indice in the store indices.
+      while (j < ndim_store_buf && 
!store->indices[j].same_as(load->indices[i])) {
+        ++j;
+      }
+
+      // It's not broadcast if we cannot find load indices in the store 
indices in order.
+      if (j == ndim_store_buf) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Checking the load indices and store indices follows injective 
pattern.
+   * It's injective pattern iff all load indice vars are in the store indices, 
no matter orders.
+   * Note that we only support store indices are direct vars so far, which can 
be enhance later.
+   * E.g. A[i, j] = B[j, i] is injective.
+   *      A[i, j] = B[i - j] is injective since the load indice vars are only 
i, j
+   */
+  static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& 
load) {
+    std::unordered_set<const tir::VarNode*> vars;
+    for (const PrimExpr& store_index : store->indices) {
+      if (const auto* v = store_index.as<tir::VarNode>()) {
+        vars.insert(v);
+      } else {
+        return false;
+      }
+    }
+    for (const PrimExpr& load_index : load->indices) {
+      // return false if there are vars used in load indices but not in store 
indices.
+      if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return 
!vars.count(var); })) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Checking the load indices and store indices allow data reuse.
+   * It allow data reuse iff there is any vars in load indices but they are 
not in store indices
+   * E.g. Store = A[i, j] and Load = B[i, j, k] allow data reuse.
+   *      Store = A[i, j] and Load = B[i, j + k] allow data reuse.
+   */
+  static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad& 
load) {
+    std::unordered_set<const tir::VarNode*> vars;
+    for (const PrimExpr& index : store->indices) {
+      if (const auto* v = index.as<tir::VarNode>()) {
+        vars.insert(v);
+      } else {
+        return false;
+      }
+    }
+    for (const PrimExpr& index : load->indices) {
+      PreOrderVisit(index, [&](const ObjectRef& node) {
+        if (const auto* v = node.as<tir::VarNode>()) {
+          if (vars.count(v)) {
+            vars.erase(v);
+          }
+        }
+        return true;
+      });
+    }
+    return !vars.empty();
+  }
+
+  /*! \brief Checking if the stmt is multiply add. E.g. C[i, j] += A[i, k] * 
B[j, k] */
+  static bool IsFMA(const Stmt& body) {
+    if (const auto* store = body.as<BufferStoreNode>()) {
+      if (const auto* add = store->value.as<AddNode>()) {
+        if (const auto* l = add->a.as<BufferLoadNode>()) {
+          if (const auto* r = add->b.as<MulNode>()) {
+            bool incremental =
+                store->buffer.same_as(l->buffer) && 
IsSameArray(store->indices, l->indices);
+            const auto* l_load = r->a.as<BufferLoadNode>();
+            const auto* r_load = r->b.as<BufferLoadNode>();
+            if (incremental && l_load && r_load) {
+              return IsAllowReusePattern(GetRef<BufferStore>(store), 
GetRef<BufferLoad>(l_load)) &&
+                     IsAllowReusePattern(GetRef<BufferStore>(store), 
GetRef<BufferLoad>(r_load));
+            }
+          }
+        }
+      }
+    }
+    return false;
+  }
+
+  /*!
+   * \brief Checking if it is pure reduce pattern.
+   * It's pure reduce pattern iff all reduces axis are directly reduce var
+   * E.g. A[i] = sum(B[i, j]) is pure reduce
+   *      A[i] = sum(B[i, j + k]) is not pure reduce
+   *      pooling is not pure reduce
+   */
+  static bool IsPureReducePattern(Array<tir::Var> reduce_loops, 
Array<PrimExpr> indices) {
+    for (const PrimExpr& e : indices) {
+      int id = -1;
+      if (UsesVar(e, [&](const tir::VarNode* var) {
+            for (size_t i = 0; i < reduce_loops.size(); ++i) {
+              if (reduce_loops[i].get() == var) {
+                id = i;
+                return true;
+              }
+            }
+            return false;
+          })) {
+        if (!reduce_loops[id].same_as(e)) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+ private:
+  /*!
+   * \brief The BufferStore node in the current block.
+   * \note We only support one BufferStore node in a block (ususally generated 
by TE compute)
+   */
+  Optional<BufferStore> store_;
+  /*! \brief The BufferLoad nodes in the current block. */
+  Array<BufferLoad> loads_;
+  /*! \brief The result of op pattern. */
+  relay::OpPatternKind kind_ = relay::kElemWise;
+  /*! \brief The buffers from function params. I.e. the input and output 
buffers. */
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> param_buffers_;
+
+ public:
+  relay::OpPatternKind GetResult() { return kind_; }
+};
+
+relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) {
+  PatternKindAnalyzer analyzer(func);
+  analyzer(func->body);
+  return analyzer.GetResult();
+}
+
+bool HasReshapePattern(const PrimFunc& func) {
+  class ReshapeDetector : public StmtVisitor {
+   public:
+    static bool Detect(const Buffer& src_buffer, const Buffer& dst_buffer, 
Stmt stmt) {
+      ReshapeDetector detector(src_buffer, dst_buffer);
+      detector(stmt);
+      return detector.is_reshape_;
+    }
+
+   private:
+    explicit ReshapeDetector(const Buffer& src_buffer, const Buffer& 
dst_buffer)
+        : is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer) 
{}
+
+    void VisitStmt_(const ForNode* loop) final {
+      ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
+      // To detect the reshape pattern, we require each For to have
+      // either another For or a BlockRealize as body.
+      if (!(loop->body->IsInstance<ForNode>() || 
loop->body->IsInstance<BlockRealizeNode>())) {
+        return;
+      }
+      this->VisitStmt(loop->body);
+    }
+
+    void VisitStmt_(const BlockRealizeNode* block_realize) final {
+      // Constructing the mapping from block iterators to iterator
+      // binding values. The mapping will be used in the substitution of
+      // the flattened buffer access index.
+      const Block& block = block_realize->block;
+      const Array<IterVar>& block_iter = block->iter_vars;
+      const Array<PrimExpr>& iter_values = block_realize->iter_values;
+      ICHECK_EQ(block_iter.size(), iter_values.size());
+      int n_iter = block_iter.size();
+      for (int i = 0; i < n_iter; ++i) {
+        // To detect the reshape pattern, we require each block iter to be 
data-parallel.
+        if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) {
+          return;
+        }
+        var_map_.Set(block_iter[i]->var, iter_values[i]);
+      }
+
+      // Recurse into the block.
+      this->VisitStmt(block);
+    }
+
+    void VisitStmt_(const BlockNode* block) final {
+      // Step 0. If the block body is a ForNode, recurse into it.
+      if (block->body->IsInstance<ForNode>()) {
+        this->VisitStmt(block->body);
+        return;
+      }
+
+      // Step 1. Get the load/store pattern of the block body.
+      // To detect the reshape pattern, we require the block body to be a
+      // BufferStore, which has a BufferLoad as value.
+      const auto* buffer_store = block->body.as<BufferStoreNode>();
+      if (buffer_store == nullptr) {
+        return;
+      }
+      const auto* buffer_load = buffer_store->value.as<BufferLoadNode>();
+      if (buffer_load == nullptr) {
+        return;
+      }
+      // Further, we require the buffer being stored and being loaded to
+      // match the parameter of the PrimFunc, namely `dst_buffer_` and 
`src_buffer_`.
+      if (!(buffer_store->buffer.same_as(dst_buffer_) &&
+            buffer_load->buffer.same_as(src_buffer_))) {
+        return;
+      }
+
+      // Step 3. Calculate the flattened access index according to the 
load/store pattern.
+      auto f_calc_flattened_idx = [](const Buffer& buffer, const 
Array<PrimExpr>& indices) {
+        ICHECK_EQ(indices.size(), buffer->shape.size());
+        int ndim = indices.size();
+        PrimExpr idx = 0;
+        for (int i = 0; i < ndim; ++i) {
+          idx = idx * buffer->shape[i] + indices[i];
+        }
+        return idx;
+      };
+      PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, 
buffer_load->indices);
+      PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, 
buffer_store->indices);
+
+      // Step 4. Substitute the block iterators in the flattened index
+      // with loop variables, and check if we can prove their equality.
+      src_idx = tir::Substitute(std::move(src_idx), var_map_);
+      dst_idx = tir::Substitute(std::move(dst_idx), var_map_);
+      if (ana_.CanProveEqual(src_idx, dst_idx)) {
+        this->is_reshape_ = true;
+      }
+    }
+
+    bool is_reshape_;
+    /*! \brief The mapping from block vars to block binding values. */
+    Map<tir::Var, PrimExpr> var_map_;
+    const Buffer& src_buffer_;
+    const Buffer& dst_buffer_;
+    arith::Analyzer ana_;
+  };
+
+  if (func->params.size() < 2) {
+    return false;
+  }
+  Optional<Buffer> src_buffer = func->buffer_map.Get(func->params.front());
+  Optional<Buffer> dst_buffer = func->buffer_map.Get(func->params.back());
+  if (!(src_buffer.defined() && dst_buffer.defined())) {
+    return false;
+  }
+
+  // To detect the reshape pattern, we require each For to have
+  // either another For or a BlockRealize as body.
+  ICHECK(func->body->IsInstance<BlockRealizeNode>());
+  return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), 
func->body);
+}
+
+TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/vm_builtin_lower.cc
new file mode 100644
index 0000000000..6613b39626
--- /dev/null
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/backend/vm/vm_builtin_lower.cc
+ * \brief Lowers most builtin functions and packed calls.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/backend.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/type.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+
+// This pass lowers most ops to VM specific builtins.
+// TODO(relax-team): revisit after PrimValue.
+class VMBuiltinLowerMutator : public ExprMutator {
+ public:
+  using ExprMutator::VisitExpr_;
+
+  // A workaround to remove the CallNodes of killing tensors and storages.
+  void VisitBinding_(const VarBindingNode* binding) final {
+    const auto* call = binding->value.as<CallNode>();
+    if (call != nullptr && (call->op == mem_kill_storage_op_ || call->op == 
mem_kill_tensor_op_)) {
+      return;
+    }
+    ExprMutator::VisitBinding_(binding);
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) final {
+    // post-order mutation
+    Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
+
+    if (call->op == call_tir_dyn_op_) {
+      return CallTIRDyn(call);
+    } else if (call->op == reshape_op_) {
+      return Reshape(call);
+    } else if (call->op == make_closure_op_) {
+      return MakeClosure(call);
+    } else if (call->op == invoke_closure_op_) {
+      return InvokeClosure(call);
+    } else if (call->op == alloc_tensor_op_) {
+      return MakeAllocTensor(call);
+    } else if (call->op == mem_alloc_storage_op_) {
+      return MakeMemAllocStorage(call);
+    } else if (call->op == mem_alloc_tensor_op_) {
+      return MakeMemAllocTensor(call);
+    } else {
+      return call;
+    }
+  }
+
+  Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const {
+    // Question: what if the dtype of tensor_type is unknown?
+    // Symbolic/static shape case
+    if (auto* shape_expr = shape.as<ShapeExprNode>()) {
+      int64_t elem_bytes = runtime::GetVectorBytes(dtype);
+      PrimExpr ret = IntImm(DataType::Int(64), elem_bytes);
+      for (PrimExpr dim : shape_expr->values) {
+        ret = ret * dim;
+      }
+      return ShapeExpr({ret});
+    } else {
+      return Call(builtin_compute_alloc_shape_, {shape, DataTypeImm(dtype)}, 
Attrs(),
+                  {GetStructInfo(shape)});
+    }
+  }
+
+  Expr MakeAllocTensor(const Call& call) {
+    ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
+    DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[1]);
+    DataType dtype = output_dtype->value;
+    Expr storage_size = ComputeStorageSize(output_shape, dtype);
+    PrimValue runtime_device_index = Downcast<PrimValue>(call->args[2]);
+    Var storage = builder_->Emit(
+        Call(vm_alloc_storage_op_, {storage_size, runtime_device_index, 
output_dtype}, Attrs()),
+        "storage");
+    Expr shape = call->args[0];
+    PrimValue offset = PrimValue::Int64(0);
+    return Call(vm_alloc_tensor_op_, {storage, offset, shape, 
DataTypeImm(dtype)}, Attrs());
+  }
+
+  Expr MakeMemAllocStorage(const Call& call) {
+    PrimValue runtime_device_index = Downcast<PrimValue>(call->args[1]);
+    DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[3]);
+    return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index, 
output_dtype}, Attrs());
+  }
+
+  Expr MakeMemAllocTensor(const Call& call) {
+    PrimValue offset = Downcast<PrimValue>(call->args[1]);
+    DataTypeImm dtype = Downcast<DataTypeImm>(call->args[3]);
+    return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2], 
dtype}, Attrs());
+  }
+
+  Expr CallTIRDyn(const Call& call_node) {
+    ICHECK(call_node->args.size() == 2);
+    ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
+    ICHECK(call_node->args[1]->IsInstance<TupleNode>());
+    Array<Expr> args;
+
+    auto tir_args = Downcast<Tuple>(call_node->args[1]);
+    args.push_back(call_node->args[0]);
+    for (Expr arg : tir_args->fields) {
+      args.push_back(arg);
+    }
+    return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_});
+  }
+
+  Expr Reshape(const Call& call_node) {
+    ICHECK(call_node->args.size() == 2);
+    ICHECK(call_node->struct_info_.defined());
+    CHECK(call_node->args[1]->IsInstance<ShapeExprNode>())
+        << "VMBuiltinLower expects the shape arg of reshape op to be a 
ShapeExpr";
+    return Call(builtin_reshape_, call_node->args, Attrs(), 
{GetStructInfo(call_node)});
+  }
+
+  Expr MakeClosure(const Call& call_node) {
+    ICHECK(call_node->args.size() == 2);
+    ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
+    ICHECK(call_node->args[1]->IsInstance<TupleNode>());
+
+    Array<Expr> args;
+    auto func = call_node->args[0];
+    auto closure_args = Downcast<Tuple>(call_node->args[1]);
+
+    args.push_back(func);
+    for (Expr arg : closure_args->fields) {
+      args.push_back(arg);
+    }
+
+    return Call(builtin_make_closure_, args, Attrs(), {object_sinfo_});
+  }
+
+  Expr InvokeClosure(const Call& call_node) {
+    ICHECK(call_node->args.size() == 2);
+    ICHECK(call_node->args[0]->IsInstance<VarNode>());
+    ICHECK(call_node->args[1]->IsInstance<TupleNode>());
+
+    Array<Expr> args;
+
+    args.push_back(call_node->args[0]);
+
+    // args for the invoke_closure
+    auto invoke_closure_args = Downcast<Tuple>(call_node->args[1]);
+    for (Expr arg : invoke_closure_args->fields) {
+      args.push_back(arg);
+    }
+    return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_, 
Tuple(args)}, Attrs(),
+                {object_sinfo_});
+  }
+
+  const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
+  const StructInfo object_sinfo_ = ObjectStructInfo();
+  const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({}));
+  // object to pattern match.
+  const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
+  const Op& reshape_op_ = Op::Get("relax.reshape");
+  const Op& make_closure_op_ = Op::Get("relax.make_closure");
+  const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
+  const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
+  const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage");
+  const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor");
+  const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage");
+  const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor");
+  // functions to lower to
+  const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
+  const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+  // Function to compute allocated shape.
+  const ExternFunc 
builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
+  const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
+  const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
+  const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
+  const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
+};
+
+Expr VMBuiltinLower(const Expr& e) { return 
VMBuiltinLowerMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass VMBuiltinLower() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) { return 
Downcast<Function>(VMBuiltinLower(f)); };
+  return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index ca66b0a9ef..ba167a45bc 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -226,6 +226,87 @@ Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t 
runtime_device_index) {
 
 
TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor);
 
+// memory planning alloc_storage
+
+RELAY_REGISTER_OP("relax.memory.alloc_storage")
+    .set_num_inputs(4)
+    .add_argument("total_space", "Expr", "The total space of the storage to 
allocate.")
+    .add_argument(
+        "virtual_device_index", "int64_t",
+        "The virtual device index indicating on which device the storage is to 
be allocated, "
+        "Index -1 is reserved for the host device.")
+    .add_argument("storage_scope", "string",
+                  "The storage scope of the storage to allocate. Default is 
global.")
+    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeAllocStorage(Expr size, int64_t virtual_device_index, std::string 
storage_scope,
+                      DataType dtype) {
+  static const Op& op = Op::Get("relax.memory.alloc_storage");
+  return Call(
+      op,
+      {size, PrimValue::Int64(virtual_device_index), StringImm(storage_scope), 
DataTypeImm(dtype)},
+      Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage);
+
+// memory planning alloc_tensor
+
+StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& 
ctx) {
+  ICHECK(GetStructInfoAs<ShapeStructInfoNode>(call->args[2]))
+      << "must be a Expr of ShapeStructInfo, but got " << 
call->args[1]->GetTypeKey();
+  DataType out_dtype;
+  if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) {
+    const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
+    out_dtype = dtype_imm->value;
+  }
+  return TensorStructInfo(call->args[2], out_dtype);
+}
+
+RELAY_REGISTER_OP("relax.memory.alloc_tensor")
+    .set_num_inputs(4)
+    .add_argument("storage", "Expr", "The storage to allocate the tensor to.")
+    .add_argument("offset", "int", "Storage offset to allocate the tensor.")
+    .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
+    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoMemAllocTensor);
+
+Expr MakeMemAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) {
+  static const Op& op = Op::Get("relax.memory.alloc_tensor");
+  return Call(op, {storage, PrimValue::Int64(offset), shape, 
DataTypeImm(dtype)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor);
+
+// memory planning kill_storage
+
+RELAY_REGISTER_OP("relax.memory.kill_storage")
+    .set_num_inputs(1)
+    .add_argument("storage", "Expr", "The storage to be killed.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo);
+
+Expr MakeMemKillStorage(Expr storage) {
+  static const Op& op = Op::Get("relax.memory.kill_storage");
+  return Call(op, {storage}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage);
+
+// memory planning kill_tensor
+
+RELAY_REGISTER_OP("relax.memory.kill_tensor")
+    .set_num_inputs(1)
+    .add_argument("tensor", "Expr", "The tensor to be killed.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo);
+
+Expr MakeMemKillTensor(Expr tensor) {
+  static const Op& op = Op::Get("relax.memory.kill_tensor");
+  return Call(op, {tensor}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor);
+
 // vm alloc_storage
 
 RELAY_REGISTER_OP("relax.vm.alloc_storage")
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
new file mode 100644
index 0000000000..2088a8306e
--- /dev/null
+++ b/src/relax/op/tensor/manipulate.cc
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file manipulate.cc
+ * \brief Manipulation operators.
+ */
+
+#include "manipulate.h"
+
+#include <algorithm>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+// Helper function for flatten and reshape.
+PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) {
+  PrimExpr shape_prod = IntImm(DataType::Int(64), 1);
+  for (PrimExpr value : shape_values) {
+    shape_prod *= value;
+  }
+  return shape_prod;
+}
+
+/* relax.reshape */
+Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
+  if (const auto* e = shape.as<ExprNode>()) {
+    return GetRef<Expr>(e);
+  }
+
+  const auto* array = shape.as<ArrayNode>();
+  CHECK(array != nullptr) << "Reshape only expects the input new shape to be 
either an Expr or an "
+                             "Array of PrimExprs. However, the given new shape 
is "
+                          << shape;
+  int dim_to_infer = -1;
+  PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
+  for (int i = 0; i < static_cast<int>(array->size()); ++i) {
+    const auto* _len = array->at(i).as<PrimExprNode>();
+    CHECK(_len != nullptr) << "Reshape only expects the input new shape to be 
either an Expr or an "
+                              "Array of PrimExprs. However, the given new 
shape is "
+                           << shape;
+    PrimExpr len = GetRef<PrimExpr>(_len);
+    CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be 
all "
+                                  "integers. However, the give new shape is "
+                               << shape;
+    const auto* int_len = len.as<IntImmNode>();
+    if (int_len != nullptr && int_len->value == -1) {
+      CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the 
new shape. However, "
+                                    "there are multiple \"-1\" in the given 
new shape  "
+                                 << shape;
+      dim_to_infer = i;
+    } else {
+      CHECK(int_len == nullptr || int_len->value > 0)
+          << "Reshape requires all values in the new shape to be positive 
except a single \"-1\". "
+             "However, the given new shape is "
+          << shape;
+      // We expect any symbolic not to signal the intent of -1, and therefore 
do no check for
+      // symbolic value here.
+      new_shape_prod = new_shape_prod * len;
+    }
+  }
+
+  Array<PrimExpr> array_ref = GetRef<Array<PrimExpr>>(array);
+  // When there is no dimension to infer, just return the input array as 
ShapeExpr.
+  if (dim_to_infer == -1) {
+    return ShapeExpr(array_ref);
+  }
+
+  // Otherwise, we require the input tensor to have known shape value for 
inference.
+  const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(data);
+  CHECK(data_sinfo != nullptr)
+      << "Reshape expects the input data to be a Tensor. However, the given 
input is "
+      << data->struct_info_->GetTypeKey();
+  CHECK(data_sinfo->shape.defined())
+      << "Reshape expects the input tensor to have known shape when there is 
some dimension length "
+         "to infer. However, the given input has no shape.";
+  const auto* shape_sinfo = 
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value());
+  CHECK(shape_sinfo != nullptr && shape_sinfo->values.defined())
+      << "Reshape expects the input tensor to have known shape when there is 
some dimension length "
+         "to infer. However, the given input shape is "
+      << data_sinfo->shape << " whose shape value is unknown.";
+
+  arith::Analyzer analyzer;
+  PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
+  array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, 
new_shape_prod)));
+  return ShapeExpr(array_ref);
+}
+
+Expr reshape(Expr x, ObjectRef shape) {
+  Expr shape_in_expr = ConvertNewShapeToExpr(x, shape);
+  static const Op& op = Op::Get("relax.reshape");
+  return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape);
+
+StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call->span) << "Reshape op should take 
2 arguments");
+  }
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* new_shape_sinfo = 
GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call->span)
+                     << "Reshape requires the input data to be Tensor. 
However, the given one is "
+                     << call->args[0]->struct_info_->GetTypeKey());
+  }
+  if (new_shape_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call->span)
+        << "Reshape requires the input new shape to be Shape. However, the 
given one is "
+        << call->args[1]->struct_info_->GetTypeKey());
+  }
+
+  Optional<Array<PrimExpr>> old_shape_values;
+  if (data_sinfo->shape.defined()) {
+    const auto* old_shape_sinfo = 
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value());
+    ICHECK_NOTNULL(old_shape_sinfo);
+    old_shape_values = old_shape_sinfo->values;
+  }
+
+  if (new_shape_sinfo->values.defined() && old_shape_values.defined()) {
+    PrimExpr new_shape_prod = 
ComputeShapeProduct(new_shape_sinfo->values.value());
+    PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value());
+    if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) {
+      ctx->ReportFatal(Diagnostic::Error(call->span)
+                       << "Reshape expects the new shape to be convertible 
from the old shape. "
+                          "However, the old shape is "
+                       << data_sinfo->shape << ", with product " << 
old_shape_prod
+                       << ", while the new shape is " << call->args[1] << ", 
with product "
+                       << new_shape_prod);
+    }
+  }
+  return TensorStructInfo(call->args[1], data_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.reshape")
+    .set_num_inputs(2)
+    .add_argument("x", "Tensor", "The input tensor.")
+    .add_argument("shape", "Shape", "The input new shape.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoReshape);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/include/tvm/relax/backend.h b/src/relax/op/tensor/manipulate.h
similarity index 58%
copy from include/tvm/relax/backend.h
copy to src/relax/op/tensor/manipulate.h
index 4ebeacac0f..1a3eb0547d 100644
--- a/include/tvm/relax/backend.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -18,27 +18,28 @@
  */
 
 /*!
- * \file tvm/relax/backend.h
- * \brief Relax backend specific transformation passes.
+ * \file manipulate.h
+ * \brief The functions to make Relax tensor manipulation operator calls.
  */
-#ifndef TVM_RELAX_BACKEND_H_
-#define TVM_RELAX_BACKEND_H_
+#ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_
+#define TVM_RELAX_OP_TENSOR_MANIPULATE_H_
 
-#include <tvm/relax/transform.h>
+#include "../op_common.h"
 
 namespace tvm {
 namespace relax {
-namespace transform {
 
 /*!
- * \brief Lower the shape expression in relax to VM shape heap and TIR 
functions.
- *
- * \return The Pass.
+ * \brief Reshape the input array, supporting `-1` inference in the new
+ * shape when the new shape is given as an Array of PrimExpr.
+ * \param x The input data to the operator.
+ * \param shape The new shape. Should be compatible with the original shape.
+ * It is required to be either an Array of PrimExpr, or a Shape in Relax
+ * \return The reshaped result.
  */
-TVM_DLL Pass VMShapeLower();
+Expr reshape(Expr x, ObjectRef shape);
 
-}  // namespace transform
 }  // namespace relax
 }  // namespace tvm
 
-#endif  // TVM_RELAX_BACKEND_H_
+#endif  // TVM_RELAX_OP_TENSOR_MANIPULATE_H_
diff --git a/src/relax/transform/attach_global_symbol.cc 
b/src/relax/transform/attach_global_symbol.cc
new file mode 100644
index 0000000000..be779e97bc
--- /dev/null
+++ b/src/relax/transform/attach_global_symbol.cc
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/attach_global_symbol.cc
+ * \brief Attach global_symbol to Relax functions and TIR Primfuncs for 
codegen.
+ */
+
+#include <tvm/relax/transform.h>
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace relax {
+
+class GlobalSymbolAttacher {
+ public:
+  explicit GlobalSymbolAttacher(IRModule mod) : mod_(mod) {}
+
+  IRModule Attach() {
+    IRModule ret;
+    for (auto& p : mod_->functions) {
+      BaseFunc func = p.second;
+      if (auto* prim_func = func.as<tir::PrimFuncNode>()) {
+        func = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol", 
p.first->name_hint);
+      } else if (auto* relax_func = func.as<FunctionNode>()) {
+        func = WithAttr(GetRef<Function>(relax_func), "global_symbol", 
p.first->name_hint);
+      } else {
+        LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey();
+        throw;
+      }
+      ret->Add(p.first, func);
+    }
+    return ret;
+  }
+
+ private:
+  IRModule mod_;
+};
+
+namespace transform {
+
+Pass AttachGlobalSymbol() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) { return 
GlobalSymbolAttacher(mod).Attach(); };
+  return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/call_tir_rewrite.cc 
b/src/relax/transform/call_tir_rewrite.cc
new file mode 100644
index 0000000000..2ea039e022
--- /dev/null
+++ b/src/relax/transform/call_tir_rewrite.cc
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/call_tir_rewrite.cc
+ * \brief Perform explicit tensor allocation for call_tir.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/op.h>
+
+#include "../../relay/transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relax {
+
+// ==================
+// CallTIRMutator
+// Perform explicit tensor allocation for call_tir.
+// Example:
+// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32")
+// -->
+// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32")
+// rx.call_packed(func, x, gv0)
+
+class CallTIRMutator : public ExprMutator {
+ public:
+  using ExprMutator::VisitExpr_;
+  Expr VisitExpr_(const CallNode* call) override {
+    // post-order mutation
+    Expr expr = VisitExprPostOrder_(call);
+    call = expr.as<CallNode>();
+
+    static const Op& call_tir_op = Op::Get("relax.call_tir");
+    static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+    static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");
+
+    if (call->op == call_tir_op) {
+      Array<Expr> outs;
+      if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr)) 
{
+        // single output case
+        const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value();
+        ICHECK(tensor_sinfo->shape.defined())
+            << "the TensorStructInfo shape of call_tir has not populated";
+        outs.push_back(
+            builder_->Emit(Call(alloc_tensor_op,  //
+                                
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
+                                 DataTypeImm(tensor_sinfo->dtype), 
PrimValue::Int64(0)},  //
+                                Attrs()),
+                           "alloc"));
+      } else if (const auto& _tuple_sinfo = 
MatchStructInfo<TupleStructInfo>(expr)) {
+        // multiple output case
+        const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value();
+        for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+          const auto& field = tuple_sinfo->fields[i];
+
+          ICHECK(field->IsInstance<TensorStructInfoNode>())
+              << "call_tir expects Tuple of TensorStructInfo, but got " << 
field
+              << " as an element of TupleStructInfo";
+          const auto& field_tensor = Downcast<TensorStructInfo>(field);
+          ICHECK(field_tensor->shape.defined())
+              << "call_tir expects all TensorStructInfo has shape, but got " 
<< field_tensor
+              << " as an element of TupleStructInfo";
+          outs.push_back(
+              builder_->Emit(Call(alloc_tensor_op,
+                                  
{Downcast<ShapeExpr>(field_tensor->shape.value()),
+                                   DataTypeImm(field_tensor->dtype), 
PrimValue::Int64(0)},
+                                  Attrs()),
+                             "alloc"));
+        }
+      } else {
+        LOG(FATAL) << "TypeError: The struct info of call_tir expects to be 
TensorStructInfo or "
+                      "TupleStructInfo, but got"
+                   << expr->struct_info_;
+      }
+
+      Array<Expr> args;
+      if (call->args[1].as<TupleNode>()) {
+        args = Downcast<Tuple>(call->args[1])->fields;
+        args.insert(args.end(), outs.begin(), outs.end());
+
+        if (call->args.size() == 2) {
+          builder_->Emit(Call(call->args[0], args), "_");
+        } else {
+          // unpack semantics
+          args.push_back(call->args[2]);
+          builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), 
"_");
+        }
+      } else {
+        args = outs;
+        args.insert(args.begin(), call->args[1]);
+        builder_->Emit(Call(call->args[0], args), "_");
+      }
+
+      if (outs.size() == 1) {
+        return outs[0];
+      }
+      return std::move(Tuple(outs));
+    }
+
+    return GetRef<Expr>(call);
+  }
+};
+
+Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass CallTIRRewrite() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) { return 
Downcast<Function>(CallTIRRewrite(f)); };
+  return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc 
b/src/relax/transform/rewrite_dataflow_reshape.cc
new file mode 100644
index 0000000000..aec0911ecc
--- /dev/null
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/rewrite_dataflow_reshape.cc
+ * \brief Transform all reshape within dataflow block to a relax.reshape 
operator
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+class DataflowReshapeRewriter : public ExprMutator {
+ public:
+  explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  using ExprMutator::VisitExpr_;
+
+  BindingBlock VisitBindingBlock(const BindingBlock& block) final {
+    // We only rewrite the bindings inside dataflow blocks.
+    if (const auto* dataflow_block = block.as<DataflowBlockNode>()) {
+      return VisitBindingBlock_(dataflow_block);
+    } else {
+      return block;
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) final {
+    // We only rewrite the bindings that are not dataflow output (which means 
they are not
+    // externally referenced)
+    if (!binding->var->IsInstance<DataflowVarNode>()) {
+      this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
+    } else {
+      ExprMutator::VisitBinding_(binding);
+    }
+  }
+
+  Expr VisitExpr_(const CallNode* call) final {
+    if (!IsCallingTIRReshape(call)) {
+      return GetRef<Call>(call);
+    }
+
+    // We bring the calls of reshape PrimFunc back to calls of high-level
+    // relax.reshape op, which will be lowered to calls of the ExternFunc
+    // vm.builtin.reshape in the VMBuiltinLower pass.
+    Array<Expr> args = Downcast<Tuple>(call->args[1])->fields;
+    ICHECK_EQ(args.size(), 1);
+    TensorStructInfo res_sinfo = 
Downcast<TensorStructInfo>(call->struct_info_);
+    ICHECK(res_sinfo->shape.defined());
+    return reshape(args[0], res_sinfo->shape.value());
+  }
+
+  bool IsCallingTIRReshape(const CallNode* call) {
+    static const Op& call_tir_op = Op::Get("relax.call_tir");
+    if (call->op != call_tir_op) {
+      return false;
+    }
+    const auto* gv = call->args[0].as<GlobalVarNode>();
+    if (gv == nullptr) {
+      return false;
+    }
+    const auto* func = 
mod_->functions.Get(GetRef<GlobalVar>(gv)).as<tir::PrimFuncNode>();
+    ICHECK_NOTNULL(func);
+    return HasReshapePattern(GetRef<tir::PrimFunc>(func));
+  }
+
+  const IRModule& mod_;
+};
+
+Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) {
+  return DataflowReshapeRewriter(mod)(f);
+}
+
+namespace transform {
+
+Pass RewriteDataflowReshape() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(RewriteDataflowReshape(f, m));
+      };
+  return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape")
+    .set_body_typed(RewriteDataflowReshape);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/to_non_dataflow.cc 
b/src/relax/transform/to_non_dataflow.cc
new file mode 100644
index 0000000000..db2e9d7ee5
--- /dev/null
+++ b/src/relax/transform/to_non_dataflow.cc
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/to_non_dataflow.cc
+ * \brief Transform all dataflow structure to non-dataflow version.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+
+class ToNonDFMutator : public ExprMutator {
+ public:
+  Var VisitVarDef(const Var& var) final {
+    if (var.as<DataflowVarNode>()) {
+      Var new_var = Var(var->vid, GetStructInfo(var), var->span);
+      this->var_remap_[var->vid] = new_var;
+      return new_var;
+    }
+    return var;
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    builder_->BeginBindingBlock();
+    for (Binding binding : block->bindings) {
+      this->VisitBinding(binding);
+    }
+    return builder_->EndBlock();
+  }
+};
+
+Expr ToNonDataflow(const Expr& e) { return ToNonDFMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass ToNonDataflow() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) { return 
Downcast<Function>(ToNonDataflow(f)); };
+  return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_analysis.py 
b/tests/python/relax/test_analysis.py
new file mode 100644
index 0000000000..5dd83f2da2
--- /dev/null
+++ b/tests/python/relax/test_analysis.py
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Set, Union
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm import relax as rx
+from tvm.relax.analysis import has_reshape_pattern
+from tvm.script import relax as R, tir as T
+
+
+def test_reshape_pattern_reshape():
+    @T.prim_func
+    def reshape(
+        rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"),
+        T_reshape: T.Buffer((8, 3), "float32"),
+    ):
+        for i0, i1 in T.grid(8, 3):
+            with T.block("T_reshape"):
+                ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                T.reads(
+                    rxplaceholder[
+                        (ax0 * 3 + ax1) // 24,
+                        (ax0 * 3 + ax1) % 24 // 12,
+                        (ax0 * 3 + ax1) % 12 // 4,
+                        (ax0 * 3 + ax1) % 4,
+                    ]
+                )
+                T.writes(T_reshape[ax0, ax1])
+                T_reshape[ax0, ax1] = rxplaceholder[
+                    (ax0 * 3 + ax1) // 24,
+                    (ax0 * 3 + ax1) % 24 // 12,
+                    (ax0 * 3 + ax1) % 12 // 4,
+                    (ax0 * 3 + ax1) % 4,
+                ]
+
+    assert has_reshape_pattern(reshape)
+
+
+def test_reshape_pattern_reshape_scheduled():
+    @T.prim_func
+    def reshape_scheduled(
+        rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"),
+        T_reshape: T.Buffer((8, 3), "float32"),
+    ):
+        for i0_i1_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+            for i0_i1_fused_1 in T.thread_binding(24, thread="threadIdx.x"):
+                with T.block("T_reshape"):
+                    ax0 = T.axis.spatial(8, (i0_i1_fused_0 * 24 + 
i0_i1_fused_1) // 3)
+                    ax1 = T.axis.spatial(3, (i0_i1_fused_0 * 24 + 
i0_i1_fused_1) % 3)
+                    T.reads(
+                        rxplaceholder[
+                            (ax0 * 3 + ax1) // 24,
+                            (ax0 * 3 + ax1) % 24 // 12,
+                            (ax0 * 3 + ax1) % 12 // 4,
+                            (ax0 * 3 + ax1) % 4,
+                        ]
+                    )
+                    T.writes(T_reshape[ax0, ax1])
+                    T_reshape[ax0, ax1] = rxplaceholder[
+                        (ax0 * 3 + ax1) // 24,
+                        (ax0 * 3 + ax1) % 24 // 12,
+                        (ax0 * 3 + ax1) % 12 // 4,
+                        (ax0 * 3 + ax1) % 4,
+                    ]
+
+    assert has_reshape_pattern(reshape_scheduled)
+
+
+def test_reshape_pattern_expand_dims():
+    @T.prim_func
+    def expand_dims(
+        rxplaceholder: T.Buffer((2, 3, 4), "float32"),
+        expand_dims: T.Buffer((2, 1, 1, 1, 3, 1, 4, 1), "float32"),
+    ):
+        T.func_attr({"tir.noalias": True})
+        for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1):
+            with T.block("expand_dims"):
+                i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap(
+                    "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7]
+                )
+                T.reads(rxplaceholder[i0_1, i4_1, i6_1])
+                T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, 
i7_1])
+                expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = 
rxplaceholder[
+                    i0_1, i4_1, i6_1
+                ]
+
+    assert has_reshape_pattern(expand_dims)
+
+
+def test_reshape_pattern_with_raggedness():
+    @T.prim_func
+    def reshape_raggedness(
+        A: T.Buffer((100, 768), "float32"),
+        src_indptr: T.Buffer((9,), "int32"),
+        B: T.Buffer((100, 12, 64), "float32"),
+    ):
+        for b in T.serial(8):
+            with T.block("block0"):
+                vb = T.axis.spatial(8, b)
+                for i in T.serial(src_indptr[vb + 1] - src_indptr[vb]):
+                    for h in T.serial(12):
+                        for f in T.serial(64):
+                            with T.block("block1"):
+                                vi, vh, vf = T.axis.remap("SSS", [i, h, f])
+                                B[src_indptr[vb] + vi, vh, vf] = A[
+                                    src_indptr[vb] + vi, vh * 64 + vf
+                                ]
+
+    assert has_reshape_pattern(reshape_raggedness)
+
+
+def test_reshape_pattern_reject_seqstmt():
+    @T.prim_func
+    def identity_bias(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), 
"float32")):
+        C = T.alloc_buffer((128, 128), "float32")
+        for i0, i1 in T.grid(4, 4):
+            with T.block("identity"):
+                vi0, vi1 = T.axis.remap("SS", [i0, i1])
+                C[vi0, vi1] = A[vi0, vi1]
+        for i0, i1 in T.grid(4, 4):
+            with T.block("identity"):
+                vi0, vi1 = T.axis.remap("SS", [i0, i1])
+                B[vi0, vi1] = C[vi0, vi1] + T.float32(1)
+
+    @T.prim_func
+    def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), 
"float32")):
+        C = T.alloc_buffer((128, 128), "float32")
+        for i0, i1 in T.grid(4, 4):
+            with T.block("identity"):
+                vi0, vi1 = T.axis.remap("SS", [i0, i1])
+                C[vi0, vi1] = A[vi0, vi1]
+        for i0, i1 in T.grid(4, 4):
+            with T.block("identity"):
+                vi0, vi1 = T.axis.remap("SS", [i0, i1])
+                B[vi0, vi1] = C[vi0, vi1]
+
+    assert not has_reshape_pattern(identity_bias)
+    assert not has_reshape_pattern(identity_identity)
+
+
+def test_reshape_pattern_reject_reduction():
+    @T.prim_func
+    def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), 
"float32")):
+        for i0, i1 in T.grid(4, 4):
+            with T.block("identity"):
+                vi0, vi1 = T.axis.remap("SR", [i0, i1])
+                with T.init():
+                    B[vi0] = T.float32(0)
+                B[vi0] = B[vi0] + A[vi0, vi1]
+
+    assert not has_reshape_pattern(reduction)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform.py 
b/tests/python/relax/test_transform.py
new file mode 100644
index 0000000000..624b7877cd
--- /dev/null
+++ b/tests/python/relax/test_transform.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+from tvm import relax
+from tvm.ir import structural_equal
+from tvm.ir.base import assert_structural_equal
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
+def test_to_non_dataflow():
+    @tvm.script.ir_module
+    class TestToNonDataflow:
+        @R.function
+        def foo(x: R.Tensor(("m", "n"), "float32")):
+            m, n = T.var("int64"), T.var("int64")
+            with R.dataflow():
+                lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
+                gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32"))
+                R.output(gv0)
+            return gv0
+
+    mod = TestToNonDataflow
+
+    old_vars = []
+
+    def fvisit(e):
+        if isinstance(e, relax.Var):
+            nonlocal old_vars
+            old_vars.append(e)
+
+    relax.analysis.post_order_visit(mod["foo"], fvisit)
+    x, lv0, gv0 = old_vars
+
+    new_mod = relax.transform.ToNonDataflow()(mod)
+
+    new_vars = []
+
+    def fvisit(e):
+        if isinstance(e, relax.Var):
+            nonlocal new_vars
+            new_vars.append(e)
+
+    relax.analysis.post_order_visit(new_mod["foo"], fvisit)
+
+    assert x == new_vars[0]
+    assert lv0 != new_vars[1]
+    assert isinstance(lv0, relax.DataflowVar)
+    assert not isinstance(new_vars[1], relax.DataflowVar)
+
+    assert isinstance(gv0, relax.Var)
+    assert isinstance(new_vars[2], relax.Var)
+    assert gv0 == new_vars[2]
+
+
+def test_call_tir_rewrite():
+    @tvm.script.ir_module
+    class TestCallTIRRewrite:
+        @R.function
+        def foo(x: R.Tensor(("m", "n"), "float32")):
+            m, n = T.var("int64"), T.var("int64")
+            gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
+            return gv0
+
+    mod = TestCallTIRRewrite
+
+    # before rewrite
+    v0 = mod["foo"].body.blocks[0].bindings[0].var
+    s0 = mod["foo"].body.blocks[0].bindings[0].value
+    assert isinstance(s0, relax.Call)
+    assert s0.op.name == "relax.call_tir"
+
+    # after rewrite
+    new_mod = relax.transform.CallTIRRewrite()(mod)
+    func = new_mod["foo"]
+
+    block = func.body.blocks[0]
+    assert not isinstance(block, relax.DataflowBlock)
+
+    s1 = block.bindings[0].value
+    assert isinstance(s1, relax.Call)
+    assert s1.op.name == "relax.builtin.alloc_tensor"
+    assert isinstance(s1.args[0], relax.ShapeExpr)
+    assert structural_equal(s1.args[0], s0.sinfo_args[0].shape)
+    s2 = block.bindings[1].value
+    assert s2.op.global_symbol == "test.op.identity"
+
+
+def test_vm_builtin_lower():
+    @tvm.script.ir_module
+    class TestVMBuiltinLower:
+        @R.function
+        def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+            m, n = T.var("int64"), T.var("int64")
+            alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, 
dtype="float32")
+            _ = R.call_packed(
+                "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, 
dtype="float32"))
+            )
+            gv0 = alloc
+            return gv0
+
+    mod = TestVMBuiltinLower
+
+    # after vm builtin lowering
+    new_mod = relax.transform.VMBuiltinLower()(mod)
+    func = new_mod["foo"]
+
+    assert isinstance(new_mod, tvm.IRModule)
+    assert isinstance(func, tvm.relax.expr.Function)
+
+    block = func.body.blocks[0]
+    s1 = block.bindings[0].value
+    assert isinstance(s1, relax.Call)
+    assert s1.op.name == "relax.vm.alloc_storage"
+    s2 = block.bindings[1].value
+    assert isinstance(s2, relax.Call)
+    s3 = block.bindings[2].value
+    assert isinstance(s3, relax.Call)
+    assert isinstance(s3.op, relax.ExternFunc)
+    assert s3.op.global_symbol == "test.op.identity"
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py 
b/tests/python/relax/test_transform_attach_global_symbol.py
new file mode 100644
index 0000000000..edfc646e21
--- /dev/null
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+from tvm import tir, relax
+from tvm.ir import assert_structural_equal
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
+@tvm.script.ir_module
+class Before:
+    @T.prim_func
+    def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+        m = T.var("int64")
+        n = T.var("int64")
+        k = T.var("int64")
+        A = T.match_buffer(x, (m, n))
+        B = T.match_buffer(y, (n, k))
+        C = T.match_buffer(z, (m, k))
+
+        for i, j, k in T.grid(m, k, n):
+            with T.block("matmul"):
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                with T.init():
+                    C[vi, vj] = T.float32(0)
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+    @R.function
+    def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")) -> R.Tensor:
+        m, n, k = T.var("int64"), T.var("int64"), T.var("int64")
+        gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), 
dtype="float32"))
+        return gv0
+
+
+def test_basic():
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+            T.func_attr({"global_symbol": "tir_matmul"})
+            m = T.var("int64")
+            n = T.var("int64")
+            k = T.var("int64")
+            A = T.match_buffer(x, (m, n))
+            B = T.match_buffer(y, (n, k))
+            C = T.match_buffer(z, (m, k))
+
+            for i, j, k in T.grid(m, k, n):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+        @R.function
+        def main(
+            x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")
+        ) -> R.Tensor:
+            R.func_attr({"global_symbol": "main"})
+            m, n, k = T.var("int64"), T.var("int64"), T.var("int64")
+            gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), 
dtype="float32"))
+            return gv0
+
+    before = Before
+    expected = Expected
+    after = relax.transform.AttachGlobalSymbol()(before)
+    assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py 
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
new file mode 100644
index 0000000000..2c53d85c56
--- /dev/null
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R, tir as T
+
+
+def test_reshape_expand_dims():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def reshape(
+            rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"),
+            T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), 
"float32"),
+        ):
+            for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(
+                        rxplaceholder[
+                            (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+                            (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+                        ]
+                    )
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+                    T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+                        (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+                        (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+                    ]
+
+        @T.prim_func
+        def expand_dims(
+            rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), 
"float32"),
+            expand_dims: T.Buffer(
+                (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)),
+                "float32",
+            ),
+        ):
+            for i0, i1, i2, i3, i4 in T.grid(
+                T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)
+            ):
+                with T.block("expand_dims"):
+                    i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, 
i1, i2, i3, i4])
+                    T.reads(rxplaceholder[i0_1, i2_1, i4_1])
+                    T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1])
+                    expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = 
rxplaceholder[i0_1, i2_1, i4_1]
+
+        @R.function
+        def main(
+            x: R.Tensor((8, 3), dtype="float32")
+        ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"):
+            with R.dataflow():
+                y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), 
dtype="float32"))
+                z = R.call_tir(expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 
1, 3), "float32"))
+                R.output(z)
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def reshape(
+            rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"),
+            T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), 
"float32"),
+        ):
+            for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(
+                        rxplaceholder[
+                            (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) 
// T.int64(3),
+                            (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) 
% T.int64(3),
+                        ]
+                    )
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+                    T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+                        (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // 
T.int64(3),
+                        (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) % 
T.int64(3),
+                    ]
+
+        @T.prim_func
+        def expand_dims(
+            rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), 
"float32"),
+            expand_dims: T.Buffer(
+                (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), 
"float32"
+            ),
+        ):
+            for i0, i1, i2, i3, i4 in T.grid(
+                T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)
+            ):
+                with T.block("expand_dims"):
+                    i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, 
i1, i2, i3, i4])
+                    T.reads(rxplaceholder[i0_1, i2_1, i4_1])
+                    T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1])
+                    expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = 
rxplaceholder[i0_1, i2_1, i4_1]
+
+        @R.function
+        def main(
+            x: R.Tensor((8, 3), dtype="float32")
+        ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"):
+            with R.dataflow():
+                y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3))
+                # Note: `z` is the output var of the dataflow block, and is 
thus
+                # not expected to be rewritten.
+                z = R.call_tir(
+                    expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), 
dtype="float32")
+                )
+                R.output(z)
+            return z
+
+    assert relax.analysis.has_reshape_pattern(Module["expand_dims"])
+    mod = relax.transform.RewriteDataflowReshape()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_reshape_non_dataflow():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def reshape(
+            rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"),
+            T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), 
"float32"),
+        ):
+            for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(
+                        rxplaceholder[
+                            (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+                            (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+                        ]
+                    )
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+                    T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+                        (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+                        (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+                    ]
+
+        @R.function
+        def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), 
dtype="float32"):
+            y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), 
dtype="float32"))
+            return y
+
+    assert relax.analysis.has_reshape_pattern(Module["reshape"])
+    # The binding var of the call_tir is not a DataflowVar. So the pass does 
no change.
+    mod = relax.transform.RewriteDataflowReshape()(Module)
+    tvm.ir.assert_structural_equal(mod, Module)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
new file mode 100644
index 0000000000..534d2308da
--- /dev/null
+++ b/tests/python/relax/test_vm_build.py
@@ -0,0 +1,908 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from typing import Tuple, Callable
+
+import sys
+import tempfile
+import numpy as np
+import pytest
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax, rpc, te, tir, topi
+from tvm.contrib import utils
+from tvm.relax.testing import nn
+from tvm.script import relax as R, tir as T
+from tvm.relax.testing.vm import check_saved_func
+
+EXEC_MODE = ["bytecode"]
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_compile_simple(exec_mode):
+    @tvm.script.ir_module
+    class TestVMCompileStage0:
+        @R.function
+        def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), 
"float32")):
+            z = R.call_packed(
+                "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, 
dtype="float32"))
+            )
+            return y
+
+    mod = TestVMCompileStage0
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+    inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    vm["foo"](inp1, inp2)
+    tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, 
atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_match_check(exec_mode):
+    @tvm.script.ir_module
+    class TestMatchCheck:
+        @R.function
+        def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) -> 
R.Tensor(["m", "n"], dtype=None):
+            return y
+
+    mod = TestMatchCheck
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x0 = tvm.nd.array(np.zeros((1, 2)).astype("int32"))
+    y0 = tvm.nd.array(np.zeros((2, 1)).astype("float32"))
+    y1 = tvm.nd.array(np.zeros((1, 2)).astype("float32"))
+    y2 = tvm.nd.array(np.zeros((2, 1, 1)).astype("float32"))
+
+    vm["foo"](x0, y0)
+
+    with pytest.raises(RuntimeError, match=".*return.*"):
+        vm["foo"](x0, y1)
+
+    with pytest.raises(ValueError, match=".*return.*"):
+        vm["foo"](x0, y2)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_compile_stage2(exec_mode):
+    @tvm.script.ir_module
+    class TestVMCompileStage2:
+        @R.function
+        def foo(x: R.Tensor(dtype="float32")) -> R.Shape:
+            n, m = T.var("int64"), T.var("int64")
+            _ = R.match_cast(x, R.Tensor((n, m), "float32"))
+            return (n * 2, m * 3)
+
+    mod = TestVMCompileStage2
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    shape = (32, 16)
+    arr = tvm.nd.array(np.random.rand(*shape).astype("float32"))
+    res = vm["foo"](arr)
+    assert res[0] == shape[0] * 2
+    assert res[1] == shape[1] * 3
+
+    # dtype mismatch
+    with pytest.raises(ValueError, match=".*dtype.*"):
+        vm["foo"](tvm.nd.array(np.zeros((1, 2)).astype("int32")))
+
+    # ndim mismatch
+    with pytest.raises(ValueError, match=".*match_cast.*ndim.*"):
+        vm["foo"](tvm.nd.array(np.zeros((1,)).astype("float32")))
+
+    # type mismach
+    with pytest.raises(TypeError):
+        vm["foo"]([])
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_compile_stage3(exec_mode):
+    @tvm.script.ir_module
+    class TestVMCompileStage3:
+        @R.function
+        def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor:
+            with R.dataflow():
+                y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16), 
dtype="float32"))
+                R.output(y)
+            return y
+
+    mod = TestVMCompileStage3
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    shape = (32, 16)
+    inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+    res = vm["foo"](inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_compile_e2e(exec_mode):
+    @tvm.script.ir_module
+    class TestVMCompileE2E:
+        @R.function
+        def foo(x: R.Tensor(dtype="float32")) -> R.Tensor:
+            with R.dataflow():
+                n, m = T.var("int64"), T.var("int64")
+                _ = R.match_cast(x, R.Tensor((n, m), "float32"))
+                y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), 
dtype="float32"))
+                R.output(y)
+            return y
+
+    mod = TestVMCompileE2E
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    shape = (32, 16)
+    inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+    res = check_saved_func(vm, "foo", inp)
+    tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), 
rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_compile_e2e_func_param_with_shape(exec_mode):
+    @tvm.script.ir_module
+    class TestVMCompileE2E2:
+        @T.prim_func
+        def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+            T.func_attr({"global_symbol": "tir_matmul"})
+            m = T.var("int32")
+            n = T.var("int32")
+            k = T.var("int32")
+            A = T.match_buffer(x, (m, n))
+            B = T.match_buffer(y, (n, k))
+            C = T.match_buffer(z, (m, k))
+
+            for i, j, k in T.grid(m, k, n):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+        @R.function
+        def func(
+            x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")
+        ) -> R.Tensor:
+            m, k = T.var("int64"), T.var("int64")
+            gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), 
dtype="float32"))
+            return gv0
+
+    mod = TestVMCompileE2E2
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
+    weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
+    res = check_saved_func(vm, "func", data, weight)
+    expected = np.dot(data.numpy(), weight.numpy())
+    tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_emit_te_extern(exec_mode):
+    if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
+        print("skip because extern function is not available")
+        return
+    bb = relax.BlockBuilder()
+    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+    x = relax.Var("x", R.Tensor([n, m], "float32"))
+    y = relax.Var("y", R.Tensor([m, n], "float32"))
+
+    with bb.function("rx_cblas_matmul", [x, y]):
+        out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, 
transb=False)
+        bb.emit_func_output(out)
+
+    mod = bb.get()
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
+    weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
+    res = check_saved_func(vm, "rx_cblas_matmul", data, weight)
+    expected = np.dot(data.numpy(), weight.numpy())
+    tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_emit_te_concat(exec_mode):
+    # concatenate of two vectors of size (n,) and (m,)
+    bb = relax.BlockBuilder()
+    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+    x = relax.Var("x", R.Tensor([n], "float32"))
+    y = relax.Var("y", R.Tensor([m], "float32"))
+
+    def te_func(A, B):
+        C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], 
B[i - n]))
+        return C
+
+    with bb.function("rx_func", [x, y]):
+        x1 = bb.emit_te(te_func, x, y)
+        bb.emit_func_output(x1)
+
+    mod = bb.get()
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    inp = tvm.nd.array(
+        np.random.rand(
+            1,
+        ).astype(np.float32)
+    )
+    inp2 = tvm.nd.array(
+        np.random.rand(
+            2,
+        ).astype(np.float32)
+    )
+    res = check_saved_func(vm, "rx_func", inp, inp2)
+    tvm.testing.assert_allclose(
+        res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7
+    )
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_emit_te_dtype_change(exec_mode):
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor([n], "float32"))
+
+    # convert a tensor with dtype of float32 to int16
+    def te_func(A):
+        B = te.compute((n,), lambda i: A[i].astype("int16"))
+        return B
+
+    with bb.function("rx_func", [x]):
+        y = bb.emit_te(te_func, x)
+        bb.emit_func_output(y)
+
+    mod = bb.get()
+
+    new_mod = relax.transform.CallTIRRewrite()(mod)
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    inp = tvm.nd.array(
+        np.random.rand(
+            1,
+        ).astype(np.float32)
+    )
+    res = check_saved_func(vm, "rx_func", inp)
+    np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16"))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_emit_te_floor_symbolic_shape(exec_mode):
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor([n], "float32"))
+
+    def te_func(A):
+        C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1)
+        return C
+
+    with bb.function("rx_func", [x]):
+        x1 = bb.emit_te(te_func, x)
+        bb.emit_func_output(x1)
+
+    mod = bb.get()
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    shape = (9,)
+    inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+    res = check_saved_func(vm, "rx_func", inp)
+
+    def expected_output():
+        output_shape = (shape[0] // 2,)
+        return inp.numpy()[: output_shape[0]] + 1
+
+    tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, 
atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_emit_te_constant_param_cpu(exec_mode):
+    x_np = np.random.rand(2, 2).astype("float32")
+    c_np = np.random.rand(2, 2).astype("float32")
+
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 2), "float32"))
+    c = relax.const(c_np, "float32")
+    with bb.function("main", [x]):
+        with bb.dataflow():
+            lv0 = bb.emit_te(topi.add, x, c)
+            gv = bb.emit_output(lv0)
+        bb.emit_func_output(gv)
+
+    mod = bb.get()
+    exec = relax.vm.build(mod, "llvm", exec_mode=exec_mode)
+    dev = tvm.cpu()
+    vm = relax.VirtualMachine(exec, dev)
+
+    add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev))
+    tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, 
atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+@tvm.testing.requires_gpu
+def test_vm_emit_te_constant_param_gpu(exec_mode):
+    x_np = np.random.rand(2, 2).astype("float32")
+    c_np = np.random.rand(2, 2).astype("float32")
+
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 2), "float32"))
+    c = relax.const(c_np, "float32")
+    with bb.function("main", [x]):
+        with bb.dataflow():
+            lv0 = bb.emit_te(topi.add, x, c)
+            gv = bb.emit_output(lv0)
+        bb.emit_func_output(gv)
+
+    mod = bb.get()
+    sch = tvm.tir.Schedule(mod, debug_mask="all")
+    loops = sch.get_loops(sch.get_block(name="T_add", func_name="add"))
+    sch.bind(loops[0], "threadIdx.x")
+
+    exec = relax.vm.build(sch.mod, "cuda", exec_mode=exec_mode)
+    dev = tvm.cuda()
+    vm = relax.VirtualMachine(exec, dev)
+
+    add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev))
+    tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, 
atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_relax_symbolic_shape(exec_mode):
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor([n], "float32"))
+    y = relax.Var("y", R.Tensor([(n // 2) + 1], "float32"))
+
+    def te_func(A, B):
+        C = te.compute((n,), lambda i: A[i] + B[i // 2])
+        return C
+
+    with bb.function("rx_func", [x, y]):
+        x1 = bb.emit_te(te_func, x, y)
+        bb.emit_func_output(x1)
+
+    mod = bb.get()
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    shape1 = (5,)
+    shape2 = (3,)
+    inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32))
+    inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32))
+    res = check_saved_func(vm, "rx_func", inp, inp2)
+
+    def expected_output():
+        return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5]
+
+    tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, 
atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_relax_dyn_tir_shape(exec_mode):
+    # case where TIR variables are unbound in generated PrimFunc
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+
+    def te_func(A):
+        C = te.compute((n + 1), lambda i: A[i])
+        return C
+
+    with bb.function("rx_func"):
+        x = nn.Placeholder((n,), dtype="float32", name="x")
+        y = nn.Placeholder((n + 1,), dtype="float32", name="y")
+
+        x1 = bb.emit_te(te_func, y)
+        bb.emit_func_output(x1, params=[x, y])
+
+    mod = bb.get()
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+    ex.mod.export_library("exec.so")
+    exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so"))
+    os.remove("exec.so")
+    assert ex.as_text() == exec1.as_text()
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    inp = tvm.nd.array(np.random.rand(2).astype(np.float32))
+    inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32))
+
+    res = check_saved_func(vm, "rx_func", inp, inp2)
+
+    tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7, 
atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_tuple(exec_mode):
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+
+    with bb.function("rx_func"):
+        x = nn.Placeholder((n,), dtype="float32", name="x")
+        y = nn.Placeholder((n,), dtype="float32", name="y")
+        tup = relax.Tuple([x, y])
+        item = tup[0]
+        bb.emit_func_output([tup, item], params=[x, y])
+
+    mod = bb.get()
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    shape = (5,)
+    inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+    inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+    (res1, res2), res3 = vm["rx_func"](inp, inp2)
+
+    tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, 
atol=1e-7)
+    tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7, 
atol=1e-7)
+    tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7, 
atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_tuplegetitem(exec_mode):
+    @tvm.script.ir_module
+    class TestVMTupleGetItem:
+        @R.function
+        def tuple_get_item(
+            x: R.Tensor(ndim=2, dtype="float32"),
+            y: R.Tensor(ndim=2, dtype="float32"),
+        ):
+            t = (x, y)
+            a = t[0]
+            b = t[1]
+            c = R.call_packed("test.vm.add", a, b, 
sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
+            return c
+
+    mod = TestVMTupleGetItem
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+    y_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+    res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp)
+    tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), 
rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_lower_memory_alloc_storage_tensor(exec_mode):
+    @tvm.script.ir_module
+    class TestMemoryAllocStorageTensor:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")):
+            storage = R.memory.alloc_storage(
+                (24,), virtual_device_index=0, storage_scope="global", 
dtype="float32"
+            )
+            y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32")
+            _ = copy(x, y)
+            return y
+
+        @T.prim_func
+        def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), 
"float32")):
+            for i0, i1 in T.grid(2, 3):
+                with T.block("block"):
+                    vi0, vi1 = T.axis.remap("SS", [i0, i1])
+                    B[vi0, vi1] = A[vi0, vi1]
+
+    mod = TestMemoryAllocStorageTensor
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+    y = vm["main"](x)
+    tvm.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_sub_func_call(exec_mode):
+    @tvm.script.ir_module
+    class TestVMSubFunction:
+        @T.prim_func
+        def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+            T.func_attr({"global_symbol": "tir_matmul"})
+            m = T.var("int32")
+            n = T.var("int32")
+            k = T.var("int32")
+            A = T.match_buffer(x, (m, n))
+            B = T.match_buffer(y, (n, k))
+            C = T.match_buffer(z, (m, k))
+
+            for i, j, k in T.grid(m, k, n):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+        @R.function
+        def relax_matmul_tir(
+            x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
+        ) -> R.Tensor((32, 32), dtype="float32"):
+            with R.dataflow():
+                gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), 
dtype="float32"))
+                R.output(gv0)
+            return gv0
+
+        @R.function
+        def relax_matmul_packed(
+            x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
+        ) -> R.Object:
+            gv0 = R.call_packed("test.vm.mul", x, w, 
sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
+            return gv0
+
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), 
"float32")) -> R.Object:
+            gv0 = relax_matmul_tir(x, w)
+            gv1 = relax_matmul_packed(gv0, gv0)
+            return gv1
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(TestVMSubFunction, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32))
+    y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32))
+    res = check_saved_func(vm, "main", x_inp, y_inp)
+    product = np.dot(x_inp.numpy(), y_inp.numpy())
+    expected = product * product
+    tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_recursion(exec_mode):
+    @tvm.script.ir_module
+    class TestVMRecursion:
+        @R.function
+        def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor:
+            cond = R.call_packed(
+                "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, 
dtype="float32"))
+            )
+            if cond:
+                res = R.const(1.0)
+            else:
+                gv0 = R.call_packed(
+                    "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, 
dtype="float32"))
+                )
+                tmp = recursion(gv0)
+                res = R.call_packed(
+                    "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, 
dtype="float32"))
+                )
+            return res
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(TestVMRecursion, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    inp = np.empty(1).astype("float32")
+    recursion_runs = np.random.randint(1, 10)
+    inp.fill(recursion_runs)
+    inp = tvm.nd.array(inp)
+    res = check_saved_func(vm, "recursion", inp)
+    tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), 
rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_closure(exec_mode):
+    @tvm.script.ir_module
+    class TestClosure:
+        @R.function
+        def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 
3), "float32")):
+            return R.call_packed("test.vm.add", x, env, sinfo_args=(R.Tensor))
+
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), "float32"),
+            y: R.Tensor((2, 3), "float32"),
+        ):
+            clo = R.make_closure(lifted_func_1, (x,))
+            res = R.invoke_closure(clo, (y,), sinfo_args=(R.Tensor))
+            return res
+
+    mod = TestClosure
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+    y_inp = tvm.nd.array(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], 
dtype="float32"))
+    res = check_saved_func(vm, "main", x_inp, y_inp)
+    tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy())
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_time_evaluator(exec_mode):
+    @tvm.script.ir_module
+    class TestTimeEvaluator:
+        @R.function
+        def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")):
+            return R.call_packed(
+                "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, 
dtype="float32"))
+            )
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(TestTimeEvaluator, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x = tvm.nd.array(np.random.rand(1).astype("float32"))
+    y = tvm.nd.array(np.random.rand(1).astype("float32"))
+
+    # ensure we can use time_evaluator with the stateful API
+    vm.set_input("main", x, y)
+    timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("main")
+    # just checking that it has some results at all
+    assert timing_res.results
+
+    # ensure we can use it with a closure
+    vm.save_function("main", "saved_main", x, y)
+    timing_res = vm.time_evaluator("saved_main", tvm.cpu())()
+    assert timing_res.results
+
+
+@tvm.script.ir_module
+class TestVMSetInput:
+    @T.prim_func
+    def test_vm_mul(x: T.handle, y: T.handle, z: T.handle):
+        T.func_attr({"global_symbol": "test_vm_mul"})
+        m = T.var("int32")
+        n = T.var("int32")
+        A = T.match_buffer(x, (m, n))
+        B = T.match_buffer(y, (m, n))
+        C = T.match_buffer(z, (m, n))
+
+        for i, j in T.grid(m, n):
+            with T.block("mul"):
+                vi = T.axis.spatial(m, i)
+                vj = T.axis.spatial(n, j)
+                with T.init():
+                    C[vi, vj] = T.float32(0)
+                C[vi, vj] = A[vi, vj] * B[vi, vj]
+
+    # test returning a tuple
+    @R.function
+    def test_vm_tuple(
+        x: R.Tensor((), "int32")
+    ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")):
+        return (x, x)
+
+    # nested tuple too
+    @R.function
+    def test_vm_nested_tuple(
+        x: R.Tensor((), "int32")
+    ) -> R.Tuple(
+        R.Tuple(
+            R.Tensor((), "int32"),
+            R.Tuple(
+                R.Tensor((), "int32"),
+            ),
+        ),
+        R.Tensor((), "int32"),
+    ):
+        return ((x, (x,)), x)
+
+    @R.function
+    def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), 
"float32")) -> R.Tensor:
+        gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32), 
dtype="float32"))
+        return gv0
+
+
+def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> 
None:
+    a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    vm.set_input("main", a, b)
+    vm.invoke_stateful("main")
+    res0 = vm.get_outputs("main")
+
+    data_dict = {"x": a, "w": b}
+    vm.set_input("main", **data_dict)
+    vm.invoke_stateful("main")
+    res1 = vm.get_outputs("main")
+    tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), 
rtol=1e-7, atol=1e-7)
+    tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7, 
atol=1e-7)
+
+    # bug! If you don't bind the NDArray to a var, the memory will get 
corrupted.
+    # Possibly due to object lifecycles and other FFI issues
+    a = tvm.nd.array(np.array(2).astype("int32"), device)
+    vm.set_input("test_vm_tuple", a)
+    vm.invoke_stateful("test_vm_tuple")
+    res2 = vm.get_outputs("test_vm_tuple")
+    # the results are NDArrays wrapped around scalars,
+    # so we have to get the scalar out of the NDArray
+    assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2)
+
+    b = tvm.nd.array(np.array(1).astype("int32"), device)
+    vm.set_input("test_vm_nested_tuple", b)
+    vm.invoke_stateful("test_vm_nested_tuple")
+    res3 = vm.get_outputs("test_vm_nested_tuple")
+    assert len(res3) == 2 and len(res3[0]) == 2 and len(res3[0][1]) == 1
+    result_cast = ((int(res3[0][0].numpy()), (int(res3[0][1][0].numpy()),)), 
int(res3[1].numpy()))
+    assert result_cast == ((1, (1,)), 1)
+
+
+def set_input_attempt_stateless(vm: relax.VirtualMachine, device: 
tvm.runtime.Device) -> None:
+    # this should fail: once you set inputs, you cannot run statelessly
+    a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    vm.set_input("main", a, b)
+    # must use invoke stateful!
+    vm["main"]()
+
+
+def set_input_attempt_invoke(vm: relax.VirtualMachine, device: 
tvm.runtime.Device) -> None:
+    # this should fail: if the function needs inputs, you can't invoke directly
+    vm.invoke_stateful("main")
+
+
+def set_input_attempt_get(vm: relax.VirtualMachine, device: 
tvm.runtime.Device) -> None:
+    # this should fail: you can't get outputs without invoking the function 
first
+    a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    vm.set_input("main", a, b)
+    _ = vm.get_outputs("main")
+
+
+def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]:
+    """Returns a local VM for the given mod and the device"""
+    target = tvm.target.Target("llvm", host="llvm")
+    exec = relax.vm.build(TestVMSetInput, target, exec_mode=exec_mode)
+    exec.mod.export_library("exec.so")
+    exec_loaded = relax.vm.Executable(tvm.runtime.load_module("exec.so"))
+    os.remove("exec.so")
+    device = tvm.cpu()
+    return relax.VirtualMachine(exec_loaded, device), device
+
+
+def run_on_rpc(
+    mod: tvm.IRModule,
+    trial_func: Callable[[relax.VirtualMachine, tvm.runtime.Device], None],
+    exec_mode: str,
+):
+    """
+    Sets up a VM over localhost using the given mod and runs the given trial 
function.
+    The trial function should take a VM and a device
+    """
+    target = tvm.target.Target("llvm", host="llvm")
+    exec = relax.vm.build(mod, target, exec_mode=exec_mode)
+    temp = utils.tempdir()
+    path = temp.relpath("vm_library.so")
+    exec.mod.export_library(path)
+
+    # Use local rpc server for testing.
+    # Server must use popen so it doesn't inherit the current process state. It
+    # will crash otherwise.
+    # Adapted from relay/test_vm.py
+    def check_remote(server):
+        remote = rpc.connect(server.host, server.port, session_timeout=10)
+
+        # Upload the serialized Executable.
+        remote.upload(path)
+        # Get a handle to remote Executable.
+        rexec = remote.load_module("vm_library.so")
+
+        device = remote.cpu()
+        # Build a VM out of the executable and context.
+        vm = relax.vm.VirtualMachine(exec=rexec, device=device)
+        trial_func(vm, device)
+
+    check_remote(rpc.Server("127.0.0.1"))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_set_input(exec_mode):
+    set_input_trial(*make_vm(TestVMSetInput, exec_mode))
+
+
+def save_function_kwargs_trial(vm: relax.VirtualMachine, device: 
tvm.runtime.Device) -> None:
+    # just checking that we can use kwargs for the args when saving a function
+    a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    vm.save_function("main", "saved_main", x=a, w=b)
+    res0 = vm["saved_main"]()
+    tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), 
rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_save_function_kwargs(exec_mode):
+    save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_save_function_kwargs_rpc(exec_mode):
+    run_on_rpc(TestVMSetInput, save_function_kwargs_trial, exec_mode)
+
+
+def save_function_time_evaluator_trial(
+    vm: relax.VirtualMachine, device: tvm.runtime.Device
+) -> None:
+    # just checking that the saved function can be called in the time evaluator
+    a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+    vm.save_function("main", "saved_main", a, b)
+    vm.time_evaluator("saved_main", device)()
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_save_function_time_evaluator(exec_mode):
+    save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_save_function_time_evaluator(exec_mode):
+    run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial, exec_mode)
+
+
+# if you set an input, you should not be able to call statelessly
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+@pytest.mark.xfail()
+def test_set_input_stateless_failure(exec_mode):
+    set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+@pytest.mark.xfail()
+def test_set_input_stateless_failure_rpc(exec_mode):
+    run_on_rpc(TestVMSetInput, set_input_attempt_stateless, exec_mode)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+@pytest.mark.xfail()
+def test_set_input_invoke_failure(exec_mode):
+    set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+@pytest.mark.xfail()
+def test_set_input_invoke_failure_rpc(exec_mode):
+    run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+@pytest.mark.xfail()
+def test_set_input_get_failure(exec_mode):
+    set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+@pytest.mark.xfail()
+def test_set_input_get_failure_rpc(exec_mode):
+    run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to