EricWF updated this revision to Diff 76167.
EricWF added a comment.

- Suppress diagnostics caused during the initial name lookup for 
`await_transform`, `return_value` and `return_void` since these diagnostics 
will get re-emitted if we actually build the calls.
- add more tests.


https://reviews.llvm.org/D26057

Files:
  include/clang/AST/ExprCXX.h
  include/clang/AST/RecursiveASTVisitor.h
  include/clang/Basic/DiagnosticSemaKinds.td
  include/clang/Basic/StmtNodes.td
  include/clang/Sema/Sema.h
  lib/AST/Expr.cpp
  lib/AST/ExprClassification.cpp
  lib/AST/ExprConstant.cpp
  lib/AST/ItaniumMangle.cpp
  lib/AST/StmtPrinter.cpp
  lib/AST/StmtProfile.cpp
  lib/Sema/SemaCoroutine.cpp
  lib/Sema/SemaExceptionSpec.cpp
  lib/Sema/TreeTransform.h
  lib/Serialization/ASTReaderStmt.cpp
  lib/Serialization/ASTWriterStmt.cpp
  lib/StaticAnalyzer/Core/ExprEngine.cpp
  test/SemaCXX/coroutines.cpp
  tools/libclang/CXCursor.cpp

Index: tools/libclang/CXCursor.cpp
===================================================================
--- tools/libclang/CXCursor.cpp
+++ tools/libclang/CXCursor.cpp
@@ -231,6 +231,7 @@
   case Stmt::TypeTraitExprClass:
   case Stmt::CoroutineBodyStmtClass:
   case Stmt::CoawaitExprClass:
+  case Stmt::CoawaitDependentExprClass:
   case Stmt::CoreturnStmtClass:
   case Stmt::CoyieldExprClass:
   case Stmt::CXXBindTemporaryExprClass:
Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -59,14 +59,14 @@
 template <typename... T>
 struct std::experimental::coroutine_traits<int, T...> {};
 
-int no_promise_type() {
-  co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+  co_await a;
 }
 
 template <>
 struct std::experimental::coroutine_traits<double, double> { typedef int promise_type; };
-double bad_promise_type(double) {
-  co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+  co_await a;
 }
 
 template <>
@@ -77,7 +77,7 @@
   co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits<double, int>::promise_type'}}
 }
 
-struct promise; // expected-note 2{{forward declaration}}
+struct promise; // expected-note {{forward declaration}}
 struct promise_void;
 struct void_tag {};
 template <typename... T>
@@ -94,9 +94,7 @@
 }
 
 // FIXME: This diagnostic is terrible.
-void undefined_promise() { // expected-error {{variable has incomplete type 'promise_type'}}
-  // FIXME: This diagnostic doesn't make any sense.
-  // expected-error@-2 {{incomplete definition of type 'promise'}}
+void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<void>::promise_type' (aka 'promise') is an incomplete type}}
   co_await a;
 }
 
@@ -239,6 +237,21 @@
   };
   template void await_template(outer); // expected-note {{instantiation}}
   template void await_template_2(outer);
+
+  template <typename T, typename U> coro<T> await_template_3(U t) {
+    co_await t;
+  }
+  struct transform_awaitable {};
+  struct transformed {};
+  struct transform_promise {
+    coro<transform_promise> get_return_object();
+    suspend_always initial_suspend();
+    suspend_always final_suspend();
+    transformed await_transform(transform_awaitable);
+  };
+  void operator co_await(transform_awaitable) = delete;
+  awaitable operator co_await(transformed);
+  template coro<transform_promise> await_template_3<transform_promise>(transform_awaitable);
 }
 
 struct yield_fn_tag {};
@@ -355,20 +368,68 @@
 int *current_exception();
 }
 
-struct bad_promise_8 {
+struct bad_promise_base {
+private:
+  void return_void(); // expected-note {{declared private here}}
+};
+struct bad_promise_8 : bad_promise_base {
   coro<bad_promise_8> get_return_object();
   suspend_always initial_suspend();
   suspend_always final_suspend();
-  void return_void();
   void set_exception();                                   // expected-note {{function not viable}}
   void set_exception(int *) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}}
   void set_exception(void *);                             // expected-note {{candidate function}}
 };
 coro<bad_promise_8> calls_set_exception() {
   // expected-error@-1 {{call to unavailable member function 'set_exception'}}
+  // expected-error@-2 {{'return_void' is a private member of 'bad_promise_base'}}
   co_await a;
 }
 
+struct bad_promise_9 {
+  coro<bad_promise_9> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void await_transform(void *);                                // expected-note {{candidate}}
+  awaitable await_transform(int) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}}
+  void return_void();
+};
+coro<bad_promise_9> calls_await_transform() {
+  co_await 42; // expected-error {{call to unavailable member function 'await_transform'}}
+  // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}}
+}
+
+struct bad_promise_10 {
+  coro<bad_promise_10> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  int await_transform;
+  void return_void();
+};
+coro<bad_promise_10> bad_coawait() {
+  // FIXME this diagnostic is terrible
+  co_await 42; // expected-error {{called object type 'int' is not a function or function pointer}}
+  // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}}
+}
+
+struct call_operator {
+  template <class ...Args> awaitable operator()(Args...) const { return a; }
+};
+void ret_void();
+struct good_promise_1 {
+  coro<good_promise_1> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  static const call_operator await_transform;
+  using Fn = void(*)();
+  Fn return_void = ret_void;
+};
+const call_operator good_promise_1::await_transform;
+coro<good_promise_1> ok_static_coawait() {
+  // FIXME this diagnostic is terrible
+  co_await 42;
+}
+
 template<> struct std::experimental::coroutine_traits<int, int, const char**>
 { using promise_type = promise; };
 
Index: lib/StaticAnalyzer/Core/ExprEngine.cpp
===================================================================
--- lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -774,6 +774,7 @@
     case Stmt::FunctionParmPackExprClass:
     case Stmt::CoroutineBodyStmtClass:
     case Stmt::CoawaitExprClass:
+    case Stmt::CoawaitDependentExprClass:
     case Stmt::CoreturnStmtClass:
     case Stmt::CoyieldExprClass:
     case Stmt::SEHTryStmtClass:
Index: lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- lib/Serialization/ASTWriterStmt.cpp
+++ lib/Serialization/ASTWriterStmt.cpp
@@ -315,6 +315,11 @@
   llvm_unreachable("unimplemented");
 }
 
+void ASTStmtWriter::VisitCoawaitDependentExpr(CoawaitDependentExpr *S) {
+  // FIXME: Implement coroutine serialization.
+  llvm_unreachable("unimplemented");
+}
+
 void ASTStmtWriter::VisitCoyieldExpr(CoyieldExpr *S) {
   // FIXME: Implement coroutine serialization.
   llvm_unreachable("unimplemented");
Index: lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- lib/Serialization/ASTReaderStmt.cpp
+++ lib/Serialization/ASTReaderStmt.cpp
@@ -400,6 +400,11 @@
   llvm_unreachable("unimplemented");
 }
 
+void ASTStmtReader::VisitCoawaitDependentExpr(CoawaitDependentExpr *S) {
+  // FIXME: Implement coroutine serialization.
+  llvm_unreachable("unimplemented");
+}
+
 void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *S) {
   // FIXME: Implement coroutine serialization.
   llvm_unreachable("unimplemented");
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -1318,6 +1318,16 @@
     return getSema().BuildCoawaitExpr(CoawaitLoc, Result);
   }
 
+  /// \brief Build a new co_await expression.
+  ///
+  /// By default, performs semantic analysis to build the new expression.
+  /// Subclasses may override this routine to provide different behavior.
+  ExprResult RebuildCoawaitDependentExpr(SourceLocation CoawaitLoc,
+                                         Expr *Result,
+                                         const UnresolvedSet<16> &Candidates) {
+    return getSema().BuildCoawaitDependentExpr(CoawaitLoc, Result, Candidates);
+  }
+
   /// \brief Build a new co_yield expression.
   ///
   /// By default, performs semantic analysis to build the new expression.
@@ -6684,9 +6694,22 @@
   return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get());
 }
 
-template<typename Derived>
+template <typename Derived>
 ExprResult
-TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) {
+TreeTransform<Derived>::TransformCoawaitDependentExpr(CoawaitDependentExpr *E) {
+  ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
+                                                        /*NotCopyInit*/ false);
+  if (Result.isInvalid())
+    return ExprError();
+
+  // Always rebuild; we don't know if this needs to be injected into a new
+  // context or if the promise type has changed.
+  return getDerived().RebuildCoawaitDependentExpr(
+      E->getKeywordLoc(), Result.get(), E->getOperatorCandidates());
+}
+
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) {
   ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
                                                         /*NotCopyInit*/false);
   if (Result.isInvalid())
Index: lib/Sema/SemaExceptionSpec.cpp
===================================================================
--- lib/Sema/SemaExceptionSpec.cpp
+++ lib/Sema/SemaExceptionSpec.cpp
@@ -1146,6 +1146,7 @@
   case Expr::ArraySubscriptExprClass:
   case Expr::OMPArraySectionExprClass:
   case Expr::BinaryOperatorClass:
+  case Expr::CoawaitDependentExprClass:
   case Expr::CompoundAssignOperatorClass:
   case Expr::CStyleCastExprClass:
   case Expr::CXXStaticCastExprClass:
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -21,21 +21,32 @@
 using namespace clang;
 using namespace sema;
 
+static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
+                         SourceLocation Loc) {
+  DeclarationName DN = S.PP.getIdentifierInfo(Name);
+  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
+  // Suppress diagnostics when a private member is selected. The same warnings
+  // will be produced again when building the call.
+  LR.suppressDiagnostics();
+  return S.LookupQualifiedName(LR, RD);
+}
+
 /// Look up the std::coroutine_traits<...>::promise_type for the given
 /// function type.
 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
-                                  SourceLocation Loc) {
+                                  SourceLocation KwLoc,
+                                  SourceLocation FuncLoc) {
   // FIXME: Cache std::coroutine_traits once we've found it.
   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
   if (!StdExp) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+    S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found);
     return QualType();
   }
 
   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
-                      Loc, Sema::LookupOrdinaryName);
+                      FuncLoc, Sema::LookupOrdinaryName);
   if (!S.LookupQualifiedName(Result, StdExp)) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+    S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found);
     return QualType();
   }
 
@@ -49,52 +60,59 @@
   }
 
   // Form template argument list for coroutine_traits<R, P1, P2, ...>.
-  TemplateArgumentListInfo Args(Loc, Loc);
+  TemplateArgumentListInfo Args(KwLoc, KwLoc);
   Args.addArgument(TemplateArgumentLoc(
       TemplateArgument(FnType->getReturnType()),
-      S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
+      S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc)));
   // FIXME: If the function is a non-static member function, add the type
   // of the implicit object parameter before the formal parameters.
   for (QualType T : FnType->getParamTypes())
     Args.addArgument(TemplateArgumentLoc(
-        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
+        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
 
   // Build the template-id.
   QualType CoroTrait =
-      S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
+      S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
   if (CoroTrait.isNull())
     return QualType();
-  if (S.RequireCompleteType(Loc, CoroTrait,
+  if (S.RequireCompleteType(KwLoc, CoroTrait,
                             diag::err_coroutine_traits_missing_specialization))
     return QualType();
 
   CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
   assert(RD && "specialization of class template is not a class?");
 
   // Look up the ::promise_type member.
-  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
+  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
                  Sema::LookupOrdinaryName);
   S.LookupQualifiedName(R, RD);
   auto *Promise = R.getAsSingle<TypeDecl>();
   if (!Promise) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
+    S.Diag(FuncLoc,
+           diag::err_implied_std_coroutine_traits_promise_type_not_found)
         << RD;
     return QualType();
   }
-
   // The promise type is required to be a class type.
   QualType PromiseType = S.Context.getTypeDeclType(Promise);
-  if (!PromiseType->getAsCXXRecordDecl()) {
-    // Use the fully-qualified name of the type.
+
+  auto buildNNS = [&]() {
     auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
     NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
                                       CoroTrait.getTypePtr());
-    PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
+    return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
+  };
 
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
-        << PromiseType;
+  RD = PromiseType->getAsCXXRecordDecl();
+  if (!RD) {
+    S.Diag(FuncLoc,
+           diag::err_implied_std_coroutine_traits_promise_type_not_class)
+        << buildNNS();
     return QualType();
   }
+  if (S.RequireCompleteType(FuncLoc, buildNNS(),
+                            diag::err_coroutine_promise_type_incomplete))
+    return QualType();
 
   return PromiseType;
 }
@@ -173,10 +191,11 @@
 
   // If we don't have a promise variable, build one now.
   if (!ScopeInfo->CoroutinePromise) {
-    QualType T = FD->getType()->isDependentType()
-                     ? S.Context.DependentTy
-                     : lookupPromiseType(
-                           S, FD->getType()->castAs<FunctionProtoType>(), Loc);
+    QualType T =
+        FD->getType()->isDependentType()
+            ? S.Context.DependentTy
+            : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
+                                Loc, FD->getLocation());
     if (T.isNull())
       return nullptr;
 
@@ -215,11 +234,20 @@
 
 /// Build a call to 'operator co_await' if there is a suitable operator for
 /// the given expression.
-static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
-                                           SourceLocation Loc, Expr *E) {
+static UnresolvedSet<16> lookupOperatorCoawaitCall(Sema &SemaRef, Scope *S,
+                                                   SourceLocation Loc,
+                                                   Expr *E) {
   UnresolvedSet<16> Functions;
   SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
                                        Functions);
+  return Functions;
+}
+
+/// Build a call to 'operator co_await' if there is a suitable operator for
+/// the given expression.
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
+                                           Expr *E,
+                                           const UnresolvedSet<16> &Functions) {
   return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
 }
 
@@ -268,6 +296,22 @@
   return Calls;
 }
 
+static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
+                                   SourceLocation Loc, StringRef Name,
+                                   MutableArrayRef<Expr *> Args) {
+  assert(Coroutine->CoroutinePromise && "no promise for coroutine");
+
+  // Form a reference to the promise.
+  auto *Promise = Coroutine->CoroutinePromise;
+  ExprResult PromiseRef = S.BuildDeclRefExpr(
+      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
+  if (PromiseRef.isInvalid())
+    return ExprError();
+
+  // Call 'yield_value', passing in E.
+  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
+}
+
 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
   if (!Coroutine) {
@@ -279,13 +323,51 @@
     if (R.isInvalid()) return ExprError();
     E = R.get();
   }
+  return BuildCoawaitDependentExpr(Loc, E,
+                                   lookupOperatorCoawaitCall(*this, S, Loc, E));
+}
+
+ExprResult
+Sema::BuildCoawaitDependentExpr(SourceLocation Loc, Expr *E,
+                                const UnresolvedSet<16> &Candidates) {
+  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
+  if (!Coroutine)
+    return ExprError();
 
-  ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
+  if (E->getType()->isPlaceholderType()) {
+    ExprResult R = CheckPlaceholderExpr(E);
+    if (R.isInvalid())
+      return ExprError();
+    E = R.get();
+  }
+
+  auto *Promise = Coroutine->CoroutinePromise;
+  if (Promise->getType()->isDependentType()) {
+    Expr *Res = new (Context)
+        CoawaitDependentExpr(Loc, Context.DependentTy, E, Candidates);
+    Coroutine->CoroutineStmts.push_back(Res);
+    return Res;
+  }
+
+  CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl();
+  if (lookupMember(*this, "await_transform", RD, Loc)) {
+    ExprResult R =
+        buildPromiseCall(*this, Coroutine, Loc, "await_transform", E);
+    if (R.isInvalid()) {
+      Diag(Loc,
+           diag::note_coroutine_promise_implicit_await_transform_required_here)
+          << E->getSourceRange();
+      return ExprError();
+    }
+    E = R.get();
+  }
+  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Candidates);
   if (Awaitable.isInvalid())
     return ExprError();
 
   return BuildCoawaitExpr(Loc, Awaitable.get());
 }
+
 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
   if (!Coroutine)
@@ -319,22 +401,6 @@
   return Res;
 }
 
-static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
-                                   SourceLocation Loc, StringRef Name,
-                                   MutableArrayRef<Expr *> Args) {
-  assert(Coroutine->CoroutinePromise && "no promise for coroutine");
-
-  // Form a reference to the promise.
-  auto *Promise = Coroutine->CoroutinePromise;
-  ExprResult PromiseRef = S.BuildDeclRefExpr(
-      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
-  if (PromiseRef.isInvalid())
-    return ExprError();
-
-  // Call 'yield_value', passing in E.
-  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
-}
-
 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
   if (!Coroutine) {
@@ -349,7 +415,8 @@
     return ExprError();
 
   // Build 'operator co_await' call.
-  Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
+  auto Functions = lookupOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
+  Awaitable = buildOperatorCoawaitCall(*this, Loc, Awaitable.get(), Functions);
   if (Awaitable.isInvalid())
     return ExprError();
 
@@ -579,16 +646,22 @@
   }
 
   bool AnyCoawaits = false;
+  bool AnyDependentCoawaits = false;
   bool AnyCoyields = false;
   for (auto *CoroutineStmt : Fn->CoroutineStmts) {
     AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
+    AnyDependentCoawaits |= isa<CoawaitDependentExpr>(CoroutineStmt);
     AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
   }
 
-  if (!AnyCoawaits && !AnyCoyields)
+  if (!AnyCoawaits && !AnyCoyields && !AnyDependentCoawaits)
     Diag(Fn->CoroutineStmts.front()->getLocStart(),
          diag::ext_coroutine_without_co_await_co_yield);
 
+  assert((!AnyDependentCoawaits ||
+          Fn->CoroutinePromise->getType()->isDependentType()) &&
+         "All dependent coawait expressions should already be resolved");
+
   SourceLocation Loc = FD->getLocation();
 
   // Form a declaration statement for the promise declaration, so that AST
@@ -633,17 +706,12 @@
       !Fn->CoroutinePromise->getType()->isDependentType()) {
     CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl();
     assert(RD && "Type should have already been checked");
+
     // [dcl.fct.def.coroutine]/4
     // The unqualified-ids 'return_void' and 'return_value' are looked up in
     // the scope of class P. If both are found, the program is ill-formed.
-    DeclarationName RVoidDN = PP.getIdentifierInfo("return_void");
-    LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName);
-    const bool HasRVoid = LookupQualifiedName(RVoidResult, RD);
-
-    DeclarationName RValueDN = PP.getIdentifierInfo("return_value");
-    LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName);
-    const bool HasRValue = LookupQualifiedName(RValueResult, RD);
-
+    const bool HasRVoid = lookupMember(*this, "return_void", RD, Loc);
+    const bool HasRValue = lookupMember(*this, "return_value", RD, Loc);
     if (HasRVoid && HasRValue) {
       // FIXME Improve this diagnostic
       Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed)
@@ -662,9 +730,8 @@
     // [dcl.fct.def.coroutine]/3
     // The unqualified-id set_exception is found in the scope of P by class
     // member access lookup (3.4.5).
-    DeclarationName SetExDN = PP.getIdentifierInfo("set_exception");
-    LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName);
-    if (LookupQualifiedName(SetExResult, RD)) {
+
+    if (lookupMember(*this, "set_exception", RD, Loc)) {
       // Form the call 'p.set_exception(std::current_exception())'
       SetException = buildStdCurrentExceptionCall(*this, Loc);
       if (SetException.isInvalid())
Index: lib/AST/StmtProfile.cpp
===================================================================
--- lib/AST/StmtProfile.cpp
+++ lib/AST/StmtProfile.cpp
@@ -1552,6 +1552,10 @@
   VisitExpr(S);
 }
 
+void StmtProfiler::VisitCoawaitDependentExpr(const CoawaitDependentExpr *S) {
+  VisitExpr(S);
+}
+
 void StmtProfiler::VisitCoyieldExpr(const CoyieldExpr *S) {
   VisitExpr(S);
 }
Index: lib/AST/StmtPrinter.cpp
===================================================================
--- lib/AST/StmtPrinter.cpp
+++ lib/AST/StmtPrinter.cpp
@@ -2422,6 +2422,11 @@
   PrintExpr(S->getOperand());
 }
 
+void StmtPrinter::VisitCoawaitDependentExpr(CoawaitDependentExpr *S) {
+  OS << "co_await ";
+  PrintExpr(S->getOperand());
+}
+
 void StmtPrinter::VisitCoyieldExpr(CoyieldExpr *S) {
   OS << "co_yield ";
   PrintExpr(S->getOperand());
Index: lib/AST/ItaniumMangle.cpp
===================================================================
--- lib/AST/ItaniumMangle.cpp
+++ lib/AST/ItaniumMangle.cpp
@@ -3281,6 +3281,8 @@
   // These all can only appear in local or variable-initialization
   // contexts and so should never appear in a mangling.
   case Expr::AddrLabelExprClass:
+  // This should no longer exist in the AST by now
+  case Expr::CoawaitDependentExprClass:
   case Expr::DesignatedInitUpdateExprClass:
   case Expr::ImplicitValueInitExprClass:
   case Expr::NoInitExprClass:
Index: lib/AST/ExprConstant.cpp
===================================================================
--- lib/AST/ExprConstant.cpp
+++ lib/AST/ExprConstant.cpp
@@ -9442,6 +9442,7 @@
   case Expr::LambdaExprClass:
   case Expr::CXXFoldExprClass:
   case Expr::CoawaitExprClass:
+  case Expr::CoawaitDependentExprClass:
   case Expr::CoyieldExprClass:
     return ICEDiag(IK_NotICE, E->getLocStart());
 
Index: lib/AST/ExprClassification.cpp
===================================================================
--- lib/AST/ExprClassification.cpp
+++ lib/AST/ExprClassification.cpp
@@ -188,6 +188,9 @@
   case Expr::CXXFoldExprClass:
   case Expr::NoInitExprClass:
   case Expr::DesignatedInitUpdateExprClass:
+  // FIXME How should we classify co_await expressions while they're still
+  // dependent?
+  case Expr::CoawaitDependentExprClass:
   case Expr::CoyieldExprClass:
     return Cl::CL_PRValue;
 
Index: lib/AST/Expr.cpp
===================================================================
--- lib/AST/Expr.cpp
+++ lib/AST/Expr.cpp
@@ -2897,6 +2897,7 @@
   case CXXNewExprClass:
   case CXXDeleteExprClass:
   case CoawaitExprClass:
+  case CoawaitDependentExprClass:
   case CoyieldExprClass:
     // These always have a side-effect.
     return true;
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -8035,6 +8035,9 @@
   StmtResult ActOnCoreturnStmt(SourceLocation KwLoc, Expr *E);
 
   ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E);
+  ExprResult
+  BuildCoawaitDependentExpr(SourceLocation KwLoc, Expr *E,
+                            const UnresolvedSet<16> &CoawaitOperatorCandidates);
   ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E);
   StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E);
 
Index: include/clang/Basic/StmtNodes.td
===================================================================
--- include/clang/Basic/StmtNodes.td
+++ include/clang/Basic/StmtNodes.td
@@ -148,6 +148,7 @@
 // C++ Coroutines TS expressions
 def CoroutineSuspendExpr : DStmt<Expr, 1>;
 def CoawaitExpr : DStmt<CoroutineSuspendExpr>;
+def CoawaitDependentExpr : DStmt<Expr>;
 def CoyieldExpr : DStmt<CoroutineSuspendExpr>;
 
 // Obj-C Expressions.
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -8645,6 +8645,8 @@
   "this function cannot be a coroutine: %q0 has no member named 'promise_type'">;
 def err_implied_std_coroutine_traits_promise_type_not_class : Error<
   "this function cannot be a coroutine: %0 is not a class">;
+def err_coroutine_promise_type_incomplete : Error<
+  "this function cannot be a coroutine: %0 is an incomplete type">;
 def err_coroutine_traits_missing_specialization : Error<
   "this function cannot be a coroutine: missing definition of "
   "specialization %q0">;
@@ -8655,6 +8657,8 @@
   "'std::current_exception' must be a function">;
 def err_coroutine_promise_return_ill_formed : Error<
   "%0 declares both 'return_value' and 'return_void'">;
+def note_coroutine_promise_implicit_await_transform_required_here : Note<
+  "call to 'await_transform' implicitly required by 'co_await' here">;
 }
 
 let CategoryName = "Documentation Issue" in {
Index: include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- include/clang/AST/RecursiveASTVisitor.h
+++ include/clang/AST/RecursiveASTVisitor.h
@@ -2471,6 +2471,12 @@
     ShouldVisitChildren = false;
   }
 })
+DEF_TRAVERSE_STMT(CoawaitDependentExpr, {
+  if (!getDerived().shouldVisitImplicitCode()) {
+    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
+    ShouldVisitChildren = false;
+  }
+})
 DEF_TRAVERSE_STMT(CoyieldExpr, {
   if (!getDerived().shouldVisitImplicitCode()) {
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
Index: include/clang/AST/ExprCXX.h
===================================================================
--- include/clang/AST/ExprCXX.h
+++ include/clang/AST/ExprCXX.h
@@ -4251,6 +4251,49 @@
   }
 };
 
+/// \brief Represents a 'co_await' expression while the type of the promise
+/// is dependent.
+class CoawaitDependentExpr : public Expr {
+  SourceLocation KeywordLoc;
+  Stmt *Operand;
+  UnresolvedSet<16> CoawaitOperatorCandidates;
+
+  friend class ASTStmtReader;
+
+public:
+  CoawaitDependentExpr(SourceLocation KeywordLoc, QualType Ty, Expr *Op,
+                       UnresolvedSet<16> OperatorCandidates)
+      : Expr(CoawaitDependentExprClass, Ty, VK_RValue, OK_Ordinary, true, true,
+             true, Op->containsUnexpandedParameterPack()),
+        KeywordLoc(KeywordLoc), Operand(Op),
+        CoawaitOperatorCandidates(OperatorCandidates) {
+    assert(Op->isTypeDependent() && Ty->isDependentType() &&
+           "wrong constructor for non-dependent co_await/co_yield expression");
+  }
+
+  CoawaitDependentExpr(EmptyShell Empty)
+      : Expr(CoawaitDependentExprClass, Empty) {}
+
+  Expr *getOperand() const { return static_cast<Expr *>(Operand); }
+
+  const UnresolvedSet<16> &getOperatorCandidates() const {
+    return CoawaitOperatorCandidates;
+  }
+
+  SourceLocation getKeywordLoc() const { return KeywordLoc; }
+
+  SourceLocation getLocStart() const LLVM_READONLY { return KeywordLoc; }
+  SourceLocation getLocEnd() const LLVM_READONLY {
+    return getOperand()->getLocEnd();
+  }
+
+  child_range children() { return child_range(&Operand, &Operand + 1); }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == CoawaitDependentExprClass;
+  }
+};
+
 /// \brief Represents a 'co_yield' expression.
 class CoyieldExpr : public CoroutineSuspendExpr {
   friend class ASTStmtReader;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to