mbrookhart commented on a change in pull request #6839:
URL: https://github.com/apache/tvm/pull/6839#discussion_r537873092



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -54,64 +54,66 @@ def atomic_add(x, y):
     return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
 
 
-def rearrange_indices_out_ir(data, out, valid_box_count):
+def rearrange_indices_out_ir(data, output, valid_box_count):
     """Hybrid routine to rearrange nms output to
     move all valid entries to top.
 
     Parameters
     ----------
     data : tvm.te.Tensor or numpy NDArray
+        NMS output. 3-D tensor with shape
+        [batch_size, num_anchors, 6] or
+        [batch_size, num_anchors, 5], or 2-D
         tensor with shape [batch_size, num_anchors].
 
+    one: tvm.tir.const
+        Constant one with the same dtype as data.
+
+    batch_size: tvm.tir.IntImm or tvm.tir.Var
+        Batch size. We need to pass it in since hybrid script doesn't support
+        binding variable to symbolic dim.
+
+    num_anchors: tvm.tir.IntImm or tvm.tir.Var
+        Number of anchors.
 
     Returns
     -------
-    stmt : Stmt
-        The result IR statement.
+    output : tvm.te.Tensor or numpy NDArray
+        2-D tensor with shape [batch_size, num_anchors].
+
+    valid_box_count : tvm.te.Tensor or numpy NDArray
+        Tensor with shape [batch_size, 1], indicates
+        the valid number of boxes.
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
 
     ib = tvm.tir.ir_builder.create()
+
     data = ib.buffer_ptr(data)
-    out = ib.buffer_ptr(out)
     valid_box_count = ib.buffer_ptr(valid_box_count)
-
-    one_count = tvm.tir.const(1, dtype="int32")
-    atomic_add_return = ib.allocate(
-        valid_box_count.dtype, (batch_size,), name="atomic_add_return", 
scope="local"
-    )
-
-    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    nthread_tx = max_threads
-    tx = te.thread_axis("threadIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    len_inner_for = (batch_size * num_anchors) // nthread_tx + 2
-
-    idxd = tvm.tir.indexdiv
-    idxm = tvm.tir.indexmod
-
-    with ib.for_range(0, len_inner_for, name="i") as i:
-        idx = tx * len_inner_for + i
-        batch_idx = idxd(idx, num_anchors)
-        with ib.if_scope(idx < batch_size):
-            valid_box_count[idx] = 0
-        with ib.if_scope(idx < batch_size * num_anchors):
-            with ib.if_scope(data[idx] >= 0):
-                atomic_add_return[batch_idx] = atomic_add(
-                    tvm.tir.call_intrin("handle", "tir.address_of", 
valid_box_count[batch_idx]),
-                    one_count,
-                )
-                out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 
data[idx]
-            with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] < 
-num_anchors)):
-                atomic_add_return[batch_idx] = atomic_add(
-                    tvm.tir.call_intrin("handle", "tir.address_of", 
valid_box_count[batch_idx]),
-                    one_count,
-                )
-                out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0
-
-            with ib.if_scope(idxm(idx, num_anchors) >= 
valid_box_count[batch_idx]):
-                out[idx] = -1

Review comment:
       This implementation of rearrange_indices_out_ir returns an undersized 
tensor in some case, I think the threading isn't quite right, but i haven't 
been able to fix.

##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -97,47 +97,44 @@ def get_valid_counts_ir(
     valid_count = ib.buffer_ptr(valid_count)
     out = ib.buffer_ptr(out)
     out_indices = ib.buffer_ptr(out_indices)
-    atomic_add_return = ib.allocate(
-        valid_count.dtype, (1,), name="atomic_add_return", scope="local"
-    )
-    one_count = tvm.tir.const(1, dtype=valid_count.dtype)
     one = tvm.tir.const(1, dtype=out.dtype)
-    score_threshold = tvm.ir.make_node("FloatImm", dtype="float32", 
value=score_threshold)
+    if isinstance(score_threshold, float):
+        score_threshold = tvm.ir.make_node("FloatImm", dtype="float32", 
value=score_threshold)
     id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index)
     score_index = tvm.ir.make_node("IntImm", dtype="int32", value=score_index)
 
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    nthread_tx = max_threads
-    nthread_bx = batch_size * num_anchors // max_threads + 1
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    tid = bx * max_threads + tx
-    idxd = tvm.tir.indexdiv
-
-    # initialize valid_count
-    with ib.if_scope(tid < batch_size):
-        valid_count[tid] = 0
-    with ib.if_scope(tid < batch_size * num_anchors):
-        i = idxd(tid, num_anchors)
-        with ib.if_scope(
-            tvm.tir.all(
-                data[tid * elem_length + score_index] > score_threshold,
-                tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] 
>= 0),
-            )
-        ):
-            atomic_add_return[0] = atomic_add(
-                tvm.tir.call_intrin("handle", "tir.address_of", 
valid_count[i]), one_count
-            )
-            with ib.for_range(0, elem_length) as k:
-                out[tid * elem_length + k] = data[tid * elem_length + k]
-                out_indices[tid + k] = tid + k
-        with ib.else_scope():
-            with ib.for_range(0, elem_length) as k:
-                out[tid * elem_length + k] = -one
-                out_indices[tid + k] = -one_count
-
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = batch_size // max_threads + 1
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size):
+            valid_count[tid] = 0
+            i = tid
+            with ib.for_range(0, num_anchors) as j:
+                score = data[(i * num_anchors + j) * elem_length + score_index]
+                with ib.if_scope(
+                    tvm.tir.all(
+                        score > score_threshold,
+                        tvm.tir.any(
+                            id_index < 0, data[(i * num_anchors + j) * 
elem_length + id_index] >= 0
+                        ),
+                    )
+                ):
+                    with ib.for_range(0, elem_length) as k:
+                        out[(i * num_anchors + valid_count[i]) * elem_length + 
k] = data[
+                            (i * num_anchors + j) * elem_length + k
+                        ]
+                    out_indices[i * num_anchors + valid_count[i]] = j
+                    valid_count[i] += 1

Review comment:
       I do not have a race condition here because I'm not parallelizing over 
that axis. The issue with the current kernel is that it cannot implement the 
ONNX API because it cannot properly sort the output indices, and ONNX needs 
those outputs in the correct order to do some post-processing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to