From: Pan Li <pan2...@intel.com>

Update in v3:

* Rewrite comment for overloaded function add.
* Move get_non_overloaded_instance to function_base.

Update in v2:

* Add get_non_overloaded_instance for function instance.
* Fix overload check for policy function.
* Enrich the test cases check.

Original log:

This patch would like add the framework to support the RVV overloaded
intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did.

However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN
with below steps.

* Register overloaded functions.
* Add function_resolver for overloaded function resolving.
* Add resolve API for function shape with default implementation.
* Implement HOOK for navigating the overloaded API to non-overloaded API.

We validated this framework by the vmv_v intrinsic API(s), and we will
add more intrins API support in the underlying patches.

gcc/ChangeLog:

        * config/riscv/riscv-c.cc
        (riscv_resolve_overloaded_builtin): New function for the hook.
        (riscv_register_pragmas): Register the hook
        * config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl.
        * config/riscv/riscv-vector-builtins-shapes.cc (build_one):
        Register overloaded function.
        (struct overloaded_base): New struct for overloaded shape.
        (struct non_overloaded_base): New struct for non overloaded shape.
        (struct move_def): Inherit overloaded shape.
        * config/riscv/riscv-vector-builtins.cc
        (function_base::get_non_overloaded_instance): New API impl.
        (function_builder::add_function): Add overloaded arg.
        (function_resolver::function_resolver): New constructor.
        (function_builder::add_overloaded_function): New API impl.
        (function_resolver::resolve): Ditto.
        (function_resolver::lookup): Ditto.
        (function_resolver::get_sub_code): Ditto.
        (resolve_overloaded_builtin): New function impl.
        * config/riscv/riscv-vector-builtins.h:
        (class function_resolver): New class.

gcc/testsuite/ChangeLog:

        * gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test.
        * gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test.
        * gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test.

Signed-off-by: Pan Li <pan2...@intel.com>
---
 gcc/config/riscv/riscv-c.cc                   |  36 ++++
 gcc/config/riscv/riscv-protos.h               |   1 +
 .../riscv/riscv-vector-builtins-shapes.cc     |  20 ++-
 gcc/config/riscv/riscv-vector-builtins.cc     | 155 +++++++++++++++++-
 gcc/config/riscv/riscv-vector-builtins.h      |  36 +++-
 .../riscv/rvv/base/overloaded_rv32_vmv_v.c    |   8 +
 .../riscv/rvv/base/overloaded_rv64_vmv_v.c    |   8 +
 .../riscv/rvv/base/overloaded_vmv_v.h         |  27 +++
 8 files changed, 288 insertions(+), 3 deletions(-)
 create mode 100644 
gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
 create mode 100644 
gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
 create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h

diff --git a/gcc/config/riscv/riscv-c.cc b/gcc/config/riscv/riscv-c.cc
index 283052ae313..060edd3129d 100644
--- a/gcc/config/riscv/riscv-c.cc
+++ b/gcc/config/riscv/riscv-c.cc
@@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> 
arg_loc, tree fndecl,
   gcc_unreachable ();
 }
 
+/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN.  */
+static tree
+riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl,
+                                 void *uncast_arglist)
+{
+  vec<tree, va_gc> empty = {};
+  location_t loc = (location_t) uncast_location;
+  vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist;
+  unsigned int code = DECL_MD_FUNCTION_CODE (fndecl);
+  unsigned int subcode = code >> RISCV_BUILTIN_SHIFT;
+  tree new_fndecl = NULL_TREE;
+
+  if (!arglist)
+    arglist = &empty;
+
+  switch (code & RISCV_BUILTIN_CLASS)
+    {
+    case RISCV_BUILTIN_GENERAL:
+      break;
+    case RISCV_BUILTIN_VECTOR:
+      new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode,
+                                                            arglist);
+      break;
+    default:
+      gcc_unreachable ();
+    }
+
+  if (new_fndecl == NULL_TREE)
+    return new_fndecl;
+
+  return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL,
+                                 fndecl);
+}
+
 /* Implement REGISTER_TARGET_PRAGMAS.  */
 
 void
 riscv_register_pragmas (void)
 {
+  targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin;
   targetm.check_builtin_call = riscv_check_builtin_call;
+
   c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic);
 }
diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 6dbf6b9f943..5d2492dd031 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, 
gimple_stmt_iterator *, gcall *);
 rtx expand_builtin (unsigned int, tree, rtx);
 bool check_builtin_call (location_t, vec<location_t>, unsigned int,
                           tree, unsigned int, tree *);
+tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *);
 bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT);
 bool legitimize_move (rtx, rtx);
 void emit_vlmax_vsetvl (machine_mode, rtx);
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc 
b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
index f8fdec863e6..1c1a2cc9488 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
@@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info 
&group,
     group.ops_infos.types[vec_type_idx].index);
   b.allocate_argument_types (function_instance, argument_types);
   b.apply_predication (function_instance, return_type, argument_types);
+
+  b.add_overloaded_function (function_instance, *group.shape);
   b.add_unique_function (function_instance, (*group.shape), return_type,
                         argument_types);
 }
@@ -87,6 +89,22 @@ struct build_base : public function_shape
   }
 };
 
+struct overloaded_base : public build_base
+{
+  tree resolve (function_resolver &r) const override
+  {
+    return r.lookup ();
+  }
+};
+
+struct non_overloaded_base : public build_base
+{
+  tree resolve (function_resolver &) const override
+  {
+    gcc_unreachable ();
+  }
+};
+
 /* vsetvl_def class.  */
 struct vsetvl_def : public build_base
 {
@@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base
 };
 
 /* move_def class. Handle vmv.v.v/vmv.v.x.  */
-struct move_def : public build_base
+struct move_def : public overloaded_base
 {
   char *get_name (function_builder &b, const function_instance &instance,
                  bool overloaded_p) const override
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc 
b/gcc/config/riscv/riscv-vector-builtins.cc
index 6d99f970ead..4f6fbdc3e28 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -80,6 +80,10 @@ public:
 
   /* The decl itself.  */
   tree GTY ((skip)) decl;
+
+  /* True if the decl represents an overloaded function that needs to be
+     resolved by function_resolver.  */
+  bool overloaded_p;
 };
 
 /* Hash traits for registered_function.  */
@@ -3196,6 +3200,77 @@ function_instance::could_trap_p () const
   return false;
 }
 
+/* Try to get the non-overloaded function instance.
+   After we register the overloaded the functions, the registered functions
+   table may look like:
+
+   +--------+---------------------------+-------------------+
+   | index  | name                      | kind              |
+   +--------+---------------------------+-------------------+
+   | 124733 | __riscv_vmv_v             | Overloaded        | <- Hook fun code
+   +--------+---------------------------+-------------------+
+   | 124735 | __riscv_vmv_v_v_i8mf8     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124737 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124739 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124741 | __riscv_vmv_v_v_i8mf4     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124743 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124745 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124747 | __riscv_vmv_v_v_i8mf2     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124749 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124751 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124753 | __riscv_vmv_v_v_i8m1      | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124755 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+
+   When we resolve the overloaded API from the hook, we always get the first
+   function code of one API group (aka vmv_v as above table). We will search
+   start from that index to find the only one non-overloaded API with exactly
+   the same arglist. Or NULL instance will be returned.
+ */
+function_instance *
+function_base::get_non_overloaded_instance (unsigned int code,
+                                           vec<tree, va_gc> &arglist) const
+{
+  unsigned int code_limit = vec_safe_length (registered_functions);
+
+  for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+    {
+      registered_function *rfun = (*registered_functions)[fun_code];
+      function_instance instance = rfun->instance;
+
+      if (rfun->overloaded_p)
+       continue;
+
+      unsigned k;
+      const rvv_arg_type_info *args = instance.op_info->args;
+
+      for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+       {
+         if (k >= arglist.length ())
+           break;
+
+         if (TYPE_MODE (instance.get_arg_type (k))
+           != TYPE_MODE (TREE_TYPE (arglist[k])))
+           break;
+       }
+
+       if (args[k].base_type == NUM_BASE_TYPES)
+         return &rfun->instance;
+    }
+
+  return NULL;
+}
+
 function_builder::function_builder ()
 {
   m_direct_overloads = lang_GNU_CXX ();
@@ -3357,7 +3432,8 @@ function_builder::get_attributes (const function_instance 
&instance)
 registered_function &
 function_builder::add_function (const function_instance &instance,
                                const char *name, tree fntype, tree attrs,
-                               bool placeholder_p)
+                               bool placeholder_p,
+                               bool overloaded_p = false)
 {
   unsigned int code = vec_safe_length (registered_functions);
   code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR;
@@ -3383,6 +3459,7 @@ function_builder::add_function (const function_instance 
&instance,
   registered_function &rfn = *ggc_alloc<registered_function> ();
   rfn.instance = instance;
   rfn.decl = decl;
+  rfn.overloaded_p = overloaded_p;
   vec_safe_push (registered_functions, &rfn);
 
   return rfn;
@@ -3432,6 +3509,26 @@ function_builder::add_unique_function (const 
function_instance &instance,
   obstack_free (&m_string_obstack, name);
 }
 
+void
+function_builder::add_overloaded_function (const function_instance &instance,
+                                          const function_shape *shape)
+{
+  if (!check_required_extensions (instance))
+    return;
+
+  char *name = shape->get_name (*this, instance, true);
+
+  if (name)
+    {
+      /* To avoid API conflicting, take void return type and void argument
+        for the overloaded function.  */
+      tree fntype = build_function_type (void_type_node, void_list_node);
+      add_function (instance, name, fntype, NULL_TREE, m_direct_overloads,
+                   true);
+      obstack_free (&m_string_obstack, name);
+    }
+}
+
 function_call_info::function_call_info (location_t location_in,
                                        const function_instance &instance_in,
                                        tree fndecl_in)
@@ -3852,6 +3949,13 @@ function_checker::function_checker (location_t location,
     m_nargs (nargs), m_args (args)
 {}
 
+function_resolver::function_resolver (location_t location,
+                                     const function_instance &instance,
+                                     tree fndecl,
+                                     vec<tree, va_gc> &arglist)
+  : function_call_info (location, instance, fndecl), m_arglist (arglist)
+{}
+
 /* Report that LOCATION has a call to FNDECL in which argument ARGNO
    was not an integer constant expression.  ARGNO counts from zero.  */
 void
@@ -3967,6 +4071,39 @@ function_checker::check ()
   return shape->check (*this);
 }
 
+unsigned int
+function_resolver::get_sub_code ()
+{
+  unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl);
+
+  return fun_code >> RISCV_BUILTIN_SHIFT;
+}
+
+tree
+function_resolver::resolve ()
+{
+  return shape->resolve (*this);
+}
+
+tree
+function_resolver::lookup ()
+{
+  unsigned int fun_code = get_sub_code ();
+  function_instance *instance
+    = base->get_non_overloaded_instance (fun_code, m_arglist);
+
+  if (!instance)
+    return NULL_TREE;
+
+  hashval_t hash = instance->hash ();
+  registered_function *rfun = function_table->find_with_hash (*instance, hash);
+
+  if (!rfun)
+    return NULL_TREE;
+
+  return rfun->decl;
+}
+
 inline hashval_t
 registered_function_hasher::hash (value_type value)
 {
@@ -4196,6 +4333,22 @@ check_builtin_call (location_t location, 
vec<location_t>, unsigned int code,
                           TREE_TYPE (rfn.decl), nargs, args).check ();
 }
 
+tree
+resolve_overloaded_builtin (location_t loc, unsigned int code,
+                           vec<tree, va_gc> *arglist)
+{
+  if (code >= vec_safe_length (registered_functions))
+    return NULL_TREE;
+
+  const registered_function *rfun = (*registered_functions)[code];
+
+  if (!rfun || !rfun->overloaded_p)
+    return NULL_TREE;
+
+  return function_resolver (loc, rfun->instance, rfun->decl, *arglist)
+    .resolve ();
+}
+
 function_instance
 get_read_vl_instance (void)
 {
diff --git a/gcc/config/riscv/riscv-vector-builtins.h 
b/gcc/config/riscv/riscv-vector-builtins.h
index e358a8e4d91..e20f0f14ce4 100644
--- a/gcc/config/riscv/riscv-vector-builtins.h
+++ b/gcc/config/riscv/riscv-vector-builtins.h
@@ -277,6 +277,8 @@ public:
   void apply_predication (const function_instance &, tree, vec<tree> &) const;
   void add_unique_function (const function_instance &, const function_shape *,
                            tree, vec<tree> &);
+  void add_overloaded_function (const function_instance &,
+                               const function_shape *);
   void register_function_group (const function_group_info &);
   void append_name (const char *);
   void append_base_name (const char *);
@@ -288,7 +290,7 @@ private:
   tree get_attributes (const function_instance &);
 
   registered_function &add_function (const function_instance &, const char *,
-                                    tree, tree, bool);
+                                    tree, tree, bool, bool);
 
   /* True if we should create a separate decl for each instance of an
      overloaded function, instead of using function_builder.  */
@@ -424,6 +426,11 @@ public:
   /* Expand the given call into rtl.  Return the result of the function,
      or an arbitrary value if the function doesn't return a result.  */
   virtual rtx expand (function_expander &) const = 0;
+
+  /* Return the non-overloaded function instance from the registered
+     function table if success, or NULL will be returned.  */
+  virtual function_instance * get_non_overloaded_instance (
+    unsigned int, vec<tree, va_gc> &arglist) const;
 };
 
 /* A class for checking that the semantic constraints on a function call are
@@ -462,6 +469,29 @@ private:
   tree *m_args;
 };
 
+/* A class for resolving an overloaded function call.  */
+class function_resolver : public function_call_info
+{
+public:
+  function_resolver (location_t, const function_instance &, tree,
+                    vec<tree, va_gc> &);
+
+  /* Resolve the correlated non-overloaded function from the
+     the registered_functions table.  */
+  tree resolve ();
+
+  /* Lookup the non-overloaded function from the registered
+     function table.  */
+  tree lookup ();
+
+  /* Return the sub code of the fndecl.  */
+  unsigned int get_sub_code ();
+
+private:
+  /* The arguments to the overloaded function.  */
+  vec<tree, va_gc> &m_arglist;
+};
+
 /* Classifies functions into "shapes" base on:
 
    - Base name of the intrinsic function.
@@ -486,6 +516,10 @@ public:
   /* Check whether the given call is semantically valid.  Return true
    if it is, otherwise report an error and return false.  */
   virtual bool check (function_checker &) const { return true; }
+
+  /* Try to resolve the overloaded call.  Return the non-overloaded
+     function decl on success and NULL_TREE on failure.  */
+  virtual tree resolve (function_resolver &) const { return NULL_TREE; };
 };
 
 extern const char *const operand_suffixes[NUM_OP_TYPES];
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c 
b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
new file mode 100644
index 00000000000..56154da155b
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c 
b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
new file mode 100644
index 00000000000..f4a63c9585d
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times 
{vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h 
b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
new file mode 100644
index 00000000000..8756c5e17b7
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
@@ -0,0 +1,27 @@
+#include "riscv_vector.h"
+
+vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) {
+  return __riscv_vmv_v (src, vl);
+}
+
+vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) {
+  return __riscv_vmv_v (src, vl);
+}
+
+vint8m4_t test_vmv_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+                                size_t vl) {
+  return __riscv_vmv_v_tu (maskedoff, src, vl);
+}
+
+vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) {
+  return __riscv_vmv_v_v_i32m1 (src, vl);
+}
+
+vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) {
+  return __riscv_vmv_v_v_f16m1 (src, vl);
+}
+
+vint8m4_t test_vmv_non_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+                                    size_t vl) {
+  return __riscv_vmv_v_v_i8m4_tu (maskedoff, src, vl);
+}
-- 
2.34.1

Reply via email to