This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 7ef36ebb5d [Unity] Support symbolic PrimValue arguments (#15980) 7ef36ebb5d is described below commit 7ef36ebb5d056320676faede712f2052d92f7a5d Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Wed Oct 25 15:50:35 2023 -0500 [Unity] Support symbolic PrimValue arguments (#15980) Prior this this commit, all symbolic variables needed to be defined either by tensor shapes, or by an explicit `tvm.runtime.ShapeTuple` argument. This commit allows arguments `arg: R.Prim(value="n")` to serve as a source of definition for symbolic variables. --- src/relax/backend/vm/codegen_vm.cc | 7 +- src/relax/backend/vm/vm_shape_lower.cc | 158 +++++++++++++++++++++++---------- src/runtime/ndarray.cc | 4 +- src/runtime/relax_vm/builtin.cc | 88 ++++++++++++++++++ tests/python/relax/test_vm_build.py | 120 +++++++++++++++++++++++++ 5 files changed, 325 insertions(+), 52 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index caee0a0c13..64b87c6c12 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -246,10 +246,11 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> { Instruction::Arg VisitExpr_(const PrimValueNode* op) final { if (auto* int_imm = op->value.as<IntImmNode>()) { return builder_->ConvertConstant(int_imm->value); - } else { - auto* float_imm = op->value.as<FloatImmNode>(); - ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now"; + } else if (auto* float_imm = op->value.as<FloatImmNode>()) { return builder_->ConvertConstant(float_imm->value); + } else { + LOG(FATAL) << "PrimValue should only contain constant after VMShapeLower, " + << "but received " << GetRef<Expr>(op) << " with type " << op->value->GetTypeKey(); } } diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 8b8eb33f5b..41b27ea625 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -347,6 +347,41 @@ class VMShapeLowerMutator return GetRef<Expr>(op); } + std::pair<Expr, Expr> MakeSymbolicShapeArg(const PrimExpr& expr) { + using runtime::relax_vm::MakeShapeCode; + + if (auto* int_expr = expr.as<IntImmNode>()) { + return {PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm)), + PrimValue::Int64(int_expr->value)}; + } else { + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been computed"; + return {PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape)), + PrimValue::Int64(slot->index)}; + } + } + + Expr VisitExpr_(const PrimValueNode* op) final { + using runtime::relax_vm::MakeShapeCode; + // Constant shape can be preserved. + bool is_const_value = + op->value->IsInstance<IntImmNode>() || op->value->IsInstance<FloatImmNode>(); + if (is_const_value) { + return GetRef<Expr>(op); + } + + Array<Expr> args = {shape_heap_}; + auto [code, value_or_index] = MakeSymbolicShapeArg(op->value); + args.push_back(code); + args.push_back(value_or_index); + + // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) + Call call(builtin_make_prim_value_, args, Attrs(), {Downcast<StructInfo>(op->struct_info_)}); + return call; + } + Expr VisitExpr_(const ShapeExprNode* op) final { using runtime::relax_vm::MakeShapeCode; // Constant shape can be preserved. @@ -359,17 +394,9 @@ class VMShapeLowerMutator Array<Expr> args = {shape_heap_, PrimValue::Int64(static_cast<int64_t>(op->values.size()))}; for (PrimExpr expr : op->values) { - if (auto* int_expr = expr.as<IntImmNode>()) { - args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm))); - args.push_back(PrimValue::Int64(int_expr->value)); - } else { - auto it = slot_map_.find(expr); - ICHECK(it != slot_map_.end()); - auto* slot = it->second; - ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been computed"; - args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape))); - args.push_back(PrimValue::Int64(slot->index)); - } + auto [code, value_or_index] = MakeSymbolicShapeArg(expr); + args.push_back(code); + args.push_back(value_or_index); } // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) @@ -402,6 +429,45 @@ class VMShapeLowerMutator // Place this pass as last pass before codegen. StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final { return sinfo; } + /* \brief Internal utility function used for RunMatch() + * + * \param expr The expression to be matched + * + * \param require_value_computed Whether we require all expr to be computed. + * + * \return The MatchShapeCode, and a relax expression specifying the + * argument used by that MatchShapeCode. + */ + std::pair<runtime::relax_vm::MatchShapeCode, Expr> MakeMatchArgs(const PrimExpr& expr, + bool require_value_computed) { + using runtime::relax_vm::MatchShapeCode; + + if (auto* int_expr = expr.as<IntImmNode>()) { + return {MatchShapeCode::kAssertEqualToImm, PrimValue::Int64(int_expr->value)}; + } + + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + if (slot->value_computed) { + return {MatchShapeCode::kAssertEqualToLoad, PrimValue::Int64(slot->index)}; + } + + // the value is not yet computed + ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; + if (expr.as<tir::VarNode>()) { + // It is a var we will populate it in this round. + + slot->value_computed = true; + ready_vars_.push_back(slot); + + return {MatchShapeCode::kStoreToHeap, PrimValue::Int64(slot->index)}; + } + + // otherwise, we skip and mark it as outstanding + return {MatchShapeCode::kNoOp, PrimValue::Int64(0)}; + } + //------------------------------------------------------- // Shape computations. //------------------------------------------------------- @@ -426,52 +492,33 @@ class VMShapeLowerMutator using runtime::relax_vm::MatchShapeCode; for (const MatchShapeTodoItem& item : match_todos) { - int64_t shape_len = static_cast<int64_t>(item.pattern.size()); bool all_nop = true; - int num_outstanding_exprs = 0; + bool any_nop = false; - Array<Expr> args = {item.input, shape_heap_, PrimValue::Int64(shape_len)}; + Array<Expr> args = {item.input, shape_heap_}; + + Expr match_op; + if (item.input->struct_info_.as<PrimStructInfoNode>()) { + match_op = builtin_match_prim_value_; + ICHECK_EQ(item.pattern.size(), 1); + } else { + match_op = builtin_match_shape_; + args.push_back(PrimValue::Int64(item.pattern.size())); + } for (PrimExpr expr : item.pattern) { - MatchShapeCode code = MatchShapeCode::kNoOp; - int64_t rvalue = 0; - if (auto* int_expr = expr.as<IntImmNode>()) { - code = MatchShapeCode::kAssertEqualToImm; - rvalue = int_expr->value; - } else { - auto it = slot_map_.find(expr); - ICHECK(it != slot_map_.end()); - auto* slot = it->second; - if (slot->value_computed) { - code = MatchShapeCode::kAssertEqualToLoad; - rvalue = slot->index; - } else { - // the value is not yet computed - ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; - if (expr.as<tir::VarNode>()) { - // if it is a var, we will populate it in this round. - // otherwise, we skip and mark it as outstanding - code = MatchShapeCode::kStoreToHeap; - rvalue = slot->index; - slot->value_computed = true; - ready_vars_.push_back(slot); - } else { - code = MatchShapeCode::kNoOp; - rvalue = 0; - ++num_outstanding_exprs; - } - } - } + auto [code, rvalue] = MakeMatchArgs(expr, require_value_computed); all_nop = all_nop && code == MatchShapeCode::kNoOp; + any_nop = any_nop || code == MatchShapeCode::kNoOp; args.push_back(PrimValue::Int64(static_cast<int>(code))); - args.push_back(PrimValue::Int64(rvalue)); + args.push_back(rvalue); } - if (num_outstanding_exprs != 0) { + if (any_nop) { outstanding_todos.push_back(item); } args.push_back(GetErrContext(item.err_ctx)); if (!all_nop) { - Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_}); + Call call(match_op, args, Attrs(), {void_sinfo_}); builder_->Emit(call, "_"); } } @@ -592,8 +639,20 @@ class VMShapeLowerMutator void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check, bool dynamic_only, const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { - // TODO(relax-team) add PrimValue checks later. - LOG(FATAL) << "MatchCast of PrimValue is not yet supported"; + // emit runtime check of shape + if (always_check || !IsBaseOf(PrimStructInfo(op->dtype), GetStructInfo(value))) { + // check_shape_info(value, ndim, err_ctx) + Call call(builtin_check_prim_value_info_, + {value, DataTypeImm(op->dtype), GetErrContext(err_ctx)}, Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + if (op->value.defined()) { + MatchShapeTodoItem item; + item.input = value; + item.pattern = {op->value.value()}; + item.err_ctx = err_ctx; + match_todos->push_back(item); + } } void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check, @@ -729,6 +788,9 @@ class VMShapeLowerMutator const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"}; const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"}; const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"}; + const ExternFunc builtin_match_prim_value_{"vm.builtin.match_prim_value"}; + const ExternFunc builtin_make_prim_value_{"vm.builtin.make_prim_value"}; + const ExternFunc builtin_check_prim_value_info_{"vm.builtin.check_prim_value_info"}; const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"}; const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"}; const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"}; diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index b7153ab50f..e47a399ae5 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -305,7 +305,9 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str DeviceAPI::Get(dev)->CopyDataFromTo(const_cast<DLTensor*>(from), to, stream); } -ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; } +ShapeTuple NDArray::Shape() const { + return static_cast<const NDArray::Container*>(data_.get())->shape_; +} runtime::DataType NDArray::DataType() const { return runtime::DataType(get_mutable()->dl_tensor.dtype); diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 8b27bb2d9e..a764c34cfa 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -66,6 +66,46 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); +/*! + * \brief Builtin match R.Prim function. + * + * \param input_value The runtime value provided by the user + * + * \param heap The VM storage for symbolic shapes + * + * \param code_value The op code, defined in MatchShapeCode, + * indicating how this value should be interpreted. + * + * \param reg The register, if using kStoreToHeap or + * kAssertEqualToLoad, or a literal value if using kAssertEqualToImm + * + * \param err_ctx An optional string used in error messages, providing + * additional context + * + * \sa MatchShape + */ +void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t reg, + Optional<String> err_ctx) { + int64_t* heap_data = heap == nullptr ? nullptr : static_cast<int64_t*>(heap->data); + MatchShapeCode code = static_cast<MatchShapeCode>(code_value); + + if (code == MatchShapeCode::kAssertEqualToImm) { + CHECK_EQ(input_value, reg) << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + << " PrimValue mismatch to specified constant."; + } else if (code == MatchShapeCode::kStoreToHeap) { + heap_data[reg] = input_value; + } else if (code == MatchShapeCode::kNoOp) { + } else if (code == MatchShapeCode::kAssertEqualToLoad) { + CHECK_EQ(input_value, heap_data[reg]) + << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + << " PrimValue mismatch to a previous populated value."; + } else { + LOG(FATAL) << "Unknown match shape code: " << static_cast<int>(code); + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue); + /*! * \brief Builtin match shape function. * \param args The packed function arguments. @@ -117,6 +157,30 @@ void MatchShape(TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body(MatchShape); +/*! + * \brief Builtin make prim value function. + * \param heap The shape heap to use + * \param shape_code The shape code of the value + * \param rv The return value. + * + * \sa MakeShape + */ +int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) { + // NOTE: heap can be nullptr + int64_t* heap_data = heap == nullptr ? nullptr : static_cast<int64_t*>(heap->data); + + MakeShapeCode code = static_cast<MakeShapeCode>(shape_code); + if (code == MakeShapeCode::kUseImm) { + return reg; + } else if (code == MakeShapeCode::kLoadShape) { + return heap_data[reg]; + } else { + LOG(FATAL) << "Invalid shape code: " << shape_code; + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue); + /*! * \brief Builtin make shape function. * \param args The packed function arguments. @@ -208,6 +272,30 @@ void CheckShapeInfo(ObjectRef arg, int ndim, Optional<String> err_ctx) { TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); +/*! + * \brief Builtin function to check if arg is PrimValue(dtype) + * \param arg The input argument. + * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for unknown dtype. + * \param err_ctx Additional context if error occurs. + */ +void CheckPrimValueInfo(TVMArgValue arg, DataType dtype, Optional<String> err_ctx) { + if (dtype.is_bool()) { + arg.operator bool(); + } else if (dtype.is_int()) { + arg.operator int64_t(); + } else if (dtype.is_uint()) { + arg.operator uint64_t(); + } else if (dtype.is_float()) { + arg.operator double(); + } else if (dtype.is_handle()) { + arg.operator void*(); + } else { + LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", unsupported dtype " << dtype; + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo); + /*! * \brief Builtin function to check if arg is Tuple with size elements. * \param arg The input argument. diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 82a6d6a2a4..b4816fd096 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -29,6 +29,7 @@ from tvm.contrib import utils, cc, popen_pool from tvm.relax.testing import nn from tvm.script import relax as R, tir as T, ir as I from tvm.relax.testing.vm import check_saved_func +from tvm.runtime import ShapeTuple EXEC_MODE = ["bytecode", "compiled"] @@ -515,6 +516,125 @@ def test_vm_relax_symbolic_shape(exec_mode): tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_symbolic_shape_tuple(exec_mode): + @I.ir_module + class mod: + @R.function + def main(shape: R.Shape(["m", "n"])): + m = T.int64() + n = T.int64() + return R.shape([2 * m, 3 * n]) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + func = vm["main"] + + assert func(ShapeTuple([2, 3])) == [4, 9] + + with pytest.raises(ValueError): + func(ShapeTuple([2, 3, 4])) + + with pytest.raises(TypeError): + func(R.prim_value(2)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_symbolic_prim_value(exec_mode): + @I.ir_module + class mod: + @R.function + def main(shape: R.Prim(value="n")): + n = T.int64() + return R.prim_value(n * n) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + func = vm["main"] + + assert func(2) == 4 + + with pytest.raises(tvm.TVMError): + func(ShapeTuple([2])) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_multiple_symbolic_prim_value(exec_mode): + """Like test_vm_relax_symbolic_prim_value, but with multiple variables""" + + @I.ir_module + class mod: + @R.function + def main( + # Provides definition of "n" + _n: R.Prim(value="n"), + # Requires definitions of both "n" and "m", but cannot + # provide either. + _shape: R.Shape(["n*2", "m*2"]), + # Provides definition of "m" + _m: R.Prim(value="m"), + ): + n = T.int64() + m = T.int64() + return R.shape([n * n, m + 1]) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + func = vm["main"] + + assert func(2, ShapeTuple([4, 12]), 6) == [4, 7] + + with pytest.raises(RuntimeError): + func(2, ShapeTuple([4, 12]), 1) + + with pytest.raises(tvm.TVMError): + func(ShapeTuple([2])) + + +@pytest.mark.xfail(reason="Current support for R.Prim with known value is primarily for int64") +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_prim_value_fp32(exec_mode): + """A PrimValue may be R.prim('float32') + + Unlike shape tuples, which must contain int64, a PrimValue may be + any type that can be represented as a single primitive value. + """ + + @I.ir_module + class mod: + @R.function + def main( + # First failure occurs during parsing. The syntactic + # sugar for symbolic variables assumes that all symbolic + # variables are int64, rather than using the type that is + # later declared. + _x: R.Prim(value="half_fill_value"), + ): + half_fill_value = T.float32() + # Second failure occurs when calling `relax.op.full`. The + # `fill_value` is expected to be a scalar constant + # (R.Tensor with 0-dim shape), not a primitive value, even + # though these are semantically the same. + return R.full(shape=[16, 16], fill_value=R.prim_value(2 * half_fill_value)) + + target = tvm.target.Target("llvm", host="llvm") + # Third failure occurs here. The current codegen assumes that all + # symbolic variables are int64. + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + func = vm["main"] + + res = func(16.0).numpy() + assert np.all(res == 32.0) + + @pytest.mark.parametrize("exec_mode", EXEC_MODE) def test_vm_relax_dyn_tir_shape(exec_mode): # case where TIR variables are unbound in generated PrimFunc