This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 562b338f6c [Unity] Relax VM codegen (#13954)
562b338f6c is described below
commit 562b338f6c3bdc7d304beb1a4dbb3075d75791cc
Author: Yuchen Jin <[email protected]>
AuthorDate: Fri Feb 10 16:19:37 2023 -0800
[Unity] Relax VM codegen (#13954)
---
python/tvm/relax/testing/runtime_builtin.py | 34 ++
src/relax/backend/vm/codegen_vm.cc | 447 +++++++++++++++++++++
src/relax/op/op.cc | 220 +++++++++-
src/relax/op/op_common.h | 25 +-
src/runtime/relax_vm/builtin.cc | 23 +-
tests/python/relax/test_runtime_builtin.py | 153 +++++++
tests/python/relax/test_tvmscript_printer_relax.py | 41 +-
tests/python/relax/test_vm_codegen_only.py | 333 +++++++++++++++
8 files changed, 1228 insertions(+), 48 deletions(-)
diff --git a/python/tvm/relax/testing/runtime_builtin.py
b/python/tvm/relax/testing/runtime_builtin.py
new file mode 100644
index 0000000000..1b04364e69
--- /dev/null
+++ b/python/tvm/relax/testing/runtime_builtin.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Testing utilities for runtime builtin functions."""
+from enum import IntEnum
+
+
+class MatchShapeCode(IntEnum):
+ """Code passed to match shape builtin"""
+
+ ASSERT_EQUAL_TO_IMM = 0
+ STORE_TO_HEAP = 1
+ NO_OP = 2
+ ASSERT_EQUAL_TO_LOAD = 3
+
+
+class MakeShapeCode(IntEnum):
+ """Code passed to match shape builtin"""
+
+ USE_IMM = 0
+ LOAD_SHAPE = 1
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
new file mode 100644
index 0000000000..1782f1107a
--- /dev/null
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -0,0 +1,447 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/vm/codegen_vm.cc
+ * \brief A codegen to generate VM executable from a Relax IRModule.
+ */
+#include <tvm/driver/driver_api.h>
+#include <tvm/relax/exec_builder.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/runtime/relax_vm/bytecode.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "../../../target/metadata_module.h"
+#include "../../../target/source/codegen_source_base.h"
+
+namespace tvm {
+namespace relax {
+namespace relax_vm {
+
+using tvm::Target;
+using namespace relax;
+using namespace tvm::runtime;
+using namespace tvm::runtime::relax_vm;
+
+// Helper function to get the function name of the registered packed function
implementation of
+// relax operator.
+FCallPacked GetPackedFuncName(const Call& call) {
+ static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
+ if (call->op.as<OpNode>()) {
+ Op op = Downcast<Op>(call->op);
+ if (op_map.count(op)) {
+ return op_map[op];
+ }
+ }
+ return {};
+}
+
+/*!
+ * \brief A class to generate VM executable for Relax functions.
+ */
+class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
+ public:
+ explicit CodeGenVM(relax::ExecBuilder builder, IRModule ctx_mod)
+ : builder_(builder), ctx_mod_(ctx_mod) {}
+
+ static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
+ IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>());
+ CodeGenVM codegen(builder, mod);
+ // Remove relax function and turn into TIR func.
+ for (auto& p : mod->functions) {
+ if (auto* func = p.second.as<FunctionNode>()) {
+ codegen.Codegen(GetRef<Function>(func));
+ } else {
+ res_mod->Add(p.first, p.second);
+ }
+ }
+ return res_mod;
+ }
+
+ protected:
+ size_t NewRegister() { return registers_num_++; }
+
+ // Convert Arg value to a register, trigger copy if needed
+ Instruction::Arg EnsureReg(Instruction::Arg arg) {
+ if (arg.kind() == Instruction::ArgKind::kRegister) {
+ return arg;
+ } else {
+ RegName dst_reg = NewRegister();
+ builder_->EmitCall("vm.builtin.copy", {arg}, dst_reg);
+ return Instruction::Arg::Register(dst_reg);
+ }
+ }
+
+ void Codegen(const Function& func) {
+ Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(gsymbol.defined()) << "there should be no local functions in Relax
VM codegen phase. "
+ "Did you forget to apply LambdaLift or
AttachGlobalSymbol Pass?";
+
+ Array<String> param_names;
+ for (Var param : func->params) {
+ param_names.push_back(param->name_hint());
+ }
+
+ builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names);
+
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ RegName r = NewRegister();
+ ICHECK_EQ(r, static_cast<RegName>(i));
+ this->var_arg_map_.insert({func->params[i],
Instruction::Arg::Register(r)});
+ }
+ Instruction::Arg ret = ExprFunctor::VisitExpr(func->body);
+ builder_->EmitRet(EnsureReg(ret));
+ builder_->EndFunction(gsymbol.value());
+ // reset register number to be 0;
+ registers_num_ = 0;
+ var_arg_map_.clear();
+ }
+
+ Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
+ for (auto block : op->blocks) {
+ for (Binding binding : block->bindings) {
+ Instruction::Arg value;
+ if (auto* var_binding = binding.as<VarBindingNode>()) {
+ value = this->VisitExpr(var_binding->value);
+ } else if (auto* match_cast = binding.as<MatchCastNode>()) {
+ value = this->VisitExpr(match_cast->value);
+ } else {
+ LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
+ }
+ this->var_arg_map_.insert({binding->var, value});
+ }
+ }
+
+ Instruction::Arg ret_reg = this->VisitExpr(op->body);
+ return ret_reg;
+ }
+
+ Instruction::Arg VisitExpr_(const CallNode* call_node) final {
+ Call call = GetRef<Call>(call_node);
+
+ if (call_node->op == null_value_op_) {
+ return Instruction::Arg::Register(Instruction::kVoidRegister);
+ }
+
+ // allocate dst register.
+ RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister :
NewRegister();
+ if (call->op.as<OpNode>()) {
+ if (call_node->op == call_builtin_with_ctx_op_) {
+ // TODO(relax-team) migrate most handling of op to
+ // directly map to call_builtin_with_ctx before codegen and simplify
vm codegen.
+ EmitCallBuiltinWithCtx(call, dst_reg);
+ } else if (call_node->op == alloc_storage_op_) {
+ EmitAllocStorage(call, dst_reg);
+ } else if (call_node->op == alloc_tensor_op_) {
+ EmitAllocTensor(call, dst_reg);
+ } else {
+ // every "normal" operator is lowered to a global var in the IRModule.
The Attrs for those
+ // ops are handled in a pass when lowering them to TIR.
+ LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" <<
call_node->op;
+ }
+ } else {
+ EmitNormalCall(call, dst_reg);
+ }
+ return Instruction::Arg::Register(dst_reg);
+ }
+
+ Instruction::Arg VisitExpr_(const IfNode* op) final {
+ const If& ife = GetRef<If>(op);
+ Instruction::Arg cond_value = this->VisitExpr(ife->cond);
+
+ // Reserve a register for cond
+ RegName cond_reg = NewRegister();
+ builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg);
+
+ // obtain the temp exec in progress.
+ vm::Executable* exec = builder_->exec();
+
+ // Record the offset of If instruction
+ size_t if_offset = exec->instr_offset.size();
+
+ builder_->EmitIf(Instruction::Arg::Register(cond_reg), 3);
+ size_t num_instr = exec->instr_offset.size();
+ Instruction::Arg true_value = this->VisitExpr(ife->true_branch);
+ // Reserve a register for return
+ size_t merge_register = NewRegister();
+ // Copy the output from true branch to merge register
+ builder_->EmitCall("vm.builtin.copy", {true_value}, merge_register);
+
+ // Record the offset of Goto instruction
+ size_t goto_offset = exec->instr_offset.size();
+
+ builder_->EmitGoto(1);
+
+ // Calculate the false offset of If
+ size_t false_offset = exec->instr_offset.size() - num_instr + 1;
+
+ Instruction::Arg false_value = this->VisitExpr(ife->false_branch);
+ // Copy the output data of false branch to merge register
+ builder_->EmitCall("vm.builtin.copy", {false_value}, merge_register);
+
+ // Update the offsets of the If instruction emitted above
+ // Jump to the behind of the next goto instruction
+ exec->SetInstructionData(if_offset, 2,
static_cast<ExecWord>(false_offset));
+ // Update the pc_offset of Goto instruction
+ // Jump over the false branch
+ size_t pc_offset = exec->instr_offset.size() - goto_offset;
+ exec->SetInstructionData(goto_offset, 1, static_cast<ExecWord>(pc_offset));
+ return Instruction::Arg::Register(merge_register);
+ }
+
+ Instruction::Arg VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
+ auto it = this->var_arg_map_.find(var);
+ ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not
defined";
+ return it->second;
+ }
+
+ Instruction::Arg VisitExpr_(const ConstantNode* op) final {
+ return builder_->ConvertConstant(op->data);
+ }
+
+ Instruction::Arg VisitExpr_(const ShapeExprNode* op) final {
+ std::vector<int64_t> shape;
+ for (PrimExpr e : op->values) {
+ if (auto* int_value = e.as<IntImmNode>()) {
+ shape.push_back(int_value->value);
+ } else {
+ LOG(FATAL) << "Should only use constant shape after shape lowering: "
<< op->values;
+ }
+ }
+ return builder_->ConvertConstant(ShapeTuple(shape));
+ }
+
+ Instruction::Arg VisitExpr_(const PrimValueNode* op) final {
+ if (auto* int_imm = op->value.as<IntImmNode>()) {
+ return builder_->ConvertConstant(int_imm->value);
+ } else {
+ auto* float_imm = op->value.as<FloatImmNode>();
+ ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now";
+ return builder_->ConvertConstant(float_imm->value);
+ }
+ }
+
+ Instruction::Arg VisitExpr_(const StringImmNode* op) final {
+ return builder_->ConvertConstant(op->value);
+ }
+
+ Instruction::Arg VisitExpr_(const DataTypeImmNode* op) final {
+ return builder_->ConvertConstant(op->value);
+ }
+
+ Instruction::Arg VisitExpr_(const TupleNode* op) final {
+ Tuple tuple = GetRef<Tuple>(op);
+ std::vector<Instruction::Arg> args;
+ for (Expr arg : tuple->fields) {
+ args.push_back(this->VisitExpr(arg));
+ }
+ size_t dst_register = NewRegister();
+ builder_->EmitCall("vm.builtin.make_tuple", args, dst_register);
+
+ return Instruction::Arg::Register(dst_register);
+ }
+
+ Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final {
+ TupleGetItem expr = GetRef<TupleGetItem>(op);
+ std::vector<Instruction::Arg> args = {this->VisitExpr(expr->tuple)};
+
+ args.push_back(builder_->ConvertConstant(expr->index));
+
+ size_t dst_register = NewRegister();
+ builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register);
+
+ return Instruction::Arg::Register(dst_register);
+ }
+
+ Instruction::Arg VisitExpr_(const GlobalVarNode* op) final {
+ GlobalVar gvar = GetRef<GlobalVar>(op);
+ Optional<String> symbol;
+ VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc;
+
+ // Run a look up in the env to see if it maps to an extern func.
+ auto it = ctx_mod_->functions.find(gvar);
+ if (it != ctx_mod_->functions.end()) {
+ BaseFunc func = (*it).second;
+ if (auto* efunc = func.as<ExternFuncNode>()) {
+ symbol = efunc->global_symbol;
+ kind = VMFuncInfo::FuncKind::kPackedFunc;
+ } else if (func.as<FunctionNode>()) {
+ symbol = gvar->name_hint;
+ kind = VMFuncInfo::FuncKind::kVMFunc;
+ }
+ }
+ // GlobalVar can be reference to a Relax function or a TIR primfunc
+ // At this point: all global var must corresponds to the right symbol.
+ // TODO(relax-team): switch everything to extern before splitting TIR/relax
+ // so we do not have idle global var here.
+ if (!symbol.defined()) {
+ symbol = gvar->name_hint;
+ kind = VMFuncInfo::FuncKind::kPackedFunc;
+ }
+ // declare the function to be safe.
+ ICHECK(symbol.defined());
+ builder_->DeclareFunction(symbol.value(), kind);
+ return builder_->GetFunction(symbol.value());
+ }
+
+ Instruction::Arg VisitExpr_(const ExternFuncNode* op) final {
+ builder_->DeclareFunction(op->global_symbol,
VMFuncInfo::FuncKind::kPackedFunc);
+ return builder_->GetFunction(op->global_symbol);
+ }
+
+ void EmitAllocStorage(const Call& call_node, RegName dst_reg) {
+ ICHECK_EQ(call_node->args.size(), 3);
+ // Handle args of the call
+ std::vector<Instruction::Arg> args;
+ args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
+ // buffer size, dtype, device index
+ for (auto arg : call_node->args) {
+ args.push_back(this->VisitExpr(arg));
+ }
+ builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg);
+ }
+
+ void EmitAllocTensor(const Call& call_node, RegName dst_reg) {
+ ICHECK_EQ(call_node->args.size(), 4);
+ std::vector<Instruction::Arg> args;
+ args.reserve(4);
+ for (Expr arg : call_node->args) {
+ args.push_back(this->VisitExpr(arg));
+ }
+ builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg);
+ }
+
+ void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) {
+ std::vector<Instruction::Arg> args;
+ args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
+
+ auto func = this->VisitExpr(call_node->args[0]);
+ auto tuple_arg = Downcast<Tuple>(call_node->args[1]);
+
+ // Handle args of the call
+ for (Expr arg : tuple_arg->fields) {
+ args.push_back(this->VisitExpr(arg));
+ }
+
+ builder_->EmitCall(func, args, dst_reg);
+ }
+
+ void EmitNormalCall(const Call& call_node, RegName dst_reg) {
+ Instruction::Arg func = VisitExpr(call_node->op);
+ std::vector<Instruction::Arg> args = VisitArray(call_node->args);
+ builder_->EmitCall(func, args, dst_reg);
+ }
+
+ // TODO(relax-team) revisit after PrimValue.
+ // Emit the `call_node` attributes as constants and append these constants
to `args` vector.
+ void AppendAttrsAsConstants(const Call& call_node,
std::vector<Instruction::Arg>& args) {
+ auto attrs = call_node->attrs;
+ if (!attrs.defined()) return;
+
+ LOG(FATAL) << "Support for attributes of Op " << call_node->op
+ << " has not been implemented yet.";
+ return;
+ }
+
+ // Emits call to packed function `name` with arguments copied over from
`call_node` args and
+ // attributes.
+ void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name,
RegName dst_reg) {
+ std::vector<Instruction::Arg> args = VisitArray(call_node->args);
+ AppendAttrsAsConstants(call_node, args);
+ builder_->EmitCall(name, args, dst_reg);
+ }
+
+ std::vector<Instruction::Arg> VisitArray(const Array<Expr>& arr) {
+ std::vector<Instruction::Arg> ret;
+ for (size_t i = 0; i < arr.size(); ++i) {
+ ret.push_back(this->VisitExpr(arr[i]));
+ }
+ return ret;
+ }
+
+ /*! \brief Internal ExecBuilder. */
+ relax::ExecBuilder builder_;
+ /*!
+ * \brief Total number of virtual registers allocated.
+ * \note The first two registers are reserved for special registers.
+ */
+ size_t registers_num_ = 0;
+ /*! \brief Map from var to register number. */
+ std::unordered_map<Var, Instruction::Arg, ObjectPtrHash, ObjectPtrEqual>
var_arg_map_;
+ /*! \brief the context module. */
+ IRModule ctx_mod_;
+ /*! \brief Cache ops that need to be frequently used later to reduce lookup
overhead. */
+ const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
+ const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+ const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
+ const Op& null_value_op_ = Op::Get("relax.null_value");
+};
+
+/*!
+ * \brief Create the Relax VM executable from all relax.Function in mod.
+ * and add them to exec_builder.
+ * \param exec_builder Builder to collect executables.
+ * \param mod Input module.
+ * \return Left over IRModule that may contain otehr functions.
+ */
+IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) {
+ return CodeGenVM::Run(exec_builder, mod);
+}
+
+TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen);
+
+/*!
+ * \brief Link the libaries together.
+ */
+Module VMLink(ExecBuilder builder, Target target, Optional<Module> lib,
Array<Module> ext_libs,
+ Map<String, runtime::NDArray> params) {
+ // TODO(relax-team) Revisit the param and ext_lib options.
+ ObjectPtr<Executable> executable = builder->Get();
+ if (!lib.defined()) {
+ lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
+ }
+ std::unordered_map<std::string, runtime::NDArray> conv_params;
+ for (const auto& [name, param] : params) {
+ conv_params[name] = param;
+ }
+ Module combined_lib = codegen::CreateMetadataModule(
+ conv_params, lib.value(), ext_libs, target,
+
+ // TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM
support.
+ // Before jumping into details, only support cpp runtime for now.
+ relay::Runtime::Create("cpp"),
+ relay::Executor::Create("graph"), // TODO(@sunggg): pass arbitrarily
executor. CPP runtime
+ // won't use this anyways.
+ relay::backend::ExecutorCodegenMetadata());
+ executable->Import(combined_lib);
+ return Module(executable);
+}
+
+TVM_REGISTER_GLOBAL("relax.VMLink").set_body_typed(VMLink);
+
+} // namespace relax_vm
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 8640ed79ad..ca66b0a9ef 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -18,13 +18,46 @@
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
-#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/utils.h>
#include <tvm/relay/op.h>
+#include "op_common.h"
+
namespace tvm {
namespace relax {
+bool EqualConstInt(const PrimExpr& lhs, int64_t value) {
+ if (const int64_t* pvalue = tir::as_const_int(lhs)) {
+ return pvalue[0] == value;
+ }
+ return false;
+}
+
+bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) {
+ PrimExpr diff = lhs - rhs;
+ if (const int64_t* pdiff = tir::as_const_int(diff)) {
+ return pdiff[0] == 0;
+ }
+ tvm::arith::Analyzer ana;
+ diff = ana.Simplify(diff);
+ if (const int64_t* pdiff = tir::as_const_int(diff)) {
+ return pdiff[0] == 0;
+ }
+ return false;
+}
+
+StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) {
+ return TupleStructInfo(Array<StructInfo>());
+}
+
+StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) {
+ return ObjectStructInfo();
+}
+
+StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) {
+ return ShapeStructInfo(kUnknownNDim);
+}
+
// call_tir
StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
@@ -73,5 +106,190 @@ Expr MakeCallTIR(Expr func, Tuple args,
Array<TensorStructInfo> out_sinfo_list,
TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR);
+// call builtin
+StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const
BlockBuilder& ctx) {
+ if (call->sinfo_args.size() == 0) {
+ // by default return void.
+ return TupleStructInfo(Array<StructInfo>());
+ } else {
+ ICHECK_EQ(call->sinfo_args.size(), 1);
+ return call->sinfo_args[0];
+ }
+}
+
+TVM_REGISTER_OP("relax.call_builtin_with_ctx")
+ .set_num_inputs(4)
+ .add_argument("func", "Expr", "The builtin packed func.")
+ .add_argument("args", "Tuple", "The input arguments.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoCallBuiltinWithCtx);
+
+Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array<StructInfo>
sinfo_args) {
+ static const Op& op = Op::Get("relax.call_builtin_with_ctx");
+ return Call(op, {func, args}, Attrs(), sinfo_args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx);
+
+TVM_REGISTER_OP("relax.null_value")
+ .set_num_inputs(0)
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeCallNullValue() {
+ static const Op& op = Op::Get("relax.null_value");
+ return Call(op, {}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue);
+
+// make_closure
+
+RELAY_REGISTER_OP("relax.make_closure")
+ .set_num_inputs(2)
+ .add_argument("func", "Expr", "The closure.")
+ .add_argument("args", "Tuple", "The captured variables.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeClosure(Expr func, Tuple args) {
+ static const Op& op = Op::Get("relax.make_closure");
+ return Call(op, {func, args}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure);
+
+// invoke_closure
+
+StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder&
ctx) {
+ if (call->sinfo_args.empty()) {
+ return ObjectStructInfo();
+ } else if (call->sinfo_args.size() == 1) {
+ return call->sinfo_args[0];
+ } else {
+ return TupleStructInfo(call->sinfo_args);
+ }
+}
+
+RELAY_REGISTER_OP("relax.invoke_closure")
+ .set_num_inputs(2)
+ .add_argument("closure", "Expr", "The VMClosure.")
+ .add_argument("args", "Tuple", "The captured variables.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoInvokeClosure);
+
+Expr InvokeClosure(Expr closure, Tuple args, Array<StructInfo> sinfo_args) {
+ static const Op& op = Op::Get("relax.invoke_closure");
+ return Call(op, {closure, args}, {}, sinfo_args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure);
+
+// shape_of
+
+RELAY_REGISTER_OP("relax.shape_of")
+ .set_num_inputs(1)
+ .add_argument("input", "Expr", "The input expression")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnShapeStructInfo);
+
+Expr MakeShapeOf(Expr expr) {
+ static const Op& op = Op::Get("relax.shape_of");
+ return Call(op, {expr}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf);
+
+// alloc_tensor
+
+StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder&
ctx) {
+ ICHECK(call->args[0].as<ShapeExprNode>())
+ << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey();
+ ICHECK(call->args[1].as<DataTypeImmNode>())
+ << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey();
+ DataType out_dtype;
+ if (const auto* dtype_node = call->args[1].as<DataTypeImmNode>()) {
+ const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
+ out_dtype = dtype_imm->value;
+ }
+ return TensorStructInfo(call->args[0], out_dtype);
+}
+
+RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
+ .set_num_inputs(3)
+ .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
+ .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+ .add_argument("runtime_device_index", "int64_t",
+ "The device index indicating on which device the tensor is
to be "
+ "allocated at runtime. Index -1 is reserved for the host
device.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoAllocateTensor);
+
+Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index)
{
+ static const Op& op = Op::Get("relax.builtin.alloc_tensor");
+ return Call(op, {shape, DataTypeImm(dtype),
PrimValue::Int64(runtime_device_index)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor);
+
+// vm alloc_storage
+
+RELAY_REGISTER_OP("relax.vm.alloc_storage")
+ .set_num_inputs(3)
+ .add_argument("size", "Expr", "The size of the storage to allocate.")
+ .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+ .add_argument("runtime_device_index", "int64_t",
+ "The device index indicating on which device the tensor is "
+ "to be allocated at runtime.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeVMAllocStorage(Expr size, int64_t runtime_device_index, DataType
dtype) {
+ static const Op& op = Op::Get("relax.vm.alloc_storage");
+ return Call(op, {size, PrimValue::Int64(runtime_device_index),
DataTypeImm(dtype)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage);
+
+// vm alloc_tensor
+
+Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) {
return call->args[1]; }
+
+StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder&
ctx) {
+ DataType out_dtype;
+ if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) {
+ const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
+ out_dtype = dtype_imm->value;
+ }
+ if (const auto* output_shape = call->args[1].as<ShapeExprNode>()) {
+ return TensorStructInfo(GetRef<Expr>(output_shape), out_dtype);
+ }
+ return TensorStructInfo(out_dtype, kUnknownNDim);
+}
+
+RELAY_REGISTER_OP("relax.vm.alloc_tensor")
+ .set_num_inputs(4)
+ .add_argument("storage", "Expr", "The storage to allocate the tensor to.")
+ .add_argument("offset", "int", "Storage offset to allocate the tensor.")
+ .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
+ .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoVMAllocTensor);
+
+Expr MakeVMAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) {
+ static const Op& op = Op::Get("relax.vm.alloc_tensor");
+ return Call(op, {storage, PrimValue::Int64(offset), shape,
DataTypeImm(dtype)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor);
+
+// vm call_tir_dyn
+
+RELAY_REGISTER_OP("relax.vm.call_tir_dyn")
+ .set_num_inputs(2)
+ .add_argument("func", "Expr", "The destination-passing-style function.")
+ .add_argument("args", "Tuple",
+ "The input arguments (list of tensors and last argument is
ShapeExpr)")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo);
+
+Expr MakeCallTIRDyn(Expr func, Tuple args) {
+ static const Op& op = Op::Get("relax.vm.call_tir_dyn");
+ return Call(op, {func, args}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 8e362bb4d5..c6d335b2a1 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -115,7 +115,30 @@ inline StructInfo InferStructInfoUnary(const Call& call,
const BlockBuilder& ctx
}
/*!
- * \brief Infer the struct info for unary arithmetic elementwise ops. It's
also
+ * \brief Infer the struct info by returning the struct info of the input
argument.
+ * \param call The context Call to the operator.
+ * \param ctx The error reporting context.
+ * \tparam arg_index The index of the argument to infer the output dtype from.
+ * \return The inferred struct info.
+ */
+template <int arg_index>
+StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) {
+ Op op = Downcast<Op>(call->op);
+ int n_input = op->arguments.size();
+ if (static_cast<int>(call->args.size()) != n_input) {
+ ctx->ReportFatal(Diagnostic::Error(call->span)
+ << op << " op should have " << n_input << " arguments");
+ }
+ if (arg_index >= n_input) {
+ ctx->ReportFatal(Diagnostic::Error(call->span)
+ << op << " op has only " << n_input
+ << "arguments, but try to get the arg with index " <<
arg_index);
+ }
+ return GetStructInfo(call->args[arg_index]);
+}
+
+/*!
+ * \brief Infer the struct info for unary arithmetic elementwise ops. It's also
* used in some NN operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 0ef63c8a41..15a4f8702b 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -19,7 +19,6 @@
/*!
* \file src/runtime/relax_vm/builtin.cc
*/
-#include <tvm/runtime/container/adt.h>
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
@@ -214,14 +213,13 @@
TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo
* \param err_ctx Additional context if error occurs.
*/
void CheckTupleInfo(ObjectRef arg, int64_t size, Optional<String> err_ctx) {
- using Tuple = runtime::ADT;
// a function that lazily get context for error reporting
- auto* ptr = arg.as<Tuple::ContainerType>();
+ auto* ptr = arg.as<runtime::ArrayNode>();
CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a
Tuple but get "
<< arg->GetTypeKey();
- CHECK(static_cast<int64_t>(ptr->size) == size)
+ CHECK(static_cast<int64_t>(ptr->size()) == size)
<< "ValueError: " << err_ctx.value_or("") << " expect a Tuple with " <<
size << " elements, "
- << " but get a Tuple with " << ptr->size << " elements.";
+ << " but get a Tuple with " << ptr->size() << " elements.";
}
TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo);
@@ -321,6 +319,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body([](TVMArgs
args, TVMRetValue* rv
*rv = args[0];
});
+TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data,
ShapeTuple new_shape) {
+ return data.CreateView(new_shape, data->dtype);
+});
+
/*!
* \brief Load the scalar value in cond and return the result value.
* \param cond The condition
@@ -367,8 +369,15 @@
TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond);
//-------------------------------------
// Data structure API
//-------------------------------------
-TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem").set_body_typed([](runtime::ADT
arr, int64_t index) {
- return arr[index];
+TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem")
+ .set_body_typed([](runtime::Array<ObjectRef> arr, int64_t index) { return
arr[index]; });
+
+TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args,
TVMRetValue* rv) {
+ runtime::Array<ObjectRef> arr;
+ for (int i = 0; i < args.num_args; ++i) {
+ arr.push_back(args[i].operator ObjectRef());
+ }
+ *rv = arr;
});
} // namespace relax_vm
diff --git a/tests/python/relax/test_runtime_builtin.py
b/tests/python/relax/test_runtime_builtin.py
new file mode 100644
index 0000000000..b4ba54b455
--- /dev/null
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import pytest
+import numpy as np
+
+from tvm.ir import assert_structural_equal
+from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode
+
+
+def test_make_shape():
+ MK = MakeShapeCode
+ make_shape = tvm.get_global_func("vm.builtin.make_shape")
+ heap = tvm.nd.array(np.arange(10).astype("int64"))
+ s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2)
+
+ assert s == tvm.runtime.container.ShapeTuple([10, 0, 2])
+
+
+def test_match_shape():
+ MS = MatchShapeCode
+ match_shape = tvm.get_global_func("vm.builtin.match_shape")
+ heap = tvm.nd.array(np.zeros(10).astype("int64"))
+
+ assert heap.numpy()[2] == 0
+
+ s = tvm.runtime.container.ShapeTuple([1, 2, 3])
+ x = tvm.nd.array(np.zeros([1, 2, 3]))
+
+ match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2,
MS.NO_OP, 0, "")
+
+ assert heap.numpy()[2] == 2
+
+ match_shape(
+ x,
+ heap,
+ 3,
+ MS.ASSERT_EQUAL_TO_IMM,
+ 1,
+ MS.ASSERT_EQUAL_TO_LOAD,
+ 2,
+ MS.ASSERT_EQUAL_TO_IMM,
+ 3,
+ "",
+ )
+
+ with pytest.raises(RuntimeError):
+ match_shape(s, heap, 2, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP,
2, "")
+
+ with pytest.raises(RuntimeError):
+ match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 2, MS.STORE_TO_HEAP,
2, MS.NO_OP, 0, "")
+
+
+def test_check_shape_info():
+ check_shape_info = tvm.get_global_func("vm.builtin.check_shape_info")
+ s = tvm.runtime.container.ShapeTuple([1, 2, 3])
+
+ check_shape_info(s, 3, "")
+ check_shape_info(s, -1, "")
+
+ # wrong ndim
+ with pytest.raises(ValueError):
+ check_shape_info(s, 2, "")
+
+ # wrong type
+ with pytest.raises(TypeError):
+ check_shape_info([], 2, "")
+
+
+def test_check_tensor_info():
+ check_tensor_info = tvm.get_global_func("vm.builtin.check_tensor_info")
+ x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+
+ check_tensor_info(x, 2, "int32", "")
+ check_tensor_info(x, -1, "int32", "")
+ check_tensor_info(x, 2, "", "")
+ check_tensor_info(x, -1, "", "")
+
+ # allow not passing in dtype
+ check_tensor_info(x, 2, "")
+ check_tensor_info(x, -1, "")
+
+ # ndim mismatch
+ with pytest.raises(ValueError, match=r".* ndim .*"):
+ check_tensor_info(x, 3, "int32", "")
+
+ # dtype mismatch
+ with pytest.raises(ValueError, match=r"myerror.* dtype .*"):
+ check_tensor_info(x, 2, "float32", "myerror")
+
+ # error with context
+ with pytest.raises(ValueError, match=r".* myerror .*"):
+ check_tensor_info(x, 3, "myerror")
+
+ # wrong type
+ with pytest.raises(TypeError):
+ check_tensor_info([], 2, "", "")
+
+
+def test_check_tuple_info():
+ check_tuple_info = tvm.get_global_func("vm.builtin.check_tuple_info")
+ x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+ t = tvm.runtime.convert([x, x, x])
+
+ check_tuple_info(t, 3, "")
+
+ # size
+ with pytest.raises(ValueError, match=r".*elements.*"):
+ check_tuple_info(t, 2, "")
+
+ # wrong type
+ with pytest.raises(TypeError):
+ check_tuple_info(x, 2, "")
+
+
+def test_check_func_info():
+ check_func_info = tvm.get_global_func("vm.builtin.check_func_info")
+ f = tvm.runtime.convert(lambda x: x)
+ x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+
+ check_func_info(f, "")
+
+ # wrong type
+ with pytest.raises(TypeError, match=".*myerror.*"):
+ check_func_info(x, "myerror")
+
+
+def test_tuple_getitem():
+ tuple_getitem = tvm.get_global_func("vm.builtin.tuple_getitem")
+ x = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+ y = tvm.nd.array(np.zeros((2, 3)).astype("int32"))
+ t = tvm.runtime.convert([x, y])
+
+ assert tuple_getitem(t, 0) == x
+ assert tuple_getitem(t, 1) == y
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index 75fc4d1429..793951b999 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
+import tvm
import pytest
from tvm import IRModule, relax, tir
from tvm.script import relax as R
@@ -448,42 +449,4 @@ else:
if __name__ == "__main__":
- test_function()
- test_extern_func()
-
- test_object_struct_info()
- test_prim_struct_info()
- test_shape_struct_info_0()
- test_shape_struct_info_1()
- test_shape_struct_info_2()
- test_tensor_struct_info()
- test_tuple_struct_info_empty()
- test_tuple_struct_info()
- test_func_struct_info()
-
- test_shape_type()
- test_object_type()
- test_dyn_tensor_type()
- test_packed_func_type()
- test_tuple_type()
- test_func_type()
-
- test_prim_value()
- test_string_imm()
- test_data_type_imm()
-
- test_var()
- test_dataflow_var()
- #
- test_tuple()
- test_tuple_get_item()
- test_shape_expr()
- test_call()
-
- test_seq_expr()
- test_binding_block()
- test_dataflow_block()
-
- test_match_cast()
- test_var_binding()
- test_if()
+ tvm.testing.main()
diff --git a/tests/python/relax/test_vm_codegen_only.py
b/tests/python/relax/test_vm_codegen_only.py
new file mode 100644
index 0000000000..b5e7709177
--- /dev/null
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -0,0 +1,333 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test last-stage of codegen VM.
+
+Restrictions: all shape lowered, explicit allocation.
+"""
+import tvm
+import pytest
+import numpy as np
+from tvm import relax, TVMError
+from tvm.script import relax as R, tir as T
+from tvm.relax.testing.vm import check_saved_func
+from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode
+
+EXEC_MODE = ["bytecode"]
+
+
+def codegen(mod, target, exec_mode="bytecode"):
+ builder = relax.ExecBuilder()
+ tir_mod = relax.vm._vmcodegen(builder, mod, exec_mode=exec_mode)
+ return relax.vm._vmlink(builder, target, tir_mod)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_copy(exec_mode):
+ @tvm.script.ir_module
+ class TestVMMove:
+ @R.function
+ def foo(x: R.Tensor((3, 4), "float32")):
+ R.func_attr({"global_symbol": "foo"})
+ z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3,
4), dtype="float32")))
+ return z
+
+ mod = TestVMMove
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = check_saved_func(vm, "foo", inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_if_cond_const(exec_mode):
+ @tvm.script.ir_module
+ class TestVMIfCondConst:
+ @R.function
+ def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2,
dtype="float32"):
+ R.func_attr({"global_symbol": "main"})
+ if relax.const(True, dtype="bool"):
+ ret = x
+ else:
+ ret = x
+ return ret
+
+ mod = TestVMIfCondConst
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ inp = tvm.nd.array(np.random.rand(3, 4))
+ res = vm["main"](inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy())
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_exec_serialize_export_library(exec_mode):
+ @tvm.script.ir_module
+ class TestVMMove:
+ @R.function
+ def foo(x: R.Tensor((3, 4), "float32")):
+ R.func_attr({"global_symbol": "foo"})
+ z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3,
4), dtype="float32")))
+ return z
+
+ mod = TestVMMove
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target)
+ from tvm.contrib import utils
+
+ temp_dir = utils.tempdir()
+ path_exec = temp_dir.relpath("exec.so")
+ ex.mod.export_library(path_exec)
+
+ loaded_exec = relax.vm.Executable(tvm.runtime.load_module(path_exec))
+ assert ex.as_text() == loaded_exec.as_text()
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_if_cond(exec_mode):
+ @tvm.script.ir_module
+ class TestVMCompileIf:
+ @R.function
+ def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) ->
R.Tensor:
+ R.func_attr({"global_symbol": "ife"})
+ if cond:
+ w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
+ else:
+ w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor))
+ return w
+
+ mod = TestVMCompileIf
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ inp = tvm.nd.array(np.random.rand(3, 4))
+ res = vm["ife"](tvm.nd.array(1), inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(),
rtol=1e-7, atol=1e-7)
+ res = vm["ife"](tvm.nd.array(True), inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(),
rtol=1e-7, atol=1e-7)
+ res = vm["ife"](tvm.nd.array(0), inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(),
rtol=1e-7, atol=1e-7)
+ res = vm["ife"](tvm.nd.array(False), inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(),
rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_return_const_tuple(exec_mode):
+ @tvm.script.ir_module
+ class ReturnConstTuple:
+ @R.function
+ def main(x: R.Tensor(ndim=2, dtype="float32")):
+ R.func_attr({"global_symbol": "main"})
+ y = R.const([1, 2])
+ z = (y, R.const([3, 4]), x)
+ return z
+
+ mod = ReturnConstTuple
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ inp = tvm.nd.array(np.random.rand(2, 3))
+ res0, res1, res2 = vm["main"](inp)
+ tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2]))
+ tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4]))
+ tvm.testing.assert_allclose(res2.numpy(), inp.numpy())
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_const_as_call_arg(exec_mode):
+ @tvm.script.ir_module
+ class TestVMConstAsCallArg:
+ @R.function
+ def main(x: R.Tensor(ndim=2, dtype="float32")):
+ R.func_attr({"global_symbol": "main"})
+ a = R.call_packed(
+ "test.vm.add",
+ relax.const([1, 2]),
+ relax.const([3, 4]),
+ sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+ )
+ b = R.call_packed(
+ "test.vm.add",
+ a,
+ x,
+ sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+ )
+ return b
+
+ mod = TestVMConstAsCallArg
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ inp = tvm.nd.array(np.random.rand(1, 2))
+ res = vm["main"](inp)
+ tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy())
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_shape_check_builtin(exec_mode):
+ MS = MatchShapeCode
+ MK = MakeShapeCode
+ # slot assignment:
+ # 0: n, 1: m
+ sindex = {"n": 0, "m": 1}
+
+ @tvm.script.ir_module
+ class TestVMShapeCheck:
+ @R.function
+ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3):
+ R.func_attr({"global_symbol": "main"})
+ n = T.Var("n", "int64")
+ k = T.Var("k", "int64")
+ shape_heap = R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ [R.prim_value(3)],
+ sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+ )
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "",
sinfo_args=[R.Tuple()]
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ 2,
+ MS.STORE_TO_HEAP,
+ sindex["n"],
+ MS.STORE_TO_HEAP,
+ sindex["m"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ # construct shape value for return
+ s = R.call_packed(
+ "vm.builtin.make_shape",
+ shape_heap,
+ 3,
+ MK.LOAD_SHAPE,
+ sindex["m"],
+ MK.LOAD_SHAPE,
+ sindex["n"],
+ MK.USE_IMM,
+ 2,
+ sinfo_args=[R.Shape(ndim=3)],
+ )
+ return s
+
+ mod = TestVMShapeCheck
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x = tvm.nd.array(np.zeros((1, 2)).astype("float32"))
+ res = vm["main"](x)
+ assert res == tvm.runtime.container.ShapeTuple([2, 1, 2])
+
+ # wrong input type
+ with pytest.raises(TypeError):
+ vm["main"]([])
+
+ # wrong ndim
+ with pytest.raises(ValueError, match=r".*ndim.*"):
+ vm["main"](tvm.nd.array(np.zeros(1).astype("float32")))
+
+ # wrong dtype
+ with pytest.raises(ValueError, match=r".*dtype.*"):
+ vm["main"](tvm.nd.array(np.zeros((1, 2)).astype("int32")))
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_prim_value(exec_mode):
+ @tvm.script.ir_module
+ class TestVMPrimValue:
+ @R.function
+ def main():
+ R.func_attr({"global_symbol": "main"})
+ ret = R.prim_value(T.int64(1))
+ return ret
+
+ mod = TestVMPrimValue
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = vm["main"]()
+ assert res == 1
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_string_imm(exec_mode):
+ @tvm.script.ir_module
+ class TestVMStringImm:
+ @R.function
+ def main():
+ R.func_attr({"global_symbol": "main"})
+ ret = R.str("hello")
+ return ret
+
+ mod = TestVMStringImm
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = vm["main"]()
+ assert res == "hello"
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_datatype_imm(exec_mode):
+ @tvm.script.ir_module
+ class TestDataTypeImm:
+ @R.function
+ def main():
+ R.func_attr({"global_symbol": "main"})
+ ret = R.dtype("float32")
+ return ret
+
+ mod = TestDataTypeImm
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = vm["main"]()
+ assert res == "float32"
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_builtin_reshape(exec_mode):
+ @tvm.script.ir_module
+ class TestVMBuiltinReshape:
+ @R.function
+ def main(x: R.Tensor((3, 4), "float32")):
+ R.func_attr({"global_symbol": "main"})
+ y = R.call_packed(
+ "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2),
"float32")
+ )
+ return y
+
+ mod = TestVMBuiltinReshape
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ dev = tvm.cpu()
+ vm = relax.VirtualMachine(ex, dev)
+
+ input_np = np.random.rand(3, 4).astype("float32")
+ input = tvm.nd.array(input_np, dev)
+ res = vm["main"](input)
+ expected = input_np.reshape(6, 2)
+ tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()