This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ff1a02c66aa69fd0e8707a671172f3942a7f61e5
Author: Yuchen Jin <yuch...@cs.washington.edu>
AuthorDate: Fri Feb 10 23:33:46 2023 -0800

    [Unity] Relax VM shape lowering pass (#13956)
    
    This PR introduces Relax `FunctionPass` and `DataflowBlockPass` API, and 
the `VMShapeLower` pass to lower the shape expression in Relax to TIR functions 
and VM shape heap builtin functions.
    
    Co-Authored-by: Ziheng Jiang <zih...@apache.org>
    Co-Authored-by: Lesheng Jin <34279105+lesheng...@users.noreply.github.com>
    Co-Authored-by: Altan Haan <alt...@cs.washington.edu>
    Co-Authored-by: Junru Shao <junrushao1...@gmail.com>
    Co-Authored-by: Prakalp Srivastava <prak...@octoml.ai>
    Co-Authored-by: Ruihang Lai <ruiha...@cs.cmu.edu>
    Co-Authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn>
    Co-Authored-by: Steven S. <Lyubomirsky slyubomir...@octoml.ai>
    Co-Authored-by: Sunghyun Park <49998730+sun...@users.noreply.github.com>
    Co-Authored-by: Tianqi Chen <tianqi.tc...@gmail.com>
    Co-Authored-by: Yong Wu <yongc...@gmail.com>
---
 include/tvm/relax/backend.h                        |  44 ++
 include/tvm/relax/transform.h                      |  72 ++
 python/tvm/relax/__init__.py                       |   1 +
 python/tvm/relax/transform/__init__.py             |  20 +
 python/tvm/relax/transform/_ffi_api.py             |  19 +
 python/tvm/relax/transform/transform.py            | 345 ++++++++++
 src/relax/backend/vm/vm_shape_lower.cc             | 725 +++++++++++++++++++++
 src/relax/ir/transform.cc                          | 413 ++++++++++++
 .../relax/test_backend_transform_shape_lower.py    | 429 ++++++++++++
 9 files changed, 2068 insertions(+)

diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h
new file mode 100644
index 0000000000..4ebeacac0f
--- /dev/null
+++ b/include/tvm/relax/backend.h
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/backend.h
+ * \brief Relax backend specific transformation passes.
+ */
+#ifndef TVM_RELAX_BACKEND_H_
+#define TVM_RELAX_BACKEND_H_
+
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+namespace transform {
+
+/*!
+ * \brief Lower the shape expression in relax to VM shape heap and TIR 
functions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass VMShapeLower();
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_BACKEND_H_
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
new file mode 100644
index 0000000000..fa288a7f06
--- /dev/null
+++ b/include/tvm/relax/transform.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform.h
+ * \brief Relax specific transformation passes.
+ */
+#ifndef TVM_RELAX_TRANSFORM_H_
+#define TVM_RELAX_TRANSFORM_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+namespace transform {
+
+using Pass = tvm::transform::Pass;
+using PassInfo = tvm::transform::PassInfo;
+using PassContext = tvm::transform::PassContext;
+using Function = tvm::relax::Function;
+using DataflowBlock = tvm::relax::DataflowBlock;
+
+/*!
+ * \brief Create a function pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the function pass.
+ * \param name The name of the function pass.
+ * \param required The list of the passes that the function pass is dependent 
on.
+ *
+ * \return The created function pass.
+ */
+TVM_DLL Pass CreateFunctionPass(
+    const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& 
pass_func,
+    int opt_level, String name, tvm::Array<String> required);
+
+/*!
+ * \brief Create a dataflowblock pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the dataflowblock pass.
+ * \param name The name of the dataflowblock pass.
+ * \param required The list of the passes that the dataflowblock pass is 
dependent on.
+ *
+ * \return The created dataflowblock pass.
+ */
+TVM_DLL Pass CreateDataflowBlockPass(
+    const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)>& pass_func,
+    int opt_level, String name, tvm::Array<String> required);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_TRANSFORM_H_
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index ce175354d0..a6306b788e 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -20,6 +20,7 @@ from . import exec_builder
 from . import expr
 from . import ty
 from . import analysis
+from . import transform
 from . import vm
 from . import block_builder
 from . import op
diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
new file mode 100644
index 0000000000..eb4d5f710c
--- /dev/null
+++ b/python/tvm/relax/transform/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import, redefined-builtin
+"""Relax transformations. """
+
+from .transform import *
diff --git a/python/tvm/relax/transform/_ffi_api.py 
b/python/tvm/relax/transform/_ffi_api.py
new file mode 100644
index 0000000000..667aa62c2c
--- /dev/null
+++ b/python/tvm/relax/transform/_ffi_api.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+"""FFI APIs for tvm.transform"""
+import tvm._ffi
+
+tvm._ffi._init_api("relax.transform", __name__)
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
new file mode 100644
index 0000000000..f20f06c522
--- /dev/null
+++ b/python/tvm/relax/transform/transform.py
@@ -0,0 +1,345 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Relax transformation passes."""
+import functools
+import inspect
+import types
+from typing import Callable, Union
+
+import tvm.ir
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("relax.FunctionPass")
+class FunctionPass(tvm.ir.transform.Pass):
+    """A pass that works on each tvm.relax.Function in a module. A function
+    pass class should be created through `function_pass`.
+    """
+
+
+@tvm._ffi.register_object("relax.DataflowBlockPass")
+class DataflowBlockPass(tvm.ir.transform.Pass):
+    """A pass that works on each tvm.relax.DataflowBlock in a module."""
+
+
+def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
+    """Lower the symbolic shape and argument and match-cast structinfo 
matching.
+
+    Parameters
+    ----------
+    emit_err_ctx: Optional[bool]
+        Whether emit err context string, can be turned off for testing 
purposes.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.VMShapeLower(emit_err_ctx)  # type: ignore
+
+
+def _wrap_class_function_pass(pass_cls, pass_info):
+    """Wrap a python class as function pass."""
+
+    class PyFunctionPass(FunctionPass):
+        """Internal wrapper class to create a class instance."""
+
+        def __init__(self, *args, **kwargs):
+            # initialize handle in case pass_cls creation failed.
+            self.handle = None
+            inst = pass_cls(*args, **kwargs)
+
+            # it is important not to capture self to
+            # avoid a cyclic dependency
+            def _pass_func(func, mod, ctx):
+                return inst.transform_function(func, mod, ctx)
+
+            self.__init_handle_by_constructor__(
+                _ffi_api.MakeFunctionPass, _pass_func, pass_info  # type: 
ignore
+            )
+            self._inst = inst
+
+        def __getattr__(self, name):
+            # fall back to instance attribute if there is not any
+            return self._inst.__getattribute__(name)
+
+    functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
+    PyFunctionPass.__name__ = pass_cls.__name__
+    PyFunctionPass.__doc__ = pass_cls.__doc__
+    PyFunctionPass.__module__ = pass_cls.__module__
+    return PyFunctionPass
+
+
+def function_pass(
+    pass_func=None,
+    opt_level=None,
+    name=None,
+    required=None,
+) -> Union[Callable, FunctionPass]:
+    """Decorate a function pass.
+
+    This function returns a callback when pass_func
+    is provided. Otherwise, it returns the created function pass using the
+    given optimization function.
+
+    Parameters
+    ----------
+    pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
+        The transformation function or class.
+
+    opt_level : int
+        The optimization level of this function pass.
+
+    name : Optional[str]
+        The name of the function pass. The name could be empty. In this case, 
the
+        name of the optimization function will be used as the pass name.
+
+    required : Optional[List[str]]
+        The list of passes that the function pass is dependent on.
+
+    Returns
+    -------
+    create_function_pass : Union[Callable, FunctionPass]
+
+        A decorator will be returned if pass_func is not provided,
+        otherwise return the decorated result.
+        The returned decorator has two behaviors depending on the input:
+        A new FunctionPass will be returned when we decorate a pass function.
+        A new FunctionPass class will be returned when we decorate a class 
type.
+
+    Examples
+    --------
+    The following code block decorates a function pass class.
+
+    .. code-block:: python
+
+        @relax.transform.function_pass(opt_level=1)
+        class TestReplaceFunc:
+            def __init__(self, new_func):
+                self.new_func = new_func
+
+            def transform_function(self, func, mod, ctx):
+                # just for demo purposes
+                # transform func to new_func
+                return self.new_func
+
+        @R.function
+        def f1(x: Tensor[(m, n), "float32"]):
+            return x
+
+        @tvm.script.ir_module
+        class InputMod:
+            @R.function
+            def f2(x: Tensor[(m, n), "float32"]):
+                gv0 = relax.add(x, x)
+                return gv0
+        # fpass is now a special pass that replaces every
+        # function to f1
+        fpass = TestReplaceFunc(f1)
+        # now every function in InputMod is replaced by f1
+        updated_mod = fpass(InputMod)
+
+
+    The following code creates a function pass by decorating
+    a user defined transform function.
+
+    .. code-block:: python
+
+        @relax.transform.function_pass(opt_level=2)
+        def transform(func, mod, ctx):
+            # my transformations here.
+            return func
+
+        function_pass = transform
+        assert isinstance(function_pass, relax.transform.FunctionPass)
+        assert function_pass.info.opt_level == 2
+
+        # Given a module m, the optimization could be invoked as the follwoing:
+        updated_mod = function_pass(m)
+        # Now transform should have been applied to every function in
+        # the provided module m. And the updated module will be returned.
+    """
+
+    if opt_level is None:
+        raise ValueError("Please provide opt_level for the function pass.")
+
+    required = required if required else []
+    if not isinstance(required, (list, tuple)):
+        raise TypeError("Required is expected to be the type of " + 
"list/tuple.")
+
+    def create_function_pass(pass_arg):
+        """Internal function that creates a function pass"""
+        fname = name if name else pass_arg.__name__
+        info = tvm.transform.PassInfo(opt_level, fname, required)
+        if inspect.isclass(pass_arg):
+            return _wrap_class_function_pass(pass_arg, info)
+        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+            raise TypeError("pass_func must be a callable for Function pass")
+        return _ffi_api.MakeFunctionPass(pass_arg, info)  # type: ignore
+
+    if pass_func:
+        return create_function_pass(pass_func)
+    return create_function_pass
+
+
+def _wrap_class_dataflowblock_pass(pass_cls, pass_info):
+    """Wrap a python class as dataflowblock pass"""
+
+    class PyDataflowBlockPass(DataflowBlockPass):
+        """Internal wrapper class to create a class instance."""
+
+        def __init__(self, *args, **kwargs):
+            # initialize handle in case pass_cls creation failed.
+            self.handle = None
+            inst = pass_cls(*args, **kwargs)
+
+            # it is important not to capture self to
+            # avoid a cyclic dependency
+            def _pass_func(func, mod, ctx):
+                return inst.transform_dataflowblock(func, mod, ctx)
+
+            self.__init_handle_by_constructor__(
+                _ffi_api.MakeDataflowBlockPass, _pass_func, pass_info  # type: 
ignore
+            )
+            self._inst = inst
+
+        def __getattr__(self, name):
+            # fall back to instance attribute if there is not any
+            return self._inst.__getattribute__(name)
+
+    functools.update_wrapper(PyDataflowBlockPass.__init__, pass_cls.__init__)
+    PyDataflowBlockPass.__name__ = pass_cls.__name__
+    PyDataflowBlockPass.__doc__ = pass_cls.__doc__
+    PyDataflowBlockPass.__module__ = pass_cls.__module__
+    return PyDataflowBlockPass
+
+
+def dataflowblock_pass(
+    pass_func=None, opt_level=None, name=None, required=None
+) -> Union[Callable, DataflowBlockPass]:
+    """Decorate a dataflowblock pass.
+
+    This function returns a callback when pass_func
+    is provided. Otherwise, it returns the created dataflowblock pass using the
+    given optimization function.
+
+    Parameters
+    ----------
+    pass_func : Optional[Callable[(DataflowBlock, Module, PassContext) -> 
DataflowBlock]]
+        The transformation function or class.
+
+    opt_level : int
+        The optimization level of this dataflowblock pass.
+
+    name : Optional[str]
+        The name of the dataflowblock pass. The name could be empty. In this 
case, the
+        name of the optimization function will be used as the pass name.
+
+    required : Optional[List[str]]
+        The list of passes that the dataflowblock pass is dependent on.
+
+    Returns
+    -------
+    create_dataflowblock_pass : Union[Callable, DataflowBlockPass]
+
+        A decorator will be returned if pass_func is not provided,
+        otherwise return the decorated result.
+        The returned decorator has two behaviors depending on the input:
+        A new DataflowBlockPass will be returned when we decorate a pass 
function.
+        A new DataflowBlockPass class will be returned when we decorate a 
class type.
+
+    Examples
+    --------
+    The following code block decorates a dataflowblock pass class.
+
+    .. code-block:: python
+
+        @relax.transform.dataflowblock_pass(opt_level=1)
+        class TestReplaceBinding:
+            # Simple test function to replace the first VarBinding to another.
+
+            def __init__(self):
+                # create a new VarBinding
+                m, n = tir.Var("m", "int64"), tir.Var("n", "int64")
+                lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], 
"float32"))
+                val = relax.const(np.random.rand(24, 56))
+                self.new_binding = relax.VarBinding(lv0, val)
+
+            def transform_dataflowblock(self, block, mod, ctx):
+                # just for demo purposes
+                # Replace the first binding in the DataflowBlock
+                new_bindings = [self.new_binding, block.bindings[1]]
+                new_block = relax.expr.DataflowBlock(new_bindings, block.span)
+                return new_block
+
+        @tvm.script.ir_module
+        class InputMod:
+            @R.function
+            def f1(x: Tensor[(m, n), "float32"]):
+                with relax.dataflow():
+                    lv0 = relax.multiply(x, x)
+                    gv0 = relax.add(x, x)
+                    relax.output(gv0)
+                return gv0
+        # block_pass is now a special pass that replaces every
+        # first binding to the constant value binding
+        block_pass = TestReplaceBinding()
+        # now every first binding in DataflowBlock of InputMod
+        # is replaced by new_binding
+        updated_mod = block_pass(InputMod)
+
+
+    The following code creates a dataflowblock pass by decorating
+    a user defined transform function.
+
+    .. code-block:: python
+
+        @relax.transform.dataflowblock_pass(opt_level=2)
+        def transform(block, mod, ctx):
+            # my transformations here.
+            return block
+
+        block_pass = transform
+        assert isinstance(block_pass, relax.transform.DataflowBlockPass)
+        assert block_pass.info.opt_level == 2
+
+        # Given a module m, the optimization could be invoked as the follwoing:
+        updated_mod = block_pass(m)
+        # Now transform should have been applied to every DataflowBlock in
+        # the provided module m. And the updated module will be returned.
+    """
+
+    if opt_level is None:
+        raise ValueError("Please provide opt_level for the dataflowblock 
pass.")
+
+    required = required if required else []
+    if not isinstance(required, (list, tuple)):
+        raise TypeError("Required is expected to be the type of " + 
"list/tuple.")
+
+    def create_dataflowblock_pass(pass_arg):
+        """Internal function that creates a dataflowblock pass"""
+        fname = name if name else pass_arg.__name__
+        info = tvm.transform.PassInfo(opt_level, fname, required)
+        if inspect.isclass(pass_arg):
+            return _wrap_class_dataflowblock_pass(pass_arg, info)
+        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+            raise TypeError("pass_func must be a callable for DataflowBlock 
pass")
+        return _ffi_api.MakeDataflowBlockPass(pass_arg, info)  # type: ignore
+
+    if pass_func:
+        return create_dataflowblock_pass(pass_func)
+    return create_dataflowblock_pass
diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
new file mode 100644
index 0000000000..090bcf01b5
--- /dev/null
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -0,0 +1,725 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/backend/vm/vm_shape_lower.cc
+ * \brief Lower the function boundary type checks and symbolic shape 
computations.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/backend.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/struct_info_functor.h>
+#include <tvm/runtime/relax_vm/builtin.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief A slot used in PrimExpr lowering. */
+struct PrimExprSlot {
+  /*! \brief The existing */
+  PrimExpr expr;
+  /*! \brief The slot index */
+  int index;
+  // The following three members are auxiliary data
+  // to help shape rewriting.
+  /*!
+   * \brief List of slots whose PrimExpr uses this PrimExpr.
+   * \note Users won't be empty only if PrimExpr is a Var and it does not 
include itself.
+   */
+  std::vector<PrimExprSlot*> user_slots;
+  /*!
+   * \brief Number of outstanding vars that are not defined in this PrimExpr.
+   * \note This is a helper counter used in analysis to perform computations.
+   */
+  int outstanding_defs = 0;
+  /*! \brief Whether we have computed the value. */
+  bool value_computed = false;
+};
+
+/*!
+ * \brief Helper dats structure to collect pairs of match shapes
+ *        in a recursive matching process.
+ */
+struct MatchShapeTodoItem {
+  Expr input;
+  Array<PrimExpr> pattern;
+  String err_ctx;
+};
+
+/*! \brief Slot map used for shape lowering. */
+using PrimExprSlotMap =
+    std::unordered_map<PrimExpr, PrimExprSlot*, StructuralHash, 
tir::ExprDeepEqual>;
+
+// Collector to collect PrimExprSlotMap
+class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor {
+ public:
+  // collect the PrimExpr slot for a given function
+  static void Collect(Function func, 
std::vector<std::unique_ptr<PrimExprSlot>>* slot_vec,
+                      PrimExprSlotMap* slot_map) {
+    PrimExprSlotCollector collector;
+    collector.slot_vec_ = slot_vec;
+    collector.slot_map_ = slot_map;
+    // collect shape declaration in func params
+    for (auto param : func->params) {
+      collector.VisitStructInfo(GetStructInfo(param));
+      collector.VisitExpr(param);
+    }
+    collector.VisitExpr(func->body);
+  }
+
+ private:
+  void VisitPrimExpr(const PrimExpr& expr) final {
+    if (expr->IsInstance<IntImmNode>()) return;
+    if (slot_map_->count(expr) == 0) {
+      auto slot = std::make_unique<PrimExprSlot>();
+      slot->expr = expr;
+      slot->index = static_cast<int>(slot_vec_->size());
+      slot_map_->emplace(expr, slot.get());
+      slot_vec_->emplace_back(std::move(slot));
+    }
+  }
+
+  void VisitBinding_(const MatchCastNode* op) final {
+    // Visit the match cast struct info so we can define
+    // the symbolic variables here.
+    this->VisitStructInfo(op->struct_info);
+  }
+
+  void VisitExpr_(const FunctionNode* op) final {
+    // Do not recurse into function node as it is self-contained
+  }
+
+  void VisitStructInfo_(const FuncStructInfoNode* op) final {
+    // Do not recurse into function struct info as it is self-contained
+  }
+
+  void VisitStructInfoExprField(const PrimExpr& expr) final { 
VisitPrimExpr(expr); }
+
+  void VisitStructInfoExprField(const Expr& expr) final { 
ExprVisitor::VisitExpr(expr); }
+
+  std::vector<std::unique_ptr<PrimExprSlot>>* slot_vec_;
+  PrimExprSlotMap* slot_map_;
+};
+
+/*!
+ * \brief Main logic to transform the shape lowered functions
+ *
+ * Consider the following input:
+ *
+ * \code
+ *
+ *  def f(x: R.Tuple(R.Tensor([m, n+1]), R.Tensor([n, 2])) -> R.Tensor:
+ *     return x
+ *
+ * \endcode
+ *
+ * Overall flow of the algorithm:
+ * - Preprocess: PrimExprSlot collection, we scan the function and allocate 
PrimExprSlot
+ *   for each PrimExpr. In the above example, the result mapping from the slot 
index
+ *   to expr would be {0:m, 1: n+1: 2: n}. Note that "n+1" also get a slot.
+ *   PrimExprSlot also comes with auxiliary fields that track whether its value
+ *   can be readily computed.
+ *
+ * Steps at each matching point:
+ * - Step 0: We call CheckMatchCast,
+ *   which will recursively unpack the StructInfo, and generate static 
information checks.
+ *   Note that this step only generates functions for checking types and ndim 
info, but not
+ *   the symbolic shape variables. The symbolic shape-matching results will be 
returned as
+ *   vector<MatchShapeTodoItem>. This is because symbolic shape matching may 
not be completed
+ *   in a single round. Importantly, CheckMatchCast also deals with tuple 
unpacking.
+ *
+ * - Step 1: We then call RunMatch to generate the statements for matching 
symbolic shapes.
+ *   In the above example, the first round will store the value of m, n to 
their corresponding
+ *   slot. RunMatch may return outstanding items. In the above example 
x.shape[1] == n+1 cannot
+ *   be checked in the first round. RunMatch will populate new vars(this case 
n, m), these vars
+ *   are added to a ready queue (ready_vars_)
+ *
+ * - Step 2: We EmitOutstandingPrimExprCompute to check if ready_vars will 
trigger new values
+ *   to be computed. We eagerly compute all the outstanding values. The 
trigger is done through
+ *   a ref counter which decreases when each outstanding def is satisfied.
+ *   This step can also generate additional TIR functions to carry out shape 
computations.
+ *
+ * - Step 3: RunMatch again for given outstanding match todos. This time all 
invariants
+ *   should be checked.
+ *
+ * The above step would populate each slot(which is backed by an element in 
shape_heap).
+ * Each time we find a symbolic shape tuple, we call MakeShape for given slot 
indices
+ * in the shape_heap.
+ *
+ *
+ * Key functions in the flow:
+ * - PrimExprSlotCollector: preprocessing and collecting the slots
+ * - CheckMatchCast: recursively structinfo unpacking, generate checks and 
match items.
+ * - RunMatch: generate symbolic shape matches
+ * - EmitOutstandingPrimExprCompute: tracks the variables to be computed and 
emit shape computation
+ * - VisitExpr_(ShapeExprNode*): makes symbolic shape tuple.
+ *
+ * The checks and symbolic shape all maps to runtime builtin functions. Please 
checkout
+ * runtime/relax_vm/builtin.cc for their definitions.
+ *
+ * Shape computation are lowered to host-side TIR functions that load var from 
slot
+ * and store computed results into the slot. For a given slot map: {0:m, 1: 
n+1: 2: n}
+ * It will create the shape_func below that loads data from H[2](n's slot) run 
compute
+ * and store back to H[1](n+1's slot).
+ *
+ * \code
+ *
+ * @T.prim_func
+ * def shape_func(H: T.Buffer([3], "int64")):
+ *     H[1] = H[2] + 1
+ *
+ * \endcode
+ *
+ * The current implementation will batch all shape computations at each match 
point.
+ * For example, all the expressions that depend on n, m will be computed in a 
single
+ * shape_func at the function boundary. If there are follow-up match_cast 
points,
+ * that defines new variable, then we might we will generate new shape 
functions
+ * to compute expressions that depend on these variables.
+ */
+class VMShapeLowerMutator
+    : public ExprMutator,
+      public StructInfoFunctor<void(const StructInfo&, Expr, bool, const 
String&,
+                                    std::vector<MatchShapeTodoItem>*)> {
+ public:
+  static IRModule Lower(IRModule mod, bool emit_err_ctx) {
+    VMShapeLowerMutator mutator(mod, emit_err_ctx);
+
+    for (auto& kv : mod->functions) {
+      if (auto* func = kv.second.as<FunctionNode>()) {
+        Function updated_func = mutator.Rewrite(kv.first, 
GetRef<Function>(func));
+        mutator.builder_->UpdateFunction(kv.first, updated_func);
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  explicit VMShapeLowerMutator(IRModule mod, bool emit_err_ctx)
+      : ExprMutator(mod), emit_err_ctx_(emit_err_ctx) {}
+
+  using ExprMutator::VisitExpr_;
+
+  // Unit rewrite function per function.
+  Function Rewrite(GlobalVar gvar, Function func) {
+    // prepare mapping and heap var
+    PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_);
+    heap_size_ = IntImm(ShapeDType(), static_cast<int64_t>(slot_vec_.size()));
+    VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_);
+    shape_heap_ = shape_heap_binding->var;
+
+    // prepare slot information
+    this->PopulateSlotInfo();
+
+    Array<BindingBlock> blocks;
+
+    builder_->BeginScope(func->params);
+
+    {
+      // Check the parameter section.
+      builder_->BeginBindingBlock();
+      this->builder_->EmitNormalized(shape_heap_binding);
+      std::vector<MatchShapeTodoItem> match_todos;
+      for (size_t i = 0; i < func->params.size(); ++i) {
+        StructInfo sinfo = GetStructInfo(func->params[i]);
+        std::ostringstream err_ctx;
+        err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i
+                << "], param=" << func->params[i]->name_hint() << ", 
annotation=" << sinfo << ") ";
+        this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(), 
&match_todos);
+      }
+      // insert heap generation logic.
+      match_todos = this->RunMatch(match_todos, false);
+      this->EmitOutstandingPrimExprCompute();
+      this->RunMatch(match_todos, true);
+
+      BindingBlock pre_block = builder_->EndBlock();
+      blocks.push_back(pre_block);
+    }
+
+    // new body.
+    auto body_seq = Downcast<SeqExpr>(this->VisitWithNewScope(func->body, 
func->params));
+    blocks.insert(blocks.end(), body_seq->blocks.begin(), 
body_seq->blocks.end());
+
+    {
+      // Insert the return value check
+      builder_->BeginBindingBlock();
+      std::ostringstream err_ctx;
+      err_ctx << "ErrorContext(fn=" << gvar->name_hint
+              << ", loc=return, annotation=" << func->ret_struct_info << ") ";
+      std::vector<MatchShapeTodoItem> match_todos;
+      // NOTE: the return value's shape computation must already be defined.
+      this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, 
err_ctx.str(),
+                           &match_todos);
+      // NOTE: the return value's shape computation must already be defined.
+      this->RunMatch(match_todos, true);
+      BindingBlock post_block = builder_->EndBlock();
+      blocks.push_back(post_block);
+    }
+
+    auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body));
+    // create a new function
+    return Function(func->params, new_body, func->ret_struct_info, 
func->attrs);
+  }
+
+  //-------------------------------------------------------
+  // PrimExpr slot handling
+  //-------------------------------------------------------
+  static DataType ShapeDType() { return DataType::Int(64); }
+
+  /*! \brief populate additional information in the slot. */
+  void PopulateSlotInfo() {
+    for (auto& kv : slot_map_) {
+      auto* slot = kv.second;
+      if (!slot->expr.as<tir::VarNode>()) {
+        Array<tir::Var> dep_vars = tir::UndefinedVars(slot->expr);
+        for (auto var : dep_vars) {
+          auto it = slot_map_.find(var);
+          ICHECK(it != slot_map_.end())
+              << "Var " << var << "is not defined in the function but is 
referenced by "
+              << slot->expr;
+          auto* var_slot = it->second;
+          // populate the use slot.
+          var_slot->user_slots.push_back(slot);
+        }
+        // set outstanding defs.
+        slot->outstanding_defs += static_cast<int>(dep_vars.size());
+      }
+    }
+  }
+  //-------------------------------------------------------
+  // Helper functions
+  //-------------------------------------------------------
+  StringImm GetErrContext(String err_ctx) const {
+    return emit_err_ctx_ ? StringImm(err_ctx) : StringImm("");
+  }
+
+  VarBinding AllocShapeHeapBinding(IntImm heap_size) {
+    if (heap_size->value > 0) {
+      TensorStructInfo heap_sinfo(ShapeDType(), 1);
+      Var var("shape_heap", heap_sinfo);
+      // set up the builtin func.
+      Call call(call_builtin_with_ctx_op_,
+                {builtin_alloc_shape_heap_, Tuple({PrimValue(heap_size)})}, 
Attrs(), {heap_sinfo});
+      UpdateStructInfo(call, heap_sinfo);
+      return VarBinding(var, call);
+    } else {
+      Var var("shape_heap", ObjectStructInfo());
+      Call call(null_value_op_, {});
+      UpdateStructInfo(call, ObjectStructInfo());
+      return VarBinding(var, call);
+    }
+  }
+
+  //-------------------------------------------------------
+  // Expr mutation overloading.
+  //-------------------------------------------------------
+  Expr VisitExpr_(const FunctionNode* op) final {
+    LOG(FATAL) << "VMShapeLower do not work for local functions, make sure "
+               << " to run it after LambdaLift";
+    return GetRef<Expr>(op);
+  }
+
+  Expr VisitExpr_(const ShapeExprNode* op) final {
+    using runtime::relax_vm::MakeShapeCode;
+    // Constant shape can be preserved.
+    bool is_const_shape = std::all_of(op->values.begin(), op->values.end(), 
[](const PrimExpr& e) {
+      return e->IsInstance<IntImmNode>();
+    });
+    if (is_const_shape) {
+      return GetRef<Expr>(op);
+    }
+
+    Array<Expr> args = {shape_heap_, 
PrimValue::Int64(static_cast<int64_t>(op->values.size()))};
+    for (PrimExpr expr : op->values) {
+      if (auto* int_expr = expr.as<IntImmNode>()) {
+        
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm)));
+        args.push_back(PrimValue::Int64(int_expr->value));
+      } else {
+        auto it = slot_map_.find(expr);
+        ICHECK(it != slot_map_.end());
+        auto* slot = it->second;
+        ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been 
computed";
+        
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape)));
+        args.push_back(PrimValue::Int64(slot->index));
+      }
+    }
+
+    // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
+    Call call(builtin_make_shape_, args, Attrs(),
+              {ShapeStructInfo(static_cast<int>(op->values.size()))});
+    return call;
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) final {
+    Expr value = ExprMutator::VisitExpr(binding->value);
+    std::vector<MatchShapeTodoItem> match_todos;
+    std::ostringstream err_ctx;
+    err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info 
<< ") ";
+    // always_check=false
+    this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(), 
&match_todos);
+
+    match_todos = this->RunMatch(match_todos, false);
+    this->EmitOutstandingPrimExprCompute();
+    this->RunMatch(match_todos, true);
+
+    // These checks are emitted as extra, in codegen
+    // match-cast is simply ignored and treated as a normal binding.
+    builder_->EmitNormalized(GetRef<MatchCast>(binding));
+  }
+
+  // Do not override shape in struct info fields
+  // We only override the shape that are already part of the normal function 
values
+  // If future passes lift those values out into the values,
+  // then codegen may not be able to handle symbolic values.
+  // Place this pass as last pass before codegen.
+  StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final { 
return sinfo; }
+
+  //-------------------------------------------------------
+  // Shape computations.
+  //-------------------------------------------------------
+  /*!
+   * \brief Execute the match todo items.
+   *
+   * This function can populate vars in the match items when seeing it for the 
first time.
+   * These new vars will be added to this->ready_vars_.
+   *
+   * If an item contains PrimExpr that are yet to be computed (but may be 
computable through
+   * vars defined in this round), it will be returned to the caller.
+   *
+   * The caller should call EmitOutstandingPrimExprCompute, then call RunMatch 
again.
+   *
+   * \param match_todos The list of match items to be executed.
+   * \param require_value_computed Whether we require all expr to be computed.
+   * \return List of outstanding items that contains value that are yet to be 
computed.
+   */
+  std::vector<MatchShapeTodoItem> RunMatch(const 
std::vector<MatchShapeTodoItem>& match_todos,
+                                           bool require_value_computed) {
+    std::vector<MatchShapeTodoItem> outstanding_todos;
+
+    using runtime::relax_vm::MatchShapeCode;
+    for (const MatchShapeTodoItem& item : match_todos) {
+      int64_t shape_len = static_cast<int64_t>(item.pattern.size());
+      bool all_nop = true;
+      int num_outstanding_exprs = 0;
+
+      Array<Expr> args = {item.input, shape_heap_, 
PrimValue::Int64(shape_len)};
+
+      for (PrimExpr expr : item.pattern) {
+        MatchShapeCode code = MatchShapeCode::kNoOp;
+        int64_t rvalue = 0;
+        if (auto* int_expr = expr.as<IntImmNode>()) {
+          code = MatchShapeCode::kAssertEqualToImm;
+          rvalue = int_expr->value;
+        } else {
+          auto it = slot_map_.find(expr);
+          ICHECK(it != slot_map_.end());
+          auto* slot = it->second;
+          if (slot->value_computed) {
+            code = MatchShapeCode::kAssertEqualToLoad;
+            rvalue = slot->index;
+          } else {
+            // the value is not yet computed
+            ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not 
computed";
+            if (expr.as<tir::VarNode>()) {
+              // if it is a var, we will populate it in this round.
+              // otherwise, we skip and mark it as outstanding
+              code = MatchShapeCode::kStoreToHeap;
+              rvalue = slot->index;
+              slot->value_computed = true;
+              ready_vars_.push_back(slot);
+            } else {
+              code = MatchShapeCode::kNoOp;
+              rvalue = 0;
+              ++num_outstanding_exprs;
+            }
+          }
+        }
+        all_nop = all_nop && code == MatchShapeCode::kNoOp;
+        args.push_back(PrimValue::Int64(static_cast<int>(code)));
+        args.push_back(PrimValue::Int64(rvalue));
+      }
+      if (num_outstanding_exprs != 0) {
+        outstanding_todos.push_back(item);
+      }
+      args.push_back(GetErrContext(item.err_ctx));
+      if (!all_nop) {
+        Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_});
+        builder_->Emit(call, "_");
+      }
+    }
+    return std::move(outstanding_todos);
+  }
+
+  /*!
+   * \brief Compute a list of prim expr that now be computed
+   *        for given ready vars.
+   */
+  std::vector<PrimExprSlot*> GetReadyPrimExprSlots() {
+    std::vector<PrimExprSlot*> to_compute;
+    for (PrimExprSlot* slot : ready_vars_) {
+      for (PrimExprSlot* user : slot->user_slots) {
+        ICHECK_GT(user->outstanding_defs, 0);
+        user->outstanding_defs -= 1;
+        if (user->outstanding_defs == 0) {
+          to_compute.push_back(user);
+        }
+      }
+    }
+    ready_vars_.clear();
+    return to_compute;
+  }
+
+  /*!
+   * \brief Check the dependent expressions of ready_vars_,
+   *
+   * If there are outstanding PrimExpr that can now be computed
+   * we generate a PrimFunc that compute the extra shape values
+   *
+   * We will then clear the ready_vars.
+   *
+   * \return Number of PrimExpr computed.
+   */
+  size_t EmitOutstandingPrimExprCompute() {
+    std::vector<PrimExprSlot*> to_compute = GetReadyPrimExprSlots();
+    if (to_compute.size() == 0) return 0;
+    ICHECK_GT(heap_size_->value, 0);
+    // construct a PrimFunc that compute the shape.
+    tir::Var heap("heap", DataType::Handle());
+    Array<PrimExpr> buffer_shape{heap_size_};
+    tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H", 
"global");
+    Map<tir::Var, tir::Buffer> buffer_map;
+    buffer_map.Set(heap, buffer);
+
+    auto var_map = [&](const tir::Var& var) -> Optional<PrimExpr> {
+      auto it = slot_map_.find(var);
+      ICHECK(it != slot_map_.end());
+      return tir::BufferLoad(buffer, {IntImm(ShapeDType(), 
it->second->index)});
+    };
+
+    Array<tir::Stmt> seq;
+    for (PrimExprSlot* slot : to_compute) {
+      ICHECK(!slot->value_computed);
+      slot->value_computed = true;
+      PrimExpr value = tir::Substitute(slot->expr, var_map);
+      seq.push_back(tir::BufferStore(buffer, value, {IntImm(ShapeDType(), 
slot->index)}));
+    }
+
+    tir::Stmt body = tir::SeqStmt::Flatten(seq);
+    Array<tir::Var> params{heap};
+    Type ret_type = VoidType();
+
+    // TODO(relax-team): Consider attach the target attribute to
+    // the shape_func to indicate that this is a host function
+    // This could require us to attach target to the relax function here.
+    tir::PrimFunc shape_func(params, body, ret_type, buffer_map);
+    GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func");
+    builder_->Emit(Call(shape_func_var, {shape_heap_}), "_");
+    return to_compute.size();
+  }
+  //-------------------------------------------------------
+  // StructInfo value match logic
+  //
+  // CheckMatchCast is the only function needed by
+  // other code sections
+  //-------------------------------------------------------
+  /*!
+   * \brief Insert runtime check of the match cast condition(value, 
struct_info).
+   *
+   * \param struct_info The struct info to be matched.
+   * \param value The input value.
+   * \param always_check Whether we insert runtime check even if we can prove
+   *        that value's struct info already satisfies the condition.
+   *        This option is necessary for argument checking per our calling 
convention.
+   *
+   * \param err_ctx Extra error context to bring more informative error 
reporting.
+   * \param match_todos List of match shape todo items collected when 
recursively
+   *                    visit the match cast.
+   */
+  void CheckMatchCast(const StructInfo& struct_info, Expr value, bool 
always_check,
+                      const String& err_ctx, std::vector<MatchShapeTodoItem>* 
match_todos) {
+    return this->VisitStructInfo(struct_info, value, always_check, err_ctx, 
match_todos);
+  }
+
+  void VisitStructInfo(const StructInfo& struct_info, Expr value, bool 
always_check,
+                       const String& err_ctx, std::vector<MatchShapeTodoItem>* 
match_todos) final {
+    // short-cut, if the struct info already satisfies the
+    // constraint during match cast, we can skip matching
+    if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return;
+    return StructInfoFunctor::VisitStructInfo(struct_info, value, 
always_check, err_ctx,
+                                              match_todos);
+  }
+
+  void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool 
always_check,
+                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+  }
+
+  void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool 
always_check,
+                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+    // TODO(relax-team) add PrimValue checks later.
+    LOG(FATAL) << "MatchCast of PrimValue is not yet supported";
+  }
+
+  void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool 
always_check,
+                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+    // emit runtime check of shape
+    if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), 
GetStructInfo(value))) {
+      // check_shape_info(value, ndim, err_ctx)
+      Call call(builtin_check_shape_info_,
+                {value, PrimValue::Int64(op->ndim), GetErrContext(err_ctx)}, 
Attrs(),
+                {void_sinfo_});
+      builder_->Emit(call, "_");
+    }
+    if (op->values.defined()) {
+      MatchShapeTodoItem item;
+      item.input = value;
+      item.pattern = op->values.value();
+      item.err_ctx = err_ctx;
+      match_todos->push_back(item);
+    }
+  }
+
+  void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool 
always_check,
+                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+    // emit runtime check of shape
+    if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim), 
GetStructInfo(value))) {
+      // check_tensor_info(value, ndim, dtype, err_ctx)
+      Call call(builtin_check_tensor_info_,
+                {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype), 
GetErrContext(err_ctx)},
+                Attrs(), {void_sinfo_});
+      builder_->Emit(call, "_");
+    }
+
+    if (auto* shape_expr = op->shape.as<ShapeExprNode>()) {
+      MatchShapeTodoItem item;
+      item.input = value;
+      item.pattern = shape_expr->values;
+      item.err_ctx = err_ctx;
+      match_todos->push_back(item);
+    } else if (op->shape.as<VarNode>()) {
+      // NOTE: This part of the logic is left empty for future support as it 
is less common.
+      // Future implementors: we can emit a binding here and assert here.
+      LOG(FATAL) << "Cannot handle Tensor shape pattern where a var appears 
multiple times";
+    } else {
+      ICHECK(!op->shape.defined()) << "Can only handle tensor shape pattern 
var";
+    }
+  }
+
+  // Internal helper function to make tuple get item.
+  // This function will try to simplify constant tuples
+  // the return value **always** have struct info.
+  Expr MakeTupleGetItem(Expr value, int64_t index) {
+    if (auto* tuple_expr = value.as<TupleNode>()) {
+      return tuple_expr->fields[index];
+    } else if (auto* tuple_sinfo = 
GetStructInfoAs<TupleStructInfoNode>(value)) {
+      // value is tuple type, it is OK to run tuple get item.
+      auto ret = TupleGetItem(value, index);
+      UpdateStructInfo(ret, tuple_sinfo->fields[index]);
+      return ret;
+    } else {
+      // call runtime tuple get item, and return a object.
+      Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)}, 
Attrs(), {object_sinfo_});
+      UpdateStructInfo(call, ObjectStructInfo());
+      return call;
+    }
+  }
+
+  void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool 
always_check,
+                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+    auto* value_tinfo = GetStructInfoAs<TupleStructInfoNode>(value);
+    if (value_tinfo) {
+      CHECK_EQ(value_tinfo->fields.size(), op->fields.size())
+          << "TypeError: " << err_ctx << " during match-cast we find tuple 
size mismatch";
+    }
+    if (always_check || !value_tinfo) {
+      // check_tuple_info(value, tuple_size)
+      Call call(builtin_check_tuple_info_,
+                {value, 
PrimValue::Int64(static_cast<int64_t>(op->fields.size())),
+                 GetErrContext(err_ctx)},
+                Attrs(), {void_sinfo_});
+      builder_->Emit(call, "_");
+    }
+    // recursively visit each sub-field and run matching
+    for (size_t i = 0; i < op->fields.size(); ++i) {
+      this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), 
always_check, err_ctx,
+                            match_todos);
+    }
+  }
+
+  void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool 
always_check,
+                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+    // we only check function is callable.
+    if (!always_check && MatchStructInfo<FuncStructInfo>(value)) return;
+    // check_func_info(value, err_ctx)
+    Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)}, 
Attrs(), {void_sinfo_});
+    builder_->Emit(call, "_");
+  }
+
+  //-------------------------------------------------------
+  // Private member fields.
+  //-------------------------------------------------------
+  /*! \brief whether to emit error context, can be turned off for testing 
purposes. */
+  bool emit_err_ctx_{true};
+  /*! \brief heap ptr to store the PrimExpr slots. */
+  Var shape_heap_;
+  /*! \brief heap size. */
+  IntImm heap_size_;
+  /*! \brief index => slot. */
+  std::vector<std::unique_ptr<PrimExprSlot>> slot_vec_;
+  /*! \brief Expr => slot. */
+  PrimExprSlotMap slot_map_;
+  /*!
+   * \brief List of vars that are being defined but
+   * have not go through outstanding shape compute check.
+   */
+  std::vector<PrimExprSlot*> ready_vars_;
+  // call builtin cop
+  const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
+  const Op& null_value_op_ = Op::Get("relax.null_value");
+  // common struct info
+  const StructInfo object_sinfo_ = ObjectStructInfo();
+  const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({}));
+  // check function
+  const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"};
+  const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"};
+  const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"};
+  const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"};
+  const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"};
+  const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"};
+  const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"};
+  const ExternFunc builtin_tuple_getitem_{"vm.builtin.tuple_getitem"};
+};
+
+namespace transform {
+
+Pass VMShapeLower(bool emit_err_ctx) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) { return 
VMShapeLowerMutator::Lower(mod, emit_err_ctx); };
+  return CreateModulePass(pass_func, 0, "VMShapeLower", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool 
emit_err_ctx) {
+  return VMShapeLower(emit_err_ctx);
+});
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc
new file mode 100644
index 0000000000..1b077d8b88
--- /dev/null
+++ b/src/relax/ir/transform.cc
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/ir/transform.cc
+ * \brief Relax specific transformation passes.
+ */
+#include <dmlc/thread_local.h>
+#include <tvm/node/repr_printer.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relay/function.h>
+#include <tvm/runtime/registry.h>
+namespace tvm {
+namespace relax {
+namespace transform {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.fallback_device_type", IntImm);
+
+// TODO(@yuchen): will need to dedup with FunctionPass in Relay when we 
upstream
+class FunctionPass;
+
+/*!
+ * \brief Function-level passes are used to implement various global
+ * optimizations for a given Relax IRModule. It fetches one function at a time
+ * from the function list in the IRModule for optimization.
+ *
+ * Note that the scope of passes at this level is a Relax function. Therefore,
+ * we cannot add or delete a function through these passes as they are not 
aware
+ * of the global information.
+ */
+class FunctionPassNode : public tvm::transform::PassNode {
+ public:
+  /* \brief The pass meta data.*/
+  PassInfo pass_info;
+
+  /*! \brief The packed pass function sketches the real optimization. For
+   * instance, we can implement a pass that works on a Relax function as a
+   * `pass_func` and let it run on a given IRModule. The same `pass_func` will
+   * then be applied on each function in the IRModule.
+   */
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func;
+
+  FunctionPassNode() = default;
+
+  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
+
+  /*!
+   * \brief Run a function pass on given pass context.
+   *
+   * \param mod The IRModule that an optimization pass is applied on.
+   * \param pass_ctx The context that an optimization pass executes on.
+   *
+   * \return Return the updated IRModule.
+   */
+  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
+
+  /*!
+   * \brief Get the pass information/meta data.
+   */
+  PassInfo Info() const override { return pass_info; }
+
+  static constexpr const char* _type_key = "relax.FunctionPass";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode);
+
+ private:
+  /*
+   * \brief Check if a function should be skipped for optimization.
+   *
+   * \param func The target function to be checked.
+   *
+   * \return Return true if the function will be skipped, otherwise false.
+   */
+  bool SkipFunction(const Function& func) const;
+};
+
+class FunctionPass : public Pass {
+ public:
+  /*!
+   * \brief The constructor
+   * \param pass_func The packed function which implements a pass.
+   * \param pass_info The pass info.
+   */
+  TVM_DLL FunctionPass(
+      runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func,
+      PassInfo pass_info);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
+};
+
+FunctionPass::FunctionPass(
+    runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func,
+    PassInfo pass_info) {
+  auto n = make_object<FunctionPassNode>();
+  n->pass_func = std::move(pass_func);
+  n->pass_info = std::move(pass_info);
+  data_ = std::move(n);
+}
+
+// Perform IRModule -> IRModule optimizations at the Function level.
+IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& 
pass_ctx) const {
+  DiagnosticContext previous = DiagnosticContext::Default(mod);
+
+  if (pass_ctx->diag_ctx) {
+    DiagnosticContext tmp = pass_ctx->diag_ctx.value();
+    pass_ctx->diag_ctx = previous;
+    previous = tmp;
+  } else {
+    pass_ctx->diag_ctx = previous;
+  }
+
+  ICHECK(pass_ctx->diag_ctx)
+      << "The diagnostic context was set at the top of this block this is a 
bug.";
+
+  const PassInfo& pass_info = Info();
+
+  ICHECK(mod.defined());
+
+  VLOG_CONTEXT << pass_info->name;
+  VLOG(0) << "Executing function pass with opt level: " << 
pass_info->opt_level;
+  VLOG(1) << "Input module:" << std::endl << mod;
+
+  IRModule updated_mod = mod->ShallowCopy();
+
+  std::vector<std::pair<GlobalVar, Function> > updates;
+  for (const auto& it : updated_mod->functions) {
+    // only picks up relax::Function
+    if (auto* n = it.second.as<FunctionNode>()) {
+      Function func = GetRef<Function>(n);
+      auto updated_func = SkipFunction(func) ? func : pass_func(func, 
updated_mod, pass_ctx);
+      updates.push_back({it.first, updated_func});
+    }
+  }
+
+  for (const auto& pair : updates) {
+    updated_mod->Add(pair.first, pair.second, true);
+  }
+
+  ICHECK(pass_ctx->diag_ctx)
+      << "The diagnostic context was set at the top of this block, this is a 
bug.";
+
+  pass_ctx->diag_ctx.value().Render();
+  pass_ctx->diag_ctx = previous;
+
+  VLOG(1) << "Output module:" << std::endl << updated_mod;
+
+  return updated_mod;
+}
+
+bool FunctionPassNode::SkipFunction(const Function& func) const {
+  // TODO(@yuchen): will need to revisit in the future
+  return (func->GetAttr<String>(relay::attr::kCompiler).defined()) ||
+         func->GetAttr<Integer>(relay::attr::kSkipOptimization, 0) != 0;
+}
+
+Pass CreateFunctionPass(
+    const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& 
pass_func,
+    int opt_level, String name, tvm::Array<String> required) {
+  PassInfo pass_info = PassInfo(opt_level, name, required);
+  return FunctionPass(pass_func, pass_info);
+}
+
+TVM_REGISTER_NODE_TYPE(FunctionPassNode);
+
+TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass")
+    .set_body_typed(
+        [](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func,
+           PassInfo pass_info) { return FunctionPass(pass_func, pass_info); });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      auto* node = static_cast<const FunctionPassNode*>(ref.get());
+      const PassInfo info = node->Info();
+      p->stream << "Run Function pass: " << info->name << " at the 
optimization level "
+                << info->opt_level;
+    });
+
+class DataflowBlockPass;
+
+/*!
+ * \brief DataflowBlock-level passes are used to implement various dataflow 
block
+ * optimizations for a given Relax IRModule. It fetches one dataflow block at 
a time
+ * from the functions in an IRModule, and yields a rewritten DataflowBlock.
+ *
+ * Note that the scope of passes at this level is a Relax DataflowBlock. 
Therefore,
+ * we cannot modify the global scope Vars and symbolic shape Vars defined 
inside the dataflow block.
+ */
+class DataflowBlockPassNode : public tvm::transform::PassNode {
+ public:
+  /* \brief The pass meta data.*/
+  PassInfo pass_info;
+
+  /*! \brief The packed pass function sketches the real optimization. For
+   * instance, we can implement a pass that works on a Relax DataflowBlock as a
+   * `pass_func` and let it run on a given IRModule. The same `pass_func` will
+   * then be applied on each DataflowBlock in the IRModule.
+   */
+  runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func;
+
+  DataflowBlockPassNode() = default;
+
+  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
+
+  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
+
+  PassInfo Info() const override { return pass_info; }
+
+  static constexpr const char* _type_key = "relax.DataflowBlockPass";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockPassNode, PassNode);
+};
+
+/*! \brief Helper to apply the passed function to dataflow blocks.*/
+class DataflowBlockMutator : public ExprMutator {
+ public:
+  DataflowBlockMutator(
+      runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func,
+      IRModule mod, PassContext pass_ctx)
+      : pass_func_(pass_func), mod_(mod), pass_ctx_(pass_ctx) {}
+
+  /*!
+   * \brief Rewrite the DataflowBlockNode with pass_func_
+   *
+   * This function will check that there are no rewrites of the global scope 
Vars
+   * and symbolic shape Vars defined inside the dataflow block.
+   */
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final {
+    // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock
+    Map<String, Var> global_scope_vars;
+    Map<String, tir::Var> symbolic_vars;
+    for (const Binding& binding : n->bindings) {
+      Var var = binding->var;
+      if (const auto* match_cast = binding.as<MatchCastNode>()) {
+        auto collected_vars = 
SymbolicVarCollector::Collect(match_cast->struct_info);
+        for (const tir::VarNode* var : collected_vars) {
+          symbolic_vars.Set(var->name_hint, GetRef<tir::Var>(var));
+        }
+      }
+      if (!var.as<DataflowVarNode>()) {
+        global_scope_vars.Set(var->name_hint(), var);
+      }
+    }
+
+    // apply pass_func_ to the DataflowBlock
+    DataflowBlock block = GetRef<DataflowBlock>(n);
+    DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_);
+
+    // raise error if there are updates of recorded Global Scope Vars and 
Symbolic Vars
+    for (const Binding& binding : updated_block->bindings) {
+      Var var = binding->var;
+      if (const auto* match_cast = binding.as<MatchCastNode>()) {
+        auto collected_vars = 
SymbolicVarCollector::Collect(match_cast->struct_info);
+        for (const tir::VarNode* var : collected_vars) {
+          if (symbolic_vars.count(var->name_hint) > 0) {
+            tir::Var old_var = symbolic_vars[var->name_hint];
+            ICHECK(var == old_var.get())
+                << "Error: DataflowBlock Pass should not rewrite any Symbolic 
Var.";
+            symbolic_vars.erase(var->name_hint);
+          }
+        }
+      }
+      if (!var.as<DataflowVarNode>() && 
global_scope_vars.count(var->name_hint()) > 0) {
+        ICHECK(var.same_as(global_scope_vars[var->name_hint()]))
+            << "Error: DataflowBlock Pass should not rewrite any GlobalScope 
Var.";
+        global_scope_vars.erase(var->name_hint());
+      }
+    }
+    ICHECK(global_scope_vars.empty() && symbolic_vars.empty())
+        << "Error: DataflowBlock Pass should not delete any 
GlobalScope/Symbolic Var.";
+
+    return std::move(updated_block);
+  }
+
+ private:
+  class SymbolicVarCollector : public StructInfoVisitor {
+   public:
+    static std::unordered_set<const tir::VarNode*> Collect(const StructInfo& 
info) {
+      SymbolicVarCollector collector;
+      collector.VisitStructInfo(info);
+      return std::move(collector.symbolic_vars_);
+    }
+
+   private:
+    void VisitStructInfoExprField(const PrimExpr& expr) final {
+      if (const tir::VarNode* sym_var = expr.as<tir::VarNode>()) {
+        symbolic_vars_.insert(sym_var);
+      }
+    }
+
+   private:
+    std::unordered_set<const tir::VarNode*> symbolic_vars_;
+  };
+
+  runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func_;
+  IRModule mod_;
+  PassContext pass_ctx_;
+};
+
+class DataflowBlockPass : public Pass {
+ public:
+  /*!
+   * \brief The constructor
+   * \param pass_func The packed function which implements a pass.
+   * \param pass_info The pass info.
+   */
+  TVM_DLL DataflowBlockPass(
+      runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func,
+      PassInfo pass_info);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockPass, Pass, 
DataflowBlockPassNode);
+};
+
+DataflowBlockPass::DataflowBlockPass(
+    runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func,
+    PassInfo pass_info) {
+  auto n = make_object<DataflowBlockPassNode>();
+  n->pass_func = std::move(pass_func);
+  n->pass_info = std::move(pass_info);
+  data_ = std::move(n);
+}
+
+// Perform IRModule -> IRModule transformations at the DataflowBlock level.
+IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& 
pass_ctx) const {
+  DiagnosticContext previous = DiagnosticContext::Default(mod);
+
+  if (pass_ctx->diag_ctx) {
+    DiagnosticContext tmp = pass_ctx->diag_ctx.value();
+    pass_ctx->diag_ctx = previous;
+    previous = tmp;
+  } else {
+    pass_ctx->diag_ctx = previous;
+  }
+
+  ICHECK(pass_ctx->diag_ctx)
+      << "The diagnostic context was set at the top of this block, this is a 
bug.";
+
+  const PassInfo& pass_info = Info();
+
+  ICHECK(mod.defined());
+
+  VLOG_CONTEXT << pass_info->name;
+  VLOG(0) << "Executing DataflowBlock pass with opt level: " << 
pass_info->opt_level;
+  VLOG(1) << "Input module:" << std::endl << mod;
+
+  IRModule updated_mod = mod->ShallowCopy();
+
+  DataflowBlockMutator dataflow_block_mutator(pass_func, updated_mod, 
pass_ctx);
+  std::vector<std::pair<GlobalVar, Function> > updates;
+  for (const auto& it : updated_mod->functions) {
+    // only picks up relax::Function
+    if (auto* n = it.second.as<FunctionNode>()) {
+      Function func = GetRef<Function>(n);
+      Function updated_func = 
Downcast<Function>(dataflow_block_mutator.VisitExpr(func));
+      updates.push_back({it.first, updated_func});
+    }
+  }
+
+  for (const auto& pair : updates) {
+    updated_mod->Add(pair.first, pair.second, true);
+  }
+
+  ICHECK(pass_ctx->diag_ctx)
+      << "The diagnostic context was set at the top of this block this is a 
bug.";
+
+  pass_ctx->diag_ctx.value().Render();
+  pass_ctx->diag_ctx = previous;
+
+  VLOG(1) << "Output module:" << std::endl << updated_mod;
+
+  return updated_mod;
+}
+
+Pass CreateDataflowBlockPass(
+    const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)>& pass_func,
+    int opt_level, String name, tvm::Array<String> required) {
+  PassInfo pass_info = PassInfo(opt_level, name, required);
+  return DataflowBlockPass(pass_func, pass_info);
+}
+
+TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode);
+
+TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass")
+    .set_body_typed(
+        [](runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func,
+           PassInfo pass_info) { return DataflowBlockPass(pass_func, 
pass_info); });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<DataflowBlockPassNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
+      auto* node = static_cast<const DataflowBlockPassNode*>(ref.get());
+      const PassInfo info = node->Info();
+      p->stream << "Run DataflowBlock pass: " << info->name << " at the 
optimization level "
+                << info->opt_level;
+    });
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py 
b/tests/python/relax/test_backend_transform_shape_lower.py
new file mode 100644
index 0000000000..0bf0f175dd
--- /dev/null
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -0,0 +1,429 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.ir import assert_structural_equal
+from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_const_shape_arg():
+    MS = MatchShapeCode
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Shape([1, 2]), y: R.Shape):
+            return x
+
+        @T.prim_func
+        def extra_func(H: T.Buffer(T.int64(4), "int64")):
+            """Extra function, checks if the pass preserves it."""
+            H[T.int64(1)] = H[T.int64(0)] + T.int64(1)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Shape([1, 2]), y: R.Shape):
+            shape_heap = R.null_value()
+            _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "", 
sinfo_args=[R.Tuple()])
+            _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "", 
sinfo_args=[R.Tuple()])
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                2,
+                MS.ASSERT_EQUAL_TO_IMM,
+                1,
+                MS.ASSERT_EQUAL_TO_IMM,
+                2,
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            return x
+
+        @T.prim_func
+        def extra_func(H: T.Buffer(T.int64(4), "int64")):
+            H[T.int64(1)] = H[T.int64(0)] + T.int64(1)
+
+    before = Before
+    expected = Expected
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    assert_structural_equal(after, expected)
+
+
+def test_static_fn_check():
+    """Check static shape and function."""
+    MS = MatchShapeCode
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])):
+            return y
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])):
+            shape_heap = R.null_value()
+            _ = R.call_packed("vm.builtin.check_func_info", f, "", 
sinfo_args=[R.Tuple()])
+            _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "", 
sinfo_args=[R.Tuple()])
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                y,
+                shape_heap,
+                2,
+                MS.ASSERT_EQUAL_TO_IMM,
+                1,
+                MS.ASSERT_EQUAL_TO_IMM,
+                2,
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            return y
+
+    before = Before
+    expected = Expected
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    assert_structural_equal(after, expected)
+
+
+def test_simple_symbolic_shape():
+    MS = MatchShapeCode
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor(["n", 2, "m"], "float32")):
+            return x
+
+    sindex = {
+        "n": 0,
+        "m": 1,
+    }
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(["n", 2, "m"], "float32")):
+            shape_heap = R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                [R.prim_value(2)],
+                sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+            )
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info", x, 3, R.dtype("float32"), "", 
sinfo_args=[R.Tuple()]
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                3,
+                MS.STORE_TO_HEAP,
+                sindex["n"],
+                MS.ASSERT_EQUAL_TO_IMM,
+                2,
+                MS.STORE_TO_HEAP,
+                sindex["m"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            return x
+
+    before = Before
+    expected = Expected
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    assert_structural_equal(after, expected)
+
+
+def test_symbolic_compute():
+    MS = MatchShapeCode
+    MK = MakeShapeCode
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)
+        ) -> R.Shape(ndim=3):
+            n = T.Var("n", "int64")
+            k = T.Var("k", "int64")
+            z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
+            return (k + 1, m, 2)
+
+    # slot assignment:
+    # 0: n, 1: m, 2:k, 3: k+1
+    sindex = {"n": 0, "m": 1, "k": 2, "k+1": 3}
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def shape_func(H: T.Buffer(T.int64(4), "int64")):
+            # generated compute function
+            H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1)
+
+        @R.function
+        def main(
+            x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)
+        ) -> R.Shape(ndim=3):
+            n = T.Var("n", "int64")
+            k = T.Var("k", "int64")
+            shape_heap = R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                [R.prim_value(4)],
+                sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+            )
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", 
sinfo_args=[R.Tuple()]
+            )
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "", 
sinfo_args=[R.Tuple()]
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                2,
+                MS.STORE_TO_HEAP,
+                sindex["n"],
+                MS.STORE_TO_HEAP,
+                sindex["m"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                y,
+                shape_heap,
+                3,
+                MS.STORE_TO_HEAP,
+                sindex["k"],
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["m"],
+                MS.NO_OP,
+                0,
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            _ = shape_func(shape_heap)
+            # extra assertion on y's shape after shape computation
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                y,
+                shape_heap,
+                3,
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["k"],
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["m"],
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["k+1"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
+            # construct shape value for return
+            s = R.call_packed(
+                "vm.builtin.make_shape",
+                shape_heap,
+                3,
+                MK.LOAD_SHAPE,
+                sindex["k+1"],
+                MK.LOAD_SHAPE,
+                sindex["m"],
+                MK.USE_IMM,
+                2,
+                sinfo_args=[R.Shape(ndim=3)],
+            )
+            return s
+
+    before = Before
+    expected = Expected
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    assert_structural_equal(after, expected)
+
+
+def test_tuple_handling():
+    MS = MatchShapeCode
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tuple(
+                R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, 
R.Tensor(["n", "k"], "int32"))
+            )
+        ):
+            return x
+
+    # slot assignment:
+    sindex = {"n": 0, "m": 1, "k": 2}
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tuple(
+                R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, 
R.Tensor(["n", "k"], "int32"))
+            )
+        ):
+            shape_heap = R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                [R.prim_value(3)],
+                sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+            )
+            # recursively unpack tuple for static info check
+            _ = R.call_packed("vm.builtin.check_tuple_info", x, 2, "", 
sinfo_args=[R.Tuple()])
+            t0 = x[0]
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                t0,
+                2,
+                R.dtype("float32"),
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            t1 = x[1]
+            _ = R.call_packed("vm.builtin.check_tuple_info", t1, 2, "", 
sinfo_args=[R.Tuple()])
+            t1x0 = t1[0]
+            _ = R.call_packed("vm.builtin.check_shape_info", t1x0, -1, "", 
sinfo_args=[R.Tuple()])
+            t1x1 = t1[1]
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                t1x1,
+                2,
+                R.dtype("int32"),
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            # match shape checks.
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                t0,
+                shape_heap,
+                2,
+                MS.STORE_TO_HEAP,
+                sindex["n"],
+                MS.STORE_TO_HEAP,
+                sindex["m"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                t1x1,
+                shape_heap,
+                2,
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["n"],
+                MS.STORE_TO_HEAP,
+                sindex["k"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            return x
+
+    before = Before
+    expected = Expected
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    assert_structural_equal(after, expected)
+
+
+def test_return_match_check():
+    """Test when return body is not same as ret_struct_info, runtime match 
check needed."""
+    MS = MatchShapeCode
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["n", "m"], "float32"), y: R.Object
+        ) -> R.Tuple(R.Tensor(["n", "m"], "float32")):
+            return y
+
+    # slot assignment:
+    sindex = {
+        "n": 0,
+        "m": 1,
+    }
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(["n", "m"], "float32"), y: R.Object
+        ) -> R.Tuple(R.Tensor(["n", "m"], "float32")):
+            shape_heap = R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                [R.prim_value(2)],
+                sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+            )
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", 
sinfo_args=[R.Tuple()]
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                2,
+                MS.STORE_TO_HEAP,
+                sindex["n"],
+                MS.STORE_TO_HEAP,
+                sindex["m"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            _ = R.call_packed("vm.builtin.check_tuple_info", y, 1, "", 
sinfo_args=[R.Tuple()])
+            # emit runtime function call since y do not have the right type.
+            y1 = R.call_packed("vm.builtin.tuple_getitem", y, 0, 
sinfo_args=[R.Object])
+            # run check
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                y1,
+                2,
+                R.dtype("float32"),
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            # shape check
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                y1,
+                shape_heap,
+                2,
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["n"],
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["m"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+
+            return y
+
+    before = Before
+    expected = Expected
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to