On Tue, Jul 30, 2019 at 12:01 PM Richard Sandiford
<richard.sandif...@arm.com> wrote:
>
> This patch extends the FMA handling in tree-ssa-math-opts.c so
> that it can cope with conditional multiplications as well as
> unconditional multiplications.  The addition or subtraction must then
> have the same condition as the multiplication (at least for now).
>
> E.g. we can currently fold:
>
>   (IFN_COND_ADD cond (mul x y) z fallback)
>     -> (IFN_COND_FMA cond x y z fallback)
>
> This patch also allows:
>
>   (IFN_COND_ADD cond (IFN_COND_MUL cond x y <whatever>) z fallback)
>     -> (IFN_COND_FMA cond x y z fallback)
>
> Tested on aarch64-linux-gnu, aarch64_be-elf and x86_64-linux-gnu.
> OK to install?

OK.

> Richard
>
>
> 2019-07-30  Richard Sandiford  <richard.sandif...@arm.com>
>
> gcc/
>         * tree-ssa-math-opts.c (convert_mult_to_fma): Add a mul_cond
>         parameter.  When nonnull, make sure that the addition or subtraction
>         has the same condition.
>         (math_opts_dom_walker::after_dom_children): Try convert_mult_to_fma
>         for CFN_COND_MUL too.
>
> gcc/testsuite/
>         * gcc.dg/vect/vect-cond-arith-7.c: New test.
>
> Index: gcc/tree-ssa-math-opts.c
> ===================================================================
> --- gcc/tree-ssa-math-opts.c    2019-07-30 10:51:22.000000000 +0100
> +++ gcc/tree-ssa-math-opts.c    2019-07-30 10:51:51.827405171 +0100
> @@ -3044,6 +3044,8 @@ last_fma_candidate_feeds_initial_phi (fm
>  /* Combine the multiplication at MUL_STMT with operands MULOP1 and MULOP2
>     with uses in additions and subtractions to form fused multiply-add
>     operations.  Returns true if successful and MUL_STMT should be removed.
> +   If MUL_COND is nonnull, the multiplication in MUL_STMT is conditional
> +   on MUL_COND, otherwise it is unconditional.
>
>     If STATE indicates that we are deferring FMA transformation, that means
>     that we do not produce FMAs for basic blocks which look like:
> @@ -3060,7 +3062,7 @@ last_fma_candidate_feeds_initial_phi (fm
>
>  static bool
>  convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
> -                    fma_deferring_state *state)
> +                    fma_deferring_state *state, tree mul_cond = NULL_TREE)
>  {
>    tree mul_result = gimple_get_lhs (mul_stmt);
>    tree type = TREE_TYPE (mul_result);
> @@ -3174,6 +3176,9 @@ convert_mult_to_fma (gimple *mul_stmt, t
>           return false;
>         }
>
> +      if (mul_cond && cond != mul_cond)
> +       return false;
> +
>        if (cond)
>         {
>           if (cond == result || else_value == result)
> @@ -3785,38 +3790,48 @@ math_opts_dom_walker::after_dom_children
>         }
>        else if (is_gimple_call (stmt))
>         {
> -         tree fndecl = gimple_call_fndecl (stmt);
> -         if (fndecl && gimple_call_builtin_p (stmt, BUILT_IN_NORMAL))
> +         switch (gimple_call_combined_fn (stmt))
>             {
> -             switch (DECL_FUNCTION_CODE (fndecl))
> +           CASE_CFN_POW:
> +             if (gimple_call_lhs (stmt)
> +                 && TREE_CODE (gimple_call_arg (stmt, 1)) == REAL_CST
> +                 && real_equal (&TREE_REAL_CST (gimple_call_arg (stmt, 1)),
> +                                &dconst2)
> +                 && convert_mult_to_fma (stmt,
> +                                         gimple_call_arg (stmt, 0),
> +                                         gimple_call_arg (stmt, 0),
> +                                         &fma_state))
>                 {
> -               case BUILT_IN_POWF:
> -               case BUILT_IN_POW:
> -               case BUILT_IN_POWL:
> -                 if (gimple_call_lhs (stmt)
> -                     && TREE_CODE (gimple_call_arg (stmt, 1)) == REAL_CST
> -                     && real_equal
> -                     (&TREE_REAL_CST (gimple_call_arg (stmt, 1)),
> -                      &dconst2)
> -                     && convert_mult_to_fma (stmt,
> -                                             gimple_call_arg (stmt, 0),
> -                                             gimple_call_arg (stmt, 0),
> -                                             &fma_state))
> -                   {
> -                     unlink_stmt_vdef (stmt);
> -                     if (gsi_remove (&gsi, true)
> -                         && gimple_purge_dead_eh_edges (bb))
> -                       *m_cfg_changed_p = true;
> -                     release_defs (stmt);
> -                     continue;
> -                   }
> -                 break;
> +                 unlink_stmt_vdef (stmt);
> +                 if (gsi_remove (&gsi, true)
> +                     && gimple_purge_dead_eh_edges (bb))
> +                   *m_cfg_changed_p = true;
> +                 release_defs (stmt);
> +                 continue;
> +               }
> +             break;
>
> -               default:;
> +           case CFN_COND_MUL:
> +             if (convert_mult_to_fma (stmt,
> +                                      gimple_call_arg (stmt, 1),
> +                                      gimple_call_arg (stmt, 2),
> +                                      &fma_state,
> +                                      gimple_call_arg (stmt, 0)))
> +
> +               {
> +                 gsi_remove (&gsi, true);
> +                 release_defs (stmt);
> +                 continue;
>                 }
> +             break;
> +
> +           case CFN_LAST:
> +             cancel_fma_deferring (&fma_state);
> +             break;
> +
> +           default:
> +             break;
>             }
> -         else
> -           cancel_fma_deferring (&fma_state);
>         }
>        gsi_next (&gsi);
>      }
> Index: gcc/testsuite/gcc.dg/vect/vect-cond-arith-7.c
> ===================================================================
> --- /dev/null   2019-07-30 08:53:31.317691683 +0100
> +++ gcc/testsuite/gcc.dg/vect/vect-cond-arith-7.c       2019-07-30 
> 10:51:51.823405201 +0100
> @@ -0,0 +1,60 @@
> +/* { dg-require-effective-target scalar_all_fma } */
> +/* { dg-additional-options "-fdump-tree-optimized -ffp-contract=fast" } */
> +
> +#include "tree-vect.h"
> +
> +#define N (VECTOR_BITS * 11 / 64 + 3)
> +
> +#define DEF(INV)                                       \
> +  void __attribute__ ((noipa))                         \
> +  f_##INV (double *restrict a, double *restrict b,     \
> +          double *restrict c, double *restrict d)      \
> +  {                                                    \
> +    for (int i = 0; i < N; ++i)                                \
> +      {                                                        \
> +       double mb = (INV & 1 ? -b[i] : b[i]);           \
> +       double mc = c[i];                               \
> +       double md = (INV & 2 ? -d[i] : d[i]);           \
> +       a[i] = b[i] < 10 ? mb * mc + md : 10.0;         \
> +      }                                                        \
> +  }
> +
> +#define TEST(INV)                                      \
> +  {                                                    \
> +    f_##INV (a, b, c, d);                              \
> +    for (int i = 0; i < N; ++i)                                \
> +      {                                                        \
> +       double mb = (INV & 1 ? -b[i] : b[i]);           \
> +       double mc = c[i];                               \
> +       double md = (INV & 2 ? -d[i] : d[i]);           \
> +       double fma = __builtin_fma (mb, mc, md);        \
> +       if (a[i] != (i % 17 < 10 ? fma : 10.0))         \
> +         __builtin_abort ();                           \
> +       asm volatile ("" ::: "memory");                 \
> +      }                                                        \
> +  }
> +
> +#define FOR_EACH_INV(T) \
> +  T (0) T (1) T (2) T (3)
> +
> +FOR_EACH_INV (DEF)
> +
> +int
> +main (void)
> +{
> +  double a[N], b[N], c[N], d[N];
> +  for (int i = 0; i < N; ++i)
> +    {
> +      b[i] = i % 17;
> +      c[i] = i % 9 + 11;
> +      d[i] = i % 13 + 14;
> +      asm volatile ("" ::: "memory");
> +    }
> +  FOR_EACH_INV (TEST)
> +  return 0;
> +}
> +
> +/* { dg-final { scan-tree-dump-times { = \.COND_FMA } 1 "optimized" { target 
> vect_double_cond_arith } } } */
> +/* { dg-final { scan-tree-dump-times { = \.COND_FMS } 1 "optimized" { target 
> vect_double_cond_arith } } } */
> +/* { dg-final { scan-tree-dump-times { = \.COND_FNMA } 1 "optimized" { 
> target vect_double_cond_arith } } } */
> +/* { dg-final { scan-tree-dump-times { = \.COND_FNMS } 1 "optimized" { 
> target vect_double_cond_arith } } } */

Reply via email to