This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 8fb1c9c577e4c329c47203afada2e52fb49c2292 Author: Yuchen Jin <yuch...@cs.washington.edu> AuthorDate: Fri Feb 10 16:19:37 2023 -0800 [Unity] Relax VM codegen (#13954) --- python/tvm/relax/testing/runtime_builtin.py | 34 ++ src/relax/backend/vm/codegen_vm.cc | 447 +++++++++++++++++++++ src/relax/op/op.cc | 220 +++++++++- src/relax/op/op_common.h | 25 +- src/runtime/relax_vm/builtin.cc | 23 +- tests/python/relax/test_runtime_builtin.py | 153 +++++++ tests/python/relax/test_tvmscript_printer_relax.py | 41 +- tests/python/relax/test_vm_codegen_only.py | 333 +++++++++++++++ 8 files changed, 1228 insertions(+), 48 deletions(-) diff --git a/python/tvm/relax/testing/runtime_builtin.py b/python/tvm/relax/testing/runtime_builtin.py new file mode 100644 index 0000000000..1b04364e69 --- /dev/null +++ b/python/tvm/relax/testing/runtime_builtin.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities for runtime builtin functions.""" +from enum import IntEnum + + +class MatchShapeCode(IntEnum): + """Code passed to match shape builtin""" + + ASSERT_EQUAL_TO_IMM = 0 + STORE_TO_HEAP = 1 + NO_OP = 2 + ASSERT_EQUAL_TO_LOAD = 3 + + +class MakeShapeCode(IntEnum): + """Code passed to match shape builtin""" + + USE_IMM = 0 + LOAD_SHAPE = 1 diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc new file mode 100644 index 0000000000..1782f1107a --- /dev/null +++ b/src/relax/backend/vm/codegen_vm.cc @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_vm.cc + * \brief A codegen to generate VM executable from a Relax IRModule. + */ +#include <tvm/driver/driver_api.h> +#include <tvm/relax/exec_builder.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/op_attr_types.h> +#include <tvm/runtime/relax_vm/bytecode.h> +#include <tvm/target/target.h> +#include <tvm/tir/function.h> + +#include <string> +#include <unordered_map> +#include <vector> + +#include "../../../target/metadata_module.h" +#include "../../../target/source/codegen_source_base.h" + +namespace tvm { +namespace relax { +namespace relax_vm { + +using tvm::Target; +using namespace relax; +using namespace tvm::runtime; +using namespace tvm::runtime::relax_vm; + +// Helper function to get the function name of the registered packed function implementation of +// relax operator. +FCallPacked GetPackedFuncName(const Call& call) { + static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked"); + if (call->op.as<OpNode>()) { + Op op = Downcast<Op>(call->op); + if (op_map.count(op)) { + return op_map[op]; + } + } + return {}; +} + +/*! + * \brief A class to generate VM executable for Relax functions. + */ +class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> { + public: + explicit CodeGenVM(relax::ExecBuilder builder, IRModule ctx_mod) + : builder_(builder), ctx_mod_(ctx_mod) {} + + static IRModule Run(relax::ExecBuilder builder, IRModule mod) { + IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>()); + CodeGenVM codegen(builder, mod); + // Remove relax function and turn into TIR func. + for (auto& p : mod->functions) { + if (auto* func = p.second.as<FunctionNode>()) { + codegen.Codegen(GetRef<Function>(func)); + } else { + res_mod->Add(p.first, p.second); + } + } + return res_mod; + } + + protected: + size_t NewRegister() { return registers_num_++; } + + // Convert Arg value to a register, trigger copy if needed + Instruction::Arg EnsureReg(Instruction::Arg arg) { + if (arg.kind() == Instruction::ArgKind::kRegister) { + return arg; + } else { + RegName dst_reg = NewRegister(); + builder_->EmitCall("vm.builtin.copy", {arg}, dst_reg); + return Instruction::Arg::Register(dst_reg); + } + } + + void Codegen(const Function& func) { + Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + + Array<String> param_names; + for (Var param : func->params) { + param_names.push_back(param->name_hint()); + } + + builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names); + + for (size_t i = 0; i < func->params.size(); ++i) { + RegName r = NewRegister(); + ICHECK_EQ(r, static_cast<RegName>(i)); + this->var_arg_map_.insert({func->params[i], Instruction::Arg::Register(r)}); + } + Instruction::Arg ret = ExprFunctor::VisitExpr(func->body); + builder_->EmitRet(EnsureReg(ret)); + builder_->EndFunction(gsymbol.value()); + // reset register number to be 0; + registers_num_ = 0; + var_arg_map_.clear(); + } + + Instruction::Arg VisitExpr_(const SeqExprNode* op) final { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + Instruction::Arg value; + if (auto* var_binding = binding.as<VarBindingNode>()) { + value = this->VisitExpr(var_binding->value); + } else if (auto* match_cast = binding.as<MatchCastNode>()) { + value = this->VisitExpr(match_cast->value); + } else { + LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + } + this->var_arg_map_.insert({binding->var, value}); + } + } + + Instruction::Arg ret_reg = this->VisitExpr(op->body); + return ret_reg; + } + + Instruction::Arg VisitExpr_(const CallNode* call_node) final { + Call call = GetRef<Call>(call_node); + + if (call_node->op == null_value_op_) { + return Instruction::Arg::Register(Instruction::kVoidRegister); + } + + // allocate dst register. + RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister(); + if (call->op.as<OpNode>()) { + if (call_node->op == call_builtin_with_ctx_op_) { + // TODO(relax-team) migrate most handling of op to + // directly map to call_builtin_with_ctx before codegen and simplify vm codegen. + EmitCallBuiltinWithCtx(call, dst_reg); + } else if (call_node->op == alloc_storage_op_) { + EmitAllocStorage(call, dst_reg); + } else if (call_node->op == alloc_tensor_op_) { + EmitAllocTensor(call, dst_reg); + } else { + // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those + // ops are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op; + } + } else { + EmitNormalCall(call, dst_reg); + } + return Instruction::Arg::Register(dst_reg); + } + + Instruction::Arg VisitExpr_(const IfNode* op) final { + const If& ife = GetRef<If>(op); + Instruction::Arg cond_value = this->VisitExpr(ife->cond); + + // Reserve a register for cond + RegName cond_reg = NewRegister(); + builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg); + + // obtain the temp exec in progress. + vm::Executable* exec = builder_->exec(); + + // Record the offset of If instruction + size_t if_offset = exec->instr_offset.size(); + + builder_->EmitIf(Instruction::Arg::Register(cond_reg), 3); + size_t num_instr = exec->instr_offset.size(); + Instruction::Arg true_value = this->VisitExpr(ife->true_branch); + // Reserve a register for return + size_t merge_register = NewRegister(); + // Copy the output from true branch to merge register + builder_->EmitCall("vm.builtin.copy", {true_value}, merge_register); + + // Record the offset of Goto instruction + size_t goto_offset = exec->instr_offset.size(); + + builder_->EmitGoto(1); + + // Calculate the false offset of If + size_t false_offset = exec->instr_offset.size() - num_instr + 1; + + Instruction::Arg false_value = this->VisitExpr(ife->false_branch); + // Copy the output data of false branch to merge register + builder_->EmitCall("vm.builtin.copy", {false_value}, merge_register); + + // Update the offsets of the If instruction emitted above + // Jump to the behind of the next goto instruction + exec->SetInstructionData(if_offset, 2, static_cast<ExecWord>(false_offset)); + // Update the pc_offset of Goto instruction + // Jump over the false branch + size_t pc_offset = exec->instr_offset.size() - goto_offset; + exec->SetInstructionData(goto_offset, 1, static_cast<ExecWord>(pc_offset)); + return Instruction::Arg::Register(merge_register); + } + + Instruction::Arg VisitExpr_(const VarNode* op) final { + Var var = GetRef<Var>(op); + auto it = this->var_arg_map_.find(var); + ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined"; + return it->second; + } + + Instruction::Arg VisitExpr_(const ConstantNode* op) final { + return builder_->ConvertConstant(op->data); + } + + Instruction::Arg VisitExpr_(const ShapeExprNode* op) final { + std::vector<int64_t> shape; + for (PrimExpr e : op->values) { + if (auto* int_value = e.as<IntImmNode>()) { + shape.push_back(int_value->value); + } else { + LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + } + } + return builder_->ConvertConstant(ShapeTuple(shape)); + } + + Instruction::Arg VisitExpr_(const PrimValueNode* op) final { + if (auto* int_imm = op->value.as<IntImmNode>()) { + return builder_->ConvertConstant(int_imm->value); + } else { + auto* float_imm = op->value.as<FloatImmNode>(); + ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now"; + return builder_->ConvertConstant(float_imm->value); + } + } + + Instruction::Arg VisitExpr_(const StringImmNode* op) final { + return builder_->ConvertConstant(op->value); + } + + Instruction::Arg VisitExpr_(const DataTypeImmNode* op) final { + return builder_->ConvertConstant(op->value); + } + + Instruction::Arg VisitExpr_(const TupleNode* op) final { + Tuple tuple = GetRef<Tuple>(op); + std::vector<Instruction::Arg> args; + for (Expr arg : tuple->fields) { + args.push_back(this->VisitExpr(arg)); + } + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.make_tuple", args, dst_register); + + return Instruction::Arg::Register(dst_register); + } + + Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = GetRef<TupleGetItem>(op); + std::vector<Instruction::Arg> args = {this->VisitExpr(expr->tuple)}; + + args.push_back(builder_->ConvertConstant(expr->index)); + + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register); + + return Instruction::Arg::Register(dst_register); + } + + Instruction::Arg VisitExpr_(const GlobalVarNode* op) final { + GlobalVar gvar = GetRef<GlobalVar>(op); + Optional<String> symbol; + VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc; + + // Run a look up in the env to see if it maps to an extern func. + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* efunc = func.as<ExternFuncNode>()) { + symbol = efunc->global_symbol; + kind = VMFuncInfo::FuncKind::kPackedFunc; + } else if (func.as<FunctionNode>()) { + symbol = gvar->name_hint; + kind = VMFuncInfo::FuncKind::kVMFunc; + } + } + // GlobalVar can be reference to a Relax function or a TIR primfunc + // At this point: all global var must corresponds to the right symbol. + // TODO(relax-team): switch everything to extern before splitting TIR/relax + // so we do not have idle global var here. + if (!symbol.defined()) { + symbol = gvar->name_hint; + kind = VMFuncInfo::FuncKind::kPackedFunc; + } + // declare the function to be safe. + ICHECK(symbol.defined()); + builder_->DeclareFunction(symbol.value(), kind); + return builder_->GetFunction(symbol.value()); + } + + Instruction::Arg VisitExpr_(const ExternFuncNode* op) final { + builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); + return builder_->GetFunction(op->global_symbol); + } + + void EmitAllocStorage(const Call& call_node, RegName dst_reg) { + ICHECK_EQ(call_node->args.size(), 3); + // Handle args of the call + std::vector<Instruction::Arg> args; + args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); + // buffer size, dtype, device index + for (auto arg : call_node->args) { + args.push_back(this->VisitExpr(arg)); + } + builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg); + } + + void EmitAllocTensor(const Call& call_node, RegName dst_reg) { + ICHECK_EQ(call_node->args.size(), 4); + std::vector<Instruction::Arg> args; + args.reserve(4); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg)); + } + builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg); + } + + void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) { + std::vector<Instruction::Arg> args; + args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); + + auto func = this->VisitExpr(call_node->args[0]); + auto tuple_arg = Downcast<Tuple>(call_node->args[1]); + + // Handle args of the call + for (Expr arg : tuple_arg->fields) { + args.push_back(this->VisitExpr(arg)); + } + + builder_->EmitCall(func, args, dst_reg); + } + + void EmitNormalCall(const Call& call_node, RegName dst_reg) { + Instruction::Arg func = VisitExpr(call_node->op); + std::vector<Instruction::Arg> args = VisitArray(call_node->args); + builder_->EmitCall(func, args, dst_reg); + } + + // TODO(relax-team) revisit after PrimValue. + // Emit the `call_node` attributes as constants and append these constants to `args` vector. + void AppendAttrsAsConstants(const Call& call_node, std::vector<Instruction::Arg>& args) { + auto attrs = call_node->attrs; + if (!attrs.defined()) return; + + LOG(FATAL) << "Support for attributes of Op " << call_node->op + << " has not been implemented yet."; + return; + } + + // Emits call to packed function `name` with arguments copied over from `call_node` args and + // attributes. + void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) { + std::vector<Instruction::Arg> args = VisitArray(call_node->args); + AppendAttrsAsConstants(call_node, args); + builder_->EmitCall(name, args, dst_reg); + } + + std::vector<Instruction::Arg> VisitArray(const Array<Expr>& arr) { + std::vector<Instruction::Arg> ret; + for (size_t i = 0; i < arr.size(); ++i) { + ret.push_back(this->VisitExpr(arr[i])); + } + return ret; + } + + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! + * \brief Total number of virtual registers allocated. + * \note The first two registers are reserved for special registers. + */ + size_t registers_num_ = 0; + /*! \brief Map from var to register number. */ + std::unordered_map<Var, Instruction::Arg, ObjectPtrHash, ObjectPtrEqual> var_arg_map_; + /*! \brief the context module. */ + IRModule ctx_mod_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); +}; + +/*! + * \brief Create the Relax VM executable from all relax.Function in mod. + * and add them to exec_builder. + * \param exec_builder Builder to collect executables. + * \param mod Input module. + * \return Left over IRModule that may contain otehr functions. + */ +IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) { + return CodeGenVM::Run(exec_builder, mod); +} + +TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); + +/*! + * \brief Link the libaries together. + */ +Module VMLink(ExecBuilder builder, Target target, Optional<Module> lib, Array<Module> ext_libs, + Map<String, runtime::NDArray> params) { + // TODO(relax-team) Revisit the param and ext_lib options. + ObjectPtr<Executable> executable = builder->Get(); + if (!lib.defined()) { + lib = codegen::CSourceModuleCreate(";", "", Array<String>{}); + } + std::unordered_map<std::string, runtime::NDArray> conv_params; + for (const auto& [name, param] : params) { + conv_params[name] = param; + } + Module combined_lib = codegen::CreateMetadataModule( + conv_params, lib.value(), ext_libs, target, + + // TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM support. + // Before jumping into details, only support cpp runtime for now. + relay::Runtime::Create("cpp"), + relay::Executor::Create("graph"), // TODO(@sunggg): pass arbitrarily executor. CPP runtime + // won't use this anyways. + relay::backend::ExecutorCodegenMetadata()); + executable->Import(combined_lib); + return Module(executable); +} + +TVM_REGISTER_GLOBAL("relax.VMLink").set_body_typed(VMLink); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 8640ed79ad..ca66b0a9ef 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -18,13 +18,46 @@ */ #include <tvm/relax/analysis.h> #include <tvm/relax/expr.h> -#include <tvm/relax/op_attr_types.h> #include <tvm/relax/utils.h> #include <tvm/relay/op.h> +#include "op_common.h" + namespace tvm { namespace relax { +bool EqualConstInt(const PrimExpr& lhs, int64_t value) { + if (const int64_t* pvalue = tir::as_const_int(lhs)) { + return pvalue[0] == value; + } + return false; +} + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + PrimExpr diff = lhs - rhs; + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + return false; +} + +StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { + return TupleStructInfo(Array<StructInfo>()); +} + +StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { + return ObjectStructInfo(); +} + +StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) { + return ShapeStructInfo(kUnknownNDim); +} + // call_tir StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { @@ -73,5 +106,190 @@ Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list, TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); +// call builtin +StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() == 0) { + // by default return void. + return TupleStructInfo(Array<StructInfo>()); + } else { + ICHECK_EQ(call->sinfo_args.size(), 1); + return call->sinfo_args[0]; + } +} + +TVM_REGISTER_OP("relax.call_builtin_with_ctx") + .set_num_inputs(4) + .add_argument("func", "Expr", "The builtin packed func.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallBuiltinWithCtx); + +Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array<StructInfo> sinfo_args) { + static const Op& op = Op::Get("relax.call_builtin_with_ctx"); + return Call(op, {func, args}, Attrs(), sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); + +TVM_REGISTER_OP("relax.null_value") + .set_num_inputs(0) + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeCallNullValue() { + static const Op& op = Op::Get("relax.null_value"); + return Call(op, {}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); + +// make_closure + +RELAY_REGISTER_OP("relax.make_closure") + .set_num_inputs(2) + .add_argument("func", "Expr", "The closure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeClosure(Expr func, Tuple args) { + static const Op& op = Op::Get("relax.make_closure"); + return Call(op, {func, args}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); + +// invoke_closure + +StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.empty()) { + return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); + } +} + +RELAY_REGISTER_OP("relax.invoke_closure") + .set_num_inputs(2) + .add_argument("closure", "Expr", "The VMClosure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoInvokeClosure); + +Expr InvokeClosure(Expr closure, Tuple args, Array<StructInfo> sinfo_args) { + static const Op& op = Op::Get("relax.invoke_closure"); + return Call(op, {closure, args}, {}, sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); + +// shape_of + +RELAY_REGISTER_OP("relax.shape_of") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnShapeStructInfo); + +Expr MakeShapeOf(Expr expr) { + static const Op& op = Op::Get("relax.shape_of"); + return Call(op, {expr}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); + +// alloc_tensor + +StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) { + ICHECK(call->args[0].as<ShapeExprNode>()) + << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); + ICHECK(call->args[1].as<DataTypeImmNode>()) + << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); + DataType out_dtype; + if (const auto* dtype_node = call->args[1].as<DataTypeImmNode>()) { + const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node); + out_dtype = dtype_imm->value; + } + return TensorStructInfo(call->args[0], out_dtype); +} + +RELAY_REGISTER_OP("relax.builtin.alloc_tensor") + .set_num_inputs(3) + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "int64_t", + "The device index indicating on which device the tensor is to be " + "allocated at runtime. Index -1 is reserved for the host device.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllocateTensor); + +Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { + static const Op& op = Op::Get("relax.builtin.alloc_tensor"); + return Call(op, {shape, DataTypeImm(dtype), PrimValue::Int64(runtime_device_index)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); + +// vm alloc_storage + +RELAY_REGISTER_OP("relax.vm.alloc_storage") + .set_num_inputs(3) + .add_argument("size", "Expr", "The size of the storage to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "int64_t", + "The device index indicating on which device the tensor is " + "to be allocated at runtime.") + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeVMAllocStorage(Expr size, int64_t runtime_device_index, DataType dtype) { + static const Op& op = Op::Get("relax.vm.alloc_storage"); + return Call(op, {size, PrimValue::Int64(runtime_device_index), DataTypeImm(dtype)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); + +// vm alloc_tensor + +Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { return call->args[1]; } + +StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) { + DataType out_dtype; + if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) { + const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node); + out_dtype = dtype_imm->value; + } + if (const auto* output_shape = call->args[1].as<ShapeExprNode>()) { + return TensorStructInfo(GetRef<Expr>(output_shape), out_dtype); + } + return TensorStructInfo(out_dtype, kUnknownNDim); +} + +RELAY_REGISTER_OP("relax.vm.alloc_tensor") + .set_num_inputs(4) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "int", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoVMAllocTensor); + +Expr MakeVMAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) { + static const Op& op = Op::Get("relax.vm.alloc_tensor"); + return Call(op, {storage, PrimValue::Int64(offset), shape, DataTypeImm(dtype)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); + +// vm call_tir_dyn + +RELAY_REGISTER_OP("relax.vm.call_tir_dyn") + .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", + "The input arguments (list of tensors and last argument is ShapeExpr)") + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeCallTIRDyn(Expr func, Tuple args) { + static const Op& op = Op::Get("relax.vm.call_tir_dyn"); + return Call(op, {func, args}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 8e362bb4d5..c6d335b2a1 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -115,7 +115,30 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx } /*! - * \brief Infer the struct info for unary arithmetic elementwise ops. It's also + * \brief Infer the struct info by returning the struct info of the input argument. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \tparam arg_index The index of the argument to infer the output dtype from. + * \return The inferred struct info. + */ +template <int arg_index> +StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast<Op>(call->op); + int n_input = op->arguments.size(); + if (static_cast<int>(call->args.size()) != n_input) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " op should have " << n_input << " arguments"); + } + if (arg_index >= n_input) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " op has only " << n_input + << "arguments, but try to get the arg with index " << arg_index); + } + return GetStructInfo(call->args[arg_index]); +} + +/*! + * \brief Infer the struct info for unary arithmetic elementwise ops. It's also * used in some NN operators. * \param call The context Call to the operator. * \param ctx The error reporting context. diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 0ef63c8a41..15a4f8702b 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -19,7 +19,6 @@ /*! * \file src/runtime/relax_vm/builtin.cc */ -#include <tvm/runtime/container/adt.h> #include <tvm/runtime/container/shape_tuple.h> #include <tvm/runtime/data_type.h> #include <tvm/runtime/device_api.h> @@ -214,14 +213,13 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo * \param err_ctx Additional context if error occurs. */ void CheckTupleInfo(ObjectRef arg, int64_t size, Optional<String> err_ctx) { - using Tuple = runtime::ADT; // a function that lazily get context for error reporting - auto* ptr = arg.as<Tuple::ContainerType>(); + auto* ptr = arg.as<runtime::ArrayNode>(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " << arg->GetTypeKey(); - CHECK(static_cast<int64_t>(ptr->size) == size) + CHECK(static_cast<int64_t>(ptr->size()) == size) << "ValueError: " << err_ctx.value_or("") << " expect a Tuple with " << size << " elements, " - << " but get a Tuple with " << ptr->size << " elements."; + << " but get a Tuple with " << ptr->size() << " elements."; } TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); @@ -321,6 +319,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body([](TVMArgs args, TVMRetValue* rv *rv = args[0]; }); +TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data, ShapeTuple new_shape) { + return data.CreateView(new_shape, data->dtype); +}); + /*! * \brief Load the scalar value in cond and return the result value. * \param cond The condition @@ -367,8 +369,15 @@ TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); //------------------------------------- // Data structure API //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem").set_body_typed([](runtime::ADT arr, int64_t index) { - return arr[index]; +TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem") + .set_body_typed([](runtime::Array<ObjectRef> arr, int64_t index) { return arr[index]; }); + +TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Array<ObjectRef> arr; + for (int i = 0; i < args.num_args; ++i) { + arr.push_back(args[i].operator ObjectRef()); + } + *rv = arr; }); } // namespace relax_vm diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py new file mode 100644 index 0000000000..b4ba54b455 --- /dev/null +++ b/tests/python/relax/test_runtime_builtin.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import pytest +import numpy as np + +from tvm.ir import assert_structural_equal +from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode + + +def test_make_shape(): + MK = MakeShapeCode + make_shape = tvm.get_global_func("vm.builtin.make_shape") + heap = tvm.nd.array(np.arange(10).astype("int64")) + s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2) + + assert s == tvm.runtime.container.ShapeTuple([10, 0, 2]) + + +def test_match_shape(): + MS = MatchShapeCode + match_shape = tvm.get_global_func("vm.builtin.match_shape") + heap = tvm.nd.array(np.zeros(10).astype("int64")) + + assert heap.numpy()[2] == 0 + + s = tvm.runtime.container.ShapeTuple([1, 2, 3]) + x = tvm.nd.array(np.zeros([1, 2, 3])) + + match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") + + assert heap.numpy()[2] == 2 + + match_shape( + x, + heap, + 3, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_LOAD, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 3, + "", + ) + + with pytest.raises(RuntimeError): + match_shape(s, heap, 2, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, "") + + with pytest.raises(RuntimeError): + match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 2, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") + + +def test_check_shape_info(): + check_shape_info = tvm.get_global_func("vm.builtin.check_shape_info") + s = tvm.runtime.container.ShapeTuple([1, 2, 3]) + + check_shape_info(s, 3, "") + check_shape_info(s, -1, "") + + # wrong ndim + with pytest.raises(ValueError): + check_shape_info(s, 2, "") + + # wrong type + with pytest.raises(TypeError): + check_shape_info([], 2, "") + + +def test_check_tensor_info(): + check_tensor_info = tvm.get_global_func("vm.builtin.check_tensor_info") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + + check_tensor_info(x, 2, "int32", "") + check_tensor_info(x, -1, "int32", "") + check_tensor_info(x, 2, "", "") + check_tensor_info(x, -1, "", "") + + # allow not passing in dtype + check_tensor_info(x, 2, "") + check_tensor_info(x, -1, "") + + # ndim mismatch + with pytest.raises(ValueError, match=r".* ndim .*"): + check_tensor_info(x, 3, "int32", "") + + # dtype mismatch + with pytest.raises(ValueError, match=r"myerror.* dtype .*"): + check_tensor_info(x, 2, "float32", "myerror") + + # error with context + with pytest.raises(ValueError, match=r".* myerror .*"): + check_tensor_info(x, 3, "myerror") + + # wrong type + with pytest.raises(TypeError): + check_tensor_info([], 2, "", "") + + +def test_check_tuple_info(): + check_tuple_info = tvm.get_global_func("vm.builtin.check_tuple_info") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + t = tvm.runtime.convert([x, x, x]) + + check_tuple_info(t, 3, "") + + # size + with pytest.raises(ValueError, match=r".*elements.*"): + check_tuple_info(t, 2, "") + + # wrong type + with pytest.raises(TypeError): + check_tuple_info(x, 2, "") + + +def test_check_func_info(): + check_func_info = tvm.get_global_func("vm.builtin.check_func_info") + f = tvm.runtime.convert(lambda x: x) + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + + check_func_info(f, "") + + # wrong type + with pytest.raises(TypeError, match=".*myerror.*"): + check_func_info(x, "myerror") + + +def test_tuple_getitem(): + tuple_getitem = tvm.get_global_func("vm.builtin.tuple_getitem") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + y = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + t = tvm.runtime.convert([x, y]) + + assert tuple_getitem(t, 0) == x + assert tuple_getitem(t, 1) == y + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index e2cb8bc5fc..58596f968f 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring +import tvm import pytest from tvm import IRModule, relax, tir from tvm.script import relax as R @@ -447,42 +448,4 @@ else: if __name__ == "__main__": - test_function() - test_extern_func() - - test_object_struct_info() - test_prim_struct_info() - test_shape_struct_info_0() - test_shape_struct_info_1() - test_shape_struct_info_2() - test_tensor_struct_info() - test_tuple_struct_info_empty() - test_tuple_struct_info() - test_func_struct_info() - - test_shape_type() - test_object_type() - test_dyn_tensor_type() - test_packed_func_type() - test_tuple_type() - test_func_type() - - test_prim_value() - test_string_imm() - test_data_type_imm() - - test_var() - test_dataflow_var() - # - test_tuple() - test_tuple_get_item() - test_shape_expr() - test_call() - - test_seq_expr() - test_binding_block() - test_dataflow_block() - - test_match_cast() - test_var_binding() - test_if() + tvm.testing.main() diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py new file mode 100644 index 0000000000..b5e7709177 --- /dev/null +++ b/tests/python/relax/test_vm_codegen_only.py @@ -0,0 +1,333 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test last-stage of codegen VM. + +Restrictions: all shape lowered, explicit allocation. +""" +import tvm +import pytest +import numpy as np +from tvm import relax, TVMError +from tvm.script import relax as R, tir as T +from tvm.relax.testing.vm import check_saved_func +from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode + +EXEC_MODE = ["bytecode"] + + +def codegen(mod, target, exec_mode="bytecode"): + builder = relax.ExecBuilder() + tir_mod = relax.vm._vmcodegen(builder, mod, exec_mode=exec_mode) + return relax.vm._vmlink(builder, target, tir_mod) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_copy(exec_mode): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_if_cond_const(exec_mode): + @tvm.script.ir_module + class TestVMIfCondConst: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2, dtype="float32"): + R.func_attr({"global_symbol": "main"}) + if relax.const(True, dtype="bool"): + ret = x + else: + ret = x + return ret + + mod = TestVMIfCondConst + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_exec_serialize_export_library(exec_mode): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target) + from tvm.contrib import utils + + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") + ex.mod.export_library(path_exec) + + loaded_exec = relax.vm.Executable(tvm.runtime.load_module(path_exec)) + assert ex.as_text() == loaded_exec.as_text() + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_if_cond(exec_mode): + @tvm.script.ir_module + class TestVMCompileIf: + @R.function + def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor: + R.func_attr({"global_symbol": "ife"}) + if cond: + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + else: + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + return w + + mod = TestVMCompileIf + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["ife"](tvm.nd.array(1), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(True), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(0), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(False), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_return_const_tuple(exec_mode): + @tvm.script.ir_module + class ReturnConstTuple: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = (y, R.const([3, 4]), x) + return z + + mod = ReturnConstTuple + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2, 3)) + res0, res1, res2 = vm["main"](inp) + tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2])) + tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4])) + tvm.testing.assert_allclose(res2.numpy(), inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_const_as_call_arg(exec_mode): + @tvm.script.ir_module + class TestVMConstAsCallArg: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"global_symbol": "main"}) + a = R.call_packed( + "test.vm.add", + relax.const([1, 2]), + relax.const([3, 4]), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + b = R.call_packed( + "test.vm.add", + a, + x, + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + return b + + mod = TestVMConstAsCallArg + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(1, 2)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_shape_check_builtin(exec_mode): + MS = MatchShapeCode + MK = MakeShapeCode + # slot assignment: + # 0: n, 1: m + sindex = {"n": 0, "m": 1} + + @tvm.script.ir_module + class TestVMShapeCheck: + @R.function + def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): + R.func_attr({"global_symbol": "main"}) + n = T.Var("n", "int64") + k = T.Var("k", "int64") + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(3)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + # construct shape value for return + s = R.call_packed( + "vm.builtin.make_shape", + shape_heap, + 3, + MK.LOAD_SHAPE, + sindex["m"], + MK.LOAD_SHAPE, + sindex["n"], + MK.USE_IMM, + 2, + sinfo_args=[R.Shape(ndim=3)], + ) + return s + + mod = TestVMShapeCheck + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.zeros((1, 2)).astype("float32")) + res = vm["main"](x) + assert res == tvm.runtime.container.ShapeTuple([2, 1, 2]) + + # wrong input type + with pytest.raises(TypeError): + vm["main"]([]) + + # wrong ndim + with pytest.raises(ValueError, match=r".*ndim.*"): + vm["main"](tvm.nd.array(np.zeros(1).astype("float32"))) + + # wrong dtype + with pytest.raises(ValueError, match=r".*dtype.*"): + vm["main"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_prim_value(exec_mode): + @tvm.script.ir_module + class TestVMPrimValue: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.prim_value(T.int64(1)) + return ret + + mod = TestVMPrimValue + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == 1 + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_string_imm(exec_mode): + @tvm.script.ir_module + class TestVMStringImm: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.str("hello") + return ret + + mod = TestVMStringImm + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == "hello" + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_datatype_imm(exec_mode): + @tvm.script.ir_module + class TestDataTypeImm: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.dtype("float32") + return ret + + mod = TestDataTypeImm + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == "float32" + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_builtin_reshape(exec_mode): + @tvm.script.ir_module + class TestVMBuiltinReshape: + @R.function + def main(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "main"}) + y = R.call_packed( + "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32") + ) + return y + + mod = TestVMBuiltinReshape + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(ex, dev) + + input_np = np.random.rand(3, 4).astype("float32") + input = tvm.nd.array(input_np, dev) + res = vm["main"](input) + expected = input_np.reshape(6, 2) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7) + + +if __name__ == "__main__": + tvm.testing.main()