EricWF retitled this revision from "[coroutines] Add CoawaitDependentExpr AST 
node and use it to properly build await_transform." to "[coroutines] Add 
DependentCoawaitExpr and fix re-building CoroutineBodyStmt.".
EricWF updated the summary for this revision.
EricWF updated this revision to Diff 76952.
EricWF marked an inline comment as done.
EricWF added a comment.

- Address review comments about `DependentCoawaitExpr` and using 
`UnresolvedLookupExpr`.
- Fix building of the initial/final coroutine suspends points.
- Fix transformation of `CoroutineBodyStmt` so that it transforms the 
final/initial suspend points instead of rebuilding them fully.

@rsmith: This change is a little big, but it's not trivial for me to split it 
up. Please let me know if you would prefer this submitted as multiple patches.


https://reviews.llvm.org/D26057

Files:
  include/clang/AST/ExprCXX.h
  include/clang/AST/RecursiveASTVisitor.h
  include/clang/AST/StmtCXX.h
  include/clang/Basic/DiagnosticSemaKinds.td
  include/clang/Basic/StmtNodes.td
  include/clang/Sema/ScopeInfo.h
  include/clang/Sema/Sema.h
  lib/AST/Expr.cpp
  lib/AST/ExprClassification.cpp
  lib/AST/ExprConstant.cpp
  lib/AST/ItaniumMangle.cpp
  lib/AST/StmtPrinter.cpp
  lib/AST/StmtProfile.cpp
  lib/Parse/ParseStmt.cpp
  lib/Sema/ScopeInfo.cpp
  lib/Sema/SemaCoroutine.cpp
  lib/Sema/SemaDecl.cpp
  lib/Sema/SemaExceptionSpec.cpp
  lib/Sema/SemaTemplateInstantiateDecl.cpp
  lib/Sema/TreeTransform.h
  lib/Serialization/ASTReaderStmt.cpp
  lib/Serialization/ASTWriterStmt.cpp
  lib/StaticAnalyzer/Core/ExprEngine.cpp
  test/SemaCXX/coroutines.cpp
  tools/libclang/CXCursor.cpp

Index: tools/libclang/CXCursor.cpp
===================================================================
--- tools/libclang/CXCursor.cpp
+++ tools/libclang/CXCursor.cpp
@@ -231,6 +231,7 @@
   case Stmt::TypeTraitExprClass:
   case Stmt::CoroutineBodyStmtClass:
   case Stmt::CoawaitExprClass:
+  case Stmt::DependentCoawaitExprClass:
   case Stmt::CoreturnStmtClass:
   case Stmt::CoyieldExprClass:
   case Stmt::CXXBindTemporaryExprClass:
Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -59,25 +59,25 @@
 template <typename... T>
 struct std::experimental::coroutine_traits<int, T...> {};
 
-int no_promise_type() {
-  co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+  co_await a;
 }
 
 template <>
 struct std::experimental::coroutine_traits<double, double> { typedef int promise_type; };
-double bad_promise_type(double) {
-  co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+  co_await a;
 }
 
 template <>
 struct std::experimental::coroutine_traits<double, int> {
   struct promise_type {};
 };
-double bad_promise_type_2(int) {
-  co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits<double, int>::promise_type'}}
+double bad_promise_type_2(int) { // expected-error {{no member named 'initial_suspend'}}
+  co_yield 0;                    // expected-error {{no member named 'yield_value'}}
 }
 
-struct promise; // expected-note 2{{forward declaration}}
+struct promise; // expected-note {{forward declaration}}
 struct promise_void;
 struct void_tag {};
 template <typename... T>
@@ -94,9 +94,7 @@
 }
 
 // FIXME: This diagnostic is terrible.
-void undefined_promise() { // expected-error {{variable has incomplete type 'promise_type'}}
-  // FIXME: This diagnostic doesn't make any sense.
-  // expected-error@-2 {{incomplete definition of type 'promise'}}
+void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<void>::promise_type' (aka 'promise') is an incomplete type}}
   co_await a;
 }
 
@@ -217,6 +215,13 @@
 }
 
 struct outer {};
+struct await_arg_1 {};
+struct await_arg_2 {};
+
+namespace adl_ns {
+struct coawait_arg_type {};
+awaitable operator co_await(coawait_arg_type);
+}
 
 namespace dependent_operator_co_await_lookup {
   template<typename T> void await_template(T t) {
@@ -239,6 +244,94 @@
   };
   template void await_template(outer); // expected-note {{instantiation}}
   template void await_template_2(outer);
+
+  struct transform_awaitable {};
+  struct transformed {};
+
+  struct transform_promise {
+    typedef transform_awaitable await_arg;
+    coro<transform_promise> get_return_object();
+    transformed initial_suspend();
+    ::adl_ns::coawait_arg_type final_suspend();
+    transformed await_transform(transform_awaitable);
+  };
+  template <class AwaitArg>
+  struct basic_promise {
+    typedef AwaitArg await_arg;
+    coro<basic_promise> get_return_object();
+    awaitable initial_suspend();
+    awaitable final_suspend();
+  };
+
+  awaitable operator co_await(await_arg_1);
+
+  template <typename T, typename U>
+  coro<T> await_template_3(U t) {
+    co_await t;
+  }
+
+  template coro<basic_promise<await_arg_1>> await_template_3<basic_promise<await_arg_1>>(await_arg_1);
+
+  template <class T, int I = 0>
+  struct dependent_member {
+    coro<T> mem_fn() const {
+      co_await typename T::await_arg{}; // expected-error {{call to function 'operator co_await'}}}
+    }
+    template <class U>
+    coro<T> dep_mem_fn(U t) {
+      co_await t;
+    }
+  };
+
+  template <>
+  struct dependent_member<long> {
+    // FIXME this diagnostic is terrible
+    coro<transform_promise> mem_fn() const { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}}
+      // expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}}
+      // expected-note@+1 {{function is a coroutine due to use of 'co_await' here}}
+      co_await transform_awaitable{};
+      // expected-error@-1 {{no member named 'await_ready'}}
+    }
+    template <class R, class U>
+    coro<R> dep_mem_fn(U u) { co_await u; }
+  };
+
+  awaitable operator co_await(await_arg_2); // expected-note {{'operator co_await' should be declared prior to the call site}}
+
+  template struct dependent_member<basic_promise<await_arg_1>, 0>;
+  template struct dependent_member<basic_promise<await_arg_2>, 0>; // expected-note {{in instantiation}}
+
+  template <>
+  coro<transform_promise>
+      // FIXME this diagnostic is terrible
+      dependent_member<long>::dep_mem_fn<transform_promise>(int) { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}}
+    //expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}}
+    //expected-note@+1 {{function is a coroutine due to use of 'co_await' here}}
+    co_await transform_awaitable{};
+    // expected-error@-1 {{no member named 'await_ready'}}
+  }
+
+  void operator co_await(transform_awaitable) = delete;
+  awaitable operator co_await(transformed);
+
+  template coro<transform_promise>
+      dependent_member<long>::dep_mem_fn<transform_promise>(transform_awaitable);
+
+  template <>
+  coro<transform_promise> dependent_member<long>::dep_mem_fn<transform_promise>(long) {
+    co_await transform_awaitable{};
+  }
+
+  template <>
+  struct dependent_member<int> {
+    coro<transform_promise> mem_fn() const {
+      co_await transform_awaitable{};
+    }
+  };
+
+  template coro<transform_promise> await_template_3<transform_promise>(transform_awaitable);
+  template struct dependent_member<transform_promise>;
+  template coro<transform_promise> dependent_member<transform_promise>::dep_mem_fn(transform_awaitable);
 }
 
 struct yield_fn_tag {};
@@ -314,7 +407,8 @@
 };
 // FIXME: This diagnostic is terrible.
 coro<bad_promise_4> bad_initial_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}}
-  co_await a;
+  // expected-note@-1 {{'initial_suspend' implicitly required}}
+  co_await a; // expected-note {{use of 'co_await' here}}
 }
 
 struct bad_promise_5 {
@@ -324,7 +418,8 @@
 };
 // FIXME: This diagnostic is terrible.
 coro<bad_promise_5> bad_final_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}}
-  co_await a;
+  // expected-note@-1 {{'final_suspend' implicitly required}}
+  co_await a; // expected-note {{use of 'co_await' here}}
 }
 
 struct bad_promise_6 {
@@ -355,20 +450,69 @@
 int *current_exception();
 }
 
-struct bad_promise_8 {
+struct bad_promise_base {
+private:
+  void return_void(); // expected-note {{declared private here}}
+};
+struct bad_promise_8 : bad_promise_base {
   coro<bad_promise_8> get_return_object();
   suspend_always initial_suspend();
   suspend_always final_suspend();
-  void return_void();
   void set_exception();                                   // expected-note {{function not viable}}
   void set_exception(int *) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}}
   void set_exception(void *);                             // expected-note {{candidate function}}
 };
 coro<bad_promise_8> calls_set_exception() {
   // expected-error@-1 {{call to unavailable member function 'set_exception'}}
+  // expected-error@-2 {{'return_void' is a private member of 'bad_promise_base'}}
   co_await a;
 }
 
+struct bad_promise_9 {
+  coro<bad_promise_9> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void await_transform(void *);                                // expected-note {{candidate}}
+  awaitable await_transform(int) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}}
+  void return_void();
+};
+coro<bad_promise_9> calls_await_transform() {
+  co_await 42; // expected-error {{call to unavailable member function 'await_transform'}}
+  // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}}
+}
+
+struct bad_promise_10 {
+  coro<bad_promise_10> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  int await_transform;
+  void return_void();
+};
+coro<bad_promise_10> bad_coawait() {
+  // FIXME this diagnostic is terrible
+  co_await 42; // expected-error {{called object type 'int' is not a function or function pointer}}
+  // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}}
+}
+
+struct call_operator {
+  template <class... Args>
+  awaitable operator()(Args...) const { return a; }
+};
+void ret_void();
+struct good_promise_1 {
+  coro<good_promise_1> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  static const call_operator await_transform;
+  using Fn = void (*)();
+  Fn return_void = ret_void;
+};
+const call_operator good_promise_1::await_transform;
+coro<good_promise_1> ok_static_coawait() {
+  // FIXME this diagnostic is terrible
+  co_await 42;
+}
+
 template<> struct std::experimental::coroutine_traits<int, int, const char**>
 { using promise_type = promise; };
 
Index: lib/StaticAnalyzer/Core/ExprEngine.cpp
===================================================================
--- lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -774,6 +774,7 @@
     case Stmt::FunctionParmPackExprClass:
     case Stmt::CoroutineBodyStmtClass:
     case Stmt::CoawaitExprClass:
+    case Stmt::DependentCoawaitExprClass:
     case Stmt::CoreturnStmtClass:
     case Stmt::CoyieldExprClass:
     case Stmt::SEHTryStmtClass:
Index: lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- lib/Serialization/ASTWriterStmt.cpp
+++ lib/Serialization/ASTWriterStmt.cpp
@@ -315,6 +315,11 @@
   llvm_unreachable("unimplemented");
 }
 
+void ASTStmtWriter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) {
+  // FIXME: Implement coroutine serialization.
+  llvm_unreachable("unimplemented");
+}
+
 void ASTStmtWriter::VisitCoyieldExpr(CoyieldExpr *S) {
   // FIXME: Implement coroutine serialization.
   llvm_unreachable("unimplemented");
Index: lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- lib/Serialization/ASTReaderStmt.cpp
+++ lib/Serialization/ASTReaderStmt.cpp
@@ -400,6 +400,11 @@
   llvm_unreachable("unimplemented");
 }
 
+void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) {
+  // FIXME: Implement coroutine serialization.
+  llvm_unreachable("unimplemented");
+}
+
 void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *S) {
   // FIXME: Implement coroutine serialization.
   llvm_unreachable("unimplemented");
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -1306,16 +1306,29 @@
   ///
   /// By default, performs semantic analysis to build the new statement.
   /// Subclasses may override this routine to provide different behavior.
-  StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result) {
-    return getSema().BuildCoreturnStmt(CoreturnLoc, Result);
+  StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result,
+                                 bool IsImplicitlyCreated) {
+    return getSema().BuildCoreturnStmt(CoreturnLoc, Result,
+                                       IsImplicitlyCreated);
   }
 
   /// \brief Build a new co_await expression.
   ///
   /// By default, performs semantic analysis to build the new expression.
   /// Subclasses may override this routine to provide different behavior.
-  ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result) {
-    return getSema().BuildCoawaitExpr(CoawaitLoc, Result);
+  ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result,
+                                bool IsImplicitlyCreated) {
+    return getSema().BuildCoawaitExpr(CoawaitLoc, Result, IsImplicitlyCreated);
+  }
+
+  /// \brief Build a new co_await expression.
+  ///
+  /// By default, performs semantic analysis to build the new expression.
+  /// Subclasses may override this routine to provide different behavior.
+  ExprResult RebuildDependentCoawaitExpr(SourceLocation CoawaitLoc,
+                                         Expr *Result,
+                                         UnresolvedLookupExpr *Lookup) {
+    return getSema().BuildDependentCoawaitExpr(CoawaitLoc, Result, Lookup);
   }
 
   /// \brief Build a new co_yield expression.
@@ -1326,6 +1339,15 @@
     return getSema().BuildCoyieldExpr(CoyieldLoc, Result);
   }
 
+  StmtResult RebuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend,
+                                      Stmt *FinalSuspend, Stmt *OnException,
+                                      Stmt *OnFallthrough,
+                                       Expr *Allocation,
+                                      Stmt *Deallocation, Expr *ReturnObject) {
+    return getSema().BuildCoroutineBodyStmt(
+        Body, Promise, InitSuspend, FinalSuspend, OnException,  OnFallthrough,
+        Allocation, Deallocation, ReturnObject);
+  }
   /// \brief Build a new Objective-C \@try statement.
   ///
   /// By default, performs semantic analysis to build the new statement.
@@ -6655,7 +6677,87 @@
 TreeTransform<Derived>::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) {
   // The coroutine body should be re-formed by the caller if necessary.
   // FIXME: The coroutine body is always rebuilt by ActOnFinishFunctionBody
-  return getDerived().TransformStmt(S->getBody());
+
+  auto *ScopeInfo = SemaRef.getCurFunction();
+  auto *FD = cast<FunctionDecl>(SemaRef.CurContext);
+  assert(ScopeInfo && !ScopeInfo->CoroutinePromise &&
+         !ScopeInfo->HasCoroutineSuspends &&
+         ScopeInfo->CoroutineStmts.empty() && "expected clean scope info");
+
+  // Set that we have (possibly-invalid) suspend points before we do anything
+  // that may fail.
+  ScopeInfo->setCoroutineSuspendsInvalid();
+
+  // The new CoroutinePromise object needs to be built and put into the current
+  // FunctionScopeInfo before any transformations or rebuilding occurs.
+  auto *Promise = S->getPromiseDecl();
+  auto *NewPromise = SemaRef.buildCoroutinePromise(FD->getLocation());
+  if (!NewPromise)
+    return StmtError();
+  getDerived().transformedLocalDecl(Promise, NewPromise);
+  ScopeInfo->CoroutinePromise = NewPromise;
+
+  // Transform the implicit coroutine statements we built during the initial
+  // parse.
+  StmtResult InitSuspend = getDerived().TransformStmt(S->getInitSuspendStmt());
+  if (InitSuspend.isInvalid())
+    return StmtError();
+  StmtResult FinalSuspend =
+      getDerived().TransformStmt(S->getFinalSuspendStmt());
+  if (FinalSuspend.isInvalid())
+    return StmtError();
+  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
+
+  StmtResult BodyRes = getDerived().TransformStmt(S->getBody());
+  if (BodyRes.isInvalid())
+    return StmtError();
+
+  Stmt *SetException = S->getExceptionHandler();
+  Stmt *Fallthrough = S->getFallthroughHandler();
+  if (Fallthrough) {
+    StmtResult Res = getDerived().TransformStmt(Fallthrough);
+    if (Res.isInvalid())
+      return StmtError();
+    Fallthrough = Res.get();
+  }
+
+  if (SetException) {
+    StmtResult Res = getDerived().TransformStmt(SetException);
+    if (Res.isInvalid())
+      return StmtError();
+    SetException = Res.get();
+  }
+
+  // Transform any additional statements we may have already built.
+  Expr *Allocation = nullptr;
+  Stmt *Deallocation = nullptr;
+  if (S->getAllocate() && S->getDeallocate()) {
+    ExprResult AllocRes = getDerived().TransformExpr(S->getAllocate());
+    if (AllocRes.isInvalid())
+      return StmtError();
+    Allocation = AllocRes.get();
+
+    StmtResult DeallocRes = getDerived().TransformStmt(S->getDeallocate());
+    if (DeallocRes.isInvalid())
+      return StmtError();
+    Deallocation = DeallocRes.get();
+  }
+
+  Expr *ReturnObject = S->getReturnValueInit();
+  if (ReturnObject) {
+    ExprResult Res = getDerived().TransformInitializer(ReturnObject,
+            /*NoCopyInit*/false);
+    if (Res.isInvalid())
+      return StmtError();
+    ReturnObject = Res.get();
+  }
+
+  // Do a partial rebuild of the coroutine body and stash it in the ScopeInfo
+  return getDerived().RebuildCoroutineBodyStmt(
+      BodyRes.get(), NewPromise, InitSuspend.get(), FinalSuspend.get(),
+      SetException, Fallthrough, Allocation,
+      Deallocation, ReturnObject);
+
 }
 
 template<typename Derived>
@@ -6668,7 +6770,8 @@
 
   // Always rebuild; we don't know if this needs to be injected into a new
   // context or if the promise type has changed.
-  return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get());
+  return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get(),
+                                          S->isImplicitlyCreated());
 }
 
 template<typename Derived>
@@ -6681,12 +6784,26 @@
 
   // Always rebuild; we don't know if this needs to be injected into a new
   // context or if the promise type has changed.
-  return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get());
+  return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(),
+                                         E->isImplicitlyCreated());
 }
 
-template<typename Derived>
+template <typename Derived>
 ExprResult
-TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) {
+TreeTransform<Derived>::TransformDependentCoawaitExpr(DependentCoawaitExpr *E) {
+  ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
+                                                        /*NotCopyInit*/ false);
+  if (Result.isInvalid())
+    return ExprError();
+
+  // Always rebuild; we don't know if this needs to be injected into a new
+  // context or if the promise type has changed.
+  return getDerived().RebuildDependentCoawaitExpr(
+      E->getKeywordLoc(), Result.get(), E->getOperatorCoawaitLookup());
+}
+
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) {
   ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
                                                         /*NotCopyInit*/false);
   if (Result.isInvalid())
Index: lib/Sema/SemaTemplateInstantiateDecl.cpp
===================================================================
--- lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -3714,6 +3714,8 @@
 
     if (Body.isInvalid())
       Function->setInvalidDecl();
+    else
+      assert(Body.get());
 
     ActOnFinishFunctionBody(Function, Body.get(),
                             /*IsInstantiation=*/true);
Index: lib/Sema/SemaExceptionSpec.cpp
===================================================================
--- lib/Sema/SemaExceptionSpec.cpp
+++ lib/Sema/SemaExceptionSpec.cpp
@@ -1146,6 +1146,7 @@
   case Expr::ArraySubscriptExprClass:
   case Expr::OMPArraySectionExprClass:
   case Expr::BinaryOperatorClass:
+  case Expr::DependentCoawaitExprClass:
   case Expr::CompoundAssignOperatorClass:
   case Expr::CStyleCastExprClass:
   case Expr::CXXStaticCastExprClass:
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -11383,7 +11383,7 @@
   if (canRedefineFunction(Definition, getLangOpts()))
     return;
 
-  // If we don't have a visible definition of the function, and it's inline or
+  // If we don't have a viNsible definition of the function, and it's inline or
   // a template, skip the new definition.
   if (SkipBody && !hasVisibleDefinition(Definition) &&
       (Definition->getFormalLinkage() == InternalLinkage ||
@@ -11675,7 +11675,7 @@
   sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy();
   sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr;
 
-  if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty())
+  if (getLangOpts().CoroutinesTS && getCurFunction()->CoroutinePromise)
     CheckCompletedCoroutineBody(FD, Body);
 
   if (FD) {
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -21,21 +21,32 @@
 using namespace clang;
 using namespace sema;
 
+static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
+                         SourceLocation Loc) {
+  DeclarationName DN = S.PP.getIdentifierInfo(Name);
+  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
+  // Suppress diagnostics when a private member is selected. The same warnings
+  // will be produced again when building the call.
+  LR.suppressDiagnostics();
+  return S.LookupQualifiedName(LR, RD);
+}
+
 /// Look up the std::coroutine_traits<...>::promise_type for the given
 /// function type.
 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
-                                  SourceLocation Loc) {
+                                  SourceLocation KwLoc,
+                                  SourceLocation FuncLoc) {
   // FIXME: Cache std::coroutine_traits once we've found it.
   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
   if (!StdExp) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+    S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found);
     return QualType();
   }
 
   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
-                      Loc, Sema::LookupOrdinaryName);
+                      FuncLoc, Sema::LookupOrdinaryName);
   if (!S.LookupQualifiedName(Result, StdExp)) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+    S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found);
     return QualType();
   }
 
@@ -49,52 +60,58 @@
   }
 
   // Form template argument list for coroutine_traits<R, P1, P2, ...>.
-  TemplateArgumentListInfo Args(Loc, Loc);
+  TemplateArgumentListInfo Args(KwLoc, KwLoc);
   Args.addArgument(TemplateArgumentLoc(
       TemplateArgument(FnType->getReturnType()),
-      S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
+      S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc)));
   // FIXME: If the function is a non-static member function, add the type
   // of the implicit object parameter before the formal parameters.
   for (QualType T : FnType->getParamTypes())
     Args.addArgument(TemplateArgumentLoc(
-        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
+        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
 
   // Build the template-id.
   QualType CoroTrait =
-      S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
+      S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
   if (CoroTrait.isNull())
     return QualType();
-  if (S.RequireCompleteType(Loc, CoroTrait,
+  if (S.RequireCompleteType(KwLoc, CoroTrait,
                             diag::err_coroutine_traits_missing_specialization))
     return QualType();
 
-  CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
+  auto *RD = CoroTrait->getAsCXXRecordDecl();
   assert(RD && "specialization of class template is not a class?");
 
   // Look up the ::promise_type member.
-  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
+  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
                  Sema::LookupOrdinaryName);
   S.LookupQualifiedName(R, RD);
   auto *Promise = R.getAsSingle<TypeDecl>();
   if (!Promise) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
+    S.Diag(FuncLoc,
+           diag::err_implied_std_coroutine_traits_promise_type_not_found)
         << RD;
     return QualType();
   }
-
   // The promise type is required to be a class type.
   QualType PromiseType = S.Context.getTypeDeclType(Promise);
-  if (!PromiseType->getAsCXXRecordDecl()) {
-    // Use the fully-qualified name of the type.
+
+  auto buildNNS = [&]() {
     auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
     NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
                                       CoroTrait.getTypePtr());
-    PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
+    return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
+  };
 
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
-        << PromiseType;
+  if (!PromiseType->getAsCXXRecordDecl()) {
+    S.Diag(FuncLoc,
+           diag::err_implied_std_coroutine_traits_promise_type_not_class)
+        << buildNNS();
     return QualType();
   }
+  if (S.RequireCompleteType(FuncLoc, buildNNS(),
+                            diag::err_coroutine_promise_type_incomplete))
+    return QualType();
 
   return PromiseType;
 }
@@ -160,41 +177,49 @@
   return !Diagnosed;
 }
 
-/// Check that this is a context in which a coroutine suspension can appear.
-static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
-                                                StringRef Keyword) {
-  if (!isValidCoroutineContext(S, Loc, Keyword))
-    return nullptr;
+static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
+                                                 SourceLocation Loc) {
+  DeclarationName OpName =
+      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
+  LookupResult Operators(SemaRef, OpName, SourceLocation(),
+                         Sema::LookupOperatorName);
+  SemaRef.LookupName(Operators, S);
+
+  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
+  const auto &Functions = Operators.asUnresolvedSet();
+  bool IsOverloaded =
+      Functions.size() > 1 ||
+      (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
+  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
+      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
+      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
+      Functions.begin(), Functions.end());
+  assert(CoawaitOp);
+  return CoawaitOp;
+}
 
-  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
-  auto *FD = cast<FunctionDecl>(S.CurContext);
-  auto *ScopeInfo = S.getCurFunction();
-  assert(ScopeInfo && "missing function scope for function");
+/// Build a call to 'operator co_await' if there is a suitable operator for
+/// the given expression.
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
+                                           Expr *E,
+                                           UnresolvedLookupExpr *Lookup) {
 
-  // If we don't have a promise variable, build one now.
-  if (!ScopeInfo->CoroutinePromise) {
-    QualType T = FD->getType()->isDependentType()
-                     ? S.Context.DependentTy
-                     : lookupPromiseType(
-                           S, FD->getType()->castAs<FunctionProtoType>(), Loc);
-    if (T.isNull())
-      return nullptr;
-
-    // Create and default-initialize the promise.
-    ScopeInfo->CoroutinePromise =
-        VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
-                        &S.PP.getIdentifierTable().get("__promise"), T,
-                        S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
-    S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
-    if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
-      S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
-  }
+  UnresolvedSet<16> Functions;
+  Functions.append(Lookup->decls_begin(), Lookup->decls_end());
+  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
+}
 
-  return ScopeInfo;
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
+                                           SourceLocation Loc, Expr *E) {
+  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
+  if (R.isInvalid())
+    return ExprError();
+  return buildOperatorCoawaitCall(SemaRef, Loc, E,
+                                  cast<UnresolvedLookupExpr>(R.get()));
 }
 
 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
-                              MutableArrayRef<Expr *> CallArgs) {
+                              MultiExprArg CallArgs) {
   StringRef Name = S.Context.BuiltinInfo.getName(Id);
   LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
   S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
@@ -213,24 +238,14 @@
   return Call.get();
 }
 
-/// Build a call to 'operator co_await' if there is a suitable operator for
-/// the given expression.
-static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
-                                           SourceLocation Loc, Expr *E) {
-  UnresolvedSet<16> Functions;
-  SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
-                                       Functions);
-  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
-}
 
 struct ReadySuspendResumeResult {
   bool IsInvalid;
   Expr *Results[3];
 };
 
 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
-                                  StringRef Name,
-                                  MutableArrayRef<Expr *> Args) {
+                                  StringRef Name, MultiExprArg Args) {
   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
 
   // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
@@ -268,25 +283,174 @@
   return Calls;
 }
 
+static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
+                                   SourceLocation Loc, StringRef Name,
+                                   MultiExprArg Args) {
+
+  // Form a reference to the promise.
+  ExprResult PromiseRef = S.BuildDeclRefExpr(
+      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
+  if (PromiseRef.isInvalid())
+    return ExprError();
+
+  // Call 'yield_value', passing in E.
+  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
+}
+
+VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
+  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
+  auto *FD = cast<FunctionDecl>(CurContext);
+
+  QualType T =
+      FD->getType()->isDependentType()
+          ? Context.DependentTy
+          : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(),
+                              Loc, FD->getLocation());
+  if (T.isNull())
+    return nullptr;
+
+  auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
+                             &PP.getIdentifierTable().get("__promise"), T,
+                             Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
+  CheckVariableDeclarationType(VD);
+  if (VD->isInvalidDecl())
+    return nullptr;
+  ActOnUninitializedDecl(VD, false);
+  assert(!VD->isInvalidDecl());
+  return VD;
+}
+
+/// Check that this is a context in which a coroutine suspension can appear.
+static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
+                                                StringRef Keyword) {
+  if (!isValidCoroutineContext(S, Loc, Keyword))
+    return nullptr;
+
+  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
+  auto *FD = cast<FunctionDecl>(S.CurContext);
+
+  auto *ScopeInfo = S.getCurFunction();
+  assert(ScopeInfo && "missing function scope for function");
+
+  if (ScopeInfo->CoroutinePromise)
+    return ScopeInfo;
+
+  ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
+  if (!ScopeInfo->CoroutinePromise)
+    return nullptr;
+
+  return ScopeInfo;
+}
+
+static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc,
+                                    StringRef Keyword) {
+  if (!checkCoroutineContext(S, KWLoc, Keyword))
+    return false;
+  auto *ScopeInfo = S.getCurFunction();
+  assert(ScopeInfo->CoroutinePromise);
+
+  // If we have existing coroutine statements then we have already built
+  // the initial and final suspend points.
+  if (ScopeInfo->HasCoroutineSuspends)
+    return true;
+
+  ScopeInfo->setCoroutineSuspendsInvalid();
+
+  auto *Fn = cast<FunctionDecl>(S.CurContext);
+  SourceLocation Loc = Fn->getLocation();
+  // Build the initial suspend point
+  auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
+    ExprResult Suspend =
+        buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None);
+    if (Suspend.isInvalid())
+      return StmtError();
+    Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get());
+    if (Suspend.isInvalid())
+      return StmtError();
+    Suspend = S.BuildCoawaitExpr(Loc, Suspend.get(),
+                                 /*IsImplicitlyCreated*/ true);
+    Suspend = S.ActOnFinishFullExpr(Suspend.get());
+    if (Suspend.isInvalid()) {
+      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+          << ((Name == "initial_suspend") ? 0 : 1);
+      S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
+      return StmtError();
+    }
+    return cast<Stmt>(Suspend.get());
+  };
+
+  StmtResult InitSuspend = buildSuspends("initial_suspend");
+  if (InitSuspend.isInvalid())
+    return true;
+
+  StmtResult FinalSuspend = buildSuspends("final_suspend");
+  if (FinalSuspend.isInvalid())
+    return true;
+
+  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
+
+  return true;
+}
+
 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
-  if (!Coroutine) {
+  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) {
     CorrectDelayedTyposInExpr(E);
     return ExprError();
   }
+
   if (E->getType()->isPlaceholderType()) {
     ExprResult R = CheckPlaceholderExpr(E);
     if (R.isInvalid()) return ExprError();
     E = R.get();
   }
+  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
+  if (Lookup.isInvalid())
+    return ExprError();
+  return BuildDependentCoawaitExpr(Loc, E,
+                                   cast<UnresolvedLookupExpr>(Lookup.get()));
+}
+
+ExprResult Sema::BuildDependentCoawaitExpr(SourceLocation Loc, Expr *E,
+                                           UnresolvedLookupExpr *Lookup) {
+  auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
+  if (!FSI)
+    return ExprError();
+
+  if (E->getType()->isPlaceholderType()) {
+    ExprResult R = CheckPlaceholderExpr(E);
+    if (R.isInvalid())
+      return ExprError();
+    E = R.get();
+  }
 
-  ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
+  auto *Promise = FSI->CoroutinePromise;
+  if (Promise->getType()->isDependentType()) {
+    Expr *Res =
+        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
+    FSI->CoroutineStmts.push_back(Res);
+    return Res;
+  }
+
+  auto *RD = Promise->getType()->getAsCXXRecordDecl();
+  if (lookupMember(*this, "await_transform", RD, Loc)) {
+    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
+    if (R.isInvalid()) {
+      Diag(Loc,
+           diag::note_coroutine_promise_implicit_await_transform_required_here)
+          << E->getSourceRange();
+      return ExprError();
+    }
+    E = R.get();
+  }
+  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
   if (Awaitable.isInvalid())
     return ExprError();
 
   return BuildCoawaitExpr(Loc, Awaitable.get());
 }
-ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
+
+ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E,
+                                  bool IsImplicitlyCreated) {
   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
   if (!Coroutine)
     return ExprError();
@@ -298,8 +462,10 @@
   }
 
   if (E->getType()->isDependentType()) {
-    Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
-    Coroutine->CoroutineStmts.push_back(Res);
+    Expr *Res = new (Context)
+        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicitlyCreated);
+    if (!IsImplicitlyCreated)
+      Coroutine->CoroutineStmts.push_back(Res);
     return Res;
   }
 
@@ -314,37 +480,21 @@
     return ExprError();
 
   Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
-                                        RSS.Results[2]);
-  Coroutine->CoroutineStmts.push_back(Res);
+                                        RSS.Results[2], IsImplicitlyCreated);
+  if (!IsImplicitlyCreated)
+    Coroutine->CoroutineStmts.push_back(Res);
   return Res;
 }
 
-static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
-                                   SourceLocation Loc, StringRef Name,
-                                   MutableArrayRef<Expr *> Args) {
-  assert(Coroutine->CoroutinePromise && "no promise for coroutine");
-
-  // Form a reference to the promise.
-  auto *Promise = Coroutine->CoroutinePromise;
-  ExprResult PromiseRef = S.BuildDeclRefExpr(
-      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
-  if (PromiseRef.isInvalid())
-    return ExprError();
-
-  // Call 'yield_value', passing in E.
-  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
-}
-
 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
-  if (!Coroutine) {
+  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) {
     CorrectDelayedTyposInExpr(E);
     return ExprError();
   }
 
   // Build yield_value call.
-  ExprResult Awaitable =
-      buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
+  ExprResult Awaitable = buildPromiseCall(
+      *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
   if (Awaitable.isInvalid())
     return ExprError();
 
@@ -388,18 +538,18 @@
   return Res;
 }
 
-StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
-  if (!Coroutine) {
+StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
+  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) {
     CorrectDelayedTyposInExpr(E);
     return StmtError();
   }
   return BuildCoreturnStmt(Loc, E);
 }
 
-StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
-  if (!Coroutine)
+StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
+                                   bool IsImplicitlyCreated) {
+  auto *FSI = checkCoroutineContext(*this, Loc, "co_return");
+  if (!FSI)
     return StmtError();
 
   if (E && E->getType()->isPlaceholderType() &&
@@ -412,20 +562,22 @@
   // FIXME: If the operand is a reference to a variable that's about to go out
   // of scope, we should treat the operand as an xvalue for this overload
   // resolution.
+  VarDecl *Promise = FSI->CoroutinePromise;
   ExprResult PC;
   if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
-    PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
+    PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
   } else {
     E = MakeFullDiscardedValueExpr(E).get();
-    PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
+    PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
   }
   if (PC.isInvalid())
     return StmtError();
 
   Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
 
-  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
-  Coroutine->CoroutineStmts.push_back(Res);
+  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicitlyCreated);
+  if (!IsImplicitlyCreated)
+    FSI->CoroutineStmts.push_back(Res);
   return Res;
 }
 
@@ -482,14 +634,82 @@
   return OperatorDelete;
 }
 
+static bool buildFallthrough(Sema &S, SourceLocation Loc,
+                             FunctionDecl *FD,
+                             FunctionScopeInfo *FTI,
+                             Stmt *&OnFallthrough)
+{
+  assert(!OnFallthrough && "rebuilding existing OnFallthrough");
+  auto *Promise = FTI->CoroutinePromise;
+  if (Promise->getType()->isDependentType())
+    return true;
+
+  CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl();
+
+  // [dcl.fct.def.coroutine]/4
+  // The unqualified-ids 'return_void' and 'return_value' are looked up in
+  // the scope of class P. If both are found, the program is ill-formed.
+  const bool HasRVoid = lookupMember(S, "return_void", RD, Loc);
+  const bool HasRValue = lookupMember(S, "return_value", RD, Loc);
+  if (HasRVoid && HasRValue) {
+    // FIXME Improve this diagnostic
+    S.Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed)
+        << RD;
+    return false;
+  } else if (HasRVoid) {
+    // If the unqualified-id return_void is found, flowing off the end of a
+    // coroutine is equivalent to a co_return with no operand. Otherwise,
+    // flowing off the end of a coroutine results in undefined behavior.
+    StmtResult Fallthrough = S.BuildCoreturnStmt(FD->getLocation(), nullptr,
+                                                 /*IsImplicitlyCreated*/ true);
+    if (!Fallthrough.isInvalid())
+      Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
+    if (Fallthrough.isInvalid())
+      return false;
+    OnFallthrough = Fallthrough.get();
+  }
+  return true;
+}
+
+static bool buildSetException(Sema &S, SourceLocation Loc,
+                             FunctionDecl *FD,
+                             FunctionScopeInfo *FTI,
+                             Stmt *&OnException)
+{
+  assert(!OnException && "rebuilding existing set_exception");
+  auto *Promise = FTI->CoroutinePromise;
+  if (Promise->getType()->isDependentType())
+     return true;
+
+  CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl();
+
+  // [dcl.fct.def.coroutine]/3
+  // The unqualified-id set_exception is found in the scope of P by class
+  // member access lookup (3.4.5).
+  if (lookupMember(S, "set_exception", RD, Loc)) {
+    // Form the call 'p.set_exception(std::current_exception())'
+    ExprResult SetException = buildStdCurrentExceptionCall(S, Loc);
+    if (SetException.isInvalid())
+      return false;
+    Expr *E = SetException.get();
+    SetException = buildPromiseCall(S, Promise, Loc, "set_exception", E);
+    SetException = S.ActOnFinishFullExpr(SetException.get(), Loc);
+    if (SetException.isInvalid())
+      return false;
+    OnException = SetException.get();
+  }
+  return true;
+}
+
+
 // Builds allocation and deallocation for the coroutine. Returns false on
 // failure.
 static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc,
                                            FunctionScopeInfo *Fn,
                                            Expr *&Allocation,
                                            Stmt *&Deallocation) {
-  TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo();
-  QualType PromiseType = TInfo->getType();
+  assert(!Allocation && !Deallocation && "alloc/dealloc statements have already been built");
+  QualType PromiseType = Fn->CoroutinePromise->getType();
   if (PromiseType->isDependentType())
     return true;
 
@@ -532,8 +752,6 @@
   if (NewExpr.isInvalid())
     return false;
 
-  Allocation = NewExpr.get();
-
   // Make delete call.
 
   QualType OpDeleteQualType = OperatorDelete->getType();
@@ -559,149 +777,137 @@
   if (DeleteExpr.isInvalid())
     return false;
 
+  Allocation = NewExpr.get();
   Deallocation = DeleteExpr.get();
 
   return true;
 }
 
-void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
-  FunctionScopeInfo *Fn = getCurFunction();
-  assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
-
-  // Coroutines [stmt.return]p1:
-  //   A return statement shall not appear in a coroutine.
-  if (Fn->FirstReturnLoc.isValid()) {
-    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
-    auto *First = Fn->CoroutineStmts[0];
-    Diag(First->getLocStart(), diag::note_declared_coroutine_here)
-        << (isa<CoawaitExpr>(First) ? 0 :
-            isa<CoyieldExpr>(First) ? 1 : 2);
-  }
-
-  bool AnyCoawaits = false;
-  bool AnyCoyields = false;
-  for (auto *CoroutineStmt : Fn->CoroutineStmts) {
-    AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
-    AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
-  }
-
-  if (!AnyCoawaits && !AnyCoyields)
-    Diag(Fn->CoroutineStmts.front()->getLocStart(),
-         diag::ext_coroutine_without_co_await_co_yield);
-
-  SourceLocation Loc = FD->getLocation();
-
+StmtResult Sema::BuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend,
+                                        Stmt *FinalSuspend, Stmt *SetException,
+                                        Stmt *OnFallthrough, Expr *Allocation,
+                                        Stmt *Deallocation, Expr *ReturnObjectInit) {
+  assert(Promise && InitSuspend && FinalSuspend && "these nodes must already be built");
   // Form a declaration statement for the promise declaration, so that AST
   // visitors can more easily find it.
+  // FIXME Get real location
+  auto *FSI = getCurFunction();
+  assert(FSI->CoroutinePromise);
+  auto *FD = cast<FunctionDecl>(CurContext);
+  auto Loc = FD->getLocation();
+
+  auto checkPlaceholders = [&](Stmt *&S) mutable {
+    Expr *E = cast_or_null<Expr>(S);
+    if (E && E->getType()->isPlaceholderType() &&
+        !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
+      ExprResult R = CheckPlaceholderExpr(E);
+      if (R.isInvalid())
+        return false;
+      S = cast<Stmt>(R.get());
+    }
+    return true;
+  };
+  if (!checkPlaceholders(InitSuspend) || !checkPlaceholders(FinalSuspend))
+    return StmtError();
+
   StmtResult PromiseStmt =
-      ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
+      ActOnDeclStmt(ConvertDeclToDeclGroup(Promise), Promise->getLocStart(),
+                    Promise->getLocEnd());
   if (PromiseStmt.isInvalid())
-    return FD->setInvalidDecl();
-
-  // Form and check implicit 'co_await p.initial_suspend();' statement.
-  ExprResult InitialSuspend =
-      buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
-  // FIXME: Support operator co_await here.
-  if (!InitialSuspend.isInvalid())
-    InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
-  InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
-  if (InitialSuspend.isInvalid())
-    return FD->setInvalidDecl();
-
-  // Form and check implicit 'co_await p.final_suspend();' statement.
-  ExprResult FinalSuspend =
-      buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
-  // FIXME: Support operator co_await here.
-  if (!FinalSuspend.isInvalid())
-    FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
-  FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
-  if (FinalSuspend.isInvalid())
-    return FD->setInvalidDecl();
+    return StmtError();
 
-  // Form and check allocation and deallocation calls.
-  Expr *Allocation = nullptr;
-  Stmt *Deallocation = nullptr;
-  if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation))
-    return FD->setInvalidDecl();
+  if (!OnFallthrough && !buildFallthrough(*this, Loc, FD, FSI, OnFallthrough))
+    return StmtError();
 
-  // control flowing off the end of the coroutine.
-  // Also try to form 'p.set_exception(std::current_exception());' to handle
-  // uncaught exceptions.
-  ExprResult SetException;
-  StmtResult Fallthrough;
-  if (Fn->CoroutinePromise &&
-      !Fn->CoroutinePromise->getType()->isDependentType()) {
-    CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl();
-    assert(RD && "Type should have already been checked");
-    // [dcl.fct.def.coroutine]/4
-    // The unqualified-ids 'return_void' and 'return_value' are looked up in
-    // the scope of class P. If both are found, the program is ill-formed.
-    DeclarationName RVoidDN = PP.getIdentifierInfo("return_void");
-    LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName);
-    const bool HasRVoid = LookupQualifiedName(RVoidResult, RD);
-
-    DeclarationName RValueDN = PP.getIdentifierInfo("return_value");
-    LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName);
-    const bool HasRValue = LookupQualifiedName(RValueResult, RD);
-
-    if (HasRVoid && HasRValue) {
-      // FIXME Improve this diagnostic
-      Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed)
-          << RD;
-      return FD->setInvalidDecl();
-    } else if (HasRVoid) {
-      // If the unqualified-id return_void is found, flowing off the end of a
-      // coroutine is equivalent to a co_return with no operand. Otherwise,
-      // flowing off the end of a coroutine results in undefined behavior.
-      Fallthrough = BuildCoreturnStmt(FD->getLocation(), nullptr);
-      Fallthrough = ActOnFinishFullStmt(Fallthrough.get());
-      if (Fallthrough.isInvalid())
-        return FD->setInvalidDecl();
-    }
+  if (!SetException && !buildSetException(*this, Loc, FD, FSI, SetException))
+    return StmtError();
 
-    // [dcl.fct.def.coroutine]/3
-    // The unqualified-id set_exception is found in the scope of P by class
-    // member access lookup (3.4.5).
-    DeclarationName SetExDN = PP.getIdentifierInfo("set_exception");
-    LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName);
-    if (LookupQualifiedName(SetExResult, RD)) {
-      // Form the call 'p.set_exception(std::current_exception())'
-      SetException = buildStdCurrentExceptionCall(*this, Loc);
-      if (SetException.isInvalid())
-        return FD->setInvalidDecl();
-      Expr *E = SetException.get();
-      SetException = buildPromiseCall(*this, Fn, Loc, "set_exception", E);
-      SetException = ActOnFinishFullExpr(SetException.get(), Loc);
-      if (SetException.isInvalid())
-        return FD->setInvalidDecl();
-    }
+  if (!Allocation || !Deallocation) {
+    assert(!Allocation && !Deallocation && "These should be a package deal");
+    if (!buildAllocationAndDeallocation(*this, Loc, FSI, Allocation,
+                                        Deallocation))
+      return StmtError();
   }
 
   // Build implicit 'p.get_return_object()' expression and form initialization
   // of return type from it.
-  ExprResult ReturnObject =
-      buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
-  if (ReturnObject.isInvalid())
-    return FD->setInvalidDecl();
-  QualType RetType = FD->getReturnType();
-  if (!RetType->isDependentType()) {
-    InitializedEntity Entity =
-        InitializedEntity::InitializeResult(Loc, RetType, false);
-    ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
-                                                   ReturnObject.get());
+  if (!ReturnObjectInit) {
+    ExprResult ReturnObject =
+        buildPromiseCall(*this, Promise, Loc, "get_return_object", None);
     if (ReturnObject.isInvalid())
-      return FD->setInvalidDecl();
+      return StmtError();
+    QualType RetType = FD->getReturnType();
+    if (!RetType->isDependentType()) {
+      InitializedEntity Entity =
+          InitializedEntity::InitializeResult(Loc, RetType, false);
+      ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
+                                                     ReturnObject.get());
+      if (ReturnObject.isInvalid())
+        return StmtError();
+    }
+    ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
+    if (ReturnObject.isInvalid())
+      return StmtError();
+    ReturnObjectInit = ReturnObject.get();
   }
-  ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
-  if (ReturnObject.isInvalid())
+
+  return new (Context) CoroutineBodyStmt(
+      Body, PromiseStmt.get(), InitSuspend, FinalSuspend,
+      SetException, OnFallthrough, Allocation, Deallocation, ReturnObjectInit, None);
+}
+
+void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
+  FunctionScopeInfo *FSI = getCurFunction();
+  assert(FSI && FSI->CoroutinePromise && FSI->HasCoroutineSuspends &&
+         "not a coroutine");
+  VarDecl *Promise = FSI->CoroutinePromise;
+
+  // Check if we failed to build the initial/final suspend points during the
+  // initial parse.
+  if (FSI->hasInvalidCoroutineSuspends())
     return FD->setInvalidDecl();
 
   // FIXME: Perform move-initialization of parameters into frame-local copies.
   SmallVector<Expr*, 16> ParamMoves;
+  if (Body && !isa<CoroutineBodyStmt>(Body)) {
+    StmtResult BodyRes = BuildCoroutineBodyStmt(
+        Body, FSI->CoroutinePromise, FSI->CoroutineSuspends.first,
+        FSI->CoroutineSuspends.second, nullptr, nullptr, nullptr, nullptr,
+        nullptr);
+    if (BodyRes.isInvalid())
+      return FD->setInvalidDecl();
+    Body = BodyRes.get();
+  }
+
+
+  // Coroutines [stmt.return]p1:
+  //   A return statement shall not appear in a coroutine.
+  if (FSI->FirstReturnLoc.isValid()) {
+    Diag(FSI->FirstReturnLoc, diag::err_return_in_coroutine);
+    auto *First = FSI->CoroutineStmts[0];
+    Diag(First->getLocStart(), diag::note_declared_coroutine_here)
+        << ((isa<CoawaitExpr>(First) || isa<DependentCoawaitExpr>(First))
+                ? "co_await"
+                : isa<CoyieldExpr>(First) ? "co_yield" : "co_return");
+  }
+
+  bool AnyCoawaits = false;
+  bool AnyDependentCoawaits = false;
+  bool AnyCoyields = false;
+  for (auto *CoroutineStmt : FSI->CoroutineStmts) {
+    // Don't count the implicitly generated initial/final suspend points
+    if (auto *CA = dyn_cast<CoawaitExpr>(CoroutineStmt))
+      AnyCoawaits |= !CA->isImplicitlyCreated();
+    AnyDependentCoawaits |= isa<DependentCoawaitExpr>(CoroutineStmt);
+    AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
+  }
+
+  if (!FD->isInvalidDecl() && !FSI->CoroutineStmts.empty() && !AnyCoawaits &&
+      !AnyCoyields && !AnyDependentCoawaits)
+    Diag(FSI->CoroutineStmts.front()->getLocStart(),
+         diag::ext_coroutine_without_co_await_co_yield);
+
+  assert((!AnyDependentCoawaits || Promise->getType()->isDependentType()) &&
+         "All dependent coawait expressions should already be resolved");
 
-  // Build body for the coroutine wrapper statement.
-  Body = new (Context) CoroutineBodyStmt(
-      Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
-      SetException.get(), Fallthrough.get(), Allocation, Deallocation,
-      ReturnObject.get(), ParamMoves);
 }
Index: lib/Sema/ScopeInfo.cpp
===================================================================
--- lib/Sema/ScopeInfo.cpp
+++ lib/Sema/ScopeInfo.cpp
@@ -42,6 +42,9 @@
   SwitchStack.clear();
   Returns.clear();
   CoroutinePromise = nullptr;
+  HasCoroutineSuspends = false;
+  CoroutineSuspends.first = nullptr;
+  CoroutineSuspends.second = nullptr;
   CoroutineStmts.clear();
   ErrorTrap.reset();
   PossiblyUnreachableDiags.clear();
Index: lib/Parse/ParseStmt.cpp
===================================================================
--- lib/Parse/ParseStmt.cpp
+++ lib/Parse/ParseStmt.cpp
@@ -1898,7 +1898,7 @@
     }
   }
   if (IsCoreturn)
-    return Actions.ActOnCoreturnStmt(ReturnLoc, R.get());
+    return Actions.ActOnCoreturnStmt(getCurScope(), ReturnLoc, R.get());
   return Actions.ActOnReturnStmt(ReturnLoc, R.get(), getCurScope());
 }
 
Index: lib/AST/StmtProfile.cpp
===================================================================
--- lib/AST/StmtProfile.cpp
+++ lib/AST/StmtProfile.cpp
@@ -1552,6 +1552,10 @@
   VisitExpr(S);
 }
 
+void StmtProfiler::VisitDependentCoawaitExpr(const DependentCoawaitExpr *S) {
+  VisitExpr(S);
+}
+
 void StmtProfiler::VisitCoyieldExpr(const CoyieldExpr *S) {
   VisitExpr(S);
 }
Index: lib/AST/StmtPrinter.cpp
===================================================================
--- lib/AST/StmtPrinter.cpp
+++ lib/AST/StmtPrinter.cpp
@@ -2422,6 +2422,11 @@
   PrintExpr(S->getOperand());
 }
 
+void StmtPrinter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) {
+  OS << "co_await ";
+  PrintExpr(S->getOperand());
+}
+
 void StmtPrinter::VisitCoyieldExpr(CoyieldExpr *S) {
   OS << "co_yield ";
   PrintExpr(S->getOperand());
Index: lib/AST/ItaniumMangle.cpp
===================================================================
--- lib/AST/ItaniumMangle.cpp
+++ lib/AST/ItaniumMangle.cpp
@@ -3299,6 +3299,8 @@
   // These all can only appear in local or variable-initialization
   // contexts and so should never appear in a mangling.
   case Expr::AddrLabelExprClass:
+  // This should no longer exist in the AST by now
+  case Expr::DependentCoawaitExprClass:
   case Expr::DesignatedInitUpdateExprClass:
   case Expr::ImplicitValueInitExprClass:
   case Expr::NoInitExprClass:
Index: lib/AST/ExprConstant.cpp
===================================================================
--- lib/AST/ExprConstant.cpp
+++ lib/AST/ExprConstant.cpp
@@ -9442,6 +9442,7 @@
   case Expr::LambdaExprClass:
   case Expr::CXXFoldExprClass:
   case Expr::CoawaitExprClass:
+  case Expr::DependentCoawaitExprClass:
   case Expr::CoyieldExprClass:
     return ICEDiag(IK_NotICE, E->getLocStart());
 
Index: lib/AST/ExprClassification.cpp
===================================================================
--- lib/AST/ExprClassification.cpp
+++ lib/AST/ExprClassification.cpp
@@ -188,6 +188,9 @@
   case Expr::CXXFoldExprClass:
   case Expr::NoInitExprClass:
   case Expr::DesignatedInitUpdateExprClass:
+  // FIXME How should we classify co_await expressions while they're still
+  // dependent?
+  case Expr::DependentCoawaitExprClass:
   case Expr::CoyieldExprClass:
     return Cl::CL_PRValue;
 
Index: lib/AST/Expr.cpp
===================================================================
--- lib/AST/Expr.cpp
+++ lib/AST/Expr.cpp
@@ -2923,6 +2923,7 @@
   case CXXNewExprClass:
   case CXXDeleteExprClass:
   case CoawaitExprClass:
+  case DependentCoawaitExprClass:
   case CoyieldExprClass:
     // These always have a side-effect.
     return true;
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -8032,12 +8032,21 @@
   //
   ExprResult ActOnCoawaitExpr(Scope *S, SourceLocation KwLoc, Expr *E);
   ExprResult ActOnCoyieldExpr(Scope *S, SourceLocation KwLoc, Expr *E);
-  StmtResult ActOnCoreturnStmt(SourceLocation KwLoc, Expr *E);
+  StmtResult ActOnCoreturnStmt(Scope *S, SourceLocation KwLoc, Expr *E);
 
-  ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E);
+  ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E,
+                              bool IsImplicitlyCreated = false);
+  ExprResult BuildDependentCoawaitExpr(SourceLocation KwLoc, Expr *E,
+                                       UnresolvedLookupExpr *Lookup);
   ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E);
-  StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E);
-
+  StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E,
+                               bool IsImplicitlyCreated = false);
+  StmtResult BuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend,
+                                    Stmt *FinalSuspend, Stmt *OnException,
+                                    Stmt *OnFallthrough, Expr *Allocation,
+                                    Stmt *Deallocation, Expr *ReturnValue);
+
+  VarDecl *buildCoroutinePromise(SourceLocation Loc);
   void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
 
   //===--------------------------------------------------------------------===//
Index: include/clang/Sema/ScopeInfo.h
===================================================================
--- include/clang/Sema/ScopeInfo.h
+++ include/clang/Sema/ScopeInfo.h
@@ -16,12 +16,14 @@
 #define LLVM_CLANG_SEMA_SCOPEINFO_H
 
 #include "clang/AST/Expr.h"
+#include "clang/AST/StmtCXX.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/CapturedStmt.h"
 #include "clang/Basic/PartialDiagnostic.h"
 #include "clang/Sema/CleanupInfo.h"
 #include "clang/Sema/Ownership.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include <algorithm>
@@ -135,6 +137,10 @@
   /// false if there is an invocation of an initializer on 'self'.
   bool ObjCWarnForNoInitDelegation : 1;
 
+  /// true iff we have attempted to build the initial and final coroutine
+  /// suspend points.
+  bool HasCoroutineSuspends : 1;
+
   /// First 'return' statement in the current function.
   SourceLocation FirstReturnLoc;
 
@@ -159,6 +165,9 @@
   /// \brief The promise object for this coroutine, if any.
   VarDecl *CoroutinePromise;
 
+  /// \brief The initial and final coroutine suspend points.
+  std::pair<Stmt *, Stmt *> CoroutineSuspends;
+
   /// \brief The list of coroutine control flow constructs (co_await, co_yield,
   /// co_return) that occur within the function or block. Empty if and only if
   /// this function or block is not (yet known to be) a coroutine.
@@ -376,22 +385,33 @@
         (HasIndirectGoto ||
           (HasBranchProtectedScope && HasBranchIntoScope));
   }
-  
+
+  void setCoroutineSuspendsInvalid() {
+    assert(!HasCoroutineSuspends && CoroutineSuspends.first == nullptr &&
+           "we already have valid suspend points");
+    HasCoroutineSuspends = true;
+  }
+
+  bool hasInvalidCoroutineSuspends() const {
+    return HasCoroutineSuspends && CoroutineSuspends.first == nullptr;
+  }
+
+  void setCoroutineSuspends(Stmt *Initial, Stmt *Final) {
+    assert(Initial && Final && "suspend points cannot be null");
+    HasCoroutineSuspends = true;
+    CoroutineSuspends.first = Initial;
+    CoroutineSuspends.second = Final;
+  }
+
   FunctionScopeInfo(DiagnosticsEngine &Diag)
-    : Kind(SK_Function),
-      HasBranchProtectedScope(false),
-      HasBranchIntoScope(false),
-      HasIndirectGoto(false),
-      HasDroppedStmt(false),
-      HasOMPDeclareReductionCombiner(false),
-      HasFallthroughStmt(false),
-      HasPotentialAvailabilityViolations(false),
-      ObjCShouldCallSuper(false),
-      ObjCIsDesignatedInit(false),
-      ObjCWarnForNoDesignatedInitChain(false),
-      ObjCIsSecondaryInit(false),
-      ObjCWarnForNoInitDelegation(false),
-      ErrorTrap(Diag) { }
+      : Kind(SK_Function), HasBranchProtectedScope(false),
+        HasBranchIntoScope(false), HasIndirectGoto(false),
+        HasDroppedStmt(false), HasOMPDeclareReductionCombiner(false),
+        HasFallthroughStmt(false), HasPotentialAvailabilityViolations(false),
+        ObjCShouldCallSuper(false), ObjCIsDesignatedInit(false),
+        ObjCWarnForNoDesignatedInitChain(false), ObjCIsSecondaryInit(false),
+        ObjCWarnForNoInitDelegation(false), HasCoroutineSuspends(false),
+        CoroutinePromise(nullptr), ErrorTrap(Diag) {}
 
   virtual ~FunctionScopeInfo();
 
Index: include/clang/Basic/StmtNodes.td
===================================================================
--- include/clang/Basic/StmtNodes.td
+++ include/clang/Basic/StmtNodes.td
@@ -148,6 +148,7 @@
 // C++ Coroutines TS expressions
 def CoroutineSuspendExpr : DStmt<Expr, 1>;
 def CoawaitExpr : DStmt<CoroutineSuspendExpr>;
+def DependentCoawaitExpr : DStmt<Expr>;
 def CoyieldExpr : DStmt<CoroutineSuspendExpr>;
 
 // Obj-C Expressions.
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -8634,8 +8634,7 @@
 def err_return_in_coroutine : Error<
   "return statement not allowed in coroutine; did you mean 'co_return'?">;
 def note_declared_coroutine_here : Note<
-  "function is a coroutine due to use of "
-  "'%select{co_await|co_yield|co_return}0' here">;
+  "function is a coroutine due to use of '%0' here">;
 def err_coroutine_objc_method : Error<
   "Objective-C methods as coroutines are not yet supported">;
 def err_coroutine_unevaluated_context : Error<
@@ -8659,6 +8658,8 @@
   "this function cannot be a coroutine: %q0 has no member named 'promise_type'">;
 def err_implied_std_coroutine_traits_promise_type_not_class : Error<
   "this function cannot be a coroutine: %0 is not a class">;
+def err_coroutine_promise_type_incomplete : Error<
+  "this function cannot be a coroutine: %0 is an incomplete type">;
 def err_coroutine_traits_missing_specialization : Error<
   "this function cannot be a coroutine: missing definition of "
   "specialization %q0">;
@@ -8669,6 +8670,11 @@
   "'std::current_exception' must be a function">;
 def err_coroutine_promise_return_ill_formed : Error<
   "%0 declares both 'return_value' and 'return_void'">;
+def note_coroutine_promise_implicit_await_transform_required_here : Note<
+  "call to 'await_transform' implicitly required by 'co_await' here">;
+def note_coroutine_promise_call_implicitly_required : Note<
+  "call to '%select{initial_suspend|final_suspend}0' implicitly "
+  "required by the %select{initial suspend point|final suspend point}0">;
 }
 
 let CategoryName = "Documentation Issue" in {
Index: include/clang/AST/StmtCXX.h
===================================================================
--- include/clang/AST/StmtCXX.h
+++ include/clang/AST/StmtCXX.h
@@ -327,6 +327,8 @@
     SubStmts[CoroutineBodyStmt::Allocate] = Allocate;
     SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate;
     SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue;
+    assert(Promise && InitSuspend && FinalSuspend &&
+                   "these members must never be null");
     // FIXME: Tail-allocate space for parameter move expressions and store them.
     assert(ParamMoves.empty() && "not implemented yet");
   }
@@ -336,32 +338,54 @@
   Stmt *getBody() const {
     return SubStmts[SubStmt::Body];
   }
-
+  void setBody(Stmt *B) {
+    assert(!B || !isa<CoroutineBodyStmt>(B));
+    SubStmts[SubStmt::Body] = B;
+  }
   Stmt *getPromiseDeclStmt() const { return SubStmts[SubStmt::Promise]; }
   VarDecl *getPromiseDecl() const {
     return cast<VarDecl>(cast<DeclStmt>(getPromiseDeclStmt())->getSingleDecl());
   }
 
+  void setPromiseDeclStmt(Stmt *S) {
+    assert(SubStmts[SubStmt::Promise] == nullptr);
+    SubStmts[SubStmt::Promise] = S;
+  }
+
   Stmt *getInitSuspendStmt() const { return SubStmts[SubStmt::InitSuspend]; }
   Stmt *getFinalSuspendStmt() const { return SubStmts[SubStmt::FinalSuspend]; }
 
+  void setInitialSuspendStmt(Stmt *Suspend) {
+    assert(SubStmts[SubStmt::InitSuspend] == nullptr);
+    SubStmts[SubStmt::InitSuspend] = Suspend;
+  }
+  void setFinalSuspendStmt(Stmt *Suspend) {
+    assert(SubStmts[SubStmt::FinalSuspend] == nullptr);
+    SubStmts[SubStmt::FinalSuspend] = Suspend;
+  }
+
   Stmt *getExceptionHandler() const { return SubStmts[SubStmt::OnException]; }
+  void setExceptionHandler(Stmt *S) { SubStmts[SubStmt::OnException] = S; }
   Stmt *getFallthroughHandler() const {
     return SubStmts[SubStmt::OnFallthrough];
   }
-
-  Expr *getAllocate() const { return cast<Expr>(SubStmts[SubStmt::Allocate]); }
+  void setFalltroughHandler(Stmt *S) { SubStmts[SubStmt::OnFallthrough] = S; }
+  Expr *getAllocate() const { return cast_or_null<Expr>(SubStmts[SubStmt::Allocate]); }
   Stmt *getDeallocate() const { return SubStmts[SubStmt::Deallocate]; }
+  void setAllocate(Expr *E) { SubStmts[SubStmt::Allocate] = E; }
+  void setDeallocate(Stmt *S) { SubStmts[SubStmt::Deallocate] = S; }
 
   Expr *getReturnValueInit() const {
-    return cast<Expr>(SubStmts[SubStmt::ReturnValue]);
+    return cast_or_null<Expr>(SubStmts[SubStmt::ReturnValue]);
   }
 
+  void setReturnValueInit(Expr *E) { SubStmts[SubStmt::ReturnValue] = E; }
+
   SourceLocation getLocStart() const LLVM_READONLY {
-    return getBody()->getLocStart();
+    return getBody() ? getBody()->getLocStart() : getPromiseDecl()->getLocStart();
   }
   SourceLocation getLocEnd() const LLVM_READONLY {
-    return getBody()->getLocEnd();
+    return getBody() ? getBody()->getLocEnd() : getPromiseDecl()->getLocStart();
   }
 
   child_range children() {
@@ -390,10 +414,14 @@
   enum SubStmt { Operand, PromiseCall, Count };
   Stmt *SubStmts[SubStmt::Count];
 
+  bool IsImplicitlyCreated : 1;
+
   friend class ASTStmtReader;
 public:
-  CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall)
-      : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc) {
+  CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall,
+               bool IsImplicit = false)
+      : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc),
+        IsImplicitlyCreated(IsImplicit) {
     SubStmts[SubStmt::Operand] = Operand;
     SubStmts[SubStmt::PromiseCall] = PromiseCall;
   }
@@ -410,6 +438,8 @@
   Expr *getPromiseCall() const {
     return static_cast<Expr*>(SubStmts[PromiseCall]);
   }
+  bool isImplicitlyCreated() const { return IsImplicitlyCreated; }
+  void setImplicitlyCreated(bool value = true) { IsImplicitlyCreated = value; }
 
   SourceLocation getLocStart() const LLVM_READONLY { return CoreturnLoc; }
   SourceLocation getLocEnd() const LLVM_READONLY {
Index: include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- include/clang/AST/RecursiveASTVisitor.h
+++ include/clang/AST/RecursiveASTVisitor.h
@@ -2471,6 +2471,12 @@
     ShouldVisitChildren = false;
   }
 })
+DEF_TRAVERSE_STMT(DependentCoawaitExpr, {
+  if (!getDerived().shouldVisitImplicitCode()) {
+    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
+    ShouldVisitChildren = false;
+  }
+})
 DEF_TRAVERSE_STMT(CoyieldExpr, {
   if (!getDerived().shouldVisitImplicitCode()) {
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
Index: include/clang/AST/ExprCXX.h
===================================================================
--- include/clang/AST/ExprCXX.h
+++ include/clang/AST/ExprCXX.h
@@ -4231,26 +4231,82 @@
 /// \brief Represents a 'co_await' expression.
 class CoawaitExpr : public CoroutineSuspendExpr {
   friend class ASTStmtReader;
+
+  /// \brief True if this co_await expression was implicitly generated by the
+  /// compiler.
+  bool IsImplicitlyCreated : 1;
+
 public:
   CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Ready,
-              Expr *Suspend, Expr *Resume)
+              Expr *Suspend, Expr *Resume, bool IsImplicit = false)
       : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Ready,
-                             Suspend, Resume) {}
-  CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand)
-      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) {}
+                             Suspend, Resume),
+        IsImplicitlyCreated(IsImplicit) {}
+  CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand,
+              bool IsImplicit = false)
+      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand),
+        IsImplicitlyCreated(IsImplicit) {}
   CoawaitExpr(EmptyShell Empty)
       : CoroutineSuspendExpr(CoawaitExprClass, Empty) {}
 
   Expr *getOperand() const {
     // FIXME: Dig out the actual operand or store it.
     return getCommonExpr();
   }
 
+  bool isImplicitlyCreated() const { return IsImplicitlyCreated; }
+  void setIsImplicitlyCreated(bool value = true) {
+    IsImplicitlyCreated = value;
+  }
+
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == CoawaitExprClass;
   }
 };
 
+/// \brief Represents a 'co_await' expression while the type of the promise
+/// is dependent.
+class DependentCoawaitExpr : public Expr {
+  SourceLocation KeywordLoc;
+  Stmt *SubExprs[2];
+
+  friend class ASTStmtReader;
+
+public:
+  DependentCoawaitExpr(SourceLocation KeywordLoc, QualType Ty, Expr *Op,
+                       UnresolvedLookupExpr *OpCoawait)
+      : Expr(DependentCoawaitExprClass, Ty, VK_RValue, OK_Ordinary,
+             /*TypeDependent*/ true, /*ValueDependent*/ true,
+             /*InstantiationDependent*/ true,
+             Op->containsUnexpandedParameterPack()),
+        KeywordLoc(KeywordLoc) {
+    assert(Op->isTypeDependent() && Ty->isDependentType() &&
+           "wrong constructor for non-dependent co_await/co_yield expression");
+    SubExprs[0] = Op;
+    SubExprs[1] = OpCoawait;
+  }
+
+  DependentCoawaitExpr(EmptyShell Empty)
+      : Expr(DependentCoawaitExprClass, Empty) {}
+
+  Expr *getOperand() const { return cast<Expr>(SubExprs[0]); }
+  UnresolvedLookupExpr *getOperatorCoawaitLookup() const {
+    return cast<UnresolvedLookupExpr>(SubExprs[1]);
+  }
+  SourceLocation getKeywordLoc() const { return KeywordLoc; }
+
+  SourceLocation getLocStart() const LLVM_READONLY { return KeywordLoc; }
+  SourceLocation getLocEnd() const LLVM_READONLY {
+    return getOperand()->getLocEnd();
+  }
+
+  child_range children() { return child_range(SubExprs, SubExprs + 2); }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == DependentCoawaitExprClass;
+  }
+};
+
 /// \brief Represents a 'co_yield' expression.
 class CoyieldExpr : public CoroutineSuspendExpr {
   friend class ASTStmtReader;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to