This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7ef36ebb5d [Unity] Support symbolic PrimValue arguments (#15980)
7ef36ebb5d is described below

commit 7ef36ebb5d056320676faede712f2052d92f7a5d
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Wed Oct 25 15:50:35 2023 -0500

    [Unity] Support symbolic PrimValue arguments (#15980)
    
    Prior this this commit, all symbolic variables needed to be defined
    either by tensor shapes, or by an explicit `tvm.runtime.ShapeTuple`
    argument.  This commit allows arguments `arg: R.Prim(value="n")` to
    serve as a source of definition for symbolic variables.
---
 src/relax/backend/vm/codegen_vm.cc     |   7 +-
 src/relax/backend/vm/vm_shape_lower.cc | 158 +++++++++++++++++++++++----------
 src/runtime/ndarray.cc                 |   4 +-
 src/runtime/relax_vm/builtin.cc        |  88 ++++++++++++++++++
 tests/python/relax/test_vm_build.py    | 120 +++++++++++++++++++++++++
 5 files changed, 325 insertions(+), 52 deletions(-)

diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index caee0a0c13..64b87c6c12 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -246,10 +246,11 @@ class CodeGenVM : public 
ExprFunctor<Instruction::Arg(const Expr&)> {
   Instruction::Arg VisitExpr_(const PrimValueNode* op) final {
     if (auto* int_imm = op->value.as<IntImmNode>()) {
       return builder_->ConvertConstant(int_imm->value);
-    } else {
-      auto* float_imm = op->value.as<FloatImmNode>();
-      ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now";
+    } else if (auto* float_imm = op->value.as<FloatImmNode>()) {
       return builder_->ConvertConstant(float_imm->value);
+    } else {
+      LOG(FATAL) << "PrimValue should only contain constant after  
VMShapeLower, "
+                 << "but received " << GetRef<Expr>(op) << " with type " << 
op->value->GetTypeKey();
     }
   }
 
diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
index 8b8eb33f5b..41b27ea625 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -347,6 +347,41 @@ class VMShapeLowerMutator
     return GetRef<Expr>(op);
   }
 
+  std::pair<Expr, Expr> MakeSymbolicShapeArg(const PrimExpr& expr) {
+    using runtime::relax_vm::MakeShapeCode;
+
+    if (auto* int_expr = expr.as<IntImmNode>()) {
+      return {PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm)),
+              PrimValue::Int64(int_expr->value)};
+    } else {
+      auto it = slot_map_.find(expr);
+      ICHECK(it != slot_map_.end());
+      auto* slot = it->second;
+      ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been 
computed";
+      return {PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape)),
+              PrimValue::Int64(slot->index)};
+    }
+  }
+
+  Expr VisitExpr_(const PrimValueNode* op) final {
+    using runtime::relax_vm::MakeShapeCode;
+    // Constant shape can be preserved.
+    bool is_const_value =
+        op->value->IsInstance<IntImmNode>() || 
op->value->IsInstance<FloatImmNode>();
+    if (is_const_value) {
+      return GetRef<Expr>(op);
+    }
+
+    Array<Expr> args = {shape_heap_};
+    auto [code, value_or_index] = MakeSymbolicShapeArg(op->value);
+    args.push_back(code);
+    args.push_back(value_or_index);
+
+    // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
+    Call call(builtin_make_prim_value_, args, Attrs(), 
{Downcast<StructInfo>(op->struct_info_)});
+    return call;
+  }
+
   Expr VisitExpr_(const ShapeExprNode* op) final {
     using runtime::relax_vm::MakeShapeCode;
     // Constant shape can be preserved.
@@ -359,17 +394,9 @@ class VMShapeLowerMutator
 
     Array<Expr> args = {shape_heap_, 
PrimValue::Int64(static_cast<int64_t>(op->values.size()))};
     for (PrimExpr expr : op->values) {
-      if (auto* int_expr = expr.as<IntImmNode>()) {
-        
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm)));
-        args.push_back(PrimValue::Int64(int_expr->value));
-      } else {
-        auto it = slot_map_.find(expr);
-        ICHECK(it != slot_map_.end());
-        auto* slot = it->second;
-        ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been 
computed";
-        
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape)));
-        args.push_back(PrimValue::Int64(slot->index));
-      }
+      auto [code, value_or_index] = MakeSymbolicShapeArg(expr);
+      args.push_back(code);
+      args.push_back(value_or_index);
     }
 
     // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
@@ -402,6 +429,45 @@ class VMShapeLowerMutator
   // Place this pass as last pass before codegen.
   StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final { 
return sinfo; }
 
+  /* \brief Internal utility function used for RunMatch()
+   *
+   * \param expr The expression to be matched
+   *
+   * \param require_value_computed Whether we require all expr to be computed.
+   *
+   * \return The MatchShapeCode, and a relax expression specifying the
+   *    argument used by that MatchShapeCode.
+   */
+  std::pair<runtime::relax_vm::MatchShapeCode, Expr> MakeMatchArgs(const 
PrimExpr& expr,
+                                                                   bool 
require_value_computed) {
+    using runtime::relax_vm::MatchShapeCode;
+
+    if (auto* int_expr = expr.as<IntImmNode>()) {
+      return {MatchShapeCode::kAssertEqualToImm, 
PrimValue::Int64(int_expr->value)};
+    }
+
+    auto it = slot_map_.find(expr);
+    ICHECK(it != slot_map_.end());
+    auto* slot = it->second;
+    if (slot->value_computed) {
+      return {MatchShapeCode::kAssertEqualToLoad, 
PrimValue::Int64(slot->index)};
+    }
+
+    // the value is not yet computed
+    ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not 
computed";
+    if (expr.as<tir::VarNode>()) {
+      // It is a var we will populate it in this round.
+
+      slot->value_computed = true;
+      ready_vars_.push_back(slot);
+
+      return {MatchShapeCode::kStoreToHeap, PrimValue::Int64(slot->index)};
+    }
+
+    // otherwise, we skip and mark it as outstanding
+    return {MatchShapeCode::kNoOp, PrimValue::Int64(0)};
+  }
+
   //-------------------------------------------------------
   // Shape computations.
   //-------------------------------------------------------
@@ -426,52 +492,33 @@ class VMShapeLowerMutator
 
     using runtime::relax_vm::MatchShapeCode;
     for (const MatchShapeTodoItem& item : match_todos) {
-      int64_t shape_len = static_cast<int64_t>(item.pattern.size());
       bool all_nop = true;
-      int num_outstanding_exprs = 0;
+      bool any_nop = false;
 
-      Array<Expr> args = {item.input, shape_heap_, 
PrimValue::Int64(shape_len)};
+      Array<Expr> args = {item.input, shape_heap_};
+
+      Expr match_op;
+      if (item.input->struct_info_.as<PrimStructInfoNode>()) {
+        match_op = builtin_match_prim_value_;
+        ICHECK_EQ(item.pattern.size(), 1);
+      } else {
+        match_op = builtin_match_shape_;
+        args.push_back(PrimValue::Int64(item.pattern.size()));
+      }
 
       for (PrimExpr expr : item.pattern) {
-        MatchShapeCode code = MatchShapeCode::kNoOp;
-        int64_t rvalue = 0;
-        if (auto* int_expr = expr.as<IntImmNode>()) {
-          code = MatchShapeCode::kAssertEqualToImm;
-          rvalue = int_expr->value;
-        } else {
-          auto it = slot_map_.find(expr);
-          ICHECK(it != slot_map_.end());
-          auto* slot = it->second;
-          if (slot->value_computed) {
-            code = MatchShapeCode::kAssertEqualToLoad;
-            rvalue = slot->index;
-          } else {
-            // the value is not yet computed
-            ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not 
computed";
-            if (expr.as<tir::VarNode>()) {
-              // if it is a var, we will populate it in this round.
-              // otherwise, we skip and mark it as outstanding
-              code = MatchShapeCode::kStoreToHeap;
-              rvalue = slot->index;
-              slot->value_computed = true;
-              ready_vars_.push_back(slot);
-            } else {
-              code = MatchShapeCode::kNoOp;
-              rvalue = 0;
-              ++num_outstanding_exprs;
-            }
-          }
-        }
+        auto [code, rvalue] = MakeMatchArgs(expr, require_value_computed);
         all_nop = all_nop && code == MatchShapeCode::kNoOp;
+        any_nop = any_nop || code == MatchShapeCode::kNoOp;
         args.push_back(PrimValue::Int64(static_cast<int>(code)));
-        args.push_back(PrimValue::Int64(rvalue));
+        args.push_back(rvalue);
       }
-      if (num_outstanding_exprs != 0) {
+      if (any_nop) {
         outstanding_todos.push_back(item);
       }
       args.push_back(GetErrContext(item.err_ctx));
       if (!all_nop) {
-        Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_});
+        Call call(match_op, args, Attrs(), {void_sinfo_});
         builder_->Emit(call, "_");
       }
     }
@@ -592,8 +639,20 @@ class VMShapeLowerMutator
   void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool 
always_check,
                         bool dynamic_only, const String& err_ctx,
                         std::vector<MatchShapeTodoItem>* match_todos) final {
-    // TODO(relax-team) add PrimValue checks later.
-    LOG(FATAL) << "MatchCast of PrimValue is not yet supported";
+    // emit runtime check of shape
+    if (always_check || !IsBaseOf(PrimStructInfo(op->dtype), 
GetStructInfo(value))) {
+      // check_shape_info(value, ndim, err_ctx)
+      Call call(builtin_check_prim_value_info_,
+                {value, DataTypeImm(op->dtype), GetErrContext(err_ctx)}, 
Attrs(), {void_sinfo_});
+      builder_->Emit(call, "_");
+    }
+    if (op->value.defined()) {
+      MatchShapeTodoItem item;
+      item.input = value;
+      item.pattern = {op->value.value()};
+      item.err_ctx = err_ctx;
+      match_todos->push_back(item);
+    }
   }
 
   void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool 
always_check,
@@ -729,6 +788,9 @@ class VMShapeLowerMutator
   const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"};
   const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"};
   const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"};
+  const ExternFunc builtin_match_prim_value_{"vm.builtin.match_prim_value"};
+  const ExternFunc builtin_make_prim_value_{"vm.builtin.make_prim_value"};
+  const ExternFunc 
builtin_check_prim_value_info_{"vm.builtin.check_prim_value_info"};
   const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"};
   const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"};
   const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"};
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index b7153ab50f..e47a399ae5 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -305,7 +305,9 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* 
to, TVMStreamHandle str
   DeviceAPI::Get(dev)->CopyDataFromTo(const_cast<DLTensor*>(from), to, stream);
 }
 
-ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; }
+ShapeTuple NDArray::Shape() const {
+  return static_cast<const NDArray::Container*>(data_.get())->shape_;
+}
 
 runtime::DataType NDArray::DataType() const {
   return runtime::DataType(get_mutable()->dl_tensor.dtype);
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 8b27bb2d9e..a764c34cfa 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -66,6 +66,46 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) {
 
 
TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap);
 
+/*!
+ * \brief Builtin match R.Prim function.
+ *
+ * \param input_value The runtime value provided by the user
+ *
+ * \param heap The VM storage for symbolic shapes
+ *
+ * \param code_value The op code, defined in MatchShapeCode,
+ *     indicating how this value should be interpreted.
+ *
+ * \param reg The register, if using kStoreToHeap or
+ *     kAssertEqualToLoad, or a literal value if using kAssertEqualToImm
+ *
+ * \param err_ctx An optional string used in error messages, providing
+ *     additional context
+ *
+ * \sa MatchShape
+ */
+void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, 
int64_t reg,
+                    Optional<String> err_ctx) {
+  int64_t* heap_data = heap == nullptr ? nullptr : 
static_cast<int64_t*>(heap->data);
+  MatchShapeCode code = static_cast<MatchShapeCode>(code_value);
+
+  if (code == MatchShapeCode::kAssertEqualToImm) {
+    CHECK_EQ(input_value, reg) << "RuntimeError: " << err_ctx.value_or("") << 
" match_cast error, "
+                               << " PrimValue mismatch to specified constant.";
+  } else if (code == MatchShapeCode::kStoreToHeap) {
+    heap_data[reg] = input_value;
+  } else if (code == MatchShapeCode::kNoOp) {
+  } else if (code == MatchShapeCode::kAssertEqualToLoad) {
+    CHECK_EQ(input_value, heap_data[reg])
+        << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, "
+        << " PrimValue mismatch to a previous populated value.";
+  } else {
+    LOG(FATAL) << "Unknown match shape code: " << static_cast<int>(code);
+  }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue);
+
 /*!
  * \brief Builtin match shape function.
  * \param args The packed function arguments.
@@ -117,6 +157,30 @@ void MatchShape(TVMArgs args, TVMRetValue* rv) {
 
 TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body(MatchShape);
 
+/*!
+ * \brief Builtin make prim value function.
+ * \param heap The shape heap to use
+ * \param shape_code The shape code of the value
+ * \param rv The return value.
+ *
+ * \sa MakeShape
+ */
+int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) {
+  // NOTE: heap can be nullptr
+  int64_t* heap_data = heap == nullptr ? nullptr : 
static_cast<int64_t*>(heap->data);
+
+  MakeShapeCode code = static_cast<MakeShapeCode>(shape_code);
+  if (code == MakeShapeCode::kUseImm) {
+    return reg;
+  } else if (code == MakeShapeCode::kLoadShape) {
+    return heap_data[reg];
+  } else {
+    LOG(FATAL) << "Invalid shape code: " << shape_code;
+  }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue);
+
 /*!
  * \brief Builtin make shape function.
  * \param args The packed function arguments.
@@ -208,6 +272,30 @@ void CheckShapeInfo(ObjectRef arg, int ndim, 
Optional<String> err_ctx) {
 
 
TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo);
 
+/*!
+ * \brief Builtin function to check if arg is PrimValue(dtype)
+ * \param arg The input argument.
+ * \param dtype Expected dtype of the PrimValue.  Can be DataType::Void() for 
unknown dtype.
+ * \param err_ctx Additional context if error occurs.
+ */
+void CheckPrimValueInfo(TVMArgValue arg, DataType dtype, Optional<String> 
err_ctx) {
+  if (dtype.is_bool()) {
+    arg.operator bool();
+  } else if (dtype.is_int()) {
+    arg.operator int64_t();
+  } else if (dtype.is_uint()) {
+    arg.operator uint64_t();
+  } else if (dtype.is_float()) {
+    arg.operator double();
+  } else if (dtype.is_handle()) {
+    arg.operator void*();
+  } else {
+    LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", unsupported 
dtype " << dtype;
+  }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo);
+
 /*!
  * \brief Builtin function to check if arg is Tuple with size elements.
  * \param arg The input argument.
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index 82a6d6a2a4..b4816fd096 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -29,6 +29,7 @@ from tvm.contrib import utils, cc, popen_pool
 from tvm.relax.testing import nn
 from tvm.script import relax as R, tir as T, ir as I
 from tvm.relax.testing.vm import check_saved_func
+from tvm.runtime import ShapeTuple
 
 EXEC_MODE = ["bytecode", "compiled"]
 
@@ -515,6 +516,125 @@ def test_vm_relax_symbolic_shape(exec_mode):
     tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, 
atol=1e-7)
 
 
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_relax_symbolic_shape_tuple(exec_mode):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(shape: R.Shape(["m", "n"])):
+            m = T.int64()
+            n = T.int64()
+            return R.shape([2 * m, 3 * n])
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    func = vm["main"]
+
+    assert func(ShapeTuple([2, 3])) == [4, 9]
+
+    with pytest.raises(ValueError):
+        func(ShapeTuple([2, 3, 4]))
+
+    with pytest.raises(TypeError):
+        func(R.prim_value(2))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_relax_symbolic_prim_value(exec_mode):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(shape: R.Prim(value="n")):
+            n = T.int64()
+            return R.prim_value(n * n)
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    func = vm["main"]
+
+    assert func(2) == 4
+
+    with pytest.raises(tvm.TVMError):
+        func(ShapeTuple([2]))
+
+
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
+    """Like test_vm_relax_symbolic_prim_value, but with multiple variables"""
+
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(
+            # Provides definition of "n"
+            _n: R.Prim(value="n"),
+            # Requires definitions of both "n" and "m", but cannot
+            # provide either.
+            _shape: R.Shape(["n*2", "m*2"]),
+            # Provides definition of "m"
+            _m: R.Prim(value="m"),
+        ):
+            n = T.int64()
+            m = T.int64()
+            return R.shape([n * n, m + 1])
+
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    func = vm["main"]
+
+    assert func(2, ShapeTuple([4, 12]), 6) == [4, 7]
+
+    with pytest.raises(RuntimeError):
+        func(2, ShapeTuple([4, 12]), 1)
+
+    with pytest.raises(tvm.TVMError):
+        func(ShapeTuple([2]))
+
+
+@pytest.mark.xfail(reason="Current support for R.Prim with known value is 
primarily for int64")
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_relax_prim_value_fp32(exec_mode):
+    """A PrimValue may be R.prim('float32')
+
+    Unlike shape tuples, which must contain int64, a PrimValue may be
+    any type that can be represented as a single primitive value.
+    """
+
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(
+            # First failure occurs during parsing.  The syntactic
+            # sugar for symbolic variables assumes that all symbolic
+            # variables are int64, rather than using the type that is
+            # later declared.
+            _x: R.Prim(value="half_fill_value"),
+        ):
+            half_fill_value = T.float32()
+            # Second failure occurs when calling `relax.op.full`.  The
+            # `fill_value` is expected to be a scalar constant
+            # (R.Tensor with 0-dim shape), not a primitive value, even
+            # though these are semantically the same.
+            return R.full(shape=[16, 16], fill_value=R.prim_value(2 * 
half_fill_value))
+
+    target = tvm.target.Target("llvm", host="llvm")
+    # Third failure occurs here.  The current codegen assumes that all
+    # symbolic variables are int64.
+    ex = relax.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    func = vm["main"]
+
+    res = func(16.0).numpy()
+    assert np.all(res == 32.0)
+
+
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 def test_vm_relax_dyn_tir_shape(exec_mode):
     # case where TIR variables are unbound in generated PrimFunc

Reply via email to