Hi Victor,

> On 10 Jul 2024, at 16:05, Victor Do Nascimento <victor.donascime...@arm.com> 
> wrote:
> 
> External email: Use caution opening links or attachments
> 
> 
> Given recent changes to the dot_prod standard pattern name, this patch
> fixes the aarch64 back-end by implementing the following changes:
> 
> 1. Add 2nd mode to all (u|s|us)dot_prod patterns in .md files.
> 2. Rewrite initialization and function expansion mechanism for simd
> builtins.
> 3. Fix all direct calls to back-end `dot_prod' patterns in SVE
> builtins.
> 
> Finally, given that it is now possible for the compiler to
> differentiate between the two- and four-way dot product, we add a test
> to ensure that autovectorization picks up on dot-product patterns
> where the result is twice the width of the operands.
> 
> gcc/ChangeLog:
> 
>        * config/aarch64/aarch64-builtins.cc (enum aarch64_builtins):
>        New AARCH64_BUILTIN_* enum values: SDOTV8QI, SDOTV16QI,
>        UDOTV8QI, UDOTV16QI, USDOTV8QI, USDOTV16QI.
>        (aarch64_init_builtin_dotprod_functions): New.
>        (aarch64_init_simd_builtins): Add call to
>        `aarch64_init_builtin_dotprod_functions'.
>        (aarch64_general_gimple_fold_builtin): Add DOT_PROD_EXPR
>        handling.
>        * config/aarch64/aarch64-simd-builtins.def: Remove macro
>        expansion-based initialization and expansion
>        of (u|s|us)dot_prod builtins.
>        * config/aarch64/aarch64-simd.md
>        (<sur>dot_prod<vsi2qi><vczle><vczbe>): Deleted.
>        (<sur>dot_prod<mode><vsi2qi><vczle><vczbe>): New.
>        (usdot_prod<vsi2qi><vczle><vczbe>): Deleted.
>        (usdot_prod<mode><vsi2qi><vczle><vczbe>): New.
>        (<su>sadv16qi): Adjust call to gen_udot_prod take second mode.
>        (popcount<mode2>): fix use of `udot_prod_optab'.
>        * config/aarch64/aarch64-sve-builtins-base.cc
>        (svdot_impl::expand): s/direct/convert/ in
>        `convert_optab_handler_for_sign' function call.
>        (svusdot_impl::expand): add second mode argument in call to
>        `code_for_dot_prod'.
>        * config/aarch64/aarch64-sve-builtins.cc
>        (function_expander::convert_optab_handler_for_sign): New class
>        method.
>        * config/aarch64/aarch64-sve-builtins.h
>        (class function_expander): Add prototype for new
>        `convert_optab_handler_for_sign' method.
>        * gcc/config/aarch64/aarch64-sve.md
>        (<sur>dot_prod<vsi2qi>): Deleted.
>        (<sur>dot_prod<mode><vsi2qi>): New.
>        (@<sur>dot_prod<vsi2qi>): Deleted.
>        (@<sur>dot_prod<mode><vsi2qi>): New.
>        (<su>sad<vsi2qi>): Adjust call to gen_udot_prod take second mode.
>        * gcc/config/aarch64/aarch64-sve2.md
>        (@aarch64_sve_<sur>dotvnx4sivnx8hi): Deleted.
>        (<sur>dot_prodvnx4sivnx8hi): New.
> 
> gcc/testsuite/ChangeLog:
>        * gcc.target/aarch64/sme/vect-dotprod-twoway.c (udot2): New.
> ---
> gcc/config/aarch64/aarch64-builtins.cc        | 71 +++++++++++++++++++
> gcc/config/aarch64/aarch64-simd-builtins.def  |  4 --
> gcc/config/aarch64/aarch64-simd.md            |  9 +--
> .../aarch64/aarch64-sve-builtins-base.cc      | 13 ++--
> gcc/config/aarch64/aarch64-sve-builtins.cc    | 17 +++++
> gcc/config/aarch64/aarch64-sve-builtins.h     |  3 +
> gcc/config/aarch64/aarch64-sve.md             |  6 +-
> gcc/config/aarch64/aarch64-sve2.md            |  2 +-
> gcc/config/aarch64/iterators.md               |  1 +
> .../aarch64/sme/vect-dotprod-twoway.c         | 25 +++++++
> 10 files changed, 133 insertions(+), 18 deletions(-)
> create mode 100644 gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c
> 
> diff --git a/gcc/config/aarch64/aarch64-builtins.cc 
> b/gcc/config/aarch64/aarch64-builtins.cc
> index 30669f8aa18..6c7c86d0e6e 100644
> --- a/gcc/config/aarch64/aarch64-builtins.cc
> +++ b/gcc/config/aarch64/aarch64-builtins.cc
> @@ -783,6 +783,12 @@ enum aarch64_builtins
>   AARCH64_SIMD_PATTERN_START = AARCH64_SIMD_BUILTIN_LANE_CHECK + 1,
>   AARCH64_SIMD_BUILTIN_MAX = AARCH64_SIMD_PATTERN_START
>                              + ARRAY_SIZE (aarch64_simd_builtin_data) - 1,
> +  AARCH64_BUILTIN_SDOTV8QI,
> +  AARCH64_BUILTIN_SDOTV16QI,
> +  AARCH64_BUILTIN_UDOTV8QI,
> +  AARCH64_BUILTIN_UDOTV16QI,
> +  AARCH64_BUILTIN_USDOTV8QI,
> +  AARCH64_BUILTIN_USDOTV16QI,
>   AARCH64_CRC32_BUILTIN_BASE,
>   AARCH64_CRC32_BUILTINS
>   AARCH64_CRC32_BUILTIN_MAX,
> @@ -1642,6 +1648,60 @@ handle_arm_neon_h (void)
>   aarch64_init_simd_intrinsics ();
> }
> 
> +void
> +aarch64_init_builtin_dotprod_functions (void)
> +{
> +  tree fndecl = NULL;
> +  tree ftype = NULL;
> +
> +  tree uv8qi = aarch64_simd_builtin_type (V8QImode, qualifier_unsigned);
> +  tree sv8qi = aarch64_simd_builtin_type (V8QImode, qualifier_none);
> +  tree uv16qi = aarch64_simd_builtin_type (V16QImode, qualifier_unsigned);
> +  tree sv16qi = aarch64_simd_builtin_type (V16QImode, qualifier_none);
> +  tree uv2si = aarch64_simd_builtin_type (V2SImode, qualifier_unsigned);
> +  tree sv2si = aarch64_simd_builtin_type (V2SImode, qualifier_none);
> +  tree uv4si = aarch64_simd_builtin_type (V4SImode, qualifier_unsigned);
> +  tree sv4si = aarch64_simd_builtin_type (V4SImode, qualifier_none);
> +
> +  struct builtin_decls_data
> +  {
> +    tree out_type_node;
> +    tree in_type1_node;
> +    tree in_type2_node;
> +    const char *builtin_name;
> +    int function_code;
> +  };
> +
> +#define NAME(A) "__builtin_aarch64_" #A
> +#define ENUM(B) AARCH64_BUILTIN_##B
> +
> +  builtin_decls_data bdda[] =
> +  {
> +    { sv2si, sv8qi,  sv8qi,  NAME (sdot_prodv8qi),       ENUM (SDOTV8QI)   },
> +    { uv2si, uv8qi,  uv8qi,  NAME (udot_prodv8qi_uuuu),   ENUM (UDOTV8QI)   
> },
> +    { sv2si, uv8qi,  sv8qi,  NAME (usdot_prodv8qi_suss),  ENUM (USDOTV8QI)  
> },
> +    { sv4si, sv16qi, sv16qi, NAME (sdot_prodv16qi),      ENUM (SDOTV16QI)  },
> +    { uv4si, uv16qi, uv16qi, NAME (udot_prodv16qi_uuuu),  ENUM (UDOTV16QI)  
> },
> +    { sv4si, uv16qi, sv16qi, NAME (usdot_prodv16qi_suss), ENUM (USDOTV16QI) 
> },
> +  };
> +
> +#undef NAME
> +#undef ENUM
> +
> +  builtin_decls_data *bdd = bdda;
> +  builtin_decls_data *bdd_end = bdd + (ARRAY_SIZE (bdda));
> +
> +  for (; bdd < bdd_end; bdd++)
> +  {
> +    ftype = build_function_type_list (bdd->out_type_node, bdd->in_type1_node,
> +                                     bdd->in_type2_node, bdd->out_type_node,
> +                                     NULL_TREE);
> +    fndecl = aarch64_general_add_builtin (bdd->builtin_name,
> +                                         ftype, bdd->function_code);
> +    aarch64_builtin_decls[bdd->function_code] = fndecl;
> +  }
> +}
> +
> static void
> aarch64_init_simd_builtins (void)
> {
> @@ -1654,6 +1714,8 @@ aarch64_init_simd_builtins (void)
>   aarch64_init_simd_builtin_scalar_types ();
> 
>   aarch64_init_simd_builtin_functions (false);
> +  aarch64_init_builtin_dotprod_functions ();
> +

Perhaps we should take this opportunity to instead migrate the dot-product 
intrinsics to the simulate_builtin_function_decl framework instead so that they 
get created as part of “#pragma GCC aarch64 “arm_neon.h””.

That’s the direction of travel we want with these builtins so I’d rather not 
complicate the legacy builtin handling code here.
I think it shouldn’t be much more work than this patch as you’ve already got 
the various static bookkeeping data on hand.

Thanks,
Kyrill


>   if (in_lto_p)
>     handle_arm_neon_h ();
> 
> @@ -3676,6 +3738,15 @@ aarch64_general_gimple_fold_builtin (unsigned int 
> fcode, gcall *stmt,
>          new_stmt = gimple_build_nop ();
>        }
>       break;
> +    case AARCH64_BUILTIN_SDOTV8QI:
> +    case AARCH64_BUILTIN_SDOTV16QI:
> +    case AARCH64_BUILTIN_UDOTV8QI:
> +    case AARCH64_BUILTIN_UDOTV16QI:
> +    case AARCH64_BUILTIN_USDOTV8QI:
> +    case AARCH64_BUILTIN_USDOTV16QI:
> +      new_stmt = gimple_build_assign (gimple_call_lhs (stmt),
> +                                     DOT_PROD_EXPR, args[0],
> +                                     args[1], args[2]);
>     default:
>       break;
>     }
> diff --git a/gcc/config/aarch64/aarch64-simd-builtins.def 
> b/gcc/config/aarch64/aarch64-simd-builtins.def
> index e65f73d7ba2..ea774ba1d49 100644
> --- a/gcc/config/aarch64/aarch64-simd-builtins.def
> +++ b/gcc/config/aarch64/aarch64-simd-builtins.def
> @@ -417,10 +417,6 @@
>   BUILTIN_VSDQ_I_DI (BINOP, srshl, 0, NONE)
>   BUILTIN_VSDQ_I_DI (BINOP_UUS, urshl, 0, NONE)
> 
> -  /* Implemented by <sur><dotprod>_prod<dot_mode>.  */
> -  BUILTIN_VB (TERNOP, sdot_prod, 10, NONE)
> -  BUILTIN_VB (TERNOPU, udot_prod, 10, NONE)
> -  BUILTIN_VB (TERNOP_SUSS, usdot_prod, 10, NONE)
>   /* Implemented by aarch64_<sur><dotprod>_lane{q}<dot_mode>.  */
>   BUILTIN_VB (QUADOP_LANE, sdot_lane, 0, NONE)
>   BUILTIN_VB (QUADOPU_LANE, udot_lane, 0, NONE)
> diff --git a/gcc/config/aarch64/aarch64-simd.md 
> b/gcc/config/aarch64/aarch64-simd.md
> index bbeee221f37..e8e1539fcf3 100644
> --- a/gcc/config/aarch64/aarch64-simd.md
> +++ b/gcc/config/aarch64/aarch64-simd.md
> @@ -568,7 +568,7 @@ (define_expand "cmul<conj_op><mode>3"
> ;; ...
> ;;
> ;; and so the vectorizer provides r, in which the result has to be 
> accumulated.
> -(define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>"
> +(define_insn "<sur>dot_prod<mode><vsi2qi><vczle><vczbe>"
>   [(set (match_operand:VS 0 "register_operand" "=w")
>        (plus:VS
>          (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w")
> @@ -582,7 +582,7 @@ (define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>"
> 
> ;; These instructions map to the __builtins for the Armv8.6-a I8MM usdot
> ;; (vector) Dot Product operation and the vectorized optab.
> -(define_insn "usdot_prod<vsi2qi><vczle><vczbe>"
> +(define_insn "usdot_prod<mode><vsi2qi><vczle><vczbe>"
>   [(set (match_operand:VS 0 "register_operand" "=w")
>        (plus:VS
>          (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w")
> @@ -1075,7 +1075,7 @@ (define_expand "<su>sadv16qi"
>        rtx ones = force_reg (V16QImode, CONST1_RTX (V16QImode));
>        rtx abd = gen_reg_rtx (V16QImode);
>        emit_insn (gen_aarch64_<su>abdv16qi (abd, operands[1], operands[2]));
> -       emit_insn (gen_udot_prodv16qi (operands[0], abd, ones, operands[3]));
> +       emit_insn (gen_udot_prodv4siv16qi (operands[0], abd, ones, 
> operands[3]));
>        DONE;
>       }
>     rtx reduc = gen_reg_rtx (V8HImode);
> @@ -3510,6 +3510,7 @@ (define_expand "popcount<mode>2"
>   {
>     /* Generate a byte popcount.  */
>     machine_mode mode = <bitsize> == 64 ? V8QImode : V16QImode;
> +    machine_mode mode2 = <bitsize> == 64 ? V2SImode : V4SImode;
>     rtx tmp = gen_reg_rtx (mode);
>     auto icode = optab_handler (popcount_optab, mode);
>     emit_insn (GEN_FCN (icode) (tmp, gen_lowpart (mode, operands[1])));
> @@ -3520,7 +3521,7 @@ (define_expand "popcount<mode>2"
>        /* For V4SI and V2SI, we can generate a UDOT with a 0 accumulator and a
>           1 multiplicand.  For V2DI, another UAADDLP is needed.  */
>        rtx ones = force_reg (mode, CONST1_RTX (mode));
> -       auto icode = optab_handler (udot_prod_optab, mode);
> +       auto icode = convert_optab_handler (udot_prod_optab, mode2, mode);
>        mode = <bitsize> == 64 ? V2SImode : V4SImode;
>        rtx dest = mode == <MODE>mode ? operands[0] : gen_reg_rtx (mode);
>        rtx zeros = force_reg (mode, CONST0_RTX (mode));
> diff --git a/gcc/config/aarch64/aarch64-sve-builtins-base.cc 
> b/gcc/config/aarch64/aarch64-sve-builtins-base.cc
> index aa26370d397..12ffaf7e6ca 100644
> --- a/gcc/config/aarch64/aarch64-sve-builtins-base.cc
> +++ b/gcc/config/aarch64/aarch64-sve-builtins-base.cc
> @@ -757,15 +757,16 @@ public:
>     e.rotate_inputs_left (0, 3);
>     insn_code icode;
>     if (e.type_suffix_ids[1] == NUM_TYPE_SUFFIXES)
> -      icode = e.direct_optab_handler_for_sign (sdot_prod_optab,
> -                                              udot_prod_optab,
> -                                              0, GET_MODE (e.args[0]));
> +      icode = e.convert_optab_handler_for_sign (sdot_prod_optab,
> +                                               udot_prod_optab,
> +                                               0, e.result_mode (),
> +                                               GET_MODE (e.args[0]));
>     else
>       icode = (e.type_suffix (0).float_p
>               ? CODE_FOR_aarch64_sve_fdotvnx4sfvnx8hf
>               : e.type_suffix (0).unsigned_p
> -              ? CODE_FOR_aarch64_sve_udotvnx4sivnx8hi
> -              : CODE_FOR_aarch64_sve_sdotvnx4sivnx8hi);
> +              ? CODE_FOR_udot_prodvnx4sivnx8hi
> +              : CODE_FOR_sdot_prodvnx4sivnx8hi);
>     return e.use_unpred_insn (icode);
>   }
> };
> @@ -2814,7 +2815,7 @@ public:
>        Hence we do the same rotation on arguments as svdot_impl does.  */
>     e.rotate_inputs_left (0, 3);
>     machine_mode mode = e.vector_mode (0);
> -    insn_code icode = code_for_dot_prod (UNSPEC_USDOT, mode);
> +    insn_code icode = code_for_dot_prod (UNSPEC_USDOT, e.result_mode (), 
> mode);
>     return e.use_exact_insn (icode);
>   }
> 
> diff --git a/gcc/config/aarch64/aarch64-sve-builtins.cc 
> b/gcc/config/aarch64/aarch64-sve-builtins.cc
> index f3983a123e3..0650e1d0a4d 100644
> --- a/gcc/config/aarch64/aarch64-sve-builtins.cc
> +++ b/gcc/config/aarch64/aarch64-sve-builtins.cc
> @@ -3745,6 +3745,23 @@ function_expander::direct_optab_handler_for_sign 
> (optab signed_op,
>   return ::direct_optab_handler (op, mode);
> }
> 
> +/* Choose between signed and unsigned convert optabs SIGNED_OP and
> +   UNSIGNED_OP based on the signedness of type suffix SUFFIX_I, then
> +   pick the appropriate optab handler for the mode.  Use MODE as the
> +   mode if given, otherwise use the mode of type suffix SUFFIX_I.  */
> +insn_code
> +function_expander::convert_optab_handler_for_sign (optab signed_op,
> +                                                  optab unsigned_op,
> +                                                  unsigned int suffix_i,
> +                                                  machine_mode to_mode,
> +                                                  machine_mode from_mode)
> +{
> +  if (from_mode == VOIDmode)
> +    from_mode = vector_mode (suffix_i);
> +  optab op = type_suffix (suffix_i).unsigned_p ? unsigned_op : signed_op;
> +  return ::convert_optab_handler (op, to_mode, from_mode);
> +}
> +
> /* Return true if X overlaps any input.  */
> bool
> function_expander::overlaps_input_p (rtx x)
> diff --git a/gcc/config/aarch64/aarch64-sve-builtins.h 
> b/gcc/config/aarch64/aarch64-sve-builtins.h
> index 9cc07d5fa3d..c277632e1dc 100644
> --- a/gcc/config/aarch64/aarch64-sve-builtins.h
> +++ b/gcc/config/aarch64/aarch64-sve-builtins.h
> @@ -659,6 +659,9 @@ public:
>   insn_code direct_optab_handler (optab, unsigned int = 0);
>   insn_code direct_optab_handler_for_sign (optab, optab, unsigned int = 0,
>                                           machine_mode = E_VOIDmode);
> +  insn_code convert_optab_handler_for_sign (optab, optab, unsigned int = 0,
> +                                           machine_mode = E_VOIDmode,
> +                                           machine_mode = E_VOIDmode);
> 
>   machine_mode result_mode () const;
> 
> diff --git a/gcc/config/aarch64/aarch64-sve.md 
> b/gcc/config/aarch64/aarch64-sve.md
> index 5331e7121d5..ce83a109725 100644
> --- a/gcc/config/aarch64/aarch64-sve.md
> +++ b/gcc/config/aarch64/aarch64-sve.md
> @@ -7196,7 +7196,7 @@ (define_insn_and_rewrite "*cond_fnma<mode>_any"
> ;; -------------------------------------------------------------------------
> 
> ;; Four-element integer dot-product with accumulation.
> -(define_insn "<sur>dot_prod<vsi2qi>"
> +(define_insn "<sur>dot_prod<mode><vsi2qi>"
>   [(set (match_operand:SVE_FULL_SDI 0 "register_operand")
>        (plus:SVE_FULL_SDI
>          (unspec:SVE_FULL_SDI
> @@ -7234,7 +7234,7 @@ (define_insn 
> "@aarch64_<sur>dot_prod_lane<SVE_FULL_SDI:mode><SVE_FULL_BHI:mode>"
>   }
> )
> 
> -(define_insn "@<sur>dot_prod<vsi2qi>"
> +(define_insn "@<sur>dot_prod<mode><vsi2qi>"
>   [(set (match_operand:VNx4SI_ONLY 0 "register_operand")
>         (plus:VNx4SI_ONLY
>          (unspec:VNx4SI_ONLY
> @@ -7292,7 +7292,7 @@ (define_expand "<su>sad<vsi2qi>"
>     rtx ones = force_reg (<VSI2QI>mode, CONST1_RTX (<VSI2QI>mode));
>     rtx diff = gen_reg_rtx (<VSI2QI>mode);
>     emit_insn (gen_<su>abd<vsi2qi>3 (diff, operands[1], operands[2]));
> -    emit_insn (gen_udot_prod<vsi2qi> (operands[0], diff, ones, operands[3]));
> +    emit_insn (gen_udot_prod<mode><vsi2qi> (operands[0], diff, ones, 
> operands[3]));
>     DONE;
>   }
> )
> diff --git a/gcc/config/aarch64/aarch64-sve2.md 
> b/gcc/config/aarch64/aarch64-sve2.md
> index 972b03a4fef..725092cc95f 100644
> --- a/gcc/config/aarch64/aarch64-sve2.md
> +++ b/gcc/config/aarch64/aarch64-sve2.md
> @@ -2021,7 +2021,7 @@ (define_insn 
> "@aarch64_sve_qsub_<sve_int_op>_lane_<mode>"
> )
> 
> ;; Two-way dot-product.
> -(define_insn "@aarch64_sve_<sur>dotvnx4sivnx8hi"
> +(define_insn "<sur>dot_prodvnx4sivnx8hi"
>   [(set (match_operand:VNx4SI 0 "register_operand")
>        (plus:VNx4SI
>          (unspec:VNx4SI
> diff --git a/gcc/config/aarch64/iterators.md b/gcc/config/aarch64/iterators.md
> index f527b2cfeb8..1864462ccfc 100644
> --- a/gcc/config/aarch64/iterators.md
> +++ b/gcc/config/aarch64/iterators.md
> @@ -2119,6 +2119,7 @@ (define_mode_attr vp [(V8QI "v") (V16QI "v")
> 
> (define_mode_attr vsi2qi [(V2SI "v8qi") (V4SI "v16qi")
>                          (VNx4SI "vnx16qi") (VNx2DI "vnx8hi")])
> +
> (define_mode_attr VSI2QI [(V2SI "V8QI") (V4SI "V16QI")
>                          (VNx4SI "VNx16QI") (VNx2DI "VNx8HI")])
> 
> diff --git a/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c 
> b/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c
> new file mode 100644
> index 00000000000..453f3a75e6f
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c
> @@ -0,0 +1,25 @@
> +/* { dg-additional-options "-march=armv9.2-a+sme2 -O2 -ftree-vectorize" } */
> +
> +#include <stdint.h>
> +
> +uint32_t udot2(int n, uint16_t* data) __arm_streaming
> +{
> +  uint32_t sum = 0;
> +  for (int i=0; i<n; i+=1) {
> +    sum += data[i] * data[i];
> +  }
> +  return sum;
> +}
> +
> +int32_t sdot2(int n, int16_t* data) __arm_streaming
> +{
> +  int32_t sum = 0;
> +  for (int i=0; i<n; i+=1) {
> +    sum += data[i] * data[i];
> +  }
> +  return sum;
> +}
> +
> +/* { dg-final { scan-assembler-times {\tudot\tz[0-9]+\.s, z[0-9]+\.h, 
> z[0-9]+\.h\n} 5 } } */
> +/* { dg-final { scan-assembler-times {\tsdot\tz[0-9]+\.s, z[0-9]+\.h, 
> z[0-9]+\.h\n} 5 } } */
> +/* { dg-final { scan-assembler-times {\twhilelo\t} 4 } } */
> --
> 2.34.1
> 

Reply via email to