This is an automated email from the ASF dual-hosted git repository. mbrookhart pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 21fc3bb [TOPI] Use fixed thread block size in unique op for Vulkan (#7718) 21fc3bb is described below commit 21fc3bb08e2cc4928e1cd06f2280fe83431c80f0 Author: masahi <masahi...@gmail.com> AuthorDate: Tue Mar 23 00:29:21 2021 +0900 [TOPI] Use fixed thread block size in unique op for Vulkan (#7718) * [TOPI] Use fixed thread block size in unique op for Vulkan * forgot to add min for non vk backend --- python/tvm/topi/cuda/unique.py | 15 ++++++++--- tests/python/unittest/test_target_codegen_spirv.py | 30 +++++++++++++++++----- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py index 02a5cf3..2bca3c4 100644 --- a/python/tvm/topi/cuda/unique.py +++ b/python/tvm/topi/cuda/unique.py @@ -24,6 +24,15 @@ from .sort import sort, argsort from ..utils import ceil_div +def _get_max_threads(batch_size): + target = tvm.target.Target.current() + max_threads = tvm.target.Target.current(allow_none=False).max_num_threads + if "vulkan" in str(target) and not isinstance(batch_size, tvm.tir.IntImm): + # SPIR-V does not support dynamic thread group size + return max_threads + return tir.min(batch_size, max_threads) + + def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): """Low level IR to calculate adjacent difference in an 1-D array. @@ -46,7 +55,7 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): data_ptr = ib.buffer_ptr(data) output_ptr = ib.buffer_ptr(output) batch_size = data.shape[0] - max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + max_threads = _get_max_threads(batch_size) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) @@ -157,7 +166,7 @@ def _calc_unique_ir( unique_seq_indices_ptr = ib.buffer_ptr(indices) batch_size = data.shape[0] - max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + max_threads = _get_max_threads(batch_size) # if need to return counts if isinstance(counts, tir.Buffer): @@ -238,7 +247,7 @@ def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): inc_scan_ptr = ib.buffer_ptr(inc_scan) first_occurence_ptr = ib.buffer_ptr(first_occurence) batch_size = argsorted_indices.shape[0] - max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + max_threads = _get_max_threads(batch_size) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) diff --git a/tests/python/unittest/test_target_codegen_spirv.py b/tests/python/unittest/test_target_codegen_spirv.py index 68be5c4..bf47bbe 100644 --- a/tests/python/unittest/test_target_codegen_spirv.py +++ b/tests/python/unittest/test_target_codegen_spirv.py @@ -72,17 +72,18 @@ def test_bool_load(): tvm.testing.assert_allclose(b.asnumpy(), ref) +def check_mod(mod, x_np, res_np): + target = "vulkan" + ctx = tvm.context(target, 0) + ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) + res = ex.evaluate()(x_np).asnumpy() + tvm.testing.assert_allclose(res, res_np, atol=1e-5) + + def test_pushconstants(): if not tvm.testing.device_enabled("vulkan"): return - def check_mod(mod, x_np, res_np): - target = "vulkan" - ctx = tvm.context(target, 0) - ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) - res = ex.evaluate()(x_np).asnumpy() - tvm.testing.assert_allclose(res, res_np, atol=1e-5) - # Three 32 bit pushconstants: any_dim, stride, stride dtype = "float32" x = relay.var("x", shape=(relay.Any(),), dtype=dtype) @@ -104,6 +105,21 @@ def test_pushconstants(): check_mod(mod, x_np, res_np) +def test_unique(): + if not tvm.testing.device_enabled("vulkan"): + return + + dtype = "int32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + [unique, _, num_unique] = relay.unique(x, is_sorted=True) + mod["main"] = relay.Function([x], relay.op.strided_slice(unique, begin=[0], end=num_unique)) + x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) + res_np = np.unique(x_np) + check_mod(mod, x_np, res_np) + + if __name__ == "__main__": test_bool_load() test_pushconstants() + test_unique()