Tamar Christina <tamar.christ...@arm.com> writes:
>> -----Original Message-----
>> From: Richard Sandiford <richard.sandif...@arm.com>
>> Sent: Tuesday, November 28, 2023 5:56 PM
>> To: Tamar Christina <tamar.christ...@arm.com>
>> Cc: gcc-patches@gcc.gnu.org; nd <n...@arm.com>; Richard Earnshaw
>> <richard.earns...@arm.com>; Marcus Shawcroft
>> <marcus.shawcr...@arm.com>; Kyrylo Tkachov <kyrylo.tkac...@arm.com>
>> Subject: Re: [PATCH 17/21]AArch64: Add implementation for vector cbranch for
>> Advanced SIMD
>> 
>> Richard Sandiford <richard.sandif...@arm.com> writes:
>> > Tamar Christina <tamar.christ...@arm.com> writes:
>> >> Hi All,
>> >>
>> >> This adds an implementation for conditional branch optab for AArch64.
>> >>
>> >> For e.g.
>> >>
>> >> void f1 ()
>> >> {
>> >>   for (int i = 0; i < N; i++)
>> >>     {
>> >>       b[i] += a[i];
>> >>       if (a[i] > 0)
>> >>   break;
>> >>     }
>> >> }
>> >>
>> >> For 128-bit vectors we generate:
>> >>
>> >>         cmgt    v1.4s, v1.4s, #0
>> >>         umaxp   v1.4s, v1.4s, v1.4s
>> >>         fmov    x3, d1
>> >>         cbnz    x3, .L8
>> >>
>> >> and of 64-bit vector we can omit the compression:
>> >>
>> >>         cmgt    v1.2s, v1.2s, #0
>> >>         fmov    x2, d1
>> >>         cbz     x2, .L13
>> >>
>> >> Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.
>> >>
>> >> Ok for master?
>> >>
>> >> Thanks,
>> >> Tamar
>> >>
>> >> gcc/ChangeLog:
>> >>
>> >>   * config/aarch64/aarch64-simd.md (cbranch<mode>4): New.
>> >>
>> >> gcc/testsuite/ChangeLog:
>> >>
>> >>   * gcc.target/aarch64/vect-early-break-cbranch.c: New test.
>> >>
>> >> --- inline copy of patch --
>> >> diff --git a/gcc/config/aarch64/aarch64-simd.md
>> b/gcc/config/aarch64/aarch64-simd.md
>> >> index
>> 90118c6348e9614bef580d1dc94c0c1841dd5204..cd5ec35c3f53028f14828bd7
>> 0a92924f62524c15 100644
>> >> --- a/gcc/config/aarch64/aarch64-simd.md
>> >> +++ b/gcc/config/aarch64/aarch64-simd.md
>> >> @@ -3830,6 +3830,46 @@ (define_expand
>> "vcond_mask_<mode><v_int_equiv>"
>> >>    DONE;
>> >>  })
>> >>
>> >> +;; Patterns comparing two vectors and conditionally jump
>> >> +
>> >> +(define_expand "cbranch<mode>4"
>> >> +  [(set (pc)
>> >> +        (if_then_else
>> >> +          (match_operator 0 "aarch64_equality_operator"
>> >> +            [(match_operand:VDQ_I 1 "register_operand")
>> >> +             (match_operand:VDQ_I 2 "aarch64_simd_reg_or_zero")])
>> >> +          (label_ref (match_operand 3 ""))
>> >> +          (pc)))]
>> >> +  "TARGET_SIMD"
>> >> +{
>> >> +  auto code = GET_CODE (operands[0]);
>> >> +  rtx tmp = operands[1];
>> >> +
>> >> +  /* If comparing against a non-zero vector we have to do a comparison 
>> >> first
>> >> +     so we can have a != 0 comparison with the result.  */
>> >> +  if (operands[2] != CONST0_RTX (<MODE>mode))
>> >> +    emit_insn (gen_vec_cmp<mode><mode> (tmp, operands[0], operands[1],
>> >> +                                 operands[2]));
>> >> +
>> >> +  /* For 64-bit vectors we need no reductions.  */
>> >> +  if (known_eq (128, GET_MODE_BITSIZE (<MODE>mode)))
>> >> +    {
>> >> +      /* Always reduce using a V4SI.  */
>> >> +      rtx reduc = gen_lowpart (V4SImode, tmp);
>> >> +      rtx res = gen_reg_rtx (V4SImode);
>> >> +      emit_insn (gen_aarch64_umaxpv4si (res, reduc, reduc));
>> >> +      emit_move_insn (tmp, gen_lowpart (<MODE>mode, res));
>> >> +    }
>> >> +
>> >> +  rtx val = gen_reg_rtx (DImode);
>> >> +  emit_move_insn (val, gen_lowpart (DImode, tmp));
>> >> +
>> >> +  rtx cc_reg = aarch64_gen_compare_reg (code, val, const0_rtx);
>> >> +  rtx cmp_rtx = gen_rtx_fmt_ee (code, DImode, cc_reg, const0_rtx);
>> >> +  emit_jump_insn (gen_condjump (cmp_rtx, cc_reg, operands[3]));
>> >> +  DONE;
>> >
>> > Are you sure this is correct for the operands[2] != const0_rtx case?
>> > It looks like it uses the same comparison code for the vector comparison
>> > and the scalar comparison.
>> >
>> > E.g. if the pattern is passed a comparison:
>> >
>> >   (eq (reg:V2SI x) (reg:V2SI y))
>> >
>> > it looks like we'd generate a CMEQ for the x and y, then branch
>> > when the DImode bitcast of the CMEQ result equals zero.  This means
>> > that we branch when no elements of x and y are equal, rather than
>> > when all elements of x and y are equal.
>> >
>> > E.g. for:
>> >
>> >    { 1, 2 } == { 1, 2 }
>> >
>> > CMEQ will produce { -1, -1 }, the scalar comparison will be -1 == 0,
>> > and the branch won't be taken.
>> >
>> > ISTM it would be easier for the operands[2] != const0_rtx case to use
>> > EOR instead of a comparison.  That gives a zero result if the input
>> > vectors are equal and a nonzero result if the input vectors are
>> > different.  We can then branch on the result using CODE and const0_rtx.
>> >
>> > (Hope I've got that right.)
>> >
>> > Maybe that also removes the need for patch 18.
>> 
>> Sorry, I forgot to say: we can't use operands[1] as a temporary,
>> since it's only an input to the pattern.  The EOR destination would
>> need to be a fresh register.
>
> I've updated the patch but it doesn't help since cbranch doesn't really push
> comparisons in.  So we don't seem to ever really get called with anything 
> non-zero.

I suppose it won't trigger for the early-break stuff, since for a scalar
== break condition, that wants:

  foo = a == b
  if (foo != 0)
    break

(break if one element equal) rather than:

  foo = a == b
  if (foo == -1)
    break

which is what would fold to:

  if (a == b)
    break

and so be a cbranch on (eq a b).  But keeping it as was would probably
be storing problems up for later.

> That said, I'm not entirely convince that the == case is correct. Since == 
> means all bits
> Equal instead of any bit set, and so it needs to generate cbz instead of cbnz 
> and I'm not
> sure that's guaranteed.

I see you've changed it from:

+  rtx cc_reg = aarch64_gen_compare_reg (code, val, const0_rtx);
+  rtx cmp_rtx = gen_rtx_fmt_ee (code, DImode, cc_reg, const0_rtx);
+  emit_jump_insn (gen_condjump (cmp_rtx, cc_reg, operands[3]));

to:

+  emit_jump_insn (gen_cbranchdi4 (operands[0], val, CONST0_RTX (DImode),
+                                 operands[3]));

Was that to fix a specific problem?  The original looked OK to me
for that part (it was the vector comparison that I was asking about).

If we do keep the cbranchdi4, I think it's more obviously correct to
recreate operands[0] with the new comparison operands, even if it
happens to work without.

For the == case, both the condjump and cbranch versions will branch iff
all bits of val are zero, which is true iff the result of the EOR is zero,
which is true iff the vector operands were bitwise identical.  So it looks
like it should work.

Thanks,
Richard

> I do have a failing testcase with this but haven't tracked down yet if the 
> mid-end did the
> right thing.  Think there might be a similar issue in a match.pd simplication.
>
> Thoughts on the == case?
>
> Thanks,
> Tamar
>
> --- inline copy of patch ---
>
> diff --git a/gcc/config/aarch64/aarch64-simd.md 
> b/gcc/config/aarch64/aarch64-simd.md
> index 
> c6f2d5828373f2a5272b9d1227bfe34365f9fd09..7b289b1fbec6b1f15fbf51b6c862bcf9a5588b6b
>  100644
> --- a/gcc/config/aarch64/aarch64-simd.md
> +++ b/gcc/config/aarch64/aarch64-simd.md
> @@ -3911,6 +3911,46 @@ (define_expand "vcond_mask_<mode><v_int_equiv>"
>    DONE;
>  })
>  
> +;; Patterns comparing two vectors and conditionally jump
> +
> +(define_expand "cbranch<mode>4"
> +  [(set (pc)
> +        (if_then_else
> +          (match_operator 0 "aarch64_equality_operator"
> +            [(match_operand:VDQ_I 1 "register_operand")
> +             (match_operand:VDQ_I 2 "aarch64_simd_reg_or_zero")])
> +          (label_ref (match_operand 3 ""))
> +          (pc)))]
> +  "TARGET_SIMD"
> +{
> +  auto code = GET_CODE (operands[0]);
> +  rtx tmp = operands[1];
> +
> +  /* If comparing against a non-zero vector we have to do a comparison first
> +     so we can have a != 0 comparison with the result.  */
> +  if (operands[2] != CONST0_RTX (<MODE>mode))
> +    {
> +      tmp = gen_reg_rtx (<MODE>mode);
> +      emit_insn (gen_xor<mode>3 (tmp, operands[1], operands[2]));
> +    }
> +
> +  /* For 64-bit vectors we need no reductions.  */
> +  if (known_eq (128, GET_MODE_BITSIZE (<MODE>mode)))
> +    {
> +      /* Always reduce using a V4SI.  */
> +      rtx reduc = gen_lowpart (V4SImode, tmp);
> +      rtx res = gen_reg_rtx (V4SImode);
> +      emit_insn (gen_aarch64_umaxpv4si (res, reduc, reduc));
> +      emit_move_insn (tmp, gen_lowpart (<MODE>mode, res));
> +    }
> +
> +  rtx val = gen_reg_rtx (DImode);
> +  emit_move_insn (val, gen_lowpart (DImode, tmp));
> +  emit_jump_insn (gen_cbranchdi4 (operands[0], val, CONST0_RTX (DImode),
> +                               operands[3]));
> +  DONE;
> +})
> +
>  ;; Patterns comparing two vectors to produce a mask.
>  
>  (define_expand "vec_cmp<mode><mode>"
> diff --git a/gcc/testsuite/gcc.target/aarch64/vect-early-break-cbranch.c 
> b/gcc/testsuite/gcc.target/aarch64/vect-early-break-cbranch.c
> new file mode 100644
> index 
> 0000000000000000000000000000000000000000..c0363c3787270507d7902bb2ac0e39faef63a852
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/aarch64/vect-early-break-cbranch.c
> @@ -0,0 +1,124 @@
> +/* { dg-do compile } */
> +/* { dg-options "-O3" } */
> +/* { dg-final { check-function-bodies "**" "" "" { target lp64 } } } */
> +
> +#pragma GCC target "+nosve"
> +
> +#define N 640
> +int a[N] = {0};
> +int b[N] = {0};
> +
> +
> +/*
> +** f1:
> +**   ...
> +**   cmgt    v[0-9]+.4s, v[0-9]+.4s, #0
> +**   umaxp   v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
> +**   fmov    x[0-9]+, d[0-9]+
> +**   cbnz    x[0-9]+, \.L[0-9]+
> +**   ...
> +*/
> +void f1 ()
> +{
> +  for (int i = 0; i < N; i++)
> +    {
> +      b[i] += a[i];
> +      if (a[i] > 0)
> +     break;
> +    }
> +}
> +
> +/*
> +** f2:
> +**   ...
> +**   cmge    v[0-9]+.4s, v[0-9]+.4s, #0
> +**   umaxp   v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
> +**   fmov    x[0-9]+, d[0-9]+
> +**   cbnz    x[0-9]+, \.L[0-9]+
> +**   ...
> +*/
> +void f2 ()
> +{
> +  for (int i = 0; i < N; i++)
> +    {
> +      b[i] += a[i];
> +      if (a[i] >= 0)
> +     break;
> +    }
> +}
> +
> +/*
> +** f3:
> +**   ...
> +**   cmeq    v[0-9]+.4s, v[0-9]+.4s, #0
> +**   umaxp   v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
> +**   fmov    x[0-9]+, d[0-9]+
> +**   cbnz    x[0-9]+, \.L[0-9]+
> +**   ...
> +*/
> +void f3 ()
> +{
> +  for (int i = 0; i < N; i++)
> +    {
> +      b[i] += a[i];
> +      if (a[i] == 0)
> +     break;
> +    }
> +}
> +
> +/*
> +** f4:
> +**   ...
> +**   cmtst   v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
> +**   umaxp   v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
> +**   fmov    x[0-9]+, d[0-9]+
> +**   cbnz    x[0-9]+, \.L[0-9]+
> +**   ...
> +*/
> +void f4 ()
> +{
> +  for (int i = 0; i < N; i++)
> +    {
> +      b[i] += a[i];
> +      if (a[i] != 0)
> +     break;
> +    }
> +}
> +
> +/*
> +** f5:
> +**   ...
> +**   cmlt    v[0-9]+.4s, v[0-9]+.4s, #0
> +**   umaxp   v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
> +**   fmov    x[0-9]+, d[0-9]+
> +**   cbnz    x[0-9]+, \.L[0-9]+
> +**   ...
> +*/
> +void f5 ()
> +{
> +  for (int i = 0; i < N; i++)
> +    {
> +      b[i] += a[i];
> +      if (a[i] < 0)
> +     break;
> +    }
> +}
> +
> +/*
> +** f6:
> +**   ...
> +**   cmle    v[0-9]+.4s, v[0-9]+.4s, #0
> +**   umaxp   v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
> +**   fmov    x[0-9]+, d[0-9]+
> +**   cbnz    x[0-9]+, \.L[0-9]+
> +**   ...
> +*/
> +void f6 ()
> +{
> +  for (int i = 0; i < N; i++)
> +    {
> +      b[i] += a[i];
> +      if (a[i] <= 0)
> +     break;
> +    }
> +}

Reply via email to