This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 21fc3bb  [TOPI] Use fixed thread block size in unique op for Vulkan 
(#7718)
21fc3bb is described below

commit 21fc3bb08e2cc4928e1cd06f2280fe83431c80f0
Author: masahi <masahi...@gmail.com>
AuthorDate: Tue Mar 23 00:29:21 2021 +0900

    [TOPI] Use fixed thread block size in unique op for Vulkan (#7718)
    
    * [TOPI] Use fixed thread block size in unique op for Vulkan
    
    * forgot to add min for non vk backend
---
 python/tvm/topi/cuda/unique.py                     | 15 ++++++++---
 tests/python/unittest/test_target_codegen_spirv.py | 30 +++++++++++++++++-----
 2 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py
index 02a5cf3..2bca3c4 100644
--- a/python/tvm/topi/cuda/unique.py
+++ b/python/tvm/topi/cuda/unique.py
@@ -24,6 +24,15 @@ from .sort import sort, argsort
 from ..utils import ceil_div
 
 
+def _get_max_threads(batch_size):
+    target = tvm.target.Target.current()
+    max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
+    if "vulkan" in str(target) and not isinstance(batch_size, tvm.tir.IntImm):
+        # SPIR-V does not support dynamic thread group size
+        return max_threads
+    return tir.min(batch_size, max_threads)
+
+
 def _calc_adjacent_diff_ir(data, output, binop=tir.Sub):
     """Low level IR to calculate adjacent difference in an 1-D array.
 
@@ -46,7 +55,7 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub):
     data_ptr = ib.buffer_ptr(data)
     output_ptr = ib.buffer_ptr(output)
     batch_size = data.shape[0]
-    max_threads = tir.min(batch_size, 
tvm.target.Target.current(allow_none=False).max_num_threads)
+    max_threads = _get_max_threads(batch_size)
     with ib.new_scope():
         nthread_tx = max_threads
         nthread_bx = ceil_div(batch_size, max_threads)
@@ -157,7 +166,7 @@ def _calc_unique_ir(
         unique_seq_indices_ptr = ib.buffer_ptr(indices)
 
     batch_size = data.shape[0]
-    max_threads = tir.min(batch_size, 
tvm.target.Target.current(allow_none=False).max_num_threads)
+    max_threads = _get_max_threads(batch_size)
 
     # if need to return counts
     if isinstance(counts, tir.Buffer):
@@ -238,7 +247,7 @@ def _calc_first_occurence_ir(argsorted_indices, inc_scan, 
first_occurence):
     inc_scan_ptr = ib.buffer_ptr(inc_scan)
     first_occurence_ptr = ib.buffer_ptr(first_occurence)
     batch_size = argsorted_indices.shape[0]
-    max_threads = tir.min(batch_size, 
tvm.target.Target.current(allow_none=False).max_num_threads)
+    max_threads = _get_max_threads(batch_size)
     with ib.new_scope():
         nthread_tx = max_threads
         nthread_bx = ceil_div(batch_size, max_threads)
diff --git a/tests/python/unittest/test_target_codegen_spirv.py 
b/tests/python/unittest/test_target_codegen_spirv.py
index 68be5c4..bf47bbe 100644
--- a/tests/python/unittest/test_target_codegen_spirv.py
+++ b/tests/python/unittest/test_target_codegen_spirv.py
@@ -72,17 +72,18 @@ def test_bool_load():
     tvm.testing.assert_allclose(b.asnumpy(), ref)
 
 
+def check_mod(mod, x_np, res_np):
+    target = "vulkan"
+    ctx = tvm.context(target, 0)
+    ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
+    res = ex.evaluate()(x_np).asnumpy()
+    tvm.testing.assert_allclose(res, res_np, atol=1e-5)
+
+
 def test_pushconstants():
     if not tvm.testing.device_enabled("vulkan"):
         return
 
-    def check_mod(mod, x_np, res_np):
-        target = "vulkan"
-        ctx = tvm.context(target, 0)
-        ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
-        res = ex.evaluate()(x_np).asnumpy()
-        tvm.testing.assert_allclose(res, res_np, atol=1e-5)
-
     # Three 32 bit pushconstants: any_dim, stride, stride
     dtype = "float32"
     x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
@@ -104,6 +105,21 @@ def test_pushconstants():
     check_mod(mod, x_np, res_np)
 
 
+def test_unique():
+    if not tvm.testing.device_enabled("vulkan"):
+        return
+
+    dtype = "int32"
+    x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
+    mod = tvm.IRModule()
+    [unique, _, num_unique] = relay.unique(x, is_sorted=True)
+    mod["main"] = relay.Function([x], relay.op.strided_slice(unique, 
begin=[0], end=num_unique))
+    x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
+    res_np = np.unique(x_np)
+    check_mod(mod, x_np, res_np)
+
+
 if __name__ == "__main__":
     test_bool_load()
     test_pushconstants()
+    test_unique()

Reply via email to