Laurawly commented on a change in pull request #6839: URL: https://github.com/apache/tvm/pull/6839#discussion_r539699599
########## File path: python/tvm/topi/cuda/nms.py ########## @@ -97,47 +97,44 @@ def get_valid_counts_ir( valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) out_indices = ib.buffer_ptr(out_indices) - atomic_add_return = ib.allocate( - valid_count.dtype, (1,), name="atomic_add_return", scope="local" - ) - one_count = tvm.tir.const(1, dtype=valid_count.dtype) one = tvm.tir.const(1, dtype=out.dtype) - score_threshold = tvm.ir.make_node("FloatImm", dtype="float32", value=score_threshold) + if isinstance(score_threshold, float): + score_threshold = tvm.ir.make_node("FloatImm", dtype="float32", value=score_threshold) id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index) score_index = tvm.ir.make_node("IntImm", dtype="int32", value=score_index) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = batch_size * num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - idxd = tvm.tir.indexdiv - - # initialize valid_count - with ib.if_scope(tid < batch_size): - valid_count[tid] = 0 - with ib.if_scope(tid < batch_size * num_anchors): - i = idxd(tid, num_anchors) - with ib.if_scope( - tvm.tir.all( - data[tid * elem_length + score_index] > score_threshold, - tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0), - ) - ): - atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_count[i]), one_count - ) - with ib.for_range(0, elem_length) as k: - out[tid * elem_length + k] = data[tid * elem_length + k] - out_indices[tid + k] = tid + k - with ib.else_scope(): - with ib.for_range(0, elem_length) as k: - out[tid * elem_length + k] = -one - out_indices[tid + k] = -one_count - + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = batch_size // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + valid_count[tid] = 0 + i = tid + with ib.for_range(0, num_anchors) as j: + score = data[(i * num_anchors + j) * elem_length + score_index] + with ib.if_scope( + tvm.tir.all( + score > score_threshold, + tvm.tir.any( + id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0 + ), + ) + ): + with ib.for_range(0, elem_length) as k: + out[(i * num_anchors + valid_count[i]) * elem_length + k] = data[ + (i * num_anchors + j) * elem_length + k + ] + out_indices[i * num_anchors + valid_count[i]] = j + valid_count[i] += 1 Review comment: So for your case, if you want to pass the correct sorted output and output indices after `get_valid_counts_ir`, what you need to do is to do an argsort with the value `out` and indices `out_indices` which you get from `get_valid_counts_ir` from current main: ```sort_tensor = argsort_thrust( score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype) ``` And that is what I did at the beginning of nms: https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/nms.py#L561 I don't pass in `valid_count` because I haven't reordered the invalid elements to the end of the output array yet but only mark those invalid elements as -1. But after sorting the whole array, the invalid elements will be sorted to the end of the array. So will their indices. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org