This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c34361369 [Hexagon] Adapt some intrinsics for high vector lanes 
(#14345)
6c34361369 is described below

commit 6c3436136926613367ac3e19ece8bed90c4d5efb
Author: apeskov <pesko...@gmail.com>
AuthorDate: Mon Mar 27 08:08:30 2023 +0400

    [Hexagon] Adapt some intrinsics for high vector lanes (#14345)
    
    * [HEX] Enhanced vector lanes for some intrinsics
    
    * fix pylint
    
    Signed-off-by: Alexander Peskov <pesko...@gmail.com>
    
    * fix lint 2
    
    Signed-off-by: Alexander Peskov <pesko...@gmail.com>
    
    * Fix typo
    
    Signed-off-by: Alexander Peskov <pesko...@gmail.com>
    
    ---------
    
    Signed-off-by: Alexander Peskov <pesko...@gmail.com>
---
 python/tvm/topi/hexagon/tensor_intrin.py           | 309 +++++++++++++++------
 .../test_hexagon/test_fixed_point_multiply.py      | 138 ++++++++-
 2 files changed, 363 insertions(+), 84 deletions(-)

diff --git a/python/tvm/topi/hexagon/tensor_intrin.py 
b/python/tvm/topi/hexagon/tensor_intrin.py
index 3e9fd47b0f..24bbacf37c 100644
--- a/python/tvm/topi/hexagon/tensor_intrin.py
+++ b/python/tvm/topi/hexagon/tensor_intrin.py
@@ -22,44 +22,165 @@ from tvm.ir import register_intrin_lowering
 from tvm import te
 
 
+def get_lanes(dtype: str):
+    if "x" not in dtype:
+        return 1
+
+    _, lanes = dtype.split("x")
+    return int(lanes)
+
+
+def is_vector_type(dtype: str):
+    return get_lanes(dtype) != 1
+
+
+def is_power_of_2(n: int):
+    return (n & (n - 1) == 0) and n != 0
+
+
+def _adapt_to_highest_lanes(*args, intrinsic=None, intrinsic_lanes: int = 0):
+    """Apply provided lowering intrinsic to arguments with longer vector data 
type.
+
+    This wrapper will do next actions:
+      * Split each argument into chunks with size equal intrinsic_lanes
+      * Apply provided intrinsic for each argument chunk
+      * Concatenate results
+
+    Parameters
+    ----------
+    args: List[PrimExpr]
+        List of arguments. Each arg expression should have vector type with 
lanes
+        equal `intrinsic_lanes * 2**n`.
+
+    intrinsic: callable
+        Intrinsic implementation to apply.
+
+    intrinsic_lanes: int
+        Vector length required by intrinsic implementation.
+
+    Returns
+    -------
+    res : PrimExpr
+        Resulting expression.
+    """
+
+    def split_args(args_set):
+        res_args_set = []
+        for args_chunk in args_set:
+            res_args_chunk_l = []
+            res_args_chunk_h = []
+            for arg_chunk in args_chunk:
+                element, lanes = arg_chunk.dtype.split("x")
+                res_arg_chunk_dtype = f"{element}x{int(lanes) // 2}"
+
+                
res_args_chunk_l.append(tvm.tir.op.vectorlow(res_arg_chunk_dtype, arg_chunk))
+                
res_args_chunk_h.append(tvm.tir.op.vectorhigh(res_arg_chunk_dtype, arg_chunk))
+            res_args_set += [res_args_chunk_l, res_args_chunk_h]
+
+        return res_args_set
+
+    def concat_args(res_chunks):
+        merged_res_chunks = []
+        for i in range(0, len(res_chunks), 2):
+            arg_chunk_l = res_chunks[i]
+            arg_chunk_h = res_chunks[i + 1]
+            element, lanes = arg_chunk_l.dtype.split("x")
+            res_arg_chunk_dtype = f"{element}x{int(lanes) * 2}"
+
+            merged_res_chunks.append(
+                tvm.tir.op.vectorcombine(res_arg_chunk_dtype, arg_chunk_l, 
arg_chunk_h)
+            )
+
+        return merged_res_chunks
+
+    num_chunks = None
+    for arg in args:
+        _, lanes = arg.dtype.split("x")
+        lanes = int(lanes)
+        assert lanes % intrinsic_lanes == 0
+        if num_chunks is None:
+            assert is_power_of_2(lanes // intrinsic_lanes)
+            num_chunks = lanes // intrinsic_lanes
+
+        assert num_chunks == lanes // intrinsic_lanes
+
+    # Split arguments
+    lowered_args = [args]
+    while len(lowered_args) != num_chunks:
+        lowered_args = split_args(lowered_args)
+
+    # Intrinsic application
+    lowered_res = []
+    for l_arg in lowered_args:
+        res = intrinsic(*l_arg)
+        lowered_res.append(res)
+
+    # Result concatenation
+    while len(lowered_res) != 1:
+        lowered_res = concat_args(lowered_res)
+
+    return lowered_res[0]
+
+
 def _q_multiply_shift_hexagon(op):
     """
     Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and 
vmpyowh when q == 31.
     """
-    x = op.args[0]
-    y = op.args[1]
-    fractional_bits = op.args[2]
-    shift = op.args[3]
-
-    # Don't use this intrinsic if we don't have a int32x32 vector
-    # or if we are not multiplying q31 numbers
-    if x.dtype != "int32x32" or fractional_bits.value != 31:
-        return op
+    arg_x = op.args[0]
+    arg_fractional_bits = op.args[2]
 
-    # Case 1, shift is negative
-    mul_e_1 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), 
x, y
-    )
-    mul_o_1 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, 
"uint32"), mul_e_1, x, y
-    )
-    fixup = 1 << (-shift - 1)
-    round_mul = mul_o_1 + fixup
-    out_negative_shift = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"), 
round_mul, shift
-    )
+    # Don't use this intrinsic if we are not multiplying q31 numbers
+    if arg_fractional_bits.value != 31:
+        return op
 
-    # Case 2, shift is positive
-    x = x * (1 << (shift))
-    mul_e_2 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), 
x, y
-    )
-    mul_o_2 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, 
"uint32"), mul_e_2, x, y
-    )
+    x_lanes = get_lanes(arg_x.dtype)
+    if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32):
+        return op
 
-    # Select depending on the shift
-    return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2)
+    # pylint: disable=unused-argument
+    def intrinsic_lowering_32(x, y, fractional_bits, shift):
+        lowered_dtype = "int32x32"
+
+        # Case 1, shift is negative
+        mul_e_1 = tvm.tir.call_llvm_intrin(
+            lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, 
"uint32"), x, y
+        )
+        mul_o_1 = tvm.tir.call_llvm_intrin(
+            lowered_dtype,
+            "llvm.hexagon.V6.vmpyowh.sacc.128B",
+            tvm.tir.const(3, "uint32"),
+            mul_e_1,
+            x,
+            y,
+        )
+        fixup = 1 << (-shift - 1)
+        round_mul = mul_o_1 + fixup
+        out_negative_shift = tvm.tir.call_llvm_intrin(
+            lowered_dtype,
+            "llvm.hexagon.V6.vaslwv.128B",
+            tvm.tir.const(2, "uint32"),
+            round_mul,
+            shift,
+        )
+
+        # Case 2, shift is positive
+        x = x * (1 << (shift))
+        mul_e_2 = tvm.tir.call_llvm_intrin(
+            lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, 
"uint32"), x, y
+        )
+        mul_o_2 = tvm.tir.call_llvm_intrin(
+            lowered_dtype,
+            "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
+            tvm.tir.const(3, "uint32"),
+            mul_e_2,
+            x,
+            y,
+        )
+
+        # Select depending on the shift
+        return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2)
+
+    return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_lowering_32, 
intrinsic_lanes=32)
 
 
 register_intrin_lowering(
@@ -72,65 +193,87 @@ def _q_multiply_shift_per_axis_hexagon(op):
     Implementation of q_multiply_shift_per_axis through hexagon intrinsics 
vmpyewuh and vmpyowh when
     q == 31.
     """
-    x = op.args[0]
-    y = op.args[1]
-    left_shift = op.args[2]
-    right_shift = op.args[3]
-    fractional_bits = op.args[4]
-    is_lshift_required = op.args[5]
-    is_rshift_required = op.args[6]
-
-    # Don't use this intrinsic if we don't have a int32x32 vector
-    # or if we are not multiplying q31 numbers
-    if x.dtype != "int32x32" or fractional_bits.value != 31:
+    arg_x = op.args[0]
+    arg_fractional_bits = op.args[4]
+    arg_is_lshift_required = op.args[5]
+    arg_is_rshift_required = op.args[6]
+
+    # Don't use this intrinsic if we are not multiplying q31 numbers
+    if arg_fractional_bits.value != 31:
+        return op
+
+    x_lanes = get_lanes(arg_x.dtype)
+    if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32):
         return op
 
     # Don't use this intrinsic when we need do both: left and right shifts.
     # For now it is not clear how to implement this case through vector HVX 
instructions without
     # accuracy drop.
-    if is_rshift_required.value and is_lshift_required.value:
+    if arg_is_rshift_required.value and arg_is_lshift_required.value:
         return op
 
-    # Case 1: do the left shift
-    shifted_x = x << left_shift
-    mul_e_1 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), 
shifted_x, y
-    )
-    left_shift_out = tvm.tir.call_llvm_intrin(
-        op.dtype,
-        "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
-        tvm.tir.const(3, "uint32"),
-        mul_e_1,
-        shifted_x,
-        y,
-    )
-
-    # Case 2: do the right shift
-    mul_e_2 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), 
x, y
-    )
-    mul_o_2 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, 
"uint32"), mul_e_2, x, y
-    )
-    fixup = 1 << (right_shift - 1)
-    round_mul = mul_o_2 + fixup
-    right_shift_out = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vasrwv.128B", tvm.tir.const(2, "uint32"), 
round_mul, right_shift
-    )
-
-    # Case 3: do neither right nor left shift
-    mul_e_3 = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), 
x, y
-    )
-    no_shift_out = tvm.tir.call_llvm_intrin(
-        op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, 
"uint32"), mul_e_3, x, y
-    )
-
-    return tvm.tir.Select(
-        tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)),
-        no_shift_out,
-        tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out),
-    )
+    # pylint: disable=unused-argument
+    def intrinsic_impl_32(
+        x, y, left_shift, right_shift, fractional_bits, is_lshift_required, 
is_rshift_required
+    ):
+        lowered_dtype = "int32x32"
+
+        # Case 1: do the left shift
+        shifted_x = x << left_shift
+        mul_e_1 = tvm.tir.call_llvm_intrin(
+            lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, 
"uint32"), shifted_x, y
+        )
+        left_shift_out = tvm.tir.call_llvm_intrin(
+            lowered_dtype,
+            "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
+            tvm.tir.const(3, "uint32"),
+            mul_e_1,
+            shifted_x,
+            y,
+        )
+
+        # Case 2: do the right shift
+        mul_e_2 = tvm.tir.call_llvm_intrin(
+            lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, 
"uint32"), x, y
+        )
+        mul_o_2 = tvm.tir.call_llvm_intrin(
+            lowered_dtype,
+            "llvm.hexagon.V6.vmpyowh.sacc.128B",
+            tvm.tir.const(3, "uint32"),
+            mul_e_2,
+            x,
+            y,
+        )
+        fixup = 1 << (right_shift - 1)
+        round_mul = mul_o_2 + fixup
+        right_shift_out = tvm.tir.call_llvm_intrin(
+            lowered_dtype,
+            "llvm.hexagon.V6.vasrwv.128B",
+            tvm.tir.const(2, "uint32"),
+            round_mul,
+            right_shift,
+        )
+
+        # Case 3: do neither right nor left shift
+        mul_e_3 = tvm.tir.call_llvm_intrin(
+            lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, 
"uint32"), x, y
+        )
+        no_shift_out = tvm.tir.call_llvm_intrin(
+            lowered_dtype,
+            "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
+            tvm.tir.const(3, "uint32"),
+            mul_e_3,
+            x,
+            y,
+        )
+
+        return tvm.tir.Select(
+            tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)),
+            no_shift_out,
+            tvm.tir.Select(is_lshift_required, left_shift_out, 
right_shift_out),
+        )
+
+    return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_impl_32, 
intrinsic_lanes=32)
 
 
 register_intrin_lowering(
diff --git a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py 
b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py
index 5eac35f2d6..fdfe3ad2b7 100644
--- a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py
+++ b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py
@@ -21,6 +21,7 @@ import numpy as np
 
 import tvm.testing
 from tvm import relay
+from tvm import te
 from tvm.relay.backend import Executor
 from tvm.contrib.hexagon.session import Session
 from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET
@@ -100,7 +101,7 @@ class TestFixedPointMultiply:
     )
 
     @tvm.testing.requires_hexagon
-    def test_fixed_point_multiply(self, hexagon_session: Session, multiplier: 
int, shift: int):
+    def test_per_tensor(self, hexagon_session: Session, multiplier: int, 
shift: int):
         """Fixed point multiply test."""
         ishape = (6, 32)
         a = relay.var("a", relay.TensorType(ishape, "int32"))
@@ -169,6 +170,141 @@ class TestFixedPointMultiply:
 
         tvm.testing.assert_allclose(hexagon_output, expected_output)
 
+    vector_size = tvm.testing.parameter(32, 64, 128, 256)
+
+    def test_per_tensor_with_lanes(self, hexagon_session: Session, 
vector_size):
+        """Test fixed point multiply with vectorization.
+        Vectorization size is more than hw vector length"""
+        ishape = [2, 256, 16]
+
+        def q_mul_shift(shape):
+            x = te.placeholder(shape, name="X", dtype="int32")
+            out = te.compute(
+                shape,
+                lambda i, j, k: tvm.tir.q_multiply_shift(
+                    x[i, j, k],
+                    tvm.tir.const(1395864320, "int32"),
+                    tvm.tir.const(31, "int32"),
+                    tvm.tir.const(1, "int32"),
+                ),
+                name="compute",
+            )
+            return te.create_prim_func([x, out])
+
+        mod = q_mul_shift(ishape)
+
+        # Schedule with vectorization
+        sch = tvm.tir.Schedule(mod)
+        b00 = sch.get_block(name="compute", func_name="main")
+        fused = sch.fuse(*sch.get_loops(block=b00))
+        _, v = sch.split(loop=fused, factors=[None, vector_size])
+        sch.vectorize(v)
+
+        with tvm.transform.PassContext(opt_level=3):
+            hex_lib = tvm.build(sch.mod["main"], 
target=get_hexagon_target("v68"))
+            host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))
+
+        asm = hex_lib.get_source("asm")
+
+        # Check that 'vmpye' instruction was generated in asm file.
+        vmpye_regex = re.compile(r"v\d{1,2}.w = 
vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)")
+        assert vmpye_regex.search(asm) is not None
+
+        # Check that 'vmpyo' instruction was generated in asm file.
+        vmpyo_regex = re.compile(r"v\d{1,2}.w \+= 
vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift")
+        assert vmpyo_regex.search(asm) is not None
+
+        # Verify accuracy
+        a_np = np.random.randint(-1000, 1000, 
size=np.prod(ishape)).reshape(ishape).astype("int32")
+        b_np = np.random.randint(-1000, 1000, 
size=np.prod(ishape)).reshape(ishape).astype("int32")
+        hex_args = [
+            tvm.runtime.ndarray.array(arg, device=hexagon_session.device, 
mem_scope="global")
+            for arg in [a_np, b_np]
+        ]
+        host_args = [tvm.runtime.ndarray.array(arg) for arg in [a_np, b_np]]
+
+        hex_rt = hexagon_session.load_module(hex_lib)
+        hex_rt(*hex_args)
+        host_lib(*host_args)
+
+        assert np.allclose(hex_args[1].numpy(), host_args[1].numpy())
+
+    def test_per_channel_with_lanes(self, hexagon_session: Session, 
vector_size):
+        """Test fixed point multiply with vectorization.
+        Vectorization size is more than hw vector length"""
+        a_shape = [2, 256, 16]
+        b_shape = [256]
+
+        def q_mul_shift(shape):
+            shift_shape = [shape[1]]
+            x = te.placeholder(shape, name="X", dtype="int32")
+            y = te.placeholder(shift_shape, name="X", dtype="int32")
+            l_shift = te.placeholder(shift_shape, name="X", dtype="int32")
+            r_shift = te.placeholder(shift_shape, name="X", dtype="int32")
+
+            out = te.compute(
+                shape,
+                lambda i, j, k: tvm.tir.q_multiply_shift_per_axis(
+                    x[i, j, k],
+                    y[j],
+                    l_shift[j],
+                    r_shift[j],
+                    tvm.tir.const(31, "int32"),
+                    tvm.tir.const(1, "bool"),
+                    tvm.tir.const(0, "bool"),
+                ),
+                name="compute",
+            )
+            return te.create_prim_func([x, y, l_shift, r_shift, out])
+
+        mod = q_mul_shift(a_shape)
+
+        # Schedule with vectorization
+        sch = tvm.tir.Schedule(mod)
+        b00 = sch.get_block(name="compute", func_name="main")
+        fused = sch.fuse(*sch.get_loops(block=b00))
+        _, v = sch.split(loop=fused, factors=[None, vector_size])
+        sch.vectorize(v)
+
+        with tvm.transform.PassContext(opt_level=3):
+            hex_lib = tvm.build(sch.mod["main"], 
target=get_hexagon_target("v68"))
+            host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))
+
+        asm = hex_lib.get_source("asm")
+
+        # Check that 'vmpye' instruction was generated in asm file.
+        vmpye_regex = re.compile(r"v\d{1,2}.w = 
vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)")
+        assert vmpye_regex.search(asm) is not None
+
+        # Check that 'vmpyo' instruction was generated in asm file.
+        vmpyo_regex = re.compile(r"v\d{1,2}.w \+= 
vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift")
+        assert vmpyo_regex.search(asm) is not None
+
+        # Verify accuracy
+        x_np = (
+            np.random.randint(-1000, 1000, 
size=np.prod(a_shape)).reshape(a_shape).astype("int32")
+        )
+        y_np = (
+            np.random.randint(-1000, 1000, 
size=np.prod(b_shape)).reshape(b_shape).astype("int32")
+        )
+        lsh_np = np.random.randint(0, 10, 
size=np.prod(b_shape)).reshape(b_shape).astype("int32")
+        rsh_np = np.random.randint(0, 10, 
size=np.prod(b_shape)).reshape(b_shape).astype("int32")
+        b_np = (
+            np.random.randint(-1000, 1000, 
size=np.prod(a_shape)).reshape(a_shape).astype("int32")
+        )
+        np_args = [x_np, y_np, lsh_np, rsh_np, b_np]
+        hex_args = [
+            tvm.runtime.ndarray.array(arg, device=hexagon_session.device, 
mem_scope="global")
+            for arg in np_args
+        ]
+        host_args = [tvm.runtime.ndarray.array(arg) for arg in np_args]
+
+        hex_rt = hexagon_session.load_module(hex_lib)
+        hex_rt(*hex_args)
+        host_lib(*host_args)
+
+        assert np.allclose(hex_args[4].numpy(), host_args[4].numpy())
+
 
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to