From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai>

gcc/ChangeLog:

        * internal-fn.cc (expand_partial_store_optab_fn): Add 
LEN_MASK_{LOAD,STORE} vectorizer support.
        (internal_load_fn_p): Ditto.
        (internal_store_fn_p): Ditto.
        (internal_fn_mask_index): Ditto.
        (internal_fn_stored_value_index): Ditto.
        (internal_len_load_store_bias): Ditto.
        * optabs-query.cc (can_vec_mask_load_store_p): Ditto.
        (get_len_load_store_mode): Ditto.
        * tree-vect-stmts.cc (check_load_store_for_partial_vectors): Ditto.
        (get_all_ones_mask): New function.
        (vectorizable_store): Add LEN_MASK_{LOAD,STORE} vectorizer support.
        (vectorizable_load): Ditto.

---
 gcc/internal-fn.cc     |  35 +++++-
 gcc/optabs-query.cc    |  25 ++++-
 gcc/tree-vect-stmts.cc | 234 ++++++++++++++++++++++++++++++-----------
 3 files changed, 227 insertions(+), 67 deletions(-)

diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
index c911ae790cb..e10c21de5f1 100644
--- a/gcc/internal-fn.cc
+++ b/gcc/internal-fn.cc
@@ -2949,7 +2949,7 @@ expand_partial_load_optab_fn (internal_fn, gcall *stmt, 
convert_optab optab)
  * OPTAB.  */
 
 static void
-expand_partial_store_optab_fn (internal_fn, gcall *stmt, convert_optab optab)
+expand_partial_store_optab_fn (internal_fn ifn, gcall *stmt, convert_optab 
optab)
 {
   class expand_operand ops[5];
   tree type, lhs, rhs, maskt, biast;
@@ -2957,7 +2957,7 @@ expand_partial_store_optab_fn (internal_fn, gcall *stmt, 
convert_optab optab)
   insn_code icode;
 
   maskt = gimple_call_arg (stmt, 2);
-  rhs = gimple_call_arg (stmt, 3);
+  rhs = gimple_call_arg (stmt, internal_fn_stored_value_index (ifn));
   type = TREE_TYPE (rhs);
   lhs = expand_call_mem_ref (type, stmt, 0);
 
@@ -4435,6 +4435,7 @@ internal_load_fn_p (internal_fn fn)
     case IFN_GATHER_LOAD:
     case IFN_MASK_GATHER_LOAD:
     case IFN_LEN_LOAD:
+    case IFN_LEN_MASK_LOAD:
       return true;
 
     default:
@@ -4455,6 +4456,7 @@ internal_store_fn_p (internal_fn fn)
     case IFN_SCATTER_STORE:
     case IFN_MASK_SCATTER_STORE:
     case IFN_LEN_STORE:
+    case IFN_LEN_MASK_STORE:
       return true;
 
     default:
@@ -4494,6 +4496,10 @@ internal_fn_mask_index (internal_fn fn)
     case IFN_MASK_STORE_LANES:
       return 2;
 
+    case IFN_LEN_MASK_LOAD:
+    case IFN_LEN_MASK_STORE:
+      return 3;
+
     case IFN_MASK_GATHER_LOAD:
     case IFN_MASK_SCATTER_STORE:
       return 4;
@@ -4519,6 +4525,9 @@ internal_fn_stored_value_index (internal_fn fn)
     case IFN_LEN_STORE:
       return 3;
 
+    case IFN_LEN_MASK_STORE:
+      return 4;
+
     default:
       return -1;
     }
@@ -4583,13 +4592,31 @@ internal_len_load_store_bias (internal_fn ifn, 
machine_mode mode)
 {
   optab optab = direct_internal_fn_optab (ifn);
   insn_code icode = direct_optab_handler (optab, mode);
+  int bias_argno = 3;
+  if (icode == CODE_FOR_nothing)
+    {
+      machine_mode mask_mode
+       = targetm.vectorize.get_mask_mode (mode).require ();
+      if (ifn == IFN_LEN_LOAD)
+       {
+         /* Try LEN_MASK_LOAD.  */
+         optab = direct_internal_fn_optab (IFN_LEN_MASK_LOAD);
+       }
+      else
+       {
+         /* Try LEN_MASK_STORE.  */
+         optab = direct_internal_fn_optab (IFN_LEN_MASK_STORE);
+       }
+      icode = convert_optab_handler (optab, mode, mask_mode);
+      bias_argno = 4;
+    }
 
   if (icode != CODE_FOR_nothing)
     {
       /* For now we only support biases of 0 or -1.  Try both of them.  */
-      if (insn_operand_matches (icode, 3, GEN_INT (0)))
+      if (insn_operand_matches (icode, bias_argno, GEN_INT (0)))
        return 0;
-      if (insn_operand_matches (icode, 3, GEN_INT (-1)))
+      if (insn_operand_matches (icode, bias_argno, GEN_INT (-1)))
        return -1;
     }
 
diff --git a/gcc/optabs-query.cc b/gcc/optabs-query.cc
index 276f8408dd7..4394d391200 100644
--- a/gcc/optabs-query.cc
+++ b/gcc/optabs-query.cc
@@ -566,11 +566,14 @@ can_vec_mask_load_store_p (machine_mode mode,
                           bool is_load)
 {
   optab op = is_load ? maskload_optab : maskstore_optab;
+  optab len_op = is_load ? len_maskload_optab : len_maskstore_optab;
   machine_mode vmode;
 
   /* If mode is vector mode, check it directly.  */
   if (VECTOR_MODE_P (mode))
-    return convert_optab_handler (op, mode, mask_mode) != CODE_FOR_nothing;
+    return convert_optab_handler (op, mode, mask_mode) != CODE_FOR_nothing
+          || convert_optab_handler (len_op, mode, mask_mode)
+               != CODE_FOR_nothing;
 
   /* Otherwise, return true if there is some vector mode with
      the mask load/store supported.  */
@@ -584,7 +587,9 @@ can_vec_mask_load_store_p (machine_mode mode,
   vmode = targetm.vectorize.preferred_simd_mode (smode);
   if (VECTOR_MODE_P (vmode)
       && targetm.vectorize.get_mask_mode (vmode).exists (&mask_mode)
-      && convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing)
+      && (convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing
+         || convert_optab_handler (len_op, vmode, mask_mode)
+              != CODE_FOR_nothing))
     return true;
 
   auto_vector_modes vector_modes;
@@ -592,7 +597,9 @@ can_vec_mask_load_store_p (machine_mode mode,
   for (machine_mode base_mode : vector_modes)
     if (related_vector_mode (base_mode, smode).exists (&vmode)
        && targetm.vectorize.get_mask_mode (vmode).exists (&mask_mode)
-       && convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing)
+       && (convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing
+           || convert_optab_handler (len_op, vmode, mask_mode)
+                != CODE_FOR_nothing))
       return true;
   return false;
 }
@@ -608,17 +615,27 @@ opt_machine_mode
 get_len_load_store_mode (machine_mode mode, bool is_load)
 {
   optab op = is_load ? len_load_optab : len_store_optab;
+  optab masked_op = is_load ? len_maskload_optab : len_maskstore_optab;
   gcc_assert (VECTOR_MODE_P (mode));
 
   /* Check if length in lanes supported for this mode directly.  */
   if (direct_optab_handler (op, mode))
     return mode;
 
+  /* Check if length in lanes supported by len_maskload/store.  */
+  machine_mode mask_mode;
+  if (targetm.vectorize.get_mask_mode (mode).exists (&mask_mode)
+      && convert_optab_handler (masked_op, mode, mask_mode) != 
CODE_FOR_nothing)
+    return mode;
+
   /* Check if length in bytes supported for same vector size VnQI.  */
   machine_mode vmode;
   poly_uint64 nunits = GET_MODE_SIZE (mode);
   if (related_vector_mode (mode, QImode, nunits).exists (&vmode)
-      && direct_optab_handler (op, vmode))
+      && (direct_optab_handler (op, vmode)
+         || (targetm.vectorize.get_mask_mode (vmode).exists (&mask_mode)
+             && convert_optab_handler (masked_op, vmode, mask_mode)
+                  != CODE_FOR_nothing)))
     return vmode;
 
   return opt_machine_mode ();
diff --git a/gcc/tree-vect-stmts.cc b/gcc/tree-vect-stmts.cc
index 056a0ecb2be..3cb72d6017f 100644
--- a/gcc/tree-vect-stmts.cc
+++ b/gcc/tree-vect-stmts.cc
@@ -1819,16 +1819,8 @@ check_load_store_for_partial_vectors (loop_vec_info 
loop_vinfo, tree vectype,
   poly_uint64 nunits = TYPE_VECTOR_SUBPARTS (vectype);
   poly_uint64 vf = LOOP_VINFO_VECT_FACTOR (loop_vinfo);
   machine_mode mask_mode;
-  bool using_partial_vectors_p = false;
-  if (targetm.vectorize.get_mask_mode (vecmode).exists (&mask_mode)
-      && can_vec_mask_load_store_p (vecmode, mask_mode, is_load))
-    {
-      nvectors = group_memory_nvectors (group_size * vf, nunits);
-      vect_record_loop_mask (loop_vinfo, masks, nvectors, vectype, 
scalar_mask);
-      using_partial_vectors_p = true;
-    }
-
   machine_mode vmode;
+  bool using_partial_vectors_p = false;
   if (get_len_load_store_mode (vecmode, is_load).exists (&vmode))
     {
       nvectors = group_memory_nvectors (group_size * vf, nunits);
@@ -1837,6 +1829,13 @@ check_load_store_for_partial_vectors (loop_vec_info 
loop_vinfo, tree vectype,
       vect_record_loop_len (loop_vinfo, lens, nvectors, vectype, factor);
       using_partial_vectors_p = true;
     }
+  else if (targetm.vectorize.get_mask_mode (vecmode).exists (&mask_mode)
+          && can_vec_mask_load_store_p (vecmode, mask_mode, is_load))
+    {
+      nvectors = group_memory_nvectors (group_size * vf, nunits);
+      vect_record_loop_mask (loop_vinfo, masks, nvectors, vectype, 
scalar_mask);
+      using_partial_vectors_p = true;
+    }
 
   if (!using_partial_vectors_p)
     {
@@ -2809,6 +2808,58 @@ vect_build_zero_merge_argument (vec_info *vinfo,
   return vect_init_vector (vinfo, stmt_info, merge, vectype, NULL);
 }
 
+/* Get all-ones vector mask for corresponding vectype.  */
+
+static tree
+get_all_ones_mask (machine_mode vmode)
+{
+  machine_mode maskmode = targetm.vectorize.get_mask_mode (vmode).require ();
+  poly_uint64 nunits = GET_MODE_NUNITS (maskmode);
+  tree masktype = build_truth_vector_type_for_mode (nunits, maskmode);
+  return constant_boolean_node (true, masktype);
+}
+
+/* Get the partial vector IFN that the target supports.
+
+   For partial contiguous load, we could return IFN_LEN_LOAD, IFN_MASK_LOAD
+   or IFN_LEN_MASK_LOAD.
+
+   For partial contiguous load, we could return IFN_LEN_STORE, IFN_MASK_STORE
+   or IFN_LEN_MASK_STORE.
+*/
+
+static internal_fn
+partial_or_mask_vector_ifn (machine_mode vecmode, bool is_load)
+{
+  machine_mode maskmode;
+  machine_mode vmode;
+
+  if (get_len_load_store_mode (vecmode, is_load).exists (&vmode))
+    {
+      if (targetm.vectorize.get_mask_mode (vecmode).exists (&maskmode)
+         && can_vec_mask_load_store_p (vecmode, maskmode, is_load))
+       {
+         if (is_load)
+           return IFN_LEN_MASK_LOAD;
+         else
+           return IFN_LEN_MASK_STORE;
+       }
+      if (is_load)
+       return IFN_LEN_LOAD;
+      else
+       return IFN_LEN_STORE;
+    }
+  else if (targetm.vectorize.get_mask_mode (vecmode).exists (&maskmode)
+          && can_vec_mask_load_store_p (vecmode, maskmode, is_load))
+    {
+      if (is_load)
+       return IFN_MASK_LOAD;
+      else
+       return IFN_MASK_STORE;
+    }
+  gcc_unreachable ();
+}
+
 /* Build a gather load call while vectorizing STMT_INFO.  Insert new
    instructions before GSI and add them to VEC_STMT.  GS_INFO describes
    the gather load operation.  If the load is conditional, MASK is the
@@ -8945,30 +8996,46 @@ vectorizable_store (vec_info *vinfo,
                }
 
              /* Arguments are ready.  Create the new vector stmt.  */
-             if (final_mask)
-               {
-                 tree ptr = build_int_cst (ref_type, align * BITS_PER_UNIT);
-                 gcall *call
-                   = gimple_build_call_internal (IFN_MASK_STORE, 4,
-                                                 dataref_ptr, ptr,
-                                                 final_mask, vec_oprnd);
-                 gimple_call_set_nothrow (call, true);
-                 vect_finish_stmt_generation (vinfo, stmt_info, call, gsi);
-                 new_stmt = call;
-               }
-             else if (loop_lens)
+             internal_fn partial_ifn
+               = partial_or_mask_vector_ifn (TYPE_MODE (vectype), false);
+             tree final_len = NULL_TREE;
+             machine_mode vmode = TYPE_MODE (vectype);
+             machine_mode new_vmode;
+
+             /* Produce 'len' argument.  */
+             if (loop_lens)
                {
-                 machine_mode vmode = TYPE_MODE (vectype);
                  opt_machine_mode new_ovmode
                    = get_len_load_store_mode (vmode, false);
-                 machine_mode new_vmode = new_ovmode.require ();
+                 new_vmode = new_ovmode.require ();
                  unsigned factor
                    = (new_ovmode == vmode) ? 1 : GET_MODE_UNIT_SIZE (vmode);
-                 tree final_len
-                   = vect_get_loop_len (loop_vinfo, gsi, loop_lens,
-                                        vec_num * ncopies, vectype,
-                                        vec_num * j + i, factor);
-                 tree ptr = build_int_cst (ref_type, align * BITS_PER_UNIT);
+                 final_len = vect_get_loop_len (loop_vinfo, gsi, loop_lens,
+                                                vec_num * ncopies, vectype,
+                                                vec_num * j + i, factor);
+               }
+             if (partial_ifn == IFN_LEN_MASK_STORE)
+               {
+                 if (!final_len)
+                   {
+                     /* Pass VF value to 'len' argument of LEN_MASK_STORE if
+                      * LOOP_LENS is invalid.  */
+                     tree iv_type = LOOP_VINFO_RGROUP_IV_TYPE (loop_vinfo);
+                     final_len
+                       = build_int_cst (iv_type,
+                                        TYPE_VECTOR_SUBPARTS (vectype));
+                   }
+                 if (!final_mask)
+                   {
+                     /* Pass all ones value to 'mask' argument of
+                      * LEN_MASK_STORE if final_mask is invalid.  */
+                     final_mask = get_all_ones_mask (vmode);
+                   }
+               }
+
+             tree ptr = build_int_cst (ref_type, align * BITS_PER_UNIT);
+             if (final_len)
+               {
                  /* Need conversion if it's wrapped with VnQI.  */
                  if (vmode != new_vmode)
                    {
@@ -8987,14 +9054,32 @@ vectorizable_store (vec_info *vinfo,
                      vec_oprnd = var;
                    }
 
-                 signed char biasval =
-                   LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
+                 signed char biasval
+                   = LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
 
                  tree bias = build_int_cst (intQI_type_node, biasval);
+                 gcall *call;
+
+                 if (final_mask)
+                   call = gimple_build_call_internal (IFN_LEN_MASK_STORE, 6,
+                                                      dataref_ptr, ptr,
+                                                      final_len, final_mask,
+                                                      vec_oprnd, bias);
+                 else
+                   call
+                     = gimple_build_call_internal (IFN_LEN_STORE, 5,
+                                                   dataref_ptr, ptr, final_len,
+                                                   vec_oprnd, bias);
+                 gimple_call_set_nothrow (call, true);
+                 vect_finish_stmt_generation (vinfo, stmt_info, call, gsi);
+                 new_stmt = call;
+               }
+             else if (final_mask)
+               {
                  gcall *call
-                   = gimple_build_call_internal (IFN_LEN_STORE, 5, dataref_ptr,
-                                                 ptr, final_len, vec_oprnd,
-                                                 bias);
+                   = gimple_build_call_internal (IFN_MASK_STORE, 4,
+                                                 dataref_ptr, ptr,
+                                                 final_mask, vec_oprnd);
                  gimple_call_set_nothrow (call, true);
                  vect_finish_stmt_generation (vinfo, stmt_info, call, gsi);
                  new_stmt = call;
@@ -10304,45 +10389,66 @@ vectorizable_load (vec_info *vinfo,
                                              align, misalign);
                    align = least_bit_hwi (misalign | align);
 
-                   if (final_mask)
-                     {
-                       tree ptr = build_int_cst (ref_type,
-                                                 align * BITS_PER_UNIT);
-                       gcall *call
-                         = gimple_build_call_internal (IFN_MASK_LOAD, 3,
-                                                       dataref_ptr, ptr,
-                                                       final_mask);
-                       gimple_call_set_nothrow (call, true);
-                       new_stmt = call;
-                       data_ref = NULL_TREE;
-                     }
-                   else if (loop_lens && memory_access_type != VMAT_INVARIANT)
+                   internal_fn partial_ifn
+                     = partial_or_mask_vector_ifn (TYPE_MODE (vectype), true);
+                   tree final_len = NULL_TREE;
+                   machine_mode vmode = TYPE_MODE (vectype);
+                   machine_mode new_vmode;
+
+                   /* Produce 'len' argument.  */
+                   if (loop_lens)
                      {
-                       machine_mode vmode = TYPE_MODE (vectype);
                        opt_machine_mode new_ovmode
-                         = get_len_load_store_mode (vmode, true);
-                       machine_mode new_vmode = new_ovmode.require ();
+                         = get_len_load_store_mode (vmode, false);
+                       new_vmode = new_ovmode.require ();
                        unsigned factor = (new_ovmode == vmode)
                                            ? 1
                                            : GET_MODE_UNIT_SIZE (vmode);
-                       tree final_len
+                       final_len
                          = vect_get_loop_len (loop_vinfo, gsi, loop_lens,
                                               vec_num * ncopies, vectype,
                                               vec_num * j + i, factor);
-                       tree ptr
-                         = build_int_cst (ref_type, align * BITS_PER_UNIT);
+                     }
+                   if (partial_ifn == IFN_LEN_MASK_LOAD)
+                     {
+                       if (!final_len)
+                         {
+                           /* Pass VF value to 'len' argument of LEN_MASK_STORE
+                            * if LOOP_LENS is invalid.  */
+                           tree iv_type
+                             = LOOP_VINFO_RGROUP_IV_TYPE (loop_vinfo);
+                           final_len
+                             = build_int_cst (iv_type,
+                                              TYPE_VECTOR_SUBPARTS (vectype));
+                         }
+                       if (!final_mask)
+                         {
+                           /* Pass all ones value to 'mask' argument of
+                            * LEN_MASK_STORE if final_mask is invalid.  */
+                           final_mask = get_all_ones_mask (vmode);
+                         }
+                     }
+
+                   tree ptr = build_int_cst (ref_type, align * BITS_PER_UNIT);
+                   if (final_len && memory_access_type != VMAT_INVARIANT)
+                     {
+                       gcall *call;
 
                        tree qi_type = unsigned_intQI_type_node;
 
-                       signed char biasval =
-                         LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
+                       signed char biasval
+                         = LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
 
                        tree bias = build_int_cst (intQI_type_node, biasval);
-
-                       gcall *call
-                         = gimple_build_call_internal (IFN_LEN_LOAD, 4,
-                                                       dataref_ptr, ptr,
-                                                       final_len, bias);
+                       if (final_mask)
+                         call = gimple_build_call_internal (IFN_LEN_MASK_LOAD,
+                                                            5, dataref_ptr,
+                                                            ptr, final_len,
+                                                            final_mask, bias);
+                       else
+                         call = gimple_build_call_internal (IFN_LEN_LOAD, 4,
+                                                            dataref_ptr, ptr,
+                                                            final_len, bias);
                        gimple_call_set_nothrow (call, true);
                        new_stmt = call;
                        data_ref = NULL_TREE;
@@ -10363,6 +10469,16 @@ vectorizable_load (vec_info *vinfo,
                                                     VIEW_CONVERT_EXPR, op);
                          }
                      }
+                   else if (final_mask)
+                     {
+                       gcall *call
+                         = gimple_build_call_internal (IFN_MASK_LOAD, 3,
+                                                       dataref_ptr, ptr,
+                                                       final_mask);
+                       gimple_call_set_nothrow (call, true);
+                       new_stmt = call;
+                       data_ref = NULL_TREE;
+                     }
                    else
                      {
                        tree ltype = vectype;
-- 
2.36.1

Reply via email to