jonathanmeier updated this revision to Diff 225928.
jonathanmeier changed the repository for this revision from rC Clang to rG LLVM 
Github Monorepo.
jonathanmeier added a comment.

- Rebased to adapt to the latest changes for spaceship operator and comparison 
operator rewrite support in rL375305 <https://reviews.llvm.org/rL375305> and 
rL375306 <https://reviews.llvm.org/rL375306>.
- Added tests for comparison operator rewrites in fold expressions.
- Changed to using `llvm::iterator_range` instead of separate begin/end 
iterators.

ping @rsmith


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D67247/new/

https://reviews.llvm.org/D67247

Files:
  clang/include/clang/AST/ExprCXX.h
  clang/include/clang/AST/UnresolvedSet.h
  clang/include/clang/Sema/Sema.h
  clang/lib/Parse/ParseExpr.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/lib/Sema/SemaLookup.cpp
  clang/lib/Sema/SemaTemplateVariadic.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReaderStmt.cpp
  clang/lib/Serialization/ASTWriterStmt.cpp
  clang/test/SemaTemplate/cxx1z-fold-expressions.cpp

Index: clang/test/SemaTemplate/cxx1z-fold-expressions.cpp
===================================================================
--- clang/test/SemaTemplate/cxx1z-fold-expressions.cpp
+++ clang/test/SemaTemplate/cxx1z-fold-expressions.cpp
@@ -79,6 +79,36 @@
 static_assert(&apply(a, &A::b, &A::B::c, &A::B::C::d, &A::B::C::D::e) == &a.b.c.d.e);
 
 #if __cplusplus > 201703L
+
+namespace N {
+
+  struct Bool {
+    constexpr Bool(const bool& b) : b(b) {}
+    bool b;
+  };
+
+}
+
+constexpr bool operator==(const N::Bool& b1, const N::Bool& b2) { return b1.b == b2.b; }
+constexpr int operator<=>(const N::Bool& b1, const N::Bool& b2) { return b1.b - b2.b; }
+
+template<typename ...T> constexpr auto fold_eq(T ...t) { return (t == ...); }
+template<typename ...T> constexpr auto fold_neq(T ...t) { return (t != ...); }
+template<typename ...T> constexpr auto fold_le(T ...t) { return (t < ...); }
+template<typename ...T> constexpr auto fold_leq(T ...t) { return (t <= ...); }
+template<typename ...T> constexpr auto fold_ge(T ...t) { return (t > ...); }
+template<typename ...T> constexpr auto fold_geq(T ...t) { return (t >= ...); }
+
+static_assert(fold_eq(N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}));
+static_assert(!fold_eq(N::Bool{true}, N::Bool{true}, N::Bool{false}, N::Bool{true}, N::Bool{true}));
+static_assert(fold_neq(N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}));
+static_assert(!fold_neq(N::Bool{true}, N::Bool{true}, N::Bool{false}, N::Bool{true}, N::Bool{true}));
+
+static_assert(!fold_le(N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}));
+static_assert(fold_leq(N::Bool{false}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}));
+static_assert(fold_ge(N::Bool{true}, N::Bool{false}, N::Bool{true}, N::Bool{false}, N::Bool{false}));
+static_assert(!fold_ge(N::Bool{false}, N::Bool{false}, N::Bool{true}, N::Bool{false}, N::Bool{false}));
+
 // The <=> operator is unique among binary operators in not being a
 // fold-operator.
 // FIXME: This diagnostic is not great.
@@ -102,3 +132,49 @@
 
   Sum<1>::type<1, 2> x; // expected-note {{instantiation of}}
 }
+
+namespace N {
+  
+  struct A { int i; };
+  struct B { int i; };
+  
+  constexpr B operator+(const B& a, const B& b) { return { a.i + b.i }; }
+  
+}
+
+struct C { int i; };
+
+constexpr C operator+(const C& a, const C& b) { return { a.i + b.i }; }
+constexpr N::A operator+(const N::A& a, const N::A& b) { return { a.i + b.i }; }
+
+template<typename T1, typename ...T2> constexpr auto custom_fold(T1 t1, T2 ...t2) {
+  return (t2 + ...) + (... + t2) + (t2 + ... + t1) + (t1 + ... + t2);
+}
+
+static_assert(custom_fold(N::A{1}, N::A{2}, N::A{3}, N::A{4}, N::A{5}).i == 58);
+static_assert(custom_fold(N::B{1}, N::B{2}, N::B{3}, N::B{4}, N::B{5}).i == 58);
+static_assert(custom_fold(C{1}, C{2}, C{3}, C{4}, C{5}).i == 58);
+
+template<typename T, int I1, int ...I2> constexpr auto func_fold(
+    decltype((T{ I2 } + ...) + (... + T{ I2 }) + (T{ I2 } + ... + T{ I1 }) + (T{ I1 } + ... + T{ I2 })) t) {
+  return t.i;
+}
+
+static_assert(func_fold<N::A, 1, 2, 3, 4, 5>(N::A{ 42 }) == 42);
+static_assert(func_fold<N::B, 1, 2, 3, 4, 5>(N::B{ 42 }) == 42);
+static_assert(func_fold<C, 1, 2, 3, 4, 5>(C{ 42 }) == 42);
+
+struct D { int i; };
+
+namespace N {
+  
+  constexpr D operator+(const D& a, const D& b) { return { a.i + b.i }; }
+  
+}
+
+template<typename T1, typename ...T2> constexpr auto custom_fold_using(T1 t1, T2 ...t2) {
+  using N::operator+;
+  return (t2 + ...) + (... + t2) + (t2 + ... + t1) + (t1 + ... + t2);
+}
+
+static_assert(custom_fold_using(D{1}, D{2}, D{3}, D{4}, D{5}).i == 58);
Index: clang/lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTWriterStmt.cpp
+++ clang/lib/Serialization/ASTWriterStmt.cpp
@@ -1843,6 +1843,7 @@
 
 void ASTStmtWriter::VisitCXXFoldExpr(CXXFoldExpr *E) {
   VisitExpr(E);
+  Record.push_back(E->getNumOverloadCands());
   Record.AddSourceLocation(E->LParenLoc);
   Record.AddSourceLocation(E->EllipsisLoc);
   Record.AddSourceLocation(E->RParenLoc);
@@ -1850,6 +1851,12 @@
   Record.AddStmt(E->SubExprs[0]);
   Record.AddStmt(E->SubExprs[1]);
   Record.push_back(E->Opcode);
+  for (UnresolvedSetIterator I = E->overloadCands().begin(),
+                             End = E->overloadCands().end();
+       I != End; ++I) {
+    Record.AddDeclRef(I.getDecl());
+    Record.push_back(I.getAccess());
+  }
   Code = serialization::EXPR_CXX_FOLD;
 }
 
Index: clang/lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTReaderStmt.cpp
+++ clang/lib/Serialization/ASTReaderStmt.cpp
@@ -1908,6 +1908,7 @@
 
 void ASTStmtReader::VisitCXXFoldExpr(CXXFoldExpr *E) {
   VisitExpr(E);
+  unsigned NumOverloadCands = Record.readInt();
   E->LParenLoc = ReadSourceLocation();
   E->EllipsisLoc = ReadSourceLocation();
   E->RParenLoc = ReadSourceLocation();
@@ -1915,6 +1916,13 @@
   E->SubExprs[0] = Record.readSubExpr();
   E->SubExprs[1] = Record.readSubExpr();
   E->Opcode = (BinaryOperatorKind)Record.readInt();
+
+  DeclAccessPair *OverloadCands = E->getTrailingObjects<DeclAccessPair>();
+  for (unsigned I = 0; I != NumOverloadCands; ++I) {
+    auto *D = ReadDeclAs<NamedDecl>();
+    auto AS = (AccessSpecifier)Record.readInt();
+    OverloadCands[I].set(D, AS);
+  }
 }
 
 void ASTStmtReader::VisitOpaqueValueExpr(OpaqueValueExpr *E) {
@@ -3479,7 +3487,8 @@
       break;
 
     case EXPR_CXX_FOLD:
-      S = new (Context) CXXFoldExpr(Empty);
+      S = CXXFoldExpr::CreateEmpty(Context,
+                                   Record[ASTStmtReader::NumExprFields]);
       break;
 
     case EXPR_OPAQUE_VALUE:
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -2347,10 +2347,12 @@
   ///
   /// By default, performs semantic analysis to build the new expression.
   /// Subclasses may override this routine to provide different behavior.
-  ExprResult RebuildBinaryOperator(SourceLocation OpLoc,
-                                         BinaryOperatorKind Opc,
-                                         Expr *LHS, Expr *RHS) {
-    return getSema().BuildBinOp(/*Scope=*/nullptr, OpLoc, Opc, LHS, RHS);
+  ExprResult
+  RebuildBinaryOperator(SourceLocation OpLoc, BinaryOperatorKind Opc, Expr *LHS,
+                        Expr *RHS,
+                        const UnresolvedSetImpl *OverloadCands = nullptr) {
+    return getSema().BuildBinOp(/*Scope=*/nullptr, OpLoc, Opc, LHS, RHS,
+                                OverloadCands);
   }
 
   /// Build a new rewritten operator expression.
@@ -3317,13 +3319,14 @@
   ///
   /// By default, performs semantic analysis in order to build a new fold
   /// expression.
-  ExprResult RebuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS,
-                                BinaryOperatorKind Operator,
-                                SourceLocation EllipsisLoc, Expr *RHS,
-                                SourceLocation RParenLoc,
-                                Optional<unsigned> NumExpansions) {
+  ExprResult RebuildCXXFoldExpr(
+      SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Operator,
+      SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc,
+      Optional<unsigned> NumExpansions,
+      llvm::iterator_range<UnresolvedSetIterator> OverloadCands) {
     return getSema().BuildCXXFoldExpr(LParenLoc, LHS, Operator, EllipsisLoc,
-                                      RHS, RParenLoc, NumExpansions);
+                                      RHS, RParenLoc, NumExpansions,
+                                      OverloadCands);
   }
 
   /// Build an empty C++1z fold-expression with the given operator.
@@ -12191,7 +12194,7 @@
 
     return getDerived().RebuildCXXFoldExpr(
         E->getBeginLoc(), LHS.get(), E->getOperator(), E->getEllipsisLoc(),
-        RHS.get(), E->getEndLoc(), NumExpansions);
+        RHS.get(), E->getEndLoc(), NumExpansions, E->overloadCands());
   }
 
   // The transform has determined that we should perform an elementwise
@@ -12212,11 +12215,13 @@
 
     Result = getDerived().RebuildCXXFoldExpr(
         E->getBeginLoc(), Out.get(), E->getOperator(), E->getEllipsisLoc(),
-        Result.get(), E->getEndLoc(), OrigNumExpansions);
+        Result.get(), E->getEndLoc(), OrigNumExpansions, E->overloadCands());
     if (Result.isInvalid())
       return true;
   }
 
+  UnresolvedSet<8> OverloadCands;
+  OverloadCands.append(E->overloadCands().begin(), E->overloadCands().end());
   for (unsigned I = 0; I != *NumExpansions; ++I) {
     Sema::ArgumentPackSubstitutionIndexRAII SubstIndex(
         getSema(), LeftFold ? I : *NumExpansions - I - 1);
@@ -12230,13 +12235,13 @@
           E->getBeginLoc(), LeftFold ? Result.get() : Out.get(),
           E->getOperator(), E->getEllipsisLoc(),
           LeftFold ? Out.get() : Result.get(), E->getEndLoc(),
-          OrigNumExpansions);
+          OrigNumExpansions, E->overloadCands());
     } else if (Result.isUsable()) {
       // We've got down to a single element; build a binary operator.
       Result = getDerived().RebuildBinaryOperator(
           E->getEllipsisLoc(), E->getOperator(),
           LeftFold ? Result.get() : Out.get(),
-          LeftFold ? Out.get() : Result.get());
+          LeftFold ? Out.get() : Result.get(), &OverloadCands);
     } else
       Result = Out;
 
@@ -12255,7 +12260,7 @@
 
     Result = getDerived().RebuildCXXFoldExpr(
         E->getBeginLoc(), Result.get(), E->getOperator(), E->getEllipsisLoc(),
-        Out.get(), E->getEndLoc(), OrigNumExpansions);
+        Out.get(), E->getEndLoc(), OrigNumExpansions, E->overloadCands());
     if (Result.isInvalid())
       return true;
   }
Index: clang/lib/Sema/SemaTemplateVariadic.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateVariadic.cpp
+++ clang/lib/Sema/SemaTemplateVariadic.cpp
@@ -1153,8 +1153,8 @@
   }
 }
 
-ExprResult Sema::ActOnCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS,
-                                  tok::TokenKind Operator,
+ExprResult Sema::ActOnCXXFoldExpr(Scope *Sc, SourceLocation LParenLoc,
+                                  Expr *LHS, tok::TokenKind Operator,
                                   SourceLocation EllipsisLoc, Expr *RHS,
                                   SourceLocation RParenLoc) {
   // LHS and RHS must be cast-expressions. We allow an arbitrary expression
@@ -1195,18 +1195,34 @@
   }
 
   BinaryOperatorKind Opc = ConvertTokenKindToBinaryOpcode(Operator);
+
+  UnresolvedSet<8> Functions;
+  OverloadedOperatorKind OverOp = BinaryOperator::getOverloadedOperator(Opc);
+  if (Sc) {
+    if (OverOp != OO_None && OverOp != OO_Equal)
+      LookupOverloadedOperatorName(OverOp, Sc, Functions);
+
+    // In C++20 onwards, we may have a second operator to look up.
+    if (getLangOpts().CPlusPlus2a) {
+      if (OverloadedOperatorKind ExtraOp =
+              getRewrittenOverloadedOperator(OverOp))
+        LookupOverloadedOperatorName(ExtraOp, Sc, Functions);
+    }
+  }
+
   return BuildCXXFoldExpr(LParenLoc, LHS, Opc, EllipsisLoc, RHS, RParenLoc,
-                          None);
+                          None,
+                          llvm::make_range(Functions.begin(), Functions.end()));
 }
 
-ExprResult Sema::BuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS,
-                                  BinaryOperatorKind Operator,
-                                  SourceLocation EllipsisLoc, Expr *RHS,
-                                  SourceLocation RParenLoc,
-                                  Optional<unsigned> NumExpansions) {
-  return new (Context) CXXFoldExpr(Context.DependentTy, LParenLoc, LHS,
-                                   Operator, EllipsisLoc, RHS, RParenLoc,
-                                   NumExpansions);
+ExprResult Sema::BuildCXXFoldExpr(
+    SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Operator,
+    SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc,
+    Optional<unsigned> NumExpansions,
+    llvm::iterator_range<UnresolvedSetIterator> OverloadCands) {
+  return CXXFoldExpr::Create(Context, Context.DependentTy, LParenLoc, LHS,
+                             Operator, EllipsisLoc, RHS, RParenLoc,
+                             NumExpansions, OverloadCands);
 }
 
 ExprResult Sema::BuildEmptyCXXFoldExpr(SourceLocation EllipsisLoc,
Index: clang/lib/Sema/SemaLookup.cpp
===================================================================
--- clang/lib/Sema/SemaLookup.cpp
+++ clang/lib/Sema/SemaLookup.cpp
@@ -3037,7 +3037,6 @@
 }
 
 void Sema::LookupOverloadedOperatorName(OverloadedOperatorKind Op, Scope *S,
-                                        QualType T1, QualType T2,
                                         UnresolvedSetImpl &Functions) {
   // C++ [over.match.oper]p3:
   //     -- The set of non-member candidates is the result of the
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -13291,8 +13291,9 @@
 
 /// Build an overloaded binary operator expression in the given scope.
 static ExprResult BuildOverloadedBinOp(Sema &S, Scope *Sc, SourceLocation OpLoc,
-                                       BinaryOperatorKind Opc,
-                                       Expr *LHS, Expr *RHS) {
+                                       BinaryOperatorKind Opc, Expr *LHS,
+                                       Expr *RHS,
+                                       const UnresolvedSetImpl *OverloadCands) {
   switch (Opc) {
   case BO_Assign:
   case BO_DivAssign:
@@ -13308,22 +13309,25 @@
     break;
   }
 
+  UnresolvedSet<16> Functions;
+  if (OverloadCands)
+    Functions.append(OverloadCands->begin(), OverloadCands->end());
+
   // Find all of the overloaded operators visible from this
   // point. We perform both an operator-name lookup from the local
   // scope and an argument-dependent lookup based on the types of
   // the arguments.
-  UnresolvedSet<16> Functions;
-  OverloadedOperatorKind OverOp
-    = BinaryOperator::getOverloadedOperator(Opc);
-  if (Sc && OverOp != OO_None && OverOp != OO_Equal)
-    S.LookupOverloadedOperatorName(OverOp, Sc, LHS->getType(),
-                                   RHS->getType(), Functions);
+  OverloadedOperatorKind OverOp = BinaryOperator::getOverloadedOperator(Opc);
+  if (Sc) {
+    if (OverOp != OO_None && OverOp != OO_Equal)
+      S.LookupOverloadedOperatorName(OverOp, Sc, Functions);
 
-  // In C++20 onwards, we may have a second operator to look up.
-  if (S.getLangOpts().CPlusPlus2a) {
-    if (OverloadedOperatorKind ExtraOp = getRewrittenOverloadedOperator(OverOp))
-      S.LookupOverloadedOperatorName(ExtraOp, Sc, LHS->getType(),
-                                     RHS->getType(), Functions);
+    // In C++20 onwards, we may have a second operator to look up.
+    if (S.getLangOpts().CPlusPlus2a) {
+      if (OverloadedOperatorKind ExtraOp =
+              getRewrittenOverloadedOperator(OverOp))
+        S.LookupOverloadedOperatorName(ExtraOp, Sc, Functions);
+    }
   }
 
   // Build the (potentially-overloaded, potentially-dependent)
@@ -13332,8 +13336,9 @@
 }
 
 ExprResult Sema::BuildBinOp(Scope *S, SourceLocation OpLoc,
-                            BinaryOperatorKind Opc,
-                            Expr *LHSExpr, Expr *RHSExpr) {
+                            BinaryOperatorKind Opc, Expr *LHSExpr,
+                            Expr *RHSExpr,
+                            const UnresolvedSetImpl *OverloadCands) {
   ExprResult LHS, RHS;
   std::tie(LHS, RHS) = CorrectDelayedTyposInBinOp(*this, Opc, LHSExpr, RHSExpr);
   if (!LHS.isUsable() || !RHS.isUsable())
@@ -13367,7 +13372,8 @@
 
       if (RHSExpr->isTypeDependent() ||
           RHSExpr->getType()->isOverloadableType())
-        return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr);
+        return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr,
+                                    OverloadCands);
     }
 
     // If we're instantiating "a.x < b" or "A::x < b" and 'x' names a function
@@ -13405,7 +13411,8 @@
       if (getLangOpts().CPlusPlus &&
           (LHSExpr->isTypeDependent() || RHSExpr->isTypeDependent() ||
            LHSExpr->getType()->isOverloadableType()))
-        return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr);
+        return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr,
+                                    OverloadCands);
 
       return CreateBuiltinBinOp(OpLoc, Opc, LHSExpr, RHSExpr);
     }
@@ -13413,7 +13420,8 @@
     // Don't resolve overloads if the other type is overloadable.
     if (getLangOpts().CPlusPlus && pty->getKind() == BuiltinType::Overload &&
         LHSExpr->getType()->isOverloadableType())
-      return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr);
+      return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr,
+                                  OverloadCands);
 
     ExprResult resolvedRHS = CheckPlaceholderExpr(RHSExpr);
     if (!resolvedRHS.isUsable()) return ExprError();
@@ -13424,13 +13432,15 @@
     // If either expression is type-dependent, always build an
     // overloaded op.
     if (LHSExpr->isTypeDependent() || RHSExpr->isTypeDependent())
-      return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr);
+      return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr,
+                                  OverloadCands);
 
     // Otherwise, build an overloaded op if either expression has an
     // overloadable type.
     if (LHSExpr->getType()->isOverloadableType() ||
         RHSExpr->getType()->isOverloadableType())
-      return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr);
+      return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr,
+                                  OverloadCands);
   }
 
   // Build a built-in binary operation.
@@ -13746,8 +13756,7 @@
     UnresolvedSet<16> Functions;
     OverloadedOperatorKind OverOp = UnaryOperator::getOverloadedOperator(Opc);
     if (S && OverOp != OO_None)
-      LookupOverloadedOperatorName(OverOp, S, Input->getType(), QualType(),
-                                   Functions);
+      LookupOverloadedOperatorName(OverOp, S, Functions);
 
     return CreateOverloadedUnaryOp(OpLoc, Opc, Functions, Input);
   }
Index: clang/lib/Parse/ParseExpr.cpp
===================================================================
--- clang/lib/Parse/ParseExpr.cpp
+++ clang/lib/Parse/ParseExpr.cpp
@@ -2870,8 +2870,9 @@
                         : diag::ext_fold_expression);
 
   T.consumeClose();
-  return Actions.ActOnCXXFoldExpr(T.getOpenLocation(), LHS.get(), Kind,
-                                  EllipsisLoc, RHS.get(), T.getCloseLocation());
+  return Actions.ActOnCXXFoldExpr(getCurScope(), T.getOpenLocation(), LHS.get(),
+                                  Kind, EllipsisLoc, RHS.get(),
+                                  T.getCloseLocation());
 }
 
 /// ParseExpressionList - Used for C/C++ (argument-)expression-list.
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -3491,7 +3491,6 @@
   bool LookupInSuper(LookupResult &R, CXXRecordDecl *Class);
 
   void LookupOverloadedOperatorName(OverloadedOperatorKind Op, Scope *S,
-                                    QualType T1, QualType T2,
                                     UnresolvedSetImpl &Functions);
 
   LabelDecl *LookupOrCreateLabel(IdentifierInfo *II, SourceLocation IdentLoc,
@@ -4780,8 +4779,9 @@
 public:
   ExprResult ActOnBinOp(Scope *S, SourceLocation TokLoc,
                         tok::TokenKind Kind, Expr *LHSExpr, Expr *RHSExpr);
-  ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc,
-                        BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr);
+  ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc, BinaryOperatorKind Opc,
+                        Expr *LHSExpr, Expr *RHSExpr,
+                        const UnresolvedSetImpl *OverloadCands = nullptr);
   ExprResult CreateBuiltinBinOp(SourceLocation OpLoc, BinaryOperatorKind Opc,
                                 Expr *LHSExpr, Expr *RHSExpr);
 
@@ -5440,15 +5440,16 @@
                             SourceLocation RParenLoc);
 
   /// Handle a C++1z fold-expression: ( expr op ... op expr ).
-  ExprResult ActOnCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS,
+  ExprResult ActOnCXXFoldExpr(Scope *Sc, SourceLocation LParenLoc, Expr *LHS,
                               tok::TokenKind Operator,
                               SourceLocation EllipsisLoc, Expr *RHS,
                               SourceLocation RParenLoc);
-  ExprResult BuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS,
-                              BinaryOperatorKind Operator,
-                              SourceLocation EllipsisLoc, Expr *RHS,
-                              SourceLocation RParenLoc,
-                              Optional<unsigned> NumExpansions);
+  ExprResult
+  BuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS,
+                   BinaryOperatorKind Operator, SourceLocation EllipsisLoc,
+                   Expr *RHS, SourceLocation RParenLoc,
+                   Optional<unsigned> NumExpansions,
+                   llvm::iterator_range<UnresolvedSetIterator> OverloadCands);
   ExprResult BuildEmptyCXXFoldExpr(SourceLocation EllipsisLoc,
                                    BinaryOperatorKind Operator);
 
Index: clang/include/clang/AST/UnresolvedSet.h
===================================================================
--- clang/include/clang/AST/UnresolvedSet.h
+++ clang/include/clang/AST/UnresolvedSet.h
@@ -33,6 +33,7 @@
                                   std::random_access_iterator_tag, NamedDecl *,
                                   std::ptrdiff_t, NamedDecl *, NamedDecl *> {
   friend class ASTUnresolvedSet;
+  friend class CXXFoldExpr;
   friend class OverloadExpr;
   friend class UnresolvedSetImpl;
 
Index: clang/include/clang/AST/ExprCXX.h
===================================================================
--- clang/include/clang/AST/ExprCXX.h
+++ clang/include/clang/AST/ExprCXX.h
@@ -4532,7 +4532,9 @@
 ///    ( expr op ... )
 ///    ( ... op expr )
 ///    ( expr op ... op expr )
-class CXXFoldExpr : public Expr {
+class CXXFoldExpr final
+    : public Expr,
+      private llvm::TrailingObjects<CXXFoldExpr, DeclAccessPair> {
   friend class ASTStmtReader;
   friend class ASTStmtWriter;
 
@@ -4545,24 +4547,64 @@
   Stmt *SubExprs[2];
   BinaryOperatorKind Opcode;
 
-public:
+  unsigned NumOverloadCands;
+
   CXXFoldExpr(QualType T, SourceLocation LParenLoc, Expr *LHS,
               BinaryOperatorKind Opcode, SourceLocation EllipsisLoc, Expr *RHS,
-              SourceLocation RParenLoc, Optional<unsigned> NumExpansions)
+              SourceLocation RParenLoc, Optional<unsigned> NumExpansions,
+              llvm::iterator_range<UnresolvedSetIterator> OverloadCands)
       : Expr(CXXFoldExprClass, T, VK_RValue, OK_Ordinary,
              /*Dependent*/ true, true, true,
              /*ContainsUnexpandedParameterPack*/ false),
         LParenLoc(LParenLoc), EllipsisLoc(EllipsisLoc), RParenLoc(RParenLoc),
-        NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode) {
+        NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode),
+        NumOverloadCands(
+            std::distance(OverloadCands.begin(), OverloadCands.end())) {
     SubExprs[0] = LHS;
     SubExprs[1] = RHS;
+    DeclAccessPair *Results = getTrailingObjects<DeclAccessPair>();
+    memcpy(Results, OverloadCands.begin().I,
+           NumOverloadCands * sizeof(DeclAccessPair));
   }
 
-  CXXFoldExpr(EmptyShell Empty) : Expr(CXXFoldExprClass, Empty) {}
+  CXXFoldExpr(EmptyShell Empty, unsigned NumOverloadCands)
+      : Expr(CXXFoldExprClass, Empty), NumOverloadCands(NumOverloadCands) {}
+
+public:
+  static CXXFoldExpr *
+  Create(const ASTContext &Ctx, QualType T, SourceLocation LParenLoc, Expr *LHS,
+         BinaryOperatorKind Opcode, SourceLocation EllipsisLoc, Expr *RHS,
+         SourceLocation RParenLoc, Optional<unsigned> NumExpansions,
+         llvm::iterator_range<UnresolvedSetIterator> OverloadCands) {
+    unsigned Size = CXXFoldExpr::totalSizeToAlloc<DeclAccessPair>(
+        std::distance(OverloadCands.begin(), OverloadCands.end()));
+    void *Mem = Ctx.Allocate(Size, alignof(CXXFoldExpr));
+    return new (Mem) CXXFoldExpr(T, LParenLoc, LHS, Opcode, EllipsisLoc, RHS,
+                                 RParenLoc, NumExpansions, OverloadCands);
+  }
+
+  static CXXFoldExpr *CreateEmpty(const ASTContext &Ctx,
+                                  unsigned NumOverloadCands) {
+    unsigned Size =
+        CXXFoldExpr::totalSizeToAlloc<DeclAccessPair>(NumOverloadCands);
+    void *Mem = Ctx.Allocate(Size, alignof(CXXFoldExpr));
+    return new (Mem) CXXFoldExpr(EmptyShell(), NumOverloadCands);
+  }
 
   Expr *getLHS() const { return static_cast<Expr*>(SubExprs[0]); }
   Expr *getRHS() const { return static_cast<Expr*>(SubExprs[1]); }
 
+  unsigned getNumOverloadCands() { return NumOverloadCands; }
+
+  unsigned numTrailingObjects(OverloadToken<DeclAccessPair>) {
+    return getNumOverloadCands();
+  }
+
+  llvm::iterator_range<UnresolvedSetIterator> overloadCands() const {
+    auto Begin = UnresolvedSetIterator(getTrailingObjects<DeclAccessPair>());
+    return llvm::make_range(Begin, Begin + NumOverloadCands);
+  }
+
   /// Does this produce a right-associated sequence of operators?
   bool isRightFold() const {
     return getLHS() && getLHS()->containsUnexpandedParameterPack();
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to