This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 8fb1c9c577e4c329c47203afada2e52fb49c2292
Author: Yuchen Jin <yuch...@cs.washington.edu>
AuthorDate: Fri Feb 10 16:19:37 2023 -0800

    [Unity] Relax VM codegen (#13954)
---
 python/tvm/relax/testing/runtime_builtin.py        |  34 ++
 src/relax/backend/vm/codegen_vm.cc                 | 447 +++++++++++++++++++++
 src/relax/op/op.cc                                 | 220 +++++++++-
 src/relax/op/op_common.h                           |  25 +-
 src/runtime/relax_vm/builtin.cc                    |  23 +-
 tests/python/relax/test_runtime_builtin.py         | 153 +++++++
 tests/python/relax/test_tvmscript_printer_relax.py |  41 +-
 tests/python/relax/test_vm_codegen_only.py         | 333 +++++++++++++++
 8 files changed, 1228 insertions(+), 48 deletions(-)

diff --git a/python/tvm/relax/testing/runtime_builtin.py 
b/python/tvm/relax/testing/runtime_builtin.py
new file mode 100644
index 0000000000..1b04364e69
--- /dev/null
+++ b/python/tvm/relax/testing/runtime_builtin.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Testing utilities for runtime builtin functions."""
+from enum import IntEnum
+
+
+class MatchShapeCode(IntEnum):
+    """Code passed to match shape builtin"""
+
+    ASSERT_EQUAL_TO_IMM = 0
+    STORE_TO_HEAP = 1
+    NO_OP = 2
+    ASSERT_EQUAL_TO_LOAD = 3
+
+
+class MakeShapeCode(IntEnum):
+    """Code passed to match shape builtin"""
+
+    USE_IMM = 0
+    LOAD_SHAPE = 1
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
new file mode 100644
index 0000000000..1782f1107a
--- /dev/null
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -0,0 +1,447 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/vm/codegen_vm.cc
+ * \brief A codegen to generate VM executable from a Relax IRModule.
+ */
+#include <tvm/driver/driver_api.h>
+#include <tvm/relax/exec_builder.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/runtime/relax_vm/bytecode.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "../../../target/metadata_module.h"
+#include "../../../target/source/codegen_source_base.h"
+
+namespace tvm {
+namespace relax {
+namespace relax_vm {
+
+using tvm::Target;
+using namespace relax;
+using namespace tvm::runtime;
+using namespace tvm::runtime::relax_vm;
+
+// Helper function to get the function name of the registered packed function 
implementation of
+// relax operator.
+FCallPacked GetPackedFuncName(const Call& call) {
+  static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
+  if (call->op.as<OpNode>()) {
+    Op op = Downcast<Op>(call->op);
+    if (op_map.count(op)) {
+      return op_map[op];
+    }
+  }
+  return {};
+}
+
+/*!
+ * \brief A class to generate VM executable for Relax functions.
+ */
+class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
+ public:
+  explicit CodeGenVM(relax::ExecBuilder builder, IRModule ctx_mod)
+      : builder_(builder), ctx_mod_(ctx_mod) {}
+
+  static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
+    IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>());
+    CodeGenVM codegen(builder, mod);
+    // Remove relax function and turn into TIR func.
+    for (auto& p : mod->functions) {
+      if (auto* func = p.second.as<FunctionNode>()) {
+        codegen.Codegen(GetRef<Function>(func));
+      } else {
+        res_mod->Add(p.first, p.second);
+      }
+    }
+    return res_mod;
+  }
+
+ protected:
+  size_t NewRegister() { return registers_num_++; }
+
+  // Convert Arg value to a register, trigger copy if needed
+  Instruction::Arg EnsureReg(Instruction::Arg arg) {
+    if (arg.kind() == Instruction::ArgKind::kRegister) {
+      return arg;
+    } else {
+      RegName dst_reg = NewRegister();
+      builder_->EmitCall("vm.builtin.copy", {arg}, dst_reg);
+      return Instruction::Arg::Register(dst_reg);
+    }
+  }
+
+  void Codegen(const Function& func) {
+    Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    ICHECK(gsymbol.defined()) << "there should be no local functions in Relax 
VM codegen phase. "
+                                 "Did you forget to apply LambdaLift or 
AttachGlobalSymbol Pass?";
+
+    Array<String> param_names;
+    for (Var param : func->params) {
+      param_names.push_back(param->name_hint());
+    }
+
+    builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names);
+
+    for (size_t i = 0; i < func->params.size(); ++i) {
+      RegName r = NewRegister();
+      ICHECK_EQ(r, static_cast<RegName>(i));
+      this->var_arg_map_.insert({func->params[i], 
Instruction::Arg::Register(r)});
+    }
+    Instruction::Arg ret = ExprFunctor::VisitExpr(func->body);
+    builder_->EmitRet(EnsureReg(ret));
+    builder_->EndFunction(gsymbol.value());
+    // reset register number to be 0;
+    registers_num_ = 0;
+    var_arg_map_.clear();
+  }
+
+  Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
+    for (auto block : op->blocks) {
+      for (Binding binding : block->bindings) {
+        Instruction::Arg value;
+        if (auto* var_binding = binding.as<VarBindingNode>()) {
+          value = this->VisitExpr(var_binding->value);
+        } else if (auto* match_cast = binding.as<MatchCastNode>()) {
+          value = this->VisitExpr(match_cast->value);
+        } else {
+          LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
+        }
+        this->var_arg_map_.insert({binding->var, value});
+      }
+    }
+
+    Instruction::Arg ret_reg = this->VisitExpr(op->body);
+    return ret_reg;
+  }
+
+  Instruction::Arg VisitExpr_(const CallNode* call_node) final {
+    Call call = GetRef<Call>(call_node);
+
+    if (call_node->op == null_value_op_) {
+      return Instruction::Arg::Register(Instruction::kVoidRegister);
+    }
+
+    // allocate dst register.
+    RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : 
NewRegister();
+    if (call->op.as<OpNode>()) {
+      if (call_node->op == call_builtin_with_ctx_op_) {
+        // TODO(relax-team) migrate most handling of op to
+        // directly map to call_builtin_with_ctx before codegen and simplify 
vm codegen.
+        EmitCallBuiltinWithCtx(call, dst_reg);
+      } else if (call_node->op == alloc_storage_op_) {
+        EmitAllocStorage(call, dst_reg);
+      } else if (call_node->op == alloc_tensor_op_) {
+        EmitAllocTensor(call, dst_reg);
+      } else {
+        // every "normal" operator is lowered to a global var in the IRModule. 
The Attrs for those
+        // ops are handled in a pass when lowering them to TIR.
+        LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << 
call_node->op;
+      }
+    } else {
+      EmitNormalCall(call, dst_reg);
+    }
+    return Instruction::Arg::Register(dst_reg);
+  }
+
+  Instruction::Arg VisitExpr_(const IfNode* op) final {
+    const If& ife = GetRef<If>(op);
+    Instruction::Arg cond_value = this->VisitExpr(ife->cond);
+
+    // Reserve a register for cond
+    RegName cond_reg = NewRegister();
+    builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg);
+
+    // obtain the temp exec in progress.
+    vm::Executable* exec = builder_->exec();
+
+    // Record the offset of If instruction
+    size_t if_offset = exec->instr_offset.size();
+
+    builder_->EmitIf(Instruction::Arg::Register(cond_reg), 3);
+    size_t num_instr = exec->instr_offset.size();
+    Instruction::Arg true_value = this->VisitExpr(ife->true_branch);
+    // Reserve a register for return
+    size_t merge_register = NewRegister();
+    // Copy the output from true branch to merge register
+    builder_->EmitCall("vm.builtin.copy", {true_value}, merge_register);
+
+    // Record the offset of Goto instruction
+    size_t goto_offset = exec->instr_offset.size();
+
+    builder_->EmitGoto(1);
+
+    // Calculate the false offset of If
+    size_t false_offset = exec->instr_offset.size() - num_instr + 1;
+
+    Instruction::Arg false_value = this->VisitExpr(ife->false_branch);
+    // Copy the output data of false branch to merge register
+    builder_->EmitCall("vm.builtin.copy", {false_value}, merge_register);
+
+    // Update the offsets of the If instruction emitted above
+    // Jump to the behind of the next goto instruction
+    exec->SetInstructionData(if_offset, 2, 
static_cast<ExecWord>(false_offset));
+    // Update the pc_offset of Goto instruction
+    // Jump over the false branch
+    size_t pc_offset = exec->instr_offset.size() - goto_offset;
+    exec->SetInstructionData(goto_offset, 1, static_cast<ExecWord>(pc_offset));
+    return Instruction::Arg::Register(merge_register);
+  }
+
+  Instruction::Arg VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
+    auto it = this->var_arg_map_.find(var);
+    ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not 
defined";
+    return it->second;
+  }
+
+  Instruction::Arg VisitExpr_(const ConstantNode* op) final {
+    return builder_->ConvertConstant(op->data);
+  }
+
+  Instruction::Arg VisitExpr_(const ShapeExprNode* op) final {
+    std::vector<int64_t> shape;
+    for (PrimExpr e : op->values) {
+      if (auto* int_value = e.as<IntImmNode>()) {
+        shape.push_back(int_value->value);
+      } else {
+        LOG(FATAL) << "Should only use constant shape after shape lowering: " 
<< op->values;
+      }
+    }
+    return builder_->ConvertConstant(ShapeTuple(shape));
+  }
+
+  Instruction::Arg VisitExpr_(const PrimValueNode* op) final {
+    if (auto* int_imm = op->value.as<IntImmNode>()) {
+      return builder_->ConvertConstant(int_imm->value);
+    } else {
+      auto* float_imm = op->value.as<FloatImmNode>();
+      ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now";
+      return builder_->ConvertConstant(float_imm->value);
+    }
+  }
+
+  Instruction::Arg VisitExpr_(const StringImmNode* op) final {
+    return builder_->ConvertConstant(op->value);
+  }
+
+  Instruction::Arg VisitExpr_(const DataTypeImmNode* op) final {
+    return builder_->ConvertConstant(op->value);
+  }
+
+  Instruction::Arg VisitExpr_(const TupleNode* op) final {
+    Tuple tuple = GetRef<Tuple>(op);
+    std::vector<Instruction::Arg> args;
+    for (Expr arg : tuple->fields) {
+      args.push_back(this->VisitExpr(arg));
+    }
+    size_t dst_register = NewRegister();
+    builder_->EmitCall("vm.builtin.make_tuple", args, dst_register);
+
+    return Instruction::Arg::Register(dst_register);
+  }
+
+  Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final {
+    TupleGetItem expr = GetRef<TupleGetItem>(op);
+    std::vector<Instruction::Arg> args = {this->VisitExpr(expr->tuple)};
+
+    args.push_back(builder_->ConvertConstant(expr->index));
+
+    size_t dst_register = NewRegister();
+    builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register);
+
+    return Instruction::Arg::Register(dst_register);
+  }
+
+  Instruction::Arg VisitExpr_(const GlobalVarNode* op) final {
+    GlobalVar gvar = GetRef<GlobalVar>(op);
+    Optional<String> symbol;
+    VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc;
+
+    // Run a look up in the env to see if it maps to an extern func.
+    auto it = ctx_mod_->functions.find(gvar);
+    if (it != ctx_mod_->functions.end()) {
+      BaseFunc func = (*it).second;
+      if (auto* efunc = func.as<ExternFuncNode>()) {
+        symbol = efunc->global_symbol;
+        kind = VMFuncInfo::FuncKind::kPackedFunc;
+      } else if (func.as<FunctionNode>()) {
+        symbol = gvar->name_hint;
+        kind = VMFuncInfo::FuncKind::kVMFunc;
+      }
+    }
+    // GlobalVar can be reference to a Relax function or a TIR primfunc
+    // At this point: all global var must corresponds to the right symbol.
+    // TODO(relax-team): switch everything to extern before splitting TIR/relax
+    // so we do not have idle global var here.
+    if (!symbol.defined()) {
+      symbol = gvar->name_hint;
+      kind = VMFuncInfo::FuncKind::kPackedFunc;
+    }
+    // declare the function to be safe.
+    ICHECK(symbol.defined());
+    builder_->DeclareFunction(symbol.value(), kind);
+    return builder_->GetFunction(symbol.value());
+  }
+
+  Instruction::Arg VisitExpr_(const ExternFuncNode* op) final {
+    builder_->DeclareFunction(op->global_symbol, 
VMFuncInfo::FuncKind::kPackedFunc);
+    return builder_->GetFunction(op->global_symbol);
+  }
+
+  void EmitAllocStorage(const Call& call_node, RegName dst_reg) {
+    ICHECK_EQ(call_node->args.size(), 3);
+    // Handle args of the call
+    std::vector<Instruction::Arg> args;
+    args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
+    // buffer size, dtype, device index
+    for (auto arg : call_node->args) {
+      args.push_back(this->VisitExpr(arg));
+    }
+    builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg);
+  }
+
+  void EmitAllocTensor(const Call& call_node, RegName dst_reg) {
+    ICHECK_EQ(call_node->args.size(), 4);
+    std::vector<Instruction::Arg> args;
+    args.reserve(4);
+    for (Expr arg : call_node->args) {
+      args.push_back(this->VisitExpr(arg));
+    }
+    builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg);
+  }
+
+  void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) {
+    std::vector<Instruction::Arg> args;
+    args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
+
+    auto func = this->VisitExpr(call_node->args[0]);
+    auto tuple_arg = Downcast<Tuple>(call_node->args[1]);
+
+    // Handle args of the call
+    for (Expr arg : tuple_arg->fields) {
+      args.push_back(this->VisitExpr(arg));
+    }
+
+    builder_->EmitCall(func, args, dst_reg);
+  }
+
+  void EmitNormalCall(const Call& call_node, RegName dst_reg) {
+    Instruction::Arg func = VisitExpr(call_node->op);
+    std::vector<Instruction::Arg> args = VisitArray(call_node->args);
+    builder_->EmitCall(func, args, dst_reg);
+  }
+
+  // TODO(relax-team) revisit after PrimValue.
+  // Emit the `call_node` attributes as constants and append these constants 
to `args` vector.
+  void AppendAttrsAsConstants(const Call& call_node, 
std::vector<Instruction::Arg>& args) {
+    auto attrs = call_node->attrs;
+    if (!attrs.defined()) return;
+
+    LOG(FATAL) << "Support for attributes of Op " << call_node->op
+               << " has not been implemented yet.";
+    return;
+  }
+
+  // Emits call to packed function `name` with arguments copied over from 
`call_node` args and
+  // attributes.
+  void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, 
RegName dst_reg) {
+    std::vector<Instruction::Arg> args = VisitArray(call_node->args);
+    AppendAttrsAsConstants(call_node, args);
+    builder_->EmitCall(name, args, dst_reg);
+  }
+
+  std::vector<Instruction::Arg> VisitArray(const Array<Expr>& arr) {
+    std::vector<Instruction::Arg> ret;
+    for (size_t i = 0; i < arr.size(); ++i) {
+      ret.push_back(this->VisitExpr(arr[i]));
+    }
+    return ret;
+  }
+
+  /*! \brief Internal ExecBuilder. */
+  relax::ExecBuilder builder_;
+  /*!
+   * \brief Total number of virtual registers allocated.
+   * \note The first two registers are reserved for special registers.
+   */
+  size_t registers_num_ = 0;
+  /*! \brief Map from var to register number. */
+  std::unordered_map<Var, Instruction::Arg, ObjectPtrHash, ObjectPtrEqual> 
var_arg_map_;
+  /*! \brief the context module. */
+  IRModule ctx_mod_;
+  /*! \brief Cache ops that need to be frequently used later to reduce lookup 
overhead. */
+  const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
+  const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+  const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
+  const Op& null_value_op_ = Op::Get("relax.null_value");
+};
+
+/*!
+ * \brief Create the Relax VM executable from all relax.Function in mod.
+ *        and add them to exec_builder.
+ * \param exec_builder Builder to collect executables.
+ * \param mod Input module.
+ * \return Left over IRModule that may contain otehr functions.
+ */
+IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) {
+  return CodeGenVM::Run(exec_builder, mod);
+}
+
+TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen);
+
+/*!
+ * \brief Link the libaries together.
+ */
+Module VMLink(ExecBuilder builder, Target target, Optional<Module> lib, 
Array<Module> ext_libs,
+              Map<String, runtime::NDArray> params) {
+  // TODO(relax-team) Revisit the param and ext_lib options.
+  ObjectPtr<Executable> executable = builder->Get();
+  if (!lib.defined()) {
+    lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
+  }
+  std::unordered_map<std::string, runtime::NDArray> conv_params;
+  for (const auto& [name, param] : params) {
+    conv_params[name] = param;
+  }
+  Module combined_lib = codegen::CreateMetadataModule(
+      conv_params, lib.value(), ext_libs, target,
+
+      // TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM 
support.
+      // Before jumping into details, only support cpp runtime for now.
+      relay::Runtime::Create("cpp"),
+      relay::Executor::Create("graph"),  // TODO(@sunggg): pass arbitrarily 
executor. CPP runtime
+                                         // won't use this anyways.
+      relay::backend::ExecutorCodegenMetadata());
+  executable->Import(combined_lib);
+  return Module(executable);
+}
+
+TVM_REGISTER_GLOBAL("relax.VMLink").set_body_typed(VMLink);
+
+}  // namespace relax_vm
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 8640ed79ad..ca66b0a9ef 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -18,13 +18,46 @@
  */
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/expr.h>
-#include <tvm/relax/op_attr_types.h>
 #include <tvm/relax/utils.h>
 #include <tvm/relay/op.h>
 
+#include "op_common.h"
+
 namespace tvm {
 namespace relax {
 
+bool EqualConstInt(const PrimExpr& lhs, int64_t value) {
+  if (const int64_t* pvalue = tir::as_const_int(lhs)) {
+    return pvalue[0] == value;
+  }
+  return false;
+}
+
+bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) {
+  PrimExpr diff = lhs - rhs;
+  if (const int64_t* pdiff = tir::as_const_int(diff)) {
+    return pdiff[0] == 0;
+  }
+  tvm::arith::Analyzer ana;
+  diff = ana.Simplify(diff);
+  if (const int64_t* pdiff = tir::as_const_int(diff)) {
+    return pdiff[0] == 0;
+  }
+  return false;
+}
+
+StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) {
+  return TupleStructInfo(Array<StructInfo>());
+}
+
+StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) {
+  return ObjectStructInfo();
+}
+
+StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) {
+  return ShapeStructInfo(kUnknownNDim);
+}
+
 // call_tir
 
 StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
@@ -73,5 +106,190 @@ Expr MakeCallTIR(Expr func, Tuple args, 
Array<TensorStructInfo> out_sinfo_list,
 
 TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR);
 
+// call builtin
+StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const 
BlockBuilder& ctx) {
+  if (call->sinfo_args.size() == 0) {
+    // by default return void.
+    return TupleStructInfo(Array<StructInfo>());
+  } else {
+    ICHECK_EQ(call->sinfo_args.size(), 1);
+    return call->sinfo_args[0];
+  }
+}
+
+TVM_REGISTER_OP("relax.call_builtin_with_ctx")
+    .set_num_inputs(4)
+    .add_argument("func", "Expr", "The builtin packed func.")
+    .add_argument("args", "Tuple", "The input arguments.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCallBuiltinWithCtx);
+
+Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array<StructInfo> 
sinfo_args) {
+  static const Op& op = Op::Get("relax.call_builtin_with_ctx");
+  return Call(op, {func, args}, Attrs(), sinfo_args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx);
+
+TVM_REGISTER_OP("relax.null_value")
+    .set_num_inputs(0)
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeCallNullValue() {
+  static const Op& op = Op::Get("relax.null_value");
+  return Call(op, {}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue);
+
+// make_closure
+
+RELAY_REGISTER_OP("relax.make_closure")
+    .set_num_inputs(2)
+    .add_argument("func", "Expr", "The closure.")
+    .add_argument("args", "Tuple", "The captured variables.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeClosure(Expr func, Tuple args) {
+  static const Op& op = Op::Get("relax.make_closure");
+  return Call(op, {func, args}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure);
+
+// invoke_closure
+
+StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->sinfo_args.empty()) {
+    return ObjectStructInfo();
+  } else if (call->sinfo_args.size() == 1) {
+    return call->sinfo_args[0];
+  } else {
+    return TupleStructInfo(call->sinfo_args);
+  }
+}
+
+RELAY_REGISTER_OP("relax.invoke_closure")
+    .set_num_inputs(2)
+    .add_argument("closure", "Expr", "The VMClosure.")
+    .add_argument("args", "Tuple", "The captured variables.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoInvokeClosure);
+
+Expr InvokeClosure(Expr closure, Tuple args, Array<StructInfo> sinfo_args) {
+  static const Op& op = Op::Get("relax.invoke_closure");
+  return Call(op, {closure, args}, {}, sinfo_args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure);
+
+// shape_of
+
+RELAY_REGISTER_OP("relax.shape_of")
+    .set_num_inputs(1)
+    .add_argument("input", "Expr", "The input expression")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnShapeStructInfo);
+
+Expr MakeShapeOf(Expr expr) {
+  static const Op& op = Op::Get("relax.shape_of");
+  return Call(op, {expr}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf);
+
+// alloc_tensor
+
+StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& 
ctx) {
+  ICHECK(call->args[0].as<ShapeExprNode>())
+      << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey();
+  ICHECK(call->args[1].as<DataTypeImmNode>())
+      << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey();
+  DataType out_dtype;
+  if (const auto* dtype_node = call->args[1].as<DataTypeImmNode>()) {
+    const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
+    out_dtype = dtype_imm->value;
+  }
+  return TensorStructInfo(call->args[0], out_dtype);
+}
+
+RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
+    .set_num_inputs(3)
+    .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
+    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .add_argument("runtime_device_index", "int64_t",
+                  "The device index indicating on which device the tensor is 
to be "
+                  "allocated at runtime. Index -1 is reserved for the host 
device.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoAllocateTensor);
+
+Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) 
{
+  static const Op& op = Op::Get("relax.builtin.alloc_tensor");
+  return Call(op, {shape, DataTypeImm(dtype), 
PrimValue::Int64(runtime_device_index)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor);
+
+// vm alloc_storage
+
+RELAY_REGISTER_OP("relax.vm.alloc_storage")
+    .set_num_inputs(3)
+    .add_argument("size", "Expr", "The size of the storage to allocate.")
+    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .add_argument("runtime_device_index", "int64_t",
+                  "The device index indicating on which device the tensor is "
+                  "to be allocated at runtime.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeVMAllocStorage(Expr size, int64_t runtime_device_index, DataType 
dtype) {
+  static const Op& op = Op::Get("relax.vm.alloc_storage");
+  return Call(op, {size, PrimValue::Int64(runtime_device_index), 
DataTypeImm(dtype)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage);
+
+// vm alloc_tensor
+
+Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { 
return call->args[1]; }
+
+StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& 
ctx) {
+  DataType out_dtype;
+  if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) {
+    const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
+    out_dtype = dtype_imm->value;
+  }
+  if (const auto* output_shape = call->args[1].as<ShapeExprNode>()) {
+    return TensorStructInfo(GetRef<Expr>(output_shape), out_dtype);
+  }
+  return TensorStructInfo(out_dtype, kUnknownNDim);
+}
+
+RELAY_REGISTER_OP("relax.vm.alloc_tensor")
+    .set_num_inputs(4)
+    .add_argument("storage", "Expr", "The storage to allocate the tensor to.")
+    .add_argument("offset", "int", "Storage offset to allocate the tensor.")
+    .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
+    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoVMAllocTensor);
+
+Expr MakeVMAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) {
+  static const Op& op = Op::Get("relax.vm.alloc_tensor");
+  return Call(op, {storage, PrimValue::Int64(offset), shape, 
DataTypeImm(dtype)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor);
+
+// vm call_tir_dyn
+
+RELAY_REGISTER_OP("relax.vm.call_tir_dyn")
+    .set_num_inputs(2)
+    .add_argument("func", "Expr", "The destination-passing-style function.")
+    .add_argument("args", "Tuple",
+                  "The input arguments (list of tensors and last argument is 
ShapeExpr)")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo);
+
+Expr MakeCallTIRDyn(Expr func, Tuple args) {
+  static const Op& op = Op::Get("relax.vm.call_tir_dyn");
+  return Call(op, {func, args}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 8e362bb4d5..c6d335b2a1 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -115,7 +115,30 @@ inline StructInfo InferStructInfoUnary(const Call& call, 
const BlockBuilder& ctx
 }
 
 /*!
- * \brief Infer  the struct info for unary arithmetic elementwise ops. It's 
also
+ * \brief Infer the struct info by returning the struct info of the input 
argument.
+ * \param call The context Call to the operator.
+ * \param ctx The error reporting context.
+ * \tparam arg_index The index of the argument to infer the output dtype from.
+ * \return The inferred struct info.
+ */
+template <int arg_index>
+StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) {
+  Op op = Downcast<Op>(call->op);
+  int n_input = op->arguments.size();
+  if (static_cast<int>(call->args.size()) != n_input) {
+    ctx->ReportFatal(Diagnostic::Error(call->span)
+                     << op << " op should have " << n_input << " arguments");
+  }
+  if (arg_index >= n_input) {
+    ctx->ReportFatal(Diagnostic::Error(call->span)
+                     << op << " op has only " << n_input
+                     << "arguments, but try to get the arg with index " << 
arg_index);
+  }
+  return GetStructInfo(call->args[arg_index]);
+}
+
+/*!
+ * \brief Infer the struct info for unary arithmetic elementwise ops. It's also
  * used in some NN operators.
  * \param call The context Call to the operator.
  * \param ctx The error reporting context.
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 0ef63c8a41..15a4f8702b 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -19,7 +19,6 @@
 /*!
  * \file src/runtime/relax_vm/builtin.cc
  */
-#include <tvm/runtime/container/adt.h>
 #include <tvm/runtime/container/shape_tuple.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/device_api.h>
@@ -214,14 +213,13 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo
  * \param err_ctx Additional context if error occurs.
  */
 void CheckTupleInfo(ObjectRef arg, int64_t size, Optional<String> err_ctx) {
-  using Tuple = runtime::ADT;
   // a function that lazily get context for error reporting
-  auto* ptr = arg.as<Tuple::ContainerType>();
+  auto* ptr = arg.as<runtime::ArrayNode>();
   CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a 
Tuple but get "
                         << arg->GetTypeKey();
-  CHECK(static_cast<int64_t>(ptr->size) == size)
+  CHECK(static_cast<int64_t>(ptr->size()) == size)
       << "ValueError: " << err_ctx.value_or("") << " expect a Tuple with " << 
size << " elements, "
-      << " but get a Tuple with " << ptr->size << " elements.";
+      << " but get a Tuple with " << ptr->size() << " elements.";
 }
 
 
TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo);
@@ -321,6 +319,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body([](TVMArgs 
args, TVMRetValue* rv
   *rv = args[0];
 });
 
+TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data, 
ShapeTuple new_shape) {
+  return data.CreateView(new_shape, data->dtype);
+});
+
 /*!
  * \brief Load the scalar value in cond and return the result value.
  * \param cond The condition
@@ -367,8 +369,15 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond);
 //-------------------------------------
 //  Data structure API
 //-------------------------------------
-TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem").set_body_typed([](runtime::ADT 
arr, int64_t index) {
-  return arr[index];
+TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem")
+    .set_body_typed([](runtime::Array<ObjectRef> arr, int64_t index) { return 
arr[index]; });
+
+TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, 
TVMRetValue* rv) {
+  runtime::Array<ObjectRef> arr;
+  for (int i = 0; i < args.num_args; ++i) {
+    arr.push_back(args[i].operator ObjectRef());
+  }
+  *rv = arr;
 });
 
 }  // namespace relax_vm
diff --git a/tests/python/relax/test_runtime_builtin.py 
b/tests/python/relax/test_runtime_builtin.py
new file mode 100644
index 0000000000..b4ba54b455
--- /dev/null
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import pytest
+import numpy as np
+
+from tvm.ir import assert_structural_equal
+from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode
+
+
+def test_make_shape():
+    MK = MakeShapeCode
+    make_shape = tvm.get_global_func("vm.builtin.make_shape")
+    heap = tvm.nd.array(np.arange(10).astype("int64"))
+    s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2)
+
+    assert s == tvm.runtime.container.ShapeTuple([10, 0, 2])
+
+
+def test_match_shape():
+    MS = MatchShapeCode
+    match_shape = tvm.get_global_func("vm.builtin.match_shape")
+    heap = tvm.nd.array(np.zeros(10).astype("int64"))
+
+    assert heap.numpy()[2] == 0
+
+    s = tvm.runtime.container.ShapeTuple([1, 2, 3])
+    x = tvm.nd.array(np.zeros([1, 2, 3]))
+
+    match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, 
MS.NO_OP, 0, "")
+
+    assert heap.numpy()[2] == 2
+
+    match_shape(
+        x,
+        heap,
+        3,
+        MS.ASSERT_EQUAL_TO_IMM,
+        1,
+        MS.ASSERT_EQUAL_TO_LOAD,
+        2,
+        MS.ASSERT_EQUAL_TO_IMM,
+        3,
+        "",
+    )
+
+    with pytest.raises(RuntimeError):
+        match_shape(s, heap, 2, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 
2, "")
+
+    with pytest.raises(RuntimeError):
+        match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 2, MS.STORE_TO_HEAP, 
2, MS.NO_OP, 0, "")
+
+
+def test_check_shape_info():
+    check_shape_info = tvm.get_global_func("vm.builtin.check_shape_info")
+    s = tvm.runtime.container.ShapeTuple([1, 2, 3])
+
+    check_shape_info(s, 3, "")
+    check_shape_info(s, -1, "")
+
+    # wrong ndim
+    with pytest.raises(ValueError):
+        check_shape_info(s, 2, "")
+
+    # wrong type
+    with pytest.raises(TypeError):
+        check_shape_info([], 2, "")
+
+
+def test_check_tensor_info():
+    check_tensor_info = tvm.get_global_func("vm.builtin.check_tensor_info")
+    x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+
+    check_tensor_info(x, 2, "int32", "")
+    check_tensor_info(x, -1, "int32", "")
+    check_tensor_info(x, 2, "", "")
+    check_tensor_info(x, -1, "", "")
+
+    # allow not passing in dtype
+    check_tensor_info(x, 2, "")
+    check_tensor_info(x, -1, "")
+
+    # ndim mismatch
+    with pytest.raises(ValueError, match=r".* ndim .*"):
+        check_tensor_info(x, 3, "int32", "")
+
+    # dtype mismatch
+    with pytest.raises(ValueError, match=r"myerror.* dtype .*"):
+        check_tensor_info(x, 2, "float32", "myerror")
+
+    # error with context
+    with pytest.raises(ValueError, match=r".* myerror .*"):
+        check_tensor_info(x, 3, "myerror")
+
+    # wrong type
+    with pytest.raises(TypeError):
+        check_tensor_info([], 2, "", "")
+
+
+def test_check_tuple_info():
+    check_tuple_info = tvm.get_global_func("vm.builtin.check_tuple_info")
+    x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+    t = tvm.runtime.convert([x, x, x])
+
+    check_tuple_info(t, 3, "")
+
+    # size
+    with pytest.raises(ValueError, match=r".*elements.*"):
+        check_tuple_info(t, 2, "")
+
+    # wrong type
+    with pytest.raises(TypeError):
+        check_tuple_info(x, 2, "")
+
+
+def test_check_func_info():
+    check_func_info = tvm.get_global_func("vm.builtin.check_func_info")
+    f = tvm.runtime.convert(lambda x: x)
+    x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+
+    check_func_info(f, "")
+
+    # wrong type
+    with pytest.raises(TypeError, match=".*myerror.*"):
+        check_func_info(x, "myerror")
+
+
+def test_tuple_getitem():
+    tuple_getitem = tvm.get_global_func("vm.builtin.tuple_getitem")
+    x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+    y = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+    t = tvm.runtime.convert([x, y])
+
+    assert tuple_getitem(t, 0) == x
+    assert tuple_getitem(t, 1) == y
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py 
b/tests/python/relax/test_tvmscript_printer_relax.py
index e2cb8bc5fc..58596f968f 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+import tvm
 import pytest
 from tvm import IRModule, relax, tir
 from tvm.script import relax as R
@@ -447,42 +448,4 @@ else:
 
 
 if __name__ == "__main__":
-    test_function()
-    test_extern_func()
-
-    test_object_struct_info()
-    test_prim_struct_info()
-    test_shape_struct_info_0()
-    test_shape_struct_info_1()
-    test_shape_struct_info_2()
-    test_tensor_struct_info()
-    test_tuple_struct_info_empty()
-    test_tuple_struct_info()
-    test_func_struct_info()
-
-    test_shape_type()
-    test_object_type()
-    test_dyn_tensor_type()
-    test_packed_func_type()
-    test_tuple_type()
-    test_func_type()
-
-    test_prim_value()
-    test_string_imm()
-    test_data_type_imm()
-
-    test_var()
-    test_dataflow_var()
-    #
-    test_tuple()
-    test_tuple_get_item()
-    test_shape_expr()
-    test_call()
-
-    test_seq_expr()
-    test_binding_block()
-    test_dataflow_block()
-
-    test_match_cast()
-    test_var_binding()
-    test_if()
+    tvm.testing.main()
diff --git a/tests/python/relax/test_vm_codegen_only.py 
b/tests/python/relax/test_vm_codegen_only.py
new file mode 100644
index 0000000000..b5e7709177
--- /dev/null
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -0,0 +1,333 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test last-stage of codegen VM.
+
+Restrictions: all shape lowered, explicit allocation.
+"""
+import tvm
+import pytest
+import numpy as np
+from tvm import relax, TVMError
+from tvm.script import relax as R, tir as T
+from tvm.relax.testing.vm import check_saved_func
+from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode
+
+EXEC_MODE = ["bytecode"]
+
+
+def codegen(mod, target, exec_mode="bytecode"):
+    builder = relax.ExecBuilder()
+    tir_mod = relax.vm._vmcodegen(builder, mod, exec_mode=exec_mode)
+    return relax.vm._vmlink(builder, target, tir_mod)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_copy(exec_mode):
+    @tvm.script.ir_module
+    class TestVMMove:
+        @R.function
+        def foo(x: R.Tensor((3, 4), "float32")):
+            R.func_attr({"global_symbol": "foo"})
+            z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 
4), dtype="float32")))
+            return z
+
+    mod = TestVMMove
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    res = check_saved_func(vm, "foo", inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_if_cond_const(exec_mode):
+    @tvm.script.ir_module
+    class TestVMIfCondConst:
+        @R.function
+        def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2, 
dtype="float32"):
+            R.func_attr({"global_symbol": "main"})
+            if relax.const(True, dtype="bool"):
+                ret = x
+            else:
+                ret = x
+            return ret
+
+    mod = TestVMIfCondConst
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    inp = tvm.nd.array(np.random.rand(3, 4))
+    res = vm["main"](inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy())
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_exec_serialize_export_library(exec_mode):
+    @tvm.script.ir_module
+    class TestVMMove:
+        @R.function
+        def foo(x: R.Tensor((3, 4), "float32")):
+            R.func_attr({"global_symbol": "foo"})
+            z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 
4), dtype="float32")))
+            return z
+
+    mod = TestVMMove
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target)
+    from tvm.contrib import utils
+
+    temp_dir = utils.tempdir()
+    path_exec = temp_dir.relpath("exec.so")
+    ex.mod.export_library(path_exec)
+
+    loaded_exec = relax.vm.Executable(tvm.runtime.load_module(path_exec))
+    assert ex.as_text() == loaded_exec.as_text()
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_if_cond(exec_mode):
+    @tvm.script.ir_module
+    class TestVMCompileIf:
+        @R.function
+        def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> 
R.Tensor:
+            R.func_attr({"global_symbol": "ife"})
+            if cond:
+                w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
+            else:
+                w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor))
+            return w
+
+    mod = TestVMCompileIf
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    inp = tvm.nd.array(np.random.rand(3, 4))
+    res = vm["ife"](tvm.nd.array(1), inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), 
rtol=1e-7, atol=1e-7)
+    res = vm["ife"](tvm.nd.array(True), inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), 
rtol=1e-7, atol=1e-7)
+    res = vm["ife"](tvm.nd.array(0), inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), 
rtol=1e-7, atol=1e-7)
+    res = vm["ife"](tvm.nd.array(False), inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), 
rtol=1e-7, atol=1e-7)
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_return_const_tuple(exec_mode):
+    @tvm.script.ir_module
+    class ReturnConstTuple:
+        @R.function
+        def main(x: R.Tensor(ndim=2, dtype="float32")):
+            R.func_attr({"global_symbol": "main"})
+            y = R.const([1, 2])
+            z = (y, R.const([3, 4]), x)
+            return z
+
+    mod = ReturnConstTuple
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    inp = tvm.nd.array(np.random.rand(2, 3))
+    res0, res1, res2 = vm["main"](inp)
+    tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2]))
+    tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4]))
+    tvm.testing.assert_allclose(res2.numpy(), inp.numpy())
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_const_as_call_arg(exec_mode):
+    @tvm.script.ir_module
+    class TestVMConstAsCallArg:
+        @R.function
+        def main(x: R.Tensor(ndim=2, dtype="float32")):
+            R.func_attr({"global_symbol": "main"})
+            a = R.call_packed(
+                "test.vm.add",
+                relax.const([1, 2]),
+                relax.const([3, 4]),
+                sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+            )
+            b = R.call_packed(
+                "test.vm.add",
+                a,
+                x,
+                sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+            )
+            return b
+
+    mod = TestVMConstAsCallArg
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    inp = tvm.nd.array(np.random.rand(1, 2))
+    res = vm["main"](inp)
+    tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy())
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_shape_check_builtin(exec_mode):
+    MS = MatchShapeCode
+    MK = MakeShapeCode
+    # slot assignment:
+    # 0: n, 1: m
+    sindex = {"n": 0, "m": 1}
+
+    @tvm.script.ir_module
+    class TestVMShapeCheck:
+        @R.function
+        def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3):
+            R.func_attr({"global_symbol": "main"})
+            n = T.Var("n", "int64")
+            k = T.Var("k", "int64")
+            shape_heap = R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                [R.prim_value(3)],
+                sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+            )
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", 
sinfo_args=[R.Tuple()]
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                2,
+                MS.STORE_TO_HEAP,
+                sindex["n"],
+                MS.STORE_TO_HEAP,
+                sindex["m"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            # construct shape value for return
+            s = R.call_packed(
+                "vm.builtin.make_shape",
+                shape_heap,
+                3,
+                MK.LOAD_SHAPE,
+                sindex["m"],
+                MK.LOAD_SHAPE,
+                sindex["n"],
+                MK.USE_IMM,
+                2,
+                sinfo_args=[R.Shape(ndim=3)],
+            )
+            return s
+
+    mod = TestVMShapeCheck
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x = tvm.nd.array(np.zeros((1, 2)).astype("float32"))
+    res = vm["main"](x)
+    assert res == tvm.runtime.container.ShapeTuple([2, 1, 2])
+
+    # wrong input type
+    with pytest.raises(TypeError):
+        vm["main"]([])
+
+    # wrong ndim
+    with pytest.raises(ValueError, match=r".*ndim.*"):
+        vm["main"](tvm.nd.array(np.zeros(1).astype("float32")))
+
+    # wrong dtype
+    with pytest.raises(ValueError, match=r".*dtype.*"):
+        vm["main"](tvm.nd.array(np.zeros((1, 2)).astype("int32")))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_prim_value(exec_mode):
+    @tvm.script.ir_module
+    class TestVMPrimValue:
+        @R.function
+        def main():
+            R.func_attr({"global_symbol": "main"})
+            ret = R.prim_value(T.int64(1))
+            return ret
+
+    mod = TestVMPrimValue
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    res = vm["main"]()
+    assert res == 1
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_string_imm(exec_mode):
+    @tvm.script.ir_module
+    class TestVMStringImm:
+        @R.function
+        def main():
+            R.func_attr({"global_symbol": "main"})
+            ret = R.str("hello")
+            return ret
+
+    mod = TestVMStringImm
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    res = vm["main"]()
+    assert res == "hello"
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_datatype_imm(exec_mode):
+    @tvm.script.ir_module
+    class TestDataTypeImm:
+        @R.function
+        def main():
+            R.func_attr({"global_symbol": "main"})
+            ret = R.dtype("float32")
+            return ret
+
+    mod = TestDataTypeImm
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    res = vm["main"]()
+    assert res == "float32"
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_builtin_reshape(exec_mode):
+    @tvm.script.ir_module
+    class TestVMBuiltinReshape:
+        @R.function
+        def main(x: R.Tensor((3, 4), "float32")):
+            R.func_attr({"global_symbol": "main"})
+            y = R.call_packed(
+                "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), 
"float32")
+            )
+            return y
+
+    mod = TestVMBuiltinReshape
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    dev = tvm.cpu()
+    vm = relax.VirtualMachine(ex, dev)
+
+    input_np = np.random.rand(3, 4).astype("float32")
+    input = tvm.nd.array(input_np, dev)
+    res = vm["main"](input)
+    expected = input_np.reshape(6, 2)
+    tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to