This is an automated email from the ASF dual-hosted git repository. kparzysz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new c0e996e291 [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037) c0e996e291 is described below commit c0e996e2914585fe6b0c11fb2efdaea5c6b9daf9 Author: abhikran-quic <63697863+abhikran-q...@users.noreply.github.com> AuthorDate: Sat Jul 16 20:40:39 2022 +0530 [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037) * [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops * Fix documentation --- python/tvm/topi/hexagon/slice_ops/reshape.py | 13 +++--- python/tvm/topi/hexagon/utils.py | 21 ++++++++++ .../python/contrib/test_hexagon/infrastructure.py | 12 +++++- .../contrib/test_hexagon/topi/test_reshape.py | 47 +++++++++++++++------- 4 files changed, 72 insertions(+), 21 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/reshape.py b/python/tvm/topi/hexagon/slice_ops/reshape.py index 374c20bb72..2220253e21 100644 --- a/python/tvm/topi/hexagon/slice_ops/reshape.py +++ b/python/tvm/topi/hexagon/slice_ops/reshape.py @@ -40,13 +40,14 @@ def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor: return topi.transform.reshape(inp, new_shape) -def stir_schedule_nhwc_1024c( +def stir_sched_nhwc_2d_op( out: te.Tensor, inp: te.Tensor, out_layout: str, in_layout: str, + c_split: int, ) -> tir.Schedule: - """Schedule for output layout: nhwc-1024c-2d""" + """Schedule for output layout: nc-1024-2d, nc-2048-2d""" reshape_func = te.create_prim_func([inp, out]) sch = tir.Schedule(reshape_func, debug_mask="all") compute = sch.get_block("T_reshape") @@ -57,7 +58,7 @@ def stir_schedule_nhwc_1024c( jout, channel = sch.split(j, [None, inp.shape[3]]) height, width = sch.split(jout, [inp.shape[1], inp.shape[2]]) channelo, channeli = sch.split(channel, [None, 1024]) - channelio, channelii = sch.split(channeli, [None, 64]) + channelio, channelii = sch.split(channeli, [None, c_split]) sch.reorder(i, height, width, channelo, channelio, channelii) sch.vectorize(channelii) return sch @@ -101,8 +102,10 @@ def reshape_stir_schedule( sch : tvm.tir.Schedule The STIR schedule for slice reshape compute """ - if output_layout == "nhwc-8h2w32c2w-2d": + if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]: return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout) if output_layout == "nc-1024-2d": - return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout) + return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 64) + if output_layout == "nc-2048-2d": + return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 128) raise RuntimeError(f"Unexpected layout '{output_layout}'") diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 4458c55e62..3b8914ffe9 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -87,6 +87,21 @@ def nc_1024_2d(n, c): return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] +def nhwc_2048c_2d(n, h, w, c): + """Return index map for nhwc_2048 2d layout""" + return [n, h, w, c // 2048, te.AXIS_SEPARATOR, c % 2048] + + +def nc_2048_2d(n, c): + """Return index map for nc_2048 2d layout""" + return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] + + +def nhwc_8h8w32c_2d(n, h, w, c): + """Return index map for nhwc_8h8w32c 2d layout""" + return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32] + + def iohw_16i32o2i_1d(height, width, in_channel, out_channel): return [ in_channel // 32, @@ -129,4 +144,10 @@ def get_layout_transform_fn(layout): return nc_1024c_2d if layout == "iohw-16i32o2i-1d": return iohw_16i32o2i_1d + if layout == "nhwc-2048c-2d": + return nhwc_2048c_2d + if layout == "nc-2048-2d": + return nc_2048_2d + if layout == "nhwc-8h8w32c-2d": + return nhwc_8h8w32c_2d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index a1fbfdefcd..7108ac5598 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -256,7 +256,17 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): if new_layout == "nhwc-1024c-2d": N, H, W, C = arr_np.shape return arr_np.reshape([N, H, W, C // 1024, 1024]) - raise RuntimeError(f"Unexpected new_layout '{new_layout}'") + if new_layout == "nc-2048-2d": + N, C = arr_np.shape + return arr_np.reshape([N, C // 2048, 2048]) + if new_layout == "nhwc-2048c-2d": + N, H, W, C = arr_np.shape + return arr_np.reshape([N, H, W, C // 2048, 2048]) + if new_layout in ["nhwc-8h8w32c-2d"]: + n, h, w, c = arr_np.shape + return arr_np.reshape([n, h // 8, 8, w // 8, 8, c // 32, 32]).transpose( + 0, 1, 3, 5, 2, 4, 6 + ) if current_layout == "nc": n, c = arr_np.shape diff --git a/tests/python/contrib/test_hexagon/topi/test_reshape.py b/tests/python/contrib/test_hexagon/topi/test_reshape.py index 2def86ad83..7df29a02ab 100644 --- a/tests/python/contrib/test_hexagon/topi/test_reshape.py +++ b/tests/python/contrib/test_hexagon/topi/test_reshape.py @@ -56,23 +56,23 @@ def reshape_helper( input_layout, ) with tvm.transform.PassContext(opt_level=3): - print("output of tvm.lower", tvm.lower(tir_s.mod, name=func)) runtime_module = tvm.build(tir_s.mod, target=target, name=func) mod = hexagon_session.load_module(runtime_module) - a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type) + a_numpy = (np.random.uniform(-10, 10, input_shape)).astype(data_type) ref = np.reshape(a_numpy, output_shape) input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout) ref_np_transformed = transform_numpy(ref, "nhwc", output_layout) input_axis_sep = [4] - if output_layout == "nhwc-8h2w32c2w-2d": + if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]: output_axis_sep = [4] - elif output_layout == "nc-1024-2d": + elif output_layout in ["nc-1024-2d", "nc-2048-2d"]: output_axis_sep = [2] else: raise RuntimeError(f"Unexpected layout '{output_layout}'") + a_tvm = allocate_hexagon_array( hexagon_session.device, data=input_np_transformed, @@ -86,11 +86,12 @@ def reshape_helper( axis_separators=output_axis_sep, mem_scope="global.vtcm", ) + mod(a_tvm, output) np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0) -batch_flatten_tests = ( +batch_flatten_fp16_tests = ( ([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"), ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"), ([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"), @@ -98,14 +99,17 @@ batch_flatten_tests = ( ) +batch_flatten_uint8_tests = ( + ([1, 1, 1, 2048], [1, 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"), + ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"), +) + + class BaseTestBatchFlatten: - ( - input_shape, - output_shape, - input_layout, - output_layout, - data_type, - ) = tvm.testing.parameters(*batch_flatten_tests) + (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters( + *batch_flatten_fp16_tests, + *batch_flatten_uint8_tests, + ) class TestBatchFlatten(BaseTestBatchFlatten): @@ -132,11 +136,24 @@ class TestBatchFlatten(BaseTestBatchFlatten): ) +reshape_fp16_tests = ( + ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"), + ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"), +) + + +reshape_uint8_tests = ( + ([1, 8, 8, 128], [1, 8, 16, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"), + ([1, 16, 64, 128], [1, 16, 128, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"), +) + + class BaseTestReshape(BaseTestBatchFlatten): (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters( - *batch_flatten_tests, - ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"), - ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"), + *batch_flatten_fp16_tests, + *batch_flatten_uint8_tests, + *reshape_fp16_tests, + *reshape_uint8_tests, )