This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cd08356e66 [TIR] Fix segfaults from ordering of Let/Assert in 
MakePackedAPI (#16543)
cd08356e66 is described below

commit cd08356e66951ec6eceb9dbd7ea21289a350eae8
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Thu Apr 4 18:29:45 2024 -0500

    [TIR] Fix segfaults from ordering of Let/Assert in MakePackedAPI (#16543)
    
    * [TIR] Fix segfaults from ordering of Let/Assert in MakePackedAPI
    
    Prior to this commit, the `MakePackedAPI` pass would output steps in
    the following order:
    
    1. Check the number of arguments.
    2. All `LetStmt` produced by the `ArgBinder`
    3. `AssertStmt` for the Type code checks for each argument.
    4. Additional `AssertStmt` produced by the `ArgBinder`.
    
    This order can cause segfaults if a function was provided incorrect
    arguments.  For example, an integer argument passed to a function
    expecting a `DLTensor*` would be dereferenced to find the tensor's
    data pointer (step (2)) before checking if it is valid to perform that
    dereference (step (3)).  The same would occur when reading the size of
    a tensor's axes (step (2)) before checking whether the tensor is the
    correct dimensionality (step (4)).
    
    This commit updates the steps to the following order.
    
    1. Check the number of arguments.
    2. Check the type code of each argument.
    3. All `LetStmt` and `AssertStmt` produced by the `ArgBinder`, in the
       order in which they are generated.
    
    * Remove unrelated change
    
    * skip flaky test
---
 src/tir/transforms/arg_binder.cc                   | 46 ++++++++++----
 src/tir/transforms/arg_binder.h                    | 38 ++++++++++--
 src/tir/transforms/make_packed_api.cc              | 58 ++++++++++++------
 tests/python/tir-base/test_debug_info.py           |  4 +-
 .../test_tir_transform_make_packed_api.py          | 71 +++++++++++++++++++++-
 5 files changed, 179 insertions(+), 38 deletions(-)

diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index f3d799365d..5b9e005b7e 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -155,6 +155,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
   const DataType tvm_shape_type = DataType::ShapeIndex();
   const DataType tvm_ndim_type = DataType::Int(32);
   const Stmt nop = Evaluate(0);
+
+  init_nest_.emplace_back(AssertStmt(
+      !Call(DataType::Bool(), builtin::isnullptr(), {handle}),
+      tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor* 
pointer"), nop));
+
   // dimension checks
   PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
 
@@ -173,7 +178,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
   std::ostringstream ndim_err_msg;
   ndim_err_msg << arg_name << ".ndim is expected to equal " << 
buffer->shape.size();
   auto msg = tvm::tir::StringImm(ndim_err_msg.str());
-  asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
+  init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
   // type checks
   std::ostringstream type_err_msg;
   type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
@@ -186,18 +191,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
   if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) 
||
         buffer->dtype == DataType::UInt(4))) {
     auto type_msg = tvm::tir::StringImm(type_err_msg.str());
-    asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
     asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
   }
-  // data field
-  if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, 
builtin::kArrData),
-            arg_name + ".data", true)) {
-    Var vptr(buffer->data);
-    def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
-    // mark alignment of external bufs
-    init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment,
-                                     IntImm(DataType::Int(32), 
buffer->data_alignment), nop));
-  }
 
   // shape field
   Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), 
buffer->shape.size())}, tvm_shape_type,
@@ -243,7 +238,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
           foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, 
b, span); },
                 const_true(1), conds),
           stride_msg, Evaluate(0));
-      check = IfThenElse(Not(v_strides_is_null), check, Stmt());
+      check = IfThenElse(Not(v_strides_is_null), check);
       asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
     }
   } else if (buffer->buffer_type == kAutoBroadcast) {
@@ -300,6 +295,33 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
         arg_name + ".device_type", true);
   Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, 
builtin::kArrDeviceId),
         arg_name + ".device_id", true);
+
+  // Data field.  Because the validation of the data field may depend
+  // on a dynamic size defined by the other DLTensor* parameters, this
+  // field must be generated last.
+  if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, 
builtin::kArrData),
+            arg_name + ".data", true)) {
+    Var vptr(buffer->data);
+
+    // Check if the data pointer is NULL.  This check is skipped for
+    // size-0 arrays, since CUDA provides a NULL pointer for size-zero
+    // allocations.
+    auto alloc_size = [&]() -> PrimExpr {
+      PrimExpr product = IntImm(buffer->DefaultIndexType(), 1);
+      for (const auto& dim : buffer->shape) {
+        product *= dim;
+      }
+      return product;
+    }();
+    asserts_.emplace_back(AssertStmt(
+        alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(), 
{vptr}),
+        tvm::tir::StringImm(arg_name + " is expected to have non-NULL data 
pointer"), nop));
+
+    def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
+    // mark alignment of external bufs
+    init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment,
+                                     IntImm(DataType::Int(32), 
buffer->data_alignment), nop));
+  }
 }
 
 }  // namespace tir
diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h
index 657ebdbec1..68cbbb6773 100644
--- a/src/tir/transforms/arg_binder.h
+++ b/src/tir/transforms/arg_binder.h
@@ -104,17 +104,43 @@ class ArgBinder {
 
   /*! \return The defs generated in binding. */
   const std::vector<Var>& defs() const { return defs_; }
-  /*! \return The asserts generated in binding */
+
+  /*! \return The asserts generated in binding
+   *
+   * This contains statements that assert the correct value has been
+   * bound.  For example, `binder.Bind(var, expr_1)` will produce an
+   * entry mapping `var` to `expr_1` in the `binder.defs()`.  If
+   * `binder.Bind(var, expr_2)` is called later, then this will
+   * produce an assert statemtn that `expr_1 == expr_2`.
+   *
+   * Note: Some assert statements produced by BindDLTensor are located
+   * in `binder.init_nest()`, not within `binder.asserts()`.  This is
+   * deliberate, as some values may require checks prior to
+   * initialization.  (e.g. Intializing `m = dl_tensor->shape[3]`
+   * requires first asserting that `3 < dl_tensor->ndim`.)
+   */
   const std::vector<Stmt>& asserts() const { return asserts_; }
+
   /*!
    * \brief Initialization nest generated
-   *  This is only non-empty when BindDLTensor is called.
    *
-   * \note The binder may choose to generate a let statement
-   *  and simply put def_map to map Variable to itself,
-   *  or update def_map to directly map to new value and not generate let 
statement.
+   * This contains both variable bindings and any assert statements
+   * that are required in order to safely produce those variable
+   * bindings.
+   *
+   * \note Variable bindings may be implemented either as a `LetStmt`
+   *     that defines the variable, or as a variable replacement.  Any
+   *     bindings implemented as a `LetStmt` will be in the
+   *     initialization list.  Any bindings implemented as a variable
+   *     replacement will be stored in the `var_def` map.
+   *
+   *     A `tir::LetStmt` is usually generated when binding to a
+   *     `DLTensor`.  This requires loading values from memory, which
+   *     should only be performed once.  If the binding to a
+   *     `DLTensor` were implemented as a variable replacement, it
+   *     would load values from memory once for each usage of the
+   *     variable.
    *
-   *  Let statement is usually generated when bind to DLTensor and memory load 
is involved.
    * \return The initialization nest generated during binding.
    */
   const std::vector<Stmt>& init_nest() const { return init_nest_; }
diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index 94e245b636..bf1f3a9e7f 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -183,6 +183,11 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, 
std::string msg) {
   return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
 }
 
+inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
+  Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
+  return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
+}
+
 /* \brief Return the global_symbol of the function, if it should be updated
  *
  * \param func The function to be inspected
@@ -255,8 +260,6 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   std::unordered_map<const VarNode*, PrimExpr> vmap;
   ArgBinder binder(&vmap);
 
-  seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));
-
   // ---------------------------
   // local function definitions
   // load i-th argument as type t
@@ -273,6 +276,33 @@ PrimFunc MakePackedAPI(PrimFunc func) {
     return res;
   };
 
+  // Find the device API context argument based on name
+  for (const auto& param : func_ptr->params) {
+    if (param->name_hint == kDeviceContextVar) {
+      num_args--;
+      v_resource_handle = param;
+      break;
+    }
+  }
+
+  // Assert correct type codes for each argument.  This must be done
+  // *before* any initialization steps produced by
+  // `binder.BindDLTensor()`.  The validity of those initialization
+  // steps depends on the correct types being present, and must not
+  // occur before the type codes are actually checked.
+  seq_init.push_back(MakeAssertEQ(v_num_packed_args, num_args, [&]() -> 
std::string {
+    std::ostringstream error_message;
+    error_message << name_hint << ": num_args should be " << num_args;
+    return error_message.str();
+  }()));
+
+  seq_init.push_back(
+      MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer 
was NULL"));
+  seq_init.push_back(
+      MakeAssertNotNull(buf_packed_arg_type_ids->data, name_hint + ": int* 
type_codes was NULL"));
+
+  seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));
+
   // Need to delay binding of the buffers, in case some arguments also
   // appear in the buffer.
   std::vector<std::pair<PrimExpr, Var>> var_def;
@@ -281,10 +311,9 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
     Var param = func_ptr->params[i];
 
-    // Pluck the device API context out based on name
+    // Ignore the device context argument, as it will still be passed
+    // as a native argument.
     if (param->name_hint == kDeviceContextVar) {
-      num_args--;
-      v_resource_handle = param;
       continue;
     }
 
@@ -301,18 +330,18 @@ PrimFunc MakePackedAPI(PrimFunc func) {
     if (t.is_handle()) {
       std::ostringstream msg;
       msg << name_hint << ": Expect arg[" << i << "] to be pointer";
-      seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == 
kTVMNDArrayHandle ||
-                                            tcode == kTVMDLTensorHandle || 
tcode == kTVMNullptr,
-                                        tvm::tir::StringImm(msg.str()), nop));
+      seq_init.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == 
kTVMNDArrayHandle ||
+                                           tcode == kTVMDLTensorHandle || 
tcode == kTVMNullptr,
+                                       tvm::tir::StringImm(msg.str()), nop));
     } else if (t.is_int() || t.is_uint()) {
       std::ostringstream msg;
       msg << name_hint << ": Expect arg[" << i << "] to be int";
-      seq_check.emplace_back(AssertStmt(tcode == kDLInt, 
tvm::tir::StringImm(msg.str()), nop));
+      seq_init.emplace_back(AssertStmt(tcode == kDLInt, 
tvm::tir::StringImm(msg.str()), nop));
     } else {
       ICHECK(t.is_float());
       std::ostringstream msg;
       msg << name_hint << ": Expect arg[" << i << "] to be float";
-      seq_check.emplace_back(AssertStmt(tcode == kDLFloat, 
tvm::tir::StringImm(msg.str()), nop));
+      seq_init.emplace_back(AssertStmt(tcode == kDLFloat, 
tvm::tir::StringImm(msg.str()), nop));
     }
   }
 
@@ -360,13 +389,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   // Return error code of zero on success
   body = SeqStmt({body, Evaluate(ret(Integer(0)))});
 
-  // Apply all argument assertions
-  std::ostringstream num_args_error;
-  num_args_error << name_hint << ": num_args should be " << num_args;
-  std::vector<Stmt> arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, 
num_args_error.str())};
-  body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, 
binder.asserts(),
-                    arg_buffer_declarations},
-                   body);
+  body = MergeNest(
+      {seq_init, binder.init_nest(), seq_check, binder.asserts(), 
arg_buffer_declarations}, body);
   func_ptr->body = body;
   func_ptr->params = args;
 
diff --git a/tests/python/tir-base/test_debug_info.py 
b/tests/python/tir-base/test_debug_info.py
index 7fc9bcf316..ecd25b3a67 100644
--- a/tests/python/tir-base/test_debug_info.py
+++ b/tests/python/tir-base/test_debug_info.py
@@ -141,7 +141,7 @@ def test_llvm_ir_debug_info():
     source = runtime_module.get_source()
 
     locations = find_di_locations(source)
-    assert len(locations) == 35
+    assert len(locations) == 41
 
 
 def test_llvm_ir_debug_accuracy():
@@ -162,7 +162,7 @@ def test_llvm_ir_debug_accuracy():
 
     # Check that it matches the expected line number (in main.tir)
     debug_line_no = int(locations[directive_idx])
-    assert debug_line_no == 56
+    assert debug_line_no == 60
 
 
 def test_building_without_llvm_equivalent():
diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py 
b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
index 2f871a246f..bf182654d7 100644
--- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py
+++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
@@ -284,5 +284,74 @@ def 
test_subroutine_call_to_externally_visible_subroutine():
     )
 
 
+def test_function_call_with_wrong_argument_count():
+    """Argument counts must be checked before accessing the type codes"""
+
+    @T.prim_func
+    def func(
+        A: T.Buffer([16, 16], "int32"),
+        B: T.Buffer([16, 16], "int32"),
+        C: T.Buffer([16, 16], "int32"),
+        D: T.Buffer([16, 16], "int32"),
+    ):
+        pass
+
+    built = tvm.build(func, target="llvm")
+
+    with pytest.raises(tvm.TVMError):
+        built()
+
+
+def test_function_call_with_wrong_type_code():
+    """Type codes must be checked before accessing the arguments"""
+
+    @T.prim_func
+    def func(A: T.Buffer([16, 16], "int32")):
+        pass
+
+    built = tvm.build(func, target="llvm")
+
+    with pytest.raises(tvm.TVMError):
+        built(0)
+
+
+def test_function_call_with_null_data_pointer():
+    """The data pointer must be checked before accessing the array"""
+
+    @T.prim_func
+    def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
+        for i, j in T.grid(16, 16):
+            B[i, j] = A[i, j]
+
+    built = tvm.build(func, target="llvm")
+
+    A = tvm.nd.empty([16, 16], "int32", tvm.cpu())
+    B = tvm.nd.empty([16, 16], "int32", tvm.cpu())
+
+    A.handle.contents.data = 0
+
+    with pytest.raises(tvm.TVMError):
+        built(A, B)
+
+
+def test_function_call_with_wrong_dimensionality():
+    """The dimensionality must be checked before validating the shape"""
+
+    @T.prim_func
+    def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
+        for i, j in T.grid(16, 16):
+            B[i, j] = A[i, j]
+
+    built = tvm.build(func, target="llvm")
+
+    A = tvm.nd.empty([16], "int32", tvm.cpu())
+    B = tvm.nd.empty([16], "int32", tvm.cpu())
+
+    A.handle.contents.data = 0
+
+    with pytest.raises(tvm.TVMError):
+        built(A, B)
+
+
 if __name__ == "__main__":
-    test_makeapi()
+    tvm.testing.main()

Reply via email to