This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 6c34361369 [Hexagon] Adapt some intrinsics for high vector lanes (#14345) 6c34361369 is described below commit 6c3436136926613367ac3e19ece8bed90c4d5efb Author: apeskov <pesko...@gmail.com> AuthorDate: Mon Mar 27 08:08:30 2023 +0400 [Hexagon] Adapt some intrinsics for high vector lanes (#14345) * [HEX] Enhanced vector lanes for some intrinsics * fix pylint Signed-off-by: Alexander Peskov <pesko...@gmail.com> * fix lint 2 Signed-off-by: Alexander Peskov <pesko...@gmail.com> * Fix typo Signed-off-by: Alexander Peskov <pesko...@gmail.com> --------- Signed-off-by: Alexander Peskov <pesko...@gmail.com> --- python/tvm/topi/hexagon/tensor_intrin.py | 309 +++++++++++++++------ .../test_hexagon/test_fixed_point_multiply.py | 138 ++++++++- 2 files changed, 363 insertions(+), 84 deletions(-) diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index 3e9fd47b0f..24bbacf37c 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -22,44 +22,165 @@ from tvm.ir import register_intrin_lowering from tvm import te +def get_lanes(dtype: str): + if "x" not in dtype: + return 1 + + _, lanes = dtype.split("x") + return int(lanes) + + +def is_vector_type(dtype: str): + return get_lanes(dtype) != 1 + + +def is_power_of_2(n: int): + return (n & (n - 1) == 0) and n != 0 + + +def _adapt_to_highest_lanes(*args, intrinsic=None, intrinsic_lanes: int = 0): + """Apply provided lowering intrinsic to arguments with longer vector data type. + + This wrapper will do next actions: + * Split each argument into chunks with size equal intrinsic_lanes + * Apply provided intrinsic for each argument chunk + * Concatenate results + + Parameters + ---------- + args: List[PrimExpr] + List of arguments. Each arg expression should have vector type with lanes + equal `intrinsic_lanes * 2**n`. + + intrinsic: callable + Intrinsic implementation to apply. + + intrinsic_lanes: int + Vector length required by intrinsic implementation. + + Returns + ------- + res : PrimExpr + Resulting expression. + """ + + def split_args(args_set): + res_args_set = [] + for args_chunk in args_set: + res_args_chunk_l = [] + res_args_chunk_h = [] + for arg_chunk in args_chunk: + element, lanes = arg_chunk.dtype.split("x") + res_arg_chunk_dtype = f"{element}x{int(lanes) // 2}" + + res_args_chunk_l.append(tvm.tir.op.vectorlow(res_arg_chunk_dtype, arg_chunk)) + res_args_chunk_h.append(tvm.tir.op.vectorhigh(res_arg_chunk_dtype, arg_chunk)) + res_args_set += [res_args_chunk_l, res_args_chunk_h] + + return res_args_set + + def concat_args(res_chunks): + merged_res_chunks = [] + for i in range(0, len(res_chunks), 2): + arg_chunk_l = res_chunks[i] + arg_chunk_h = res_chunks[i + 1] + element, lanes = arg_chunk_l.dtype.split("x") + res_arg_chunk_dtype = f"{element}x{int(lanes) * 2}" + + merged_res_chunks.append( + tvm.tir.op.vectorcombine(res_arg_chunk_dtype, arg_chunk_l, arg_chunk_h) + ) + + return merged_res_chunks + + num_chunks = None + for arg in args: + _, lanes = arg.dtype.split("x") + lanes = int(lanes) + assert lanes % intrinsic_lanes == 0 + if num_chunks is None: + assert is_power_of_2(lanes // intrinsic_lanes) + num_chunks = lanes // intrinsic_lanes + + assert num_chunks == lanes // intrinsic_lanes + + # Split arguments + lowered_args = [args] + while len(lowered_args) != num_chunks: + lowered_args = split_args(lowered_args) + + # Intrinsic application + lowered_res = [] + for l_arg in lowered_args: + res = intrinsic(*l_arg) + lowered_res.append(res) + + # Result concatenation + while len(lowered_res) != 1: + lowered_res = concat_args(lowered_res) + + return lowered_res[0] + + def _q_multiply_shift_hexagon(op): """ Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and vmpyowh when q == 31. """ - x = op.args[0] - y = op.args[1] - fractional_bits = op.args[2] - shift = op.args[3] - - # Don't use this intrinsic if we don't have a int32x32 vector - # or if we are not multiplying q31 numbers - if x.dtype != "int32x32" or fractional_bits.value != 31: - return op + arg_x = op.args[0] + arg_fractional_bits = op.args[2] - # Case 1, shift is negative - mul_e_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - mul_o_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y - ) - fixup = 1 << (-shift - 1) - round_mul = mul_o_1 + fixup - out_negative_shift = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"), round_mul, shift - ) + # Don't use this intrinsic if we are not multiplying q31 numbers + if arg_fractional_bits.value != 31: + return op - # Case 2, shift is positive - x = x * (1 << (shift)) - mul_e_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - mul_o_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y - ) + x_lanes = get_lanes(arg_x.dtype) + if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32): + return op - # Select depending on the shift - return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2) + # pylint: disable=unused-argument + def intrinsic_lowering_32(x, y, fractional_bits, shift): + lowered_dtype = "int32x32" + + # Case 1, shift is negative + mul_e_1 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_1 = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_1, + x, + y, + ) + fixup = 1 << (-shift - 1) + round_mul = mul_o_1 + fixup + out_negative_shift = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vaslwv.128B", + tvm.tir.const(2, "uint32"), + round_mul, + shift, + ) + + # Case 2, shift is positive + x = x * (1 << (shift)) + mul_e_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_2, + x, + y, + ) + + # Select depending on the shift + return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2) + + return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_lowering_32, intrinsic_lanes=32) register_intrin_lowering( @@ -72,65 +193,87 @@ def _q_multiply_shift_per_axis_hexagon(op): Implementation of q_multiply_shift_per_axis through hexagon intrinsics vmpyewuh and vmpyowh when q == 31. """ - x = op.args[0] - y = op.args[1] - left_shift = op.args[2] - right_shift = op.args[3] - fractional_bits = op.args[4] - is_lshift_required = op.args[5] - is_rshift_required = op.args[6] - - # Don't use this intrinsic if we don't have a int32x32 vector - # or if we are not multiplying q31 numbers - if x.dtype != "int32x32" or fractional_bits.value != 31: + arg_x = op.args[0] + arg_fractional_bits = op.args[4] + arg_is_lshift_required = op.args[5] + arg_is_rshift_required = op.args[6] + + # Don't use this intrinsic if we are not multiplying q31 numbers + if arg_fractional_bits.value != 31: + return op + + x_lanes = get_lanes(arg_x.dtype) + if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32): return op # Don't use this intrinsic when we need do both: left and right shifts. # For now it is not clear how to implement this case through vector HVX instructions without # accuracy drop. - if is_rshift_required.value and is_lshift_required.value: + if arg_is_rshift_required.value and arg_is_lshift_required.value: return op - # Case 1: do the left shift - shifted_x = x << left_shift - mul_e_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), shifted_x, y - ) - left_shift_out = tvm.tir.call_llvm_intrin( - op.dtype, - "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", - tvm.tir.const(3, "uint32"), - mul_e_1, - shifted_x, - y, - ) - - # Case 2: do the right shift - mul_e_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - mul_o_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y - ) - fixup = 1 << (right_shift - 1) - round_mul = mul_o_2 + fixup - right_shift_out = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vasrwv.128B", tvm.tir.const(2, "uint32"), round_mul, right_shift - ) - - # Case 3: do neither right nor left shift - mul_e_3 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - no_shift_out = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_3, x, y - ) - - return tvm.tir.Select( - tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)), - no_shift_out, - tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out), - ) + # pylint: disable=unused-argument + def intrinsic_impl_32( + x, y, left_shift, right_shift, fractional_bits, is_lshift_required, is_rshift_required + ): + lowered_dtype = "int32x32" + + # Case 1: do the left shift + shifted_x = x << left_shift + mul_e_1 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), shifted_x, y + ) + left_shift_out = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_1, + shifted_x, + y, + ) + + # Case 2: do the right shift + mul_e_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_2, + x, + y, + ) + fixup = 1 << (right_shift - 1) + round_mul = mul_o_2 + fixup + right_shift_out = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vasrwv.128B", + tvm.tir.const(2, "uint32"), + round_mul, + right_shift, + ) + + # Case 3: do neither right nor left shift + mul_e_3 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + no_shift_out = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_3, + x, + y, + ) + + return tvm.tir.Select( + tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)), + no_shift_out, + tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out), + ) + + return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_impl_32, intrinsic_lanes=32) register_intrin_lowering( diff --git a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py index 5eac35f2d6..fdfe3ad2b7 100644 --- a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py +++ b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py @@ -21,6 +21,7 @@ import numpy as np import tvm.testing from tvm import relay +from tvm import te from tvm.relay.backend import Executor from tvm.contrib.hexagon.session import Session from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET @@ -100,7 +101,7 @@ class TestFixedPointMultiply: ) @tvm.testing.requires_hexagon - def test_fixed_point_multiply(self, hexagon_session: Session, multiplier: int, shift: int): + def test_per_tensor(self, hexagon_session: Session, multiplier: int, shift: int): """Fixed point multiply test.""" ishape = (6, 32) a = relay.var("a", relay.TensorType(ishape, "int32")) @@ -169,6 +170,141 @@ class TestFixedPointMultiply: tvm.testing.assert_allclose(hexagon_output, expected_output) + vector_size = tvm.testing.parameter(32, 64, 128, 256) + + def test_per_tensor_with_lanes(self, hexagon_session: Session, vector_size): + """Test fixed point multiply with vectorization. + Vectorization size is more than hw vector length""" + ishape = [2, 256, 16] + + def q_mul_shift(shape): + x = te.placeholder(shape, name="X", dtype="int32") + out = te.compute( + shape, + lambda i, j, k: tvm.tir.q_multiply_shift( + x[i, j, k], + tvm.tir.const(1395864320, "int32"), + tvm.tir.const(31, "int32"), + tvm.tir.const(1, "int32"), + ), + name="compute", + ) + return te.create_prim_func([x, out]) + + mod = q_mul_shift(ishape) + + # Schedule with vectorization + sch = tvm.tir.Schedule(mod) + b00 = sch.get_block(name="compute", func_name="main") + fused = sch.fuse(*sch.get_loops(block=b00)) + _, v = sch.split(loop=fused, factors=[None, vector_size]) + sch.vectorize(v) + + with tvm.transform.PassContext(opt_level=3): + hex_lib = tvm.build(sch.mod["main"], target=get_hexagon_target("v68")) + host_lib = tvm.build(mod, target=tvm.target.Target("llvm")) + + asm = hex_lib.get_source("asm") + + # Check that 'vmpye' instruction was generated in asm file. + vmpye_regex = re.compile(r"v\d{1,2}.w = vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)") + assert vmpye_regex.search(asm) is not None + + # Check that 'vmpyo' instruction was generated in asm file. + vmpyo_regex = re.compile(r"v\d{1,2}.w \+= vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift") + assert vmpyo_regex.search(asm) is not None + + # Verify accuracy + a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32") + b_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32") + hex_args = [ + tvm.runtime.ndarray.array(arg, device=hexagon_session.device, mem_scope="global") + for arg in [a_np, b_np] + ] + host_args = [tvm.runtime.ndarray.array(arg) for arg in [a_np, b_np]] + + hex_rt = hexagon_session.load_module(hex_lib) + hex_rt(*hex_args) + host_lib(*host_args) + + assert np.allclose(hex_args[1].numpy(), host_args[1].numpy()) + + def test_per_channel_with_lanes(self, hexagon_session: Session, vector_size): + """Test fixed point multiply with vectorization. + Vectorization size is more than hw vector length""" + a_shape = [2, 256, 16] + b_shape = [256] + + def q_mul_shift(shape): + shift_shape = [shape[1]] + x = te.placeholder(shape, name="X", dtype="int32") + y = te.placeholder(shift_shape, name="X", dtype="int32") + l_shift = te.placeholder(shift_shape, name="X", dtype="int32") + r_shift = te.placeholder(shift_shape, name="X", dtype="int32") + + out = te.compute( + shape, + lambda i, j, k: tvm.tir.q_multiply_shift_per_axis( + x[i, j, k], + y[j], + l_shift[j], + r_shift[j], + tvm.tir.const(31, "int32"), + tvm.tir.const(1, "bool"), + tvm.tir.const(0, "bool"), + ), + name="compute", + ) + return te.create_prim_func([x, y, l_shift, r_shift, out]) + + mod = q_mul_shift(a_shape) + + # Schedule with vectorization + sch = tvm.tir.Schedule(mod) + b00 = sch.get_block(name="compute", func_name="main") + fused = sch.fuse(*sch.get_loops(block=b00)) + _, v = sch.split(loop=fused, factors=[None, vector_size]) + sch.vectorize(v) + + with tvm.transform.PassContext(opt_level=3): + hex_lib = tvm.build(sch.mod["main"], target=get_hexagon_target("v68")) + host_lib = tvm.build(mod, target=tvm.target.Target("llvm")) + + asm = hex_lib.get_source("asm") + + # Check that 'vmpye' instruction was generated in asm file. + vmpye_regex = re.compile(r"v\d{1,2}.w = vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)") + assert vmpye_regex.search(asm) is not None + + # Check that 'vmpyo' instruction was generated in asm file. + vmpyo_regex = re.compile(r"v\d{1,2}.w \+= vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift") + assert vmpyo_regex.search(asm) is not None + + # Verify accuracy + x_np = ( + np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32") + ) + y_np = ( + np.random.randint(-1000, 1000, size=np.prod(b_shape)).reshape(b_shape).astype("int32") + ) + lsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32") + rsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32") + b_np = ( + np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32") + ) + np_args = [x_np, y_np, lsh_np, rsh_np, b_np] + hex_args = [ + tvm.runtime.ndarray.array(arg, device=hexagon_session.device, mem_scope="global") + for arg in np_args + ] + host_args = [tvm.runtime.ndarray.array(arg) for arg in np_args] + + hex_rt = hexagon_session.load_module(hex_lib) + hex_rt(*hex_args) + host_lib(*host_args) + + assert np.allclose(hex_args[4].numpy(), host_args[4].numpy()) + if __name__ == "__main__": tvm.testing.main()