This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 55810737be7acdc39120baa5d1bfed0d84b3bc64
Author: Tianqi Chen <tqc...@users.noreply.github.com>
AuthorDate: Thu Feb 16 21:12:02 2023 -0500

    [Unity][VM] Supporting "compiled" exec mode. (#14015)
    
    [VM] Supporting "compiled" exec mode.
    
    This PR adds support of "compiled" mode to the VM. The compiled mode 
translate
    the relax function into TIR function and drive it through the TIR function.
    
    It is different from the micro AOT codegen, which generate TIR code that 
targets
    the micro C runtime environment and useful for resource limited settings 
with
    smaller set of features. Both leverages the low-level TIR build that is 
also shared with TensorIR.
    
    The current implementation targets full TVM (VM) runtime, that comes with 
PackedFunc,
    object, tuple, closure and all kinds of rich structure support. This also 
mean that
    we can leverage the full runtime support to handle things like allocation, 
dynamic shape,
    easy plugins and python interaction, which are not available in more 
limited runtime.
    
    The user directly use the same API to load the generated code regardless of
    compiled mode or bytecode. And just need to change one line
    
    ```python
    ex = relax.vm.build(mod, target, exec_mode="compiled")
    ```
    
    The simplicity is thanks to the TVM runtime archiecture that allows us to 
compose things together in objects.
    The only difference is how the PackedFunc of high-level driving is being 
provided. In the case of bytecode
    it is normal interpretation and in the case of compiled mode it is TIR.
    
    It is a complete implementation Unit-testcases are added. All codegen build 
tests are updated to include two
    exec_modes and have passed locally.
    
    Co-authored-by: Junru Shao <junrushao1...@gmail.com>
---
 include/tvm/tir/builtin.h                  |  44 +++
 python/tvm/script/ir_builder/tir/ir.py     |   8 +
 python/tvm/tir/op.py                       |  68 ++++
 src/relax/backend/vm/codegen_vm_tir.cc     | 511 +++++++++++++++++++++++++++++
 src/runtime/library_module.cc              |   5 +-
 src/target/llvm/codegen_cpu.cc             |   6 +-
 src/tir/op/builtin.cc                      |  12 +
 src/tir/op/runtime.cc                      |  41 +++
 src/tir/transforms/lower_tvm_builtin.cc    | 169 +++++-----
 tests/python/relax/test_vm_build.py        |   2 +-
 tests/python/relax/test_vm_codegen_only.py |   2 +-
 tests/python/relax/test_vm_codegen_tir.py  | 224 +++++++++++++
 12 files changed, 1006 insertions(+), 86 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 708abde2cd..35022e0e75 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -762,6 +762,50 @@ TVM_DLL const Op& start_profile_intrinsic();
  */
 TVM_DLL const Op& end_profile_intrinsic();
 
+/*!
+ * \brief Get a item from any list and return it.
+ *
+ *  Any anylist_getitem(Handle anylist,
+ *                      int index)
+ *     return anylist[index];
+ *  }
+ *
+ * \note This intrinsic is only applicable when appearing
+ *       in call_packed and anylist_setitem_call_packed.
+ */
+TVM_DLL const Op& anylist_getitem();
+
+/*!
+ * \brief Reset and clear a item in any list.
+ *
+ *  void anylist_resetitem(Handle anylist,
+ *                         int index)
+ *    anylist[index] = nullptr;
+ *  }
+ *
+ * \note This intrinsic is only applicable when appearing
+ *       in call_packed and anylist_setitem_call_packed.
+ */
+TVM_DLL const Op& anylist_resetitem();
+
+/*!
+ * \brief Set an item into any list by running packed function call.
+ *
+ *  void anylist_setitem_call_packed(Handle anylist,
+ *                                   int index,
+ *                                   name, *args)
+ *
+ *    anylist[index] = call_packed(name, *args)
+ *  }
+ *  \note This intrinsic can be used in combination with anylist_getitem.
+ */
+TVM_DLL const Op& anylist_setitem_call_packed();
+
+/*!
+ * \brief Same as anylist_setitem_call_packed but use C calling convention.
+ */
+TVM_DLL const Op& anylist_setitem_call_cpacked();
+
 /*! \brief The kind of structure field info used in intrinsic */
 enum TVMStructFieldKind : int {
   // array head address
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 5f4e9d4f2c..601963565f 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1713,6 +1713,10 @@ TVMBackendAllocWorkspace = 
_op_wrapper(_tir_op.TVMBackendAllocWorkspace)
 TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
 start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
 end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
+anylist_getitem = _op_wrapper(_tir_op.anylist_getitem)
+anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem)
+anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
+anylist_setitem_call_cpacked = 
_op_wrapper(_tir_op.anylist_setitem_call_cpacked)
 
 
 def _dtype_forward(func):
@@ -1988,6 +1992,10 @@ __all__ = [
     "start_profile_intrinsic",
     "end_profile_intrinsic",
     "meta_var",
+    "anylist_getitem",
+    "anylist_resetitem",
+    "anylist_setitem_call_packed",
+    "anylist_setitem_call_cpacked",
     "llvm_lookup_intrinsic_id",
     "type_annotation",
     "broadcast",
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 0a9c4fdfaa..14decca77e 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -2931,6 +2931,74 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr):
     return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, 
device_id, ptr)
 
 
+def anylist_getitem(list_handle, index):
+    """Returns an item from any list.
+    list_handle: Var
+        The handle to anylist
+    index : int
+        The index
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.anylist_getitem", list_handle, index)
+
+
+def anylist_resetitem(list_handle, index):
+    """Reset an item from any list.
+    list_handle: Var
+        The handle to anylist
+    index : int
+        The index
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("int", "tir.anylist_resetitem", list_handle, index)
+
+
+def anylist_setitem_call_packed(list_handle, index, func_name, *args):
+    """Set anylist item by result of packed call.
+    list_handle: Var
+        The handle to anylist
+    index : int
+        The index
+    func_name: str
+        The name of the function to be called.
+    args:
+        Extra arguments
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        "int", "tir.anylist_setitem_call_packed", list_handle, index, 
func_name, *args
+    )
+
+
+def anylist_setitem_call_cpacked(list_handle, index, func_name, *args):
+    """Set anylist item by result of packed call.
+    list_handle: Var
+        The handle to anylist
+    index : int
+        The index
+    func_name: str
+        The name of the function to be called.
+    args:
+        Extra arguments
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        "int", "tir.anylist_setitem_call_cpacked", list_handle, index, 
func_name, *args
+    )
+
+
 # pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, 
name="min")  # type: ignore
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc 
b/src/relax/backend/vm/codegen_vm_tir.cc
new file mode 100644
index 0000000000..2f63a50d37
--- /dev/null
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -0,0 +1,511 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/vm/codegen_tir.cc
+ * \brief A codegen to generate VMTIR function(that can be compiled) from 
executable.
+ */
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/module.h>
+#include <tvm/relax/exec_builder.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/runtime/relax_vm/executable.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt.h>
+
+#include <cctype>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+namespace relax_vm {
+
+using vm::VMFuncInfo;
+
+/*!
+ * \brief A class to generate VMTIR for Relax functions.
+ *
+ * \note Skip CallPacked with special attrs for now, as they can be
+ *       further simplified with PrimValue.
+ */
+class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
+ public:
+  explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod)
+      : builder_(builder), ctx_mod_(ctx_mod) {}
+
+  static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
+    // create a new copy
+    IRModule res_mod = mod;
+    res_mod.CopyOnWrite();
+
+    CodeGenVMTIR codegen(builder, mod);
+    // Remove relax function and turn into TIR func.
+    for (auto& p : mod->functions) {
+      if (auto* func = p.second.as<FunctionNode>()) {
+        auto tir_func = codegen.Codegen(GetRef<Function>(func));
+        auto gsymbol = tir_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        res_mod->Add(GlobalVar(gsymbol.value()), tir_func);
+        res_mod->Remove(p.first);
+      }
+    }
+    return res_mod;
+  }
+
+ private:
+  int64_t NewRegister() { return registers_num_++; }
+
+  static IntImm ConstInt64(int64_t value) { return IntImm(DataType::Int(64), 
value); }
+
+  static IntImm ConstInt32(int64_t value) { return IntImm(DataType::Int(32), 
value); }
+
+  PrimExpr RegListGet(int64_t slot) const {
+    // use 128 bits to represent any
+    return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(),
+                     {reg_anylist_handle_, ConstInt32(slot)});
+  }
+
+  PrimExpr ConstListGet(int64_t slot) const {
+    // use 128 bits to represent any
+    return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(),
+                     {const_anylist_handle_, ConstInt32(slot)});
+  }
+
+  PrimExpr FuncListGet(int64_t slot) const {
+    // use 128 bits to represent any
+    return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(),
+                     {func_anylist_handle_, ConstInt32(slot)});
+  }
+
+  void EmitStmt(tir::Stmt stmt) {
+    ICHECK(!stmt_stack_.empty());
+    stmt_stack_.back().emplace_back(stmt);
+  }
+
+  void EmitCallPacked(String name, const Array<PrimExpr>& args, int64_t 
dst_anylist_slot = -1) {
+    Array<PrimExpr> all_args;
+    // negative index indicate return value can be discarded, emit call_packed
+    if (dst_anylist_slot >= 0) {
+      all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)};
+    }
+    all_args.push_back(tir::StringImm(name));
+    for (PrimExpr arg : args) {
+      all_args.push_back(arg);
+    }
+    if (dst_anylist_slot >= 0) {
+      this->EmitStmt(tir::Evaluate(
+          tir::Call(DataType::Int(32), 
tir::builtin::anylist_setitem_call_packed(), all_args)));
+    } else {
+      this->EmitStmt(
+          tir::Evaluate(tir::Call(DataType::Int(32), 
tir::builtin::tvm_call_packed(), all_args)));
+    }
+  }
+
+  void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array<PrimExpr>& 
args,
+                       int64_t dst_anylist_slot = -1) {
+    Optional<String> gsymbol = 
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    ICHECK(gsymbol.defined()) << "All functions must have global symbol at 
this phase";
+    Array<PrimExpr> all_args;
+    // negative index indicate return value can be discarded, emit call_packed
+    if (dst_anylist_slot >= 0) {
+      all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)};
+    }
+    all_args.push_back(tir::StringImm(gsymbol.value()));
+    for (PrimExpr arg : args) {
+      all_args.push_back(arg);
+    }
+    // push an empty handle to be compatible with current cpacked convention
+    // TODO(tqchen): revisit C Packed convention
+    all_args.push_back(tir::make_zero(DataType::Handle()));
+    if (dst_anylist_slot >= 0) {
+      this->EmitStmt(tir::Evaluate(
+          tir::Call(DataType::Int(32), 
tir::builtin::anylist_setitem_call_cpacked(), all_args)));
+    } else {
+      this->EmitStmt(
+          tir::Evaluate(tir::Call(DataType::Int(32), 
tir::builtin::tvm_call_cpacked(), all_args)));
+    }
+  }
+
+  tir::PrimFunc Codegen(const Function& func) {
+    Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    ICHECK(gsymbol.defined()) << "there should be no local functions in Relax 
VM codegen phase. "
+                                 "Did you forget to apply LambdaLift or 
AttachGlobalSymbol Pass?";
+    // initialize the state
+    stmt_stack_ = {};
+    registers_num_ = 0;
+    var_map_.clear();
+    ctx_ptr_ = tir::Var("ctx_ptr", DataType::Handle());
+    reg_anylist_handle_ = tir::Var("r", DataType::Handle());
+    func_anylist_handle_ = tir::Var("f", DataType::Handle());
+    const_anylist_handle_ = tir::Var("c", DataType::Handle());
+
+    Array<String> param_names;
+    for (Var param : func->params) {
+      param_names.push_back(param->name_hint());
+    }
+    // declare this function.
+    builder_->DeclareFunction(gsymbol.value(), 
vm::VMFuncInfo::FuncKind::kVMTIRFunc);
+
+    for (size_t i = 0; i < func->params.size(); ++i) {
+      int64_t r = NewRegister();
+      ICHECK_EQ(static_cast<size_t>(r), i);
+      this->var_map_.insert({func->params[i], RegListGet(r)});
+    }
+    size_t ret_reg = NewRegister();
+
+    tir::Stmt body = WithNewScope([&]() {
+      Optional<PrimExpr> ret = ExprFunctor::VisitExpr(func->body);
+      if (ret.defined()) {
+        this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg);
+      }
+    });
+
+    // Mark the function entry internally.
+    builder_->EmitFunction(gsymbol.value(), param_names.size(), param_names,
+                           VMFuncInfo::FuncKind::kVMTIRFunc, registers_num_);
+    builder_->EndFunction(gsymbol.value());
+
+    Type ret_type = VoidType();
+    Array<tir::Var> tir_params = {ctx_ptr_, reg_anylist_handle_, 
const_anylist_handle_,
+                                  func_anylist_handle_};
+    String tir_func_name = "__vmtir__" + gsymbol.value();
+    tir::PrimFunc tir_func(tir_params, body, ret_type, {});
+    tir_func = WithAttr(tir_func, "global_symbol", tir_func_name);
+    registers_num_ = 0;
+    var_map_.clear();
+    stmt_stack_.clear();
+    return tir_func;
+  }
+
+  Optional<PrimExpr> VisitExpr_(const SeqExprNode* op) final {
+    for (auto block : op->blocks) {
+      for (Binding binding : block->bindings) {
+        Optional<PrimExpr> value;
+        if (auto* var_binding = binding.as<VarBindingNode>()) {
+          value = this->VisitExpr(var_binding->value);
+        } else if (auto* match_cast = binding.as<MatchCastNode>()) {
+          value = this->VisitExpr(match_cast->value);
+        } else {
+          LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
+        }
+        this->var_map_.insert({binding->var, value});
+      }
+    }
+    return this->VisitExpr(op->body);
+  }
+
+  Optional<PrimExpr> VisitExpr_(const CallNode* call_node) final {
+    Call call = GetRef<Call>(call_node);
+
+    if (call_node->op == null_value_op_) {
+      return tir::Call(DataType::Handle(), tir::builtin::reinterpret(),
+                       {IntImm(DataType::Int(64), 0)});
+    }
+    int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister();
+    if (call->op.as<OpNode>()) {
+      if (call_node->op == call_builtin_with_ctx_op_) {
+        EmitCallBuiltinWithCtx(call, dst_reg);
+      } else if (call_node->op == alloc_storage_op_) {
+        EmitAllocStorage(call, dst_reg);
+      } else if (call_node->op == alloc_tensor_op_) {
+        EmitAllocTensor(call, dst_reg);
+      } else {
+        // every "normal" operator is lowered to a global var in the IRModule. 
The Attrs for those
+        // ops are handled in a pass when lowering them to TIR.
+        LOG(FATAL) << "CodeGenVMTIR cannot handle this intrinsic now:\n" << 
call_node->op;
+      }
+    } else {
+      EmitNormalCall(call, dst_reg);
+    }
+    if (dst_reg >= 0) {
+      return RegListGet(dst_reg);
+    } else {
+      return NullOpt;
+    }
+  }
+
+  Optional<PrimExpr> VisitExpr_(const IfNode* op) final {
+    // Reserve a register for return
+    size_t merge_register = NewRegister();
+    PrimExpr cond_value = this->VisitExpr(op->cond).value();
+
+    // turn ndarray cond value into scalar.
+    cond_value = tir::Cast(DataType::Bool(),
+                           tir::Call(DataType::Int(32), 
tir::builtin::tvm_call_packed(),
+                                     
{tir::StringImm("vm.builtin.read_if_cond"), cond_value}));
+
+    tir::Stmt true_branch = WithNewScope([&]() {
+      PrimExpr true_value = this->VisitExpr(op->true_branch).value();
+      this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register);
+    });
+    tir::Stmt false_branch = WithNewScope([&]() {
+      PrimExpr false_value = this->VisitExpr(op->false_branch).value();
+      this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register);
+    });
+    this->EmitStmt(tir::IfThenElse(cond_value, true_branch, false_branch));
+    return RegListGet(merge_register);
+  }
+
+  Optional<PrimExpr> VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
+    auto it = this->var_map_.find(var);
+    ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined";
+    return it->second;
+  }
+
+  Optional<PrimExpr> VisitExpr_(const ConstantNode* op) final {
+    return ConstListGet(builder_->ConvertConstant(op->data).value());
+  }
+
+  Optional<PrimExpr> VisitExpr_(const ShapeExprNode* op) final {
+    std::vector<int64_t> shape;
+    for (PrimExpr e : op->values) {
+      if (auto* int_value = e.as<IntImmNode>()) {
+        shape.push_back(int_value->value);
+      } else {
+        LOG(FATAL) << "Should only use constant shape after shape lowering: " 
<< op->values;
+      }
+    }
+    return ConstListGet(builder_->ConvertConstant(ShapeTuple(shape)).value());
+  }
+
+  Optional<PrimExpr> VisitExpr_(const PrimValueNode* op) final { return 
op->value; }
+
+  Optional<PrimExpr> VisitExpr_(const StringImmNode* op) final {
+    return ConstListGet(builder_->ConvertConstant(op->value).value());
+  }
+
+  Optional<PrimExpr> VisitExpr_(const DataTypeImmNode* op) final {
+    return ConstListGet(builder_->ConvertConstant(op->value).value());
+  }
+
+  Optional<PrimExpr> VisitExpr_(const TupleNode* op) final {
+    Tuple tuple = GetRef<Tuple>(op);
+    Array<PrimExpr> args;
+    for (auto arg : tuple->fields) {
+      args.push_back(this->VisitExpr(arg).value());
+    }
+    int32_t dst_register = NewRegister();
+    this->EmitCallPacked("vm.builtin.make_tuple", args, dst_register);
+    return RegListGet(dst_register);
+  }
+
+  Optional<PrimExpr> VisitExpr_(const TupleGetItemNode* op) final {
+    TupleGetItem expr = GetRef<TupleGetItem>(op);
+    Array<PrimExpr> args = {this->VisitExpr(expr->tuple).value()};
+
+    args.push_back(ConstInt64(expr->index));
+
+    int64_t dst_register = NewRegister();
+    this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register);
+    return RegListGet(dst_register);
+  }
+
+  // Lookup the function and see if it matches
+  Optional<String> LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* 
kind) {
+    if (auto* ext_func = expr.as<ExternFuncNode>()) {
+      *kind = VMFuncInfo::FuncKind::kPackedFunc;
+      return ext_func->global_symbol;
+    } else if (auto* gvar_ptr = expr.as<GlobalVarNode>()) {
+      GlobalVar gvar = GetRef<GlobalVar>(gvar_ptr);
+      // Run a look up in the env to see if it maps to an extern func.
+      auto it = ctx_mod_->functions.find(gvar);
+      if (it != ctx_mod_->functions.end()) {
+        BaseFunc func = (*it).second;
+        if (auto* efunc = func.as<ExternFuncNode>()) {
+          *kind = VMFuncInfo::FuncKind::kPackedFunc;
+          return efunc->global_symbol;
+        } else if (func.as<FunctionNode>()) {
+          *kind = VMFuncInfo::FuncKind::kVMTIRFunc;
+          return gvar->name_hint;
+        } else if (func.as<tir::PrimFuncNode>()) {
+          *kind = VMFuncInfo::FuncKind::kPackedFunc;
+          return gvar->name_hint;
+        } else {
+          *kind = VMFuncInfo::FuncKind::kPackedFunc;
+          return gvar->name_hint;
+        }
+      }
+      LOG(WARNING) << "Undefined global var " << gvar->name_hint;
+      // undefined global var, consider eliminate later.
+      *kind = VMFuncInfo::FuncKind::kPackedFunc;
+      return gvar->name_hint;
+    } else {
+      return NullOpt;
+    }
+  }
+  // Lookup PrimFunc in the same module
+  // We can do direct PrimFunc call in such cases
+  Optional<tir::PrimFunc> LookupPrimFunc(const String& name) {
+    if (!ctx_mod_->ContainGlobalVar(name)) return NullOpt;
+
+    GlobalVar gvar = ctx_mod_->GetGlobalVar(name);
+    auto it = ctx_mod_->functions.find(gvar);
+    if (it != ctx_mod_->functions.end()) {
+      BaseFunc func = (*it).second;
+      if (auto* prim_func = func.as<tir::PrimFuncNode>()) {
+        return GetRef<tir::PrimFunc>(prim_func);
+      }
+    }
+    return NullOpt;
+  }
+
+  Optional<PrimExpr> VisitExpr_(const GlobalVarNode* op) final {
+    VMFuncInfo::FuncKind kind;
+    auto symbol = LookupFunction(GetRef<Expr>(op), &kind);
+    ICHECK(symbol.defined());
+    builder_->DeclareFunction(symbol.value(), kind);
+    return FuncListGet(builder_->GetFunction(symbol.value()).value());
+  }
+
+  Optional<PrimExpr> VisitExpr_(const ExternFuncNode* op) final {
+    builder_->DeclareFunction(op->global_symbol, 
VMFuncInfo::FuncKind::kPackedFunc);
+    return FuncListGet(builder_->GetFunction(op->global_symbol).value());
+  }
+
+  void EmitAllocStorage(const Call& call_node, int64_t dst_reg) {
+    // Handle args of the call
+    Array<PrimExpr> args;
+    args.push_back(ctx_ptr_);
+    for (Expr arg : call_node->args) {
+      args.push_back(this->VisitExpr(arg).value());
+    }
+    this->EmitCallPacked("vm.builtin.alloc_storage", args, dst_reg);
+  }
+
+  void EmitAllocTensor(const Call& call_node, int64_t dst_reg) {
+    ICHECK_EQ(call_node->args.size(), 4);
+    Array<PrimExpr> args;
+    args.reserve(4);
+    for (Expr arg : call_node->args) {
+      args.push_back(this->VisitExpr(arg).value());
+    }
+    this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg);
+  }
+
+  void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) {
+    Array<PrimExpr> args;
+    // if context is required, pass as first argument.
+    args.push_back(ctx_ptr_);
+    auto* func = call_node->args[0].as<ExternFuncNode>();
+    ICHECK(func) << "CallBuiltin comes with extern func";
+
+    auto tuple_arg = Downcast<Tuple>(call_node->args[1]);
+
+    // Handle args of the call
+    for (Expr arg : tuple_arg->fields) {
+      args.push_back(this->VisitExpr(arg).value());
+    }
+
+    this->EmitCallPacked(func->global_symbol, args, dst_reg);
+  }
+
+  void EmitNormalCall(const Call& call_node, int64_t dst_reg) {
+    Array<PrimExpr> args = VisitArray(call_node->args);
+    // A function can be a closure that comes from parent
+    // Do call closure to be safe.
+    VMFuncInfo::FuncKind kind;
+    auto symbol = LookupFunction(call_node->op, &kind);
+
+    if (symbol.defined() && kind == VMFuncInfo::FuncKind::kPackedFunc) {
+      // primfunc in the same module.
+      // use cpacked to directly invoke without named based lookup
+      if (Optional<tir::PrimFunc> prim_func = LookupPrimFunc(symbol.value())) {
+        this->EmitCallCPacked(prim_func.value(), args, dst_reg);
+      } else {
+        this->EmitCallPacked(symbol.value(), args, dst_reg);
+      }
+    } else {
+      // Default path, leverage function table and invoke as closure
+      Array<PrimExpr> all_args;
+      all_args.push_back(ctx_ptr_);
+      all_args.push_back(this->VisitExpr(call_node->op).value());
+      for (auto arg : args) {
+        all_args.push_back(arg);
+      }
+      this->EmitCallPacked("vm.builtin.invoke_closure", all_args, dst_reg);
+    }
+  }
+
+  template <typename FLambda>
+  tir::Stmt WithNewScope(const FLambda& callback) {
+    stmt_stack_.push_back({});
+    callback();
+    tir::Stmt stmt = tir::SeqStmt::Flatten(stmt_stack_.back());
+    stmt_stack_.pop_back();
+    return stmt;
+  }
+
+  Array<PrimExpr> VisitArray(const Array<Expr>& arr) {
+    Array<PrimExpr> ret;
+    for (size_t i = 0; i < arr.size(); ++i) {
+      ret.push_back(this->VisitExpr(arr[i]).value());
+    }
+    return ret;
+  }
+  /*! \brief Internal ExecBuilder. */
+  relax::ExecBuilder builder_;
+  /*! \brief List to ctx_ptr */
+  tir::Var ctx_ptr_;
+  /*! \brief List to store temp object registers */
+  tir::Var reg_anylist_handle_;
+  /*! \brief List to store closures */
+  tir::Var func_anylist_handle_;
+  /*! \brief List to store constants */
+  tir::Var const_anylist_handle_;
+  /*!
+   * \brief Total number of virtual registers allocated.
+   * \note The first two registers are reserved for special registers.
+   */
+  int64_t registers_num_ = 0;
+  /*! \brief Stack to build up statements */
+  std::vector<std::vector<tir::Stmt>> stmt_stack_;
+  /*! \brief Map from var to Expr. */
+  std::unordered_map<Var, Optional<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> 
var_map_;
+  /*! \brief the context module. */
+  IRModule ctx_mod_;
+  /*! \brief Cache ops that need to be frequently used later to reduce lookup 
overhead. */
+  const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
+  const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+  const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
+  const Op& null_value_op_ = Op::Get("relax.null_value");
+};
+
+/*!
+ * \brief Create the Relax VM executable from all relax.Function in mod.
+ *        and add them to exec_builder. Create extra TIR functions.
+ *
+ * \param exec_builder Builder to collect executables.
+ * \param mod Input module.
+ * \return Extra TIR module created.
+ */
+IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) {
+  return CodeGenVMTIR::Run(exec_builder, mod);
+}
+
+TVM_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen);
+
+}  // namespace relax_vm
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index d6c2f791de..17dfbec0d0 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -77,7 +77,10 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const 
ObjectPtr<Object>&
     int ret_type_code = kTVMNullptr;
     int ret = (*faddr)(const_cast<TVMValue*>(args.values), 
const_cast<int*>(args.type_codes),
                        args.num_args, &ret_value, &ret_type_code, nullptr);
-    ICHECK_EQ(ret, 0) << TVMGetLastError();
+    // NOTE: important to keep the original error message.
+    if (ret != 0) {
+      LOG(FATAL) << TVMGetLastError();
+    }
     if (ret_type_code != kTVMNullptr) {
       *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
     }
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 21d2c6ebe0..10aa2688a8 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -905,8 +905,10 @@ CodeGenCPU::PackedCall 
CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
           llvm::Function::Create(ftype_tvm_backend_packed_c_func_, 
llvm::Function::ExternalLinkage,
                                  func_name, module_.get());
     }
-
-    nargs -= 1;
+    // NOTE: This is a bugfix to a previous coupled convention(in 
lower_tvm_builtin)
+    // The begin, end should correspond to the right location in cpacked 
excluding resource handle.
+    // TODO(tqchen): upstream the fix.
+    // nargs -= 1;
     call_args.insert(call_args.end(), {
                                           builder_->CreateBitCast(arg_value, 
t_void_p_),
                                           arg_tcode.addr,
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 680202751f..f9d5228042 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -318,6 +318,18 @@ TIR_DEFINE_BUILTIN_FUNC(start_profile_intrinsic)
 TIR_DEFINE_BUILTIN_FUNC(end_profile_intrinsic)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
 
+TIR_DEFINE_BUILTIN_FUNC(anylist_getitem)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kReadState));
+
+TIR_DEFINE_BUILTIN_FUNC(anylist_resetitem)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAnyListResetItem");
+
+TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 }  // namespace builtin
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc
new file mode 100644
index 0000000000..9ee6c67ec9
--- /dev/null
+++ b/src/tir/op/runtime.cc
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/op/runtime.cc
+ * \brief TIR ops for runtime functions.
+ */
+#include <tvm/ir/op.h>
+#include <tvm/tir/op_attr_types.h>
+
+namespace tvm {
+namespace tir {
+
+TVM_REGISTER_OP("tir.TVMBackendAnyListSetPackedArg")
+    .set_num_inputs(5)
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAnyListSetPackedArg")
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TVM_REGISTER_OP("tir.TVMBackendAnyListMoveFromPackedReturn")
+    .set_num_inputs(3)
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", 
"TVMBackendAnyListMoveFromPackedReturn")
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/lower_tvm_builtin.cc 
b/src/tir/transforms/lower_tvm_builtin.cc
index 082a54f9c7..b0a87a3056 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -302,13 +302,21 @@ class BuiltinLower : public StmtExprMutator {
       return Stmt(n);
     }
   }
+
   PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->op.same_as(builtin::tvm_call_packed())) {
-      return MakeCallPacked(op, /* use_string_lookup */ true);
+      return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(),
+                                   /* use_string_lookup */ true);
     } else if (op->op.same_as(builtin::tvm_call_cpacked())) {
-      return MakeCallPacked(op, /* use_string_lookup */ false);
+      return MakeCallPackedGeneric(op, 0, builtin::tvm_call_cpacked_lowered(),
+                                   /* use_string_lookup */ false);
     } else if (op->op.same_as(builtin::tvm_call_trace_packed())) {
-      return MakeCallTracePacked(op);
+      return MakeCallPackedGeneric(op, 0, 
builtin::tvm_call_trace_packed_lowered(),
+                                   /* use_string_lookup */ true);
+    } else if (op->op.same_as(builtin::anylist_setitem_call_packed())) {
+      return MakeAnyListSetItemCallPacked(op, 
builtin::tvm_call_packed_lowered(), true);
+    } else if (op->op.same_as(builtin::anylist_setitem_call_cpacked())) {
+      return MakeAnyListSetItemCallPacked(op, 
builtin::tvm_call_cpacked_lowered(), false);
     } else if (op->op.same_as(builtin::tvm_stack_make_shape())) {
       return MakeShape(op);
     } else if (op->op.same_as(builtin::tvm_stack_make_array())) {
@@ -418,8 +426,68 @@ class BuiltinLower : public StmtExprMutator {
                                        cast(DataType::Int(32), device_type_)));
     return TVMStructGet(DataType::Handle(), scope.stack_array, idx, 
builtin::kArrAddr);
   }
-  // call packed.
-  PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) {
+
+  void SetPackedArg(PrimExpr arg, const Var& value_stack, const Buffer& 
tcode_stack,
+                    size_t stack_offset, std::vector<tir::Stmt>* prep_seq) {
+    auto* call_pattern = arg.as<CallNode>();
+    if (call_pattern && call_pattern->op.same_as(builtin::anylist_getitem())) {
+      // call runtime function to set anylist
+      prep_seq->emplace_back(
+          Evaluate(Call(DataType::Int(32), 
Op::Get("tir.TVMBackendAnyListSetPackedArg"),
+                        {call_pattern->args[0], call_pattern->args[1], 
value_stack,
+                         tcode_stack->data, ConstInt32(stack_offset)})));
+    } else {
+      DataType api_type = APIType(arg.dtype());
+      if (arg.dtype() != api_type) {
+        arg = Cast(api_type, arg);
+      }
+      prep_seq->emplace_back(
+          TVMStructSet(value_stack, stack_offset, builtin::kTVMValueContent, 
arg));
+      int arg_tcode = api_type.code();
+      if (api_type.is_handle() && arg.as<StringImmNode>()) {
+        arg_tcode = kTVMStr;
+      } else if (IsArrayHandle(arg)) {
+        arg_tcode = kTVMDLTensorHandle;
+      }
+      // opaque handle need to set the kind properly
+      if (arg_tcode == kTVMOpaqueHandle) {
+        prep_seq->emplace_back(IfThenElse(
+            Call(DataType::Bool(), builtin::isnullptr(), {arg}),
+            BufferStore(tcode_stack, ConstInt32(kTVMNullptr), 
{ConstInt32(stack_offset)}),
+            BufferStore(tcode_stack, ConstInt32(arg_tcode), 
{ConstInt32(stack_offset)})));
+      } else {
+        prep_seq->emplace_back(
+            BufferStore(tcode_stack, ConstInt32(arg_tcode), 
{ConstInt32(stack_offset)}));
+      }
+    }
+  }
+
+  PrimExpr MakeAnyListSetItemCallPacked(const CallNode* op, const Op& 
lowered_op,
+                                        bool use_string_lookup) {
+    PrimExpr list_handle = op->args[0];
+    PrimExpr list_index = op->args[1];
+
+    Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup);
+    PrimExpr value_stack = call->args[1];
+    PrimExpr tcode_stack = call->args[2];
+    // The stack offset of return value stack_end
+    PrimExpr ret_offset = call->args[4];
+    auto& prep_seq = prep_seq_stack_.back();
+    prep_seq.emplace_back(Evaluate(call));
+    return Call(DataType::Int(32), 
Op::Get("tir.TVMBackendAnyListMoveFromPackedReturn"),
+                {list_handle, list_index, value_stack, tcode_stack, 
ret_offset});
+  }
+  /*!
+   * \brief Generic tool to make low-level
+   *  packed_call(other_args..., func_name, packed_arg0, packed_arg1...)
+   *
+   * \param op The call
+   * \param name_offset The beginning of function name and call packed section.
+   * \param lowered_packed_op The target lowered op.
+   * \param use_string_lookup Whether to lookup function by string.
+   */
+  Call MakeCallPackedGeneric(const CallNode* op, size_t name_offset, const Op& 
lowered_packed_op,
+                             bool use_string_lookup) {
     auto& scope = alloca_scope_.back();
     auto& prep_seq = prep_seq_stack_.back();
 
@@ -427,34 +495,24 @@ class BuiltinLower : public StmtExprMutator {
     size_t restore_array_stack = scope.run_sizes.array_stack;
     size_t arg_stack_begin = scope.run_sizes.arg_stack;
 
-    size_t arg_count = op->args.size();
+    size_t args_begin = name_offset + 1;
+    size_t args_end = op->args.size();
 
     // cpacked expects a resource_handle parameter
     if (!use_string_lookup) {
-      arg_count--;
+      --args_end;
     }
+    size_t num_args = args_end - args_begin;
 
-    scope.run_sizes.arg_stack += arg_count;
+    // The extra one slot is for return value.
+    scope.run_sizes.arg_stack += num_args + 1;
     // Specially handle the buffer packed intrinsic
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
-    for (size_t i = 1; i < arg_count; ++i) {
-      PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
-      PrimExpr arg = op->args[i];
-      DataType t = arg.dtype();
-      DataType api_type = APIType(t);
-      if (t != api_type) {
-        arg = Cast(api_type, arg);
-      }
-      prep_seq.emplace_back(TVMStructSet(scope.stack_value,
-                                         static_cast<int>(arg_stack_begin + i 
- 1),
-                                         builtin::kTVMValueContent, arg));
-      int arg_tcode = api_type.code();
-      if (api_type.is_handle() && arg.as<StringImmNode>()) {
-        arg_tcode = kTVMStr;
-      }
-      if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
-      prep_seq.emplace_back(BufferStore(scope.stack_tcode, 
ConstInt32(arg_tcode), {stack_index}));
+
+    for (size_t i = 0; i < num_args; ++i) {
+      this->SetPackedArg(op->args[args_begin + i], scope.stack_value, 
scope.stack_tcode,
+                         arg_stack_begin + i, &prep_seq);
     }
     // Verify stack size matches earlier value.
     if (is_precheck_) {
@@ -465,13 +523,12 @@ class BuiltinLower : public StmtExprMutator {
     scope.run_sizes.shape_stack = restore_shape_stack;
     scope.run_sizes.array_stack = restore_array_stack;
     scope.run_sizes.arg_stack = arg_stack_begin;
-    Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, 
scope.stack_tcode->data,
-                                   ConstInt32(arg_stack_begin),
-                                   ConstInt32(arg_stack_begin + 
op->args.size() - 1)};
-
+    Array<PrimExpr> packed_args = {op->args[name_offset], scope.stack_value,
+                                   scope.stack_tcode->data, 
ConstInt32(arg_stack_begin),
+                                   ConstInt32(arg_stack_begin + num_args)};
     // cpacked call resource_handle
     if (!use_string_lookup) {
-      PrimExpr last_arg = op->args[arg_count];
+      PrimExpr last_arg = op->args[args_end];
       const VarNode* var_node = last_arg.as<VarNode>();
       if (var_node != nullptr) {
         tir::Var resource_handle = GetRef<Var>(var_node);
@@ -480,57 +537,7 @@ class BuiltinLower : public StmtExprMutator {
         packed_args.push_back(last_arg);
       }
     }
-
-    auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered()
-                                          : 
builtin::tvm_call_cpacked_lowered();
-    return Call(op->dtype, builtin_call, packed_args);
-  }
-
-  PrimExpr MakeCallTracePacked(const CallNode* op) {
-    ICHECK(!alloca_scope_.empty());
-    auto& scope = alloca_scope_.back();
-    auto& prep_seq = prep_seq_stack_.back();
-
-    int64_t restore_shape_stack = scope.run_sizes.shape_stack;
-    size_t restore_array_stack = scope.run_sizes.array_stack;
-    size_t arg_stack_begin = scope.run_sizes.arg_stack;
-    scope.run_sizes.arg_stack += op->args.size();
-    size_t args_size = op->args.size();
-    ICHECK_GT(args_size, 0);
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<CallNode>();
-    for (size_t i = 1; i < op->args.size(); ++i) {
-      PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
-      PrimExpr arg = op->args[i];
-      DataType t = arg.dtype();
-      DataType api_type = APIType(t);
-      if (t != api_type) {
-        arg = Cast(api_type, arg);
-      }
-      prep_seq.emplace_back(TVMStructSet(scope.stack_value,
-                                         static_cast<int>(arg_stack_begin + i 
- 1),
-                                         builtin::kTVMValueContent, arg));
-      int arg_tcode = api_type.code();
-      ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
-      prep_seq.emplace_back(BufferStore(scope.stack_tcode, 
ConstInt32(arg_tcode), {stack_index}));
-    }
-    // Verify stack size matches earlier value.
-    if (is_precheck_) {
-      scope.UpdateMax();
-    } else {
-      scope.AssertMaxIsValid();
-    }
-    scope.run_sizes.shape_stack = restore_shape_stack;
-    scope.run_sizes.array_stack = restore_array_stack;
-    // Update the top of the stack, so we can use more than one
-    // packed function's arguments with the one stack.
-    scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1;
-    Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, 
scope.stack_tcode->data,
-                                   ConstInt32(arg_stack_begin),
-                                   ConstInt32(arg_stack_begin + 
op->args.size() - 1),
-                                   // Pass traced value.
-                                   op->args[args_size - 1]};
-    return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), 
packed_args);
+    return Call(op->dtype, lowered_packed_op, packed_args);
   }
 
   Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) {
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index 0a881691ac..d57efd8b99 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -30,7 +30,7 @@ from tvm.relax.testing import nn
 from tvm.script import relax as R, tir as T
 from tvm.relax.testing.vm import check_saved_func
 
-EXEC_MODE = ["bytecode"]
+EXEC_MODE = ["bytecode", "compiled"]
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
diff --git a/tests/python/relax/test_vm_codegen_only.py 
b/tests/python/relax/test_vm_codegen_only.py
index 4b79ecf70f..600d245617 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -28,7 +28,7 @@ from tvm.relax.testing.vm import check_saved_func
 from tvm.script import relax as R
 from tvm.script import tir as T
 
-EXEC_MODE = ["bytecode"]
+EXEC_MODE = ["bytecode", "compiled"]
 
 
 def codegen(mod, target, exec_mode="bytecode"):
diff --git a/tests/python/relax/test_vm_codegen_tir.py 
b/tests/python/relax/test_vm_codegen_tir.py
new file mode 100644
index 0000000000..6f3bced385
--- /dev/null
+++ b/tests/python/relax/test_vm_codegen_tir.py
@@ -0,0 +1,224 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test the TIR codegen path of VM compiled mode.
+
+Restrictions: all shape lowered, explicit allocation.
+"""
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.ir import assert_structural_equal
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def get_tir_mod(mod):
+    builder = relax.ExecBuilder()
+    return relax.vm._vmcodegen(builder, mod, exec_mode="compiled")
+
+
+def test_add():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor):
+            R.func_attr({"global_symbol": "foo"})
+            z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: 
T.handle):
+            T.func_attr({"global_symbol": "__vmtir__foo"})
+            T.anylist_setitem_call_packed(
+                r,
+                T.int32(2),
+                "test.vm.add",
+                T.anylist_getitem(r, T.int32(0)),
+                T.anylist_getitem(r, T.int32(0)),
+            )
+            T.anylist_setitem_call_packed(
+                r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, 
T.int32(2))
+            )
+
+    before = Before
+    expected = Expected
+    after = get_tir_mod(before)
+    assert_structural_equal(expected, after)
+
+
+def test_tir_call():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def shape_func(H: T.Buffer(T.int64(4), "int64")):
+            T.func_attr({"global_symbol": "shape_func"})
+            # generated compute function
+            H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
+
+        @R.function
+        def foo(x: R.Tensor):
+            R.func_attr({"global_symbol": "foo"})
+            _ = shape_func(x)
+            return x
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def shape_func(H: T.Buffer(T.int64(4), "int64")):
+            T.func_attr({"global_symbol": "shape_func"})
+            # generated compute function
+            H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
+
+        @T.prim_func
+        def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: 
T.handle):
+            T.func_attr({"global_symbol": "__vmtir__foo"})
+            T.call_cpacked(
+                "shape_func", T.anylist_getitem(r, T.int32(0)), 
T.reinterpret("handle", T.uint64(0))
+            )
+            T.anylist_setitem_call_packed(
+                r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, 
T.int32(0))
+            )
+
+    before = Before
+    expected = Expected
+    after = get_tir_mod(before)
+    assert_structural_equal(expected, after)
+
+
+def test_if_cond():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor:
+            R.func_attr({"global_symbol": "ife"})
+            if cond:
+                w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
+            else:
+                w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor))
+            return w
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: 
T.handle):
+            T.func_attr({"global_symbol": "__vmtir__ife"})
+            if T.cast(
+                T.tvm_call_packed("vm.builtin.read_if_cond", 
T.anylist_getitem(r, T.int32(0))),
+                "bool",
+            ):
+                T.anylist_setitem_call_packed(
+                    r,
+                    T.int32(4),
+                    "test.vm.add",
+                    T.anylist_getitem(r, T.int32(1)),
+                    T.anylist_getitem(r, T.int32(1)),
+                )
+                T.anylist_setitem_call_packed(
+                    r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, 
T.int32(4))
+                )
+            else:
+                T.anylist_setitem_call_packed(
+                    r,
+                    T.int32(5),
+                    "test.vm.mul",
+                    T.anylist_getitem(r, T.int32(1)),
+                    T.anylist_getitem(r, T.int32(1)),
+                )
+                T.anylist_setitem_call_packed(
+                    r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, 
T.int32(5))
+                )
+            T.anylist_setitem_call_packed(
+                r, T.int32(2), "vm.builtin.copy", T.anylist_getitem(r, 
T.int32(3))
+            )
+
+    before = Before
+    expected = Expected
+    after = get_tir_mod(before)
+    assert_structural_equal(expected, after)
+
+
+def test_const():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            R.func_attr({"global_symbol": "main"})
+            y = R.const([1, 2])
+            z = (y, R.const([3, 4]), x)
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: 
T.handle):
+            # function attr dict
+            T.func_attr({"global_symbol": "__vmtir__main"})
+            # body
+            T.anylist_setitem_call_packed(
+                r,
+                T.int32(2),
+                "vm.builtin.make_tuple",
+                T.anylist_getitem(c, T.int32(0)),
+                T.anylist_getitem(c, T.int32(1)),
+                T.anylist_getitem(r, T.int32(0)),
+            )
+            T.anylist_setitem_call_packed(
+                r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, 
T.int32(2))
+            )
+
+    before = Before
+    expected = Expected
+    after = get_tir_mod(before)
+    assert_structural_equal(expected, after)
+
+
+def test_const_call():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            R.func_attr({"global_symbol": "main"})
+            y = R.const([1, 2])
+            z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor))
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: 
T.handle):
+            # function attr dict
+            T.func_attr({"global_symbol": "__vmtir__main"})
+            # body
+            T.anylist_setitem_call_packed(
+                r,
+                2,
+                "test.vm.add",
+                T.anylist_getitem(r, 0),
+                T.anylist_getitem(c, 0),
+            )
+            T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", 
T.anylist_getitem(r, 2))
+
+    before = Before
+    expected = Expected
+    after = get_tir_mod(before)
+    assert_structural_equal(expected, after)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to