From: Jakub Dupak <d...@jakubdupak.com>

gcc/rust/ChangeLog:

        * typecheck/rust-tyty.h (class ClosureType): Inherit interface.
        (class FnPtr): Inherit interface.
        (class FnType): Inherit interface.
        (class CallableTypeInterface): New interface.
        (BaseType::is): Detect interface members API.

Signed-off-by: Jakub Dupak <d...@jakubdupak.com>
---
 gcc/rust/typecheck/rust-tyty.h | 116 ++++++++++++++++++++++++++++-----
 1 file changed, 98 insertions(+), 18 deletions(-)

diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h
index da4b901724d..6ce760df66e 100644
--- a/gcc/rust/typecheck/rust-tyty.h
+++ b/gcc/rust/typecheck/rust-tyty.h
@@ -38,6 +38,10 @@ class AssociatedImplTrait;
 } // namespace Resolver
 
 namespace TyTy {
+class ClosureType;
+class FnPtr;
+class FnType;
+class CallableTypeInterface;
 
 // 
https://rustc-dev-guide.rust-lang.org/type-inference.html#inference-variables
 // 
https://doc.rust-lang.org/nightly/nightly-rustc/rustc_middle/ty/enum.TyKind.html#variants
@@ -242,6 +246,22 @@ protected:
   Analysis::Mappings *mappings;
 };
 
+/** Unified interface for all function-like types. */
+class CallableTypeInterface : public BaseType
+{
+public:
+  explicit CallableTypeInterface (HirId ref, HirId ty_ref, TypeKind kind,
+                                 RustIdent ident,
+                                 std::set<HirId> refs = std::set<HirId> ())
+    : BaseType (ref, ty_ref, kind, ident, refs)
+  {}
+
+  WARN_UNUSED_RESULT virtual size_t get_num_params () const = 0;
+  WARN_UNUSED_RESULT virtual BaseType *
+  get_param_type_at (size_t index) const = 0;
+  WARN_UNUSED_RESULT virtual BaseType *get_return_type () const = 0;
+};
+
 class InferType : public BaseType
 {
 public:
@@ -736,7 +756,7 @@ private:
   ReprOptions repr;
 };
 
-class FnType : public BaseType, public SubstitutionRef
+class FnType : public CallableTypeInterface, public SubstitutionRef
 {
 public:
   static constexpr auto KIND = TypeKind::FNDEF;
@@ -751,7 +771,7 @@ public:
          std::vector<std::pair<HIR::Pattern *, BaseType *>> params,
          BaseType *type, std::vector<SubstitutionParamMapping> subst_refs,
          std::set<HirId> refs = std::set<HirId> ())
-    : BaseType (ref, ref, TypeKind::FNDEF, ident, refs),
+    : CallableTypeInterface (ref, ref, TypeKind::FNDEF, ident, refs),
       SubstitutionRef (std::move (subst_refs),
                       SubstitutionArgumentMappings::error ()),
       params (std::move (params)), type (type), flags (flags),
@@ -766,7 +786,7 @@ public:
          std::vector<std::pair<HIR::Pattern *, BaseType *>> params,
          BaseType *type, std::vector<SubstitutionParamMapping> subst_refs,
          std::set<HirId> refs = std::set<HirId> ())
-    : BaseType (ref, ty_ref, TypeKind::FNDEF, ident, refs),
+    : CallableTypeInterface (ref, ty_ref, TypeKind::FNDEF, ident, refs),
       SubstitutionRef (std::move (subst_refs),
                       SubstitutionArgumentMappings::error ()),
       params (params), type (type), flags (flags), identifier (identifier),
@@ -832,8 +852,6 @@ public:
     return params.at (idx);
   }
 
-  BaseType *get_return_type () const { return type; }
-
   BaseType *clone () const final override;
 
   FnType *
@@ -842,6 +860,21 @@ public:
   ABI get_abi () const { return abi; }
   uint8_t get_flags () const { return flags; }
 
+  WARN_UNUSED_RESULT size_t get_num_params () const override
+  {
+    return params.size ();
+  }
+
+  WARN_UNUSED_RESULT BaseType *get_param_type_at (size_t index) const override
+  {
+    return param_at (index).second;
+  }
+
+  WARN_UNUSED_RESULT BaseType *get_return_type () const override
+  {
+    return type;
+  }
+
 private:
   std::vector<std::pair<HIR::Pattern *, BaseType *>> params;
   BaseType *type;
@@ -851,33 +884,50 @@ private:
   ABI abi;
 };
 
-class FnPtr : public BaseType
+class FnPtr : public CallableTypeInterface
 {
 public:
   static constexpr auto KIND = TypeKind::FNPTR;
 
   FnPtr (HirId ref, location_t locus, std::vector<TyVar> params,
         TyVar result_type, std::set<HirId> refs = std::set<HirId> ())
-    : BaseType (ref, ref, TypeKind::FNPTR,
-               {Resolver::CanonicalPath::create_empty (), locus}, refs),
+    : CallableTypeInterface (ref, ref, TypeKind::FNPTR,
+                            {Resolver::CanonicalPath::create_empty (), locus},
+                            refs),
       params (std::move (params)), result_type (result_type)
   {}
 
   FnPtr (HirId ref, HirId ty_ref, location_t locus, std::vector<TyVar> params,
         TyVar result_type, std::set<HirId> refs = std::set<HirId> ())
-    : BaseType (ref, ty_ref, TypeKind::FNPTR,
-               {Resolver::CanonicalPath::create_empty (), locus}, refs),
+    : CallableTypeInterface (ref, ty_ref, TypeKind::FNPTR,
+                            {Resolver::CanonicalPath::create_empty (), locus},
+                            refs),
       params (params), result_type (result_type)
   {}
 
   std::string get_name () const override final { return as_string (); }
 
-  BaseType *get_return_type () const { return result_type.get_tyty (); }
+  WARN_UNUSED_RESULT size_t get_num_params () const override
+  {
+    return params.size ();
+  }
+
+  WARN_UNUSED_RESULT BaseType *get_param_type_at (size_t index) const override
+  {
+    return params.at (index).get_tyty ();
+  }
+
+  WARN_UNUSED_RESULT BaseType *get_return_type () const override
+  {
+    return result_type.get_tyty ();
+  }
+
   const TyVar &get_var_return_type () const { return result_type; }
 
   size_t num_params () const { return params.size (); }
 
-  BaseType *param_at (size_t idx) const { return params.at (idx).get_tyty (); }
+  // DEPRECATED: Use get_param_type_at
+  BaseType *param_at (size_t idx) const { return get_param_type_at (idx); }
 
   void accept_vis (TyVisitor &vis) override;
   void accept_vis (TyConstVisitor &vis) const override;
@@ -898,19 +948,19 @@ private:
   TyVar result_type;
 };
 
-class ClosureType : public BaseType, public SubstitutionRef
+class ClosureType : public CallableTypeInterface, public SubstitutionRef
 {
 public:
   static constexpr auto KIND = TypeKind::CLOSURE;
 
-  ClosureType (HirId ref, DefId id, RustIdent ident,
-              TyTy::TupleType *parameters, TyVar result_type,
+  ClosureType (HirId ref, DefId id, RustIdent ident, TupleType *parameters,
+              TyVar result_type,
               std::vector<SubstitutionParamMapping> subst_refs,
               std::set<NodeId> captures,
               std::set<HirId> refs = std::set<HirId> (),
               std::vector<TypeBoundPredicate> specified_bounds
               = std::vector<TypeBoundPredicate> ())
-    : BaseType (ref, ref, TypeKind::CLOSURE, ident, refs),
+    : CallableTypeInterface (ref, ref, TypeKind::CLOSURE, ident, refs),
       SubstitutionRef (std::move (subst_refs),
                       SubstitutionArgumentMappings::error ()),
       parameters (parameters), result_type (std::move (result_type)), id (id),
@@ -922,13 +972,13 @@ public:
   }
 
   ClosureType (HirId ref, HirId ty_ref, RustIdent ident, DefId id,
-              TyTy::TupleType *parameters, TyVar result_type,
+              TupleType *parameters, TyVar result_type,
               std::vector<SubstitutionParamMapping> subst_refs,
               std::set<NodeId> captures,
               std::set<HirId> refs = std::set<HirId> (),
               std::vector<TypeBoundPredicate> specified_bounds
               = std::vector<TypeBoundPredicate> ())
-    : BaseType (ref, ty_ref, TypeKind::CLOSURE, ident, refs),
+    : CallableTypeInterface (ref, ty_ref, TypeKind::CLOSURE, ident, refs),
       SubstitutionRef (std::move (subst_refs),
                       SubstitutionArgumentMappings::error ()),
       parameters (parameters), result_type (std::move (result_type)), id (id),
@@ -942,6 +992,21 @@ public:
   void accept_vis (TyVisitor &vis) override;
   void accept_vis (TyConstVisitor &vis) const override;
 
+  WARN_UNUSED_RESULT size_t get_num_params () const override
+  {
+    return parameters->num_fields ();
+  }
+
+  WARN_UNUSED_RESULT BaseType *get_param_type_at (size_t index) const override
+  {
+    return parameters->get_field (index);
+  }
+
+  WARN_UNUSED_RESULT BaseType *get_return_type () const override
+  {
+    return result_type.get_tyty ();
+  }
+
   std::string as_string () const override;
   std::string get_name () const override final { return as_string (); }
 
@@ -1495,6 +1560,21 @@ private:
   DefId item;
 };
 
+template <>
+WARN_UNUSED_RESULT inline bool
+BaseType::is<CallableTypeInterface> () const
+{
+  auto kind = this->get_kind ();
+  return kind == FNPTR || kind == FNDEF || kind == CLOSURE;
+}
+
+template <>
+WARN_UNUSED_RESULT inline bool
+BaseType::is<const CallableTypeInterface> () const
+{
+  return this->is<CallableTypeInterface> ();
+}
+
 } // namespace TyTy
 } // namespace Rust
 
-- 
2.42.1

Reply via email to