EricWF created this revision.
EricWF added a reviewer: rsmith.
Herald added a subscriber: mgrang.

This is a work-in-progress attempt to add `operator<=>` rewriting. It's nowhere 
close to complete, but I would like some initial feedback on the direction.

As currently implemented, rewritten and synthesized candidates are only 
partially checked for viability when they're added to the overload candidates 
(the conversion sequence from the argument type -> parameter type is checked, 
but nothing else). This can lead non-viable candidates being selected or 
causing ambiguity when final overload resolution is attempted.

The solution implemented in this patch is to fully build "potentially viable" 
candidates if they are selected or cause ambiguity. If building the candidate 
fails, it is marked as non-viable, and a new best viable function is computed.

The problem with this approach is that it's expensive and potentially wasteful. 
For example, if overload resolution results in ambiguity with `N` rewritten 
candidates with the same partial ordering, then all `N` candidates are fully 
built, and those results potentially thrown away.

For builtin candidates this can be avoided by separating out the bits of 
`CheckBinaryOperands` which compute it's result type from the bits which 
actually build the expression and convert the arguments. Once we know the 
return type of the builtin, we can deduce if the comparison category type can 
be used in the expression `0 @ <comp-category>` without much effect.

However, for non-builtin overloads of `operator<=>` (which don't return a 
comparison category type), it seems non-trivial to check that `0 @ <result>` is 
valid without actually attempting to build the overloaded binary operator.

@rsmith Could you provide initial feedback on this direction? Or any ideas you 
have about how to best implement it?


Repository:
  rC Clang

https://reviews.llvm.org/D45680

Files:
  include/clang/Sema/Overload.h
  include/clang/Sema/Sema.h
  lib/Sema/SemaOverload.cpp
  test/SemaCXX/compare-cxx2a.cpp

Index: test/SemaCXX/compare-cxx2a.cpp
===================================================================
--- test/SemaCXX/compare-cxx2a.cpp
+++ test/SemaCXX/compare-cxx2a.cpp
@@ -293,20 +293,21 @@
 
 template <int>
 struct Tag {};
-// expected-note@+1 {{candidate}}
-Tag<0> operator<=>(EnumA, EnumA) {
-  return {};
+std::strong_ordering operator<=>(EnumA, EnumA) {
+  return std::strong_ordering::equal;
 }
-Tag<1> operator<=>(EnumA, EnumB) {
-  return {};
+// expected-note@+1 {{candidate function}},
+std::strong_ordering operator<=>(EnumA a, EnumB b) {
+  return ((int)a <=> (int)b);
 }
 
 void test_enum_ovl_provided() {
   auto r1 = (EnumA::A <=> EnumA::A);
-  ASSERT_EXPR_TYPE(r1, Tag<0>);
+  ASSERT_EXPR_TYPE(r1, std::strong_ordering);
   auto r2 = (EnumA::A <=> EnumB::B);
-  ASSERT_EXPR_TYPE(r2, Tag<1>);
-  (void)(EnumB::B <=> EnumA::A); // expected-error {{invalid operands to binary expression ('EnumCompareTests::EnumB' and 'EnumCompareTests::EnumA')}}
+  ASSERT_EXPR_TYPE(r2, std::strong_ordering);
+  (void)(EnumB::B <=> EnumA::A); // OK, chooses reverse order synthesized candidate.
+  (void)(EnumB::B <=> EnumC::C); // expected-error {{invalid operands to binary expression ('EnumCompareTests::EnumB' and 'EnumCompareTests::EnumC')}}
 }
 
 void enum_float_test() {
@@ -375,3 +376,60 @@
   ASSERT_EXPR_TYPE(r4, std::partial_ordering);
 
 }
+
+namespace TestRewritting {
+
+struct T {
+  int x;
+  // expected-note@+1 {{candidate}}
+  constexpr std::strong_ordering operator<=>(T y) const {
+    return (x <=> y.x);
+  }
+};
+
+struct U {
+  int x;
+  // FIXME: This diagnostic is terrible.
+  // expected-note@+1 {{candidate function not viable: requires single argument 'y', but 2 arguments were provided}}
+  constexpr std::strong_equality operator<=>(T y) const {
+    if (x == y.x)
+      return std::strong_equality::equal;
+    return std::strong_equality::nonequal;
+  }
+};
+
+struct X { int x; };
+struct Y { int x; };
+struct Tag {};
+// expected-note@+1 2 {{candidate}}
+Tag operator<=>(X, Y) {
+  return {};
+}
+// expected-note@+1 2 {{candidate}}
+constexpr auto operator<=>(Y y, X x) {
+  return y.x <=> x.x;
+}
+
+void foo() {
+  T t{42};
+  T t2{0};
+  U u{101};
+  auto r1 = (t <=> u);
+  ASSERT_EXPR_TYPE(r1, std::strong_equality);
+  auto r2 = (t <=> t2);
+  ASSERT_EXPR_TYPE(r2, std::strong_ordering);
+
+  auto r3 = t == u;
+  ASSERT_EXPR_TYPE(r3, bool);
+
+  (void)(t < u); // expected-error {{invalid operands to binary expression ('TestRewritting::T' and 'TestRewritting::U')}}
+
+  constexpr X x{1};
+  constexpr Y y{2};
+  constexpr auto r4 = (y < x);
+  static_assert(r4 == false);
+  constexpr auto r5 = (x < y);
+  static_assert(r5 == true);
+}
+
+} // namespace TestRewritting
Index: lib/Sema/SemaOverload.cpp
===================================================================
--- lib/Sema/SemaOverload.cpp
+++ lib/Sema/SemaOverload.cpp
@@ -831,6 +831,29 @@
   }
 }
 
+const ImplicitConversionSequence &
+OverloadCandidate::getConversion(unsigned ArgIdx) const {
+  return Conversions[getTrueArgIndex(ArgIdx)];
+}
+
+unsigned OverloadCandidate::getTrueArgIndex(unsigned Idx) const {
+  if (getRewrittenKind() != ROC_Synthesized)
+    return Idx;
+  // FIXME(EricWF): Handle these cases.
+  assert(!IsSurrogate);
+  assert(!IgnoreObjectArgument);
+  // assert(getNumParams() == 2);
+  assert(Idx < 2);
+  return Idx == 0 ? 1 : 0;
+}
+
+QualType OverloadCandidate::getParamType(unsigned ArgIdx) const {
+  ArgIdx = getTrueArgIndex(ArgIdx);
+  if (Function)
+    return Function->getParamDecl(ArgIdx)->getType();
+  return BuiltinParamTypes[ArgIdx];
+}
+
 void OverloadCandidateSet::destroyCandidates() {
   for (iterator i = begin(), e = end(); i != e; ++i) {
     for (auto &C : i->Conversions)
@@ -5935,13 +5958,15 @@
   //   (possibly cv-qualified) T2", when T2 is an enumeration type, are
   //   candidate functions.
   if (CandidateSet.getKind() == OverloadCandidateSet::CSK_Operator &&
-      !IsAcceptableNonMemberOperatorCandidate(Context, Function, Args))
+      !IsAcceptableNonMemberOperatorCandidate(Context, Function, Args)) {
     return;
+  }
 
   // C++11 [class.copy]p11: [DR1402]
   //   A defaulted move constructor that is defined as deleted is ignored by
   //   overload resolution.
   CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(Function);
+
   if (Constructor && Constructor->isDefaulted() && Constructor->isDeleted() &&
       Constructor->isMoveConstructor())
     return;
@@ -5957,6 +5982,7 @@
   Candidate.Function = Function;
   Candidate.Viable = true;
   Candidate.IsSurrogate = false;
+  Candidate.RewrittenOpKind = ROC_None;
   Candidate.IgnoreObjectArgument = false;
   Candidate.ExplicitCallArguments = Args.size();
 
@@ -6647,6 +6673,7 @@
     Candidate.Function = MethodTmpl->getTemplatedDecl();
     Candidate.Viable = false;
     Candidate.IsSurrogate = false;
+    Candidate.RewrittenOpKind = ROC_None;
     Candidate.IgnoreObjectArgument =
         cast<CXXMethodDecl>(Candidate.Function)->isStatic() ||
         ObjectType.isNull();
@@ -6711,6 +6738,7 @@
     Candidate.Function = FunctionTemplate->getTemplatedDecl();
     Candidate.Viable = false;
     Candidate.IsSurrogate = false;
+    Candidate.RewrittenOpKind = ROC_None;
     // Ignore the object argument if there is one, since we don't have an object
     // type.
     Candidate.IgnoreObjectArgument =
@@ -6882,6 +6910,7 @@
   Candidate.FoundDecl = FoundDecl;
   Candidate.Function = Conversion;
   Candidate.IsSurrogate = false;
+  Candidate.RewrittenOpKind = ROC_None;
   Candidate.IgnoreObjectArgument = false;
   Candidate.FinalConversion.setAsIdentityConversion();
   Candidate.FinalConversion.setFromType(ConvType);
@@ -7043,6 +7072,7 @@
     Candidate.Viable = false;
     Candidate.FailureKind = ovl_fail_bad_deduction;
     Candidate.IsSurrogate = false;
+    Candidate.RewrittenOpKind = ROC_None;
     Candidate.IgnoreObjectArgument = false;
     Candidate.ExplicitCallArguments = 1;
     Candidate.DeductionFailure = MakeDeductionFailureInfo(Context, Result,
@@ -7083,6 +7113,7 @@
   Candidate.Surrogate = Conversion;
   Candidate.Viable = true;
   Candidate.IsSurrogate = true;
+  Candidate.RewrittenOpKind = ROC_None;
   Candidate.IgnoreObjectArgument = false;
   Candidate.ExplicitCallArguments = Args.size();
 
@@ -7240,6 +7271,7 @@
   Candidate.FoundDecl = DeclAccessPair::make(nullptr, AS_none);
   Candidate.Function = nullptr;
   Candidate.IsSurrogate = false;
+  Candidate.RewrittenOpKind = ROC_None;
   Candidate.IgnoreObjectArgument = false;
   std::copy(ParamTys, ParamTys + Args.size(), Candidate.BuiltinParamTypes);
 
@@ -8845,20 +8877,85 @@
   }
 }
 
+/// \brief Add the rewritten and synthesized candidates for binary comparison
+//    operators. No additional semantic checking is done to see if the candidate
+//    is well formed.
+void Sema::AddRewrittenOperatorCandidates(OverloadedOperatorKind Op,
+                                          SourceLocation OpLoc,
+                                          ArrayRef<Expr *> InputArgs,
+                                          OverloadCandidateSet &CandidateSet,
+                                          bool PerformADL) {
+  assert(getLangOpts().CPlusPlus2a);
+  auto Opc = BinaryOperator::getOverloadedOpcode(Op);
+
+  bool IsRelationalOrEquality =
+      BinaryOperator::isRelationalOp(Opc) || BinaryOperator::isEqualityOp(Opc);
+  if (!IsRelationalOrEquality && Opc != BO_Cmp)
+    return;
+  assert(InputArgs.size() == 2);
+
+  OverloadedOperatorKind CmpOp = OO_Spaceship;
+  DeclarationName OpName = Context.DeclarationNames.getCXXOperatorName(CmpOp);
+
+  // Lookup possible candidates for the rewritten operator.
+  // FIXME:  should this really be done in the current scope?
+  LookupResult Operators(*this, OpName, SourceLocation(),
+                         Sema::LookupOperatorName);
+  LookupName(Operators, getCurScope());
+  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
+  const auto &Functions = Operators.asUnresolvedSet();
+
+  // AddCandidates - Add operator<=> candidates for the specified set of args,
+  // and mark all newly generated candidates as having the specified
+  // 'RewrittenOverloadCandidateKind'.
+  auto AddCandidates = [&](ArrayRef<Expr *> Args,
+                           RewrittenOverloadCandidateKind Kind) {
+    OverloadCandidateSet::RewrittenCandidateContextGuard Guard(CandidateSet);
+
+    unsigned InitialSize = CandidateSet.size();
+    AddFunctionCandidates(Functions, Args, CandidateSet);
+    AddMemberOperatorCandidates(CmpOp, OpLoc, Args, CandidateSet);
+    if (PerformADL)
+      AddArgumentDependentLookupCandidates(OpName, OpLoc, Args,
+                                           /*ExplicitTemplateArgs*/ nullptr,
+                                           CandidateSet);
+    AddBuiltinOperatorCandidates(CmpOp, OpLoc, Args, CandidateSet);
+
+    for (auto It = std::next(CandidateSet.begin(), InitialSize);
+         It != CandidateSet.end(); ++It) {
+      OverloadCandidate &Ovl = *It;
+      Ovl.RewrittenOpKind = Kind;
+    }
+  };
+
+  // If we have a relational or equality operation, add the rewritten candidates
+  // of the form: (LHS <=> RHS) @ 0
+  if (IsRelationalOrEquality)
+   AddCandidates(InputArgs, ROC_Rewritten);
+
+  // TODO: We should be able to avoid adding synthesized candidates when LHS and
+  // RHS have the same type, since the synthesized candidates for <=> should be
+  // the same as the rewritten ones. Note: It's still possible for the result
+  // of operator<=> to be usable only on the left or right side of the
+  // expression (0 @ <result>) or (<result> @ 0).
+
+  // For relational, equality, and three-way comparisons, add the rewritten and
+  // synthesized candidates of the form: 0 @ (RHS <=> LHS)
+  SmallVector<Expr *, 2> ReverseArgs(InputArgs.rbegin(), InputArgs.rend());
+  AddCandidates(ReverseArgs, ROC_Synthesized);
+}
+
 /// \brief Add function candidates found via argument-dependent lookup
 /// to the set of overloading candidates.
 ///
 /// This routine performs argument-dependent name lookup based on the
 /// given function name (which may also be an operator name) and adds
 /// all of the overload candidates found by ADL to the overload
 /// candidate set (C++ [basic.lookup.argdep]).
-void
-Sema::AddArgumentDependentLookupCandidates(DeclarationName Name,
-                                           SourceLocation Loc,
-                                           ArrayRef<Expr *> Args,
+void Sema::AddArgumentDependentLookupCandidates(
+    DeclarationName Name, SourceLocation Loc, ArrayRef<Expr *> Args,
     TemplateArgumentListInfo *ExplicitTemplateArgs,
-                                           OverloadCandidateSet& CandidateSet,
-                                           bool PartialOverloading) {
+    OverloadCandidateSet &CandidateSet, bool PartialOverloading) {
   ADLResult Fns;
 
   // FIXME: This approach for uniquing ADL results (and removing
@@ -8991,8 +9088,8 @@
   assert(Cand2.Conversions.size() == NumArgs && "Overload candidate mismatch");
   bool HasBetterConversion = false;
   for (unsigned ArgIdx = StartArg; ArgIdx < NumArgs; ++ArgIdx) {
-    bool Cand1Bad = IsIllFormedConversion(Cand1.Conversions[ArgIdx]);
-    bool Cand2Bad = IsIllFormedConversion(Cand2.Conversions[ArgIdx]);
+    bool Cand1Bad = IsIllFormedConversion(Cand1.getConversion(ArgIdx));
+    bool Cand2Bad = IsIllFormedConversion(Cand2.getConversion(ArgIdx));
     if (Cand1Bad != Cand2Bad) {
       if (Cand1Bad)
         return false;
@@ -9008,9 +9105,8 @@
   //   viable function F2 if for all arguments i, ICSi(F1) is not a worse
   //   conversion sequence than ICSi(F2), and then...
   for (unsigned ArgIdx = StartArg; ArgIdx < NumArgs; ++ArgIdx) {
-    switch (CompareImplicitConversionSequences(S, Loc,
-                                               Cand1.Conversions[ArgIdx],
-                                               Cand2.Conversions[ArgIdx])) {
+    switch (CompareImplicitConversionSequences(
+        S, Loc, Cand1.getConversion(ArgIdx), Cand2.getConversion(ArgIdx))) {
     case ImplicitConversionSequence::Better:
       // Cand1 has a better conversion sequence.
       HasBetterConversion = true;
@@ -9116,6 +9212,31 @@
     // Inherited from sibling base classes: still ambiguous.
   }
 
+  // Check C++2a tie-breakers for rewritten candidates
+  {
+    // --- F2 is a rewritten candidate ([over.match.oper]) and F1 is not.
+    if (Cand2.getRewrittenKind() && !Cand1.getRewrittenKind())
+      return true;
+    if (Cand1.getRewrittenKind() && Cand2.getRewrittenKind() &&
+        Cand2.getRewrittenKind() == ROC_Synthesized &&
+        Cand1.getRewrittenKind() != ROC_Synthesized) {
+      assert(StartArg == 0);
+      assert(NumArgs == 2);
+      // --- F1 and F2 are rewritten candidates, and F2 is a synthesized
+      // candidate with reversed order of parameters and F1 is not.
+      bool Matches = true;
+      for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
+        if (Cand1.getParamType(ArgIdx).getCanonicalType() !=
+            Cand2.getParamType(ArgIdx).getCanonicalType()) {
+          Matches = false;
+          break;
+        }
+      }
+      if (Matches)
+        return true;
+    }
+  }
+
   // Check C++17 tie-breakers for deduction guides.
   {
     auto *Guide1 = dyn_cast_or_null<CXXDeductionGuideDecl>(Cand1.Function);
@@ -12211,6 +12332,314 @@
   return CreateBuiltinUnaryOp(OpLoc, Opc, Input);
 }
 
+static ExprResult buildBinaryOperatorCandidate(Sema &S, SourceLocation OpLoc,
+                                               BinaryOperatorKind Opc,
+                                               const OverloadCandidate &Ovl,
+                                               Expr *LHSE, Expr *RHSE,
+                                               bool HadMultipleCandidates) {
+  Expr *Args[2] = {LHSE, RHSE};
+  OverloadedOperatorKind Op = BinaryOperator::getOverloadedOperator(Opc);
+  // We found a built-in operator or an overloaded operator.
+  FunctionDecl *FnDecl = Ovl.Function;
+
+  if (FnDecl) {
+    Expr *Base = nullptr;
+    // We matched an overloaded operator. Build a call to that
+    // operator.
+
+    // Convert the arguments.
+    if (CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(FnDecl)) {
+      // Ovl.Access is only meaningful for class members.
+      S.CheckMemberOperatorAccess(OpLoc, Args[0], Args[1], Ovl.FoundDecl);
+
+      ExprResult Arg1 =
+          S.PerformCopyInitialization(InitializedEntity::InitializeParameter(
+                                          S.Context, FnDecl->getParamDecl(0)),
+                                      SourceLocation(), Args[1]);
+      if (Arg1.isInvalid())
+        return ExprError();
+
+      ExprResult Arg0 = S.PerformObjectArgumentInitialization(
+          Args[0], /*Qualifier=*/nullptr, Ovl.FoundDecl, Method);
+      if (Arg0.isInvalid())
+        return ExprError();
+      Base = Args[0] = Arg0.getAs<Expr>();
+      Args[1] = Arg1.getAs<Expr>();
+    } else {
+      // Convert the arguments.
+      ExprResult Arg0 =
+          S.PerformCopyInitialization(InitializedEntity::InitializeParameter(
+                                          S.Context, FnDecl->getParamDecl(0)),
+                                      SourceLocation(), Args[0]);
+      if (Arg0.isInvalid())
+        return ExprError();
+
+      ExprResult Arg1 =
+          S.PerformCopyInitialization(InitializedEntity::InitializeParameter(
+                                          S.Context, FnDecl->getParamDecl(1)),
+                                      SourceLocation(), Args[1]);
+      if (Arg1.isInvalid())
+        return ExprError();
+      Args[0] = Arg0.getAs<Expr>();
+      Args[1] = Arg1.getAs<Expr>();
+    }
+
+    // Build the actual expression node.
+    ExprResult FnExpr = CreateFunctionRefExpr(S, FnDecl, Ovl.FoundDecl, Base,
+                                              HadMultipleCandidates, OpLoc);
+    if (FnExpr.isInvalid())
+      return ExprError();
+
+    // Determine the result type.
+    QualType ResultTy = FnDecl->getReturnType();
+    ExprValueKind VK = Expr::getValueKindForType(ResultTy);
+    ResultTy = ResultTy.getNonLValueExprType(S.Context);
+
+    CXXOperatorCallExpr *TheCall = new (S.Context) CXXOperatorCallExpr(
+        S.Context, Op, FnExpr.get(), Args, ResultTy, VK, OpLoc, S.FPFeatures);
+
+    if (S.CheckCallReturnType(FnDecl->getReturnType(), OpLoc, TheCall, FnDecl))
+      return ExprError();
+
+    ArrayRef<const Expr *> ArgsArray(Args, 2);
+    const Expr *ImplicitThis = nullptr;
+    // Cut off the implicit 'this'.
+    if (isa<CXXMethodDecl>(FnDecl)) {
+      ImplicitThis = ArgsArray[0];
+      ArgsArray = ArgsArray.slice(1);
+    }
+
+    // Check for a self move.
+    if (Op == OO_Equal)
+      S.DiagnoseSelfMove(Args[0], Args[1], OpLoc);
+
+    S.checkCall(FnDecl, nullptr, ImplicitThis, ArgsArray,
+                isa<CXXMethodDecl>(FnDecl), OpLoc, TheCall->getSourceRange(),
+                Sema::VariadicDoesNotApply);
+
+    return S.MaybeBindToTemporary(TheCall);
+
+  } else {
+    // We matched a built-in operator. Convert the arguments, then
+    // break out so that we will build the appropriate built-in
+    // operator node.
+    ExprResult ArgsRes0 =
+        S.PerformImplicitConversion(Args[0], Ovl.BuiltinParamTypes[0],
+                                    Ovl.Conversions[0], Sema::AA_Passing);
+    if (ArgsRes0.isInvalid())
+      return ExprError();
+    Args[0] = ArgsRes0.get();
+
+    ExprResult ArgsRes1 =
+        S.PerformImplicitConversion(Args[1], Ovl.BuiltinParamTypes[1],
+                                    Ovl.Conversions[1], Sema::AA_Passing);
+    if (ArgsRes1.isInvalid())
+      return ExprError();
+    Args[1] = ArgsRes1.get();
+  }
+  // We matched a built-in operator; build it.
+  return S.CreateBuiltinBinOp(OpLoc, Opc, Args[0], Args[1]);
+}
+
+namespace {
+
+/// \brief RewrittenCandidateOverloadResolver - This class handles initial
+/// overload resolution for candidate sets which include rewritten candidates.
+///
+/// Rewritten candidates haven't been fully checked for validity. They may still
+/// be invalid if:
+///    (A) The rewritten candidate is a builtin, but semantic checking of the
+///        builtin would fail.
+///    (B) The result of the "partially rewritten expression"
+///        (ie the (LHS <=> RHS) part) is ill-formed when used as an operand to
+///        (<result> @ 0) or (0 @ <result>).
+///
+/// TODO: Separate out the bits of semantic checking for builtin spaceship
+/// operators which determine validity and the return type, and use that instead
+/// of building the full expression to check validity.
+class RewrittenCandidateOverloadResolver {
+public:
+  RewrittenCandidateOverloadResolver(Sema &S, SourceLocation OpLoc,
+                                     BinaryOperatorKind Opc,
+                                     ArrayRef<Expr *> Args,
+                                     const UnresolvedSetImpl &Fns,
+                                     bool PerformADL, OverloadCandidateSet &CS)
+      : S(S), OpLoc(OpLoc), Opc(Opc), Args(Args), Fns(Fns),
+        PerformADL(PerformADL), CandidateSet(CS) {}
+
+  ExprResult ResolveRewrittenCandidates() {
+    ExprResult FinalResult = ExprError();
+    OverloadCandidateSet::iterator Best;
+    OverloadingResult OvlRes;
+    do {
+      OvlRes = CandidateSet.BestViableFunction(S, OpLoc, Best);
+    } while (RemoveNonViableRewrittenCandidates(OvlRes, Best, FinalResult));
+    return FinalResult;
+  }
+private:
+  bool RemoveNonViableRewrittenCandidates(OverloadingResult OvlRes,
+                                          OverloadCandidateSet::iterator Best,
+                                          ExprResult &FinalResult);
+  ExprResult BuildRewrittenCandidate(const OverloadCandidate &Ovl);
+
+  RewrittenCandidateOverloadResolver(
+      RewrittenCandidateOverloadResolver const &) = delete;
+  RewrittenCandidateOverloadResolver &
+  operator=(RewrittenCandidateOverloadResolver const &) = delete;
+
+private:
+  Sema &S;
+  SourceLocation OpLoc;
+  BinaryOperatorKind Opc;
+  ArrayRef<Expr *> Args;
+  const UnresolvedSetImpl &Fns;
+  bool PerformADL;
+  OverloadCandidateSet &CandidateSet;
+};
+} // end namespace
+
+
+ExprResult RewrittenCandidateOverloadResolver::BuildRewrittenCandidate(
+    const OverloadCandidate &Ovl) {
+  Expr *RewrittenArgs[2] = {Args[0], Args[1]};
+  bool IsSynthesized = Ovl.getRewrittenKind() == ROC_Synthesized;
+  if (IsSynthesized)
+    std::swap(RewrittenArgs[0], RewrittenArgs[1]);
+
+  // Supress diagnostics when building the expressions for the specified
+  // candidate. If evaluation fails the candidate will be marked non-viable
+  // and the best viable candidate re-computed.
+  Sema::TentativeAnalysisScope DiagnosticScopeGuard(S);
+
+  // Build the '(LHS <=> RHS)' operand to the full expression.
+  ExprResult RewrittenRes = buildBinaryOperatorCandidate(
+      S, OpLoc, BO_Cmp, Ovl, RewrittenArgs[0], RewrittenArgs[1],
+      /*HadMultipleCandidates*/ false);
+  if (RewrittenRes.isInvalid())
+    return ExprError();
+
+  // Now attempt to build the full expression '(LHS <=> RHS) @ 0' using the
+  // evaluated operand and the literal 0.
+  llvm::APInt I =
+      llvm::APInt::getNullValue(S.Context.getIntWidth(S.Context.IntTy));
+  Expr *Zero =
+      IntegerLiteral::Create(S.Context, I, S.Context.IntTy, SourceLocation());
+
+  Expr *NewLHS = RewrittenRes.get();
+  Expr *NewRHS = Zero;
+  if (Ovl.getRewrittenKind() == ROC_Synthesized)
+    std::swap(NewLHS, NewRHS);
+
+  return S.CreateOverloadedBinOp(OpLoc, Opc, Fns, NewLHS, NewRHS, PerformADL,
+                                 /*AllowRewrittenCandidates*/ false);
+}
+
+/// Rewritten candidates have been added but not checked for validity. They
+/// could still be non-viable if:
+///  (A) The rewritten call (x <=> y) is a builtin, but it will be ill-formed
+///      when built (for example it has narrowing conversions).
+///  (B) The expression (x <=> y) @ 0 is ill-formed for the result of (x <=> y).
+///
+/// If either is the case, this function should be considered non-viable and
+/// another best viable function needs to be computed.
+///
+/// Therefore, we do the following:
+///  (1) If we have no viable candidate, or a deleted candidate, stop.
+///      Otherwise, if the best-viable-candidate or a set of ambiguous
+///      candidates, and none of which are rewritten, stop.
+///
+///  (2) If the best viable candidate is a rewritten candidate, build and
+///      check the full expression for that candidate. If it succeeds return
+///      that result. Otherwise, mark the candidate as non-viable, re-compute
+///      the best viable function, and continue.
+///
+///  (3) If we have ambiguity order the set of viable candidates. For each
+///      rewritten candidate causing ambiguity:
+///
+///        (3.1) build the full expression for the specified candidate.
+///        (3.2) If the result is invalid, mark the candidate as non-viable.
+///        (3.3) Otherwise, cache the valid result for later.
+///
+///      If any of the rewritten candidates were marked non-viable, recompute
+///      the best viable function and continue. Otherwise, stop since the
+///      lookup is still ambiguous,
+bool RewrittenCandidateOverloadResolver::RemoveNonViableRewrittenCandidates(
+    OverloadingResult OvlRes, OverloadCandidateSet::iterator Best,
+    ExprResult &FinalResult) {
+  switch (OvlRes) {
+  case OR_Deleted:
+    // FIXME(EricWF): If we've found a deleted rewritten operator, it's
+    // possible we should have never considered it a viable candidate.
+  case OR_No_Viable_Function:
+    return false;
+
+  case OR_Success: {
+    OverloadCandidate &Ovl = *Best;
+    if (!Ovl.getRewrittenKind())
+      return false;
+    // Build the full expression for the rewritten candidate, and return it if
+    // it's valid. Otherwise mork this candidate as non-viable and continue.
+    ExprResult Res = BuildRewrittenCandidate(Ovl);
+    if (Res.isInvalid()) {
+      Ovl.Viable = false;
+      return true;
+    }
+    FinalResult = Res;
+    return false;
+  }
+  case OR_Ambiguous: {
+    SmallVector<OverloadCandidate *, 16> Overloads;
+    // Gather all viable candidates. If non of the viable candidates are
+    // rewritten, stop.
+    bool HasRewritten = false;
+    for (auto &Ovl : CandidateSet) {
+      if (!Ovl.Viable)
+        continue;
+      HasRewritten |= Ovl.getRewrittenKind();
+      Overloads.push_back(&Ovl);
+    }
+    if (!HasRewritten)
+      return false;
+    auto CmpOverloads = [&](const OverloadCandidate *C1,
+                            const OverloadCandidate *C2) {
+      return isBetterOverloadCandidate(S, *C1, *C2, CandidateSet.getLocation(),
+                                       CandidateSet.getKind());
+    };
+    // Sort the candidate functions based on their partial ordering,
+    // and find the first N functions which rank equally.
+    std::sort(Overloads.begin(), Overloads.end(), CmpOverloads);
+    auto End = std::lower_bound(Overloads.begin(), Overloads.end(),
+                                Overloads[0], CmpOverloads);
+
+    int NumViableCandidates = 0;
+    ExprResult ViableRewritten = ExprError();
+    for (auto It = Overloads.begin(); It != End; ++It) {
+      OverloadCandidate &Ovl = **It;
+      if (Ovl.getRewrittenKind()) {
+        ExprResult Res = BuildRewrittenCandidate(Ovl);
+        if (Res.isInvalid()) {
+          Ovl.Viable = false;
+          continue;
+        }
+        ViableRewritten = Res;
+      }
+      ++NumViableCandidates;
+    }
+    // If only one of the candidates turns out to be viable, and it's a
+    // rewritten candidate, return that candidate as the result.
+    if (NumViableCandidates == 1 && !ViableRewritten.isInvalid()) {
+      FinalResult = ViableRewritten;
+      return false;
+    }
+    // If none of the rewritten
+    if (NumViableCandidates > 1)
+      return false;
+    return true;
+  }
+  }
+  llvm_unreachable("unhandled case");
+}
+
 /// \brief Create a binary operation that may resolve to an overloaded
 /// operator.
 ///
@@ -12227,11 +12656,11 @@
 ///
 /// \param LHS Left-hand argument.
 /// \param RHS Right-hand argument.
-ExprResult
-Sema::CreateOverloadedBinOp(SourceLocation OpLoc,
+ExprResult Sema::CreateOverloadedBinOp(SourceLocation OpLoc,
                                        BinaryOperatorKind Opc,
-                            const UnresolvedSetImpl &Fns,
-                            Expr *LHS, Expr *RHS, bool PerformADL) {
+                                       const UnresolvedSetImpl &Fns, Expr *LHS,
+                                       Expr *RHS, bool PerformADL,
+                                       bool AllowRewrittenCandidates) {
   Expr *Args[2] = { LHS, RHS };
   LHS=RHS=nullptr; // Please use only Args instead of LHS/RHS couple
 
@@ -12313,119 +12742,40 @@
   // Add builtin operator candidates.
   AddBuiltinOperatorCandidates(Op, OpLoc, Args, CandidateSet);
 
-  bool HadMultipleCandidates = (CandidateSet.size() > 1);
+  bool BeforeRewrittenSize = CandidateSet.size();
+  // C++2a Add rewritten and synthesized operator candidates.
+  if (getLangOpts().CPlusPlus2a && AllowRewrittenCandidates)
+    AddRewrittenOperatorCandidates(Op, OpLoc, Args, CandidateSet, PerformADL);
 
-  // Perform overload resolution.
-  OverloadCandidateSet::iterator Best;
-  switch (CandidateSet.BestViableFunction(*this, OpLoc, Best)) {
-    case OR_Success: {
-      // We found a built-in operator or an overloaded operator.
-      FunctionDecl *FnDecl = Best->Function;
-
-      if (FnDecl) {
-        Expr *Base = nullptr;
-        // We matched an overloaded operator. Build a call to that
-        // operator.
+  bool HasRewrittenCandidates = BeforeRewrittenSize != CandidateSet.size();
 
-        // Convert the arguments.
-        if (CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(FnDecl)) {
-          // Best->Access is only meaningful for class members.
-          CheckMemberOperatorAccess(OpLoc, Args[0], Args[1], Best->FoundDecl);
+  if (HasRewrittenCandidates) {
+    RewrittenCandidateOverloadResolver RewrittenOvlResolver(
+        *this, OpLoc, Opc, Args, Fns, PerformADL, CandidateSet);
 
-          ExprResult Arg1 =
-            PerformCopyInitialization(
-              InitializedEntity::InitializeParameter(Context,
-                                                     FnDecl->getParamDecl(0)),
-              SourceLocation(), Args[1]);
-          if (Arg1.isInvalid())
-            return ExprError();
+    // Perform initial overload resolution that includes partially checked
+    // rewritten candidates, removing rewritten candidates which turn out to be
+    // invalid as needed.
+    ExprResult RewrittenResult =
+        RewrittenOvlResolver.ResolveRewrittenCandidates();
 
-          ExprResult Arg0 =
-            PerformObjectArgumentInitialization(Args[0], /*Qualifier=*/nullptr,
-                                                Best->FoundDecl, Method);
-          if (Arg0.isInvalid())
-            return ExprError();
-          Base = Args[0] = Arg0.getAs<Expr>();
-          Args[1] = RHS = Arg1.getAs<Expr>();
-        } else {
-          // Convert the arguments.
-          ExprResult Arg0 = PerformCopyInitialization(
-            InitializedEntity::InitializeParameter(Context,
-                                                   FnDecl->getParamDecl(0)),
-            SourceLocation(), Args[0]);
-          if (Arg0.isInvalid())
-            return ExprError();
-
-          ExprResult Arg1 =
-            PerformCopyInitialization(
-              InitializedEntity::InitializeParameter(Context,
-                                                     FnDecl->getParamDecl(1)),
-              SourceLocation(), Args[1]);
-          if (Arg1.isInvalid())
-            return ExprError();
-          Args[0] = LHS = Arg0.getAs<Expr>();
-          Args[1] = RHS = Arg1.getAs<Expr>();
-        }
-
-        // Build the actual expression node.
-        ExprResult FnExpr = CreateFunctionRefExpr(*this, FnDecl,
-                                                  Best->FoundDecl, Base,
-                                                  HadMultipleCandidates, OpLoc);
-        if (FnExpr.isInvalid())
-          return ExprError();
-
-        // Determine the result type.
-        QualType ResultTy = FnDecl->getReturnType();
-        ExprValueKind VK = Expr::getValueKindForType(ResultTy);
-        ResultTy = ResultTy.getNonLValueExprType(Context);
-
-        CXXOperatorCallExpr *TheCall =
-          new (Context) CXXOperatorCallExpr(Context, Op, FnExpr.get(),
-                                            Args, ResultTy, VK, OpLoc,
-                                            FPFeatures);
-
-        if (CheckCallReturnType(FnDecl->getReturnType(), OpLoc, TheCall,
-                                FnDecl))
-          return ExprError();
-
-        ArrayRef<const Expr *> ArgsArray(Args, 2);
-        const Expr *ImplicitThis = nullptr;
-        // Cut off the implicit 'this'.
-        if (isa<CXXMethodDecl>(FnDecl)) {
-          ImplicitThis = ArgsArray[0];
-          ArgsArray = ArgsArray.slice(1);
-        }
-
-        // Check for a self move.
-        if (Op == OO_Equal)
-          DiagnoseSelfMove(Args[0], Args[1], OpLoc);
-
-        checkCall(FnDecl, nullptr, ImplicitThis, ArgsArray,
-                  isa<CXXMethodDecl>(FnDecl), OpLoc, TheCall->getSourceRange(),
-                  VariadicDoesNotApply);
-
-        return MaybeBindToTemporary(TheCall);
-      } else {
-        // We matched a built-in operator. Convert the arguments, then
-        // break out so that we will build the appropriate built-in
-        // operator node.
-        ExprResult ArgsRes0 =
-            PerformImplicitConversion(Args[0], Best->BuiltinParamTypes[0],
-                                      Best->Conversions[0], AA_Passing);
-        if (ArgsRes0.isInvalid())
-          return ExprError();
-        Args[0] = ArgsRes0.get();
-
-        ExprResult ArgsRes1 =
-            PerformImplicitConversion(Args[1], Best->BuiltinParamTypes[1],
-                                      Best->Conversions[1], AA_Passing);
-        if (ArgsRes1.isInvalid())
-          return ExprError();
-        Args[1] = ArgsRes1.get();
-        break;
-      }
+    // If overload resolution was successful and the result was a re-written
+    // overload candidate, then that candidate was evaluated and we can return
+    // that value directly.
+    if (!RewrittenResult.isInvalid())
+      return RewrittenResult;
   }
 
+  // Perform final overload resolution.
+  bool HadMultipleCandidates = (CandidateSet.size() > 1);
+  OverloadCandidateSet::iterator Best;
+  switch (CandidateSet.BestViableFunction(*this, OpLoc, Best)) {
+  case OR_Success:
+    assert(
+        !Best->getRewrittenKind() &&
+        "rewritten candidates should have already been resolved and evaluated");
+    return buildBinaryOperatorCandidate(*this, OpLoc, Opc, *Best, Args[0],
+                                        Args[1], HadMultipleCandidates);
   case OR_No_Viable_Function: {
     // C++ [over.match.oper]p9:
     //   If the operator is the operator , [...] and there are no
@@ -12438,15 +12788,15 @@
     // operator do not fall through to handling in built-in, but report that
     // no overloaded assignment operator found
     ExprResult Result = ExprError();
-      if (Args[0]->getType()->isRecordType() &&
-          Opc >= BO_Assign && Opc <= BO_OrAssign) {
+    if (Args[0]->getType()->isRecordType() && Opc >= BO_Assign &&
+        Opc <= BO_OrAssign) {
       Diag(OpLoc, diag::err_ovl_no_viable_oper)
-             << BinaryOperator::getOpcodeStr(Opc)
-             << Args[0]->getSourceRange() << Args[1]->getSourceRange();
+          << BinaryOperator::getOpcodeStr(Opc) << Args[0]->getSourceRange()
+          << Args[1]->getSourceRange();
       if (Args[0]->getType()->isIncompleteType()) {
         Diag(OpLoc, diag::note_assign_lhs_incomplete)
-            << Args[0]->getType()
-            << Args[0]->getSourceRange() << Args[1]->getSourceRange();
+            << Args[0]->getType() << Args[0]->getSourceRange()
+            << Args[1]->getSourceRange();
       }
     } else {
       // This is an erroneous use of an operator which can be overloaded by
@@ -12467,12 +12817,11 @@
                                   BinaryOperator::getOpcodeStr(Opc), OpLoc);
     return Result;
   }
-
   case OR_Ambiguous:
     Diag(OpLoc, diag::err_ovl_ambiguous_oper_binary)
-          << BinaryOperator::getOpcodeStr(Opc)
-          << Args[0]->getType() << Args[1]->getType()
-          << Args[0]->getSourceRange() << Args[1]->getSourceRange();
+        << BinaryOperator::getOpcodeStr(Opc) << Args[0]->getType()
+        << Args[1]->getType() << Args[0]->getSourceRange()
+        << Args[1]->getSourceRange();
     CandidateSet.NoteCandidates(*this, OCD_ViableCandidates, Args,
                                 BinaryOperator::getOpcodeStr(Opc), OpLoc);
     return ExprError();
@@ -12490,8 +12839,7 @@
       return ExprError();
     } else {
       Diag(OpLoc, diag::err_ovl_deleted_oper)
-          << Best->Function->isDeleted()
-          << BinaryOperator::getOpcodeStr(Opc)
+          << Best->Function->isDeleted() << BinaryOperator::getOpcodeStr(Opc)
           << getDeletedOrUnavailableSuffix(Best->Function)
           << Args[0]->getSourceRange() << Args[1]->getSourceRange();
     }
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -2777,6 +2777,11 @@
   void AddBuiltinOperatorCandidates(OverloadedOperatorKind Op,
                                     SourceLocation OpLoc, ArrayRef<Expr *> Args,
                                     OverloadCandidateSet& CandidateSet);
+  void AddRewrittenOperatorCandidates(OverloadedOperatorKind Op,
+                                      SourceLocation OpLoc,
+                                      ArrayRef<Expr *> Args,
+                                      OverloadCandidateSet &CandidateSet,
+                                      bool PerformADL);
   void AddArgumentDependentLookupCandidates(DeclarationName Name,
                                             SourceLocation Loc,
                                             ArrayRef<Expr *> Args,
@@ -2919,11 +2924,10 @@
                                      const UnresolvedSetImpl &Fns,
                                      Expr *input, bool RequiresADL = true);
 
-  ExprResult CreateOverloadedBinOp(SourceLocation OpLoc,
-                                   BinaryOperatorKind Opc,
-                                   const UnresolvedSetImpl &Fns,
-                                   Expr *LHS, Expr *RHS,
-                                   bool RequiresADL = true);
+  ExprResult CreateOverloadedBinOp(SourceLocation OpLoc, BinaryOperatorKind Opc,
+                                   const UnresolvedSetImpl &Fns, Expr *LHS,
+                                   Expr *RHS, bool RequiresADL = true,
+                                   bool AllowRewrittenCandidates = true);
 
   ExprResult CreateOverloadedArraySubscriptExpr(SourceLocation LLoc,
                                                 SourceLocation RLoc,
@@ -10366,11 +10370,13 @@
                             const FunctionProtoType *Proto,
                             SourceLocation Loc);
 
+public:
   void checkCall(NamedDecl *FDecl, const FunctionProtoType *Proto,
                  const Expr *ThisArg, ArrayRef<const Expr *> Args,
                  bool IsMemberFunction, SourceLocation Loc, SourceRange Range,
                  VariadicCallType CallType);
 
+private:
   bool CheckObjCString(Expr *Arg);
   ExprResult CheckOSLogFormatStringArg(Expr *Arg);
 
Index: include/clang/Sema/Overload.h
===================================================================
--- include/clang/Sema/Overload.h
+++ include/clang/Sema/Overload.h
@@ -72,6 +72,17 @@
     OCD_ViableCandidates
   };
 
+  /// OperatorOverloadCandidateKind - The kind of the operator candidate in
+  /// accordance with [over.match.oper].
+  enum RewrittenOverloadCandidateKind : unsigned char {
+    /// Not a rewritten candidate.
+    ROC_None,
+    /// Rewritten but not synthesized.
+    ROC_Rewritten,
+    /// Both rewritten and synthesized.
+    ROC_Synthesized
+  };
+
   /// ImplicitConversionKind - The kind of implicit conversion used to
   /// convert an argument to a parameter's type. The enumerator values
   /// match with the table titled 'Conversions' in [over.ics.scs] and are listed
@@ -755,21 +766,25 @@
     ConversionFixItGenerator Fix;
 
     /// Viable - True to indicate that this overload candidate is viable.
-    bool Viable;
+    bool Viable : 1;
 
     /// IsSurrogate - True to indicate that this candidate is a
     /// surrogate for a conversion to a function pointer or reference
     /// (C++ [over.call.object]).
-    bool IsSurrogate;
+    bool IsSurrogate : 1;
 
     /// IgnoreObjectArgument - True to indicate that the first
     /// argument's conversion, which for this function represents the
     /// implicit object argument, should be ignored. This will be true
     /// when the candidate is a static member function (where the
     /// implicit object argument is just a placeholder) or a
     /// non-static member function when the call doesn't have an
     /// object argument.
-    bool IgnoreObjectArgument;
+    bool IgnoreObjectArgument : 1;
+
+    /// RewrittenKind - For rewritten operator candidates, the kind of rewritten
+    /// candidate it is: rewritten or synthesized.
+    unsigned char RewrittenOpKind : 2;
 
     /// FailureKind - The reason why this candidate is not viable.
     /// Actually an OverloadFailureKind.
@@ -812,6 +827,19 @@
       return CanFix;
     }
 
+    /// \brief Return the "true" index for the specified argument. If this
+    /// is not a synthesized candidate, 'Idx' is returned. Otherwise the
+    /// index corresponding to the reversed parameter is returned.
+    unsigned getTrueArgIndex(unsigned Idx) const;
+
+    /// \brief Return the conversion sequence for the specified argument index.
+    /// If this is a synthesized candidate, the argument index is reversed.
+    const ImplicitConversionSequence &getConversion(unsigned ArgIdx) const;
+
+    /// \brief Return the parameter type for the specified index. If this is
+    /// a synthesized candidate, the argument index is reversed.
+    QualType getParamType(unsigned ArgIdx) const;
+
     unsigned getNumParams() const {
       if (IsSurrogate) {
         auto STy = Surrogate->getConversionType();
@@ -823,6 +851,10 @@
         return Function->getNumParams();
       return ExplicitCallArguments;
     }
+
+    RewrittenOverloadCandidateKind getRewrittenKind() const {
+      return static_cast<RewrittenOverloadCandidateKind>(RewrittenOpKind);
+    }
   };
 
   /// OverloadCandidateSet - A set of overload candidates, used in C++
@@ -853,8 +885,10 @@
 
   private:
     SmallVector<OverloadCandidate, 16> Candidates;
-    llvm::SmallPtrSet<Decl *, 16> Functions;
+    using DeclSet = llvm::SmallPtrSet<Decl *, 16>;
+    DeclSet Functions;
 
+  private:
     // Allocator for ConversionSequenceLists. We store the first few of these
     // inline to avoid allocation for small sets.
     llvm::BumpPtrAllocator SlabAllocator;
@@ -896,6 +930,31 @@
     void destroyCandidates();
 
   public:
+    /// \brief RewrittenCandidateContextGuard - Enter a context suitable for
+    /// adding rewritten overload candidates. Rewritten candidates can
+    /// re-consider previously seen functions, so save and clear the list of
+    /// considered functions, and restore it when the rewritten context is
+    /// exited.
+    struct RewrittenCandidateContextGuard {
+      RewrittenCandidateContextGuard(OverloadCandidateSet &CS)
+          : CandidateSet(CS) {
+        assert(CS.Kind == CSK_Operator &&
+               "rewritten expressions can only occur for operators");
+        OldFunctions = std::move(CandidateSet.Functions);
+      }
+
+      ~RewrittenCandidateContextGuard() {
+        CandidateSet.Functions.insert(OldFunctions.begin(), OldFunctions.end());
+      }
+
+    private:
+      OverloadCandidateSet &CandidateSet;
+      DeclSet OldFunctions;
+    };
+
+    friend struct RewrittenCandidateContextGuard;
+
+  public:
     OverloadCandidateSet(SourceLocation Loc, CandidateSetKind CSK)
         : Loc(Loc), Kind(CSK) {}
     OverloadCandidateSet(const OverloadCandidateSet &) = delete;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to